Skip to main content

ries_rs/
udf.rs

1//! User-defined functions for RIES
2//!
3//! Parse and evaluate user-defined functions specified via --define option.
4//! Functions are defined as postfix expressions using existing symbols
5//! plus stack operations: | (dup) and @ (swap).
6
7use crate::symbol::{NumType, Symbol};
8
9/// A user-defined function
10#[derive(Clone, Debug)]
11pub struct UserFunction {
12    /// Weight (complexity) of this function
13    ///
14    /// This field is part of the public API and is used when generating expressions
15    /// that include user-defined functions.
16    #[allow(dead_code)]
17    pub weight: u16,
18    /// Short name (single or few characters)
19    pub name: String,
20    /// Description (for display)
21    ///
22    /// This field is part of the public API for documentation and display purposes.
23    #[allow(dead_code)]
24    pub description: String,
25    /// The body of the function as a postfix expression
26    /// Uses standard symbols plus special stack operations
27    pub body: Vec<UdfOp>,
28    /// Numeric type of result
29    ///
30    /// This field is part of the public API for type inference when combining
31    /// expressions that use user-defined functions.
32    #[allow(dead_code)]
33    pub num_type: NumType,
34}
35
36/// Operations that can appear in a user-defined function
37#[derive(Clone, Debug, PartialEq)]
38pub enum UdfOp {
39    /// A standard RIES symbol (constant or operator)
40    Symbol(Symbol),
41    /// Duplicate top of stack (|)
42    Dup,
43    /// Swap top two stack elements (@)
44    Swap,
45}
46
47impl UserFunction {
48    /// Parse a user-defined function from a definition string
49    /// Format: "weight:name:description:formula"
50    /// Example: "4:sinh:hyperbolic sine:E|r-2/"
51    pub fn parse(spec: &str) -> Result<Self, String> {
52        let parts: Vec<&str> = spec.split(':').collect();
53        if parts.len() != 4 {
54            return Err(format!(
55                "Invalid --define format: expected 4 colon-separated parts, got {}",
56                parts.len()
57            ));
58        }
59
60        let weight: u16 = parts[0]
61            .parse()
62            .map_err(|_| format!("Invalid weight: {}", parts[0]))?;
63
64        let name = parts[1].to_string();
65        if name.is_empty() {
66            return Err("Function name cannot be empty".to_string());
67        }
68
69        let description = parts[2].to_string();
70
71        // Parse the formula (postfix expression)
72        let body = parse_udf_formula(parts[3])?;
73
74        // Determine the numeric type based on the operations used
75        let num_type = infer_num_type(&body);
76
77        Ok(UserFunction {
78            weight,
79            name,
80            description,
81            body,
82            num_type,
83        })
84    }
85
86    /// Get the stack effect of this function (pushed - popped)
87    /// For a unary function, this should be 0 (pop 1, push 1)
88    ///
89    /// This method is part of the public API for library consumers who need
90    /// to validate user-defined functions before use.
91    #[allow(dead_code)]
92    pub fn stack_effect(&self) -> i32 {
93        calculate_stack_effect(&self.body)
94    }
95}
96
97/// Parse a UDF formula string into a vector of operations
98fn parse_udf_formula(formula: &str) -> Result<Vec<UdfOp>, String> {
99    let mut ops = Vec::new();
100
101    if let Some(ch) = formula.chars().find(|c| !c.is_ascii()) {
102        return Err(format!(
103            "Non-ASCII symbol '{}' in function definition; formulas must use ASCII symbols",
104            ch
105        ));
106    }
107
108    for b in formula.bytes() {
109        match b as char {
110            '|' => ops.push(UdfOp::Dup),
111            '@' => ops.push(UdfOp::Swap),
112            _ => {
113                // Try to parse as a standard symbol
114                if let Some(sym) = Symbol::from_byte(b) {
115                    ops.push(UdfOp::Symbol(sym));
116                } else {
117                    return Err(format!(
118                        "Unknown symbol '{}' in function definition",
119                        b as char
120                    ));
121                }
122            }
123        }
124    }
125
126    validate_stack_behavior(&ops)?;
127
128    Ok(ops)
129}
130
131fn validate_stack_behavior(ops: &[UdfOp]) -> Result<(), String> {
132    let mut depth: i32 = 1;
133
134    for (idx, op) in ops.iter().enumerate() {
135        let (required_depth, delta, op_name) = match op {
136            UdfOp::Symbol(sym) => match sym.seft() {
137                crate::symbol::Seft::A => (0, 1, "constant"),
138                crate::symbol::Seft::B => (1, 0, "unary"),
139                crate::symbol::Seft::C => (2, -1, "binary"),
140            },
141            UdfOp::Dup => (1, 1, "dup"),
142            UdfOp::Swap => (2, 0, "swap"),
143        };
144
145        if depth < required_depth {
146            return Err(format!(
147                "Invalid function: stack underflow at op {} ({})",
148                idx + 1,
149                op_name
150            ));
151        }
152
153        depth += delta;
154    }
155
156    if depth != 1 {
157        let effect = depth - 1;
158        return Err(format!(
159            "Invalid function: stack effect is {} (should be 0 for a unary function)",
160            effect
161        ));
162    }
163
164    Ok(())
165}
166
167/// Calculate the net stack effect of a sequence of operations
168fn calculate_stack_effect(ops: &[UdfOp]) -> i32 {
169    let mut effect = 0;
170
171    for op in ops {
172        match op {
173            UdfOp::Symbol(sym) => {
174                // Use the symbol's Seft to determine stack effect
175                let seft = sym.seft();
176                match seft {
177                    crate::symbol::Seft::A => {
178                        // Constant: pushes 1, pops 0 → effect +1
179                        effect += 1;
180                    }
181                    crate::symbol::Seft::B => {
182                        // Unary: pushes 1, pops 1 → effect 0
183                        // But net effect is 0 since we pop first
184                        effect -= 1; // pop
185                        effect += 1; // push
186                    }
187                    crate::symbol::Seft::C => {
188                        // Binary: pushes 1, pops 2 → effect -1
189                        effect -= 2; // pop 2
190                        effect += 1; // push 1
191                    }
192                }
193            }
194            UdfOp::Dup => {
195                // Dup: pops 1, pushes 2 → effect +1
196                effect -= 1;
197                effect += 2;
198            }
199            UdfOp::Swap => {
200                // Swap: pops 2, pushes 2 → effect 0
201                // No net change
202            }
203        }
204    }
205
206    effect
207}
208
209/// Infer the numeric type of a function based on its operations
210fn infer_num_type(ops: &[UdfOp]) -> NumType {
211    for op in ops {
212        if let UdfOp::Symbol(sym) = op {
213            // If any operation produces transcendental results, the function is transcendental
214            // Use result_type with an empty arg_types to check
215            let result = sym.result_type(&[]);
216            if matches!(result, NumType::Transcendental) {
217                return NumType::Transcendental;
218            }
219        }
220    }
221
222    // Default to transcendental for safety
223    NumType::Transcendental
224}
225
226#[cfg(test)]
227mod tests {
228    use super::*;
229
230    #[test]
231    fn test_parse_sinh() {
232        // sinh(x) = (e^x - e^-x) / 2
233        // In postfix: E|r-2/ (exp, dup, recip, subtract, 2, divide)
234        let udf = UserFunction::parse("4:sinh:hyperbolic sine:E|r-2/").unwrap();
235
236        assert_eq!(udf.weight, 4);
237        assert_eq!(udf.name, "sinh");
238        assert_eq!(udf.description, "hyperbolic sine");
239        assert_eq!(udf.stack_effect(), 0);
240    }
241
242    #[test]
243    fn test_parse_xex() {
244        // XeX(x) = x * e^x
245        // In postfix: |E* (dup, exp, multiply)
246        let udf = UserFunction::parse("4:XeX:x*exp(x):|E*").unwrap();
247
248        assert_eq!(udf.weight, 4);
249        assert_eq!(udf.name, "XeX");
250        assert_eq!(udf.stack_effect(), 0);
251
252        // Verify the body
253        assert_eq!(udf.body.len(), 3);
254        assert_eq!(udf.body[0], UdfOp::Dup);
255        assert_eq!(udf.body[1], UdfOp::Symbol(Symbol::Exp));
256        assert_eq!(udf.body[2], UdfOp::Symbol(Symbol::Mul));
257    }
258
259    #[test]
260    fn test_parse_cosh() {
261        // cosh(x) = (e^x + e^-x) / 2
262        // In postfix: E|r+2/
263        let udf = UserFunction::parse("4:cosh:hyperbolic cosine:E|r+2/").unwrap();
264
265        assert_eq!(udf.stack_effect(), 0);
266    }
267
268    #[test]
269    fn test_invalid_stack_effect() {
270        // This should fail because it doesn't produce a valid unary function
271        let result = UserFunction::parse("4:bad:bad function:12+");
272        assert!(result.is_err());
273        assert!(result.unwrap_err().contains("stack effect"));
274    }
275
276    #[test]
277    fn test_unknown_symbol() {
278        let result = UserFunction::parse("4:bad:bad function:xyz");
279        assert!(result.is_err());
280        assert!(result.unwrap_err().contains("Unknown symbol"));
281    }
282
283    #[test]
284    fn test_stack_underflow_swap_rejected() {
285        let result = UserFunction::parse("4:bad:bad function:@");
286        assert!(result.is_err());
287        assert!(result.unwrap_err().contains("stack underflow"));
288    }
289
290    #[test]
291    fn test_stack_underflow_binary_rejected() {
292        let result = UserFunction::parse("4:bad:bad function:+1");
293        assert!(result.is_err());
294        assert!(result.unwrap_err().contains("stack underflow"));
295    }
296
297    #[test]
298    fn test_non_ascii_symbol_rejected() {
299        let result = UserFunction::parse("4:bad:bad function:ı");
300        assert!(result.is_err());
301        assert!(result.unwrap_err().contains("Non-ASCII"));
302    }
303}