forge_runtime/function/
registry.rs

1use std::collections::HashMap;
2use std::future::Future;
3use std::pin::Pin;
4use std::sync::Arc;
5
6use forge_core::{
7    ActionContext, ForgeAction, ForgeMutation, ForgeQuery, FunctionInfo, FunctionKind,
8    MutationContext, QueryContext, Result,
9};
10use serde_json::Value;
11
12/// Type alias for a boxed function that executes with JSON args and returns JSON result.
13pub type BoxedQueryFn = Arc<
14    dyn Fn(&QueryContext, Value) -> Pin<Box<dyn Future<Output = Result<Value>> + Send + '_>>
15        + Send
16        + Sync,
17>;
18
19pub type BoxedMutationFn = Arc<
20    dyn Fn(&MutationContext, Value) -> Pin<Box<dyn Future<Output = Result<Value>> + Send + '_>>
21        + Send
22        + Sync,
23>;
24
25pub type BoxedActionFn = Arc<
26    dyn Fn(&ActionContext, Value) -> Pin<Box<dyn Future<Output = Result<Value>> + Send + '_>>
27        + Send
28        + Sync,
29>;
30
31/// Entry in the function registry.
32pub enum FunctionEntry {
33    Query {
34        info: FunctionInfo,
35        handler: BoxedQueryFn,
36    },
37    Mutation {
38        info: FunctionInfo,
39        handler: BoxedMutationFn,
40    },
41    Action {
42        info: FunctionInfo,
43        handler: BoxedActionFn,
44    },
45}
46
47impl FunctionEntry {
48    pub fn info(&self) -> &FunctionInfo {
49        match self {
50            FunctionEntry::Query { info, .. } => info,
51            FunctionEntry::Mutation { info, .. } => info,
52            FunctionEntry::Action { info, .. } => info,
53        }
54    }
55
56    pub fn kind(&self) -> FunctionKind {
57        match self {
58            FunctionEntry::Query { .. } => FunctionKind::Query,
59            FunctionEntry::Mutation { .. } => FunctionKind::Mutation,
60            FunctionEntry::Action { .. } => FunctionKind::Action,
61        }
62    }
63}
64
65/// Registry of all FORGE functions.
66#[derive(Clone)]
67pub struct FunctionRegistry {
68    functions: HashMap<String, FunctionEntry>,
69}
70
71impl Clone for FunctionEntry {
72    fn clone(&self) -> Self {
73        match self {
74            FunctionEntry::Query { info, handler } => FunctionEntry::Query {
75                info: info.clone(),
76                handler: Arc::clone(handler),
77            },
78            FunctionEntry::Mutation { info, handler } => FunctionEntry::Mutation {
79                info: info.clone(),
80                handler: Arc::clone(handler),
81            },
82            FunctionEntry::Action { info, handler } => FunctionEntry::Action {
83                info: info.clone(),
84                handler: Arc::clone(handler),
85            },
86        }
87    }
88}
89
90impl FunctionRegistry {
91    /// Create a new empty registry.
92    pub fn new() -> Self {
93        Self {
94            functions: HashMap::new(),
95        }
96    }
97
98    /// Register a query function.
99    pub fn register_query<Q: ForgeQuery>(&mut self)
100    where
101        Q::Args: serde::de::DeserializeOwned + Send + 'static,
102        Q::Output: serde::Serialize + Send + 'static,
103    {
104        let info = Q::info();
105        let name = info.name.to_string();
106
107        let handler: BoxedQueryFn = Arc::new(move |ctx, args| {
108            Box::pin(async move {
109                let parsed_args: Q::Args = serde_json::from_value(args)
110                    .map_err(|e| forge_core::ForgeError::Validation(e.to_string()))?;
111                let result = Q::execute(ctx, parsed_args).await?;
112                serde_json::to_value(result)
113                    .map_err(|e| forge_core::ForgeError::Internal(e.to_string()))
114            })
115        });
116
117        self.functions
118            .insert(name, FunctionEntry::Query { info, handler });
119    }
120
121    /// Register a mutation function.
122    pub fn register_mutation<M: ForgeMutation>(&mut self)
123    where
124        M::Args: serde::de::DeserializeOwned + Send + 'static,
125        M::Output: serde::Serialize + Send + 'static,
126    {
127        let info = M::info();
128        let name = info.name.to_string();
129
130        let handler: BoxedMutationFn = Arc::new(move |ctx, args| {
131            Box::pin(async move {
132                let parsed_args: M::Args = serde_json::from_value(args)
133                    .map_err(|e| forge_core::ForgeError::Validation(e.to_string()))?;
134                let result = M::execute(ctx, parsed_args).await?;
135                serde_json::to_value(result)
136                    .map_err(|e| forge_core::ForgeError::Internal(e.to_string()))
137            })
138        });
139
140        self.functions
141            .insert(name, FunctionEntry::Mutation { info, handler });
142    }
143
144    /// Register an action function.
145    pub fn register_action<A: ForgeAction>(&mut self)
146    where
147        A::Args: serde::de::DeserializeOwned + Send + 'static,
148        A::Output: serde::Serialize + Send + 'static,
149    {
150        let info = A::info();
151        let name = info.name.to_string();
152
153        let handler: BoxedActionFn = Arc::new(move |ctx, args| {
154            Box::pin(async move {
155                let parsed_args: A::Args = serde_json::from_value(args)
156                    .map_err(|e| forge_core::ForgeError::Validation(e.to_string()))?;
157                let result = A::execute(ctx, parsed_args).await?;
158                serde_json::to_value(result)
159                    .map_err(|e| forge_core::ForgeError::Internal(e.to_string()))
160            })
161        });
162
163        self.functions
164            .insert(name, FunctionEntry::Action { info, handler });
165    }
166
167    /// Get a function by name.
168    pub fn get(&self, name: &str) -> Option<&FunctionEntry> {
169        self.functions.get(name)
170    }
171
172    /// Get all function names.
173    pub fn function_names(&self) -> impl Iterator<Item = &str> {
174        self.functions.keys().map(|s| s.as_str())
175    }
176
177    /// Get all functions.
178    pub fn functions(&self) -> impl Iterator<Item = (&str, &FunctionEntry)> {
179        self.functions.iter().map(|(k, v)| (k.as_str(), v))
180    }
181
182    /// Get the number of registered functions.
183    pub fn len(&self) -> usize {
184        self.functions.len()
185    }
186
187    /// Check if the registry is empty.
188    pub fn is_empty(&self) -> bool {
189        self.functions.is_empty()
190    }
191
192    /// Get all queries.
193    pub fn queries(&self) -> impl Iterator<Item = (&str, &FunctionInfo)> {
194        self.functions.iter().filter_map(|(name, entry)| {
195            if let FunctionEntry::Query { info, .. } = entry {
196                Some((name.as_str(), info))
197            } else {
198                None
199            }
200        })
201    }
202
203    /// Get all mutations.
204    pub fn mutations(&self) -> impl Iterator<Item = (&str, &FunctionInfo)> {
205        self.functions.iter().filter_map(|(name, entry)| {
206            if let FunctionEntry::Mutation { info, .. } = entry {
207                Some((name.as_str(), info))
208            } else {
209                None
210            }
211        })
212    }
213
214    /// Get all actions.
215    pub fn actions(&self) -> impl Iterator<Item = (&str, &FunctionInfo)> {
216        self.functions.iter().filter_map(|(name, entry)| {
217            if let FunctionEntry::Action { info, .. } = entry {
218                Some((name.as_str(), info))
219            } else {
220                None
221            }
222        })
223    }
224}
225
226impl Default for FunctionRegistry {
227    fn default() -> Self {
228        Self::new()
229    }
230}
231
232#[cfg(test)]
233mod tests {
234    use super::*;
235
236    #[test]
237    fn test_empty_registry() {
238        let registry = FunctionRegistry::new();
239        assert!(registry.is_empty());
240        assert_eq!(registry.len(), 0);
241        assert!(registry.get("nonexistent").is_none());
242    }
243}