pforge_runtime/
registry.rs1use crate::{Error, Handler, Result};
2use rustc_hash::FxHashMap;
3use std::future::Future;
4use std::pin::Pin;
5use std::sync::Arc;
6
7type BoxFuture<'a, T> = Pin<Box<dyn Future<Output = T> + Send + 'a>>;
8
9pub struct HandlerRegistry {
60 handlers: FxHashMap<String, Arc<dyn HandlerEntry>>,
61}
62
63trait HandlerEntry: Send + Sync {
64 fn dispatch(&self, params: &[u8]) -> BoxFuture<'static, Result<Vec<u8>>>;
66
67 fn input_schema(&self) -> schemars::schema::RootSchema;
69 fn output_schema(&self) -> schemars::schema::RootSchema;
70}
71
72struct HandlerEntryImpl<H: Handler> {
73 handler: Arc<H>,
74}
75
76impl<H: Handler> HandlerEntryImpl<H> {
77 fn new(handler: H) -> Self {
78 Self {
79 handler: Arc::new(handler),
80 }
81 }
82}
83
84impl<H> HandlerEntry for HandlerEntryImpl<H>
85where
86 H: Handler,
87 H::Input: 'static,
88 H::Output: 'static,
89{
90 fn dispatch(&self, params: &[u8]) -> BoxFuture<'static, Result<Vec<u8>>> {
91 let input: H::Input = match serde_json::from_slice(params) {
92 Ok(input) => input,
93 Err(e) => return Box::pin(async move { Err(e.into()) }),
94 };
95
96 let handler = self.handler.clone();
97 Box::pin(async move {
98 let output = handler.handle(input).await.map_err(Into::into)?;
99 serde_json::to_vec(&output).map_err(Into::into)
100 })
101 }
102
103 fn input_schema(&self) -> schemars::schema::RootSchema {
104 H::input_schema()
105 }
106
107 fn output_schema(&self) -> schemars::schema::RootSchema {
108 H::output_schema()
109 }
110}
111
112impl HandlerRegistry {
113 pub fn new() -> Self {
115 Self {
116 handlers: FxHashMap::default(),
117 }
118 }
119
120 pub fn register<H>(&mut self, name: impl Into<String>, handler: H)
122 where
123 H: Handler,
124 H::Input: 'static,
125 H::Output: 'static,
126 {
127 let entry = HandlerEntryImpl::new(handler);
128 self.handlers.insert(name.into(), Arc::new(entry));
129 }
130
131 pub fn has_handler(&self, name: &str) -> bool {
133 self.handlers.contains_key(name)
134 }
135
136 #[inline(always)]
138 pub async fn dispatch(&self, tool: &str, params: &[u8]) -> Result<Vec<u8>> {
139 match self.handlers.get(tool) {
140 Some(handler) => handler.dispatch(params).await,
141 None => Err(Error::ToolNotFound(tool.to_string())),
142 }
143 }
144
145 pub fn len(&self) -> usize {
147 self.handlers.len()
148 }
149
150 pub fn is_empty(&self) -> bool {
152 self.handlers.is_empty()
153 }
154
155 pub fn get_input_schema(&self, tool: &str) -> Option<schemars::schema::RootSchema> {
157 self.handlers.get(tool).map(|h| h.input_schema())
158 }
159
160 pub fn get_output_schema(&self, tool: &str) -> Option<schemars::schema::RootSchema> {
162 self.handlers.get(tool).map(|h| h.output_schema())
163 }
164}
165
166impl Default for HandlerRegistry {
167 fn default() -> Self {
168 Self::new()
169 }
170}
171
172#[cfg(test)]
173mod tests {
174 use super::*;
175 use async_trait::async_trait;
176 use schemars::JsonSchema;
177 use serde::{Deserialize, Serialize};
178
179 #[derive(Debug, Serialize, Deserialize, JsonSchema)]
180 struct TestInput {
181 value: i32,
182 }
183
184 #[derive(Debug, Serialize, Deserialize, JsonSchema)]
185 struct TestOutput {
186 result: i32,
187 }
188
189 struct TestHandler;
190
191 #[async_trait]
192 impl crate::Handler for TestHandler {
193 type Input = TestInput;
194 type Output = TestOutput;
195 type Error = crate::Error;
196
197 async fn handle(&self, input: Self::Input) -> Result<Self::Output> {
198 Ok(TestOutput {
199 result: input.value * 2,
200 })
201 }
202 }
203
204 struct ErrorHandler;
205
206 #[async_trait]
207 impl crate::Handler for ErrorHandler {
208 type Input = TestInput;
209 type Output = TestOutput;
210 type Error = crate::Error;
211
212 async fn handle(&self, _input: Self::Input) -> Result<Self::Output> {
213 Err(crate::Error::Handler("test error".to_string()))
214 }
215 }
216
217 #[tokio::test]
218 async fn test_registry_new() {
219 let registry = HandlerRegistry::new();
220 assert!(registry.is_empty());
221 assert_eq!(registry.len(), 0);
222 }
223
224 #[tokio::test]
225 async fn test_registry_register() {
226 let mut registry = HandlerRegistry::new();
227 registry.register("test", TestHandler);
228
229 assert!(!registry.is_empty());
230 assert_eq!(registry.len(), 1);
231 assert!(registry.has_handler("test"));
232 assert!(!registry.has_handler("nonexistent"));
233 }
234
235 #[tokio::test]
236 async fn test_registry_dispatch() {
237 let mut registry = HandlerRegistry::new();
238 registry.register("test", TestHandler);
239
240 let input = TestInput { value: 21 };
241 let input_bytes = serde_json::to_vec(&input).unwrap();
242
243 let result = registry.dispatch("test", &input_bytes).await;
244 assert!(result.is_ok());
245
246 let output: TestOutput = serde_json::from_slice(&result.unwrap()).unwrap();
247 assert_eq!(output.result, 42);
248 }
249
250 #[tokio::test]
251 async fn test_registry_dispatch_tool_not_found() {
252 let registry = HandlerRegistry::new();
253 let input = TestInput { value: 21 };
254 let input_bytes = serde_json::to_vec(&input).unwrap();
255
256 let result = registry.dispatch("nonexistent", &input_bytes).await;
257 assert!(result.is_err());
258 assert!(matches!(result.unwrap_err(), crate::Error::ToolNotFound(_)));
259 }
260
261 #[tokio::test]
262 async fn test_registry_dispatch_invalid_input() {
263 let mut registry = HandlerRegistry::new();
264 registry.register("test", TestHandler);
265
266 let invalid_input = b"{\"invalid\": \"json\"}";
267 let result = registry.dispatch("test", invalid_input).await;
268 assert!(result.is_err());
269 }
270
271 #[tokio::test]
272 async fn test_registry_dispatch_handler_error() {
273 let mut registry = HandlerRegistry::new();
274 registry.register("error", ErrorHandler);
275
276 let input = TestInput { value: 21 };
277 let input_bytes = serde_json::to_vec(&input).unwrap();
278
279 let result = registry.dispatch("error", &input_bytes).await;
280 assert!(result.is_err());
281 assert!(matches!(result.unwrap_err(), crate::Error::Handler(_)));
282 }
283
284 #[tokio::test]
285 async fn test_registry_get_schemas() {
286 let mut registry = HandlerRegistry::new();
287 registry.register("test", TestHandler);
288
289 let input_schema = registry.get_input_schema("test");
290 assert!(input_schema.is_some());
291
292 let output_schema = registry.get_output_schema("test");
293 assert!(output_schema.is_some());
294
295 let missing_schema = registry.get_input_schema("nonexistent");
296 assert!(missing_schema.is_none());
297 }
298
299 #[tokio::test]
300 async fn test_registry_multiple_handlers() {
301 let mut registry = HandlerRegistry::new();
302 registry.register("handler1", TestHandler);
303 registry.register("handler2", TestHandler);
304 registry.register("handler3", TestHandler);
305
306 assert_eq!(registry.len(), 3);
307 assert!(registry.has_handler("handler1"));
308 assert!(registry.has_handler("handler2"));
309 assert!(registry.has_handler("handler3"));
310 }
311
312 #[tokio::test]
313 async fn test_schema_not_default() {
314 let mut registry = HandlerRegistry::new();
315 registry.register("test", TestHandler);
316
317 let input_schema = registry.get_input_schema("test").unwrap();
318 let default_schema = schemars::schema::RootSchema::default();
319
320 assert_ne!(
322 serde_json::to_string(&input_schema).unwrap(),
323 serde_json::to_string(&default_schema).unwrap(),
324 "Input schema should not be Default::default()"
325 );
326
327 let output_schema = registry.get_output_schema("test").unwrap();
328 assert_ne!(
329 serde_json::to_string(&output_schema).unwrap(),
330 serde_json::to_string(&default_schema).unwrap(),
331 "Output schema should not be Default::default()"
332 );
333 }
334
335 #[tokio::test]
336 async fn test_schema_properties() {
337 let mut registry = HandlerRegistry::new();
338 registry.register("test", TestHandler);
339
340 let input_schema = registry.get_input_schema("test").unwrap();
342 assert!(
343 input_schema.schema.object.is_some(),
344 "Input schema should have object"
345 );
346
347 let output_schema = registry.get_output_schema("test").unwrap();
349 assert!(
350 output_schema.schema.object.is_some(),
351 "Output schema should have object"
352 );
353 }
354}