hipcheck_sdk/
server.rs

1// SPDX-License-Identifier: Apache-2.0
2
3use crate::{
4	engine::HcSessionSocket,
5	error::{Error, Result},
6	Plugin, QuerySchema,
7};
8use hipcheck_common::{
9	proto::{
10		plugin_service_server::{PluginService, PluginServiceServer},
11		ConfigurationStatus, ExplainDefaultQueryRequest as ExplainDefaultQueryReq,
12		ExplainDefaultQueryResponse as ExplainDefaultQueryResp,
13		GetDefaultPolicyExpressionRequest as GetDefaultPolicyExpressionReq,
14		GetDefaultPolicyExpressionResponse as GetDefaultPolicyExpressionResp,
15		GetQuerySchemasRequest as GetQuerySchemasReq,
16		GetQuerySchemasResponse as GetQuerySchemasResp,
17		InitiateQueryProtocolRequest as InitiateQueryProtocolReq,
18		InitiateQueryProtocolResponse as InitiateQueryProtocolResp,
19		SetConfigurationRequest as SetConfigurationReq,
20		SetConfigurationResponse as SetConfigurationResp,
21	},
22	types::LogLevel,
23};
24use std::{
25	net::{Ipv4Addr, SocketAddr},
26	result::Result as StdResult,
27	sync::Arc,
28};
29use tokio::sync::mpsc;
30use tokio_stream::wrappers::ReceiverStream as RecvStream;
31use tonic::{transport::Server, Code, Request as Req, Response as Resp, Status, Streaming};
32use tracing::error;
33
34#[derive(Debug, Clone)]
35pub enum Host {
36	// 127.0.0.1
37	Loopback,
38	// 0.0.0.0
39	Any,
40	// Any other IP address.
41	Other(Ipv4Addr),
42}
43
44impl Host {
45	fn to_socket_addr(&self, port: u16) -> SocketAddr {
46		match self {
47			Host::Loopback => SocketAddr::new(Ipv4Addr::new(127, 0, 0, 1).into(), port),
48			Host::Any => SocketAddr::new(Ipv4Addr::new(0, 0, 0, 0).into(), port),
49			Host::Other(ip) => SocketAddr::new((*ip).into(), port),
50		}
51	}
52}
53
54/// Runs the Hipcheck plugin protocol based on the user's implementation of the `Plugin` trait.
55///
56/// This struct implements the underlying gRPC protocol that is not exposed to the plugin author.
57pub struct PluginServer<P> {
58	plugin: Arc<P>,
59	curr_host: Host,
60}
61
62impl<P: Plugin> PluginServer<P> {
63	/// Create a new plugin server for the provided plugin.
64	pub fn register(plugin: P, log_level_opt: impl Into<Option<LogLevel>>) -> PluginServer<P> {
65		#[cfg(feature = "log_forwarding")]
66		{
67			let log_level = log_level_opt.into().unwrap_or(LogLevel::Error);
68			crate::init_tracing_logger(log_level);
69		}
70
71		PluginServer {
72			plugin: Arc::new(plugin),
73			curr_host: Host::Any, // default
74		}
75	}
76
77	/// Run the plugin server on the loopback address and provided port.
78	pub async fn listen_local(self, port: u16) -> Result<()> {
79		self.listen(Host::Loopback, port).await
80	}
81
82	/// Run the plugin server on the provided port.
83	pub async fn listen(mut self, host: Host, port: u16) -> Result<()> {
84		self.curr_host = host.clone();
85		let service = PluginServiceServer::new(self);
86		let host_addr = host.to_socket_addr(port);
87
88		Server::builder()
89			.add_service(service)
90			.serve(host_addr)
91			.await
92			.map_err(Error::FailedToStartServer)?;
93
94		Ok(())
95	}
96}
97
98/// The result of running a query, where the error is of the type `tonic::Status`.
99pub type QueryResult<T> = StdResult<T, Status>;
100
101#[tonic::async_trait]
102impl<P: Plugin> PluginService for PluginServer<P> {
103	type GetQuerySchemasStream = RecvStream<QueryResult<GetQuerySchemasResp>>;
104	type InitiateQueryProtocolStream = RecvStream<QueryResult<InitiateQueryProtocolResp>>;
105
106	async fn set_configuration(
107		&self,
108		req: Req<SetConfigurationReq>,
109	) -> QueryResult<Resp<SetConfigurationResp>> {
110		let config = serde_json::from_str(&req.into_inner().configuration)
111			.map_err(|e| Status::from_error(Box::new(e)))?;
112		match self.plugin.set_config(config) {
113			Ok(_) => Ok(Resp::new(SetConfigurationResp {
114				status: ConfigurationStatus::None as i32,
115				message: "".to_owned(),
116			})),
117			Err(e) => Ok(Resp::new(e.into())),
118		}
119	}
120
121	async fn get_default_policy_expression(
122		&self,
123		_req: Req<GetDefaultPolicyExpressionReq>,
124	) -> QueryResult<Resp<GetDefaultPolicyExpressionResp>> {
125		// The request is empty, so we do nothing.
126		match self.plugin.default_policy_expr() {
127			Ok(policy_expression) => Ok(Resp::new(GetDefaultPolicyExpressionResp {
128				policy_expression,
129			})),
130			Err(e) => Err(Status::new(
131				tonic::Code::NotFound,
132				format!(
133					"Error determining default policy expr for {}/{}: {}",
134					P::PUBLISHER,
135					P::NAME,
136					e
137				),
138			)),
139		}
140	}
141
142	async fn explain_default_query(
143		&self,
144		_req: Req<ExplainDefaultQueryReq>,
145	) -> QueryResult<Resp<ExplainDefaultQueryResp>> {
146		match self.plugin.explain_default_query() {
147			Ok(explanation) => Ok(Resp::new(ExplainDefaultQueryResp {
148				explanation: explanation
149					.unwrap_or_else(|| "No default query explanation provided".to_owned()),
150			})),
151			Err(e) => Err(Status::new(
152				tonic::Code::NotFound,
153				format!(
154					"Error explaining default query expr for {}/{}: {}",
155					P::PUBLISHER,
156					P::NAME,
157					e
158				),
159			)),
160		}
161	}
162
163	async fn get_query_schemas(
164		&self,
165		_req: Req<GetQuerySchemasReq>,
166	) -> QueryResult<Resp<Self::GetQuerySchemasStream>> {
167		// Ignore the input, it's empty.
168		let query_schemas = self.plugin.schemas().collect::<Vec<QuerySchema>>();
169		// TODO: does this need to be configurable?
170		let (tx, rx) = mpsc::channel(10);
171		tokio::spawn(async move {
172			for x in query_schemas {
173				let input_schema = serde_json::to_string(&x.input_schema);
174				let output_schema = serde_json::to_string(&x.output_schema);
175
176				let schema_resp = match (input_schema, output_schema) {
177					(Ok(input_schema), Ok(output_schema)) => Ok(GetQuerySchemasResp {
178						query_name: x.query_name.to_string(),
179						key_schema: input_schema,
180						output_schema,
181					}),
182					(Ok(_), Err(e)) => Err(Status::new(
183						Code::FailedPrecondition,
184						format!("Error converting output schema to String: {}", e),
185					)),
186					(Err(_), Ok(e)) => Err(Status::new(
187						Code::FailedPrecondition,
188						format!("Error converting input schema to String: {}", e),
189					)),
190					(Err(e1), Err(e2)) => Err(Status::new(
191						Code::FailedPrecondition,
192						format!(
193							"Error converting input and output schema to String: {} {}",
194							e1, e2
195						),
196					)),
197				};
198
199				if tx.send(schema_resp).await.is_err() {
200					// TODO: handle this?
201					panic!();
202				}
203			}
204		});
205		Ok(Resp::new(RecvStream::new(rx)))
206	}
207
208	async fn initiate_query_protocol(
209		&self,
210		req: Req<Streaming<InitiateQueryProtocolReq>>,
211	) -> QueryResult<Resp<Self::InitiateQueryProtocolStream>> {
212		let rx = req.into_inner();
213		// TODO: - make channel size configurable
214		let (tx, out_rx) = match self.curr_host {
215			Host::Loopback => mpsc::channel::<QueryResult<InitiateQueryProtocolResp>>(10),
216			_ => mpsc::channel::<QueryResult<InitiateQueryProtocolResp>>(100),
217		};
218
219		let cloned_plugin = self.plugin.clone();
220		let tx_clone = tx.clone();
221		tokio::spawn(async move {
222			let mut channel = HcSessionSocket::new(tx, rx);
223			if let Err(e) = channel.run(cloned_plugin).await {
224				error!("Channel error: {e}");
225				if !tx_clone.is_closed() {
226					if let Err(send_err) = tx_clone
227						.send(Err(tonic::Status::internal(format!("Session error: {e}"))))
228						.await
229					{
230						error!("Failed to send error through channel: {send_err}");
231					}
232				}
233			}
234		});
235
236		Ok(Resp::new(RecvStream::new(out_rx)))
237	}
238}