forge_runtime/function/
registry.rs1use std::collections::HashMap;
2use std::future::Future;
3use std::pin::Pin;
4use std::sync::Arc;
5
6use forge_core::{
7 ForgeMutation, ForgeQuery, FunctionInfo, FunctionKind, MutationContext, QueryContext, Result,
8};
9use serde_json::Value;
10
11fn normalize_args(args: Value) -> Value {
15 match &args {
16 Value::Object(map) if map.len() == 1 => {
17 if map.contains_key("args") {
18 map.get("args").cloned().unwrap_or(Value::Null)
19 } else if map.contains_key("input") {
20 map.get("input").cloned().unwrap_or(Value::Null)
21 } else {
22 args
23 }
24 }
25 _ => args,
26 }
27}
28
29pub type BoxedQueryFn = Arc<
31 dyn Fn(&QueryContext, Value) -> Pin<Box<dyn Future<Output = Result<Value>> + Send + '_>>
32 + Send
33 + Sync,
34>;
35
36pub type BoxedMutationFn = Arc<
37 dyn Fn(&MutationContext, Value) -> Pin<Box<dyn Future<Output = Result<Value>> + Send + '_>>
38 + Send
39 + Sync,
40>;
41
42pub enum FunctionEntry {
44 Query {
45 info: FunctionInfo,
46 handler: BoxedQueryFn,
47 },
48 Mutation {
49 info: FunctionInfo,
50 handler: BoxedMutationFn,
51 },
52}
53
54impl FunctionEntry {
55 pub fn info(&self) -> &FunctionInfo {
56 match self {
57 FunctionEntry::Query { info, .. } => info,
58 FunctionEntry::Mutation { info, .. } => info,
59 }
60 }
61
62 pub fn kind(&self) -> FunctionKind {
63 match self {
64 FunctionEntry::Query { .. } => FunctionKind::Query,
65 FunctionEntry::Mutation { .. } => FunctionKind::Mutation,
66 }
67 }
68}
69
70#[derive(Clone)]
72pub struct FunctionRegistry {
73 functions: HashMap<String, FunctionEntry>,
74}
75
76impl Clone for FunctionEntry {
77 fn clone(&self) -> Self {
78 match self {
79 FunctionEntry::Query { info, handler } => FunctionEntry::Query {
80 info: info.clone(),
81 handler: Arc::clone(handler),
82 },
83 FunctionEntry::Mutation { info, handler } => FunctionEntry::Mutation {
84 info: info.clone(),
85 handler: Arc::clone(handler),
86 },
87 }
88 }
89}
90
91impl FunctionRegistry {
92 pub fn new() -> Self {
94 Self {
95 functions: HashMap::new(),
96 }
97 }
98
99 pub fn register_query<Q: ForgeQuery>(&mut self)
101 where
102 Q::Args: serde::de::DeserializeOwned + Send + 'static,
103 Q::Output: serde::Serialize + Send + 'static,
104 {
105 let info = Q::info();
106 let name = info.name.to_string();
107
108 let handler: BoxedQueryFn = Arc::new(move |ctx, args| {
109 Box::pin(async move {
110 let parsed_args: Q::Args = serde_json::from_value(normalize_args(args))
111 .map_err(|e| forge_core::ForgeError::Validation(e.to_string()))?;
112 let result = Q::execute(ctx, parsed_args).await?;
113 serde_json::to_value(result)
114 .map_err(|e| forge_core::ForgeError::Internal(e.to_string()))
115 })
116 });
117
118 self.functions
119 .insert(name, FunctionEntry::Query { info, handler });
120 }
121
122 pub fn register_mutation<M: ForgeMutation>(&mut self)
124 where
125 M::Args: serde::de::DeserializeOwned + Send + 'static,
126 M::Output: serde::Serialize + Send + 'static,
127 {
128 let info = M::info();
129 let name = info.name.to_string();
130
131 let handler: BoxedMutationFn = Arc::new(move |ctx, args| {
132 Box::pin(async move {
133 let parsed_args: M::Args = serde_json::from_value(normalize_args(args))
134 .map_err(|e| forge_core::ForgeError::Validation(e.to_string()))?;
135 let result = M::execute(ctx, parsed_args).await?;
136 serde_json::to_value(result)
137 .map_err(|e| forge_core::ForgeError::Internal(e.to_string()))
138 })
139 });
140
141 self.functions
142 .insert(name, FunctionEntry::Mutation { info, handler });
143 }
144
145 pub fn get(&self, name: &str) -> Option<&FunctionEntry> {
147 self.functions.get(name)
148 }
149
150 pub fn function_names(&self) -> impl Iterator<Item = &str> {
152 self.functions.keys().map(|s| s.as_str())
153 }
154
155 pub fn functions(&self) -> impl Iterator<Item = (&str, &FunctionEntry)> {
157 self.functions.iter().map(|(k, v)| (k.as_str(), v))
158 }
159
160 pub fn len(&self) -> usize {
162 self.functions.len()
163 }
164
165 pub fn is_empty(&self) -> bool {
167 self.functions.is_empty()
168 }
169
170 pub fn queries(&self) -> impl Iterator<Item = (&str, &FunctionInfo)> {
172 self.functions.iter().filter_map(|(name, entry)| {
173 if let FunctionEntry::Query { info, .. } = entry {
174 Some((name.as_str(), info))
175 } else {
176 None
177 }
178 })
179 }
180
181 pub fn mutations(&self) -> impl Iterator<Item = (&str, &FunctionInfo)> {
183 self.functions.iter().filter_map(|(name, entry)| {
184 if let FunctionEntry::Mutation { info, .. } = entry {
185 Some((name.as_str(), info))
186 } else {
187 None
188 }
189 })
190 }
191}
192
193impl Default for FunctionRegistry {
194 fn default() -> Self {
195 Self::new()
196 }
197}
198
199#[cfg(test)]
200mod tests {
201 use super::*;
202
203 #[test]
204 fn test_empty_registry() {
205 let registry = FunctionRegistry::new();
206 assert!(registry.is_empty());
207 assert_eq!(registry.len(), 0);
208 assert!(registry.get("nonexistent").is_none());
209 }
210}