forge_runtime/function/
registry.rs1use std::collections::HashMap;
2use std::future::Future;
3use std::pin::Pin;
4use std::sync::Arc;
5
6use forge_core::{
7 ForgeMutation, ForgeQuery, FunctionInfo, FunctionKind, MutationContext, QueryContext, Result,
8};
9use serde_json::Value;
10
11fn normalize_args(args: Value) -> Value {
16 let unwrapped = match &args {
18 Value::Object(map) if map.len() == 1 && map.contains_key("args") => {
19 map.get("args").cloned().unwrap_or(Value::Null)
20 }
21 _ => args,
22 };
23
24 match &unwrapped {
26 Value::Object(map) if map.is_empty() => Value::Null,
27 _ => unwrapped,
28 }
29}
30
31pub type BoxedQueryFn = Arc<
33 dyn Fn(&QueryContext, Value) -> Pin<Box<dyn Future<Output = Result<Value>> + Send + '_>>
34 + Send
35 + Sync,
36>;
37
38pub type BoxedMutationFn = Arc<
39 dyn Fn(&MutationContext, Value) -> Pin<Box<dyn Future<Output = Result<Value>> + Send + '_>>
40 + Send
41 + Sync,
42>;
43
44pub enum FunctionEntry {
46 Query {
47 info: FunctionInfo,
48 handler: BoxedQueryFn,
49 },
50 Mutation {
51 info: FunctionInfo,
52 handler: BoxedMutationFn,
53 },
54}
55
56impl FunctionEntry {
57 pub fn info(&self) -> &FunctionInfo {
58 match self {
59 FunctionEntry::Query { info, .. } => info,
60 FunctionEntry::Mutation { info, .. } => info,
61 }
62 }
63
64 pub fn kind(&self) -> FunctionKind {
65 match self {
66 FunctionEntry::Query { .. } => FunctionKind::Query,
67 FunctionEntry::Mutation { .. } => FunctionKind::Mutation,
68 }
69 }
70}
71
72#[derive(Clone)]
74pub struct FunctionRegistry {
75 functions: HashMap<String, FunctionEntry>,
76}
77
78impl Clone for FunctionEntry {
79 fn clone(&self) -> Self {
80 match self {
81 FunctionEntry::Query { info, handler } => FunctionEntry::Query {
82 info: info.clone(),
83 handler: Arc::clone(handler),
84 },
85 FunctionEntry::Mutation { info, handler } => FunctionEntry::Mutation {
86 info: info.clone(),
87 handler: Arc::clone(handler),
88 },
89 }
90 }
91}
92
93impl FunctionRegistry {
94 pub fn new() -> Self {
96 Self {
97 functions: HashMap::new(),
98 }
99 }
100
101 pub fn register_query<Q: ForgeQuery>(&mut self)
103 where
104 Q::Args: serde::de::DeserializeOwned + Send + 'static,
105 Q::Output: serde::Serialize + Send + 'static,
106 {
107 let info = Q::info();
108 let name = info.name.to_string();
109
110 let handler: BoxedQueryFn = Arc::new(move |ctx, args| {
111 Box::pin(async move {
112 let parsed_args: Q::Args = serde_json::from_value(normalize_args(args))
113 .map_err(|e| forge_core::ForgeError::Validation(e.to_string()))?;
114 let result = Q::execute(ctx, parsed_args).await?;
115 serde_json::to_value(result)
116 .map_err(|e| forge_core::ForgeError::Internal(e.to_string()))
117 })
118 });
119
120 self.functions
121 .insert(name, FunctionEntry::Query { info, handler });
122 }
123
124 pub fn register_mutation<M: ForgeMutation>(&mut self)
126 where
127 M::Args: serde::de::DeserializeOwned + Send + 'static,
128 M::Output: serde::Serialize + Send + 'static,
129 {
130 let info = M::info();
131 let name = info.name.to_string();
132
133 let handler: BoxedMutationFn = Arc::new(move |ctx, args| {
134 Box::pin(async move {
135 let parsed_args: M::Args = serde_json::from_value(normalize_args(args))
136 .map_err(|e| forge_core::ForgeError::Validation(e.to_string()))?;
137 let result = M::execute(ctx, parsed_args).await?;
138 serde_json::to_value(result)
139 .map_err(|e| forge_core::ForgeError::Internal(e.to_string()))
140 })
141 });
142
143 self.functions
144 .insert(name, FunctionEntry::Mutation { info, handler });
145 }
146
147 pub fn get(&self, name: &str) -> Option<&FunctionEntry> {
149 self.functions.get(name)
150 }
151
152 pub fn function_names(&self) -> impl Iterator<Item = &str> {
154 self.functions.keys().map(|s| s.as_str())
155 }
156
157 pub fn functions(&self) -> impl Iterator<Item = (&str, &FunctionEntry)> {
159 self.functions.iter().map(|(k, v)| (k.as_str(), v))
160 }
161
162 pub fn len(&self) -> usize {
164 self.functions.len()
165 }
166
167 pub fn is_empty(&self) -> bool {
169 self.functions.is_empty()
170 }
171
172 pub fn queries(&self) -> impl Iterator<Item = (&str, &FunctionInfo)> {
174 self.functions.iter().filter_map(|(name, entry)| {
175 if let FunctionEntry::Query { info, .. } = entry {
176 Some((name.as_str(), info))
177 } else {
178 None
179 }
180 })
181 }
182
183 pub fn mutations(&self) -> impl Iterator<Item = (&str, &FunctionInfo)> {
185 self.functions.iter().filter_map(|(name, entry)| {
186 if let FunctionEntry::Mutation { info, .. } = entry {
187 Some((name.as_str(), info))
188 } else {
189 None
190 }
191 })
192 }
193}
194
195impl Default for FunctionRegistry {
196 fn default() -> Self {
197 Self::new()
198 }
199}
200
201#[cfg(test)]
202mod tests {
203 use super::*;
204
205 #[test]
206 fn test_empty_registry() {
207 let registry = FunctionRegistry::new();
208 assert!(registry.is_empty());
209 assert_eq!(registry.len(), 0);
210 assert!(registry.get("nonexistent").is_none());
211 }
212}