forge_runtime/function/
registry.rs

1use std::collections::HashMap;
2use std::future::Future;
3use std::pin::Pin;
4use std::sync::Arc;
5
6use forge_core::{
7    ForgeMutation, ForgeQuery, FunctionInfo, FunctionKind, MutationContext, QueryContext, Result,
8};
9use serde_json::Value;
10
11/// Normalize args for deserialization.
12/// - Converts empty objects `{}` to `null` to support unit type `()` deserialization.
13/// - Unwraps `{"args": ...}` wrapper if present (frontend may send wrapped args).
14///   This allows frontend to send `{}` for functions with no arguments.
15fn normalize_args(args: Value) -> Value {
16    // First, unwrap {"args": ...} wrapper if present
17    let unwrapped = match &args {
18        Value::Object(map) if map.len() == 1 && map.contains_key("args") => {
19            map.get("args").cloned().unwrap_or(Value::Null)
20        }
21        _ => args,
22    };
23
24    // Then normalize empty objects to null
25    match &unwrapped {
26        Value::Object(map) if map.is_empty() => Value::Null,
27        _ => unwrapped,
28    }
29}
30
31/// Type alias for a boxed function that executes with JSON args and returns JSON result.
32pub type BoxedQueryFn = Arc<
33    dyn Fn(&QueryContext, Value) -> Pin<Box<dyn Future<Output = Result<Value>> + Send + '_>>
34        + Send
35        + Sync,
36>;
37
38pub type BoxedMutationFn = Arc<
39    dyn Fn(&MutationContext, Value) -> Pin<Box<dyn Future<Output = Result<Value>> + Send + '_>>
40        + Send
41        + Sync,
42>;
43
44/// Entry in the function registry.
45pub enum FunctionEntry {
46    Query {
47        info: FunctionInfo,
48        handler: BoxedQueryFn,
49    },
50    Mutation {
51        info: FunctionInfo,
52        handler: BoxedMutationFn,
53    },
54}
55
56impl FunctionEntry {
57    pub fn info(&self) -> &FunctionInfo {
58        match self {
59            FunctionEntry::Query { info, .. } => info,
60            FunctionEntry::Mutation { info, .. } => info,
61        }
62    }
63
64    pub fn kind(&self) -> FunctionKind {
65        match self {
66            FunctionEntry::Query { .. } => FunctionKind::Query,
67            FunctionEntry::Mutation { .. } => FunctionKind::Mutation,
68        }
69    }
70}
71
72/// Registry of all FORGE functions.
73#[derive(Clone)]
74pub struct FunctionRegistry {
75    functions: HashMap<String, FunctionEntry>,
76}
77
78impl Clone for FunctionEntry {
79    fn clone(&self) -> Self {
80        match self {
81            FunctionEntry::Query { info, handler } => FunctionEntry::Query {
82                info: info.clone(),
83                handler: Arc::clone(handler),
84            },
85            FunctionEntry::Mutation { info, handler } => FunctionEntry::Mutation {
86                info: info.clone(),
87                handler: Arc::clone(handler),
88            },
89        }
90    }
91}
92
93impl FunctionRegistry {
94    /// Create a new empty registry.
95    pub fn new() -> Self {
96        Self {
97            functions: HashMap::new(),
98        }
99    }
100
101    /// Register a query function.
102    pub fn register_query<Q: ForgeQuery>(&mut self)
103    where
104        Q::Args: serde::de::DeserializeOwned + Send + 'static,
105        Q::Output: serde::Serialize + Send + 'static,
106    {
107        let info = Q::info();
108        let name = info.name.to_string();
109
110        let handler: BoxedQueryFn = Arc::new(move |ctx, args| {
111            Box::pin(async move {
112                let parsed_args: Q::Args = serde_json::from_value(normalize_args(args))
113                    .map_err(|e| forge_core::ForgeError::Validation(e.to_string()))?;
114                let result = Q::execute(ctx, parsed_args).await?;
115                serde_json::to_value(result)
116                    .map_err(|e| forge_core::ForgeError::Internal(e.to_string()))
117            })
118        });
119
120        self.functions
121            .insert(name, FunctionEntry::Query { info, handler });
122    }
123
124    /// Register a mutation function.
125    pub fn register_mutation<M: ForgeMutation>(&mut self)
126    where
127        M::Args: serde::de::DeserializeOwned + Send + 'static,
128        M::Output: serde::Serialize + Send + 'static,
129    {
130        let info = M::info();
131        let name = info.name.to_string();
132
133        let handler: BoxedMutationFn = Arc::new(move |ctx, args| {
134            Box::pin(async move {
135                let parsed_args: M::Args = serde_json::from_value(normalize_args(args))
136                    .map_err(|e| forge_core::ForgeError::Validation(e.to_string()))?;
137                let result = M::execute(ctx, parsed_args).await?;
138                serde_json::to_value(result)
139                    .map_err(|e| forge_core::ForgeError::Internal(e.to_string()))
140            })
141        });
142
143        self.functions
144            .insert(name, FunctionEntry::Mutation { info, handler });
145    }
146
147    /// Get a function by name.
148    pub fn get(&self, name: &str) -> Option<&FunctionEntry> {
149        self.functions.get(name)
150    }
151
152    /// Get all function names.
153    pub fn function_names(&self) -> impl Iterator<Item = &str> {
154        self.functions.keys().map(|s| s.as_str())
155    }
156
157    /// Get all functions.
158    pub fn functions(&self) -> impl Iterator<Item = (&str, &FunctionEntry)> {
159        self.functions.iter().map(|(k, v)| (k.as_str(), v))
160    }
161
162    /// Get the number of registered functions.
163    pub fn len(&self) -> usize {
164        self.functions.len()
165    }
166
167    /// Check if the registry is empty.
168    pub fn is_empty(&self) -> bool {
169        self.functions.is_empty()
170    }
171
172    /// Get all queries.
173    pub fn queries(&self) -> impl Iterator<Item = (&str, &FunctionInfo)> {
174        self.functions.iter().filter_map(|(name, entry)| {
175            if let FunctionEntry::Query { info, .. } = entry {
176                Some((name.as_str(), info))
177            } else {
178                None
179            }
180        })
181    }
182
183    /// Get all mutations.
184    pub fn mutations(&self) -> impl Iterator<Item = (&str, &FunctionInfo)> {
185        self.functions.iter().filter_map(|(name, entry)| {
186            if let FunctionEntry::Mutation { info, .. } = entry {
187                Some((name.as_str(), info))
188            } else {
189                None
190            }
191        })
192    }
193}
194
195impl Default for FunctionRegistry {
196    fn default() -> Self {
197        Self::new()
198    }
199}
200
201#[cfg(test)]
202mod tests {
203    use super::*;
204
205    #[test]
206    fn test_empty_registry() {
207        let registry = FunctionRegistry::new();
208        assert!(registry.is_empty());
209        assert_eq!(registry.len(), 0);
210        assert!(registry.get("nonexistent").is_none());
211    }
212}