Skip to main content

shape_runtime/
const_eval.rs

1//! Const evaluation for annotation metadata() handlers
2//!
3//! This module provides compile-time evaluation of Shape expressions.
4//! Only a subset of expressions are allowed (literals, object/array construction,
5//! annotation parameters, and const arithmetic).
6//!
7//! ## Purpose
8//!
9//! Const evaluation enables:
10//! - **LSP to extract metadata without runtime execution** - Show code lenses, hover info
11//! - **Compiler optimizations** based on static metadata (pure functions, cacheable results)
12//! - **Static analysis and documentation generation**
13//!
14//! ## How It Works
15//!
16//! When the LSP encounters an `annotation ... { ... }` definition, it:
17//!
18//! 1. **Parses the annotation definition** (from current file or imports)
19//! 2. **Finds the `metadata()` handler** in the handlers list
20//! 3. **Const-evaluates the handler body** using this module
21//! 4. **Extracts special properties** from the result:
22//!    - `code_lens: [...]` → Creates IDE action buttons
23//!    - `pure: true` → Marks for compiler optimization
24//!    - Custom properties → Stored for user's own tooling
25//!
26//! Example:
27//!
28//! ```shape
29//! annotation strategy() {
30//!     metadata() {
31//!         {
32//!             is_strategy: true,              // Custom metadata
33//!             code_lens: [                    // Special: IDE integration
34//!                 { title: "▶ Run", command: "shape.runBacktest" }
35//!             ]
36//!         }
37//!     }
38//! }
39//! ```
40//!
41//! When a function has `@strategy`, the LSP:
42//! 1. Looks up the `@strategy` annotation definition
43//! 2. Const-evaluates `metadata()` → `{ is_strategy: true, code_lens: [...] }`
44//! 3. Creates a "▶ Run" button above the function
45//!
46//! ## Allowed Constructs
47//!
48//! - Literals: `42`, `"hello"`, `true`, `null`
49//! - Objects: `{ key: value, ... }`
50//! - Arrays: `[1, 2, 3]`
51//! - Annotation parameters (captured in scope)
52//! - Const arithmetic: `2 + 2`, `"a" + "b"`
53//!
54//! ## Not Allowed
55//!
56//! - Function calls (runtime dependency)
57//! - Variable references (except annotation parameters)
58//! - `ctx` or `fn` access (runtime state)
59//! - Side effects
60//! - Non-const conditionals
61
62use shape_ast::ast::{Expr, Literal, ObjectEntry};
63use shape_ast::error::{Result, ShapeError};
64use shape_value::ValueWord;
65use std::collections::HashMap;
66use std::sync::Arc;
67
68/// Const evaluator for metadata() handlers
69#[derive(Debug, Clone)]
70pub struct ConstEvaluator {
71    /// Annotation parameters available during evaluation
72    /// Maps parameter name → const value
73    params: HashMap<String, ValueWord>,
74}
75
76impl ConstEvaluator {
77    /// Create a new const evaluator with annotation parameters
78    pub fn new() -> Self {
79        Self {
80            params: HashMap::new(),
81        }
82    }
83
84    /// Create a const evaluator with annotation parameters
85    pub fn with_params(params: HashMap<String, ValueWord>) -> Self {
86        Self {
87            params: params.into_iter().map(|(k, v)| (k, v)).collect(),
88        }
89    }
90
91    /// Add an annotation parameter to the scope
92    pub fn add_param(&mut self, name: String, value: ValueWord) {
93        self.params.insert(name, value);
94    }
95
96    /// Add an annotation parameter to the scope (ValueWord, avoids ValueWord conversion)
97    pub fn add_param_nb(&mut self, name: String, value: ValueWord) {
98        self.params.insert(name, value);
99    }
100
101    /// Evaluate an expression as a const (compile-time) value
102    ///
103    /// Returns an error if the expression uses non-const constructs.
104    pub fn eval(&self, expr: &Expr) -> Result<ValueWord> {
105        Ok(self.eval_nb(expr)?.clone())
106    }
107
108    /// Evaluate an expression as a const ValueWord value (avoids ValueWord materialization)
109    pub fn eval_as_nb(&self, expr: &Expr) -> Result<ValueWord> {
110        self.eval_nb(expr)
111    }
112
113    /// Evaluate an expression as a const ValueWord value
114    fn eval_nb(&self, expr: &Expr) -> Result<ValueWord> {
115        match expr {
116            // Literals are always const
117            Expr::Literal(lit, _) => match lit {
118                Literal::Int(i) => Ok(ValueWord::from_f64(*i as f64)),
119                Literal::UInt(u) => Ok(ValueWord::from_native_u64(*u)),
120                Literal::TypedInt(v, _) => Ok(ValueWord::from_i64(*v)),
121                Literal::Number(n) => Ok(ValueWord::from_f64(*n)),
122                Literal::Decimal(d) => {
123                    use rust_decimal::prelude::ToPrimitive;
124                    Ok(ValueWord::from_f64(d.to_f64().unwrap_or(0.0)))
125                }
126                Literal::String(s) => Ok(ValueWord::from_string(Arc::new(s.clone()))),
127                Literal::FormattedString { value, .. } => {
128                    Ok(ValueWord::from_string(Arc::new(value.clone())))
129                }
130                Literal::ContentString { value, .. } => {
131                    Ok(ValueWord::from_string(Arc::new(value.clone())))
132                }
133                Literal::Char(c) => Ok(ValueWord::from_char(*c)),
134                Literal::Bool(b) => Ok(ValueWord::from_bool(*b)),
135                Literal::None => Ok(ValueWord::none()),
136                Literal::Unit => Ok(ValueWord::unit()),
137                Literal::Timeframe(tf) => Ok(ValueWord::from_timeframe(*tf)),
138            },
139
140            // Object literals - recursively evaluate all values
141            Expr::Object(entries, _) => {
142                let mut pairs: Vec<(String, ValueWord)> = Vec::new();
143                for entry in entries {
144                    match entry {
145                        ObjectEntry::Field {
146                            key,
147                            value,
148                            type_annotation: _,
149                        } => {
150                            let val = self.eval_nb(value)?;
151                            pairs.push((key.clone(), val));
152                        }
153                        ObjectEntry::Spread(_) => {
154                            return Err(ShapeError::RuntimeError {
155                                message: "Object spread (...) not allowed in const context"
156                                    .to_string(),
157                                location: None,
158                            });
159                        }
160                    }
161                }
162                let ref_pairs: Vec<(&str, ValueWord)> =
163                    pairs.iter().map(|(k, v)| (k.as_str(), v.clone())).collect();
164                Ok(crate::type_schema::typed_object_from_nb_pairs(&ref_pairs))
165            }
166
167            // Array literals - recursively evaluate all elements
168            Expr::Array(elements, _) => {
169                let mut arr = Vec::new();
170                for elem in elements {
171                    arr.push(self.eval_nb(elem)?);
172                }
173                Ok(ValueWord::from_array(Arc::new(arr)))
174            }
175
176            // Identifiers - only allowed if they're annotation parameters
177            Expr::Identifier(name, _span) => {
178                self.params
179                    .get(name)
180                    .cloned()
181                    .ok_or_else(|| ShapeError::RuntimeError {
182                        message: format!(
183                            "Cannot reference variable '{}' in const context (metadata()). \
184                             Only annotation parameters are allowed.",
185                            name
186                        ),
187                        location: None,
188                    })
189            }
190
191            // Binary operations - only const arithmetic/string concat
192            Expr::BinaryOp {
193                left,
194                op,
195                right,
196                span: _,
197            } => {
198                let left_val = self.eval_nb(left)?;
199                let right_val = self.eval_nb(right)?;
200
201                use shape_ast::ast::BinaryOp;
202                match op {
203                    // Arithmetic
204                    BinaryOp::Add => self.const_add_nb(left_val, right_val),
205                    BinaryOp::Sub => {
206                        self.const_arith_nb(left_val, right_val, "subtraction", |a, b| a - b)
207                    }
208                    BinaryOp::Mul => {
209                        self.const_arith_nb(left_val, right_val, "multiplication", |a, b| a * b)
210                    }
211                    BinaryOp::Div => {
212                        let a = left_val.as_f64().ok_or_else(|| ShapeError::RuntimeError {
213                            message: "Const division only works on numbers".to_string(),
214                            location: None,
215                        })?;
216                        let b = right_val.as_f64().ok_or_else(|| ShapeError::RuntimeError {
217                            message: "Const division only works on numbers".to_string(),
218                            location: None,
219                        })?;
220                        if b == 0.0 {
221                            Err(ShapeError::RuntimeError {
222                                message: "Division by zero in const context".to_string(),
223                                location: None,
224                            })
225                        } else {
226                            Ok(ValueWord::from_f64(a / b))
227                        }
228                    }
229                    BinaryOp::Mod => {
230                        self.const_arith_nb(left_val, right_val, "modulo", |a, b| a % b)
231                    }
232
233                    // Comparison
234                    BinaryOp::Equal => Ok(ValueWord::from_bool(left_val.vw_equals(&right_val))),
235                    BinaryOp::NotEqual => Ok(ValueWord::from_bool(!left_val.vw_equals(&right_val))),
236                    BinaryOp::Less => self.const_compare_nb(left_val, right_val, |a, b| a < b),
237                    BinaryOp::LessEq => self.const_compare_nb(left_val, right_val, |a, b| a <= b),
238                    BinaryOp::Greater => self.const_compare_nb(left_val, right_val, |a, b| a > b),
239                    BinaryOp::GreaterEq => {
240                        self.const_compare_nb(left_val, right_val, |a, b| a >= b)
241                    }
242
243                    // Logical
244                    BinaryOp::And => Ok(ValueWord::from_bool(
245                        left_val.is_truthy() && right_val.is_truthy(),
246                    )),
247                    BinaryOp::Or => Ok(ValueWord::from_bool(
248                        left_val.is_truthy() || right_val.is_truthy(),
249                    )),
250
251                    // Not allowed in const context
252                    _ => Err(ShapeError::RuntimeError {
253                        message: format!("Binary operator {:?} not allowed in const context", op),
254                        location: None,
255                    }),
256                }
257            }
258
259            // Unary operations
260            Expr::UnaryOp {
261                op,
262                operand,
263                span: _,
264            } => {
265                let val = self.eval_nb(operand)?;
266                use shape_ast::ast::UnaryOp;
267                match op {
268                    UnaryOp::Not => Ok(ValueWord::from_bool(!val.is_truthy())),
269                    UnaryOp::Neg => {
270                        if let Some(n) = val.as_f64() {
271                            Ok(ValueWord::from_f64(-n))
272                        } else {
273                            Err(ShapeError::RuntimeError {
274                                message: "Cannot negate non-number in const context".to_string(),
275                                location: None,
276                            })
277                        }
278                    }
279                    UnaryOp::BitNot => Err(ShapeError::RuntimeError {
280                        message: "Bitwise NOT not allowed in const context".to_string(),
281                        location: None,
282                    }),
283                }
284            }
285
286            // Everything else is not allowed in const context
287            Expr::FunctionCall { .. } => Err(ShapeError::RuntimeError {
288                message: "Function calls are not allowed in const context (metadata())".to_string(),
289                location: None,
290            }),
291
292            Expr::PropertyAccess { .. } => Err(ShapeError::RuntimeError {
293                message:
294                    "Property access (obj.field) is not allowed in const context (metadata()). \
295                         Cannot access runtime state like ctx.* or fn.*"
296                        .to_string(),
297                location: None,
298            }),
299
300            _ => Err(ShapeError::RuntimeError {
301                message: format!(
302                    "Expression type not allowed in const context (metadata()): {:?}",
303                    expr
304                ),
305                location: None,
306            }),
307        }
308    }
309
310    // Const arithmetic operations (ValueWord)
311
312    fn const_add_nb(&self, left: ValueWord, right: ValueWord) -> Result<ValueWord> {
313        if let (Some(a), Some(b)) = (left.as_f64(), right.as_f64()) {
314            return Ok(ValueWord::from_f64(a + b));
315        }
316        if let (Some(a), Some(b)) = (left.as_str(), right.as_str()) {
317            return Ok(ValueWord::from_string(Arc::new(format!("{}{}", a, b))));
318        }
319        Err(ShapeError::RuntimeError {
320            message: "Const addition only works on numbers or strings".to_string(),
321            location: None,
322        })
323    }
324
325    fn const_arith_nb(
326        &self,
327        left: ValueWord,
328        right: ValueWord,
329        op_name: &str,
330        f: fn(f64, f64) -> f64,
331    ) -> Result<ValueWord> {
332        let a = left.as_f64().ok_or_else(|| ShapeError::RuntimeError {
333            message: format!("Const {} only works on numbers", op_name),
334            location: None,
335        })?;
336        let b = right.as_f64().ok_or_else(|| ShapeError::RuntimeError {
337            message: format!("Const {} only works on numbers", op_name),
338            location: None,
339        })?;
340        Ok(ValueWord::from_f64(f(a, b)))
341    }
342
343    fn const_compare_nb(
344        &self,
345        left: ValueWord,
346        right: ValueWord,
347        cmp: fn(f64, f64) -> bool,
348    ) -> Result<ValueWord> {
349        let a = left.as_f64().ok_or_else(|| ShapeError::RuntimeError {
350            message: "Const comparison only works on numbers".to_string(),
351            location: None,
352        })?;
353        let b = right.as_f64().ok_or_else(|| ShapeError::RuntimeError {
354            message: "Const comparison only works on numbers".to_string(),
355            location: None,
356        })?;
357        Ok(ValueWord::from_bool(cmp(a, b)))
358    }
359}
360
361impl Default for ConstEvaluator {
362    fn default() -> Self {
363        Self::new()
364    }
365}
366
367#[cfg(test)]
368mod tests {
369    use super::*;
370    use shape_ast::ast::Span;
371    use std::sync::Arc;
372
373    #[test]
374    fn test_const_number_literal() {
375        let evaluator = ConstEvaluator::new();
376        let expr = Expr::Literal(Literal::Number(42.0), Span::DUMMY);
377        let result = evaluator.eval(&expr).unwrap();
378        assert_eq!(result, ValueWord::from_f64(42.0));
379    }
380
381    #[test]
382    fn test_const_string_literal() {
383        let evaluator = ConstEvaluator::new();
384        let expr = Expr::Literal(Literal::String("hello".to_string()), Span::DUMMY);
385        let result = evaluator.eval(&expr).unwrap();
386        assert_eq!(
387            result,
388            ValueWord::from_string(Arc::new("hello".to_string()))
389        );
390    }
391
392    #[test]
393    fn test_const_formatted_string_literal() {
394        let evaluator = ConstEvaluator::new();
395        let expr = Expr::Literal(
396            Literal::FormattedString {
397                value: "value: {x}".to_string(),
398                mode: shape_ast::ast::InterpolationMode::Braces,
399            },
400            Span::DUMMY,
401        );
402        let result = evaluator.eval(&expr).unwrap();
403        assert_eq!(
404            result,
405            ValueWord::from_string(Arc::new("value: {x}".to_string()))
406        );
407    }
408
409    #[test]
410    fn test_const_boolean_literal() {
411        let evaluator = ConstEvaluator::new();
412        let expr = Expr::Literal(Literal::Bool(true), Span::DUMMY);
413        let result = evaluator.eval(&expr).unwrap();
414        assert_eq!(result, ValueWord::from_bool(true));
415    }
416
417    #[test]
418    fn test_const_object_literal() {
419        let evaluator = ConstEvaluator::new();
420        let expr = Expr::Object(
421            vec![
422                ObjectEntry::Field {
423                    key: "key1".to_string(),
424                    value: Expr::Literal(Literal::Number(42.0), Span::DUMMY),
425                    type_annotation: None,
426                },
427                ObjectEntry::Field {
428                    key: "key2".to_string(),
429                    value: Expr::Literal(Literal::String("value".to_string()), Span::DUMMY),
430                    type_annotation: None,
431                },
432            ],
433            Span::DUMMY,
434        );
435        let result = evaluator.eval(&expr).unwrap();
436
437        let obj =
438            crate::type_schema::typed_object_to_hashmap_nb(&result).expect("Expected TypedObject");
439        assert_eq!(obj.get("key1").and_then(|v| v.as_f64()), Some(42.0));
440        assert_eq!(obj.get("key2").and_then(|v| v.as_str()), Some("value"));
441    }
442
443    #[test]
444    fn test_const_array_literal() {
445        let evaluator = ConstEvaluator::new();
446        let expr = Expr::Array(
447            vec![
448                Expr::Literal(Literal::Number(1.0), Span::DUMMY),
449                Expr::Literal(Literal::Number(2.0), Span::DUMMY),
450                Expr::Literal(Literal::Number(3.0), Span::DUMMY),
451            ],
452            Span::DUMMY,
453        );
454        let result = evaluator.eval(&expr).unwrap();
455
456        let arr = result.as_any_array().expect("Expected array").to_generic();
457        assert_eq!(arr.len(), 3);
458        assert_eq!(arr[0].as_f64(), Some(1.0));
459        assert_eq!(arr[1].as_f64(), Some(2.0));
460        assert_eq!(arr[2].as_f64(), Some(3.0));
461    }
462
463    #[test]
464    fn test_const_arithmetic_add() {
465        let evaluator = ConstEvaluator::new();
466        let expr = Expr::BinaryOp {
467            left: Box::new(Expr::Literal(Literal::Number(2.0), Span::DUMMY)),
468            op: shape_ast::ast::BinaryOp::Add,
469            right: Box::new(Expr::Literal(Literal::Number(3.0), Span::DUMMY)),
470            span: Span::DUMMY,
471        };
472        let result = evaluator.eval(&expr).unwrap();
473        assert_eq!(result, ValueWord::from_f64(5.0));
474    }
475
476    #[test]
477    fn test_const_string_concat() {
478        let evaluator = ConstEvaluator::new();
479        let expr = Expr::BinaryOp {
480            left: Box::new(Expr::Literal(
481                Literal::String("hello ".to_string()),
482                Span::DUMMY,
483            )),
484            op: shape_ast::ast::BinaryOp::Add,
485            right: Box::new(Expr::Literal(
486                Literal::String("world".to_string()),
487                Span::DUMMY,
488            )),
489            span: Span::DUMMY,
490        };
491        let result = evaluator.eval(&expr).unwrap();
492        assert_eq!(
493            result,
494            ValueWord::from_string(Arc::new("hello world".to_string()))
495        );
496    }
497
498    #[test]
499    fn test_const_annotation_param() {
500        let mut evaluator = ConstEvaluator::new();
501        evaluator.add_param("period".to_string(), ValueWord::from_f64(20.0));
502
503        let expr = Expr::Identifier("period".to_string(), Span::DUMMY);
504        let result = evaluator.eval(&expr).unwrap();
505        assert_eq!(result, ValueWord::from_f64(20.0));
506    }
507
508    #[test]
509    fn test_const_nested_object() {
510        let evaluator = ConstEvaluator::new();
511        let expr = Expr::Object(
512            vec![
513                ObjectEntry::Field {
514                    key: "is_test".to_string(),
515                    value: Expr::Literal(Literal::Bool(true), Span::DUMMY),
516                    type_annotation: None,
517                },
518                ObjectEntry::Field {
519                    key: "code_lens".to_string(),
520                    value: Expr::Array(
521                        vec![Expr::Object(
522                            vec![
523                                ObjectEntry::Field {
524                                    key: "title".to_string(),
525                                    value: Expr::Literal(
526                                        Literal::String("Run".to_string()),
527                                        Span::DUMMY,
528                                    ),
529                                    type_annotation: None,
530                                },
531                                ObjectEntry::Field {
532                                    key: "command".to_string(),
533                                    value: Expr::Literal(
534                                        Literal::String("run".to_string()),
535                                        Span::DUMMY,
536                                    ),
537                                    type_annotation: None,
538                                },
539                            ],
540                            Span::DUMMY,
541                        )],
542                        Span::DUMMY,
543                    ),
544                    type_annotation: None,
545                },
546            ],
547            Span::DUMMY,
548        );
549        let result = evaluator.eval(&expr).unwrap();
550
551        let obj =
552            crate::type_schema::typed_object_to_hashmap_nb(&result).expect("Expected TypedObject");
553        assert_eq!(obj.get("is_test").and_then(|v| v.as_bool()), Some(true));
554        assert!(
555            obj.get("code_lens")
556                .and_then(|v| v.as_any_array())
557                .is_some()
558        );
559    }
560
561    #[test]
562    fn test_const_function_call_fails() {
563        let evaluator = ConstEvaluator::new();
564        let expr = Expr::FunctionCall {
565            name: "foo".to_string(),
566            args: vec![],
567            named_args: vec![],
568            span: Span::DUMMY,
569        };
570        let result = evaluator.eval(&expr);
571        assert!(result.is_err());
572        assert!(
573            result
574                .unwrap_err()
575                .to_string()
576                .contains("not allowed in const context")
577        );
578    }
579
580    #[test]
581    fn test_const_undefined_variable_fails() {
582        let evaluator = ConstEvaluator::new();
583        let expr = Expr::Identifier("undefined_var".to_string(), Span::DUMMY);
584        let result = evaluator.eval(&expr);
585        assert!(result.is_err());
586        assert!(
587            result
588                .unwrap_err()
589                .to_string()
590                .contains("annotation parameters")
591        );
592    }
593}