1use shape_ast::ast::{Expr, Literal, ObjectEntry};
63use shape_ast::error::{Result, ShapeError};
64use shape_value::ValueWord;
65use std::collections::HashMap;
66use std::sync::Arc;
67
68#[derive(Debug, Clone)]
70pub struct ConstEvaluator {
71 params: HashMap<String, ValueWord>,
74}
75
76impl ConstEvaluator {
77 pub fn new() -> Self {
79 Self {
80 params: HashMap::new(),
81 }
82 }
83
84 pub fn with_params(params: HashMap<String, ValueWord>) -> Self {
86 Self {
87 params: params.into_iter().map(|(k, v)| (k, v)).collect(),
88 }
89 }
90
91 pub fn add_param(&mut self, name: String, value: ValueWord) {
93 self.params.insert(name, value);
94 }
95
96 pub fn add_param_nb(&mut self, name: String, value: ValueWord) {
98 self.params.insert(name, value);
99 }
100
101 pub fn eval(&self, expr: &Expr) -> Result<ValueWord> {
105 Ok(self.eval_nb(expr)?.clone())
106 }
107
108 pub fn eval_as_nb(&self, expr: &Expr) -> Result<ValueWord> {
110 self.eval_nb(expr)
111 }
112
113 fn eval_nb(&self, expr: &Expr) -> Result<ValueWord> {
115 match expr {
116 Expr::Literal(lit, _) => match lit {
118 Literal::Int(i) => Ok(ValueWord::from_f64(*i as f64)),
119 Literal::UInt(u) => Ok(ValueWord::from_native_u64(*u)),
120 Literal::TypedInt(v, _) => Ok(ValueWord::from_i64(*v)),
121 Literal::Number(n) => Ok(ValueWord::from_f64(*n)),
122 Literal::Decimal(d) => {
123 use rust_decimal::prelude::ToPrimitive;
124 Ok(ValueWord::from_f64(d.to_f64().unwrap_or(0.0)))
125 }
126 Literal::String(s) => Ok(ValueWord::from_string(Arc::new(s.clone()))),
127 Literal::FormattedString { value, .. } => {
128 Ok(ValueWord::from_string(Arc::new(value.clone())))
129 }
130 Literal::ContentString { value, .. } => {
131 Ok(ValueWord::from_string(Arc::new(value.clone())))
132 }
133 Literal::Char(c) => Ok(ValueWord::from_char(*c)),
134 Literal::Bool(b) => Ok(ValueWord::from_bool(*b)),
135 Literal::None => Ok(ValueWord::none()),
136 Literal::Unit => Ok(ValueWord::unit()),
137 Literal::Timeframe(tf) => Ok(ValueWord::from_timeframe(*tf)),
138 },
139
140 Expr::Object(entries, _) => {
142 let mut pairs: Vec<(String, ValueWord)> = Vec::new();
143 for entry in entries {
144 match entry {
145 ObjectEntry::Field {
146 key,
147 value,
148 type_annotation: _,
149 } => {
150 let val = self.eval_nb(value)?;
151 pairs.push((key.clone(), val));
152 }
153 ObjectEntry::Spread(_) => {
154 return Err(ShapeError::RuntimeError {
155 message: "Object spread (...) not allowed in const context"
156 .to_string(),
157 location: None,
158 });
159 }
160 }
161 }
162 let ref_pairs: Vec<(&str, ValueWord)> =
163 pairs.iter().map(|(k, v)| (k.as_str(), v.clone())).collect();
164 Ok(crate::type_schema::typed_object_from_nb_pairs(&ref_pairs))
165 }
166
167 Expr::Array(elements, _) => {
169 let mut arr = Vec::new();
170 for elem in elements {
171 arr.push(self.eval_nb(elem)?);
172 }
173 Ok(ValueWord::from_array(Arc::new(arr)))
174 }
175
176 Expr::Identifier(name, _span) => {
178 self.params
179 .get(name)
180 .cloned()
181 .ok_or_else(|| ShapeError::RuntimeError {
182 message: format!(
183 "Cannot reference variable '{}' in const context (metadata()). \
184 Only annotation parameters are allowed.",
185 name
186 ),
187 location: None,
188 })
189 }
190
191 Expr::BinaryOp {
193 left,
194 op,
195 right,
196 span: _,
197 } => {
198 let left_val = self.eval_nb(left)?;
199 let right_val = self.eval_nb(right)?;
200
201 use shape_ast::ast::BinaryOp;
202 match op {
203 BinaryOp::Add => self.const_add_nb(left_val, right_val),
205 BinaryOp::Sub => {
206 self.const_arith_nb(left_val, right_val, "subtraction", |a, b| a - b)
207 }
208 BinaryOp::Mul => {
209 self.const_arith_nb(left_val, right_val, "multiplication", |a, b| a * b)
210 }
211 BinaryOp::Div => {
212 let a = left_val.as_f64().ok_or_else(|| ShapeError::RuntimeError {
213 message: "Const division only works on numbers".to_string(),
214 location: None,
215 })?;
216 let b = right_val.as_f64().ok_or_else(|| ShapeError::RuntimeError {
217 message: "Const division only works on numbers".to_string(),
218 location: None,
219 })?;
220 if b == 0.0 {
221 Err(ShapeError::RuntimeError {
222 message: "Division by zero in const context".to_string(),
223 location: None,
224 })
225 } else {
226 Ok(ValueWord::from_f64(a / b))
227 }
228 }
229 BinaryOp::Mod => {
230 self.const_arith_nb(left_val, right_val, "modulo", |a, b| a % b)
231 }
232
233 BinaryOp::Equal => Ok(ValueWord::from_bool(left_val.vw_equals(&right_val))),
235 BinaryOp::NotEqual => Ok(ValueWord::from_bool(!left_val.vw_equals(&right_val))),
236 BinaryOp::Less => self.const_compare_nb(left_val, right_val, |a, b| a < b),
237 BinaryOp::LessEq => self.const_compare_nb(left_val, right_val, |a, b| a <= b),
238 BinaryOp::Greater => self.const_compare_nb(left_val, right_val, |a, b| a > b),
239 BinaryOp::GreaterEq => {
240 self.const_compare_nb(left_val, right_val, |a, b| a >= b)
241 }
242
243 BinaryOp::And => Ok(ValueWord::from_bool(
245 left_val.is_truthy() && right_val.is_truthy(),
246 )),
247 BinaryOp::Or => Ok(ValueWord::from_bool(
248 left_val.is_truthy() || right_val.is_truthy(),
249 )),
250
251 _ => Err(ShapeError::RuntimeError {
253 message: format!("Binary operator {:?} not allowed in const context", op),
254 location: None,
255 }),
256 }
257 }
258
259 Expr::UnaryOp {
261 op,
262 operand,
263 span: _,
264 } => {
265 let val = self.eval_nb(operand)?;
266 use shape_ast::ast::UnaryOp;
267 match op {
268 UnaryOp::Not => Ok(ValueWord::from_bool(!val.is_truthy())),
269 UnaryOp::Neg => {
270 if let Some(n) = val.as_f64() {
271 Ok(ValueWord::from_f64(-n))
272 } else {
273 Err(ShapeError::RuntimeError {
274 message: "Cannot negate non-number in const context".to_string(),
275 location: None,
276 })
277 }
278 }
279 UnaryOp::BitNot => Err(ShapeError::RuntimeError {
280 message: "Bitwise NOT not allowed in const context".to_string(),
281 location: None,
282 }),
283 }
284 }
285
286 Expr::FunctionCall { .. } => Err(ShapeError::RuntimeError {
288 message: "Function calls are not allowed in const context (metadata())".to_string(),
289 location: None,
290 }),
291
292 Expr::PropertyAccess { .. } => Err(ShapeError::RuntimeError {
293 message:
294 "Property access (obj.field) is not allowed in const context (metadata()). \
295 Cannot access runtime state like ctx.* or fn.*"
296 .to_string(),
297 location: None,
298 }),
299
300 _ => Err(ShapeError::RuntimeError {
301 message: format!(
302 "Expression type not allowed in const context (metadata()): {:?}",
303 expr
304 ),
305 location: None,
306 }),
307 }
308 }
309
310 fn const_add_nb(&self, left: ValueWord, right: ValueWord) -> Result<ValueWord> {
313 if let (Some(a), Some(b)) = (left.as_f64(), right.as_f64()) {
314 return Ok(ValueWord::from_f64(a + b));
315 }
316 if let (Some(a), Some(b)) = (left.as_str(), right.as_str()) {
317 return Ok(ValueWord::from_string(Arc::new(format!("{}{}", a, b))));
318 }
319 Err(ShapeError::RuntimeError {
320 message: "Const addition only works on numbers or strings".to_string(),
321 location: None,
322 })
323 }
324
325 fn const_arith_nb(
326 &self,
327 left: ValueWord,
328 right: ValueWord,
329 op_name: &str,
330 f: fn(f64, f64) -> f64,
331 ) -> Result<ValueWord> {
332 let a = left.as_f64().ok_or_else(|| ShapeError::RuntimeError {
333 message: format!("Const {} only works on numbers", op_name),
334 location: None,
335 })?;
336 let b = right.as_f64().ok_or_else(|| ShapeError::RuntimeError {
337 message: format!("Const {} only works on numbers", op_name),
338 location: None,
339 })?;
340 Ok(ValueWord::from_f64(f(a, b)))
341 }
342
343 fn const_compare_nb(
344 &self,
345 left: ValueWord,
346 right: ValueWord,
347 cmp: fn(f64, f64) -> bool,
348 ) -> Result<ValueWord> {
349 let a = left.as_f64().ok_or_else(|| ShapeError::RuntimeError {
350 message: "Const comparison only works on numbers".to_string(),
351 location: None,
352 })?;
353 let b = right.as_f64().ok_or_else(|| ShapeError::RuntimeError {
354 message: "Const comparison only works on numbers".to_string(),
355 location: None,
356 })?;
357 Ok(ValueWord::from_bool(cmp(a, b)))
358 }
359}
360
361impl Default for ConstEvaluator {
362 fn default() -> Self {
363 Self::new()
364 }
365}
366
367#[cfg(test)]
368mod tests {
369 use super::*;
370 use shape_ast::ast::Span;
371 use std::sync::Arc;
372
373 #[test]
374 fn test_const_number_literal() {
375 let evaluator = ConstEvaluator::new();
376 let expr = Expr::Literal(Literal::Number(42.0), Span::DUMMY);
377 let result = evaluator.eval(&expr).unwrap();
378 assert_eq!(result, ValueWord::from_f64(42.0));
379 }
380
381 #[test]
382 fn test_const_string_literal() {
383 let evaluator = ConstEvaluator::new();
384 let expr = Expr::Literal(Literal::String("hello".to_string()), Span::DUMMY);
385 let result = evaluator.eval(&expr).unwrap();
386 assert_eq!(
387 result,
388 ValueWord::from_string(Arc::new("hello".to_string()))
389 );
390 }
391
392 #[test]
393 fn test_const_formatted_string_literal() {
394 let evaluator = ConstEvaluator::new();
395 let expr = Expr::Literal(
396 Literal::FormattedString {
397 value: "value: {x}".to_string(),
398 mode: shape_ast::ast::InterpolationMode::Braces,
399 },
400 Span::DUMMY,
401 );
402 let result = evaluator.eval(&expr).unwrap();
403 assert_eq!(
404 result,
405 ValueWord::from_string(Arc::new("value: {x}".to_string()))
406 );
407 }
408
409 #[test]
410 fn test_const_boolean_literal() {
411 let evaluator = ConstEvaluator::new();
412 let expr = Expr::Literal(Literal::Bool(true), Span::DUMMY);
413 let result = evaluator.eval(&expr).unwrap();
414 assert_eq!(result, ValueWord::from_bool(true));
415 }
416
417 #[test]
418 fn test_const_object_literal() {
419 let evaluator = ConstEvaluator::new();
420 let expr = Expr::Object(
421 vec![
422 ObjectEntry::Field {
423 key: "key1".to_string(),
424 value: Expr::Literal(Literal::Number(42.0), Span::DUMMY),
425 type_annotation: None,
426 },
427 ObjectEntry::Field {
428 key: "key2".to_string(),
429 value: Expr::Literal(Literal::String("value".to_string()), Span::DUMMY),
430 type_annotation: None,
431 },
432 ],
433 Span::DUMMY,
434 );
435 let result = evaluator.eval(&expr).unwrap();
436
437 let obj =
438 crate::type_schema::typed_object_to_hashmap_nb(&result).expect("Expected TypedObject");
439 assert_eq!(obj.get("key1").and_then(|v| v.as_f64()), Some(42.0));
440 assert_eq!(obj.get("key2").and_then(|v| v.as_str()), Some("value"));
441 }
442
443 #[test]
444 fn test_const_array_literal() {
445 let evaluator = ConstEvaluator::new();
446 let expr = Expr::Array(
447 vec![
448 Expr::Literal(Literal::Number(1.0), Span::DUMMY),
449 Expr::Literal(Literal::Number(2.0), Span::DUMMY),
450 Expr::Literal(Literal::Number(3.0), Span::DUMMY),
451 ],
452 Span::DUMMY,
453 );
454 let result = evaluator.eval(&expr).unwrap();
455
456 let arr = result.as_any_array().expect("Expected array").to_generic();
457 assert_eq!(arr.len(), 3);
458 assert_eq!(arr[0].as_f64(), Some(1.0));
459 assert_eq!(arr[1].as_f64(), Some(2.0));
460 assert_eq!(arr[2].as_f64(), Some(3.0));
461 }
462
463 #[test]
464 fn test_const_arithmetic_add() {
465 let evaluator = ConstEvaluator::new();
466 let expr = Expr::BinaryOp {
467 left: Box::new(Expr::Literal(Literal::Number(2.0), Span::DUMMY)),
468 op: shape_ast::ast::BinaryOp::Add,
469 right: Box::new(Expr::Literal(Literal::Number(3.0), Span::DUMMY)),
470 span: Span::DUMMY,
471 };
472 let result = evaluator.eval(&expr).unwrap();
473 assert_eq!(result, ValueWord::from_f64(5.0));
474 }
475
476 #[test]
477 fn test_const_string_concat() {
478 let evaluator = ConstEvaluator::new();
479 let expr = Expr::BinaryOp {
480 left: Box::new(Expr::Literal(
481 Literal::String("hello ".to_string()),
482 Span::DUMMY,
483 )),
484 op: shape_ast::ast::BinaryOp::Add,
485 right: Box::new(Expr::Literal(
486 Literal::String("world".to_string()),
487 Span::DUMMY,
488 )),
489 span: Span::DUMMY,
490 };
491 let result = evaluator.eval(&expr).unwrap();
492 assert_eq!(
493 result,
494 ValueWord::from_string(Arc::new("hello world".to_string()))
495 );
496 }
497
498 #[test]
499 fn test_const_annotation_param() {
500 let mut evaluator = ConstEvaluator::new();
501 evaluator.add_param("period".to_string(), ValueWord::from_f64(20.0));
502
503 let expr = Expr::Identifier("period".to_string(), Span::DUMMY);
504 let result = evaluator.eval(&expr).unwrap();
505 assert_eq!(result, ValueWord::from_f64(20.0));
506 }
507
508 #[test]
509 fn test_const_nested_object() {
510 let evaluator = ConstEvaluator::new();
511 let expr = Expr::Object(
512 vec![
513 ObjectEntry::Field {
514 key: "is_test".to_string(),
515 value: Expr::Literal(Literal::Bool(true), Span::DUMMY),
516 type_annotation: None,
517 },
518 ObjectEntry::Field {
519 key: "code_lens".to_string(),
520 value: Expr::Array(
521 vec![Expr::Object(
522 vec![
523 ObjectEntry::Field {
524 key: "title".to_string(),
525 value: Expr::Literal(
526 Literal::String("Run".to_string()),
527 Span::DUMMY,
528 ),
529 type_annotation: None,
530 },
531 ObjectEntry::Field {
532 key: "command".to_string(),
533 value: Expr::Literal(
534 Literal::String("run".to_string()),
535 Span::DUMMY,
536 ),
537 type_annotation: None,
538 },
539 ],
540 Span::DUMMY,
541 )],
542 Span::DUMMY,
543 ),
544 type_annotation: None,
545 },
546 ],
547 Span::DUMMY,
548 );
549 let result = evaluator.eval(&expr).unwrap();
550
551 let obj =
552 crate::type_schema::typed_object_to_hashmap_nb(&result).expect("Expected TypedObject");
553 assert_eq!(obj.get("is_test").and_then(|v| v.as_bool()), Some(true));
554 assert!(
555 obj.get("code_lens")
556 .and_then(|v| v.as_any_array())
557 .is_some()
558 );
559 }
560
561 #[test]
562 fn test_const_function_call_fails() {
563 let evaluator = ConstEvaluator::new();
564 let expr = Expr::FunctionCall {
565 name: "foo".to_string(),
566 args: vec![],
567 named_args: vec![],
568 span: Span::DUMMY,
569 };
570 let result = evaluator.eval(&expr);
571 assert!(result.is_err());
572 assert!(
573 result
574 .unwrap_err()
575 .to_string()
576 .contains("not allowed in const context")
577 );
578 }
579
580 #[test]
581 fn test_const_undefined_variable_fails() {
582 let evaluator = ConstEvaluator::new();
583 let expr = Expr::Identifier("undefined_var".to_string(), Span::DUMMY);
584 let result = evaluator.eval(&expr);
585 assert!(result.is_err());
586 assert!(
587 result
588 .unwrap_err()
589 .to_string()
590 .contains("annotation parameters")
591 );
592 }
593}