1use ahash::AHashMap;
2use std::future::Future;
3use std::pin::Pin;
4use std::sync::Arc;
5
6use forge_core::{
7 ForgeMutation, ForgeQuery, FunctionInfo, FunctionKind, MutationContext, QueryContext, Result,
8};
9use serde_json::Value;
10
11fn normalize_args(args: Value) -> Value {
16 match &args {
17 Value::Object(map) if map.is_empty() => Value::Null,
18 Value::Object(map) if map.len() == 1 => {
19 if map.contains_key("args") {
20 map.get("args").cloned().unwrap_or(Value::Null)
21 } else if map.contains_key("input") {
22 map.get("input").cloned().unwrap_or(Value::Null)
23 } else {
24 args
25 }
26 }
27 _ => args,
28 }
29}
30
31pub type BoxedQueryFn = Arc<
33 dyn Fn(&QueryContext, Value) -> Pin<Box<dyn Future<Output = Result<Value>> + Send + '_>>
34 + Send
35 + Sync,
36>;
37
38pub type BoxedMutationFn = Arc<
39 dyn Fn(&MutationContext, Value) -> Pin<Box<dyn Future<Output = Result<Value>> + Send + '_>>
40 + Send
41 + Sync,
42>;
43
44pub enum FunctionEntry {
46 Query {
47 info: FunctionInfo,
48 handler: BoxedQueryFn,
49 },
50 Mutation {
51 info: FunctionInfo,
52 handler: BoxedMutationFn,
53 },
54 Webhook { info: FunctionInfo },
58}
59
60impl FunctionEntry {
61 pub fn info(&self) -> &FunctionInfo {
62 match self {
63 FunctionEntry::Query { info, .. } => info,
64 FunctionEntry::Mutation { info, .. } => info,
65 FunctionEntry::Webhook { info } => info,
66 }
67 }
68
69 pub fn kind(&self) -> FunctionKind {
70 match self {
71 FunctionEntry::Query { .. } => FunctionKind::Query,
72 FunctionEntry::Mutation { .. } => FunctionKind::Mutation,
73 FunctionEntry::Webhook { .. } => FunctionKind::Webhook,
74 }
75 }
76}
77
78#[derive(Clone)]
80pub struct FunctionRegistry {
81 functions: AHashMap<String, FunctionEntry>,
82}
83
84impl Clone for FunctionEntry {
85 fn clone(&self) -> Self {
86 match self {
87 FunctionEntry::Query { info, handler } => FunctionEntry::Query {
88 info: info.clone(),
89 handler: Arc::clone(handler),
90 },
91 FunctionEntry::Mutation { info, handler } => FunctionEntry::Mutation {
92 info: info.clone(),
93 handler: Arc::clone(handler),
94 },
95 FunctionEntry::Webhook { info } => FunctionEntry::Webhook { info: info.clone() },
96 }
97 }
98}
99
100impl FunctionRegistry {
101 pub fn new() -> Self {
103 Self {
104 functions: AHashMap::new(),
105 }
106 }
107
108 pub fn register_query<Q: ForgeQuery>(&mut self)
110 where
111 Q::Args: serde::de::DeserializeOwned + Send + 'static,
112 Q::Output: serde::Serialize + Send + 'static,
113 {
114 let info = Q::info();
115 let name = info.name.to_string();
116
117 let handler: BoxedQueryFn = Arc::new(move |ctx, args| {
118 Box::pin(async move {
119 let parsed_args: Q::Args = serde_json::from_value(normalize_args(args))
120 .map_err(|e| forge_core::ForgeError::Validation(e.to_string()))?;
121 let result = Q::execute(ctx, parsed_args).await?;
122 serde_json::to_value(result).map_err(|e| {
123 forge_core::ForgeError::internal_with("Failed to serialize result", e)
124 })
125 })
126 });
127
128 self.functions
129 .insert(name, FunctionEntry::Query { info, handler });
130 }
131
132 pub fn register_mutation<M: ForgeMutation>(&mut self)
134 where
135 M::Args: serde::de::DeserializeOwned + Send + 'static,
136 M::Output: serde::Serialize + Send + 'static,
137 {
138 let info = M::info();
139 let name = info.name.to_string();
140
141 let handler: BoxedMutationFn = Arc::new(move |ctx, args| {
142 Box::pin(async move {
143 let parsed_args: M::Args = serde_json::from_value(normalize_args(args))
144 .map_err(|e| forge_core::ForgeError::Validation(e.to_string()))?;
145 let result = M::execute(ctx, parsed_args).await?;
146 serde_json::to_value(result).map_err(|e| {
147 forge_core::ForgeError::internal_with("Failed to serialize result", e)
148 })
149 })
150 });
151
152 self.functions
153 .insert(name, FunctionEntry::Mutation { info, handler });
154 }
155
156 pub fn get(&self, name: &str) -> Option<&FunctionEntry> {
158 self.functions.get(name)
159 }
160
161 pub fn function_names(&self) -> impl Iterator<Item = &str> {
163 self.functions.keys().map(|s| s.as_str())
164 }
165
166 pub fn functions(&self) -> impl Iterator<Item = (&str, &FunctionEntry)> {
168 self.functions.iter().map(|(k, v)| (k.as_str(), v))
169 }
170
171 pub fn len(&self) -> usize {
173 self.functions.len()
174 }
175
176 pub fn is_empty(&self) -> bool {
178 self.functions.is_empty()
179 }
180
181 pub fn queries(&self) -> impl Iterator<Item = (&str, &FunctionInfo)> {
183 self.functions.iter().filter_map(|(name, entry)| {
184 if let FunctionEntry::Query { info, .. } = entry {
185 Some((name.as_str(), info))
186 } else {
187 None
188 }
189 })
190 }
191
192 pub fn register_webhook_info(&mut self, info: FunctionInfo) {
200 let name = info.name.to_string();
201 self.functions.insert(name, FunctionEntry::Webhook { info });
202 }
203
204 pub fn mutations(&self) -> impl Iterator<Item = (&str, &FunctionInfo)> {
206 self.functions.iter().filter_map(|(name, entry)| {
207 if let FunctionEntry::Mutation { info, .. } = entry {
208 Some((name.as_str(), info))
209 } else {
210 None
211 }
212 })
213 }
214
215 pub fn webhooks(&self) -> impl Iterator<Item = (&str, &FunctionInfo)> {
217 self.functions.iter().filter_map(|(name, entry)| {
218 if let FunctionEntry::Webhook { info } = entry {
219 Some((name.as_str(), info))
220 } else {
221 None
222 }
223 })
224 }
225}
226
227impl Default for FunctionRegistry {
228 fn default() -> Self {
229 Self::new()
230 }
231}
232
233#[cfg(test)]
234#[allow(clippy::unwrap_used, clippy::indexing_slicing)]
235mod tests {
236 use super::*;
237 use serde_json::json;
238
239 fn info(name: &'static str, kind: FunctionKind) -> FunctionInfo {
240 FunctionInfo {
241 name,
242 description: None,
243 kind,
244 required_role: None,
245 is_public: true,
246 cache_ttl: None,
247 timeout: None,
248 http_timeout: None,
249 rate_limit_requests: None,
250 rate_limit_per_secs: None,
251 rate_limit_key: None,
252 log_level: None,
253 table_dependencies: &[],
254 selected_columns: &[],
255 changed_columns: &[],
256 transactional: false,
257 consistent: false,
258 max_upload_size_bytes: None,
259 requires_tenant_scope: false,
260 }
261 }
262
263 #[test]
264 fn normalize_args_passes_null_through_untouched() {
265 assert_eq!(normalize_args(json!(null)), json!(null));
267 }
268
269 #[test]
270 fn normalize_args_treats_empty_object_as_null() {
271 assert_eq!(normalize_args(json!({})), json!(null));
274 }
275
276 #[test]
277 fn normalize_args_unwraps_args_envelope() {
278 assert_eq!(normalize_args(json!({"args": {"id": 7}})), json!({"id": 7}));
279 assert_eq!(normalize_args(json!({"args": 42})), json!(42));
280 assert_eq!(normalize_args(json!({"args": null})), json!(null));
281 }
282
283 #[test]
284 fn normalize_args_unwraps_input_envelope() {
285 assert_eq!(
286 normalize_args(json!({"input": [1, 2, 3]})),
287 json!([1, 2, 3])
288 );
289 }
290
291 #[test]
292 fn normalize_args_keeps_other_single_key_objects_intact() {
293 let v = json!({"id": 7});
296 assert_eq!(normalize_args(v.clone()), v);
297 }
298
299 #[test]
300 fn normalize_args_keeps_multi_key_objects_intact() {
301 let v = json!({"name": "alice", "age": 30});
302 assert_eq!(normalize_args(v.clone()), v);
303 }
304
305 #[test]
306 fn normalize_args_keeps_non_object_values_intact() {
307 assert_eq!(normalize_args(json!(42)), json!(42));
308 assert_eq!(normalize_args(json!("hello")), json!("hello"));
309 assert_eq!(normalize_args(json!([1, 2])), json!([1, 2]));
310 assert_eq!(normalize_args(json!(true)), json!(true));
311 }
312
313 #[test]
316 fn new_registry_is_empty() {
317 let reg = FunctionRegistry::new();
318 assert!(reg.is_empty());
319 assert_eq!(reg.len(), 0);
320 assert!(reg.get("anything").is_none());
321 assert_eq!(reg.function_names().count(), 0);
322 }
323
324 #[test]
325 fn register_webhook_info_stores_entry_under_function_name() {
326 let mut reg = FunctionRegistry::new();
327 reg.register_webhook_info(info("stripe_webhook", FunctionKind::Webhook));
328
329 assert_eq!(reg.len(), 1);
330 assert!(!reg.is_empty());
331 let entry = reg.get("stripe_webhook").expect("registered");
332 assert_eq!(entry.kind(), FunctionKind::Webhook);
333 assert_eq!(entry.info().name, "stripe_webhook");
334 }
335
336 #[test]
337 fn register_same_name_overwrites_existing_entry() {
338 let mut reg = FunctionRegistry::new();
340 reg.register_webhook_info(info("dup", FunctionKind::Webhook));
341 let mut second = info("dup", FunctionKind::Webhook);
342 second.is_public = false;
343 reg.register_webhook_info(second);
344
345 assert_eq!(reg.len(), 1);
346 assert!(!reg.get("dup").expect("present").info().is_public);
347 }
348
349 #[test]
350 fn iterators_partition_by_kind() {
351 let mut reg = FunctionRegistry::new();
352 reg.register_webhook_info(info("hook_a", FunctionKind::Webhook));
355 reg.register_webhook_info(info("hook_b", FunctionKind::Webhook));
356
357 let names: Vec<&str> = reg.function_names().collect();
358 assert_eq!(names.len(), 2);
359 assert!(names.contains(&"hook_a"));
360 assert!(names.contains(&"hook_b"));
361
362 let webhooks: Vec<&str> = reg.webhooks().map(|(n, _)| n).collect();
364 assert_eq!(webhooks.len(), 2);
365 assert_eq!(reg.queries().count(), 0);
367 assert_eq!(reg.mutations().count(), 0);
368 }
369}