osquery_rust_ng/
server.rs

1use clap::crate_name;
2use std::collections::HashMap;
3use std::io::Error;
4use std::os::unix::net::UnixStream;
5use std::thread;
6use std::time::Duration;
7use strum::VariantNames;
8use thrift::protocol::*;
9use thrift::transport::*;
10
11use crate::_osquery as osquery;
12use crate::_osquery::{TExtensionManagerSyncClient, TExtensionSyncClient};
13use crate::client::Client;
14use crate::plugin::{OsqueryPlugin, Plugin, Registry};
15use crate::util::OptionToThriftResult;
16
17const DEFAULT_TIMEOUT: Duration = Duration::from_millis(1000);
18const DEFAULT_PING_INTERVAL: Duration = Duration::from_millis(5000);
19
20#[allow(clippy::type_complexity)]
21pub struct Server<P: OsqueryPlugin + Clone + Send + Sync + 'static> {
22    name: String,
23    socket_path: String,
24    client: Client,
25    plugins: Vec<P>,
26    server: Option<
27        thrift::server::TServer<
28            osquery::ExtensionManagerSyncProcessor<Handler<P>>,
29            Box<dyn TReadTransportFactory>,
30            Box<dyn TInputProtocolFactory>,
31            Box<dyn TWriteTransportFactory>,
32            Box<dyn TOutputProtocolFactory>,
33        >,
34    >,
35    #[allow(dead_code)]
36    transport: Option<
37        osquery::ExtensionSyncClient<
38            TBinaryInputProtocol<UnixStream>,
39            TBinaryOutputProtocol<UnixStream>,
40        >,
41    >,
42    #[allow(dead_code)]
43    timeout: Duration,
44    ping_interval: Duration,
45    //mutex: Mutex<u32>,
46    uuid: Option<osquery::ExtensionRouteUUID>,
47    // Used to ensure tests wait until the server is actually started
48    started: bool,
49}
50
51impl<P: OsqueryPlugin + Clone + Send + 'static> Server<P> {
52    pub fn new(name: Option<&str>, socket_path: &str) -> Result<Self, Error> {
53        let mut reg: HashMap<String, HashMap<String, Plugin>> = HashMap::new();
54        for var in Registry::VARIANTS {
55            reg.insert((*var).to_string(), HashMap::new());
56        }
57
58        let name = name.unwrap_or(crate_name!());
59
60        let client = Client::new(socket_path, Default::default())?;
61
62        Ok(Server {
63            name: name.to_string(),
64            socket_path: socket_path.to_string(),
65            client,
66            plugins: Vec::new(),
67            server: None,
68            transport: None,
69            timeout: DEFAULT_TIMEOUT,
70            ping_interval: DEFAULT_PING_INTERVAL,
71            uuid: None,
72            started: false,
73        })
74    }
75
76    ///
77    /// Registers a plugin, something which implements the OsqueryPlugin trait.
78    /// Consumes the plugin.
79    ///
80    pub fn register_plugin(&mut self, plugin: P) -> &Self {
81        self.plugins.push(plugin);
82        self
83    }
84
85    pub fn run(&mut self) -> thrift::Result<()> {
86        self.start()?;
87        loop {
88            self.client.ping()?;
89            thread::sleep(self.ping_interval);
90        }
91    }
92
93    fn start(&mut self) -> thrift::Result<()> {
94        let stat = self.client.register_extension(
95            osquery::InternalExtensionInfo {
96                name: Some(self.name.clone()),
97                version: Some("1.0".to_string()),
98                sdk_version: Some("Unknown".to_string()),
99                min_sdk_version: Some("Unknown".to_string()),
100            },
101            self.generate_registry()?,
102        )?;
103
104        //if stat.code != Some(0) {
105        log::info!(
106            "Status {} registering extension {} ({}): {}",
107            stat.code.unwrap_or(0),
108            self.name,
109            stat.uuid.unwrap_or(0),
110            stat.message.unwrap_or_else(|| "No message".to_string())
111        );
112        //}
113
114        self.uuid = stat.uuid;
115        let listen_path = format!("{}.{}", self.socket_path, self.uuid.unwrap_or(0));
116
117        let processor = osquery::ExtensionManagerSyncProcessor::new(Handler::new(&self.plugins)?);
118        let i_tr_fact: Box<dyn TReadTransportFactory> =
119            Box::new(TBufferedReadTransportFactory::new());
120        let i_pr_fact: Box<dyn TInputProtocolFactory> =
121            Box::new(TBinaryInputProtocolFactory::new());
122        let o_tr_fact: Box<dyn TWriteTransportFactory> =
123            Box::new(TBufferedWriteTransportFactory::new());
124        let o_pr_fact: Box<dyn TOutputProtocolFactory> =
125            Box::new(TBinaryOutputProtocolFactory::new());
126
127        let mut server =
128            thrift::server::TServer::new(i_tr_fact, i_pr_fact, o_tr_fact, o_pr_fact, processor, 10);
129
130        match server.listen_uds(listen_path.clone()) {
131            Ok(_) => {}
132            Err(e) => {
133                log::error!("FATAL: {e} while binding to {listen_path}")
134            }
135        }
136        self.server = Some(server);
137
138        self.started = true;
139
140        Ok(())
141    }
142
143    fn generate_registry(&self) -> thrift::Result<osquery::ExtensionRegistry> {
144        let mut registry = osquery::ExtensionRegistry::new();
145
146        for var in Registry::VARIANTS {
147            registry.insert((*var).to_string(), osquery::ExtensionRouteTable::new());
148        }
149
150        for plugin in self.plugins.iter() {
151            registry
152                .get_mut(plugin.registry().to_string().as_str())
153                .ok_or_thrift_err(|| format!("Failed to register plugin {}", plugin.name()))?
154                .insert(plugin.name(), plugin.routes());
155        }
156        Ok(registry)
157    }
158}
159
160struct Handler<P: OsqueryPlugin + Clone> {
161    registry: HashMap<String, HashMap<String, P>>,
162}
163
164impl<P: OsqueryPlugin + Clone> Handler<P> {
165    fn new(plugins: &[P]) -> thrift::Result<Self> {
166        let mut reg: HashMap<String, HashMap<String, P>> = HashMap::new();
167        for var in Registry::VARIANTS {
168            reg.insert((*var).to_string(), HashMap::new());
169        }
170
171        for plugin in plugins.iter() {
172            reg.get_mut(plugin.registry().to_string().as_str())
173                .ok_or_thrift_err(|| format!("Failed to register plugin {}", plugin.name()))?
174                .insert(plugin.name(), plugin.clone());
175        }
176
177        Ok(Handler { registry: reg })
178    }
179}
180
181impl<P: OsqueryPlugin + Clone> osquery::ExtensionSyncHandler for Handler<P> {
182    fn handle_ping(&self) -> thrift::Result<osquery::ExtensionStatus> {
183        Ok(osquery::ExtensionStatus::default())
184    }
185
186    fn handle_call(
187        &self,
188        registry: String,
189        item: String,
190        request: osquery::ExtensionPluginRequest,
191    ) -> thrift::Result<osquery::ExtensionResponse> {
192        log::trace!("Registry: {registry}");
193        log::trace!("Item: {item}");
194        log::trace!("Request: {request:?}");
195
196        let plugin = self
197            .registry
198            .get(registry.as_str())
199            .ok_or_thrift_err(|| {
200                format!(
201                    "Failed to get registry:{} from registries",
202                    registry.as_str()
203                )
204            })?
205            .get(item.as_str())
206            .ok_or_thrift_err(|| {
207                format!(
208                    "Failed to item:{} from registry:{}",
209                    item.as_str(),
210                    registry.as_str()
211                )
212            })?;
213
214        Ok(plugin.handle_call(request))
215    }
216
217    fn handle_shutdown(&self) -> thrift::Result<()> {
218        log::trace!("Shutdown");
219
220        self.registry.iter().for_each(|(_, v)| {
221            v.iter().for_each(|(_, p)| {
222                p.shutdown();
223            });
224        });
225
226        Ok(())
227    }
228}
229
230impl<P: OsqueryPlugin + Clone> osquery::ExtensionManagerSyncHandler for Handler<P> {
231    fn handle_extensions(&self) -> thrift::Result<osquery::InternalExtensionList> {
232        todo!()
233    }
234
235    fn handle_options(&self) -> thrift::Result<osquery::InternalOptionList> {
236        todo!()
237    }
238
239    fn handle_register_extension(
240        &self,
241        _info: osquery::InternalExtensionInfo,
242        _registry: osquery::ExtensionRegistry,
243    ) -> thrift::Result<osquery::ExtensionStatus> {
244        todo!()
245    }
246
247    fn handle_deregister_extension(
248        &self,
249        _uuid: osquery::ExtensionRouteUUID,
250    ) -> thrift::Result<osquery::ExtensionStatus> {
251        todo!()
252    }
253
254    fn handle_query(&self, _sql: String) -> thrift::Result<osquery::ExtensionResponse> {
255        todo!()
256    }
257
258    fn handle_get_query_columns(&self, _sql: String) -> thrift::Result<osquery::ExtensionResponse> {
259        todo!()
260    }
261}