1use clap::crate_name;
2use std::collections::HashMap;
3use std::io::Error;
4use std::sync::atomic::{AtomicBool, Ordering};
5use std::sync::Arc;
6use std::thread;
7use std::time::{Duration, Instant};
8use strum::VariantNames;
9use thrift::protocol::*;
10use thrift::transport::*;
11
12use crate::_osquery as osquery;
13use crate::_osquery::{TExtensionManagerSyncClient, TExtensionSyncClient};
14use crate::client::Client;
15use crate::plugin::{OsqueryPlugin, Plugin, Registry};
16use crate::util::OptionToThriftResult;
17
18const DEFAULT_PING_INTERVAL: Duration = Duration::from_millis(500);
19
20#[derive(Clone)]
45pub struct ServerStopHandle {
46 shutdown_flag: Arc<AtomicBool>,
47}
48
49impl ServerStopHandle {
50 pub fn stop(&self) {
55 self.shutdown_flag.store(true, Ordering::Release);
56 }
57
58 pub fn is_running(&self) -> bool {
63 !self.shutdown_flag.load(Ordering::Acquire)
64 }
65}
66
67pub struct Server<P: OsqueryPlugin + Clone + Send + Sync + 'static> {
68 name: String,
69 socket_path: String,
70 client: Client,
71 plugins: Vec<P>,
72 ping_interval: Duration,
73 uuid: Option<osquery::ExtensionRouteUUID>,
74 started: bool,
76 shutdown_flag: Arc<AtomicBool>,
77 listener_thread: Option<thread::JoinHandle<()>>,
79 listen_path: Option<String>,
81}
82
83impl<P: OsqueryPlugin + Clone + Send + 'static> Server<P> {
84 pub fn new(name: Option<&str>, socket_path: &str) -> Result<Self, Error> {
85 let mut reg: HashMap<String, HashMap<String, Plugin>> = HashMap::new();
86 for var in Registry::VARIANTS {
87 reg.insert((*var).to_string(), HashMap::new());
88 }
89
90 let name = name.unwrap_or(crate_name!());
91
92 let client = Client::new(socket_path, Default::default())?;
93
94 Ok(Server {
95 name: name.to_string(),
96 socket_path: socket_path.to_string(),
97 client,
98 plugins: Vec::new(),
99 ping_interval: DEFAULT_PING_INTERVAL,
100 uuid: None,
101 started: false,
102 shutdown_flag: Arc::new(AtomicBool::new(false)),
103 listener_thread: None,
104 listen_path: None,
105 })
106 }
107
108 pub fn register_plugin(&mut self, plugin: P) -> &Self {
113 self.plugins.push(plugin);
114 self
115 }
116
117 pub fn run(&mut self) -> thrift::Result<()> {
128 self.start()?;
129 self.run_loop();
130 self.shutdown_and_cleanup();
131 Ok(())
132 }
133
134 #[cfg(unix)]
151 pub fn run_with_signal_handling(&mut self) -> thrift::Result<()> {
152 use signal_hook::consts::{SIGINT, SIGTERM};
153 use signal_hook::flag;
154
155 if let Err(e) = flag::register(SIGINT, self.shutdown_flag.clone()) {
160 log::warn!("Failed to register SIGINT handler: {e}");
161 }
162 if let Err(e) = flag::register(SIGTERM, self.shutdown_flag.clone()) {
163 log::warn!("Failed to register SIGTERM handler: {e}");
164 }
165
166 self.start()?;
167 self.run_loop();
168 self.shutdown_and_cleanup();
169 Ok(())
170 }
171
172 fn run_loop(&mut self) {
174 while !self.should_shutdown() {
175 if let Err(e) = self.client.ping() {
176 log::warn!("Ping failed, initiating shutdown: {e}");
177 self.request_shutdown();
178 break;
179 }
180 thread::sleep(self.ping_interval);
181 }
182 }
183
184 fn shutdown_and_cleanup(&mut self) {
186 log::info!("Shutting down");
187
188 self.join_listener_thread();
189
190 if let Some(uuid) = self.uuid {
192 if let Err(e) = self.client.deregister_extension(uuid) {
193 log::warn!("Failed to deregister from osquery: {e}");
194 }
195 }
196
197 self.notify_plugins_shutdown();
198 self.cleanup_socket();
199 }
200
201 fn join_listener_thread(&mut self) {
212 const JOIN_TIMEOUT: Duration = Duration::from_millis(100);
213 const POLL_INTERVAL: Duration = Duration::from_millis(10);
214
215 let Some(thread) = self.listener_thread.take() else {
216 return;
217 };
218
219 log::debug!("Waiting for listener thread to exit");
220 let start = Instant::now();
221
222 while !thread.is_finished() {
223 if start.elapsed() > JOIN_TIMEOUT {
224 log::warn!(
225 "Listener thread did not exit within {:?}, orphaning (will terminate on process exit)",
226 JOIN_TIMEOUT
227 );
228 return;
229 }
230 self.wake_listener();
231 thread::sleep(POLL_INTERVAL);
232 }
233
234 if let Err(e) = thread.join() {
236 log::warn!("Listener thread panicked: {e:?}");
237 }
238 }
239
240 fn start(&mut self) -> thrift::Result<()> {
241 let stat = self.client.register_extension(
242 osquery::InternalExtensionInfo {
243 name: Some(self.name.clone()),
244 version: Some("1.0".to_string()),
245 sdk_version: Some("Unknown".to_string()),
246 min_sdk_version: Some("Unknown".to_string()),
247 },
248 self.generate_registry()?,
249 )?;
250
251 log::info!(
253 "Status {} registering extension {} ({}): {}",
254 stat.code.unwrap_or(0),
255 self.name,
256 stat.uuid.unwrap_or(0),
257 stat.message.unwrap_or_else(|| "No message".to_string())
258 );
259 self.uuid = stat.uuid;
262 let listen_path = format!("{}.{}", self.socket_path, self.uuid.unwrap_or(0));
263
264 let processor = osquery::ExtensionManagerSyncProcessor::new(Handler::new(
265 &self.plugins,
266 self.shutdown_flag.clone(),
267 )?);
268 let i_tr_fact: Box<dyn TReadTransportFactory + Send> =
269 Box::new(TBufferedReadTransportFactory::new());
270 let i_pr_fact: Box<dyn TInputProtocolFactory + Send> =
271 Box::new(TBinaryInputProtocolFactory::new());
272 let o_tr_fact: Box<dyn TWriteTransportFactory + Send> =
273 Box::new(TBufferedWriteTransportFactory::new());
274 let o_pr_fact: Box<dyn TOutputProtocolFactory + Send> =
275 Box::new(TBinaryOutputProtocolFactory::new());
276
277 let mut server =
278 thrift::server::TServer::new(i_tr_fact, i_pr_fact, o_tr_fact, o_pr_fact, processor, 10);
279
280 self.listen_path = Some(listen_path.clone());
282
283 let listener_thread = thread::spawn(move || {
287 if let Err(e) = server.listen_uds(listen_path) {
288 log::debug!("Listener thread exited: {e}");
290 }
291 });
292
293 self.listener_thread = Some(listener_thread);
294 self.started = true;
295
296 Ok(())
297 }
298
299 fn generate_registry(&self) -> thrift::Result<osquery::ExtensionRegistry> {
300 let mut registry = osquery::ExtensionRegistry::new();
301
302 for var in Registry::VARIANTS {
303 registry.insert((*var).to_string(), osquery::ExtensionRouteTable::new());
304 }
305
306 for plugin in self.plugins.iter() {
307 registry
308 .get_mut(plugin.registry().to_string().as_str())
309 .ok_or_thrift_err(|| format!("Failed to register plugin {}", plugin.name()))?
310 .insert(plugin.name(), plugin.routes());
311 }
312 Ok(registry)
313 }
314
315 fn should_shutdown(&self) -> bool {
317 self.shutdown_flag.load(Ordering::Acquire)
318 }
319
320 fn request_shutdown(&self) {
322 self.shutdown_flag.store(true, Ordering::Release);
323 }
324
325 fn wake_listener(&self) {
349 if let Some(ref path) = self.listen_path {
350 let _ = std::os::unix::net::UnixStream::connect(path);
351 }
352 }
353
354 fn notify_plugins_shutdown(&self) {
357 log::debug!("Notifying {} plugins of shutdown", self.plugins.len());
358 for plugin in &self.plugins {
359 let plugin_name = plugin.name();
360 if let Err(e) = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
361 plugin.shutdown();
362 })) {
363 log::error!("Plugin '{plugin_name}' panicked during shutdown: {e:?}");
364 }
365 }
366 }
367
368 fn cleanup_socket(&self) {
371 let Some(uuid) = self.uuid else {
372 log::debug!("No socket to clean up (uuid not set)");
373 return;
374 };
375
376 let socket_path = format!("{}.{}", self.socket_path, uuid);
377 log::debug!("Cleaning up socket: {socket_path}");
378
379 if let Err(e) = std::fs::remove_file(&socket_path) {
380 if e.kind() != std::io::ErrorKind::NotFound {
381 log::warn!("Failed to remove socket file {socket_path}: {e}");
382 }
383 }
384 }
385
386 pub fn get_stop_handle(&self) -> ServerStopHandle {
392 ServerStopHandle {
393 shutdown_flag: self.shutdown_flag.clone(),
394 }
395 }
396
397 pub fn stop(&self) {
403 self.request_shutdown();
404 }
405
406 pub fn is_running(&self) -> bool {
412 !self.should_shutdown()
413 }
414}
415
416struct Handler<P: OsqueryPlugin + Clone> {
417 registry: HashMap<String, HashMap<String, P>>,
418 shutdown_flag: Arc<AtomicBool>,
419}
420
421impl<P: OsqueryPlugin + Clone> Handler<P> {
422 fn new(plugins: &[P], shutdown_flag: Arc<AtomicBool>) -> thrift::Result<Self> {
423 let mut reg: HashMap<String, HashMap<String, P>> = HashMap::new();
424 for var in Registry::VARIANTS {
425 reg.insert((*var).to_string(), HashMap::new());
426 }
427
428 for plugin in plugins.iter() {
429 reg.get_mut(plugin.registry().to_string().as_str())
430 .ok_or_thrift_err(|| format!("Failed to register plugin {}", plugin.name()))?
431 .insert(plugin.name(), plugin.clone());
432 }
433
434 Ok(Handler {
435 registry: reg,
436 shutdown_flag,
437 })
438 }
439}
440
441impl<P: OsqueryPlugin + Clone> osquery::ExtensionSyncHandler for Handler<P> {
442 fn handle_ping(&self) -> thrift::Result<osquery::ExtensionStatus> {
443 Ok(osquery::ExtensionStatus::default())
444 }
445
446 fn handle_call(
447 &self,
448 registry: String,
449 item: String,
450 request: osquery::ExtensionPluginRequest,
451 ) -> thrift::Result<osquery::ExtensionResponse> {
452 log::trace!("Registry: {registry}");
453 log::trace!("Item: {item}");
454 log::trace!("Request: {request:?}");
455
456 let plugin = self
457 .registry
458 .get(registry.as_str())
459 .ok_or_thrift_err(|| {
460 format!(
461 "Failed to get registry:{} from registries",
462 registry.as_str()
463 )
464 })?
465 .get(item.as_str())
466 .ok_or_thrift_err(|| {
467 format!(
468 "Failed to item:{} from registry:{}",
469 item.as_str(),
470 registry.as_str()
471 )
472 })?;
473
474 Ok(plugin.handle_call(request))
475 }
476
477 fn handle_shutdown(&self) -> thrift::Result<()> {
478 log::debug!("Shutdown RPC received from osquery");
479 self.shutdown_flag.store(true, Ordering::Release);
480 Ok(())
481 }
482}
483
484impl<P: OsqueryPlugin + Clone> osquery::ExtensionManagerSyncHandler for Handler<P> {
485 fn handle_extensions(&self) -> thrift::Result<osquery::InternalExtensionList> {
486 Ok(osquery::InternalExtensionList::new())
488 }
489
490 fn handle_options(&self) -> thrift::Result<osquery::InternalOptionList> {
491 Ok(osquery::InternalOptionList::new())
493 }
494
495 fn handle_register_extension(
496 &self,
497 _info: osquery::InternalExtensionInfo,
498 _registry: osquery::ExtensionRegistry,
499 ) -> thrift::Result<osquery::ExtensionStatus> {
500 Ok(osquery::ExtensionStatus {
502 code: Some(1),
503 message: Some("Extension registration not supported".to_string()),
504 uuid: None,
505 })
506 }
507
508 fn handle_deregister_extension(
509 &self,
510 _uuid: osquery::ExtensionRouteUUID,
511 ) -> thrift::Result<osquery::ExtensionStatus> {
512 Ok(osquery::ExtensionStatus {
514 code: Some(1),
515 message: Some("Extension deregistration not supported".to_string()),
516 uuid: None,
517 })
518 }
519
520 fn handle_query(&self, _sql: String) -> thrift::Result<osquery::ExtensionResponse> {
521 Ok(osquery::ExtensionResponse::new(
523 osquery::ExtensionStatus {
524 code: Some(1),
525 message: Some("Query execution not supported".to_string()),
526 uuid: None,
527 },
528 vec![],
529 ))
530 }
531
532 fn handle_get_query_columns(&self, _sql: String) -> thrift::Result<osquery::ExtensionResponse> {
533 Ok(osquery::ExtensionResponse::new(
535 osquery::ExtensionStatus {
536 code: Some(1),
537 message: Some("Query column introspection not supported".to_string()),
538 uuid: None,
539 },
540 vec![],
541 ))
542 }
543}