1use std::any::Any;
2use std::collections::BTreeMap;
3use std::future::Future;
4use std::marker::PhantomData;
5use std::pin::Pin;
6use std::sync::Arc;
7
8use schemars::JsonSchema;
9use serde::Serialize;
10use serde::de::DeserializeOwned;
11use serde_json::Value;
12
13use crate::api::{IntoResponse, Request};
14use crate::catalog::{
15 Catalog, CatalogOperation, OperationAnnotations, schema_json, schema_parameters,
16};
17use crate::error::{Error, INTERNAL_ERROR_MESSAGE, Result};
18use crate::provider_server::OperationResult;
19
20#[derive(Clone, Debug)]
21pub struct Operation<In, Out> {
23 pub id: String,
25 pub method: String,
27 pub title: String,
29 pub description: String,
31 pub allowed_roles: Vec<String>,
33 pub tags: Vec<String>,
35 pub read_only: bool,
37 pub visible: Option<bool>,
39 _types: PhantomData<(In, Out)>,
40}
41
42impl<In, Out> Operation<In, Out>
43where
44 In: JsonSchema,
45 Out: JsonSchema,
46{
47 pub fn new(id: impl Into<String>) -> Self {
49 Self {
50 id: id.into(),
51 method: "POST".to_owned(),
52 title: String::new(),
53 description: String::new(),
54 allowed_roles: Vec::new(),
55 tags: Vec::new(),
56 read_only: false,
57 visible: None,
58 _types: PhantomData,
59 }
60 }
61
62 pub fn method(mut self, method: impl AsRef<str>) -> Self {
64 let method = method.as_ref().trim().to_ascii_uppercase();
65 if !method.is_empty() {
66 self.method = method;
67 }
68 self
69 }
70
71 pub fn title(mut self, title: impl Into<String>) -> Self {
73 self.title = title.into();
74 self
75 }
76
77 pub fn description(mut self, description: impl Into<String>) -> Self {
79 self.description = description.into();
80 self
81 }
82
83 pub fn allowed_roles(mut self, allowed_roles: impl Into<Vec<String>>) -> Self {
85 self.allowed_roles = allowed_roles.into();
86 self
87 }
88
89 pub fn tags(mut self, tags: impl Into<Vec<String>>) -> Self {
91 self.tags = tags.into();
92 self
93 }
94
95 pub fn read_only(mut self, read_only: bool) -> Self {
97 self.read_only = read_only;
98 self
99 }
100
101 pub fn visible(mut self, visible: bool) -> Self {
103 self.visible = Some(visible);
104 self
105 }
106}
107
108type Handler<P> = Arc<
109 dyn Fn(Arc<P>, Value, Request) -> Pin<Box<dyn Future<Output = OperationResult> + Send>>
110 + Send
111 + Sync,
112>;
113
114pub struct Router<P> {
116 catalog: Catalog,
117 handlers: BTreeMap<String, Handler<P>>,
118}
119
120impl<P> Clone for Router<P> {
121 fn clone(&self) -> Self {
122 Self {
123 catalog: self.catalog.clone(),
124 handlers: self.handlers.clone(),
125 }
126 }
127}
128
129impl<P> Default for Router<P> {
130 fn default() -> Self {
131 Self::new()
132 }
133}
134
135impl<P> Router<P> {
136 pub fn new() -> Self {
138 Self {
139 catalog: Catalog::default(),
140 handlers: BTreeMap::new(),
141 }
142 }
143
144 pub fn with_name(mut self, name: impl Into<String>) -> Self {
146 let name = name.into();
147 if !name.trim().is_empty() {
148 self.catalog.name = name;
149 }
150 self
151 }
152
153 pub fn catalog(&self) -> &Catalog {
155 &self.catalog
156 }
157
158 pub async fn execute(
160 &self,
161 provider: Arc<P>,
162 operation: &str,
163 params: Value,
164 request: Request,
165 ) -> OperationResult {
166 let Some(handler) = self.handlers.get(operation) else {
167 return OperationResult::error(404, "unknown operation");
168 };
169
170 match tokio::spawn(handler(provider, params, request)).await {
171 Ok(result) => result,
172 Err(error) => OperationResult::error(500, join_error_message(error)),
173 }
174 }
175}
176
177impl<P> Router<P>
178where
179 P: Send + Sync + 'static,
180{
181 pub fn register<In, Out, F, Fut, R, E>(
184 mut self,
185 operation: Operation<In, Out>,
186 handler: F,
187 ) -> Result<Self>
188 where
189 In: DeserializeOwned + JsonSchema + Send + 'static,
190 Out: Serialize + JsonSchema + Send + 'static,
191 F: Fn(Arc<P>, In, Request) -> Fut + Send + Sync + 'static,
192 Fut: Future<Output = std::result::Result<R, E>> + Send + 'static,
193 R: IntoResponse<Out> + Send + 'static,
194 E: Into<Error> + Send + 'static,
195 {
196 let operation_id = operation.id.trim();
197 if operation_id.is_empty() {
198 return Err(Error::bad_request("operation id is required"));
199 }
200 if self.handlers.contains_key(operation_id) {
201 return Err(Error::bad_request(format!(
202 "duplicate operation id {:?}",
203 operation_id
204 )));
205 }
206
207 let input_schema = schema_json::<In>()?;
208 let output_schema = schema_json::<Out>()?;
209 let parameters = schema_parameters(&input_schema);
210 let input_schema_str = serde_json::to_string(&input_schema).unwrap_or_default();
211 let output_schema_str = serde_json::to_string(&output_schema).unwrap_or_default();
212 let annotations = Some(OperationAnnotations {
213 read_only_hint: operation.read_only.then_some(true),
214 ..Default::default()
215 });
216 self.catalog.operations.push(CatalogOperation {
217 id: operation_id.to_owned(),
218 method: operation.method.clone(),
219 title: operation.title.trim().to_owned(),
220 description: operation.description.trim().to_owned(),
221 input_schema: input_schema_str,
222 output_schema: output_schema_str,
223 annotations,
224 parameters,
225 required_scopes: Vec::new(),
226 tags: operation.tags.clone(),
227 read_only: operation.read_only,
228 visible: operation.visible,
229 transport: String::new(),
230 allowed_roles: operation.allowed_roles.clone(),
231 });
232
233 let handler = Arc::new(handler);
234 let operation_id = operation_id.to_owned();
235 self.handlers.insert(
236 operation_id.clone(),
237 Arc::new(
238 move |provider: Arc<P>, raw_params: Value, request: Request| {
239 let handler = Arc::clone(&handler);
240 let operation_id = operation_id.clone();
241 Box::pin(async move {
242 let input = match decode_params::<In>(&operation_id, raw_params) {
243 Ok(input) => input,
244 Err(error) => return OperationResult::from_error(error),
245 };
246
247 match handler(provider, input, request).await {
248 Ok(response) => {
249 OperationResult::from_response(response.into_response())
250 }
251 Err(error) => OperationResult::from_error(error.into()),
252 }
253 })
254 },
255 ),
256 );
257
258 Ok(self)
259 }
260}
261
262fn decode_params<In: DeserializeOwned>(operation_id: &str, raw_params: Value) -> Result<In> {
263 let empty = is_empty_object(&raw_params);
264 match serde_json::from_value::<In>(raw_params) {
265 Ok(input) => Ok(input),
266 Err(error) if empty => serde_json::from_value::<In>(Value::Null).map_err(|_| {
267 Error::bad_request(format!("decode params for {:?}: {}", operation_id, error))
268 }),
269 Err(error) => Err(Error::bad_request(format!(
270 "decode params for {:?}: {}",
271 operation_id, error
272 ))),
273 }
274}
275
276fn is_empty_object(value: &Value) -> bool {
277 matches!(value, Value::Object(map) if map.is_empty())
278}
279
280fn join_error_message(error: tokio::task::JoinError) -> String {
281 if error.is_panic() {
282 let payload = error.try_into_panic().expect("panic payload");
283 log_panic_payload(payload);
284 } else {
285 eprintln!("internal error in Gestalt operation task: {error}");
286 }
287 INTERNAL_ERROR_MESSAGE.to_owned()
288}
289
290fn log_panic_payload(payload: Box<dyn Any + Send + 'static>) {
291 if let Some(text) = payload.downcast_ref::<&'static str>() {
292 eprintln!("panic in Gestalt operation: {}", text);
293 } else if let Some(text) = payload.downcast_ref::<String>() {
294 eprintln!("panic in Gestalt operation: {}", text);
295 } else {
296 eprintln!("panic in Gestalt operation");
297 }
298}
299
300#[cfg(test)]
301mod tests {
302 use super::*;
303
304 #[derive(Clone, Default)]
305 struct TestProvider;
306
307 #[derive(serde::Deserialize, schemars::JsonSchema)]
308 struct Input {
309 query: String,
310 }
311
312 #[derive(serde::Serialize, schemars::JsonSchema)]
313 struct Output {
314 query: String,
315 }
316
317 #[tokio::test]
318 async fn router_execute_returns_not_found_for_unknown_operation() {
319 let router = Router::<TestProvider>::new();
320 let result = router
321 .execute(
322 Arc::new(TestProvider),
323 "missing",
324 Value::Object(Default::default()),
325 Request::default(),
326 )
327 .await;
328 assert_eq!(result.status, 404);
329 }
330
331 #[test]
332 fn router_rejects_duplicate_ids() {
333 let router = Router::<TestProvider>::new()
334 .register(
335 Operation::<Input, Output>::new("search"),
336 |_provider, input, _request| async move {
337 Ok::<crate::Response<Output>, std::convert::Infallible>(crate::ok(Output {
338 query: input.query,
339 }))
340 },
341 )
342 .expect("first registration");
343 let result = router.register(
344 Operation::<Input, Output>::new("search"),
345 |_provider, input, _request| async move {
346 Ok::<crate::Response<Output>, std::convert::Infallible>(crate::ok(Output {
347 query: input.query,
348 }))
349 },
350 );
351
352 match result {
353 Ok(_) => panic!("duplicate id should fail"),
354 Err(err) => assert!(err.message().contains("duplicate operation id")),
355 }
356 }
357}