use crate::{
JsonValue, Plugin, QueryTarget,
error::{Error, Result},
};
use futures::Stream;
use hipcheck_common::proto::{
self, InitiateQueryProtocolRequest, InitiateQueryProtocolResponse, Query as PluginQuery,
QueryState,
};
use hipcheck_common::{
chunk::QuerySynthesizer,
types::{Query, QueryDirection},
};
use serde::Serialize;
use std::{
collections::{HashMap, VecDeque},
future::poll_fn,
pin::Pin,
result::Result as StdResult,
sync::Arc,
};
use tokio::sync::mpsc::{self, error::TrySendError};
use tonic::Status;
impl From<Status> for Error {
fn from(_value: Status) -> Error {
Error::SessionChannelClosed
}
}
type SessionTracker = HashMap<i32, mpsc::Sender<Option<PluginQuery>>>;
pub struct QueryBuilder<'engine> {
keys: Vec<JsonValue>,
target: QueryTarget,
plugin_engine: &'engine mut PluginEngine,
}
impl<'engine> QueryBuilder<'engine> {
fn new<T>(plugin_engine: &'engine mut PluginEngine, target: T) -> Result<QueryBuilder<'engine>>
where
T: TryInto<QueryTarget, Error: Into<Error>>,
{
let target: QueryTarget = target.try_into().map_err(|e| e.into())?;
Ok(Self {
plugin_engine,
target,
keys: vec![],
})
}
pub fn query(&mut self, key: JsonValue) -> usize {
let len = self.keys.len();
self.keys.push(key);
len
}
pub async fn send(self) -> Result<Vec<JsonValue>> {
self.plugin_engine.batch_query(self.target, self.keys).await
}
}
pub struct PluginEngine {
id: usize,
tx: mpsc::Sender<StdResult<InitiateQueryProtocolResponse, Status>>,
rx: mpsc::Receiver<Option<PluginQuery>>,
concerns: Vec<String>,
drop_tx: mpsc::Sender<i32>,
mock_responses: MockResponses,
}
impl PluginEngine {
#[cfg(feature = "mock_engine")]
#[cfg_attr(docsrs, doc(cfg(feature = "mock_engine")))]
pub fn mock(mock_responses: MockResponses) -> Self {
mock_responses.into()
}
pub fn batch<T>(&mut self, target: T) -> Result<QueryBuilder<'_>>
where
T: TryInto<QueryTarget, Error: Into<Error>>,
{
QueryBuilder::new(self, target)
}
async fn query_inner(
&mut self,
target: QueryTarget,
input: Vec<JsonValue>,
) -> Result<Vec<JsonValue>> {
if cfg!(feature = "mock_engine") {
let mut results = Vec::with_capacity(input.len());
for i in input {
match self.mock_responses.0.get(&(target.clone(), i)) {
Some(res) => match res {
Ok(val) => results.push(val.clone()),
Err(e) => {
tracing::error!("Error parsing mock_engine response: {e}");
return Err(Error::UnexpectedPluginQueryInputFormat);
}
},
None => {
return Err(Error::UnknownPluginQuery(
target.to_string().into_boxed_str(),
));
}
}
}
Ok(results)
}
else {
let query = Query {
id: 0,
direction: QueryDirection::Request,
publisher: target.publisher,
plugin: target.plugin,
query: target.query.unwrap_or_else(|| "".to_owned()),
key: input,
output: vec![],
concerns: vec![],
};
self.send(query).await?;
let response = self.recv().await?;
match response {
Some(response) => Ok(response.output),
None => Err(Error::SessionChannelClosed),
}
}
}
pub async fn query<T, V>(&mut self, target: T, input: V) -> Result<JsonValue>
where
T: TryInto<QueryTarget, Error: Into<Error>>,
V: Serialize,
{
let query_target: QueryTarget = target.try_into().map_err(|e| e.into())?;
tracing::trace!("querying {}", query_target.to_string());
let input: JsonValue = serde_json::to_value(input)
.map_err(|source| Error::InvalidJsonInQueryKey(Box::new(source)))?;
let mut response = self.query_inner(query_target, vec![input]).await?;
Ok(response.pop().unwrap())
}
pub async fn batch_query<T, V>(&mut self, target: T, keys: Vec<V>) -> Result<Vec<JsonValue>>
where
T: TryInto<QueryTarget, Error: Into<Error>>,
V: Serialize,
{
let target: QueryTarget = target.try_into().map_err(|e| e.into())?;
tracing::trace!("querying {}", target.to_string());
let mut input = Vec::with_capacity(keys.len());
for key in keys {
let jsonified_key = serde_json::to_value(key)
.map_err(|source| Error::InvalidJsonInQueryKey(Box::new(source)))?;
input.push(jsonified_key);
}
self.query_inner(target, input).await
}
fn id(&self) -> usize {
self.id
}
async fn recv_raw(&mut self) -> Result<Option<VecDeque<PluginQuery>>> {
let mut out = VecDeque::new();
tracing::trace!("SDK: awaiting raw rx recv");
let opt_first = self.rx.recv().await.ok_or(Error::SessionChannelClosed)?;
let Some(first) = opt_first else {
return Ok(None);
};
out.push_back(first);
loop {
match self.rx.try_recv() {
Ok(Some(msg)) => {
out.push_back(msg);
}
Ok(None) => {
tracing::warn!(
"None received, gRPC channel closed. we may not close properly if None is not returned again"
);
break;
}
Err(_) => {
break;
}
}
}
Ok(Some(out))
}
async fn send(&self, mut query: Query) -> Result<()> {
query.id = self.id(); let queries = hipcheck_common::chunk::prepare(query)?;
for pq in queries {
let query = InitiateQueryProtocolResponse { query: Some(pq) };
self.tx
.send(Ok(query))
.await
.map_err(|source| Error::FailedToSendQueryFromSessionToServer(Box::new(source)))?;
}
Ok(())
}
async fn send_session_err<P>(&mut self) -> crate::error::Result<()>
where
P: Plugin,
{
let query = proto::Query {
id: self.id() as i32,
state: QueryState::Unspecified as i32,
publisher_name: P::PUBLISHER.to_owned(),
plugin_name: P::NAME.to_owned(),
query_name: "".to_owned(),
key: vec![],
output: vec![],
concern: self.take_concerns(),
split: false,
};
self.tx
.send(Ok(InitiateQueryProtocolResponse { query: Some(query) }))
.await
.map_err(|source| Error::FailedToSendQueryFromSessionToServer(Box::new(source)))
}
async fn recv(&mut self) -> Result<Option<Query>> {
let mut synth = QuerySynthesizer::default();
let mut res: Option<Query> = None;
while res.is_none() {
let Some(msg_chunks) = self.recv_raw().await? else {
return Ok(None);
};
res = synth.add(msg_chunks.into_iter())?;
}
Ok(res)
}
async fn handle_session_fallible<P>(&mut self, plugin: Arc<P>) -> crate::error::Result<()>
where
P: Plugin,
{
let Some(query) = self.recv().await? else {
return Err(Error::SessionChannelClosed);
};
if query.direction == QueryDirection::Response {
return Err(Error::ReceivedReplyWhenExpectingRequest);
}
let name = query.query;
if query.key.len() != 1 {
return Err(Error::UnspecifiedQueryState);
}
let key = query.key.first().unwrap().clone();
let query = plugin
.queries()
.filter_map(|x| if x.name == name { Some(x.inner) } else { None })
.next()
.or_else(|| plugin.default_query())
.ok_or_else(|| {
if name.is_empty() {
Error::NoDefaultQuery
} else {
Error::UnknownPluginQuery(name.clone().into_boxed_str())
}
})?;
#[cfg(feature = "print-timings")]
let _0 = crate::benchmarking::print_scope_time!(format!("{}/{}", P::NAME, name));
let value = query.run(self, key).await?;
#[cfg(feature = "print-timings")]
drop(_0);
let query = Query {
id: self.id(),
direction: QueryDirection::Response,
publisher: P::PUBLISHER.to_owned(),
plugin: P::NAME.to_owned(),
query: name.to_owned(),
key: vec![],
output: vec![value],
concerns: self.take_concerns(),
};
self.send(query).await
}
async fn handle_session<P>(&mut self, plugin: Arc<P>)
where
P: Plugin,
{
if let Err(e) = self.handle_session_fallible(plugin).await {
let res_err_send = match e {
Error::FailedToSendQueryFromSessionToServer(_) => {
tracing::error!("Failed to send message to Hipcheck core, analysis will hang.");
return;
}
other => {
tracing::error!("{}", other);
self.send_session_err::<P>().await
}
};
if res_err_send.is_err() {
tracing::error!("Failed to send message to Hipcheck core, analysis will hang.");
}
}
}
pub fn record_concern<S: AsRef<str>>(&mut self, concern: S) {
fn inner(engine: &mut PluginEngine, concern: &str) {
engine.concerns.push(concern.to_owned());
}
inner(self, concern.as_ref())
}
#[cfg(feature = "mock_engine")]
#[cfg_attr(docsrs, doc(cfg(feature = "mock_engine")))]
pub fn get_concerns(&self) -> &[String] {
&self.concerns
}
fn take_concerns(&mut self) -> Vec<String> {
self.concerns.drain(..).collect()
}
}
#[cfg(feature = "mock_engine")]
#[cfg_attr(docsrs, doc(cfg(feature = "mock_engine")))]
impl From<MockResponses> for PluginEngine {
fn from(value: MockResponses) -> Self {
let (tx, _) = mpsc::channel(1);
let (_, rx) = mpsc::channel(1);
let (drop_tx, _) = mpsc::channel(1);
Self {
id: 0,
concerns: vec![],
tx,
rx,
drop_tx,
mock_responses: value,
}
}
}
impl Drop for PluginEngine {
fn drop(&mut self) {
if cfg!(feature = "mock_engine") {
let _ = self.drop_tx.max_capacity();
} else {
while let Err(e) = self.drop_tx.try_send(self.id as i32) {
match e {
TrySendError::Closed(_) => {
break;
}
TrySendError::Full(_) => (),
}
}
}
}
}
type PluginQueryStream = Box<
dyn Stream<Item = StdResult<InitiateQueryProtocolRequest, Status>> + Send + Unpin + 'static,
>;
pub(crate) struct HcSessionSocket {
tx: mpsc::Sender<StdResult<InitiateQueryProtocolResponse, Status>>,
rx: PluginQueryStream,
drop_tx: mpsc::Sender<i32>,
drop_rx: mpsc::Receiver<i32>,
sessions: SessionTracker,
}
impl std::fmt::Debug for HcSessionSocket {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("HcSessionSocket")
.field("tx", &self.tx)
.field("rx", &"<rx>")
.field("drop_tx", &self.drop_tx)
.field("drop_rx", &self.drop_rx)
.field("sessions", &self.sessions)
.finish()
}
}
impl HcSessionSocket {
pub(crate) fn new(
tx: mpsc::Sender<StdResult<InitiateQueryProtocolResponse, Status>>,
rx: impl Stream<Item = StdResult<InitiateQueryProtocolRequest, Status>> + Send + Unpin + 'static,
) -> Self {
let (drop_tx, drop_rx) = mpsc::channel(10);
Self {
tx,
rx: Box::new(rx),
drop_tx,
drop_rx,
sessions: HashMap::new(),
}
}
fn cleanup_sessions(&mut self) {
while let Ok(id) = self.drop_rx.try_recv() {
match self.sessions.remove(&id) {
Some(_) => tracing::trace!("Cleaned up session {id}"),
None => {
tracing::warn!(
"HcSessionSocket got request to drop a session that does not exist"
)
}
}
}
}
async fn message(&mut self) -> StdResult<Option<PluginQuery>, Status> {
let fut = poll_fn(|cx| Pin::new(&mut *self.rx).poll_next(cx));
match fut.await {
Some(Ok(m)) => Ok(m.query),
Some(Err(e)) => Err(e),
None => Ok(None),
}
}
pub(crate) async fn listen(&mut self) -> Result<Option<PluginEngine>> {
loop {
let Some(raw) = self.message().await.map_err(Error::from)? else {
return Ok(None);
};
let id = raw.id;
self.cleanup_sessions();
match self.decide_action(&raw) {
Ok(HandleAction::ForwardMsgToExistingSession(tx)) => {
tracing::trace!("SDK: forwarding message to session {id}");
if let Err(_e) = tx.send(Some(raw)).await {
tracing::error!("Error forwarding msg to session {id}");
self.sessions.remove(&id);
};
}
Ok(HandleAction::CreateSession) => {
tracing::trace!("SDK: creating new session {id}");
let (in_tx, rx) = mpsc::channel::<Option<PluginQuery>>(10);
let tx = self.tx.clone();
let session = PluginEngine {
id: id as usize,
concerns: vec![],
tx,
rx,
drop_tx: self.drop_tx.clone(),
mock_responses: MockResponses::new(),
};
in_tx.send(Some(raw)).await.expect(
"Failed sending message to newly created Session, should never happen",
);
tracing::trace!("SDK: adding new session {id} to tracker");
self.sessions.insert(id, in_tx);
return Ok(Some(session));
}
Err(e) => tracing::error!("{}", e),
}
}
}
fn decide_action(&mut self, query: &PluginQuery) -> Result<HandleAction<'_>> {
if let Some(tx) = self.sessions.get_mut(&query.id) {
return Ok(HandleAction::ForwardMsgToExistingSession(tx));
}
if [QueryState::SubmitInProgress, QueryState::SubmitComplete].contains(&query.state()) {
return Ok(HandleAction::CreateSession);
}
Err(Error::ReceivedReplyWhenExpectingRequest)
}
pub(crate) async fn run<P>(&mut self, plugin: Arc<P>) -> Result<()>
where
P: Plugin,
{
loop {
let Some(mut engine) = self
.listen()
.await
.map_err(|_| Error::SessionChannelClosed)?
else {
tracing::trace!("Channel closed by remote");
break;
};
let cloned_plugin = plugin.clone();
tokio::spawn(async move {
engine.handle_session(cloned_plugin).await;
});
}
Ok(())
}
}
enum HandleAction<'s> {
ForwardMsgToExistingSession(&'s mut mpsc::Sender<Option<PluginQuery>>),
CreateSession,
}
#[derive(Default, Debug)]
pub struct MockResponses(pub(crate) HashMap<(QueryTarget, JsonValue), Result<JsonValue>>);
impl MockResponses {
pub fn new() -> Self {
Self(HashMap::new())
}
}
impl MockResponses {
#[cfg(feature = "mock_engine")]
pub fn insert<T, V, W>(
&mut self,
query_target: T,
query_value: V,
query_response: Result<W>,
) -> Result<()>
where
T: TryInto<QueryTarget, Error: Into<crate::Error>>,
V: serde::Serialize,
W: serde::Serialize,
{
let query_target: QueryTarget = query_target.try_into().map_err(|e| e.into())?;
let query_value: JsonValue = serde_json::to_value(query_value)
.map_err(|source| crate::Error::InvalidJsonInQueryKey(Box::new(source)))?;
let query_response = match query_response {
Ok(v) => serde_json::to_value(v)
.map_err(|source| crate::Error::InvalidJsonInQueryKey(Box::new(source))),
Err(e) => Err(e),
};
self.0.insert((query_target, query_value), query_response);
Ok(())
}
}
#[cfg(test)]
mod test {
use super::*;
#[cfg(feature = "mock_engine")]
#[tokio::test]
async fn test_query_builder() {
let mut mock_responses = MockResponses::new();
mock_responses
.insert("mitre/foo", "abcd", Ok(1234))
.unwrap();
mock_responses
.insert("mitre/foo", "efgh", Ok(5678))
.unwrap();
let mut engine = PluginEngine::mock(mock_responses);
let mut builder = engine.batch("mitre/foo").unwrap();
let idx = builder.query("abcd".into());
assert_eq!(idx, 0);
let idx = builder.query("efgh".into());
assert_eq!(idx, 1);
let response = builder.send().await.unwrap();
assert_eq!(
response.first().unwrap(),
&<i32 as Into<JsonValue>>::into(1234)
);
assert_eq!(
response.get(1).unwrap(),
&<i32 as Into<JsonValue>>::into(5678)
);
}
}