Skip to main content

gestalt/
router.rs

1use std::any::Any;
2use std::collections::BTreeMap;
3use std::future::Future;
4use std::marker::PhantomData;
5use std::pin::Pin;
6use std::sync::Arc;
7
8use schemars::JsonSchema;
9use serde::Serialize;
10use serde::de::DeserializeOwned;
11use serde_json::Value;
12
13use crate::api::{IntoResponse, Request};
14use crate::catalog::{
15    Catalog, CatalogOperation, OperationAnnotations, schema_json, schema_parameters,
16};
17use crate::error::{Error, INTERNAL_ERROR_MESSAGE, Result};
18use crate::provider_server::OperationResult;
19
20#[derive(Clone, Debug)]
21/// Describes one statically declared executable operation.
22pub struct Operation<In, Out> {
23    /// Stable operation id.
24    pub id: String,
25    /// HTTP method advertised in the catalog.
26    pub method: String,
27    /// Human-readable title.
28    pub title: String,
29    /// Human-readable description.
30    pub description: String,
31    /// Host-side roles allowed to invoke the operation.
32    pub allowed_roles: Vec<String>,
33    /// Free-form catalog tags.
34    pub tags: Vec<String>,
35    /// Whether the operation is read-only.
36    pub read_only: bool,
37    /// Optional catalog visibility override.
38    pub visible: Option<bool>,
39    _types: PhantomData<(In, Out)>,
40}
41
42impl<In, Out> Operation<In, Out>
43where
44    In: JsonSchema,
45    Out: JsonSchema,
46{
47    /// Creates a new operation definition with the supplied stable id.
48    pub fn new(id: impl Into<String>) -> Self {
49        Self {
50            id: id.into(),
51            method: "POST".to_owned(),
52            title: String::new(),
53            description: String::new(),
54            allowed_roles: Vec::new(),
55            tags: Vec::new(),
56            read_only: false,
57            visible: None,
58            _types: PhantomData,
59        }
60    }
61
62    /// Overrides the HTTP verb advertised in the derived catalog.
63    pub fn method(mut self, method: impl AsRef<str>) -> Self {
64        let method = method.as_ref().trim().to_ascii_uppercase();
65        if !method.is_empty() {
66            self.method = method;
67        }
68        self
69    }
70
71    /// Sets the human-readable title shown in the derived catalog.
72    pub fn title(mut self, title: impl Into<String>) -> Self {
73        self.title = title.into();
74        self
75    }
76
77    /// Sets the human-readable description shown in the derived catalog.
78    pub fn description(mut self, description: impl Into<String>) -> Self {
79        self.description = description.into();
80        self
81    }
82
83    /// Restricts the operation to the supplied host-side roles.
84    pub fn allowed_roles(mut self, allowed_roles: impl Into<Vec<String>>) -> Self {
85        self.allowed_roles = allowed_roles.into();
86        self
87    }
88
89    /// Attaches free-form tags to the derived catalog entry.
90    pub fn tags(mut self, tags: impl Into<Vec<String>>) -> Self {
91        self.tags = tags.into();
92        self
93    }
94
95    /// Marks the operation as read-only in the derived catalog.
96    pub fn read_only(mut self, read_only: bool) -> Self {
97        self.read_only = read_only;
98        self
99    }
100
101    /// Controls whether the operation should be visible to callers.
102    pub fn visible(mut self, visible: bool) -> Self {
103        self.visible = Some(visible);
104        self
105    }
106}
107
108type Handler<P> = Arc<
109    dyn Fn(Arc<P>, Value, Request) -> Pin<Box<dyn Future<Output = OperationResult> + Send>>
110        + Send
111        + Sync,
112>;
113
114/// Dispatches typed operations and exposes the corresponding static catalog.
115pub struct Router<P> {
116    catalog: Catalog,
117    handlers: BTreeMap<String, Handler<P>>,
118}
119
120impl<P> Clone for Router<P> {
121    fn clone(&self) -> Self {
122        Self {
123            catalog: self.catalog.clone(),
124            handlers: self.handlers.clone(),
125        }
126    }
127}
128
129impl<P> Default for Router<P> {
130    fn default() -> Self {
131        Self::new()
132    }
133}
134
135impl<P> Router<P> {
136    /// Creates an empty router.
137    pub fn new() -> Self {
138        Self {
139            catalog: Catalog::default(),
140            handlers: BTreeMap::new(),
141        }
142    }
143
144    /// Returns a copy of the router with the catalog name overridden.
145    pub fn with_name(mut self, name: impl Into<String>) -> Self {
146        let name = name.into();
147        if !name.trim().is_empty() {
148            self.catalog.name = name;
149        }
150        self
151    }
152
153    /// Returns the router's derived static catalog.
154    pub fn catalog(&self) -> &Catalog {
155        &self.catalog
156    }
157
158    /// Executes one named operation against provider.
159    pub async fn execute(
160        &self,
161        provider: Arc<P>,
162        operation: &str,
163        params: Value,
164        request: Request,
165    ) -> OperationResult {
166        let Some(handler) = self.handlers.get(operation) else {
167            return OperationResult::error(404, "unknown operation");
168        };
169
170        match tokio::spawn(handler(provider, params, request)).await {
171            Ok(result) => result,
172            Err(error) => OperationResult::error(500, join_error_message(error)),
173        }
174    }
175}
176
177impl<P> Router<P>
178where
179    P: Send + Sync + 'static,
180{
181    /// Registers a typed handler and derives the corresponding catalog entry
182    /// from its `serde` and `schemars` types.
183    pub fn register<In, Out, F, Fut, R, E>(
184        mut self,
185        operation: Operation<In, Out>,
186        handler: F,
187    ) -> Result<Self>
188    where
189        In: DeserializeOwned + JsonSchema + Send + 'static,
190        Out: Serialize + JsonSchema + Send + 'static,
191        F: Fn(Arc<P>, In, Request) -> Fut + Send + Sync + 'static,
192        Fut: Future<Output = std::result::Result<R, E>> + Send + 'static,
193        R: IntoResponse<Out> + Send + 'static,
194        E: Into<Error> + Send + 'static,
195    {
196        let operation_id = operation.id.trim();
197        if operation_id.is_empty() {
198            return Err(Error::bad_request("operation id is required"));
199        }
200        if self.handlers.contains_key(operation_id) {
201            return Err(Error::bad_request(format!(
202                "duplicate operation id {:?}",
203                operation_id
204            )));
205        }
206
207        let input_schema = schema_json::<In>()?;
208        let output_schema = schema_json::<Out>()?;
209        let parameters = schema_parameters(&input_schema);
210        let input_schema_str = serde_json::to_string(&input_schema).unwrap_or_default();
211        let output_schema_str = serde_json::to_string(&output_schema).unwrap_or_default();
212        let annotations = Some(OperationAnnotations {
213            read_only_hint: operation.read_only.then_some(true),
214            ..Default::default()
215        });
216        self.catalog.operations.push(CatalogOperation {
217            id: operation_id.to_owned(),
218            method: operation.method.clone(),
219            title: operation.title.trim().to_owned(),
220            description: operation.description.trim().to_owned(),
221            input_schema: input_schema_str,
222            output_schema: output_schema_str,
223            annotations,
224            parameters,
225            required_scopes: Vec::new(),
226            tags: operation.tags.clone(),
227            read_only: operation.read_only,
228            visible: operation.visible,
229            transport: String::new(),
230            allowed_roles: operation.allowed_roles.clone(),
231        });
232
233        let handler = Arc::new(handler);
234        let operation_id = operation_id.to_owned();
235        self.handlers.insert(
236            operation_id.clone(),
237            Arc::new(
238                move |provider: Arc<P>, raw_params: Value, request: Request| {
239                    let handler = Arc::clone(&handler);
240                    let operation_id = operation_id.clone();
241                    Box::pin(async move {
242                        let input = match decode_params::<In>(&operation_id, raw_params) {
243                            Ok(input) => input,
244                            Err(error) => return OperationResult::from_error(error),
245                        };
246
247                        match handler(provider, input, request).await {
248                            Ok(response) => {
249                                OperationResult::from_response(response.into_response())
250                            }
251                            Err(error) => OperationResult::from_error(error.into()),
252                        }
253                    })
254                },
255            ),
256        );
257
258        Ok(self)
259    }
260}
261
262fn decode_params<In: DeserializeOwned>(operation_id: &str, raw_params: Value) -> Result<In> {
263    let empty = is_empty_object(&raw_params);
264    match serde_json::from_value::<In>(raw_params) {
265        Ok(input) => Ok(input),
266        Err(error) if empty => serde_json::from_value::<In>(Value::Null).map_err(|_| {
267            Error::bad_request(format!("decode params for {:?}: {}", operation_id, error))
268        }),
269        Err(error) => Err(Error::bad_request(format!(
270            "decode params for {:?}: {}",
271            operation_id, error
272        ))),
273    }
274}
275
276fn is_empty_object(value: &Value) -> bool {
277    matches!(value, Value::Object(map) if map.is_empty())
278}
279
280fn join_error_message(error: tokio::task::JoinError) -> String {
281    if error.is_panic() {
282        let payload = error.try_into_panic().expect("panic payload");
283        log_panic_payload(payload);
284    } else {
285        eprintln!("internal error in Gestalt operation task: {error}");
286    }
287    INTERNAL_ERROR_MESSAGE.to_owned()
288}
289
290fn log_panic_payload(payload: Box<dyn Any + Send + 'static>) {
291    if let Some(text) = payload.downcast_ref::<&'static str>() {
292        eprintln!("panic in Gestalt operation: {}", text);
293    } else if let Some(text) = payload.downcast_ref::<String>() {
294        eprintln!("panic in Gestalt operation: {}", text);
295    } else {
296        eprintln!("panic in Gestalt operation");
297    }
298}
299
300#[cfg(test)]
301mod tests {
302    use super::*;
303
304    #[derive(Clone, Default)]
305    struct TestProvider;
306
307    #[derive(serde::Deserialize, schemars::JsonSchema)]
308    struct Input {
309        query: String,
310    }
311
312    #[derive(serde::Serialize, schemars::JsonSchema)]
313    struct Output {
314        query: String,
315    }
316
317    #[tokio::test]
318    async fn router_execute_returns_not_found_for_unknown_operation() {
319        let router = Router::<TestProvider>::new();
320        let result = router
321            .execute(
322                Arc::new(TestProvider),
323                "missing",
324                Value::Object(Default::default()),
325                Request::default(),
326            )
327            .await;
328        assert_eq!(result.status, 404);
329    }
330
331    #[test]
332    fn router_rejects_duplicate_ids() {
333        let router = Router::<TestProvider>::new()
334            .register(
335                Operation::<Input, Output>::new("search"),
336                |_provider, input, _request| async move {
337                    Ok::<crate::Response<Output>, std::convert::Infallible>(crate::ok(Output {
338                        query: input.query,
339                    }))
340                },
341            )
342            .expect("first registration");
343        let result = router.register(
344            Operation::<Input, Output>::new("search"),
345            |_provider, input, _request| async move {
346                Ok::<crate::Response<Output>, std::convert::Infallible>(crate::ok(Output {
347                    query: input.query,
348                }))
349            },
350        );
351
352        match result {
353            Ok(_) => panic!("duplicate id should fail"),
354            Err(err) => assert!(err.message().contains("duplicate operation id")),
355        }
356    }
357}