hipcheck_sdk/
engine.rs

1// SPDX-License-Identifier: Apache-2.0
2
3use crate::{
4	error::{Error, Result},
5	JsonValue, Plugin, QueryTarget,
6};
7use futures::Stream;
8use hipcheck_common::proto::{
9	self, InitiateQueryProtocolRequest, InitiateQueryProtocolResponse, Query as PluginQuery,
10	QueryState,
11};
12use hipcheck_common::{
13	chunk::QuerySynthesizer,
14	types::{Query, QueryDirection},
15};
16use serde::Serialize;
17use std::{
18	collections::{HashMap, VecDeque},
19	future::poll_fn,
20	pin::Pin,
21	result::Result as StdResult,
22	sync::Arc,
23};
24use tokio::sync::mpsc::{self, error::TrySendError};
25use tonic::Status;
26
27impl From<Status> for Error {
28	fn from(_value: Status) -> Error {
29		// TODO: higher-fidelity handling?
30		Error::SessionChannelClosed
31	}
32}
33
34type SessionTracker = HashMap<i32, mpsc::Sender<Option<PluginQuery>>>;
35
36/// Used for building a up a `Vec` of keys to send to specific hipcheck plugin
37pub struct QueryBuilder<'engine> {
38	keys: Vec<JsonValue>,
39	target: QueryTarget,
40	plugin_engine: &'engine mut PluginEngine,
41}
42
43impl<'engine> QueryBuilder<'engine> {
44	/// Create a new `QueryBuilder` to dynamically add keys to send to `target` plugin
45	fn new<T>(plugin_engine: &'engine mut PluginEngine, target: T) -> Result<QueryBuilder<'engine>>
46	where
47		T: TryInto<QueryTarget, Error: Into<Error>>,
48	{
49		let target: QueryTarget = target.try_into().map_err(|e| e.into())?;
50		Ok(Self {
51			plugin_engine,
52			target,
53			keys: vec![],
54		})
55	}
56
57	/// Add a key to the internal list of keys to be sent to `target`
58	///
59	/// Returns the index `key` was inserted was inserted to
60	pub fn query(&mut self, key: JsonValue) -> usize {
61		let len = self.keys.len();
62		self.keys.push(key);
63		len
64	}
65
66	/// Send all of the provided keys to `target` plugin endpont and wait for query results
67	pub async fn send(self) -> Result<Vec<JsonValue>> {
68		self.plugin_engine.batch_query(self.target, self.keys).await
69	}
70}
71
72/// Manages a particular query session.
73///
74/// This struct invokes a `Query` trait object, passing a handle to itself to `Query::run()`. This
75/// allows the query logic to request information from other Hipcheck plugins in order to complete.
76pub struct PluginEngine {
77	id: usize,
78	tx: mpsc::Sender<StdResult<InitiateQueryProtocolResponse, Status>>,
79	rx: mpsc::Receiver<Option<PluginQuery>>,
80	concerns: Vec<String>,
81	// So that we can remove ourselves when we get dropped
82	drop_tx: mpsc::Sender<i32>,
83	// When unit testing, this enables the user to mock plugin responses to various inputs
84	mock_responses: MockResponses,
85}
86
87impl PluginEngine {
88	#[cfg(feature = "mock_engine")]
89	#[cfg_attr(docsrs, doc(cfg(feature = "mock_engine")))]
90	/// Constructor for use in unit tests, `query()` function will reference this map instead of
91	/// trying to connect to Hipcheck core for a response value
92	pub fn mock(mock_responses: MockResponses) -> Self {
93		mock_responses.into()
94	}
95
96	/// Convenience function to expose a `QueryBuilder` to make it convenient to dynamically build
97	/// up queries to plugins and send them off to the `target` plugin, in as few GRPC calls as
98	/// possible
99	pub fn batch<T>(&mut self, target: T) -> Result<QueryBuilder>
100	where
101		T: TryInto<QueryTarget, Error: Into<Error>>,
102	{
103		QueryBuilder::new(self, target)
104	}
105
106	async fn query_inner(
107		&mut self,
108		target: QueryTarget,
109		input: Vec<JsonValue>,
110	) -> Result<Vec<JsonValue>> {
111		// If doing a mock engine, look to the `mock_responses` field for the query answer
112		if cfg!(feature = "mock_engine") {
113			let mut results = Vec::with_capacity(input.len());
114			for i in input {
115				match self.mock_responses.0.get(&(target.clone(), i)) {
116					Some(res) => {
117						match res {
118							Ok(val) => results.push(val.clone()),
119							// TODO: since Error is not Clone, is there a better way to deal with this
120							Err(_) => return Err(Error::UnexpectedPluginQueryInputFormat),
121						}
122					}
123					None => return Err(Error::UnknownPluginQuery),
124				}
125			}
126			Ok(results)
127		}
128		// Normal execution, send messages to hipcheck core to query other plugin
129		else {
130			let query = Query {
131				id: 0,
132				direction: QueryDirection::Request,
133				publisher: target.publisher,
134				plugin: target.plugin,
135				query: target.query.unwrap_or_else(|| "".to_owned()),
136				key: input,
137				output: vec![],
138				concerns: vec![],
139			};
140			self.send(query).await?;
141			let response = self.recv().await?;
142			match response {
143				Some(response) => Ok(response.output),
144				None => Err(Error::SessionChannelClosed),
145			}
146		}
147	}
148
149	/// Query another Hipcheck plugin `target` with key `input`. On success, the JSONified result
150	/// of the query is returned. `target` will often be a string of the format
151	/// `"publisher/plugin[/query]"`, where the bracketed substring is optional if the plugin's
152	/// default query endpoint is desired. `input` must be of a type implementing `serde::Serialize`,
153	pub async fn query<T, V>(&mut self, target: T, input: V) -> Result<JsonValue>
154	where
155		T: TryInto<QueryTarget, Error: Into<Error>>,
156		V: Serialize,
157	{
158		let query_target: QueryTarget = target.try_into().map_err(|e| e.into())?;
159		tracing::trace!("querying {}", query_target.to_string());
160		let input: JsonValue = serde_json::to_value(input).map_err(Error::InvalidJsonInQueryKey)?;
161		// since there input had one value, there will only be one response
162		let mut response = self.query_inner(query_target, vec![input]).await?;
163		Ok(response.pop().unwrap())
164	}
165
166	/// Query another Hipcheck plugin `target` with Vec of `inputs`. On success, the JSONified result
167	/// of the query is returned. `target` will often be a string of the format
168	/// `"publisher/plugin[/query]"`, where the bracketed substring is optional if the plugin's
169	/// default query endpoint is desired. `keys` must be a Vec containing a type which implements `serde::Serialize`,
170	pub async fn batch_query<T, V>(&mut self, target: T, keys: Vec<V>) -> Result<Vec<JsonValue>>
171	where
172		T: TryInto<QueryTarget, Error: Into<Error>>,
173		V: Serialize,
174	{
175		let target: QueryTarget = target.try_into().map_err(|e| e.into())?;
176		tracing::trace!("querying {}", target.to_string());
177		let mut input = Vec::with_capacity(keys.len());
178		for key in keys {
179			let jsonified_key = serde_json::to_value(key).map_err(Error::InvalidJsonInQueryKey)?;
180			input.push(jsonified_key);
181		}
182		self.query_inner(target, input).await
183	}
184
185	fn id(&self) -> usize {
186		self.id
187	}
188
189	async fn recv_raw(&mut self) -> Result<Option<VecDeque<PluginQuery>>> {
190		let mut out = VecDeque::new();
191
192		tracing::trace!("SDK: awaiting raw rx recv");
193
194		let opt_first = self.rx.recv().await.ok_or(Error::SessionChannelClosed)?;
195
196		let Some(first) = opt_first else {
197			// Underlying gRPC channel closed
198			return Ok(None);
199		};
200		out.push_back(first);
201
202		// If more messages in the queue, opportunistically read more
203		loop {
204			match self.rx.try_recv() {
205				Ok(Some(msg)) => {
206					out.push_back(msg);
207				}
208				Ok(None) => {
209					tracing::warn!("None received, gRPC channel closed. we may not close properly if None is not returned again");
210					break;
211				}
212				// Whether empty or disconnected, we return what we have
213				Err(_) => {
214					break;
215				}
216			}
217		}
218
219		Ok(Some(out))
220	}
221
222	// Send a gRPC query from plugin to the hipcheck server
223	async fn send(&self, mut query: Query) -> Result<()> {
224		query.id = self.id(); // incoming id value is just a placeholder
225		let queries = hipcheck_common::chunk::prepare(query)?;
226		for pq in queries {
227			let query = InitiateQueryProtocolResponse { query: Some(pq) };
228			self.tx
229				.send(Ok(query))
230				.await
231				.map_err(Error::FailedToSendQueryFromSessionToServer)?;
232		}
233		Ok(())
234	}
235
236	async fn send_session_err<P>(&mut self) -> crate::error::Result<()>
237	where
238		P: Plugin,
239	{
240		let query = proto::Query {
241			id: self.id() as i32,
242			state: QueryState::Unspecified as i32,
243			publisher_name: P::PUBLISHER.to_owned(),
244			plugin_name: P::NAME.to_owned(),
245			query_name: "".to_owned(),
246			key: vec![],
247			output: vec![],
248			concern: self.take_concerns(),
249			split: false,
250		};
251		self.tx
252			.send(Ok(InitiateQueryProtocolResponse { query: Some(query) }))
253			.await
254			.map_err(Error::FailedToSendQueryFromSessionToServer)
255	}
256
257	async fn recv(&mut self) -> Result<Option<Query>> {
258		let mut synth = QuerySynthesizer::default();
259		let mut res: Option<Query> = None;
260		while res.is_none() {
261			let Some(msg_chunks) = self.recv_raw().await? else {
262				return Ok(None);
263			};
264			res = synth.add(msg_chunks.into_iter())?;
265		}
266		Ok(res)
267	}
268
269	async fn handle_session_fallible<P>(&mut self, plugin: Arc<P>) -> crate::error::Result<()>
270	where
271		P: Plugin,
272	{
273		let Some(query) = self.recv().await? else {
274			return Err(Error::SessionChannelClosed);
275		};
276
277		if query.direction == QueryDirection::Response {
278			return Err(Error::ReceivedReplyWhenExpectingRequest);
279		}
280
281		let name = query.query;
282
283		// Per RFD 0009, there should only be one query key per query
284		if query.key.len() != 1 {
285			return Err(Error::UnspecifiedQueryState);
286		}
287		let key = query.key.first().unwrap().clone();
288
289		// if we find the plugin by name, run it
290		// if not, check if there is a default plugin and run that one
291		// otherwise error out
292		let query = plugin
293			.queries()
294			.filter_map(|x| if x.name == name { Some(x.inner) } else { None })
295			.next()
296			.or_else(|| plugin.default_query())
297			.ok_or_else(|| Error::UnknownPluginQuery)?;
298
299		#[cfg(feature = "print-timings")]
300		let _0 = crate::benchmarking::print_scope_time!(format!("{}/{}", P::NAME, name));
301
302		let value = query.run(self, key).await?;
303
304		#[cfg(feature = "print-timings")]
305		drop(_0);
306
307		let query = Query {
308			id: self.id(),
309			direction: QueryDirection::Response,
310			publisher: P::PUBLISHER.to_owned(),
311			plugin: P::NAME.to_owned(),
312			query: name.to_owned(),
313			key: vec![],
314			output: vec![value],
315			concerns: self.take_concerns(),
316		};
317
318		self.send(query).await
319	}
320
321	async fn handle_session<P>(&mut self, plugin: Arc<P>)
322	where
323		P: Plugin,
324	{
325		use crate::error::Error::*;
326		if let Err(e) = self.handle_session_fallible(plugin).await {
327			let res_err_send = match e {
328				FailedToSendQueryFromSessionToServer(_) => {
329					tracing::error!("Failed to send message to Hipcheck core, analysis will hang.");
330					return;
331				}
332				other => {
333					tracing::error!("{}", other);
334					self.send_session_err::<P>().await
335				}
336			};
337			if res_err_send.is_err() {
338				tracing::error!("Failed to send message to Hipcheck core, analysis will hang.");
339			}
340		}
341	}
342
343	/// Records a string-like concern that will be emitted in the final Hipcheck report. Intended
344	/// for use within a `Query` trait impl.
345	pub fn record_concern<S: AsRef<str>>(&mut self, concern: S) {
346		fn inner(engine: &mut PluginEngine, concern: &str) {
347			engine.concerns.push(concern.to_owned());
348		}
349		inner(self, concern.as_ref())
350	}
351
352	#[cfg(feature = "mock_engine")]
353	#[cfg_attr(docsrs, doc(cfg(feature = "mock_engine")))]
354	/// Exposes the current set of concerns recorded by `PluginEngine`
355	pub fn get_concerns(&self) -> &[String] {
356		&self.concerns
357	}
358
359	fn take_concerns(&mut self) -> Vec<String> {
360		self.concerns.drain(..).collect()
361	}
362}
363
364#[cfg(feature = "mock_engine")]
365#[cfg_attr(docsrs, doc(cfg(feature = "mock_engine")))]
366impl From<MockResponses> for PluginEngine {
367	fn from(value: MockResponses) -> Self {
368		let (tx, _) = mpsc::channel(1);
369		let (_, rx) = mpsc::channel(1);
370		let (drop_tx, _) = mpsc::channel(1);
371
372		Self {
373			id: 0,
374			concerns: vec![],
375			tx,
376			rx,
377			drop_tx,
378			mock_responses: value,
379		}
380	}
381}
382
383impl Drop for PluginEngine {
384	// Notify to have self removed from session tracker
385	fn drop(&mut self) {
386		if cfg!(feature = "mock_engine") {
387			// "use" drop_tx to prevent 'unused' warning. Less messy than trying to gate the
388			// existence of "drop_tx" var itself.
389			let _ = self.drop_tx.max_capacity();
390		} else {
391			while let Err(e) = self.drop_tx.try_send(self.id as i32) {
392				match e {
393					TrySendError::Closed(_) => {
394						break;
395					}
396					TrySendError::Full(_) => (),
397				}
398			}
399		}
400	}
401}
402
403type PluginQueryStream = Box<
404	dyn Stream<Item = StdResult<InitiateQueryProtocolRequest, Status>> + Send + Unpin + 'static,
405>;
406
407pub(crate) struct HcSessionSocket {
408	tx: mpsc::Sender<StdResult<InitiateQueryProtocolResponse, Status>>,
409	rx: PluginQueryStream,
410	drop_tx: mpsc::Sender<i32>,
411	drop_rx: mpsc::Receiver<i32>,
412	sessions: SessionTracker,
413}
414
415// This is implemented manually since the stream trait object
416// can't impl `Debug`.
417impl std::fmt::Debug for HcSessionSocket {
418	fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
419		f.debug_struct("HcSessionSocket")
420			.field("tx", &self.tx)
421			.field("rx", &"<rx>")
422			.field("drop_tx", &self.drop_tx)
423			.field("drop_rx", &self.drop_rx)
424			.field("sessions", &self.sessions)
425			.finish()
426	}
427}
428
429impl HcSessionSocket {
430	pub(crate) fn new(
431		tx: mpsc::Sender<StdResult<InitiateQueryProtocolResponse, Status>>,
432		rx: impl Stream<Item = StdResult<InitiateQueryProtocolRequest, Status>> + Send + Unpin + 'static,
433	) -> Self {
434		// channel for QuerySession objects to notify us they dropped
435		// TODO: make this configurable
436		let (drop_tx, drop_rx) = mpsc::channel(10);
437		Self {
438			tx,
439			rx: Box::new(rx),
440			drop_tx,
441			drop_rx,
442			sessions: HashMap::new(),
443		}
444	}
445
446	/// Clean up completed sessions by going through all drop messages.
447	fn cleanup_sessions(&mut self) {
448		while let Ok(id) = self.drop_rx.try_recv() {
449			match self.sessions.remove(&id) {
450				Some(_) => tracing::trace!("Cleaned up session {id}"),
451				None => {
452					tracing::warn!(
453						"HcSessionSocket got request to drop a session that does not exist"
454					)
455				}
456			}
457		}
458	}
459
460	async fn message(&mut self) -> StdResult<Option<PluginQuery>, Status> {
461		let fut = poll_fn(|cx| Pin::new(&mut *self.rx).poll_next(cx));
462
463		match fut.await {
464			Some(Ok(m)) => Ok(m.query),
465			Some(Err(e)) => Err(e),
466			None => Ok(None),
467		}
468	}
469
470	pub(crate) async fn listen(&mut self) -> Result<Option<PluginEngine>> {
471		loop {
472			let Some(raw) = self.message().await.map_err(Error::from)? else {
473				return Ok(None);
474			};
475			let id = raw.id;
476
477			// While we were waiting for a message, some session objects may have
478			// dropped, handle them before we look at the ID of this message.
479			// The downside of this strategy is that once we receive our last message,
480			// we won't clean up any sessions that close after
481			self.cleanup_sessions();
482
483			match self.decide_action(&raw) {
484				Ok(HandleAction::ForwardMsgToExistingSession(tx)) => {
485					tracing::trace!("SDK: forwarding message to session {id}");
486
487					if let Err(_e) = tx.send(Some(raw)).await {
488						tracing::error!("Error forwarding msg to session {id}");
489						self.sessions.remove(&id);
490					};
491				}
492				Ok(HandleAction::CreateSession) => {
493					tracing::trace!("SDK: creating new session {id}");
494
495					let (in_tx, rx) = mpsc::channel::<Option<PluginQuery>>(10);
496					let tx = self.tx.clone();
497
498					let session = PluginEngine {
499						id: id as usize,
500						concerns: vec![],
501						tx,
502						rx,
503						drop_tx: self.drop_tx.clone(),
504						mock_responses: MockResponses::new(),
505					};
506
507					in_tx.send(Some(raw)).await.expect(
508						"Failed sending message to newly created Session, should never happen",
509					);
510
511					tracing::trace!("SDK: adding new session {id} to tracker");
512					self.sessions.insert(id, in_tx);
513
514					return Ok(Some(session));
515				}
516				Err(e) => tracing::error!("{}", e),
517			}
518		}
519	}
520
521	fn decide_action(&mut self, query: &PluginQuery) -> Result<HandleAction<'_>> {
522		if let Some(tx) = self.sessions.get_mut(&query.id) {
523			return Ok(HandleAction::ForwardMsgToExistingSession(tx));
524		}
525
526		if [QueryState::SubmitInProgress, QueryState::SubmitComplete].contains(&query.state()) {
527			return Ok(HandleAction::CreateSession);
528		}
529
530		Err(Error::ReceivedReplyWhenExpectingRequest)
531	}
532
533	pub(crate) async fn run<P>(&mut self, plugin: Arc<P>) -> Result<()>
534	where
535		P: Plugin,
536	{
537		loop {
538			let Some(mut engine) = self
539				.listen()
540				.await
541				.map_err(|_| Error::SessionChannelClosed)?
542			else {
543				tracing::trace!("Channel closed by remote");
544				break;
545			};
546
547			let cloned_plugin = plugin.clone();
548			tokio::spawn(async move {
549				engine.handle_session(cloned_plugin).await;
550			});
551		}
552
553		Ok(())
554	}
555}
556
557enum HandleAction<'s> {
558	ForwardMsgToExistingSession(&'s mut mpsc::Sender<Option<PluginQuery>>),
559	CreateSession,
560}
561
562/// A map of query endpoints to mock return values.
563///
564/// When using the `mock_engine` feature, calling `PluginEngine::query()` will cause this
565/// structure to be referenced instead of trying to communicate with Hipcheck core. Allows
566/// constructing a `PluginEngine` with which to write unit tests.
567#[derive(Default, Debug)]
568pub struct MockResponses(pub(crate) HashMap<(QueryTarget, JsonValue), Result<JsonValue>>);
569
570impl MockResponses {
571	pub fn new() -> Self {
572		Self(HashMap::new())
573	}
574}
575
576impl MockResponses {
577	#[cfg(feature = "mock_engine")]
578	pub fn insert<T, V, W>(
579		&mut self,
580		query_target: T,
581		query_value: V,
582		query_response: Result<W>,
583	) -> Result<()>
584	where
585		T: TryInto<QueryTarget, Error: Into<crate::Error>>,
586		V: serde::Serialize,
587		W: serde::Serialize,
588	{
589		let query_target: QueryTarget = query_target.try_into().map_err(|e| e.into())?;
590		let query_value: JsonValue =
591			serde_json::to_value(query_value).map_err(crate::Error::InvalidJsonInQueryKey)?;
592		let query_response = match query_response {
593			Ok(v) => serde_json::to_value(v).map_err(crate::Error::InvalidJsonInQueryKey),
594			Err(e) => Err(e),
595		};
596		self.0.insert((query_target, query_value), query_response);
597		Ok(())
598	}
599}
600
601#[cfg(test)]
602mod test {
603	use super::*;
604
605	#[cfg(feature = "mock_engine")]
606	#[tokio::test]
607	async fn test_query_builder() {
608		let mut mock_responses = MockResponses::new();
609		mock_responses
610			.insert("mitre/foo", "abcd", Ok(1234))
611			.unwrap();
612		mock_responses
613			.insert("mitre/foo", "efgh", Ok(5678))
614			.unwrap();
615		let mut engine = PluginEngine::mock(mock_responses);
616		let mut builder = engine.batch("mitre/foo").unwrap();
617		let idx = builder.query("abcd".into());
618		assert_eq!(idx, 0);
619		let idx = builder.query("efgh".into());
620		assert_eq!(idx, 1);
621		let response = builder.send().await.unwrap();
622		assert_eq!(
623			response.first().unwrap(),
624			&<i32 as Into<JsonValue>>::into(1234)
625		);
626		assert_eq!(
627			response.get(1).unwrap(),
628			&<i32 as Into<JsonValue>>::into(5678)
629		);
630	}
631}