Skip to main content

forge_runtime/function/
registry.rs

1use std::collections::HashMap;
2use std::future::Future;
3use std::pin::Pin;
4use std::sync::Arc;
5
6use forge_core::{
7    ForgeMutation, ForgeQuery, FunctionInfo, FunctionKind, MutationContext, QueryContext, Result,
8};
9use serde_json::Value;
10
11/// Normalize args for deserialization.
12/// - Keeps `null` as-is so unit `()` deserializes correctly.
13/// - Unwraps `{"args": ...}` or `{"input": ...}` wrapper if present (callers may use either format).
14fn normalize_args(args: Value) -> Value {
15    match &args {
16        Value::Object(map) if map.len() == 1 => {
17            if map.contains_key("args") {
18                map.get("args").cloned().unwrap_or(Value::Null)
19            } else if map.contains_key("input") {
20                map.get("input").cloned().unwrap_or(Value::Null)
21            } else {
22                args
23            }
24        }
25        _ => args,
26    }
27}
28
29/// Type alias for a boxed function that executes with JSON args and returns JSON result.
30pub type BoxedQueryFn = Arc<
31    dyn Fn(&QueryContext, Value) -> Pin<Box<dyn Future<Output = Result<Value>> + Send + '_>>
32        + Send
33        + Sync,
34>;
35
36pub type BoxedMutationFn = Arc<
37    dyn Fn(&MutationContext, Value) -> Pin<Box<dyn Future<Output = Result<Value>> + Send + '_>>
38        + Send
39        + Sync,
40>;
41
42/// Entry in the function registry.
43pub enum FunctionEntry {
44    Query {
45        info: FunctionInfo,
46        handler: BoxedQueryFn,
47    },
48    Mutation {
49        info: FunctionInfo,
50        handler: BoxedMutationFn,
51    },
52}
53
54impl FunctionEntry {
55    pub fn info(&self) -> &FunctionInfo {
56        match self {
57            FunctionEntry::Query { info, .. } => info,
58            FunctionEntry::Mutation { info, .. } => info,
59        }
60    }
61
62    pub fn kind(&self) -> FunctionKind {
63        match self {
64            FunctionEntry::Query { .. } => FunctionKind::Query,
65            FunctionEntry::Mutation { .. } => FunctionKind::Mutation,
66        }
67    }
68}
69
70/// Registry of all FORGE functions.
71#[derive(Clone)]
72pub struct FunctionRegistry {
73    functions: HashMap<String, FunctionEntry>,
74}
75
76impl Clone for FunctionEntry {
77    fn clone(&self) -> Self {
78        match self {
79            FunctionEntry::Query { info, handler } => FunctionEntry::Query {
80                info: info.clone(),
81                handler: Arc::clone(handler),
82            },
83            FunctionEntry::Mutation { info, handler } => FunctionEntry::Mutation {
84                info: info.clone(),
85                handler: Arc::clone(handler),
86            },
87        }
88    }
89}
90
91impl FunctionRegistry {
92    /// Create a new empty registry.
93    pub fn new() -> Self {
94        Self {
95            functions: HashMap::new(),
96        }
97    }
98
99    /// Register a query function.
100    pub fn register_query<Q: ForgeQuery>(&mut self)
101    where
102        Q::Args: serde::de::DeserializeOwned + Send + 'static,
103        Q::Output: serde::Serialize + Send + 'static,
104    {
105        let info = Q::info();
106        let name = info.name.to_string();
107
108        let handler: BoxedQueryFn = Arc::new(move |ctx, args| {
109            Box::pin(async move {
110                let parsed_args: Q::Args = serde_json::from_value(normalize_args(args))
111                    .map_err(|e| forge_core::ForgeError::Validation(e.to_string()))?;
112                let result = Q::execute(ctx, parsed_args).await?;
113                serde_json::to_value(result)
114                    .map_err(|e| forge_core::ForgeError::Internal(e.to_string()))
115            })
116        });
117
118        self.functions
119            .insert(name, FunctionEntry::Query { info, handler });
120    }
121
122    /// Register a mutation function.
123    pub fn register_mutation<M: ForgeMutation>(&mut self)
124    where
125        M::Args: serde::de::DeserializeOwned + Send + 'static,
126        M::Output: serde::Serialize + Send + 'static,
127    {
128        let info = M::info();
129        let name = info.name.to_string();
130
131        let handler: BoxedMutationFn = Arc::new(move |ctx, args| {
132            Box::pin(async move {
133                let parsed_args: M::Args = serde_json::from_value(normalize_args(args))
134                    .map_err(|e| forge_core::ForgeError::Validation(e.to_string()))?;
135                let result = M::execute(ctx, parsed_args).await?;
136                serde_json::to_value(result)
137                    .map_err(|e| forge_core::ForgeError::Internal(e.to_string()))
138            })
139        });
140
141        self.functions
142            .insert(name, FunctionEntry::Mutation { info, handler });
143    }
144
145    /// Get a function by name.
146    pub fn get(&self, name: &str) -> Option<&FunctionEntry> {
147        self.functions.get(name)
148    }
149
150    /// Get all function names.
151    pub fn function_names(&self) -> impl Iterator<Item = &str> {
152        self.functions.keys().map(|s| s.as_str())
153    }
154
155    /// Get all functions.
156    pub fn functions(&self) -> impl Iterator<Item = (&str, &FunctionEntry)> {
157        self.functions.iter().map(|(k, v)| (k.as_str(), v))
158    }
159
160    /// Get the number of registered functions.
161    pub fn len(&self) -> usize {
162        self.functions.len()
163    }
164
165    /// Check if the registry is empty.
166    pub fn is_empty(&self) -> bool {
167        self.functions.is_empty()
168    }
169
170    /// Get all queries.
171    pub fn queries(&self) -> impl Iterator<Item = (&str, &FunctionInfo)> {
172        self.functions.iter().filter_map(|(name, entry)| {
173            if let FunctionEntry::Query { info, .. } = entry {
174                Some((name.as_str(), info))
175            } else {
176                None
177            }
178        })
179    }
180
181    /// Get all mutations.
182    pub fn mutations(&self) -> impl Iterator<Item = (&str, &FunctionInfo)> {
183        self.functions.iter().filter_map(|(name, entry)| {
184            if let FunctionEntry::Mutation { info, .. } = entry {
185                Some((name.as_str(), info))
186            } else {
187                None
188            }
189        })
190    }
191}
192
193impl Default for FunctionRegistry {
194    fn default() -> Self {
195        Self::new()
196    }
197}
198
199#[cfg(test)]
200mod tests {
201    use super::*;
202
203    #[test]
204    fn test_empty_registry() {
205        let registry = FunctionRegistry::new();
206        assert!(registry.is_empty());
207        assert_eq!(registry.len(), 0);
208        assert!(registry.get("nonexistent").is_none());
209    }
210}