use crate::{
Plugin, QuerySchema,
engine::HcSessionSocket,
error::{Error, Result},
};
use hipcheck_common::{
proto::{
ConfigurationStatus, ExplainDefaultQueryRequest as ExplainDefaultQueryReq,
ExplainDefaultQueryResponse as ExplainDefaultQueryResp,
GetDefaultPolicyExpressionRequest as GetDefaultPolicyExpressionReq,
GetDefaultPolicyExpressionResponse as GetDefaultPolicyExpressionResp,
GetQuerySchemasRequest as GetQuerySchemasReq,
GetQuerySchemasResponse as GetQuerySchemasResp,
InitiateQueryProtocolRequest as InitiateQueryProtocolReq,
InitiateQueryProtocolResponse as InitiateQueryProtocolResp,
SetConfigurationRequest as SetConfigurationReq,
SetConfigurationResponse as SetConfigurationResp,
plugin_service_server::{PluginService, PluginServiceServer},
},
types::LogLevel,
};
use std::{
net::{Ipv4Addr, SocketAddr},
result::Result as StdResult,
sync::Arc,
};
use tokio::sync::mpsc;
use tokio_stream::wrappers::ReceiverStream as RecvStream;
use tonic::{Code, Request as Req, Response as Resp, Status, Streaming, transport::Server};
use tracing::error;
#[derive(Debug, Clone)]
pub enum Host {
Loopback,
Any,
Other(Ipv4Addr),
}
impl Host {
fn to_socket_addr(&self, port: u16) -> SocketAddr {
match self {
Host::Loopback => SocketAddr::new(Ipv4Addr::new(127, 0, 0, 1).into(), port),
Host::Any => SocketAddr::new(Ipv4Addr::new(0, 0, 0, 0).into(), port),
Host::Other(ip) => SocketAddr::new((*ip).into(), port),
}
}
}
pub struct PluginServer<P> {
plugin: Arc<P>,
curr_host: Host,
}
impl<P: Plugin> PluginServer<P> {
pub fn register(plugin: P, log_level_opt: impl Into<Option<LogLevel>>) -> PluginServer<P> {
#[cfg(feature = "log_forwarding")]
{
let log_level = log_level_opt.into().unwrap_or(LogLevel::Error);
crate::init_tracing_logger(log_level);
}
PluginServer {
plugin: Arc::new(plugin),
curr_host: Host::Any, }
}
pub async fn listen_local(self, port: u16) -> Result<()> {
self.listen(Host::Loopback, port).await
}
pub async fn listen(mut self, host: Host, port: u16) -> Result<()> {
self.curr_host = host.clone();
let service = PluginServiceServer::new(self);
let host_addr = host.to_socket_addr(port);
Server::builder()
.add_service(service)
.serve(host_addr)
.await
.map_err(|source| Error::FailedToStartServer(Box::new(source)))?;
Ok(())
}
}
pub type QueryResult<T> = StdResult<T, Status>;
#[tonic::async_trait]
impl<P: Plugin> PluginService for PluginServer<P> {
type GetQuerySchemasStream = RecvStream<QueryResult<GetQuerySchemasResp>>;
type InitiateQueryProtocolStream = RecvStream<QueryResult<InitiateQueryProtocolResp>>;
async fn set_configuration(
&self,
req: Req<SetConfigurationReq>,
) -> QueryResult<Resp<SetConfigurationResp>> {
let config = serde_json::from_str(&req.into_inner().configuration)
.map_err(|e| Status::from_error(Box::new(e)))?;
match self.plugin.set_config(config) {
Ok(_) => Ok(Resp::new(SetConfigurationResp {
status: ConfigurationStatus::None as i32,
message: "".to_owned(),
})),
Err(e) => Ok(Resp::new(e.into())),
}
}
async fn get_default_policy_expression(
&self,
_req: Req<GetDefaultPolicyExpressionReq>,
) -> QueryResult<Resp<GetDefaultPolicyExpressionResp>> {
match self.plugin.default_policy_expr() {
Ok(policy_expression) => Ok(Resp::new(GetDefaultPolicyExpressionResp {
policy_expression,
})),
Err(e) => Err(Status::new(
tonic::Code::NotFound,
format!(
"Error determining default policy expr for {}/{}: {}",
P::PUBLISHER,
P::NAME,
e
),
)),
}
}
async fn explain_default_query(
&self,
_req: Req<ExplainDefaultQueryReq>,
) -> QueryResult<Resp<ExplainDefaultQueryResp>> {
match self.plugin.explain_default_query() {
Ok(explanation) => Ok(Resp::new(ExplainDefaultQueryResp {
explanation: explanation
.unwrap_or_else(|| "No default query explanation provided".to_owned()),
})),
Err(e) => Err(Status::new(
tonic::Code::NotFound,
format!(
"Error explaining default query expr for {}/{}: {}",
P::PUBLISHER,
P::NAME,
e
),
)),
}
}
async fn get_query_schemas(
&self,
_req: Req<GetQuerySchemasReq>,
) -> QueryResult<Resp<Self::GetQuerySchemasStream>> {
let query_schemas = self.plugin.schemas().collect::<Vec<QuerySchema>>();
let (tx, rx) = mpsc::channel(10);
tokio::spawn(async move {
for x in query_schemas {
let input_schema = serde_json::to_string(&x.input_schema);
let output_schema = serde_json::to_string(&x.output_schema);
let schema_resp = match (input_schema, output_schema) {
(Ok(input_schema), Ok(output_schema)) => Ok(GetQuerySchemasResp {
query_name: x.query_name.to_string(),
key_schema: input_schema,
output_schema,
}),
(Ok(_), Err(e)) => Err(Status::new(
Code::FailedPrecondition,
format!("Error converting output schema to String: {}", e),
)),
(Err(_), Ok(e)) => Err(Status::new(
Code::FailedPrecondition,
format!("Error converting input schema to String: {}", e),
)),
(Err(e1), Err(e2)) => Err(Status::new(
Code::FailedPrecondition,
format!(
"Error converting input and output schema to String: {} {}",
e1, e2
),
)),
};
if tx.send(schema_resp).await.is_err() {
panic!();
}
}
});
Ok(Resp::new(RecvStream::new(rx)))
}
async fn initiate_query_protocol(
&self,
req: Req<Streaming<InitiateQueryProtocolReq>>,
) -> QueryResult<Resp<Self::InitiateQueryProtocolStream>> {
let rx = req.into_inner();
let (tx, out_rx) = match self.curr_host {
Host::Loopback => mpsc::channel::<QueryResult<InitiateQueryProtocolResp>>(10),
_ => mpsc::channel::<QueryResult<InitiateQueryProtocolResp>>(100),
};
let cloned_plugin = self.plugin.clone();
let tx_clone = tx.clone();
tokio::spawn(async move {
let mut channel = HcSessionSocket::new(tx, rx);
if let Err(e) = channel.run(cloned_plugin).await {
error!("Channel error: {e}");
if !tx_clone.is_closed()
&& let Err(send_err) = tx_clone
.send(Err(tonic::Status::internal(format!("Session error: {e}"))))
.await
{
error!("Failed to send error through channel: {send_err}");
}
}
});
Ok(Resp::new(RecvStream::new(out_rx)))
}
}