1use std::any::Any;
2use std::collections::BTreeMap;
3use std::future::Future;
4use std::marker::PhantomData;
5use std::pin::Pin;
6use std::sync::Arc;
7
8use schemars::JsonSchema;
9use serde::Serialize;
10use serde::de::DeserializeOwned;
11use serde_json::Value;
12
13use crate::api::{IntoResponse, Request};
14use crate::catalog::{Catalog, CatalogOperation, schema_json, schema_parameters};
15use crate::error::{Error, INTERNAL_ERROR_MESSAGE, Result};
16use crate::provider_server::OperationResult;
17
18#[derive(Clone, Debug)]
19pub struct Operation<In, Out> {
20 pub id: String,
21 pub method: String,
22 pub title: String,
23 pub description: String,
24 pub allowed_roles: Vec<String>,
25 pub tags: Vec<String>,
26 pub read_only: bool,
27 pub visible: Option<bool>,
28 _types: PhantomData<(In, Out)>,
29}
30
31impl<In, Out> Operation<In, Out>
32where
33 In: JsonSchema,
34 Out: JsonSchema,
35{
36 pub fn new(id: impl Into<String>) -> Self {
37 Self {
38 id: id.into(),
39 method: "POST".to_owned(),
40 title: String::new(),
41 description: String::new(),
42 allowed_roles: Vec::new(),
43 tags: Vec::new(),
44 read_only: false,
45 visible: None,
46 _types: PhantomData,
47 }
48 }
49
50 pub fn method(mut self, method: impl AsRef<str>) -> Self {
51 let method = method.as_ref().trim().to_ascii_uppercase();
52 if !method.is_empty() {
53 self.method = method;
54 }
55 self
56 }
57
58 pub fn title(mut self, title: impl Into<String>) -> Self {
59 self.title = title.into();
60 self
61 }
62
63 pub fn description(mut self, description: impl Into<String>) -> Self {
64 self.description = description.into();
65 self
66 }
67
68 pub fn allowed_roles(mut self, allowed_roles: impl Into<Vec<String>>) -> Self {
69 self.allowed_roles = allowed_roles.into();
70 self
71 }
72
73 pub fn tags(mut self, tags: impl Into<Vec<String>>) -> Self {
74 self.tags = tags.into();
75 self
76 }
77
78 pub fn read_only(mut self, read_only: bool) -> Self {
79 self.read_only = read_only;
80 self
81 }
82
83 pub fn visible(mut self, visible: bool) -> Self {
84 self.visible = Some(visible);
85 self
86 }
87}
88
89type Handler<P> = Arc<
90 dyn Fn(Arc<P>, Value, Request) -> Pin<Box<dyn Future<Output = OperationResult> + Send>>
91 + Send
92 + Sync,
93>;
94
95pub struct Router<P> {
96 catalog: Catalog,
97 handlers: BTreeMap<String, Handler<P>>,
98}
99
100impl<P> Clone for Router<P> {
101 fn clone(&self) -> Self {
102 Self {
103 catalog: self.catalog.clone(),
104 handlers: self.handlers.clone(),
105 }
106 }
107}
108
109impl<P> Default for Router<P> {
110 fn default() -> Self {
111 Self::new()
112 }
113}
114
115impl<P> Router<P> {
116 pub fn new() -> Self {
117 Self {
118 catalog: Catalog::default(),
119 handlers: BTreeMap::new(),
120 }
121 }
122
123 pub fn with_name(mut self, name: impl Into<String>) -> Self {
124 let name = name.into();
125 if !name.trim().is_empty() {
126 self.catalog.name = name;
127 }
128 self
129 }
130
131 pub fn catalog(&self) -> &Catalog {
132 &self.catalog
133 }
134
135 pub async fn execute(
136 &self,
137 provider: Arc<P>,
138 operation: &str,
139 params: Value,
140 request: Request,
141 ) -> OperationResult {
142 let Some(handler) = self.handlers.get(operation) else {
143 return OperationResult::error(404, "unknown operation");
144 };
145
146 match tokio::spawn(handler(provider, params, request)).await {
147 Ok(result) => result,
148 Err(error) => OperationResult::error(500, join_error_message(error)),
149 }
150 }
151}
152
153impl<P> Router<P>
154where
155 P: Send + Sync + 'static,
156{
157 pub fn register<In, Out, F, Fut, R, E>(
158 mut self,
159 operation: Operation<In, Out>,
160 handler: F,
161 ) -> Result<Self>
162 where
163 In: DeserializeOwned + JsonSchema + Send + 'static,
164 Out: Serialize + JsonSchema + Send + 'static,
165 F: Fn(Arc<P>, In, Request) -> Fut + Send + Sync + 'static,
166 Fut: Future<Output = std::result::Result<R, E>> + Send + 'static,
167 R: IntoResponse<Out> + Send + 'static,
168 E: Into<Error> + Send + 'static,
169 {
170 let operation_id = operation.id.trim();
171 if operation_id.is_empty() {
172 return Err(Error::bad_request("operation id is required"));
173 }
174 if self.handlers.contains_key(operation_id) {
175 return Err(Error::bad_request(format!(
176 "duplicate operation id {:?}",
177 operation_id
178 )));
179 }
180
181 let input_schema = schema_json::<In>()?;
182 let output_schema = schema_json::<Out>()?;
183 let parameters = schema_parameters(&input_schema);
184 let input_schema_str = serde_json::to_string(&input_schema).unwrap_or_default();
185 let output_schema_str = serde_json::to_string(&output_schema).unwrap_or_default();
186 let annotations = Some(crate::generated::v1::OperationAnnotations {
187 read_only_hint: operation.read_only.then_some(true),
188 ..Default::default()
189 });
190 self.catalog.operations.push(CatalogOperation {
191 id: operation_id.to_owned(),
192 method: operation.method.clone(),
193 title: operation.title.trim().to_owned(),
194 description: operation.description.trim().to_owned(),
195 input_schema: input_schema_str,
196 output_schema: output_schema_str,
197 annotations,
198 parameters,
199 required_scopes: Vec::new(),
200 tags: operation.tags.clone(),
201 read_only: operation.read_only,
202 visible: operation.visible,
203 transport: String::new(),
204 allowed_roles: operation.allowed_roles.clone(),
205 });
206
207 let handler = Arc::new(handler);
208 let operation_id = operation_id.to_owned();
209 self.handlers.insert(
210 operation_id.clone(),
211 Arc::new(
212 move |provider: Arc<P>, raw_params: Value, request: Request| {
213 let handler = Arc::clone(&handler);
214 let operation_id = operation_id.clone();
215 Box::pin(async move {
216 let input = match decode_params::<In>(&operation_id, raw_params) {
217 Ok(input) => input,
218 Err(error) => return OperationResult::from_error(error),
219 };
220
221 match handler(provider, input, request).await {
222 Ok(response) => {
223 OperationResult::from_response(response.into_response())
224 }
225 Err(error) => OperationResult::from_error(error.into()),
226 }
227 })
228 },
229 ),
230 );
231
232 Ok(self)
233 }
234}
235
236fn decode_params<In: DeserializeOwned>(operation_id: &str, raw_params: Value) -> Result<In> {
237 let empty = is_empty_object(&raw_params);
238 match serde_json::from_value::<In>(raw_params) {
239 Ok(input) => Ok(input),
240 Err(error) if empty => serde_json::from_value::<In>(Value::Null).map_err(|_| {
241 Error::bad_request(format!("decode params for {:?}: {}", operation_id, error))
242 }),
243 Err(error) => Err(Error::bad_request(format!(
244 "decode params for {:?}: {}",
245 operation_id, error
246 ))),
247 }
248}
249
250fn is_empty_object(value: &Value) -> bool {
251 matches!(value, Value::Object(map) if map.is_empty())
252}
253
254fn join_error_message(error: tokio::task::JoinError) -> String {
255 if error.is_panic() {
256 let payload = error.try_into_panic().expect("panic payload");
257 log_panic_payload(payload);
258 } else {
259 eprintln!("internal error in Gestalt operation task: {error}");
260 }
261 INTERNAL_ERROR_MESSAGE.to_owned()
262}
263
264fn log_panic_payload(payload: Box<dyn Any + Send + 'static>) {
265 if let Some(text) = payload.downcast_ref::<&'static str>() {
266 eprintln!("panic in Gestalt operation: {}", text);
267 } else if let Some(text) = payload.downcast_ref::<String>() {
268 eprintln!("panic in Gestalt operation: {}", text);
269 } else {
270 eprintln!("panic in Gestalt operation");
271 }
272}
273
274#[cfg(test)]
275mod tests {
276 use super::*;
277
278 #[derive(Clone, Default)]
279 struct TestProvider;
280
281 #[derive(serde::Deserialize, schemars::JsonSchema)]
282 struct Input {
283 query: String,
284 }
285
286 #[derive(serde::Serialize, schemars::JsonSchema)]
287 struct Output {
288 query: String,
289 }
290
291 #[tokio::test]
292 async fn router_execute_returns_not_found_for_unknown_operation() {
293 let router = Router::<TestProvider>::new();
294 let result = router
295 .execute(
296 Arc::new(TestProvider),
297 "missing",
298 Value::Object(Default::default()),
299 Request::default(),
300 )
301 .await;
302 assert_eq!(result.status, 404);
303 }
304
305 #[test]
306 fn router_rejects_duplicate_ids() {
307 let router = Router::<TestProvider>::new()
308 .register(
309 Operation::<Input, Output>::new("search"),
310 |_provider, input, _request| async move {
311 Ok::<crate::Response<Output>, std::convert::Infallible>(crate::ok(Output {
312 query: input.query,
313 }))
314 },
315 )
316 .expect("first registration");
317 let result = router.register(
318 Operation::<Input, Output>::new("search"),
319 |_provider, input, _request| async move {
320 Ok::<crate::Response<Output>, std::convert::Infallible>(crate::ok(Output {
321 query: input.query,
322 }))
323 },
324 );
325
326 match result {
327 Ok(_) => panic!("duplicate id should fail"),
328 Err(err) => assert!(err.message().contains("duplicate operation id")),
329 }
330 }
331}