use clap::crate_name;
use std::collections::HashMap;
use std::io::Error;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use std::thread;
use std::time::{Duration, Instant};
use strum::VariantNames;
use thrift::protocol::*;
use thrift::transport::*;
use crate::_osquery as osquery;
use crate::_osquery::{TExtensionManagerSyncClient, TExtensionSyncClient};
use crate::client::Client;
use crate::plugin::{OsqueryPlugin, Plugin, Registry};
use crate::util::OptionToThriftResult;
const DEFAULT_PING_INTERVAL: Duration = Duration::from_millis(500);
#[derive(Clone)]
pub struct ServerStopHandle {
shutdown_flag: Arc<AtomicBool>,
}
impl ServerStopHandle {
pub fn stop(&self) {
self.shutdown_flag.store(true, Ordering::Release);
}
pub fn is_running(&self) -> bool {
!self.shutdown_flag.load(Ordering::Acquire)
}
}
pub struct Server<P: OsqueryPlugin + Clone + Send + Sync + 'static> {
name: String,
socket_path: String,
client: Client,
plugins: Vec<P>,
ping_interval: Duration,
uuid: Option<osquery::ExtensionRouteUUID>,
started: bool,
shutdown_flag: Arc<AtomicBool>,
listener_thread: Option<thread::JoinHandle<()>>,
listen_path: Option<String>,
}
impl<P: OsqueryPlugin + Clone + Send + 'static> Server<P> {
pub fn new(name: Option<&str>, socket_path: &str) -> Result<Self, Error> {
let mut reg: HashMap<String, HashMap<String, Plugin>> = HashMap::new();
for var in Registry::VARIANTS {
reg.insert((*var).to_string(), HashMap::new());
}
let name = name.unwrap_or(crate_name!());
let client = Client::new(socket_path, Default::default())?;
Ok(Server {
name: name.to_string(),
socket_path: socket_path.to_string(),
client,
plugins: Vec::new(),
ping_interval: DEFAULT_PING_INTERVAL,
uuid: None,
started: false,
shutdown_flag: Arc::new(AtomicBool::new(false)),
listener_thread: None,
listen_path: None,
})
}
pub fn register_plugin(&mut self, plugin: P) -> &Self {
self.plugins.push(plugin);
self
}
pub fn run(&mut self) -> thrift::Result<()> {
self.start()?;
self.run_loop();
self.shutdown_and_cleanup();
Ok(())
}
#[cfg(unix)]
pub fn run_with_signal_handling(&mut self) -> thrift::Result<()> {
use signal_hook::consts::{SIGINT, SIGTERM};
use signal_hook::flag;
if let Err(e) = flag::register(SIGINT, self.shutdown_flag.clone()) {
log::warn!("Failed to register SIGINT handler: {e}");
}
if let Err(e) = flag::register(SIGTERM, self.shutdown_flag.clone()) {
log::warn!("Failed to register SIGTERM handler: {e}");
}
self.start()?;
self.run_loop();
self.shutdown_and_cleanup();
Ok(())
}
fn run_loop(&mut self) {
while !self.should_shutdown() {
if let Err(e) = self.client.ping() {
log::warn!("Ping failed, initiating shutdown: {e}");
self.request_shutdown();
break;
}
thread::sleep(self.ping_interval);
}
}
fn shutdown_and_cleanup(&mut self) {
log::info!("Shutting down");
self.join_listener_thread();
if let Some(uuid) = self.uuid {
if let Err(e) = self.client.deregister_extension(uuid) {
log::warn!("Failed to deregister from osquery: {e}");
}
}
self.notify_plugins_shutdown();
self.cleanup_socket();
}
fn join_listener_thread(&mut self) {
const JOIN_TIMEOUT: Duration = Duration::from_millis(100);
const POLL_INTERVAL: Duration = Duration::from_millis(10);
let Some(thread) = self.listener_thread.take() else {
return;
};
log::debug!("Waiting for listener thread to exit");
let start = Instant::now();
while !thread.is_finished() {
if start.elapsed() > JOIN_TIMEOUT {
log::warn!(
"Listener thread did not exit within {:?}, orphaning (will terminate on process exit)",
JOIN_TIMEOUT
);
return;
}
self.wake_listener();
thread::sleep(POLL_INTERVAL);
}
if let Err(e) = thread.join() {
log::warn!("Listener thread panicked: {e:?}");
}
}
fn start(&mut self) -> thrift::Result<()> {
let stat = self.client.register_extension(
osquery::InternalExtensionInfo {
name: Some(self.name.clone()),
version: Some("1.0".to_string()),
sdk_version: Some("Unknown".to_string()),
min_sdk_version: Some("Unknown".to_string()),
},
self.generate_registry()?,
)?;
log::info!(
"Status {} registering extension {} ({}): {}",
stat.code.unwrap_or(0),
self.name,
stat.uuid.unwrap_or(0),
stat.message.unwrap_or_else(|| "No message".to_string())
);
self.uuid = stat.uuid;
let listen_path = format!("{}.{}", self.socket_path, self.uuid.unwrap_or(0));
let processor = osquery::ExtensionManagerSyncProcessor::new(Handler::new(
&self.plugins,
self.shutdown_flag.clone(),
)?);
let i_tr_fact: Box<dyn TReadTransportFactory + Send> =
Box::new(TBufferedReadTransportFactory::new());
let i_pr_fact: Box<dyn TInputProtocolFactory + Send> =
Box::new(TBinaryInputProtocolFactory::new());
let o_tr_fact: Box<dyn TWriteTransportFactory + Send> =
Box::new(TBufferedWriteTransportFactory::new());
let o_pr_fact: Box<dyn TOutputProtocolFactory + Send> =
Box::new(TBinaryOutputProtocolFactory::new());
let mut server =
thrift::server::TServer::new(i_tr_fact, i_pr_fact, o_tr_fact, o_pr_fact, processor, 10);
self.listen_path = Some(listen_path.clone());
let listener_thread = thread::spawn(move || {
if let Err(e) = server.listen_uds(listen_path) {
log::debug!("Listener thread exited: {e}");
}
});
self.listener_thread = Some(listener_thread);
self.started = true;
Ok(())
}
fn generate_registry(&self) -> thrift::Result<osquery::ExtensionRegistry> {
let mut registry = osquery::ExtensionRegistry::new();
for var in Registry::VARIANTS {
registry.insert((*var).to_string(), osquery::ExtensionRouteTable::new());
}
for plugin in self.plugins.iter() {
registry
.get_mut(plugin.registry().to_string().as_str())
.ok_or_thrift_err(|| format!("Failed to register plugin {}", plugin.name()))?
.insert(plugin.name(), plugin.routes());
}
Ok(registry)
}
fn should_shutdown(&self) -> bool {
self.shutdown_flag.load(Ordering::Acquire)
}
fn request_shutdown(&self) {
self.shutdown_flag.store(true, Ordering::Release);
}
fn wake_listener(&self) {
if let Some(ref path) = self.listen_path {
let _ = std::os::unix::net::UnixStream::connect(path);
}
}
fn notify_plugins_shutdown(&self) {
log::debug!("Notifying {} plugins of shutdown", self.plugins.len());
for plugin in &self.plugins {
let plugin_name = plugin.name();
if let Err(e) = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
plugin.shutdown();
})) {
log::error!("Plugin '{plugin_name}' panicked during shutdown: {e:?}");
}
}
}
fn cleanup_socket(&self) {
let Some(uuid) = self.uuid else {
log::debug!("No socket to clean up (uuid not set)");
return;
};
let socket_path = format!("{}.{}", self.socket_path, uuid);
log::debug!("Cleaning up socket: {socket_path}");
if let Err(e) = std::fs::remove_file(&socket_path) {
if e.kind() != std::io::ErrorKind::NotFound {
log::warn!("Failed to remove socket file {socket_path}: {e}");
}
}
}
pub fn get_stop_handle(&self) -> ServerStopHandle {
ServerStopHandle {
shutdown_flag: self.shutdown_flag.clone(),
}
}
pub fn stop(&self) {
self.request_shutdown();
}
pub fn is_running(&self) -> bool {
!self.should_shutdown()
}
}
struct Handler<P: OsqueryPlugin + Clone> {
registry: HashMap<String, HashMap<String, P>>,
shutdown_flag: Arc<AtomicBool>,
}
impl<P: OsqueryPlugin + Clone> Handler<P> {
fn new(plugins: &[P], shutdown_flag: Arc<AtomicBool>) -> thrift::Result<Self> {
let mut reg: HashMap<String, HashMap<String, P>> = HashMap::new();
for var in Registry::VARIANTS {
reg.insert((*var).to_string(), HashMap::new());
}
for plugin in plugins.iter() {
reg.get_mut(plugin.registry().to_string().as_str())
.ok_or_thrift_err(|| format!("Failed to register plugin {}", plugin.name()))?
.insert(plugin.name(), plugin.clone());
}
Ok(Handler {
registry: reg,
shutdown_flag,
})
}
}
impl<P: OsqueryPlugin + Clone> osquery::ExtensionSyncHandler for Handler<P> {
fn handle_ping(&self) -> thrift::Result<osquery::ExtensionStatus> {
Ok(osquery::ExtensionStatus::default())
}
fn handle_call(
&self,
registry: String,
item: String,
request: osquery::ExtensionPluginRequest,
) -> thrift::Result<osquery::ExtensionResponse> {
log::trace!("Registry: {registry}");
log::trace!("Item: {item}");
log::trace!("Request: {request:?}");
let plugin = self
.registry
.get(registry.as_str())
.ok_or_thrift_err(|| {
format!(
"Failed to get registry:{} from registries",
registry.as_str()
)
})?
.get(item.as_str())
.ok_or_thrift_err(|| {
format!(
"Failed to item:{} from registry:{}",
item.as_str(),
registry.as_str()
)
})?;
Ok(plugin.handle_call(request))
}
fn handle_shutdown(&self) -> thrift::Result<()> {
log::debug!("Shutdown RPC received from osquery");
self.shutdown_flag.store(true, Ordering::Release);
Ok(())
}
}
impl<P: OsqueryPlugin + Clone> osquery::ExtensionManagerSyncHandler for Handler<P> {
fn handle_extensions(&self) -> thrift::Result<osquery::InternalExtensionList> {
Ok(osquery::InternalExtensionList::new())
}
fn handle_options(&self) -> thrift::Result<osquery::InternalOptionList> {
Ok(osquery::InternalOptionList::new())
}
fn handle_register_extension(
&self,
_info: osquery::InternalExtensionInfo,
_registry: osquery::ExtensionRegistry,
) -> thrift::Result<osquery::ExtensionStatus> {
Ok(osquery::ExtensionStatus {
code: Some(1),
message: Some("Extension registration not supported".to_string()),
uuid: None,
})
}
fn handle_deregister_extension(
&self,
_uuid: osquery::ExtensionRouteUUID,
) -> thrift::Result<osquery::ExtensionStatus> {
Ok(osquery::ExtensionStatus {
code: Some(1),
message: Some("Extension deregistration not supported".to_string()),
uuid: None,
})
}
fn handle_query(&self, _sql: String) -> thrift::Result<osquery::ExtensionResponse> {
Ok(osquery::ExtensionResponse::new(
osquery::ExtensionStatus {
code: Some(1),
message: Some("Query execution not supported".to_string()),
uuid: None,
},
vec![],
))
}
fn handle_get_query_columns(&self, _sql: String) -> thrift::Result<osquery::ExtensionResponse> {
Ok(osquery::ExtensionResponse::new(
osquery::ExtensionStatus {
code: Some(1),
message: Some("Query column introspection not supported".to_string()),
uuid: None,
},
vec![],
))
}
}