Skip to main content

hipcheck_sdk/
engine.rs

1// SPDX-License-Identifier: Apache-2.0
2
3use crate::{
4	JsonValue, Plugin, QueryTarget,
5	error::{Error, Result},
6};
7use futures::Stream;
8use hipcheck_common::proto::{
9	self, InitiateQueryProtocolRequest, InitiateQueryProtocolResponse, Query as PluginQuery,
10	QueryState,
11};
12use hipcheck_common::{
13	chunk::QuerySynthesizer,
14	types::{Query, QueryDirection},
15};
16use serde::Serialize;
17use std::{
18	collections::{HashMap, VecDeque},
19	future::poll_fn,
20	pin::Pin,
21	result::Result as StdResult,
22	sync::Arc,
23};
24use tokio::sync::mpsc::{self, error::TrySendError};
25use tonic::Status;
26
27impl From<Status> for Error {
28	fn from(_value: Status) -> Error {
29		// TODO: higher-fidelity handling?
30		Error::SessionChannelClosed
31	}
32}
33
34type SessionTracker = HashMap<i32, mpsc::Sender<Option<PluginQuery>>>;
35
36/// Used for building a up a `Vec` of keys to send to specific hipcheck plugin
37pub struct QueryBuilder<'engine> {
38	keys: Vec<JsonValue>,
39	target: QueryTarget,
40	plugin_engine: &'engine mut PluginEngine,
41}
42
43impl<'engine> QueryBuilder<'engine> {
44	/// Create a new `QueryBuilder` to dynamically add keys to send to `target` plugin
45	fn new<T>(plugin_engine: &'engine mut PluginEngine, target: T) -> Result<QueryBuilder<'engine>>
46	where
47		T: TryInto<QueryTarget, Error: Into<Error>>,
48	{
49		let target: QueryTarget = target.try_into().map_err(|e| e.into())?;
50		Ok(Self {
51			plugin_engine,
52			target,
53			keys: vec![],
54		})
55	}
56
57	/// Add a key to the internal list of keys to be sent to `target`
58	///
59	/// Returns the index `key` was inserted was inserted to
60	pub fn query(&mut self, key: JsonValue) -> usize {
61		let len = self.keys.len();
62		self.keys.push(key);
63		len
64	}
65
66	/// Send all of the provided keys to `target` plugin endpont and wait for query results
67	pub async fn send(self) -> Result<Vec<JsonValue>> {
68		self.plugin_engine.batch_query(self.target, self.keys).await
69	}
70}
71
72/// Manages a particular query session.
73///
74/// This struct invokes a `Query` trait object, passing a handle to itself to `Query::run()`. This
75/// allows the query logic to request information from other Hipcheck plugins in order to complete.
76pub struct PluginEngine {
77	id: usize,
78	tx: mpsc::Sender<StdResult<InitiateQueryProtocolResponse, Status>>,
79	rx: mpsc::Receiver<Option<PluginQuery>>,
80	concerns: Vec<String>,
81	// So that we can remove ourselves when we get dropped
82	drop_tx: mpsc::Sender<i32>,
83	// When unit testing, this enables the user to mock plugin responses to various inputs
84	mock_responses: MockResponses,
85}
86
87impl PluginEngine {
88	#[cfg(feature = "mock_engine")]
89	#[cfg_attr(docsrs, doc(cfg(feature = "mock_engine")))]
90	/// Constructor for use in unit tests, `query()` function will reference this map instead of
91	/// trying to connect to Hipcheck core for a response value
92	pub fn mock(mock_responses: MockResponses) -> Self {
93		mock_responses.into()
94	}
95
96	/// Convenience function to expose a `QueryBuilder` to make it convenient to dynamically build
97	/// up queries to plugins and send them off to the `target` plugin, in as few GRPC calls as
98	/// possible
99	pub fn batch<T>(&mut self, target: T) -> Result<QueryBuilder<'_>>
100	where
101		T: TryInto<QueryTarget, Error: Into<Error>>,
102	{
103		QueryBuilder::new(self, target)
104	}
105
106	async fn query_inner(
107		&mut self,
108		target: QueryTarget,
109		input: Vec<JsonValue>,
110	) -> Result<Vec<JsonValue>> {
111		// If doing a mock engine, look to the `mock_responses` field for the query answer
112		if cfg!(feature = "mock_engine") {
113			let mut results = Vec::with_capacity(input.len());
114			for i in input {
115				match self.mock_responses.0.get(&(target.clone(), i)) {
116					Some(res) => match res {
117						Ok(val) => results.push(val.clone()),
118						Err(e) => {
119							tracing::error!("Error parsing mock_engine response: {e}");
120							return Err(Error::UnexpectedPluginQueryInputFormat);
121						}
122					},
123					None => {
124						return Err(Error::UnknownPluginQuery(
125							target.to_string().into_boxed_str(),
126						));
127					}
128				}
129			}
130			Ok(results)
131		}
132		// Normal execution, send messages to hipcheck core to query other plugin
133		else {
134			let query = Query {
135				id: 0,
136				direction: QueryDirection::Request,
137				publisher: target.publisher,
138				plugin: target.plugin,
139				query: target.query.unwrap_or_else(|| "".to_owned()),
140				key: input,
141				output: vec![],
142				concerns: vec![],
143			};
144			self.send(query).await?;
145			let response = self.recv().await?;
146			match response {
147				Some(response) => Ok(response.output),
148				None => Err(Error::SessionChannelClosed),
149			}
150		}
151	}
152
153	/// Query another Hipcheck plugin `target` with key `input`. On success, the JSONified result
154	/// of the query is returned. `target` will often be a string of the format
155	/// `"publisher/plugin[/query]"`, where the bracketed substring is optional if the plugin's
156	/// default query endpoint is desired. `input` must be of a type implementing `serde::Serialize`,
157	pub async fn query<T, V>(&mut self, target: T, input: V) -> Result<JsonValue>
158	where
159		T: TryInto<QueryTarget, Error: Into<Error>>,
160		V: Serialize,
161	{
162		let query_target: QueryTarget = target.try_into().map_err(|e| e.into())?;
163		tracing::trace!("querying {}", query_target.to_string());
164		let input: JsonValue = serde_json::to_value(input)
165			.map_err(|source| Error::InvalidJsonInQueryKey(Box::new(source)))?;
166		// since there input had one value, there will only be one response
167		let mut response = self.query_inner(query_target, vec![input]).await?;
168		Ok(response.pop().unwrap())
169	}
170
171	/// Query another Hipcheck plugin `target` with Vec of `inputs`. On success, the JSONified result
172	/// of the query is returned. `target` will often be a string of the format
173	/// `"publisher/plugin[/query]"`, where the bracketed substring is optional if the plugin's
174	/// default query endpoint is desired. `keys` must be a Vec containing a type which implements `serde::Serialize`,
175	pub async fn batch_query<T, V>(&mut self, target: T, keys: Vec<V>) -> Result<Vec<JsonValue>>
176	where
177		T: TryInto<QueryTarget, Error: Into<Error>>,
178		V: Serialize,
179	{
180		let target: QueryTarget = target.try_into().map_err(|e| e.into())?;
181		tracing::trace!("querying {}", target.to_string());
182		let mut input = Vec::with_capacity(keys.len());
183		for key in keys {
184			let jsonified_key = serde_json::to_value(key)
185				.map_err(|source| Error::InvalidJsonInQueryKey(Box::new(source)))?;
186			input.push(jsonified_key);
187		}
188		self.query_inner(target, input).await
189	}
190
191	fn id(&self) -> usize {
192		self.id
193	}
194
195	async fn recv_raw(&mut self) -> Result<Option<VecDeque<PluginQuery>>> {
196		let mut out = VecDeque::new();
197
198		tracing::trace!("SDK: awaiting raw rx recv");
199
200		let opt_first = self.rx.recv().await.ok_or(Error::SessionChannelClosed)?;
201
202		let Some(first) = opt_first else {
203			// Underlying gRPC channel closed
204			return Ok(None);
205		};
206		out.push_back(first);
207
208		// If more messages in the queue, opportunistically read more
209		loop {
210			match self.rx.try_recv() {
211				Ok(Some(msg)) => {
212					out.push_back(msg);
213				}
214				Ok(None) => {
215					tracing::warn!(
216						"None received, gRPC channel closed. we may not close properly if None is not returned again"
217					);
218					break;
219				}
220				// Whether empty or disconnected, we return what we have
221				Err(_) => {
222					break;
223				}
224			}
225		}
226
227		Ok(Some(out))
228	}
229
230	// Send a gRPC query from plugin to the hipcheck server
231	async fn send(&self, mut query: Query) -> Result<()> {
232		query.id = self.id(); // incoming id value is just a placeholder
233		let queries = hipcheck_common::chunk::prepare(query)?;
234		for pq in queries {
235			let query = InitiateQueryProtocolResponse { query: Some(pq) };
236			self.tx
237				.send(Ok(query))
238				.await
239				.map_err(|source| Error::FailedToSendQueryFromSessionToServer(Box::new(source)))?;
240		}
241		Ok(())
242	}
243
244	async fn send_session_err<P>(&mut self) -> crate::error::Result<()>
245	where
246		P: Plugin,
247	{
248		let query = proto::Query {
249			id: self.id() as i32,
250			state: QueryState::Unspecified as i32,
251			publisher_name: P::PUBLISHER.to_owned(),
252			plugin_name: P::NAME.to_owned(),
253			query_name: "".to_owned(),
254			key: vec![],
255			output: vec![],
256			concern: self.take_concerns(),
257			split: false,
258		};
259		self.tx
260			.send(Ok(InitiateQueryProtocolResponse { query: Some(query) }))
261			.await
262			.map_err(|source| Error::FailedToSendQueryFromSessionToServer(Box::new(source)))
263	}
264
265	async fn recv(&mut self) -> Result<Option<Query>> {
266		let mut synth = QuerySynthesizer::default();
267		let mut res: Option<Query> = None;
268		while res.is_none() {
269			let Some(msg_chunks) = self.recv_raw().await? else {
270				return Ok(None);
271			};
272			res = synth.add(msg_chunks.into_iter())?;
273		}
274		Ok(res)
275	}
276
277	async fn handle_session_fallible<P>(&mut self, plugin: Arc<P>) -> crate::error::Result<()>
278	where
279		P: Plugin,
280	{
281		let Some(query) = self.recv().await? else {
282			return Err(Error::SessionChannelClosed);
283		};
284
285		if query.direction == QueryDirection::Response {
286			return Err(Error::ReceivedReplyWhenExpectingRequest);
287		}
288
289		let name = query.query;
290
291		// Per RFD 0009, there should only be one query key per query
292		if query.key.len() != 1 {
293			return Err(Error::UnspecifiedQueryState);
294		}
295		let key = query.key.first().unwrap().clone();
296
297		// if we find the plugin by name, run it
298		// if not, check if there is a default plugin and run that one
299		// otherwise error out
300		let query = plugin
301			.queries()
302			.filter_map(|x| if x.name == name { Some(x.inner) } else { None })
303			.next()
304			.or_else(|| plugin.default_query())
305			.ok_or_else(|| {
306				if name.is_empty() {
307					Error::NoDefaultQuery
308				} else {
309					Error::UnknownPluginQuery(name.clone().into_boxed_str())
310				}
311			})?;
312
313		#[cfg(feature = "print-timings")]
314		let _0 = crate::benchmarking::print_scope_time!(format!("{}/{}", P::NAME, name));
315
316		let value = query.run(self, key).await?;
317
318		#[cfg(feature = "print-timings")]
319		drop(_0);
320
321		let query = Query {
322			id: self.id(),
323			direction: QueryDirection::Response,
324			publisher: P::PUBLISHER.to_owned(),
325			plugin: P::NAME.to_owned(),
326			query: name.to_owned(),
327			key: vec![],
328			output: vec![value],
329			concerns: self.take_concerns(),
330		};
331
332		self.send(query).await
333	}
334
335	async fn handle_session<P>(&mut self, plugin: Arc<P>)
336	where
337		P: Plugin,
338	{
339		if let Err(e) = self.handle_session_fallible(plugin).await {
340			let res_err_send = match e {
341				Error::FailedToSendQueryFromSessionToServer(_) => {
342					tracing::error!("Failed to send message to Hipcheck core, analysis will hang.");
343					return;
344				}
345				other => {
346					tracing::error!("{}", other);
347					self.send_session_err::<P>().await
348				}
349			};
350			if res_err_send.is_err() {
351				tracing::error!("Failed to send message to Hipcheck core, analysis will hang.");
352			}
353		}
354	}
355
356	/// Records a string-like concern that will be emitted in the final Hipcheck report. Intended
357	/// for use within a `Query` trait impl.
358	pub fn record_concern<S: AsRef<str>>(&mut self, concern: S) {
359		fn inner(engine: &mut PluginEngine, concern: &str) {
360			engine.concerns.push(concern.to_owned());
361		}
362		inner(self, concern.as_ref())
363	}
364
365	#[cfg(feature = "mock_engine")]
366	#[cfg_attr(docsrs, doc(cfg(feature = "mock_engine")))]
367	/// Exposes the current set of concerns recorded by `PluginEngine`
368	pub fn get_concerns(&self) -> &[String] {
369		&self.concerns
370	}
371
372	fn take_concerns(&mut self) -> Vec<String> {
373		self.concerns.drain(..).collect()
374	}
375}
376
377#[cfg(feature = "mock_engine")]
378#[cfg_attr(docsrs, doc(cfg(feature = "mock_engine")))]
379impl From<MockResponses> for PluginEngine {
380	fn from(value: MockResponses) -> Self {
381		let (tx, _) = mpsc::channel(1);
382		let (_, rx) = mpsc::channel(1);
383		let (drop_tx, _) = mpsc::channel(1);
384
385		Self {
386			id: 0,
387			concerns: vec![],
388			tx,
389			rx,
390			drop_tx,
391			mock_responses: value,
392		}
393	}
394}
395
396impl Drop for PluginEngine {
397	// Notify to have self removed from session tracker
398	fn drop(&mut self) {
399		if cfg!(feature = "mock_engine") {
400			// "use" drop_tx to prevent 'unused' warning. Less messy than trying to gate the
401			// existence of "drop_tx" var itself.
402			let _ = self.drop_tx.max_capacity();
403		} else {
404			while let Err(e) = self.drop_tx.try_send(self.id as i32) {
405				match e {
406					TrySendError::Closed(_) => {
407						break;
408					}
409					TrySendError::Full(_) => (),
410				}
411			}
412		}
413	}
414}
415
416type PluginQueryStream = Box<
417	dyn Stream<Item = StdResult<InitiateQueryProtocolRequest, Status>> + Send + Unpin + 'static,
418>;
419
420pub(crate) struct HcSessionSocket {
421	tx: mpsc::Sender<StdResult<InitiateQueryProtocolResponse, Status>>,
422	rx: PluginQueryStream,
423	drop_tx: mpsc::Sender<i32>,
424	drop_rx: mpsc::Receiver<i32>,
425	sessions: SessionTracker,
426}
427
428// This is implemented manually since the stream trait object
429// can't impl `Debug`.
430impl std::fmt::Debug for HcSessionSocket {
431	fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
432		f.debug_struct("HcSessionSocket")
433			.field("tx", &self.tx)
434			.field("rx", &"<rx>")
435			.field("drop_tx", &self.drop_tx)
436			.field("drop_rx", &self.drop_rx)
437			.field("sessions", &self.sessions)
438			.finish()
439	}
440}
441
442impl HcSessionSocket {
443	pub(crate) fn new(
444		tx: mpsc::Sender<StdResult<InitiateQueryProtocolResponse, Status>>,
445		rx: impl Stream<Item = StdResult<InitiateQueryProtocolRequest, Status>> + Send + Unpin + 'static,
446	) -> Self {
447		// channel for QuerySession objects to notify us they dropped
448		// TODO: make this configurable
449		let (drop_tx, drop_rx) = mpsc::channel(10);
450		Self {
451			tx,
452			rx: Box::new(rx),
453			drop_tx,
454			drop_rx,
455			sessions: HashMap::new(),
456		}
457	}
458
459	/// Clean up completed sessions by going through all drop messages.
460	fn cleanup_sessions(&mut self) {
461		while let Ok(id) = self.drop_rx.try_recv() {
462			match self.sessions.remove(&id) {
463				Some(_) => tracing::trace!("Cleaned up session {id}"),
464				None => {
465					tracing::warn!(
466						"HcSessionSocket got request to drop a session that does not exist"
467					)
468				}
469			}
470		}
471	}
472
473	async fn message(&mut self) -> StdResult<Option<PluginQuery>, Status> {
474		let fut = poll_fn(|cx| Pin::new(&mut *self.rx).poll_next(cx));
475
476		match fut.await {
477			Some(Ok(m)) => Ok(m.query),
478			Some(Err(e)) => Err(e),
479			None => Ok(None),
480		}
481	}
482
483	pub(crate) async fn listen(&mut self) -> Result<Option<PluginEngine>> {
484		loop {
485			let Some(raw) = self.message().await.map_err(Error::from)? else {
486				return Ok(None);
487			};
488			let id = raw.id;
489
490			// While we were waiting for a message, some session objects may have
491			// dropped, handle them before we look at the ID of this message.
492			// The downside of this strategy is that once we receive our last message,
493			// we won't clean up any sessions that close after
494			self.cleanup_sessions();
495
496			match self.decide_action(&raw) {
497				Ok(HandleAction::ForwardMsgToExistingSession(tx)) => {
498					tracing::trace!("SDK: forwarding message to session {id}");
499
500					if let Err(_e) = tx.send(Some(raw)).await {
501						tracing::error!("Error forwarding msg to session {id}");
502						self.sessions.remove(&id);
503					};
504				}
505				Ok(HandleAction::CreateSession) => {
506					tracing::trace!("SDK: creating new session {id}");
507
508					let (in_tx, rx) = mpsc::channel::<Option<PluginQuery>>(10);
509					let tx = self.tx.clone();
510
511					let session = PluginEngine {
512						id: id as usize,
513						concerns: vec![],
514						tx,
515						rx,
516						drop_tx: self.drop_tx.clone(),
517						mock_responses: MockResponses::new(),
518					};
519
520					in_tx.send(Some(raw)).await.expect(
521						"Failed sending message to newly created Session, should never happen",
522					);
523
524					tracing::trace!("SDK: adding new session {id} to tracker");
525					self.sessions.insert(id, in_tx);
526
527					return Ok(Some(session));
528				}
529				Err(e) => tracing::error!("{}", e),
530			}
531		}
532	}
533
534	fn decide_action(&mut self, query: &PluginQuery) -> Result<HandleAction<'_>> {
535		if let Some(tx) = self.sessions.get_mut(&query.id) {
536			return Ok(HandleAction::ForwardMsgToExistingSession(tx));
537		}
538
539		if [QueryState::SubmitInProgress, QueryState::SubmitComplete].contains(&query.state()) {
540			return Ok(HandleAction::CreateSession);
541		}
542
543		Err(Error::ReceivedReplyWhenExpectingRequest)
544	}
545
546	pub(crate) async fn run<P>(&mut self, plugin: Arc<P>) -> Result<()>
547	where
548		P: Plugin,
549	{
550		loop {
551			let Some(mut engine) = self
552				.listen()
553				.await
554				.map_err(|_| Error::SessionChannelClosed)?
555			else {
556				tracing::trace!("Channel closed by remote");
557				break;
558			};
559
560			let cloned_plugin = plugin.clone();
561			tokio::spawn(async move {
562				engine.handle_session(cloned_plugin).await;
563			});
564		}
565
566		Ok(())
567	}
568}
569
570enum HandleAction<'s> {
571	ForwardMsgToExistingSession(&'s mut mpsc::Sender<Option<PluginQuery>>),
572	CreateSession,
573}
574
575/// A map of query endpoints to mock return values.
576///
577/// When using the `mock_engine` feature, calling `PluginEngine::query()` will cause this
578/// structure to be referenced instead of trying to communicate with Hipcheck core. Allows
579/// constructing a `PluginEngine` with which to write unit tests.
580#[derive(Default, Debug)]
581pub struct MockResponses(pub(crate) HashMap<(QueryTarget, JsonValue), Result<JsonValue>>);
582
583impl MockResponses {
584	pub fn new() -> Self {
585		Self(HashMap::new())
586	}
587}
588
589impl MockResponses {
590	#[cfg(feature = "mock_engine")]
591	pub fn insert<T, V, W>(
592		&mut self,
593		query_target: T,
594		query_value: V,
595		query_response: Result<W>,
596	) -> Result<()>
597	where
598		T: TryInto<QueryTarget, Error: Into<crate::Error>>,
599		V: serde::Serialize,
600		W: serde::Serialize,
601	{
602		let query_target: QueryTarget = query_target.try_into().map_err(|e| e.into())?;
603		let query_value: JsonValue = serde_json::to_value(query_value)
604			.map_err(|source| crate::Error::InvalidJsonInQueryKey(Box::new(source)))?;
605		let query_response = match query_response {
606			Ok(v) => serde_json::to_value(v)
607				.map_err(|source| crate::Error::InvalidJsonInQueryKey(Box::new(source))),
608			Err(e) => Err(e),
609		};
610		self.0.insert((query_target, query_value), query_response);
611		Ok(())
612	}
613}
614
615#[cfg(test)]
616mod test {
617	use super::*;
618
619	#[cfg(feature = "mock_engine")]
620	#[tokio::test]
621	async fn test_query_builder() {
622		let mut mock_responses = MockResponses::new();
623		mock_responses
624			.insert("mitre/foo", "abcd", Ok(1234))
625			.unwrap();
626		mock_responses
627			.insert("mitre/foo", "efgh", Ok(5678))
628			.unwrap();
629		let mut engine = PluginEngine::mock(mock_responses);
630		let mut builder = engine.batch("mitre/foo").unwrap();
631		let idx = builder.query("abcd".into());
632		assert_eq!(idx, 0);
633		let idx = builder.query("efgh".into());
634		assert_eq!(idx, 1);
635		let response = builder.send().await.unwrap();
636		assert_eq!(
637			response.first().unwrap(),
638			&<i32 as Into<JsonValue>>::into(1234)
639		);
640		assert_eq!(
641			response.get(1).unwrap(),
642			&<i32 as Into<JsonValue>>::into(5678)
643		);
644	}
645}