1use crate::{
4 error::{Error, Result},
5 JsonValue, Plugin, QueryTarget,
6};
7use futures::Stream;
8use hipcheck_common::proto::{
9 self, InitiateQueryProtocolRequest, InitiateQueryProtocolResponse, Query as PluginQuery,
10 QueryState,
11};
12use hipcheck_common::{
13 chunk::QuerySynthesizer,
14 types::{Query, QueryDirection},
15};
16use serde::Serialize;
17use std::{
18 collections::{HashMap, VecDeque},
19 future::poll_fn,
20 pin::Pin,
21 result::Result as StdResult,
22 sync::Arc,
23};
24use tokio::sync::mpsc::{self, error::TrySendError};
25use tonic::Status;
26
27impl From<Status> for Error {
28 fn from(_value: Status) -> Error {
29 Error::SessionChannelClosed
31 }
32}
33
34type SessionTracker = HashMap<i32, mpsc::Sender<Option<PluginQuery>>>;
35
36pub struct QueryBuilder<'engine> {
38 keys: Vec<JsonValue>,
39 target: QueryTarget,
40 plugin_engine: &'engine mut PluginEngine,
41}
42
43impl<'engine> QueryBuilder<'engine> {
44 fn new<T>(plugin_engine: &'engine mut PluginEngine, target: T) -> Result<QueryBuilder<'engine>>
46 where
47 T: TryInto<QueryTarget, Error: Into<Error>>,
48 {
49 let target: QueryTarget = target.try_into().map_err(|e| e.into())?;
50 Ok(Self {
51 plugin_engine,
52 target,
53 keys: vec![],
54 })
55 }
56
57 pub fn query(&mut self, key: JsonValue) -> usize {
61 let len = self.keys.len();
62 self.keys.push(key);
63 len
64 }
65
66 pub async fn send(self) -> Result<Vec<JsonValue>> {
68 self.plugin_engine.batch_query(self.target, self.keys).await
69 }
70}
71
72pub struct PluginEngine {
77 id: usize,
78 tx: mpsc::Sender<StdResult<InitiateQueryProtocolResponse, Status>>,
79 rx: mpsc::Receiver<Option<PluginQuery>>,
80 concerns: Vec<String>,
81 drop_tx: mpsc::Sender<i32>,
83 mock_responses: MockResponses,
85}
86
87impl PluginEngine {
88 #[cfg(feature = "mock_engine")]
89 #[cfg_attr(docsrs, doc(cfg(feature = "mock_engine")))]
90 pub fn mock(mock_responses: MockResponses) -> Self {
93 mock_responses.into()
94 }
95
96 pub fn batch<T>(&mut self, target: T) -> Result<QueryBuilder>
100 where
101 T: TryInto<QueryTarget, Error: Into<Error>>,
102 {
103 QueryBuilder::new(self, target)
104 }
105
106 async fn query_inner(
107 &mut self,
108 target: QueryTarget,
109 input: Vec<JsonValue>,
110 ) -> Result<Vec<JsonValue>> {
111 if cfg!(feature = "mock_engine") {
113 let mut results = Vec::with_capacity(input.len());
114 for i in input {
115 match self.mock_responses.0.get(&(target.clone(), i)) {
116 Some(res) => {
117 match res {
118 Ok(val) => results.push(val.clone()),
119 Err(_) => return Err(Error::UnexpectedPluginQueryInputFormat),
121 }
122 }
123 None => return Err(Error::UnknownPluginQuery),
124 }
125 }
126 Ok(results)
127 }
128 else {
130 let query = Query {
131 id: 0,
132 direction: QueryDirection::Request,
133 publisher: target.publisher,
134 plugin: target.plugin,
135 query: target.query.unwrap_or_else(|| "".to_owned()),
136 key: input,
137 output: vec![],
138 concerns: vec![],
139 };
140 self.send(query).await?;
141 let response = self.recv().await?;
142 match response {
143 Some(response) => Ok(response.output),
144 None => Err(Error::SessionChannelClosed),
145 }
146 }
147 }
148
149 pub async fn query<T, V>(&mut self, target: T, input: V) -> Result<JsonValue>
154 where
155 T: TryInto<QueryTarget, Error: Into<Error>>,
156 V: Serialize,
157 {
158 let query_target: QueryTarget = target.try_into().map_err(|e| e.into())?;
159 tracing::trace!("querying {}", query_target.to_string());
160 let input: JsonValue = serde_json::to_value(input).map_err(Error::InvalidJsonInQueryKey)?;
161 let mut response = self.query_inner(query_target, vec![input]).await?;
163 Ok(response.pop().unwrap())
164 }
165
166 pub async fn batch_query<T, V>(&mut self, target: T, keys: Vec<V>) -> Result<Vec<JsonValue>>
171 where
172 T: TryInto<QueryTarget, Error: Into<Error>>,
173 V: Serialize,
174 {
175 let target: QueryTarget = target.try_into().map_err(|e| e.into())?;
176 tracing::trace!("querying {}", target.to_string());
177 let mut input = Vec::with_capacity(keys.len());
178 for key in keys {
179 let jsonified_key = serde_json::to_value(key).map_err(Error::InvalidJsonInQueryKey)?;
180 input.push(jsonified_key);
181 }
182 self.query_inner(target, input).await
183 }
184
185 fn id(&self) -> usize {
186 self.id
187 }
188
189 async fn recv_raw(&mut self) -> Result<Option<VecDeque<PluginQuery>>> {
190 let mut out = VecDeque::new();
191
192 tracing::trace!("SDK: awaiting raw rx recv");
193
194 let opt_first = self.rx.recv().await.ok_or(Error::SessionChannelClosed)?;
195
196 let Some(first) = opt_first else {
197 return Ok(None);
199 };
200 out.push_back(first);
201
202 loop {
204 match self.rx.try_recv() {
205 Ok(Some(msg)) => {
206 out.push_back(msg);
207 }
208 Ok(None) => {
209 tracing::warn!("None received, gRPC channel closed. we may not close properly if None is not returned again");
210 break;
211 }
212 Err(_) => {
214 break;
215 }
216 }
217 }
218
219 Ok(Some(out))
220 }
221
222 async fn send(&self, mut query: Query) -> Result<()> {
224 query.id = self.id(); let queries = hipcheck_common::chunk::prepare(query)?;
226 for pq in queries {
227 let query = InitiateQueryProtocolResponse { query: Some(pq) };
228 self.tx
229 .send(Ok(query))
230 .await
231 .map_err(Error::FailedToSendQueryFromSessionToServer)?;
232 }
233 Ok(())
234 }
235
236 async fn send_session_err<P>(&mut self) -> crate::error::Result<()>
237 where
238 P: Plugin,
239 {
240 let query = proto::Query {
241 id: self.id() as i32,
242 state: QueryState::Unspecified as i32,
243 publisher_name: P::PUBLISHER.to_owned(),
244 plugin_name: P::NAME.to_owned(),
245 query_name: "".to_owned(),
246 key: vec![],
247 output: vec![],
248 concern: self.take_concerns(),
249 split: false,
250 };
251 self.tx
252 .send(Ok(InitiateQueryProtocolResponse { query: Some(query) }))
253 .await
254 .map_err(Error::FailedToSendQueryFromSessionToServer)
255 }
256
257 async fn recv(&mut self) -> Result<Option<Query>> {
258 let mut synth = QuerySynthesizer::default();
259 let mut res: Option<Query> = None;
260 while res.is_none() {
261 let Some(msg_chunks) = self.recv_raw().await? else {
262 return Ok(None);
263 };
264 res = synth.add(msg_chunks.into_iter())?;
265 }
266 Ok(res)
267 }
268
269 async fn handle_session_fallible<P>(&mut self, plugin: Arc<P>) -> crate::error::Result<()>
270 where
271 P: Plugin,
272 {
273 let Some(query) = self.recv().await? else {
274 return Err(Error::SessionChannelClosed);
275 };
276
277 if query.direction == QueryDirection::Response {
278 return Err(Error::ReceivedReplyWhenExpectingRequest);
279 }
280
281 let name = query.query;
282
283 if query.key.len() != 1 {
285 return Err(Error::UnspecifiedQueryState);
286 }
287 let key = query.key.first().unwrap().clone();
288
289 let query = plugin
293 .queries()
294 .filter_map(|x| if x.name == name { Some(x.inner) } else { None })
295 .next()
296 .or_else(|| plugin.default_query())
297 .ok_or_else(|| Error::UnknownPluginQuery)?;
298
299 #[cfg(feature = "print-timings")]
300 let _0 = crate::benchmarking::print_scope_time!(format!("{}/{}", P::NAME, name));
301
302 let value = query.run(self, key).await?;
303
304 #[cfg(feature = "print-timings")]
305 drop(_0);
306
307 let query = Query {
308 id: self.id(),
309 direction: QueryDirection::Response,
310 publisher: P::PUBLISHER.to_owned(),
311 plugin: P::NAME.to_owned(),
312 query: name.to_owned(),
313 key: vec![],
314 output: vec![value],
315 concerns: self.take_concerns(),
316 };
317
318 self.send(query).await
319 }
320
321 async fn handle_session<P>(&mut self, plugin: Arc<P>)
322 where
323 P: Plugin,
324 {
325 use crate::error::Error::*;
326 if let Err(e) = self.handle_session_fallible(plugin).await {
327 let res_err_send = match e {
328 FailedToSendQueryFromSessionToServer(_) => {
329 tracing::error!("Failed to send message to Hipcheck core, analysis will hang.");
330 return;
331 }
332 other => {
333 tracing::error!("{}", other);
334 self.send_session_err::<P>().await
335 }
336 };
337 if res_err_send.is_err() {
338 tracing::error!("Failed to send message to Hipcheck core, analysis will hang.");
339 }
340 }
341 }
342
343 pub fn record_concern<S: AsRef<str>>(&mut self, concern: S) {
346 fn inner(engine: &mut PluginEngine, concern: &str) {
347 engine.concerns.push(concern.to_owned());
348 }
349 inner(self, concern.as_ref())
350 }
351
352 #[cfg(feature = "mock_engine")]
353 #[cfg_attr(docsrs, doc(cfg(feature = "mock_engine")))]
354 pub fn get_concerns(&self) -> &[String] {
356 &self.concerns
357 }
358
359 fn take_concerns(&mut self) -> Vec<String> {
360 self.concerns.drain(..).collect()
361 }
362}
363
364#[cfg(feature = "mock_engine")]
365#[cfg_attr(docsrs, doc(cfg(feature = "mock_engine")))]
366impl From<MockResponses> for PluginEngine {
367 fn from(value: MockResponses) -> Self {
368 let (tx, _) = mpsc::channel(1);
369 let (_, rx) = mpsc::channel(1);
370 let (drop_tx, _) = mpsc::channel(1);
371
372 Self {
373 id: 0,
374 concerns: vec![],
375 tx,
376 rx,
377 drop_tx,
378 mock_responses: value,
379 }
380 }
381}
382
383impl Drop for PluginEngine {
384 fn drop(&mut self) {
386 if cfg!(feature = "mock_engine") {
387 let _ = self.drop_tx.max_capacity();
390 } else {
391 while let Err(e) = self.drop_tx.try_send(self.id as i32) {
392 match e {
393 TrySendError::Closed(_) => {
394 break;
395 }
396 TrySendError::Full(_) => (),
397 }
398 }
399 }
400 }
401}
402
403type PluginQueryStream = Box<
404 dyn Stream<Item = StdResult<InitiateQueryProtocolRequest, Status>> + Send + Unpin + 'static,
405>;
406
407pub(crate) struct HcSessionSocket {
408 tx: mpsc::Sender<StdResult<InitiateQueryProtocolResponse, Status>>,
409 rx: PluginQueryStream,
410 drop_tx: mpsc::Sender<i32>,
411 drop_rx: mpsc::Receiver<i32>,
412 sessions: SessionTracker,
413}
414
415impl std::fmt::Debug for HcSessionSocket {
418 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
419 f.debug_struct("HcSessionSocket")
420 .field("tx", &self.tx)
421 .field("rx", &"<rx>")
422 .field("drop_tx", &self.drop_tx)
423 .field("drop_rx", &self.drop_rx)
424 .field("sessions", &self.sessions)
425 .finish()
426 }
427}
428
429impl HcSessionSocket {
430 pub(crate) fn new(
431 tx: mpsc::Sender<StdResult<InitiateQueryProtocolResponse, Status>>,
432 rx: impl Stream<Item = StdResult<InitiateQueryProtocolRequest, Status>> + Send + Unpin + 'static,
433 ) -> Self {
434 let (drop_tx, drop_rx) = mpsc::channel(10);
437 Self {
438 tx,
439 rx: Box::new(rx),
440 drop_tx,
441 drop_rx,
442 sessions: HashMap::new(),
443 }
444 }
445
446 fn cleanup_sessions(&mut self) {
448 while let Ok(id) = self.drop_rx.try_recv() {
449 match self.sessions.remove(&id) {
450 Some(_) => tracing::trace!("Cleaned up session {id}"),
451 None => {
452 tracing::warn!(
453 "HcSessionSocket got request to drop a session that does not exist"
454 )
455 }
456 }
457 }
458 }
459
460 async fn message(&mut self) -> StdResult<Option<PluginQuery>, Status> {
461 let fut = poll_fn(|cx| Pin::new(&mut *self.rx).poll_next(cx));
462
463 match fut.await {
464 Some(Ok(m)) => Ok(m.query),
465 Some(Err(e)) => Err(e),
466 None => Ok(None),
467 }
468 }
469
470 pub(crate) async fn listen(&mut self) -> Result<Option<PluginEngine>> {
471 loop {
472 let Some(raw) = self.message().await.map_err(Error::from)? else {
473 return Ok(None);
474 };
475 let id = raw.id;
476
477 self.cleanup_sessions();
482
483 match self.decide_action(&raw) {
484 Ok(HandleAction::ForwardMsgToExistingSession(tx)) => {
485 tracing::trace!("SDK: forwarding message to session {id}");
486
487 if let Err(_e) = tx.send(Some(raw)).await {
488 tracing::error!("Error forwarding msg to session {id}");
489 self.sessions.remove(&id);
490 };
491 }
492 Ok(HandleAction::CreateSession) => {
493 tracing::trace!("SDK: creating new session {id}");
494
495 let (in_tx, rx) = mpsc::channel::<Option<PluginQuery>>(10);
496 let tx = self.tx.clone();
497
498 let session = PluginEngine {
499 id: id as usize,
500 concerns: vec![],
501 tx,
502 rx,
503 drop_tx: self.drop_tx.clone(),
504 mock_responses: MockResponses::new(),
505 };
506
507 in_tx.send(Some(raw)).await.expect(
508 "Failed sending message to newly created Session, should never happen",
509 );
510
511 tracing::trace!("SDK: adding new session {id} to tracker");
512 self.sessions.insert(id, in_tx);
513
514 return Ok(Some(session));
515 }
516 Err(e) => tracing::error!("{}", e),
517 }
518 }
519 }
520
521 fn decide_action(&mut self, query: &PluginQuery) -> Result<HandleAction<'_>> {
522 if let Some(tx) = self.sessions.get_mut(&query.id) {
523 return Ok(HandleAction::ForwardMsgToExistingSession(tx));
524 }
525
526 if [QueryState::SubmitInProgress, QueryState::SubmitComplete].contains(&query.state()) {
527 return Ok(HandleAction::CreateSession);
528 }
529
530 Err(Error::ReceivedReplyWhenExpectingRequest)
531 }
532
533 pub(crate) async fn run<P>(&mut self, plugin: Arc<P>) -> Result<()>
534 where
535 P: Plugin,
536 {
537 loop {
538 let Some(mut engine) = self
539 .listen()
540 .await
541 .map_err(|_| Error::SessionChannelClosed)?
542 else {
543 tracing::trace!("Channel closed by remote");
544 break;
545 };
546
547 let cloned_plugin = plugin.clone();
548 tokio::spawn(async move {
549 engine.handle_session(cloned_plugin).await;
550 });
551 }
552
553 Ok(())
554 }
555}
556
557enum HandleAction<'s> {
558 ForwardMsgToExistingSession(&'s mut mpsc::Sender<Option<PluginQuery>>),
559 CreateSession,
560}
561
562#[derive(Default, Debug)]
568pub struct MockResponses(pub(crate) HashMap<(QueryTarget, JsonValue), Result<JsonValue>>);
569
570impl MockResponses {
571 pub fn new() -> Self {
572 Self(HashMap::new())
573 }
574}
575
576impl MockResponses {
577 #[cfg(feature = "mock_engine")]
578 pub fn insert<T, V, W>(
579 &mut self,
580 query_target: T,
581 query_value: V,
582 query_response: Result<W>,
583 ) -> Result<()>
584 where
585 T: TryInto<QueryTarget, Error: Into<crate::Error>>,
586 V: serde::Serialize,
587 W: serde::Serialize,
588 {
589 let query_target: QueryTarget = query_target.try_into().map_err(|e| e.into())?;
590 let query_value: JsonValue =
591 serde_json::to_value(query_value).map_err(crate::Error::InvalidJsonInQueryKey)?;
592 let query_response = match query_response {
593 Ok(v) => serde_json::to_value(v).map_err(crate::Error::InvalidJsonInQueryKey),
594 Err(e) => Err(e),
595 };
596 self.0.insert((query_target, query_value), query_response);
597 Ok(())
598 }
599}
600
601#[cfg(test)]
602mod test {
603 use super::*;
604
605 #[cfg(feature = "mock_engine")]
606 #[tokio::test]
607 async fn test_query_builder() {
608 let mut mock_responses = MockResponses::new();
609 mock_responses
610 .insert("mitre/foo", "abcd", Ok(1234))
611 .unwrap();
612 mock_responses
613 .insert("mitre/foo", "efgh", Ok(5678))
614 .unwrap();
615 let mut engine = PluginEngine::mock(mock_responses);
616 let mut builder = engine.batch("mitre/foo").unwrap();
617 let idx = builder.query("abcd".into());
618 assert_eq!(idx, 0);
619 let idx = builder.query("efgh".into());
620 assert_eq!(idx, 1);
621 let response = builder.send().await.unwrap();
622 assert_eq!(
623 response.first().unwrap(),
624 &<i32 as Into<JsonValue>>::into(1234)
625 );
626 assert_eq!(
627 response.get(1).unwrap(),
628 &<i32 as Into<JsonValue>>::into(5678)
629 );
630 }
631}