Skip to main content

gestalt/
router.rs

1use std::any::Any;
2use std::collections::BTreeMap;
3use std::future::Future;
4use std::marker::PhantomData;
5use std::pin::Pin;
6use std::sync::Arc;
7
8use schemars::JsonSchema;
9use serde::Serialize;
10use serde::de::DeserializeOwned;
11use serde_json::Value;
12
13use crate::api::{IntoResponse, Request};
14use crate::catalog::{Catalog, CatalogOperation, schema_json, schema_parameters};
15use crate::error::{Error, INTERNAL_ERROR_MESSAGE, Result};
16use crate::provider_server::OperationResult;
17
18#[derive(Clone, Debug)]
19/// Describes one statically declared executable operation.
20pub struct Operation<In, Out> {
21    /// Stable operation id.
22    pub id: String,
23    /// HTTP method advertised in the catalog.
24    pub method: String,
25    /// Human-readable title.
26    pub title: String,
27    /// Human-readable description.
28    pub description: String,
29    /// Host-side roles allowed to invoke the operation.
30    pub allowed_roles: Vec<String>,
31    /// Free-form catalog tags.
32    pub tags: Vec<String>,
33    /// Whether the operation is read-only.
34    pub read_only: bool,
35    /// Optional catalog visibility override.
36    pub visible: Option<bool>,
37    _types: PhantomData<(In, Out)>,
38}
39
40impl<In, Out> Operation<In, Out>
41where
42    In: JsonSchema,
43    Out: JsonSchema,
44{
45    /// Creates a new operation definition with the supplied stable id.
46    pub fn new(id: impl Into<String>) -> Self {
47        Self {
48            id: id.into(),
49            method: "POST".to_owned(),
50            title: String::new(),
51            description: String::new(),
52            allowed_roles: Vec::new(),
53            tags: Vec::new(),
54            read_only: false,
55            visible: None,
56            _types: PhantomData,
57        }
58    }
59
60    /// Overrides the HTTP verb advertised in the derived catalog.
61    pub fn method(mut self, method: impl AsRef<str>) -> Self {
62        let method = method.as_ref().trim().to_ascii_uppercase();
63        if !method.is_empty() {
64            self.method = method;
65        }
66        self
67    }
68
69    /// Sets the human-readable title shown in the derived catalog.
70    pub fn title(mut self, title: impl Into<String>) -> Self {
71        self.title = title.into();
72        self
73    }
74
75    /// Sets the human-readable description shown in the derived catalog.
76    pub fn description(mut self, description: impl Into<String>) -> Self {
77        self.description = description.into();
78        self
79    }
80
81    /// Restricts the operation to the supplied host-side roles.
82    pub fn allowed_roles(mut self, allowed_roles: impl Into<Vec<String>>) -> Self {
83        self.allowed_roles = allowed_roles.into();
84        self
85    }
86
87    /// Attaches free-form tags to the derived catalog entry.
88    pub fn tags(mut self, tags: impl Into<Vec<String>>) -> Self {
89        self.tags = tags.into();
90        self
91    }
92
93    /// Marks the operation as read-only in the derived catalog.
94    pub fn read_only(mut self, read_only: bool) -> Self {
95        self.read_only = read_only;
96        self
97    }
98
99    /// Controls whether the operation should be visible to callers.
100    pub fn visible(mut self, visible: bool) -> Self {
101        self.visible = Some(visible);
102        self
103    }
104}
105
106type Handler<P> = Arc<
107    dyn Fn(Arc<P>, Value, Request) -> Pin<Box<dyn Future<Output = OperationResult> + Send>>
108        + Send
109        + Sync,
110>;
111
112/// Dispatches typed operations and exposes the corresponding static catalog.
113pub struct Router<P> {
114    catalog: Catalog,
115    handlers: BTreeMap<String, Handler<P>>,
116}
117
118impl<P> Clone for Router<P> {
119    fn clone(&self) -> Self {
120        Self {
121            catalog: self.catalog.clone(),
122            handlers: self.handlers.clone(),
123        }
124    }
125}
126
127impl<P> Default for Router<P> {
128    fn default() -> Self {
129        Self::new()
130    }
131}
132
133impl<P> Router<P> {
134    /// Creates an empty router.
135    pub fn new() -> Self {
136        Self {
137            catalog: Catalog::default(),
138            handlers: BTreeMap::new(),
139        }
140    }
141
142    /// Returns a copy of the router with the catalog name overridden.
143    pub fn with_name(mut self, name: impl Into<String>) -> Self {
144        let name = name.into();
145        if !name.trim().is_empty() {
146            self.catalog.name = name;
147        }
148        self
149    }
150
151    /// Returns the router's derived static catalog.
152    pub fn catalog(&self) -> &Catalog {
153        &self.catalog
154    }
155
156    /// Executes one named operation against provider.
157    pub async fn execute(
158        &self,
159        provider: Arc<P>,
160        operation: &str,
161        params: Value,
162        request: Request,
163    ) -> OperationResult {
164        let Some(handler) = self.handlers.get(operation) else {
165            return OperationResult::error(404, "unknown operation");
166        };
167
168        match tokio::spawn(handler(provider, params, request)).await {
169            Ok(result) => result,
170            Err(error) => OperationResult::error(500, join_error_message(error)),
171        }
172    }
173}
174
175impl<P> Router<P>
176where
177    P: Send + Sync + 'static,
178{
179    /// Registers a typed handler and derives the corresponding catalog entry
180    /// from its `serde` and `schemars` types.
181    pub fn register<In, Out, F, Fut, R, E>(
182        mut self,
183        operation: Operation<In, Out>,
184        handler: F,
185    ) -> Result<Self>
186    where
187        In: DeserializeOwned + JsonSchema + Send + 'static,
188        Out: Serialize + JsonSchema + Send + 'static,
189        F: Fn(Arc<P>, In, Request) -> Fut + Send + Sync + 'static,
190        Fut: Future<Output = std::result::Result<R, E>> + Send + 'static,
191        R: IntoResponse<Out> + Send + 'static,
192        E: Into<Error> + Send + 'static,
193    {
194        let operation_id = operation.id.trim();
195        if operation_id.is_empty() {
196            return Err(Error::bad_request("operation id is required"));
197        }
198        if self.handlers.contains_key(operation_id) {
199            return Err(Error::bad_request(format!(
200                "duplicate operation id {:?}",
201                operation_id
202            )));
203        }
204
205        let input_schema = schema_json::<In>()?;
206        let output_schema = schema_json::<Out>()?;
207        let parameters = schema_parameters(&input_schema);
208        let input_schema_str = serde_json::to_string(&input_schema).unwrap_or_default();
209        let output_schema_str = serde_json::to_string(&output_schema).unwrap_or_default();
210        let annotations = Some(crate::generated::v1::OperationAnnotations {
211            read_only_hint: operation.read_only.then_some(true),
212            ..Default::default()
213        });
214        self.catalog.operations.push(CatalogOperation {
215            id: operation_id.to_owned(),
216            method: operation.method.clone(),
217            title: operation.title.trim().to_owned(),
218            description: operation.description.trim().to_owned(),
219            input_schema: input_schema_str,
220            output_schema: output_schema_str,
221            annotations,
222            parameters,
223            required_scopes: Vec::new(),
224            tags: operation.tags.clone(),
225            read_only: operation.read_only,
226            visible: operation.visible,
227            transport: String::new(),
228            allowed_roles: operation.allowed_roles.clone(),
229        });
230
231        let handler = Arc::new(handler);
232        let operation_id = operation_id.to_owned();
233        self.handlers.insert(
234            operation_id.clone(),
235            Arc::new(
236                move |provider: Arc<P>, raw_params: Value, request: Request| {
237                    let handler = Arc::clone(&handler);
238                    let operation_id = operation_id.clone();
239                    Box::pin(async move {
240                        let input = match decode_params::<In>(&operation_id, raw_params) {
241                            Ok(input) => input,
242                            Err(error) => return OperationResult::from_error(error),
243                        };
244
245                        match handler(provider, input, request).await {
246                            Ok(response) => {
247                                OperationResult::from_response(response.into_response())
248                            }
249                            Err(error) => OperationResult::from_error(error.into()),
250                        }
251                    })
252                },
253            ),
254        );
255
256        Ok(self)
257    }
258}
259
260fn decode_params<In: DeserializeOwned>(operation_id: &str, raw_params: Value) -> Result<In> {
261    let empty = is_empty_object(&raw_params);
262    match serde_json::from_value::<In>(raw_params) {
263        Ok(input) => Ok(input),
264        Err(error) if empty => serde_json::from_value::<In>(Value::Null).map_err(|_| {
265            Error::bad_request(format!("decode params for {:?}: {}", operation_id, error))
266        }),
267        Err(error) => Err(Error::bad_request(format!(
268            "decode params for {:?}: {}",
269            operation_id, error
270        ))),
271    }
272}
273
274fn is_empty_object(value: &Value) -> bool {
275    matches!(value, Value::Object(map) if map.is_empty())
276}
277
278fn join_error_message(error: tokio::task::JoinError) -> String {
279    if error.is_panic() {
280        let payload = error.try_into_panic().expect("panic payload");
281        log_panic_payload(payload);
282    } else {
283        eprintln!("internal error in Gestalt operation task: {error}");
284    }
285    INTERNAL_ERROR_MESSAGE.to_owned()
286}
287
288fn log_panic_payload(payload: Box<dyn Any + Send + 'static>) {
289    if let Some(text) = payload.downcast_ref::<&'static str>() {
290        eprintln!("panic in Gestalt operation: {}", text);
291    } else if let Some(text) = payload.downcast_ref::<String>() {
292        eprintln!("panic in Gestalt operation: {}", text);
293    } else {
294        eprintln!("panic in Gestalt operation");
295    }
296}
297
298#[cfg(test)]
299mod tests {
300    use super::*;
301
302    #[derive(Clone, Default)]
303    struct TestProvider;
304
305    #[derive(serde::Deserialize, schemars::JsonSchema)]
306    struct Input {
307        query: String,
308    }
309
310    #[derive(serde::Serialize, schemars::JsonSchema)]
311    struct Output {
312        query: String,
313    }
314
315    #[tokio::test]
316    async fn router_execute_returns_not_found_for_unknown_operation() {
317        let router = Router::<TestProvider>::new();
318        let result = router
319            .execute(
320                Arc::new(TestProvider),
321                "missing",
322                Value::Object(Default::default()),
323                Request::default(),
324            )
325            .await;
326        assert_eq!(result.status, 404);
327    }
328
329    #[test]
330    fn router_rejects_duplicate_ids() {
331        let router = Router::<TestProvider>::new()
332            .register(
333                Operation::<Input, Output>::new("search"),
334                |_provider, input, _request| async move {
335                    Ok::<crate::Response<Output>, std::convert::Infallible>(crate::ok(Output {
336                        query: input.query,
337                    }))
338                },
339            )
340            .expect("first registration");
341        let result = router.register(
342            Operation::<Input, Output>::new("search"),
343            |_provider, input, _request| async move {
344                Ok::<crate::Response<Output>, std::convert::Infallible>(crate::ok(Output {
345                    query: input.query,
346                }))
347            },
348        );
349
350        match result {
351            Ok(_) => panic!("duplicate id should fail"),
352            Err(err) => assert!(err.message().contains("duplicate operation id")),
353        }
354    }
355}