forge_runtime/function/
registry.rs1use std::collections::HashMap;
2use std::future::Future;
3use std::pin::Pin;
4use std::sync::Arc;
5
6use forge_core::{
7 ActionContext, ForgeAction, ForgeMutation, ForgeQuery, FunctionInfo, FunctionKind,
8 MutationContext, QueryContext, Result,
9};
10use serde_json::Value;
11
12pub type BoxedQueryFn = Arc<
14 dyn Fn(&QueryContext, Value) -> Pin<Box<dyn Future<Output = Result<Value>> + Send + '_>>
15 + Send
16 + Sync,
17>;
18
19pub type BoxedMutationFn = Arc<
20 dyn Fn(&MutationContext, Value) -> Pin<Box<dyn Future<Output = Result<Value>> + Send + '_>>
21 + Send
22 + Sync,
23>;
24
25pub type BoxedActionFn = Arc<
26 dyn Fn(&ActionContext, Value) -> Pin<Box<dyn Future<Output = Result<Value>> + Send + '_>>
27 + Send
28 + Sync,
29>;
30
31pub enum FunctionEntry {
33 Query {
34 info: FunctionInfo,
35 handler: BoxedQueryFn,
36 },
37 Mutation {
38 info: FunctionInfo,
39 handler: BoxedMutationFn,
40 },
41 Action {
42 info: FunctionInfo,
43 handler: BoxedActionFn,
44 },
45}
46
47impl FunctionEntry {
48 pub fn info(&self) -> &FunctionInfo {
49 match self {
50 FunctionEntry::Query { info, .. } => info,
51 FunctionEntry::Mutation { info, .. } => info,
52 FunctionEntry::Action { info, .. } => info,
53 }
54 }
55
56 pub fn kind(&self) -> FunctionKind {
57 match self {
58 FunctionEntry::Query { .. } => FunctionKind::Query,
59 FunctionEntry::Mutation { .. } => FunctionKind::Mutation,
60 FunctionEntry::Action { .. } => FunctionKind::Action,
61 }
62 }
63}
64
65#[derive(Clone)]
67pub struct FunctionRegistry {
68 functions: HashMap<String, FunctionEntry>,
69}
70
71impl Clone for FunctionEntry {
72 fn clone(&self) -> Self {
73 match self {
74 FunctionEntry::Query { info, handler } => FunctionEntry::Query {
75 info: info.clone(),
76 handler: Arc::clone(handler),
77 },
78 FunctionEntry::Mutation { info, handler } => FunctionEntry::Mutation {
79 info: info.clone(),
80 handler: Arc::clone(handler),
81 },
82 FunctionEntry::Action { info, handler } => FunctionEntry::Action {
83 info: info.clone(),
84 handler: Arc::clone(handler),
85 },
86 }
87 }
88}
89
90impl FunctionRegistry {
91 pub fn new() -> Self {
93 Self {
94 functions: HashMap::new(),
95 }
96 }
97
98 pub fn register_query<Q: ForgeQuery>(&mut self)
100 where
101 Q::Args: serde::de::DeserializeOwned + Send + 'static,
102 Q::Output: serde::Serialize + Send + 'static,
103 {
104 let info = Q::info();
105 let name = info.name.to_string();
106
107 let handler: BoxedQueryFn = Arc::new(move |ctx, args| {
108 Box::pin(async move {
109 let parsed_args: Q::Args = serde_json::from_value(args)
110 .map_err(|e| forge_core::ForgeError::Validation(e.to_string()))?;
111 let result = Q::execute(ctx, parsed_args).await?;
112 serde_json::to_value(result)
113 .map_err(|e| forge_core::ForgeError::Internal(e.to_string()))
114 })
115 });
116
117 self.functions
118 .insert(name, FunctionEntry::Query { info, handler });
119 }
120
121 pub fn register_mutation<M: ForgeMutation>(&mut self)
123 where
124 M::Args: serde::de::DeserializeOwned + Send + 'static,
125 M::Output: serde::Serialize + Send + 'static,
126 {
127 let info = M::info();
128 let name = info.name.to_string();
129
130 let handler: BoxedMutationFn = Arc::new(move |ctx, args| {
131 Box::pin(async move {
132 let parsed_args: M::Args = serde_json::from_value(args)
133 .map_err(|e| forge_core::ForgeError::Validation(e.to_string()))?;
134 let result = M::execute(ctx, parsed_args).await?;
135 serde_json::to_value(result)
136 .map_err(|e| forge_core::ForgeError::Internal(e.to_string()))
137 })
138 });
139
140 self.functions
141 .insert(name, FunctionEntry::Mutation { info, handler });
142 }
143
144 pub fn register_action<A: ForgeAction>(&mut self)
146 where
147 A::Args: serde::de::DeserializeOwned + Send + 'static,
148 A::Output: serde::Serialize + Send + 'static,
149 {
150 let info = A::info();
151 let name = info.name.to_string();
152
153 let handler: BoxedActionFn = Arc::new(move |ctx, args| {
154 Box::pin(async move {
155 let parsed_args: A::Args = serde_json::from_value(args)
156 .map_err(|e| forge_core::ForgeError::Validation(e.to_string()))?;
157 let result = A::execute(ctx, parsed_args).await?;
158 serde_json::to_value(result)
159 .map_err(|e| forge_core::ForgeError::Internal(e.to_string()))
160 })
161 });
162
163 self.functions
164 .insert(name, FunctionEntry::Action { info, handler });
165 }
166
167 pub fn get(&self, name: &str) -> Option<&FunctionEntry> {
169 self.functions.get(name)
170 }
171
172 pub fn function_names(&self) -> impl Iterator<Item = &str> {
174 self.functions.keys().map(|s| s.as_str())
175 }
176
177 pub fn functions(&self) -> impl Iterator<Item = (&str, &FunctionEntry)> {
179 self.functions.iter().map(|(k, v)| (k.as_str(), v))
180 }
181
182 pub fn len(&self) -> usize {
184 self.functions.len()
185 }
186
187 pub fn is_empty(&self) -> bool {
189 self.functions.is_empty()
190 }
191
192 pub fn queries(&self) -> impl Iterator<Item = (&str, &FunctionInfo)> {
194 self.functions.iter().filter_map(|(name, entry)| {
195 if let FunctionEntry::Query { info, .. } = entry {
196 Some((name.as_str(), info))
197 } else {
198 None
199 }
200 })
201 }
202
203 pub fn mutations(&self) -> impl Iterator<Item = (&str, &FunctionInfo)> {
205 self.functions.iter().filter_map(|(name, entry)| {
206 if let FunctionEntry::Mutation { info, .. } = entry {
207 Some((name.as_str(), info))
208 } else {
209 None
210 }
211 })
212 }
213
214 pub fn actions(&self) -> impl Iterator<Item = (&str, &FunctionInfo)> {
216 self.functions.iter().filter_map(|(name, entry)| {
217 if let FunctionEntry::Action { info, .. } = entry {
218 Some((name.as_str(), info))
219 } else {
220 None
221 }
222 })
223 }
224}
225
226impl Default for FunctionRegistry {
227 fn default() -> Self {
228 Self::new()
229 }
230}
231
232#[cfg(test)]
233mod tests {
234 use super::*;
235
236 #[test]
237 fn test_empty_registry() {
238 let registry = FunctionRegistry::new();
239 assert!(registry.is_empty());
240 assert_eq!(registry.len(), 0);
241 assert!(registry.get("nonexistent").is_none());
242 }
243}