Skip to main content

gestalt/
router.rs

1use std::any::Any;
2use std::collections::BTreeMap;
3use std::future::Future;
4use std::marker::PhantomData;
5use std::pin::Pin;
6use std::sync::Arc;
7
8use schemars::JsonSchema;
9use serde::Serialize;
10use serde::de::DeserializeOwned;
11use serde_json::Value;
12
13use crate::api::{IntoResponse, Request};
14use crate::catalog::{Catalog, CatalogOperation, schema_json, schema_parameters};
15use crate::error::{Error, INTERNAL_ERROR_MESSAGE, Result};
16use crate::provider_server::OperationResult;
17
18#[derive(Clone, Debug)]
19pub struct Operation<In, Out> {
20    pub id: String,
21    pub method: String,
22    pub title: String,
23    pub description: String,
24    pub allowed_roles: Vec<String>,
25    pub tags: Vec<String>,
26    pub read_only: bool,
27    pub visible: Option<bool>,
28    _types: PhantomData<(In, Out)>,
29}
30
31impl<In, Out> Operation<In, Out>
32where
33    In: JsonSchema,
34    Out: JsonSchema,
35{
36    pub fn new(id: impl Into<String>) -> Self {
37        Self {
38            id: id.into(),
39            method: "POST".to_owned(),
40            title: String::new(),
41            description: String::new(),
42            allowed_roles: Vec::new(),
43            tags: Vec::new(),
44            read_only: false,
45            visible: None,
46            _types: PhantomData,
47        }
48    }
49
50    pub fn method(mut self, method: impl AsRef<str>) -> Self {
51        let method = method.as_ref().trim().to_ascii_uppercase();
52        if !method.is_empty() {
53            self.method = method;
54        }
55        self
56    }
57
58    pub fn title(mut self, title: impl Into<String>) -> Self {
59        self.title = title.into();
60        self
61    }
62
63    pub fn description(mut self, description: impl Into<String>) -> Self {
64        self.description = description.into();
65        self
66    }
67
68    pub fn allowed_roles(mut self, allowed_roles: impl Into<Vec<String>>) -> Self {
69        self.allowed_roles = allowed_roles.into();
70        self
71    }
72
73    pub fn tags(mut self, tags: impl Into<Vec<String>>) -> Self {
74        self.tags = tags.into();
75        self
76    }
77
78    pub fn read_only(mut self, read_only: bool) -> Self {
79        self.read_only = read_only;
80        self
81    }
82
83    pub fn visible(mut self, visible: bool) -> Self {
84        self.visible = Some(visible);
85        self
86    }
87}
88
89type Handler<P> = Arc<
90    dyn Fn(Arc<P>, Value, Request) -> Pin<Box<dyn Future<Output = OperationResult> + Send>>
91        + Send
92        + Sync,
93>;
94
95pub struct Router<P> {
96    catalog: Catalog,
97    handlers: BTreeMap<String, Handler<P>>,
98}
99
100impl<P> Clone for Router<P> {
101    fn clone(&self) -> Self {
102        Self {
103            catalog: self.catalog.clone(),
104            handlers: self.handlers.clone(),
105        }
106    }
107}
108
109impl<P> Default for Router<P> {
110    fn default() -> Self {
111        Self::new()
112    }
113}
114
115impl<P> Router<P> {
116    pub fn new() -> Self {
117        Self {
118            catalog: Catalog::default(),
119            handlers: BTreeMap::new(),
120        }
121    }
122
123    pub fn with_name(mut self, name: impl Into<String>) -> Self {
124        let name = name.into();
125        if !name.trim().is_empty() {
126            self.catalog.name = name;
127        }
128        self
129    }
130
131    pub fn catalog(&self) -> &Catalog {
132        &self.catalog
133    }
134
135    pub async fn execute(
136        &self,
137        provider: Arc<P>,
138        operation: &str,
139        params: Value,
140        request: Request,
141    ) -> OperationResult {
142        let Some(handler) = self.handlers.get(operation) else {
143            return OperationResult::error(404, "unknown operation");
144        };
145
146        match tokio::spawn(handler(provider, params, request)).await {
147            Ok(result) => result,
148            Err(error) => OperationResult::error(500, join_error_message(error)),
149        }
150    }
151}
152
153impl<P> Router<P>
154where
155    P: Send + Sync + 'static,
156{
157    pub fn register<In, Out, F, Fut, R, E>(
158        mut self,
159        operation: Operation<In, Out>,
160        handler: F,
161    ) -> Result<Self>
162    where
163        In: DeserializeOwned + JsonSchema + Send + 'static,
164        Out: Serialize + JsonSchema + Send + 'static,
165        F: Fn(Arc<P>, In, Request) -> Fut + Send + Sync + 'static,
166        Fut: Future<Output = std::result::Result<R, E>> + Send + 'static,
167        R: IntoResponse<Out> + Send + 'static,
168        E: Into<Error> + Send + 'static,
169    {
170        let operation_id = operation.id.trim();
171        if operation_id.is_empty() {
172            return Err(Error::bad_request("operation id is required"));
173        }
174        if self.handlers.contains_key(operation_id) {
175            return Err(Error::bad_request(format!(
176                "duplicate operation id {:?}",
177                operation_id
178            )));
179        }
180
181        let input_schema = schema_json::<In>()?;
182        let output_schema = schema_json::<Out>()?;
183        let parameters = schema_parameters(&input_schema);
184        let input_schema_str = serde_json::to_string(&input_schema).unwrap_or_default();
185        let output_schema_str = serde_json::to_string(&output_schema).unwrap_or_default();
186        let annotations = Some(crate::generated::v1::OperationAnnotations {
187            read_only_hint: operation.read_only.then_some(true),
188            ..Default::default()
189        });
190        self.catalog.operations.push(CatalogOperation {
191            id: operation_id.to_owned(),
192            method: operation.method.clone(),
193            title: operation.title.trim().to_owned(),
194            description: operation.description.trim().to_owned(),
195            input_schema: input_schema_str,
196            output_schema: output_schema_str,
197            annotations,
198            parameters,
199            required_scopes: Vec::new(),
200            tags: operation.tags.clone(),
201            read_only: operation.read_only,
202            visible: operation.visible,
203            transport: String::new(),
204            allowed_roles: operation.allowed_roles.clone(),
205        });
206
207        let handler = Arc::new(handler);
208        let operation_id = operation_id.to_owned();
209        self.handlers.insert(
210            operation_id.clone(),
211            Arc::new(
212                move |provider: Arc<P>, raw_params: Value, request: Request| {
213                    let handler = Arc::clone(&handler);
214                    let operation_id = operation_id.clone();
215                    Box::pin(async move {
216                        let input = match decode_params::<In>(&operation_id, raw_params) {
217                            Ok(input) => input,
218                            Err(error) => return OperationResult::from_error(error),
219                        };
220
221                        match handler(provider, input, request).await {
222                            Ok(response) => {
223                                OperationResult::from_response(response.into_response())
224                            }
225                            Err(error) => OperationResult::from_error(error.into()),
226                        }
227                    })
228                },
229            ),
230        );
231
232        Ok(self)
233    }
234}
235
236fn decode_params<In: DeserializeOwned>(operation_id: &str, raw_params: Value) -> Result<In> {
237    let empty = is_empty_object(&raw_params);
238    match serde_json::from_value::<In>(raw_params) {
239        Ok(input) => Ok(input),
240        Err(error) if empty => serde_json::from_value::<In>(Value::Null).map_err(|_| {
241            Error::bad_request(format!("decode params for {:?}: {}", operation_id, error))
242        }),
243        Err(error) => Err(Error::bad_request(format!(
244            "decode params for {:?}: {}",
245            operation_id, error
246        ))),
247    }
248}
249
250fn is_empty_object(value: &Value) -> bool {
251    matches!(value, Value::Object(map) if map.is_empty())
252}
253
254fn join_error_message(error: tokio::task::JoinError) -> String {
255    if error.is_panic() {
256        let payload = error.try_into_panic().expect("panic payload");
257        log_panic_payload(payload);
258    } else {
259        eprintln!("internal error in Gestalt operation task: {error}");
260    }
261    INTERNAL_ERROR_MESSAGE.to_owned()
262}
263
264fn log_panic_payload(payload: Box<dyn Any + Send + 'static>) {
265    if let Some(text) = payload.downcast_ref::<&'static str>() {
266        eprintln!("panic in Gestalt operation: {}", text);
267    } else if let Some(text) = payload.downcast_ref::<String>() {
268        eprintln!("panic in Gestalt operation: {}", text);
269    } else {
270        eprintln!("panic in Gestalt operation");
271    }
272}
273
274#[cfg(test)]
275mod tests {
276    use super::*;
277
278    #[derive(Clone, Default)]
279    struct TestProvider;
280
281    #[derive(serde::Deserialize, schemars::JsonSchema)]
282    struct Input {
283        query: String,
284    }
285
286    #[derive(serde::Serialize, schemars::JsonSchema)]
287    struct Output {
288        query: String,
289    }
290
291    #[tokio::test]
292    async fn router_execute_returns_not_found_for_unknown_operation() {
293        let router = Router::<TestProvider>::new();
294        let result = router
295            .execute(
296                Arc::new(TestProvider),
297                "missing",
298                Value::Object(Default::default()),
299                Request::default(),
300            )
301            .await;
302        assert_eq!(result.status, 404);
303    }
304
305    #[test]
306    fn router_rejects_duplicate_ids() {
307        let router = Router::<TestProvider>::new()
308            .register(
309                Operation::<Input, Output>::new("search"),
310                |_provider, input, _request| async move {
311                    Ok::<crate::Response<Output>, std::convert::Infallible>(crate::ok(Output {
312                        query: input.query,
313                    }))
314                },
315            )
316            .expect("first registration");
317        let result = router.register(
318            Operation::<Input, Output>::new("search"),
319            |_provider, input, _request| async move {
320                Ok::<crate::Response<Output>, std::convert::Infallible>(crate::ok(Output {
321                    query: input.query,
322                }))
323            },
324        );
325
326        match result {
327            Ok(_) => panic!("duplicate id should fail"),
328            Err(err) => assert!(err.message().contains("duplicate operation id")),
329        }
330    }
331}