1use crate::{
4 JsonValue, Plugin, QueryTarget,
5 error::{Error, Result},
6};
7use futures::Stream;
8use hipcheck_common::proto::{
9 self, InitiateQueryProtocolRequest, InitiateQueryProtocolResponse, Query as PluginQuery,
10 QueryState,
11};
12use hipcheck_common::{
13 chunk::QuerySynthesizer,
14 types::{Query, QueryDirection},
15};
16use serde::Serialize;
17use std::{
18 collections::{HashMap, VecDeque},
19 future::poll_fn,
20 pin::Pin,
21 result::Result as StdResult,
22 sync::Arc,
23};
24use tokio::sync::mpsc::{self, error::TrySendError};
25use tonic::Status;
26
27impl From<Status> for Error {
28 fn from(_value: Status) -> Error {
29 Error::SessionChannelClosed
31 }
32}
33
34type SessionTracker = HashMap<i32, mpsc::Sender<Option<PluginQuery>>>;
35
36pub struct QueryBuilder<'engine> {
38 keys: Vec<JsonValue>,
39 target: QueryTarget,
40 plugin_engine: &'engine mut PluginEngine,
41}
42
43impl<'engine> QueryBuilder<'engine> {
44 fn new<T>(plugin_engine: &'engine mut PluginEngine, target: T) -> Result<QueryBuilder<'engine>>
46 where
47 T: TryInto<QueryTarget, Error: Into<Error>>,
48 {
49 let target: QueryTarget = target.try_into().map_err(|e| e.into())?;
50 Ok(Self {
51 plugin_engine,
52 target,
53 keys: vec![],
54 })
55 }
56
57 pub fn query(&mut self, key: JsonValue) -> usize {
61 let len = self.keys.len();
62 self.keys.push(key);
63 len
64 }
65
66 pub async fn send(self) -> Result<Vec<JsonValue>> {
68 self.plugin_engine.batch_query(self.target, self.keys).await
69 }
70}
71
72pub struct PluginEngine {
77 id: usize,
78 tx: mpsc::Sender<StdResult<InitiateQueryProtocolResponse, Status>>,
79 rx: mpsc::Receiver<Option<PluginQuery>>,
80 concerns: Vec<String>,
81 drop_tx: mpsc::Sender<i32>,
83 mock_responses: MockResponses,
85}
86
87impl PluginEngine {
88 #[cfg(feature = "mock_engine")]
89 #[cfg_attr(docsrs, doc(cfg(feature = "mock_engine")))]
90 pub fn mock(mock_responses: MockResponses) -> Self {
93 mock_responses.into()
94 }
95
96 pub fn batch<T>(&mut self, target: T) -> Result<QueryBuilder<'_>>
100 where
101 T: TryInto<QueryTarget, Error: Into<Error>>,
102 {
103 QueryBuilder::new(self, target)
104 }
105
106 async fn query_inner(
107 &mut self,
108 target: QueryTarget,
109 input: Vec<JsonValue>,
110 ) -> Result<Vec<JsonValue>> {
111 if cfg!(feature = "mock_engine") {
113 let mut results = Vec::with_capacity(input.len());
114 for i in input {
115 match self.mock_responses.0.get(&(target.clone(), i)) {
116 Some(res) => match res {
117 Ok(val) => results.push(val.clone()),
118 Err(e) => {
119 tracing::error!("Error parsing mock_engine response: {e}");
120 return Err(Error::UnexpectedPluginQueryInputFormat);
121 }
122 },
123 None => {
124 return Err(Error::UnknownPluginQuery(
125 target.to_string().into_boxed_str(),
126 ));
127 }
128 }
129 }
130 Ok(results)
131 }
132 else {
134 let query = Query {
135 id: 0,
136 direction: QueryDirection::Request,
137 publisher: target.publisher,
138 plugin: target.plugin,
139 query: target.query.unwrap_or_else(|| "".to_owned()),
140 key: input,
141 output: vec![],
142 concerns: vec![],
143 };
144 self.send(query).await?;
145 let response = self.recv().await?;
146 match response {
147 Some(response) => Ok(response.output),
148 None => Err(Error::SessionChannelClosed),
149 }
150 }
151 }
152
153 pub async fn query<T, V>(&mut self, target: T, input: V) -> Result<JsonValue>
158 where
159 T: TryInto<QueryTarget, Error: Into<Error>>,
160 V: Serialize,
161 {
162 let query_target: QueryTarget = target.try_into().map_err(|e| e.into())?;
163 tracing::trace!("querying {}", query_target.to_string());
164 let input: JsonValue = serde_json::to_value(input)
165 .map_err(|source| Error::InvalidJsonInQueryKey(Box::new(source)))?;
166 let mut response = self.query_inner(query_target, vec![input]).await?;
168 Ok(response.pop().unwrap())
169 }
170
171 pub async fn batch_query<T, V>(&mut self, target: T, keys: Vec<V>) -> Result<Vec<JsonValue>>
176 where
177 T: TryInto<QueryTarget, Error: Into<Error>>,
178 V: Serialize,
179 {
180 let target: QueryTarget = target.try_into().map_err(|e| e.into())?;
181 tracing::trace!("querying {}", target.to_string());
182 let mut input = Vec::with_capacity(keys.len());
183 for key in keys {
184 let jsonified_key = serde_json::to_value(key)
185 .map_err(|source| Error::InvalidJsonInQueryKey(Box::new(source)))?;
186 input.push(jsonified_key);
187 }
188 self.query_inner(target, input).await
189 }
190
191 fn id(&self) -> usize {
192 self.id
193 }
194
195 async fn recv_raw(&mut self) -> Result<Option<VecDeque<PluginQuery>>> {
196 let mut out = VecDeque::new();
197
198 tracing::trace!("SDK: awaiting raw rx recv");
199
200 let opt_first = self.rx.recv().await.ok_or(Error::SessionChannelClosed)?;
201
202 let Some(first) = opt_first else {
203 return Ok(None);
205 };
206 out.push_back(first);
207
208 loop {
210 match self.rx.try_recv() {
211 Ok(Some(msg)) => {
212 out.push_back(msg);
213 }
214 Ok(None) => {
215 tracing::warn!(
216 "None received, gRPC channel closed. we may not close properly if None is not returned again"
217 );
218 break;
219 }
220 Err(_) => {
222 break;
223 }
224 }
225 }
226
227 Ok(Some(out))
228 }
229
230 async fn send(&self, mut query: Query) -> Result<()> {
232 query.id = self.id(); let queries = hipcheck_common::chunk::prepare(query)?;
234 for pq in queries {
235 let query = InitiateQueryProtocolResponse { query: Some(pq) };
236 self.tx
237 .send(Ok(query))
238 .await
239 .map_err(|source| Error::FailedToSendQueryFromSessionToServer(Box::new(source)))?;
240 }
241 Ok(())
242 }
243
244 async fn send_session_err<P>(&mut self) -> crate::error::Result<()>
245 where
246 P: Plugin,
247 {
248 let query = proto::Query {
249 id: self.id() as i32,
250 state: QueryState::Unspecified as i32,
251 publisher_name: P::PUBLISHER.to_owned(),
252 plugin_name: P::NAME.to_owned(),
253 query_name: "".to_owned(),
254 key: vec![],
255 output: vec![],
256 concern: self.take_concerns(),
257 split: false,
258 };
259 self.tx
260 .send(Ok(InitiateQueryProtocolResponse { query: Some(query) }))
261 .await
262 .map_err(|source| Error::FailedToSendQueryFromSessionToServer(Box::new(source)))
263 }
264
265 async fn recv(&mut self) -> Result<Option<Query>> {
266 let mut synth = QuerySynthesizer::default();
267 let mut res: Option<Query> = None;
268 while res.is_none() {
269 let Some(msg_chunks) = self.recv_raw().await? else {
270 return Ok(None);
271 };
272 res = synth.add(msg_chunks.into_iter())?;
273 }
274 Ok(res)
275 }
276
277 async fn handle_session_fallible<P>(&mut self, plugin: Arc<P>) -> crate::error::Result<()>
278 where
279 P: Plugin,
280 {
281 let Some(query) = self.recv().await? else {
282 return Err(Error::SessionChannelClosed);
283 };
284
285 if query.direction == QueryDirection::Response {
286 return Err(Error::ReceivedReplyWhenExpectingRequest);
287 }
288
289 let name = query.query;
290
291 if query.key.len() != 1 {
293 return Err(Error::UnspecifiedQueryState);
294 }
295 let key = query.key.first().unwrap().clone();
296
297 let query = plugin
301 .queries()
302 .filter_map(|x| if x.name == name { Some(x.inner) } else { None })
303 .next()
304 .or_else(|| plugin.default_query())
305 .ok_or_else(|| {
306 if name.is_empty() {
307 Error::NoDefaultQuery
308 } else {
309 Error::UnknownPluginQuery(name.clone().into_boxed_str())
310 }
311 })?;
312
313 #[cfg(feature = "print-timings")]
314 let _0 = crate::benchmarking::print_scope_time!(format!("{}/{}", P::NAME, name));
315
316 let value = query.run(self, key).await?;
317
318 #[cfg(feature = "print-timings")]
319 drop(_0);
320
321 let query = Query {
322 id: self.id(),
323 direction: QueryDirection::Response,
324 publisher: P::PUBLISHER.to_owned(),
325 plugin: P::NAME.to_owned(),
326 query: name.to_owned(),
327 key: vec![],
328 output: vec![value],
329 concerns: self.take_concerns(),
330 };
331
332 self.send(query).await
333 }
334
335 async fn handle_session<P>(&mut self, plugin: Arc<P>)
336 where
337 P: Plugin,
338 {
339 if let Err(e) = self.handle_session_fallible(plugin).await {
340 let res_err_send = match e {
341 Error::FailedToSendQueryFromSessionToServer(_) => {
342 tracing::error!("Failed to send message to Hipcheck core, analysis will hang.");
343 return;
344 }
345 other => {
346 tracing::error!("{}", other);
347 self.send_session_err::<P>().await
348 }
349 };
350 if res_err_send.is_err() {
351 tracing::error!("Failed to send message to Hipcheck core, analysis will hang.");
352 }
353 }
354 }
355
356 pub fn record_concern<S: AsRef<str>>(&mut self, concern: S) {
359 fn inner(engine: &mut PluginEngine, concern: &str) {
360 engine.concerns.push(concern.to_owned());
361 }
362 inner(self, concern.as_ref())
363 }
364
365 #[cfg(feature = "mock_engine")]
366 #[cfg_attr(docsrs, doc(cfg(feature = "mock_engine")))]
367 pub fn get_concerns(&self) -> &[String] {
369 &self.concerns
370 }
371
372 fn take_concerns(&mut self) -> Vec<String> {
373 self.concerns.drain(..).collect()
374 }
375}
376
377#[cfg(feature = "mock_engine")]
378#[cfg_attr(docsrs, doc(cfg(feature = "mock_engine")))]
379impl From<MockResponses> for PluginEngine {
380 fn from(value: MockResponses) -> Self {
381 let (tx, _) = mpsc::channel(1);
382 let (_, rx) = mpsc::channel(1);
383 let (drop_tx, _) = mpsc::channel(1);
384
385 Self {
386 id: 0,
387 concerns: vec![],
388 tx,
389 rx,
390 drop_tx,
391 mock_responses: value,
392 }
393 }
394}
395
396impl Drop for PluginEngine {
397 fn drop(&mut self) {
399 if cfg!(feature = "mock_engine") {
400 let _ = self.drop_tx.max_capacity();
403 } else {
404 while let Err(e) = self.drop_tx.try_send(self.id as i32) {
405 match e {
406 TrySendError::Closed(_) => {
407 break;
408 }
409 TrySendError::Full(_) => (),
410 }
411 }
412 }
413 }
414}
415
416type PluginQueryStream = Box<
417 dyn Stream<Item = StdResult<InitiateQueryProtocolRequest, Status>> + Send + Unpin + 'static,
418>;
419
420pub(crate) struct HcSessionSocket {
421 tx: mpsc::Sender<StdResult<InitiateQueryProtocolResponse, Status>>,
422 rx: PluginQueryStream,
423 drop_tx: mpsc::Sender<i32>,
424 drop_rx: mpsc::Receiver<i32>,
425 sessions: SessionTracker,
426}
427
428impl std::fmt::Debug for HcSessionSocket {
431 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
432 f.debug_struct("HcSessionSocket")
433 .field("tx", &self.tx)
434 .field("rx", &"<rx>")
435 .field("drop_tx", &self.drop_tx)
436 .field("drop_rx", &self.drop_rx)
437 .field("sessions", &self.sessions)
438 .finish()
439 }
440}
441
442impl HcSessionSocket {
443 pub(crate) fn new(
444 tx: mpsc::Sender<StdResult<InitiateQueryProtocolResponse, Status>>,
445 rx: impl Stream<Item = StdResult<InitiateQueryProtocolRequest, Status>> + Send + Unpin + 'static,
446 ) -> Self {
447 let (drop_tx, drop_rx) = mpsc::channel(10);
450 Self {
451 tx,
452 rx: Box::new(rx),
453 drop_tx,
454 drop_rx,
455 sessions: HashMap::new(),
456 }
457 }
458
459 fn cleanup_sessions(&mut self) {
461 while let Ok(id) = self.drop_rx.try_recv() {
462 match self.sessions.remove(&id) {
463 Some(_) => tracing::trace!("Cleaned up session {id}"),
464 None => {
465 tracing::warn!(
466 "HcSessionSocket got request to drop a session that does not exist"
467 )
468 }
469 }
470 }
471 }
472
473 async fn message(&mut self) -> StdResult<Option<PluginQuery>, Status> {
474 let fut = poll_fn(|cx| Pin::new(&mut *self.rx).poll_next(cx));
475
476 match fut.await {
477 Some(Ok(m)) => Ok(m.query),
478 Some(Err(e)) => Err(e),
479 None => Ok(None),
480 }
481 }
482
483 pub(crate) async fn listen(&mut self) -> Result<Option<PluginEngine>> {
484 loop {
485 let Some(raw) = self.message().await.map_err(Error::from)? else {
486 return Ok(None);
487 };
488 let id = raw.id;
489
490 self.cleanup_sessions();
495
496 match self.decide_action(&raw) {
497 Ok(HandleAction::ForwardMsgToExistingSession(tx)) => {
498 tracing::trace!("SDK: forwarding message to session {id}");
499
500 if let Err(_e) = tx.send(Some(raw)).await {
501 tracing::error!("Error forwarding msg to session {id}");
502 self.sessions.remove(&id);
503 };
504 }
505 Ok(HandleAction::CreateSession) => {
506 tracing::trace!("SDK: creating new session {id}");
507
508 let (in_tx, rx) = mpsc::channel::<Option<PluginQuery>>(10);
509 let tx = self.tx.clone();
510
511 let session = PluginEngine {
512 id: id as usize,
513 concerns: vec![],
514 tx,
515 rx,
516 drop_tx: self.drop_tx.clone(),
517 mock_responses: MockResponses::new(),
518 };
519
520 in_tx.send(Some(raw)).await.expect(
521 "Failed sending message to newly created Session, should never happen",
522 );
523
524 tracing::trace!("SDK: adding new session {id} to tracker");
525 self.sessions.insert(id, in_tx);
526
527 return Ok(Some(session));
528 }
529 Err(e) => tracing::error!("{}", e),
530 }
531 }
532 }
533
534 fn decide_action(&mut self, query: &PluginQuery) -> Result<HandleAction<'_>> {
535 if let Some(tx) = self.sessions.get_mut(&query.id) {
536 return Ok(HandleAction::ForwardMsgToExistingSession(tx));
537 }
538
539 if [QueryState::SubmitInProgress, QueryState::SubmitComplete].contains(&query.state()) {
540 return Ok(HandleAction::CreateSession);
541 }
542
543 Err(Error::ReceivedReplyWhenExpectingRequest)
544 }
545
546 pub(crate) async fn run<P>(&mut self, plugin: Arc<P>) -> Result<()>
547 where
548 P: Plugin,
549 {
550 loop {
551 let Some(mut engine) = self
552 .listen()
553 .await
554 .map_err(|_| Error::SessionChannelClosed)?
555 else {
556 tracing::trace!("Channel closed by remote");
557 break;
558 };
559
560 let cloned_plugin = plugin.clone();
561 tokio::spawn(async move {
562 engine.handle_session(cloned_plugin).await;
563 });
564 }
565
566 Ok(())
567 }
568}
569
570enum HandleAction<'s> {
571 ForwardMsgToExistingSession(&'s mut mpsc::Sender<Option<PluginQuery>>),
572 CreateSession,
573}
574
575#[derive(Default, Debug)]
581pub struct MockResponses(pub(crate) HashMap<(QueryTarget, JsonValue), Result<JsonValue>>);
582
583impl MockResponses {
584 pub fn new() -> Self {
585 Self(HashMap::new())
586 }
587}
588
589impl MockResponses {
590 #[cfg(feature = "mock_engine")]
591 pub fn insert<T, V, W>(
592 &mut self,
593 query_target: T,
594 query_value: V,
595 query_response: Result<W>,
596 ) -> Result<()>
597 where
598 T: TryInto<QueryTarget, Error: Into<crate::Error>>,
599 V: serde::Serialize,
600 W: serde::Serialize,
601 {
602 let query_target: QueryTarget = query_target.try_into().map_err(|e| e.into())?;
603 let query_value: JsonValue = serde_json::to_value(query_value)
604 .map_err(|source| crate::Error::InvalidJsonInQueryKey(Box::new(source)))?;
605 let query_response = match query_response {
606 Ok(v) => serde_json::to_value(v)
607 .map_err(|source| crate::Error::InvalidJsonInQueryKey(Box::new(source))),
608 Err(e) => Err(e),
609 };
610 self.0.insert((query_target, query_value), query_response);
611 Ok(())
612 }
613}
614
615#[cfg(test)]
616mod test {
617 use super::*;
618
619 #[cfg(feature = "mock_engine")]
620 #[tokio::test]
621 async fn test_query_builder() {
622 let mut mock_responses = MockResponses::new();
623 mock_responses
624 .insert("mitre/foo", "abcd", Ok(1234))
625 .unwrap();
626 mock_responses
627 .insert("mitre/foo", "efgh", Ok(5678))
628 .unwrap();
629 let mut engine = PluginEngine::mock(mock_responses);
630 let mut builder = engine.batch("mitre/foo").unwrap();
631 let idx = builder.query("abcd".into());
632 assert_eq!(idx, 0);
633 let idx = builder.query("efgh".into());
634 assert_eq!(idx, 1);
635 let response = builder.send().await.unwrap();
636 assert_eq!(
637 response.first().unwrap(),
638 &<i32 as Into<JsonValue>>::into(1234)
639 );
640 assert_eq!(
641 response.get(1).unwrap(),
642 &<i32 as Into<JsonValue>>::into(5678)
643 );
644 }
645}