1use crate::{
4 engine::HcSessionSocket,
5 error::{Error, Result},
6 Plugin, QuerySchema,
7};
8use hipcheck_common::{
9 proto::{
10 plugin_service_server::{PluginService, PluginServiceServer},
11 ConfigurationStatus, ExplainDefaultQueryRequest as ExplainDefaultQueryReq,
12 ExplainDefaultQueryResponse as ExplainDefaultQueryResp,
13 GetDefaultPolicyExpressionRequest as GetDefaultPolicyExpressionReq,
14 GetDefaultPolicyExpressionResponse as GetDefaultPolicyExpressionResp,
15 GetQuerySchemasRequest as GetQuerySchemasReq,
16 GetQuerySchemasResponse as GetQuerySchemasResp,
17 InitiateQueryProtocolRequest as InitiateQueryProtocolReq,
18 InitiateQueryProtocolResponse as InitiateQueryProtocolResp,
19 SetConfigurationRequest as SetConfigurationReq,
20 SetConfigurationResponse as SetConfigurationResp,
21 },
22 types::LogLevel,
23};
24use std::{
25 net::{Ipv4Addr, SocketAddr},
26 result::Result as StdResult,
27 sync::Arc,
28};
29use tokio::sync::mpsc;
30use tokio_stream::wrappers::ReceiverStream as RecvStream;
31use tonic::{transport::Server, Code, Request as Req, Response as Resp, Status, Streaming};
32use tracing::error;
33
34#[derive(Debug, Clone)]
35pub enum Host {
36 Loopback,
38 Any,
40 Other(Ipv4Addr),
42}
43
44impl Host {
45 fn to_socket_addr(&self, port: u16) -> SocketAddr {
46 match self {
47 Host::Loopback => SocketAddr::new(Ipv4Addr::new(127, 0, 0, 1).into(), port),
48 Host::Any => SocketAddr::new(Ipv4Addr::new(0, 0, 0, 0).into(), port),
49 Host::Other(ip) => SocketAddr::new((*ip).into(), port),
50 }
51 }
52}
53
54pub struct PluginServer<P> {
58 plugin: Arc<P>,
59 curr_host: Host,
60}
61
62impl<P: Plugin> PluginServer<P> {
63 pub fn register(plugin: P, log_level_opt: impl Into<Option<LogLevel>>) -> PluginServer<P> {
65 #[cfg(feature = "log_forwarding")]
66 {
67 let log_level = log_level_opt.into().unwrap_or(LogLevel::Error);
68 crate::init_tracing_logger(log_level);
69 }
70
71 PluginServer {
72 plugin: Arc::new(plugin),
73 curr_host: Host::Any, }
75 }
76
77 pub async fn listen_local(self, port: u16) -> Result<()> {
79 self.listen(Host::Loopback, port).await
80 }
81
82 pub async fn listen(mut self, host: Host, port: u16) -> Result<()> {
84 self.curr_host = host.clone();
85 let service = PluginServiceServer::new(self);
86 let host_addr = host.to_socket_addr(port);
87
88 Server::builder()
89 .add_service(service)
90 .serve(host_addr)
91 .await
92 .map_err(Error::FailedToStartServer)?;
93
94 Ok(())
95 }
96}
97
98pub type QueryResult<T> = StdResult<T, Status>;
100
101#[tonic::async_trait]
102impl<P: Plugin> PluginService for PluginServer<P> {
103 type GetQuerySchemasStream = RecvStream<QueryResult<GetQuerySchemasResp>>;
104 type InitiateQueryProtocolStream = RecvStream<QueryResult<InitiateQueryProtocolResp>>;
105
106 async fn set_configuration(
107 &self,
108 req: Req<SetConfigurationReq>,
109 ) -> QueryResult<Resp<SetConfigurationResp>> {
110 let config = serde_json::from_str(&req.into_inner().configuration)
111 .map_err(|e| Status::from_error(Box::new(e)))?;
112 match self.plugin.set_config(config) {
113 Ok(_) => Ok(Resp::new(SetConfigurationResp {
114 status: ConfigurationStatus::None as i32,
115 message: "".to_owned(),
116 })),
117 Err(e) => Ok(Resp::new(e.into())),
118 }
119 }
120
121 async fn get_default_policy_expression(
122 &self,
123 _req: Req<GetDefaultPolicyExpressionReq>,
124 ) -> QueryResult<Resp<GetDefaultPolicyExpressionResp>> {
125 match self.plugin.default_policy_expr() {
127 Ok(policy_expression) => Ok(Resp::new(GetDefaultPolicyExpressionResp {
128 policy_expression,
129 })),
130 Err(e) => Err(Status::new(
131 tonic::Code::NotFound,
132 format!(
133 "Error determining default policy expr for {}/{}: {}",
134 P::PUBLISHER,
135 P::NAME,
136 e
137 ),
138 )),
139 }
140 }
141
142 async fn explain_default_query(
143 &self,
144 _req: Req<ExplainDefaultQueryReq>,
145 ) -> QueryResult<Resp<ExplainDefaultQueryResp>> {
146 match self.plugin.explain_default_query() {
147 Ok(explanation) => Ok(Resp::new(ExplainDefaultQueryResp {
148 explanation: explanation
149 .unwrap_or_else(|| "No default query explanation provided".to_owned()),
150 })),
151 Err(e) => Err(Status::new(
152 tonic::Code::NotFound,
153 format!(
154 "Error explaining default query expr for {}/{}: {}",
155 P::PUBLISHER,
156 P::NAME,
157 e
158 ),
159 )),
160 }
161 }
162
163 async fn get_query_schemas(
164 &self,
165 _req: Req<GetQuerySchemasReq>,
166 ) -> QueryResult<Resp<Self::GetQuerySchemasStream>> {
167 let query_schemas = self.plugin.schemas().collect::<Vec<QuerySchema>>();
169 let (tx, rx) = mpsc::channel(10);
171 tokio::spawn(async move {
172 for x in query_schemas {
173 let input_schema = serde_json::to_string(&x.input_schema);
174 let output_schema = serde_json::to_string(&x.output_schema);
175
176 let schema_resp = match (input_schema, output_schema) {
177 (Ok(input_schema), Ok(output_schema)) => Ok(GetQuerySchemasResp {
178 query_name: x.query_name.to_string(),
179 key_schema: input_schema,
180 output_schema,
181 }),
182 (Ok(_), Err(e)) => Err(Status::new(
183 Code::FailedPrecondition,
184 format!("Error converting output schema to String: {}", e),
185 )),
186 (Err(_), Ok(e)) => Err(Status::new(
187 Code::FailedPrecondition,
188 format!("Error converting input schema to String: {}", e),
189 )),
190 (Err(e1), Err(e2)) => Err(Status::new(
191 Code::FailedPrecondition,
192 format!(
193 "Error converting input and output schema to String: {} {}",
194 e1, e2
195 ),
196 )),
197 };
198
199 if tx.send(schema_resp).await.is_err() {
200 panic!();
202 }
203 }
204 });
205 Ok(Resp::new(RecvStream::new(rx)))
206 }
207
208 async fn initiate_query_protocol(
209 &self,
210 req: Req<Streaming<InitiateQueryProtocolReq>>,
211 ) -> QueryResult<Resp<Self::InitiateQueryProtocolStream>> {
212 let rx = req.into_inner();
213 let (tx, out_rx) = match self.curr_host {
215 Host::Loopback => mpsc::channel::<QueryResult<InitiateQueryProtocolResp>>(10),
216 _ => mpsc::channel::<QueryResult<InitiateQueryProtocolResp>>(100),
217 };
218
219 let cloned_plugin = self.plugin.clone();
220 let tx_clone = tx.clone();
221 tokio::spawn(async move {
222 let mut channel = HcSessionSocket::new(tx, rx);
223 if let Err(e) = channel.run(cloned_plugin).await {
224 error!("Channel error: {e}");
225 if !tx_clone.is_closed() {
226 if let Err(send_err) = tx_clone
227 .send(Err(tonic::Status::internal(format!("Session error: {e}"))))
228 .await
229 {
230 error!("Failed to send error through channel: {send_err}");
231 }
232 }
233 }
234 });
235
236 Ok(Resp::new(RecvStream::new(out_rx)))
237 }
238}