Skip to main content

varpulis_runtime/
udf.rs

1//! User-defined functions with type signatures
2//!
3//! Provides traits for native Rust scalar and aggregate UDFs that bypass
4//! the VPL interpreter for maximum performance.
5//!
6//! # Example
7//!
8//! ```rust,no_run
9//! use varpulis_runtime::udf::{ScalarUDF, Signature, TypeConstraint};
10//! use varpulis_core::{Type, Value};
11//!
12//! struct DoubleUdf;
13//! impl ScalarUDF for DoubleUdf {
14//!     fn name(&self) -> &str { "double" }
15//!     fn signature(&self) -> Signature {
16//!         Signature {
17//!             input_types: vec![TypeConstraint::Numeric],
18//!             return_type: Type::Float,
19//!             variadic: None,
20//!         }
21//!     }
22//!     fn evaluate(&self, args: &[Value]) -> Option<Value> {
23//!         match &args[0] {
24//!             Value::Int(i) => Some(Value::Float(*i as f64 * 2.0)),
25//!             Value::Float(f) => Some(Value::Float(f * 2.0)),
26//!             _ => None,
27//!         }
28//!     }
29//! }
30//! ```
31
32use std::sync::Arc;
33
34use rustc_hash::FxHashMap;
35use varpulis_core::{Type, Value};
36
37/// Type constraint for a function parameter.
38#[derive(Debug, Clone)]
39pub enum TypeConstraint {
40    /// Must be exactly this type.
41    Exact(Type),
42    /// Must be a numeric type (Int or Float).
43    Numeric,
44    /// Any type is accepted.
45    Any,
46    /// Must be one of the listed types.
47    OneOf(Vec<Type>),
48}
49
50impl TypeConstraint {
51    /// Check whether a concrete type satisfies this constraint.
52    pub fn matches(&self, ty: &Type) -> bool {
53        match self {
54            Self::Exact(expected) => ty == expected,
55            Self::Numeric => ty.is_numeric(),
56            Self::Any => true,
57            Self::OneOf(types) => types.contains(ty),
58        }
59    }
60}
61
62/// Specification for variadic arguments.
63#[derive(Debug, Clone)]
64pub struct VariadicSpec {
65    /// Minimum number of variadic arguments.
66    pub min_args: usize,
67    /// Type constraint for each variadic argument.
68    pub arg_type: TypeConstraint,
69}
70
71/// Function signature describing inputs and output.
72#[derive(Debug, Clone)]
73pub struct Signature {
74    /// Type constraints for each positional parameter.
75    pub input_types: Vec<TypeConstraint>,
76    /// Return type of the function.
77    pub return_type: Type,
78    /// Optional variadic specification (additional args beyond `input_types`).
79    pub variadic: Option<VariadicSpec>,
80}
81
82/// Trait for native Rust scalar UDFs.
83///
84/// Scalar UDFs evaluate once per event, producing a single output value.
85pub trait ScalarUDF: Send + Sync {
86    /// Function name used in VPL source.
87    fn name(&self) -> &str;
88    /// Type signature for validation and optimization.
89    fn signature(&self) -> Signature;
90    /// Evaluate the function with the given arguments.
91    fn evaluate(&self, args: &[Value]) -> Option<Value>;
92}
93
94/// Trait for native Rust aggregate UDFs.
95///
96/// Aggregate UDFs produce an [`Accumulator`] that collects values across
97/// a window or group.
98pub trait AggregateUDF: Send + Sync {
99    /// Function name used in VPL source.
100    fn name(&self) -> &str;
101    /// Type signature for validation.
102    fn signature(&self) -> Signature;
103    /// Create a fresh accumulator instance.
104    fn init(&self) -> Box<dyn Accumulator>;
105}
106
107/// Stateful accumulator for aggregate UDFs.
108pub trait Accumulator: Send + Sync {
109    /// Incorporate a new value.
110    fn update(&mut self, value: &Value);
111    /// Merge another accumulator's state into this one.
112    fn merge(&mut self, other: &dyn Accumulator);
113    /// Produce the final aggregate result.
114    fn finish(&self) -> Value;
115    /// Reset to the initial state.
116    fn reset(&mut self);
117    /// Clone this accumulator into a boxed trait object.
118    fn clone_box(&self) -> Box<dyn Accumulator>;
119}
120
121/// Registry of native UDFs.
122///
123/// The engine checks this registry before falling through to VPL-interpreted
124/// user functions, allowing native Rust implementations to take priority.
125pub struct UdfRegistry {
126    scalar_udfs: FxHashMap<String, Arc<dyn ScalarUDF>>,
127    aggregate_udfs: FxHashMap<String, Arc<dyn AggregateUDF>>,
128}
129
130impl std::fmt::Debug for UdfRegistry {
131    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
132        f.debug_struct("UdfRegistry")
133            .field("scalar_udfs", &self.scalar_udfs.keys().collect::<Vec<_>>())
134            .field(
135                "aggregate_udfs",
136                &self.aggregate_udfs.keys().collect::<Vec<_>>(),
137            )
138            .finish_non_exhaustive()
139    }
140}
141
142impl UdfRegistry {
143    pub fn new() -> Self {
144        Self {
145            scalar_udfs: FxHashMap::default(),
146            aggregate_udfs: FxHashMap::default(),
147        }
148    }
149
150    /// Register a scalar UDF.
151    pub fn register_scalar(&mut self, udf: Arc<dyn ScalarUDF>) {
152        self.scalar_udfs.insert(udf.name().to_string(), udf);
153    }
154
155    /// Register an aggregate UDF.
156    pub fn register_aggregate(&mut self, udf: Arc<dyn AggregateUDF>) {
157        self.aggregate_udfs.insert(udf.name().to_string(), udf);
158    }
159
160    /// Look up a scalar UDF by name.
161    pub fn get_scalar(&self, name: &str) -> Option<&Arc<dyn ScalarUDF>> {
162        self.scalar_udfs.get(name)
163    }
164
165    /// Look up an aggregate UDF by name.
166    pub fn get_aggregate(&self, name: &str) -> Option<&Arc<dyn AggregateUDF>> {
167        self.aggregate_udfs.get(name)
168    }
169
170    /// Validate a function call against registered UDF signatures.
171    ///
172    /// Returns `Ok(return_type)` if the call is valid, or `Err(reason)` otherwise.
173    pub fn validate_call(&self, name: &str, arg_types: &[Type]) -> Result<Type, String> {
174        if let Some(udf) = self.scalar_udfs.get(name) {
175            let sig = udf.signature();
176            validate_signature(&sig, arg_types)?;
177            return Ok(sig.return_type);
178        }
179        if let Some(udf) = self.aggregate_udfs.get(name) {
180            let sig = udf.signature();
181            validate_signature(&sig, arg_types)?;
182            return Ok(sig.return_type);
183        }
184        Err(format!("unknown UDF: {name}"))
185    }
186
187    /// Returns true if the registry has no registered UDFs.
188    pub fn is_empty(&self) -> bool {
189        self.scalar_udfs.is_empty() && self.aggregate_udfs.is_empty()
190    }
191}
192
193impl Default for UdfRegistry {
194    fn default() -> Self {
195        Self::new()
196    }
197}
198
199fn validate_signature(sig: &Signature, arg_types: &[Type]) -> Result<(), String> {
200    let required = sig.input_types.len();
201
202    if let Some(variadic) = &sig.variadic {
203        if arg_types.len() < required + variadic.min_args {
204            return Err(format!(
205                "expected at least {} arguments, got {}",
206                required + variadic.min_args,
207                arg_types.len()
208            ));
209        }
210        // Validate required args
211        for (i, constraint) in sig.input_types.iter().enumerate() {
212            if !constraint.matches(&arg_types[i]) {
213                return Err(format!(
214                    "argument {} type mismatch: expected {:?}, got {:?}",
215                    i, constraint, arg_types[i]
216                ));
217            }
218        }
219        // Validate variadic args
220        for (i, ty) in arg_types[required..].iter().enumerate() {
221            if !variadic.arg_type.matches(ty) {
222                return Err(format!(
223                    "variadic argument {} type mismatch: expected {:?}, got {:?}",
224                    i, variadic.arg_type, ty
225                ));
226            }
227        }
228    } else {
229        if arg_types.len() != required {
230            return Err(format!(
231                "expected {} arguments, got {}",
232                required,
233                arg_types.len()
234            ));
235        }
236        for (i, constraint) in sig.input_types.iter().enumerate() {
237            if !constraint.matches(&arg_types[i]) {
238                return Err(format!(
239                    "argument {} type mismatch: expected {:?}, got {:?}",
240                    i, constraint, arg_types[i]
241                ));
242            }
243        }
244    }
245
246    Ok(())
247}
248
249#[cfg(test)]
250mod tests {
251    use super::*;
252
253    struct DoubleUdf;
254    impl ScalarUDF for DoubleUdf {
255        fn name(&self) -> &'static str {
256            "double"
257        }
258        fn signature(&self) -> Signature {
259            Signature {
260                input_types: vec![TypeConstraint::Numeric],
261                return_type: Type::Float,
262                variadic: None,
263            }
264        }
265        fn evaluate(&self, args: &[Value]) -> Option<Value> {
266            match &args[0] {
267                Value::Int(i) => Some(Value::Float(*i as f64 * 2.0)),
268                Value::Float(f) => Some(Value::Float(f * 2.0)),
269                _ => None,
270            }
271        }
272    }
273
274    struct SumAccumulator {
275        total: f64,
276    }
277
278    impl Accumulator for SumAccumulator {
279        fn update(&mut self, value: &Value) {
280            match value {
281                Value::Int(i) => self.total += *i as f64,
282                Value::Float(f) => self.total += f,
283                _ => {}
284            }
285        }
286        fn merge(&mut self, other: &dyn Accumulator) {
287            // In practice we'd downcast; for testing just use finish()
288            let val = other.finish();
289            self.update(&val);
290        }
291        fn finish(&self) -> Value {
292            Value::Float(self.total)
293        }
294        fn reset(&mut self) {
295            self.total = 0.0;
296        }
297        fn clone_box(&self) -> Box<dyn Accumulator> {
298            Box::new(Self { total: self.total })
299        }
300    }
301
302    struct CustomSumUdf;
303    impl AggregateUDF for CustomSumUdf {
304        fn name(&self) -> &'static str {
305            "custom_sum"
306        }
307        fn signature(&self) -> Signature {
308            Signature {
309                input_types: vec![TypeConstraint::Numeric],
310                return_type: Type::Float,
311                variadic: None,
312            }
313        }
314        fn init(&self) -> Box<dyn Accumulator> {
315            Box::new(SumAccumulator { total: 0.0 })
316        }
317    }
318
319    #[test]
320    fn test_scalar_udf_evaluate() {
321        let udf = DoubleUdf;
322        assert_eq!(udf.evaluate(&[Value::Int(5)]), Some(Value::Float(10.0)));
323        assert_eq!(udf.evaluate(&[Value::Float(2.5)]), Some(Value::Float(5.0)));
324        assert_eq!(udf.evaluate(&[Value::Null]), None);
325    }
326
327    #[test]
328    fn test_registry_lookup() {
329        let mut registry = UdfRegistry::new();
330        assert!(registry.is_empty());
331
332        registry.register_scalar(Arc::new(DoubleUdf));
333        registry.register_aggregate(Arc::new(CustomSumUdf));
334
335        assert!(!registry.is_empty());
336        assert!(registry.get_scalar("double").is_some());
337        assert!(registry.get_aggregate("custom_sum").is_some());
338        assert!(registry.get_scalar("unknown").is_none());
339    }
340
341    #[test]
342    fn test_validate_call_valid() {
343        let mut registry = UdfRegistry::new();
344        registry.register_scalar(Arc::new(DoubleUdf));
345
346        let result = registry.validate_call("double", &[Type::Int]);
347        assert!(result.is_ok());
348        assert_eq!(result.unwrap(), Type::Float);
349    }
350
351    #[test]
352    fn test_validate_call_wrong_type() {
353        let mut registry = UdfRegistry::new();
354        registry.register_scalar(Arc::new(DoubleUdf));
355
356        let result = registry.validate_call("double", &[Type::Str]);
357        assert!(result.is_err());
358    }
359
360    #[test]
361    fn test_validate_call_wrong_arity() {
362        let mut registry = UdfRegistry::new();
363        registry.register_scalar(Arc::new(DoubleUdf));
364
365        let result = registry.validate_call("double", &[Type::Int, Type::Int]);
366        assert!(result.is_err());
367    }
368
369    #[test]
370    fn test_validate_unknown_udf() {
371        let registry = UdfRegistry::new();
372        let result = registry.validate_call("nonexistent", &[]);
373        assert!(result.is_err());
374    }
375
376    #[test]
377    fn test_accumulator_lifecycle() {
378        let udf = CustomSumUdf;
379        let mut acc = udf.init();
380
381        acc.update(&Value::Int(10));
382        acc.update(&Value::Float(5.5));
383        assert_eq!(acc.finish(), Value::Float(15.5));
384
385        let cloned = acc.clone_box();
386        assert_eq!(cloned.finish(), Value::Float(15.5));
387
388        acc.reset();
389        assert_eq!(acc.finish(), Value::Float(0.0));
390    }
391
392    #[test]
393    fn test_type_constraint_matches() {
394        assert!(TypeConstraint::Any.matches(&Type::Int));
395        assert!(TypeConstraint::Any.matches(&Type::Str));
396
397        assert!(TypeConstraint::Numeric.matches(&Type::Int));
398        assert!(TypeConstraint::Numeric.matches(&Type::Float));
399        assert!(!TypeConstraint::Numeric.matches(&Type::Str));
400
401        assert!(TypeConstraint::Exact(Type::Int).matches(&Type::Int));
402        assert!(!TypeConstraint::Exact(Type::Int).matches(&Type::Float));
403
404        assert!(TypeConstraint::OneOf(vec![Type::Int, Type::Str]).matches(&Type::Str));
405        assert!(!TypeConstraint::OneOf(vec![Type::Int, Type::Str]).matches(&Type::Bool));
406    }
407
408    #[test]
409    fn test_variadic_validation() {
410        let mut registry = UdfRegistry::new();
411
412        struct ConcatUdf;
413        impl ScalarUDF for ConcatUdf {
414            fn name(&self) -> &'static str {
415                "concat"
416            }
417            fn signature(&self) -> Signature {
418                Signature {
419                    input_types: vec![],
420                    return_type: Type::Str,
421                    variadic: Some(VariadicSpec {
422                        min_args: 1,
423                        arg_type: TypeConstraint::Exact(Type::Str),
424                    }),
425                }
426            }
427            fn evaluate(&self, _args: &[Value]) -> Option<Value> {
428                None
429            }
430        }
431
432        registry.register_scalar(Arc::new(ConcatUdf));
433
434        // Valid: 1 string arg
435        assert!(registry.validate_call("concat", &[Type::Str]).is_ok());
436        // Valid: 3 string args
437        assert!(registry
438            .validate_call("concat", &[Type::Str, Type::Str, Type::Str])
439            .is_ok());
440        // Invalid: 0 args
441        assert!(registry.validate_call("concat", &[]).is_err());
442        // Invalid: int arg
443        assert!(registry.validate_call("concat", &[Type::Int]).is_err());
444    }
445}