1use crate::symbol::{NumType, Symbol};
8
9#[derive(Clone, Debug)]
11pub struct UserFunction {
12 #[allow(dead_code)]
17 pub weight: u16,
18 pub name: String,
20 #[allow(dead_code)]
24 pub description: String,
25 pub body: Vec<UdfOp>,
28 #[allow(dead_code)]
33 pub num_type: NumType,
34}
35
36#[derive(Clone, Debug, PartialEq)]
38pub enum UdfOp {
39 Symbol(Symbol),
41 Dup,
43 Swap,
45}
46
47impl UserFunction {
48 pub fn parse(spec: &str) -> Result<Self, String> {
52 let parts: Vec<&str> = spec.split(':').collect();
53 if parts.len() != 4 {
54 return Err(format!(
55 "Invalid --define format: expected 4 colon-separated parts, got {}",
56 parts.len()
57 ));
58 }
59
60 let weight: u16 = parts[0]
61 .parse()
62 .map_err(|_| format!("Invalid weight: {}", parts[0]))?;
63
64 let name = parts[1].to_string();
65 if name.is_empty() {
66 return Err("Function name cannot be empty".to_string());
67 }
68
69 let description = parts[2].to_string();
70
71 let body = parse_udf_formula(parts[3])?;
73
74 let num_type = infer_num_type(&body);
76
77 Ok(UserFunction {
78 weight,
79 name,
80 description,
81 body,
82 num_type,
83 })
84 }
85
86 #[allow(dead_code)]
92 pub fn stack_effect(&self) -> i32 {
93 calculate_stack_effect(&self.body)
94 }
95}
96
97fn parse_udf_formula(formula: &str) -> Result<Vec<UdfOp>, String> {
99 let mut ops = Vec::new();
100
101 if let Some(ch) = formula.chars().find(|c| !c.is_ascii()) {
102 return Err(format!(
103 "Non-ASCII symbol '{}' in function definition; formulas must use ASCII symbols",
104 ch
105 ));
106 }
107
108 for b in formula.bytes() {
109 match b as char {
110 '|' => ops.push(UdfOp::Dup),
111 '@' => ops.push(UdfOp::Swap),
112 _ => {
113 if let Some(sym) = Symbol::from_byte(b) {
115 ops.push(UdfOp::Symbol(sym));
116 } else {
117 return Err(format!(
118 "Unknown symbol '{}' in function definition",
119 b as char
120 ));
121 }
122 }
123 }
124 }
125
126 validate_stack_behavior(&ops)?;
127
128 Ok(ops)
129}
130
131fn validate_stack_behavior(ops: &[UdfOp]) -> Result<(), String> {
132 let mut depth: i32 = 1;
133
134 for (idx, op) in ops.iter().enumerate() {
135 let (required_depth, delta, op_name) = match op {
136 UdfOp::Symbol(sym) => match sym.seft() {
137 crate::symbol::Seft::A => (0, 1, "constant"),
138 crate::symbol::Seft::B => (1, 0, "unary"),
139 crate::symbol::Seft::C => (2, -1, "binary"),
140 },
141 UdfOp::Dup => (1, 1, "dup"),
142 UdfOp::Swap => (2, 0, "swap"),
143 };
144
145 if depth < required_depth {
146 return Err(format!(
147 "Invalid function: stack underflow at op {} ({})",
148 idx + 1,
149 op_name
150 ));
151 }
152
153 depth += delta;
154 }
155
156 if depth != 1 {
157 let effect = depth - 1;
158 return Err(format!(
159 "Invalid function: stack effect is {} (should be 0 for a unary function)",
160 effect
161 ));
162 }
163
164 Ok(())
165}
166
167fn calculate_stack_effect(ops: &[UdfOp]) -> i32 {
169 let mut effect = 0;
170
171 for op in ops {
172 match op {
173 UdfOp::Symbol(sym) => {
174 let seft = sym.seft();
176 match seft {
177 crate::symbol::Seft::A => {
178 effect += 1;
180 }
181 crate::symbol::Seft::B => {
182 effect -= 1; effect += 1; }
187 crate::symbol::Seft::C => {
188 effect -= 2; effect += 1; }
192 }
193 }
194 UdfOp::Dup => {
195 effect -= 1;
197 effect += 2;
198 }
199 UdfOp::Swap => {
200 }
203 }
204 }
205
206 effect
207}
208
209fn infer_num_type(ops: &[UdfOp]) -> NumType {
211 for op in ops {
212 if let UdfOp::Symbol(sym) = op {
213 let result = sym.result_type(&[]);
216 if matches!(result, NumType::Transcendental) {
217 return NumType::Transcendental;
218 }
219 }
220 }
221
222 NumType::Transcendental
224}
225
226#[cfg(test)]
227mod tests {
228 use super::*;
229
230 #[test]
231 fn test_parse_sinh() {
232 let udf = UserFunction::parse("4:sinh:hyperbolic sine:E|r-2/").unwrap();
235
236 assert_eq!(udf.weight, 4);
237 assert_eq!(udf.name, "sinh");
238 assert_eq!(udf.description, "hyperbolic sine");
239 assert_eq!(udf.stack_effect(), 0);
240 }
241
242 #[test]
243 fn test_parse_xex() {
244 let udf = UserFunction::parse("4:XeX:x*exp(x):|E*").unwrap();
247
248 assert_eq!(udf.weight, 4);
249 assert_eq!(udf.name, "XeX");
250 assert_eq!(udf.stack_effect(), 0);
251
252 assert_eq!(udf.body.len(), 3);
254 assert_eq!(udf.body[0], UdfOp::Dup);
255 assert_eq!(udf.body[1], UdfOp::Symbol(Symbol::Exp));
256 assert_eq!(udf.body[2], UdfOp::Symbol(Symbol::Mul));
257 }
258
259 #[test]
260 fn test_parse_cosh() {
261 let udf = UserFunction::parse("4:cosh:hyperbolic cosine:E|r+2/").unwrap();
264
265 assert_eq!(udf.stack_effect(), 0);
266 }
267
268 #[test]
269 fn test_invalid_stack_effect() {
270 let result = UserFunction::parse("4:bad:bad function:12+");
272 assert!(result.is_err());
273 assert!(result.unwrap_err().contains("stack effect"));
274 }
275
276 #[test]
277 fn test_unknown_symbol() {
278 let result = UserFunction::parse("4:bad:bad function:xyz");
279 assert!(result.is_err());
280 assert!(result.unwrap_err().contains("Unknown symbol"));
281 }
282
283 #[test]
284 fn test_stack_underflow_swap_rejected() {
285 let result = UserFunction::parse("4:bad:bad function:@");
286 assert!(result.is_err());
287 assert!(result.unwrap_err().contains("stack underflow"));
288 }
289
290 #[test]
291 fn test_stack_underflow_binary_rejected() {
292 let result = UserFunction::parse("4:bad:bad function:+1");
293 assert!(result.is_err());
294 assert!(result.unwrap_err().contains("stack underflow"));
295 }
296
297 #[test]
298 fn test_non_ascii_symbol_rejected() {
299 let result = UserFunction::parse("4:bad:bad function:ı");
300 assert!(result.is_err());
301 assert!(result.unwrap_err().contains("Non-ASCII"));
302 }
303}