Skip to main content

forge_runtime/function/
registry.rs

1use ahash::AHashMap;
2use std::future::Future;
3use std::pin::Pin;
4use std::sync::Arc;
5
6use forge_core::{
7    ForgeMutation, ForgeQuery, FunctionInfo, FunctionKind, MutationContext, QueryContext, Result,
8};
9use serde_json::Value;
10
11/// Normalize args for deserialization.
12/// - Keeps `null` as-is so unit `()` deserializes correctly.
13/// - Treats `{}` as `null` so no-arg functions accept empty objects.
14/// - Unwraps `{"args": ...}` or `{"input": ...}` wrapper if present (callers may use either format).
15fn normalize_args(args: Value) -> Value {
16    match &args {
17        Value::Object(map) if map.is_empty() => Value::Null,
18        Value::Object(map) if map.len() == 1 => {
19            if map.contains_key("args") {
20                map.get("args").cloned().unwrap_or(Value::Null)
21            } else if map.contains_key("input") {
22                map.get("input").cloned().unwrap_or(Value::Null)
23            } else {
24                args
25            }
26        }
27        _ => args,
28    }
29}
30
31/// Type alias for a boxed function that executes with JSON args and returns JSON result.
32pub type BoxedQueryFn = Arc<
33    dyn Fn(&QueryContext, Value) -> Pin<Box<dyn Future<Output = Result<Value>> + Send + '_>>
34        + Send
35        + Sync,
36>;
37
38pub type BoxedMutationFn = Arc<
39    dyn Fn(&MutationContext, Value) -> Pin<Box<dyn Future<Output = Result<Value>> + Send + '_>>
40        + Send
41        + Sync,
42>;
43
44/// Entry in the function registry.
45pub enum FunctionEntry {
46    Query {
47        info: FunctionInfo,
48        handler: BoxedQueryFn,
49    },
50    Mutation {
51        info: FunctionInfo,
52        handler: BoxedMutationFn,
53    },
54    /// Webhook registered for metadata access (info lookup, MCP tool list,
55    /// observability). Execution always goes through the dedicated webhook HTTP
56    /// route with signature validation — this entry carries no handler.
57    Webhook { info: FunctionInfo },
58}
59
60impl FunctionEntry {
61    pub fn info(&self) -> &FunctionInfo {
62        match self {
63            FunctionEntry::Query { info, .. } => info,
64            FunctionEntry::Mutation { info, .. } => info,
65            FunctionEntry::Webhook { info } => info,
66        }
67    }
68
69    pub fn kind(&self) -> FunctionKind {
70        match self {
71            FunctionEntry::Query { .. } => FunctionKind::Query,
72            FunctionEntry::Mutation { .. } => FunctionKind::Mutation,
73            FunctionEntry::Webhook { .. } => FunctionKind::Webhook,
74        }
75    }
76}
77
78/// Registry of all FORGE functions.
79#[derive(Clone)]
80pub struct FunctionRegistry {
81    functions: AHashMap<String, FunctionEntry>,
82}
83
84impl Clone for FunctionEntry {
85    fn clone(&self) -> Self {
86        match self {
87            FunctionEntry::Query { info, handler } => FunctionEntry::Query {
88                info: info.clone(),
89                handler: Arc::clone(handler),
90            },
91            FunctionEntry::Mutation { info, handler } => FunctionEntry::Mutation {
92                info: info.clone(),
93                handler: Arc::clone(handler),
94            },
95            FunctionEntry::Webhook { info } => FunctionEntry::Webhook { info: info.clone() },
96        }
97    }
98}
99
100impl FunctionRegistry {
101    /// Create a new empty registry.
102    pub fn new() -> Self {
103        Self {
104            functions: AHashMap::new(),
105        }
106    }
107
108    /// Register a query function.
109    pub fn register_query<Q: ForgeQuery>(&mut self)
110    where
111        Q::Args: serde::de::DeserializeOwned + Send + 'static,
112        Q::Output: serde::Serialize + Send + 'static,
113    {
114        let info = Q::info();
115        let name = info.name.to_string();
116
117        let handler: BoxedQueryFn = Arc::new(move |ctx, args| {
118            Box::pin(async move {
119                let parsed_args: Q::Args = serde_json::from_value(normalize_args(args))
120                    .map_err(|e| forge_core::ForgeError::Validation(e.to_string()))?;
121                let result = Q::execute(ctx, parsed_args).await?;
122                serde_json::to_value(result).map_err(|e| {
123                    forge_core::ForgeError::internal_with("Failed to serialize result", e)
124                })
125            })
126        });
127
128        self.functions
129            .insert(name, FunctionEntry::Query { info, handler });
130    }
131
132    /// Register a mutation function.
133    pub fn register_mutation<M: ForgeMutation>(&mut self)
134    where
135        M::Args: serde::de::DeserializeOwned + Send + 'static,
136        M::Output: serde::Serialize + Send + 'static,
137    {
138        let info = M::info();
139        let name = info.name.to_string();
140
141        let handler: BoxedMutationFn = Arc::new(move |ctx, args| {
142            Box::pin(async move {
143                let parsed_args: M::Args = serde_json::from_value(normalize_args(args))
144                    .map_err(|e| forge_core::ForgeError::Validation(e.to_string()))?;
145                let result = M::execute(ctx, parsed_args).await?;
146                serde_json::to_value(result).map_err(|e| {
147                    forge_core::ForgeError::internal_with("Failed to serialize result", e)
148                })
149            })
150        });
151
152        self.functions
153            .insert(name, FunctionEntry::Mutation { info, handler });
154    }
155
156    /// Get a function by name.
157    pub fn get(&self, name: &str) -> Option<&FunctionEntry> {
158        self.functions.get(name)
159    }
160
161    /// Get all function names.
162    pub fn function_names(&self) -> impl Iterator<Item = &str> {
163        self.functions.keys().map(|s| s.as_str())
164    }
165
166    /// Get all functions.
167    pub fn functions(&self) -> impl Iterator<Item = (&str, &FunctionEntry)> {
168        self.functions.iter().map(|(k, v)| (k.as_str(), v))
169    }
170
171    /// Get the number of registered functions.
172    pub fn len(&self) -> usize {
173        self.functions.len()
174    }
175
176    /// Check if the registry is empty.
177    pub fn is_empty(&self) -> bool {
178        self.functions.is_empty()
179    }
180
181    /// Get all queries.
182    pub fn queries(&self) -> impl Iterator<Item = (&str, &FunctionInfo)> {
183        self.functions.iter().filter_map(|(name, entry)| {
184            if let FunctionEntry::Query { info, .. } = entry {
185                Some((name.as_str(), info))
186            } else {
187                None
188            }
189        })
190    }
191
192    /// Register a webhook's metadata into the function registry.
193    ///
194    /// Webhooks are registered here for metadata access (info lookup, MCP tool
195    /// list, observability) only. They carry no executable handler because
196    /// execution requires signature validation that lives in the dedicated
197    /// webhook HTTP route. Direct RPC calls to webhook names are rejected by
198    /// `FunctionRouter`.
199    pub fn register_webhook_info(&mut self, info: FunctionInfo) {
200        let name = info.name.to_string();
201        self.functions.insert(name, FunctionEntry::Webhook { info });
202    }
203
204    /// Get all mutations.
205    pub fn mutations(&self) -> impl Iterator<Item = (&str, &FunctionInfo)> {
206        self.functions.iter().filter_map(|(name, entry)| {
207            if let FunctionEntry::Mutation { info, .. } = entry {
208                Some((name.as_str(), info))
209            } else {
210                None
211            }
212        })
213    }
214
215    /// Get all webhooks.
216    pub fn webhooks(&self) -> impl Iterator<Item = (&str, &FunctionInfo)> {
217        self.functions.iter().filter_map(|(name, entry)| {
218            if let FunctionEntry::Webhook { info } = entry {
219                Some((name.as_str(), info))
220            } else {
221                None
222            }
223        })
224    }
225}
226
227impl Default for FunctionRegistry {
228    fn default() -> Self {
229        Self::new()
230    }
231}
232
233#[cfg(test)]
234#[allow(clippy::unwrap_used, clippy::indexing_slicing)]
235mod tests {
236    use super::*;
237    use serde_json::json;
238
239    fn info(name: &'static str, kind: FunctionKind) -> FunctionInfo {
240        FunctionInfo {
241            name,
242            description: None,
243            kind,
244            required_role: None,
245            is_public: true,
246            cache_ttl: None,
247            timeout: None,
248            http_timeout: None,
249            rate_limit_requests: None,
250            rate_limit_per_secs: None,
251            rate_limit_key: None,
252            log_level: None,
253            table_dependencies: &[],
254            selected_columns: &[],
255            changed_columns: &[],
256            transactional: false,
257            consistent: false,
258            max_upload_size_bytes: None,
259            requires_tenant_scope: false,
260        }
261    }
262
263    #[test]
264    fn normalize_args_passes_null_through_untouched() {
265        // Unit `()` deserializes from JSON null; do not collapse it elsewhere.
266        assert_eq!(normalize_args(json!(null)), json!(null));
267    }
268
269    #[test]
270    fn normalize_args_treats_empty_object_as_null() {
271        // No-arg handlers accept `{}` so frontends don't have to special-case
272        // unit args when serializing a typed payload.
273        assert_eq!(normalize_args(json!({})), json!(null));
274    }
275
276    #[test]
277    fn normalize_args_unwraps_args_envelope() {
278        assert_eq!(normalize_args(json!({"args": {"id": 7}})), json!({"id": 7}));
279        assert_eq!(normalize_args(json!({"args": 42})), json!(42));
280        assert_eq!(normalize_args(json!({"args": null})), json!(null));
281    }
282
283    #[test]
284    fn normalize_args_unwraps_input_envelope() {
285        assert_eq!(
286            normalize_args(json!({"input": [1, 2, 3]})),
287            json!([1, 2, 3])
288        );
289    }
290
291    #[test]
292    fn normalize_args_keeps_other_single_key_objects_intact() {
293        // A handler with `struct Args { id: u32 }` should receive {"id": ...}
294        // as-is — the envelope-stripping only fires for `args`/`input`.
295        let v = json!({"id": 7});
296        assert_eq!(normalize_args(v.clone()), v);
297    }
298
299    #[test]
300    fn normalize_args_keeps_multi_key_objects_intact() {
301        let v = json!({"name": "alice", "age": 30});
302        assert_eq!(normalize_args(v.clone()), v);
303    }
304
305    #[test]
306    fn normalize_args_keeps_non_object_values_intact() {
307        assert_eq!(normalize_args(json!(42)), json!(42));
308        assert_eq!(normalize_args(json!("hello")), json!("hello"));
309        assert_eq!(normalize_args(json!([1, 2])), json!([1, 2]));
310        assert_eq!(normalize_args(json!(true)), json!(true));
311    }
312
313    // --- Registry construction + lookup ---
314
315    #[test]
316    fn new_registry_is_empty() {
317        let reg = FunctionRegistry::new();
318        assert!(reg.is_empty());
319        assert_eq!(reg.len(), 0);
320        assert!(reg.get("anything").is_none());
321        assert_eq!(reg.function_names().count(), 0);
322    }
323
324    #[test]
325    fn register_webhook_info_stores_entry_under_function_name() {
326        let mut reg = FunctionRegistry::new();
327        reg.register_webhook_info(info("stripe_webhook", FunctionKind::Webhook));
328
329        assert_eq!(reg.len(), 1);
330        assert!(!reg.is_empty());
331        let entry = reg.get("stripe_webhook").expect("registered");
332        assert_eq!(entry.kind(), FunctionKind::Webhook);
333        assert_eq!(entry.info().name, "stripe_webhook");
334    }
335
336    #[test]
337    fn register_same_name_overwrites_existing_entry() {
338        // Last writer wins — startup ordering decides which handler is active.
339        let mut reg = FunctionRegistry::new();
340        reg.register_webhook_info(info("dup", FunctionKind::Webhook));
341        let mut second = info("dup", FunctionKind::Webhook);
342        second.is_public = false;
343        reg.register_webhook_info(second);
344
345        assert_eq!(reg.len(), 1);
346        assert!(!reg.get("dup").expect("present").info().is_public);
347    }
348
349    #[test]
350    fn iterators_partition_by_kind() {
351        let mut reg = FunctionRegistry::new();
352        // Use webhook entries for each kind variant — only the FunctionEntry
353        // variant is exercised, not handler invocation.
354        reg.register_webhook_info(info("hook_a", FunctionKind::Webhook));
355        reg.register_webhook_info(info("hook_b", FunctionKind::Webhook));
356
357        let names: Vec<&str> = reg.function_names().collect();
358        assert_eq!(names.len(), 2);
359        assert!(names.contains(&"hook_a"));
360        assert!(names.contains(&"hook_b"));
361
362        // Webhook accessor returns both.
363        let webhooks: Vec<&str> = reg.webhooks().map(|(n, _)| n).collect();
364        assert_eq!(webhooks.len(), 2);
365        // Query/Mutation accessors skip webhook entries.
366        assert_eq!(reg.queries().count(), 0);
367        assert_eq!(reg.mutations().count(), 0);
368    }
369}