osquery_rust_ng/
server.rs

1use clap::crate_name;
2use std::collections::HashMap;
3use std::io::Error;
4use std::sync::atomic::{AtomicBool, Ordering};
5use std::sync::Arc;
6use std::thread;
7use std::time::{Duration, Instant};
8use strum::VariantNames;
9use thrift::protocol::*;
10use thrift::transport::*;
11
12use crate::_osquery as osquery;
13use crate::_osquery::{TExtensionManagerSyncClient, TExtensionSyncClient};
14use crate::client::Client;
15use crate::plugin::{OsqueryPlugin, Plugin, Registry};
16use crate::util::OptionToThriftResult;
17
18const DEFAULT_PING_INTERVAL: Duration = Duration::from_millis(500);
19
20/// Handle that allows stopping the server from another thread.
21///
22/// This handle can be cloned and shared across threads. It provides a way for
23/// external code to request a graceful shutdown of the server.
24///
25/// # Thread Safety
26///
27/// `ServerStopHandle` is `Clone + Send + Sync` and can be safely shared between
28/// threads. Multiple calls to `stop()` are safe and idempotent.
29///
30/// # Example
31///
32/// ```ignore
33/// let mut server = Server::new(None, "/path/to/socket")?;
34/// let handle = server.get_stop_handle();
35///
36/// // In another thread:
37/// std::thread::spawn(move || {
38///     // ... some condition ...
39///     handle.stop();
40/// });
41///
42/// server.run()?; // Will exit when stop() is called
43/// ```
44#[derive(Clone)]
45pub struct ServerStopHandle {
46    shutdown_flag: Arc<AtomicBool>,
47}
48
49impl ServerStopHandle {
50    /// Request the server to stop.
51    ///
52    /// This method is idempotent - multiple calls are safe.
53    /// The server will exit its run loop on the next iteration.
54    pub fn stop(&self) {
55        self.shutdown_flag.store(true, Ordering::Release);
56    }
57
58    /// Check if the server is still running.
59    ///
60    /// Returns `true` if the server has not been requested to stop,
61    /// `false` if `stop()` has been called.
62    pub fn is_running(&self) -> bool {
63        !self.shutdown_flag.load(Ordering::Acquire)
64    }
65}
66
67pub struct Server<P: OsqueryPlugin + Clone + Send + Sync + 'static> {
68    name: String,
69    socket_path: String,
70    client: Client,
71    plugins: Vec<P>,
72    ping_interval: Duration,
73    uuid: Option<osquery::ExtensionRouteUUID>,
74    // Used to ensure tests wait until the server is actually started
75    started: bool,
76    shutdown_flag: Arc<AtomicBool>,
77    /// Handle to the listener thread for graceful shutdown
78    listener_thread: Option<thread::JoinHandle<()>>,
79    /// Path to the listener socket for wake-up connection on shutdown
80    listen_path: Option<String>,
81}
82
83impl<P: OsqueryPlugin + Clone + Send + 'static> Server<P> {
84    pub fn new(name: Option<&str>, socket_path: &str) -> Result<Self, Error> {
85        let mut reg: HashMap<String, HashMap<String, Plugin>> = HashMap::new();
86        for var in Registry::VARIANTS {
87            reg.insert((*var).to_string(), HashMap::new());
88        }
89
90        let name = name.unwrap_or(crate_name!());
91
92        let client = Client::new(socket_path, Default::default())?;
93
94        Ok(Server {
95            name: name.to_string(),
96            socket_path: socket_path.to_string(),
97            client,
98            plugins: Vec::new(),
99            ping_interval: DEFAULT_PING_INTERVAL,
100            uuid: None,
101            started: false,
102            shutdown_flag: Arc::new(AtomicBool::new(false)),
103            listener_thread: None,
104            listen_path: None,
105        })
106    }
107
108    ///
109    /// Registers a plugin, something which implements the OsqueryPlugin trait.
110    /// Consumes the plugin.
111    ///
112    pub fn register_plugin(&mut self, plugin: P) -> &Self {
113        self.plugins.push(plugin);
114        self
115    }
116
117    /// Run the server, blocking until shutdown is requested.
118    ///
119    /// This method starts the server, registers with osquery, and enters a loop
120    /// that pings osquery periodically. The loop exits when shutdown is triggered
121    /// by any of:
122    /// - osquery calling the shutdown RPC
123    /// - Connection to osquery being lost
124    /// - `stop()` being called from another thread
125    ///
126    /// For signal handling (SIGTERM/SIGINT), use `run_with_signal_handling()` instead.
127    pub fn run(&mut self) -> thrift::Result<()> {
128        self.start()?;
129        self.run_loop();
130        self.shutdown_and_cleanup();
131        Ok(())
132    }
133
134    /// Run the server with signal handling enabled (Unix only).
135    ///
136    /// This method registers handlers for SIGTERM and SIGINT that will trigger
137    /// graceful shutdown. Use this instead of `run()` if you want the server to
138    /// respond to OS signals (e.g., systemd sending SIGTERM, or Ctrl+C sending SIGINT).
139    ///
140    /// The loop exits when shutdown is triggered by any of:
141    /// - SIGTERM or SIGINT signal received
142    /// - osquery calling the shutdown RPC
143    /// - Connection to osquery being lost
144    /// - `stop()` being called from another thread
145    ///
146    /// # Platform Support
147    ///
148    /// This method is only available on Unix platforms. For Windows, use `run()`
149    /// and implement your own signal handling.
150    #[cfg(unix)]
151    pub fn run_with_signal_handling(&mut self) -> thrift::Result<()> {
152        use signal_hook::consts::{SIGINT, SIGTERM};
153        use signal_hook::flag;
154
155        // Register signal handlers that set our shutdown flag.
156        // signal_hook::flag::register atomically sets the bool when signal received.
157        // Errors are rare (e.g., invalid signal number) and non-fatal - signals
158        // just won't trigger shutdown, but other shutdown mechanisms still work.
159        if let Err(e) = flag::register(SIGINT, self.shutdown_flag.clone()) {
160            log::warn!("Failed to register SIGINT handler: {e}");
161        }
162        if let Err(e) = flag::register(SIGTERM, self.shutdown_flag.clone()) {
163            log::warn!("Failed to register SIGTERM handler: {e}");
164        }
165
166        self.start()?;
167        self.run_loop();
168        self.shutdown_and_cleanup();
169        Ok(())
170    }
171
172    /// The main ping loop. Exits when should_shutdown() returns true.
173    fn run_loop(&mut self) {
174        while !self.should_shutdown() {
175            if let Err(e) = self.client.ping() {
176                log::warn!("Ping failed, initiating shutdown: {e}");
177                self.request_shutdown();
178                break;
179            }
180            thread::sleep(self.ping_interval);
181        }
182    }
183
184    /// Common shutdown logic: wake listener, join thread, deregister, notify plugins, cleanup socket.
185    fn shutdown_and_cleanup(&mut self) {
186        log::info!("Shutting down");
187
188        self.join_listener_thread();
189
190        // Deregister from osquery (best-effort, allows faster cleanup than timeout)
191        if let Some(uuid) = self.uuid {
192            if let Err(e) = self.client.deregister_extension(uuid) {
193                log::warn!("Failed to deregister from osquery: {e}");
194            }
195        }
196
197        self.notify_plugins_shutdown();
198        self.cleanup_socket();
199    }
200
201    /// Attempt to join the listener thread with a timeout.
202    ///
203    /// The thrift listener has an infinite loop that we cannot control, so we use
204    /// a timed join: repeatedly wake the listener and check if it has exited.
205    /// If it doesn't exit within the timeout, we orphan the thread (it will be
206    /// cleaned up when the process exits).
207    ///
208    /// This is a pragmatic solution per:
209    /// - <https://matklad.github.io/2019/08/23/join-your-threads.html>
210    /// - <https://github.com/rust-lang/rust/issues/26446>
211    fn join_listener_thread(&mut self) {
212        const JOIN_TIMEOUT: Duration = Duration::from_millis(100);
213        const POLL_INTERVAL: Duration = Duration::from_millis(10);
214
215        let Some(thread) = self.listener_thread.take() else {
216            return;
217        };
218
219        log::debug!("Waiting for listener thread to exit");
220        let start = Instant::now();
221
222        while !thread.is_finished() {
223            if start.elapsed() > JOIN_TIMEOUT {
224                log::warn!(
225                    "Listener thread did not exit within {:?}, orphaning (will terminate on process exit)",
226                    JOIN_TIMEOUT
227                );
228                return;
229            }
230            self.wake_listener();
231            thread::sleep(POLL_INTERVAL);
232        }
233
234        // Thread finished, now we can join without blocking
235        if let Err(e) = thread.join() {
236            log::warn!("Listener thread panicked: {e:?}");
237        }
238    }
239
240    fn start(&mut self) -> thrift::Result<()> {
241        let stat = self.client.register_extension(
242            osquery::InternalExtensionInfo {
243                name: Some(self.name.clone()),
244                version: Some("1.0".to_string()),
245                sdk_version: Some("Unknown".to_string()),
246                min_sdk_version: Some("Unknown".to_string()),
247            },
248            self.generate_registry()?,
249        )?;
250
251        //if stat.code != Some(0) {
252        log::info!(
253            "Status {} registering extension {} ({}): {}",
254            stat.code.unwrap_or(0),
255            self.name,
256            stat.uuid.unwrap_or(0),
257            stat.message.unwrap_or_else(|| "No message".to_string())
258        );
259        //}
260
261        self.uuid = stat.uuid;
262        let listen_path = format!("{}.{}", self.socket_path, self.uuid.unwrap_or(0));
263
264        let processor = osquery::ExtensionManagerSyncProcessor::new(Handler::new(
265            &self.plugins,
266            self.shutdown_flag.clone(),
267        )?);
268        let i_tr_fact: Box<dyn TReadTransportFactory + Send> =
269            Box::new(TBufferedReadTransportFactory::new());
270        let i_pr_fact: Box<dyn TInputProtocolFactory + Send> =
271            Box::new(TBinaryInputProtocolFactory::new());
272        let o_tr_fact: Box<dyn TWriteTransportFactory + Send> =
273            Box::new(TBufferedWriteTransportFactory::new());
274        let o_pr_fact: Box<dyn TOutputProtocolFactory + Send> =
275            Box::new(TBinaryOutputProtocolFactory::new());
276
277        let mut server =
278            thrift::server::TServer::new(i_tr_fact, i_pr_fact, o_tr_fact, o_pr_fact, processor, 10);
279
280        // Store the listen path for wake-up connection on shutdown
281        self.listen_path = Some(listen_path.clone());
282
283        // Spawn the listener in a background thread so we can check shutdown flag
284        // in run_loop(). The thrift listen_uds() blocks forever, so without this
285        // the server cannot gracefully shutdown.
286        let listener_thread = thread::spawn(move || {
287            if let Err(e) = server.listen_uds(listen_path) {
288                // Log but don't panic - listener exiting is expected on shutdown
289                log::debug!("Listener thread exited: {e}");
290            }
291        });
292
293        self.listener_thread = Some(listener_thread);
294        self.started = true;
295
296        Ok(())
297    }
298
299    fn generate_registry(&self) -> thrift::Result<osquery::ExtensionRegistry> {
300        let mut registry = osquery::ExtensionRegistry::new();
301
302        for var in Registry::VARIANTS {
303            registry.insert((*var).to_string(), osquery::ExtensionRouteTable::new());
304        }
305
306        for plugin in self.plugins.iter() {
307            registry
308                .get_mut(plugin.registry().to_string().as_str())
309                .ok_or_thrift_err(|| format!("Failed to register plugin {}", plugin.name()))?
310                .insert(plugin.name(), plugin.routes());
311        }
312        Ok(registry)
313    }
314
315    /// Check if shutdown has been requested.
316    fn should_shutdown(&self) -> bool {
317        self.shutdown_flag.load(Ordering::Acquire)
318    }
319
320    /// Request shutdown by setting the shutdown flag.
321    fn request_shutdown(&self) {
322        self.shutdown_flag.store(true, Ordering::Release);
323    }
324
325    /// Wake the blocking listener thread by making a dummy connection.
326    ///
327    /// # Why This Workaround Exists
328    ///
329    /// The thrift crate's `TServer::listen_uds()` blocks forever on `accept()` with no
330    /// shutdown mechanism - it only exposes `new()`, `listen()`, and `listen_uds()`.
331    /// See: <https://docs.rs/thrift/latest/thrift/server/struct.TServer.html>
332    ///
333    /// More elegant alternatives and why we can't use them:
334    /// - `shutdown(fd, SHUT_RD)`: Thrift owns the socket, we have no access to the raw FD
335    /// - Async (tokio): Thrift uses a synchronous API
336    /// - Non-blocking + poll: Would require modifying thrift internals
337    /// - `close()` on listener: Doesn't reliably wake threads on Linux
338    ///
339    /// The dummy connection pattern is a documented workaround:
340    /// <https://stackoverflow.com/questions/2486335/wake-up-thread-blocked-on-accept-call>
341    ///
342    /// # How It Works
343    ///
344    /// 1. Shutdown flag is set (by caller)
345    /// 2. We connect to our own socket, which unblocks `accept()`
346    /// 3. The listener thread receives the connection, checks shutdown flag, and exits
347    /// 4. The connection is immediately dropped (never read from)
348    fn wake_listener(&self) {
349        if let Some(ref path) = self.listen_path {
350            let _ = std::os::unix::net::UnixStream::connect(path);
351        }
352    }
353
354    /// Notify all registered plugins that shutdown is occurring.
355    /// Uses catch_unwind to ensure all plugins are notified even if one panics.
356    fn notify_plugins_shutdown(&self) {
357        log::debug!("Notifying {} plugins of shutdown", self.plugins.len());
358        for plugin in &self.plugins {
359            let plugin_name = plugin.name();
360            if let Err(e) = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
361                plugin.shutdown();
362            })) {
363                log::error!("Plugin '{plugin_name}' panicked during shutdown: {e:?}");
364            }
365        }
366    }
367
368    /// Clean up the socket file created during start().
369    /// Logs errors (except NotFound, which is expected if socket was already cleaned up).
370    fn cleanup_socket(&self) {
371        let Some(uuid) = self.uuid else {
372            log::debug!("No socket to clean up (uuid not set)");
373            return;
374        };
375
376        let socket_path = format!("{}.{}", self.socket_path, uuid);
377        log::debug!("Cleaning up socket: {socket_path}");
378
379        if let Err(e) = std::fs::remove_file(&socket_path) {
380            if e.kind() != std::io::ErrorKind::NotFound {
381                log::warn!("Failed to remove socket file {socket_path}: {e}");
382            }
383        }
384    }
385
386    /// Get a handle that can be used to stop the server from another thread.
387    ///
388    /// The returned handle can be cloned and shared across threads. Calling
389    /// `stop()` on the handle will cause the server's `run()` method to exit
390    /// gracefully on the next iteration.
391    pub fn get_stop_handle(&self) -> ServerStopHandle {
392        ServerStopHandle {
393            shutdown_flag: self.shutdown_flag.clone(),
394        }
395    }
396
397    /// Request the server to stop.
398    ///
399    /// This is a convenience method equivalent to calling `stop()` on a
400    /// `ServerStopHandle`. The server will exit its `run()` loop on the next
401    /// iteration.
402    pub fn stop(&self) {
403        self.request_shutdown();
404    }
405
406    /// Check if the server is still running.
407    ///
408    /// Returns `true` if the server has not been requested to stop,
409    /// `false` if `stop()` has been called or shutdown has been triggered
410    /// by another mechanism (e.g., osquery shutdown RPC, connection loss).
411    pub fn is_running(&self) -> bool {
412        !self.should_shutdown()
413    }
414}
415
416struct Handler<P: OsqueryPlugin + Clone> {
417    registry: HashMap<String, HashMap<String, P>>,
418    shutdown_flag: Arc<AtomicBool>,
419}
420
421impl<P: OsqueryPlugin + Clone> Handler<P> {
422    fn new(plugins: &[P], shutdown_flag: Arc<AtomicBool>) -> thrift::Result<Self> {
423        let mut reg: HashMap<String, HashMap<String, P>> = HashMap::new();
424        for var in Registry::VARIANTS {
425            reg.insert((*var).to_string(), HashMap::new());
426        }
427
428        for plugin in plugins.iter() {
429            reg.get_mut(plugin.registry().to_string().as_str())
430                .ok_or_thrift_err(|| format!("Failed to register plugin {}", plugin.name()))?
431                .insert(plugin.name(), plugin.clone());
432        }
433
434        Ok(Handler {
435            registry: reg,
436            shutdown_flag,
437        })
438    }
439}
440
441impl<P: OsqueryPlugin + Clone> osquery::ExtensionSyncHandler for Handler<P> {
442    fn handle_ping(&self) -> thrift::Result<osquery::ExtensionStatus> {
443        Ok(osquery::ExtensionStatus::default())
444    }
445
446    fn handle_call(
447        &self,
448        registry: String,
449        item: String,
450        request: osquery::ExtensionPluginRequest,
451    ) -> thrift::Result<osquery::ExtensionResponse> {
452        log::trace!("Registry: {registry}");
453        log::trace!("Item: {item}");
454        log::trace!("Request: {request:?}");
455
456        let plugin = self
457            .registry
458            .get(registry.as_str())
459            .ok_or_thrift_err(|| {
460                format!(
461                    "Failed to get registry:{} from registries",
462                    registry.as_str()
463                )
464            })?
465            .get(item.as_str())
466            .ok_or_thrift_err(|| {
467                format!(
468                    "Failed to item:{} from registry:{}",
469                    item.as_str(),
470                    registry.as_str()
471                )
472            })?;
473
474        Ok(plugin.handle_call(request))
475    }
476
477    fn handle_shutdown(&self) -> thrift::Result<()> {
478        log::debug!("Shutdown RPC received from osquery");
479        self.shutdown_flag.store(true, Ordering::Release);
480        Ok(())
481    }
482}
483
484impl<P: OsqueryPlugin + Clone> osquery::ExtensionManagerSyncHandler for Handler<P> {
485    fn handle_extensions(&self) -> thrift::Result<osquery::InternalExtensionList> {
486        // Extension management not supported - return empty list
487        Ok(osquery::InternalExtensionList::new())
488    }
489
490    fn handle_options(&self) -> thrift::Result<osquery::InternalOptionList> {
491        // Extension options not supported - return empty list
492        Ok(osquery::InternalOptionList::new())
493    }
494
495    fn handle_register_extension(
496        &self,
497        _info: osquery::InternalExtensionInfo,
498        _registry: osquery::ExtensionRegistry,
499    ) -> thrift::Result<osquery::ExtensionStatus> {
500        // Nested extension registration not supported
501        Ok(osquery::ExtensionStatus {
502            code: Some(1),
503            message: Some("Extension registration not supported".to_string()),
504            uuid: None,
505        })
506    }
507
508    fn handle_deregister_extension(
509        &self,
510        _uuid: osquery::ExtensionRouteUUID,
511    ) -> thrift::Result<osquery::ExtensionStatus> {
512        // Nested extension deregistration not supported
513        Ok(osquery::ExtensionStatus {
514            code: Some(1),
515            message: Some("Extension deregistration not supported".to_string()),
516            uuid: None,
517        })
518    }
519
520    fn handle_query(&self, _sql: String) -> thrift::Result<osquery::ExtensionResponse> {
521        // Query execution not supported
522        Ok(osquery::ExtensionResponse::new(
523            osquery::ExtensionStatus {
524                code: Some(1),
525                message: Some("Query execution not supported".to_string()),
526                uuid: None,
527            },
528            vec![],
529        ))
530    }
531
532    fn handle_get_query_columns(&self, _sql: String) -> thrift::Result<osquery::ExtensionResponse> {
533        // Query column introspection not supported
534        Ok(osquery::ExtensionResponse::new(
535            osquery::ExtensionStatus {
536                code: Some(1),
537                message: Some("Query column introspection not supported".to_string()),
538                uuid: None,
539            },
540            vec![],
541        ))
542    }
543}