1use std::any::Any;
2use std::collections::BTreeMap;
3use std::future::Future;
4use std::marker::PhantomData;
5use std::pin::Pin;
6use std::sync::Arc;
7
8use schemars::JsonSchema;
9use serde::Serialize;
10use serde::de::DeserializeOwned;
11use serde_json::Value;
12
13use crate::api::{IntoResponse, Request};
14use crate::catalog::{Catalog, CatalogOperation, schema_json, schema_parameters};
15use crate::error::{Error, INTERNAL_ERROR_MESSAGE, Result};
16use crate::provider_server::OperationResult;
17
18#[derive(Clone, Debug)]
19pub struct Operation<In, Out> {
21 pub id: String,
23 pub method: String,
25 pub title: String,
27 pub description: String,
29 pub allowed_roles: Vec<String>,
31 pub tags: Vec<String>,
33 pub read_only: bool,
35 pub visible: Option<bool>,
37 _types: PhantomData<(In, Out)>,
38}
39
40impl<In, Out> Operation<In, Out>
41where
42 In: JsonSchema,
43 Out: JsonSchema,
44{
45 pub fn new(id: impl Into<String>) -> Self {
47 Self {
48 id: id.into(),
49 method: "POST".to_owned(),
50 title: String::new(),
51 description: String::new(),
52 allowed_roles: Vec::new(),
53 tags: Vec::new(),
54 read_only: false,
55 visible: None,
56 _types: PhantomData,
57 }
58 }
59
60 pub fn method(mut self, method: impl AsRef<str>) -> Self {
62 let method = method.as_ref().trim().to_ascii_uppercase();
63 if !method.is_empty() {
64 self.method = method;
65 }
66 self
67 }
68
69 pub fn title(mut self, title: impl Into<String>) -> Self {
71 self.title = title.into();
72 self
73 }
74
75 pub fn description(mut self, description: impl Into<String>) -> Self {
77 self.description = description.into();
78 self
79 }
80
81 pub fn allowed_roles(mut self, allowed_roles: impl Into<Vec<String>>) -> Self {
83 self.allowed_roles = allowed_roles.into();
84 self
85 }
86
87 pub fn tags(mut self, tags: impl Into<Vec<String>>) -> Self {
89 self.tags = tags.into();
90 self
91 }
92
93 pub fn read_only(mut self, read_only: bool) -> Self {
95 self.read_only = read_only;
96 self
97 }
98
99 pub fn visible(mut self, visible: bool) -> Self {
101 self.visible = Some(visible);
102 self
103 }
104}
105
106type Handler<P> = Arc<
107 dyn Fn(Arc<P>, Value, Request) -> Pin<Box<dyn Future<Output = OperationResult> + Send>>
108 + Send
109 + Sync,
110>;
111
112pub struct Router<P> {
114 catalog: Catalog,
115 handlers: BTreeMap<String, Handler<P>>,
116}
117
118impl<P> Clone for Router<P> {
119 fn clone(&self) -> Self {
120 Self {
121 catalog: self.catalog.clone(),
122 handlers: self.handlers.clone(),
123 }
124 }
125}
126
127impl<P> Default for Router<P> {
128 fn default() -> Self {
129 Self::new()
130 }
131}
132
133impl<P> Router<P> {
134 pub fn new() -> Self {
136 Self {
137 catalog: Catalog::default(),
138 handlers: BTreeMap::new(),
139 }
140 }
141
142 pub fn with_name(mut self, name: impl Into<String>) -> Self {
144 let name = name.into();
145 if !name.trim().is_empty() {
146 self.catalog.name = name;
147 }
148 self
149 }
150
151 pub fn catalog(&self) -> &Catalog {
153 &self.catalog
154 }
155
156 pub async fn execute(
158 &self,
159 provider: Arc<P>,
160 operation: &str,
161 params: Value,
162 request: Request,
163 ) -> OperationResult {
164 let Some(handler) = self.handlers.get(operation) else {
165 return OperationResult::error(404, "unknown operation");
166 };
167
168 match tokio::spawn(handler(provider, params, request)).await {
169 Ok(result) => result,
170 Err(error) => OperationResult::error(500, join_error_message(error)),
171 }
172 }
173}
174
175impl<P> Router<P>
176where
177 P: Send + Sync + 'static,
178{
179 pub fn register<In, Out, F, Fut, R, E>(
182 mut self,
183 operation: Operation<In, Out>,
184 handler: F,
185 ) -> Result<Self>
186 where
187 In: DeserializeOwned + JsonSchema + Send + 'static,
188 Out: Serialize + JsonSchema + Send + 'static,
189 F: Fn(Arc<P>, In, Request) -> Fut + Send + Sync + 'static,
190 Fut: Future<Output = std::result::Result<R, E>> + Send + 'static,
191 R: IntoResponse<Out> + Send + 'static,
192 E: Into<Error> + Send + 'static,
193 {
194 let operation_id = operation.id.trim();
195 if operation_id.is_empty() {
196 return Err(Error::bad_request("operation id is required"));
197 }
198 if self.handlers.contains_key(operation_id) {
199 return Err(Error::bad_request(format!(
200 "duplicate operation id {:?}",
201 operation_id
202 )));
203 }
204
205 let input_schema = schema_json::<In>()?;
206 let output_schema = schema_json::<Out>()?;
207 let parameters = schema_parameters(&input_schema);
208 let input_schema_str = serde_json::to_string(&input_schema).unwrap_or_default();
209 let output_schema_str = serde_json::to_string(&output_schema).unwrap_or_default();
210 let annotations = Some(crate::generated::v1::OperationAnnotations {
211 read_only_hint: operation.read_only.then_some(true),
212 ..Default::default()
213 });
214 self.catalog.operations.push(CatalogOperation {
215 id: operation_id.to_owned(),
216 method: operation.method.clone(),
217 title: operation.title.trim().to_owned(),
218 description: operation.description.trim().to_owned(),
219 input_schema: input_schema_str,
220 output_schema: output_schema_str,
221 annotations,
222 parameters,
223 required_scopes: Vec::new(),
224 tags: operation.tags.clone(),
225 read_only: operation.read_only,
226 visible: operation.visible,
227 transport: String::new(),
228 allowed_roles: operation.allowed_roles.clone(),
229 });
230
231 let handler = Arc::new(handler);
232 let operation_id = operation_id.to_owned();
233 self.handlers.insert(
234 operation_id.clone(),
235 Arc::new(
236 move |provider: Arc<P>, raw_params: Value, request: Request| {
237 let handler = Arc::clone(&handler);
238 let operation_id = operation_id.clone();
239 Box::pin(async move {
240 let input = match decode_params::<In>(&operation_id, raw_params) {
241 Ok(input) => input,
242 Err(error) => return OperationResult::from_error(error),
243 };
244
245 match handler(provider, input, request).await {
246 Ok(response) => {
247 OperationResult::from_response(response.into_response())
248 }
249 Err(error) => OperationResult::from_error(error.into()),
250 }
251 })
252 },
253 ),
254 );
255
256 Ok(self)
257 }
258}
259
260fn decode_params<In: DeserializeOwned>(operation_id: &str, raw_params: Value) -> Result<In> {
261 let empty = is_empty_object(&raw_params);
262 match serde_json::from_value::<In>(raw_params) {
263 Ok(input) => Ok(input),
264 Err(error) if empty => serde_json::from_value::<In>(Value::Null).map_err(|_| {
265 Error::bad_request(format!("decode params for {:?}: {}", operation_id, error))
266 }),
267 Err(error) => Err(Error::bad_request(format!(
268 "decode params for {:?}: {}",
269 operation_id, error
270 ))),
271 }
272}
273
274fn is_empty_object(value: &Value) -> bool {
275 matches!(value, Value::Object(map) if map.is_empty())
276}
277
278fn join_error_message(error: tokio::task::JoinError) -> String {
279 if error.is_panic() {
280 let payload = error.try_into_panic().expect("panic payload");
281 log_panic_payload(payload);
282 } else {
283 eprintln!("internal error in Gestalt operation task: {error}");
284 }
285 INTERNAL_ERROR_MESSAGE.to_owned()
286}
287
288fn log_panic_payload(payload: Box<dyn Any + Send + 'static>) {
289 if let Some(text) = payload.downcast_ref::<&'static str>() {
290 eprintln!("panic in Gestalt operation: {}", text);
291 } else if let Some(text) = payload.downcast_ref::<String>() {
292 eprintln!("panic in Gestalt operation: {}", text);
293 } else {
294 eprintln!("panic in Gestalt operation");
295 }
296}
297
298#[cfg(test)]
299mod tests {
300 use super::*;
301
302 #[derive(Clone, Default)]
303 struct TestProvider;
304
305 #[derive(serde::Deserialize, schemars::JsonSchema)]
306 struct Input {
307 query: String,
308 }
309
310 #[derive(serde::Serialize, schemars::JsonSchema)]
311 struct Output {
312 query: String,
313 }
314
315 #[tokio::test]
316 async fn router_execute_returns_not_found_for_unknown_operation() {
317 let router = Router::<TestProvider>::new();
318 let result = router
319 .execute(
320 Arc::new(TestProvider),
321 "missing",
322 Value::Object(Default::default()),
323 Request::default(),
324 )
325 .await;
326 assert_eq!(result.status, 404);
327 }
328
329 #[test]
330 fn router_rejects_duplicate_ids() {
331 let router = Router::<TestProvider>::new()
332 .register(
333 Operation::<Input, Output>::new("search"),
334 |_provider, input, _request| async move {
335 Ok::<crate::Response<Output>, std::convert::Infallible>(crate::ok(Output {
336 query: input.query,
337 }))
338 },
339 )
340 .expect("first registration");
341 let result = router.register(
342 Operation::<Input, Output>::new("search"),
343 |_provider, input, _request| async move {
344 Ok::<crate::Response<Output>, std::convert::Infallible>(crate::ok(Output {
345 query: input.query,
346 }))
347 },
348 );
349
350 match result {
351 Ok(_) => panic!("duplicate id should fail"),
352 Err(err) => assert!(err.message().contains("duplicate operation id")),
353 }
354 }
355}