1use std::collections::HashMap;
7use std::sync::Arc;
8
9use crate::checker::{check, check_with_proto_types, CheckError, CheckResult, STANDARD_LIBRARY};
10use crate::ext;
11use crate::parser::{self, ParseError, ParseResult};
12use crate::types::{CelType, FunctionDecl, ProtoTypeRegistry, SpannedExpr};
13
14use crate::ast::Ast;
15
16#[derive(Debug, Clone)]
18pub enum CompileError {
19 Parse(Vec<ParseError>),
21 Check(Vec<CheckError>),
23 NoAst,
25}
26
27impl std::fmt::Display for CompileError {
28 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
29 match self {
30 CompileError::Parse(errors) => {
31 write!(f, "parse errors: ")?;
32 for (i, e) in errors.iter().enumerate() {
33 if i > 0 {
34 write!(f, "; ")?;
35 }
36 write!(f, "{}", e.message)?;
37 }
38 Ok(())
39 }
40 CompileError::Check(errors) => {
41 write!(f, "check errors: ")?;
42 for (i, e) in errors.iter().enumerate() {
43 if i > 0 {
44 write!(f, "; ")?;
45 }
46 write!(f, "{}", e.message())?;
47 }
48 Ok(())
49 }
50 CompileError::NoAst => write!(f, "no AST produced"),
51 }
52 }
53}
54
55impl std::error::Error for CompileError {}
56
57#[derive(Debug, Clone)]
78pub struct Env {
79 variables: HashMap<String, CelType>,
81 functions: HashMap<String, FunctionDecl>,
83 container: String,
85 proto_types: Option<Arc<ProtoTypeRegistry>>,
87}
88
89impl Env {
90 pub fn new() -> Self {
95 Self {
96 variables: HashMap::new(),
97 functions: HashMap::new(),
98 container: String::new(),
99 proto_types: None,
100 }
101 }
102
103 pub fn with_standard_library() -> Self {
107 let mut env = Self::new();
108
109 for func in STANDARD_LIBRARY.iter() {
111 env.functions.insert(func.name.clone(), func.clone());
112 }
113
114 env.add_type_constant("bool", CelType::Bool);
116 env.add_type_constant("int", CelType::Int);
117 env.add_type_constant("uint", CelType::UInt);
118 env.add_type_constant("double", CelType::Double);
119 env.add_type_constant("string", CelType::String);
120 env.add_type_constant("bytes", CelType::Bytes);
121 env.add_type_constant("list", CelType::list(CelType::Dyn));
122 env.add_type_constant("map", CelType::map(CelType::Dyn, CelType::Dyn));
123 env.add_type_constant("null_type", CelType::Null);
124 env.add_type_constant("type", CelType::type_of(CelType::Dyn));
125 env.add_type_constant("dyn", CelType::Dyn);
126
127 env
128 }
129
130 fn add_type_constant(&mut self, name: &str, cel_type: CelType) {
132 self.variables
133 .insert(name.to_string(), CelType::type_of(cel_type));
134 }
135
136 pub fn with_variable(mut self, name: impl Into<String>, cel_type: CelType) -> Self {
149 self.variables.insert(name.into(), cel_type);
150 self
151 }
152
153 pub fn add_variable(&mut self, name: impl Into<String>, cel_type: CelType) {
155 self.variables.insert(name.into(), cel_type);
156 }
157
158 pub fn with_function(mut self, decl: FunctionDecl) -> Self {
160 self.add_function(decl);
161 self
162 }
163
164 pub fn add_function(&mut self, decl: FunctionDecl) {
168 if let Some(existing) = self.functions.get_mut(&decl.name) {
169 existing.overloads.extend(decl.overloads);
171 } else {
172 self.functions.insert(decl.name.clone(), decl);
173 }
174 }
175
176 pub fn with_container(mut self, container: impl Into<String>) -> Self {
180 self.container = container.into();
181 self
182 }
183
184 pub fn set_container(&mut self, container: impl Into<String>) {
186 self.container = container.into();
187 }
188
189 pub fn container(&self) -> &str {
191 &self.container
192 }
193
194 pub fn with_proto_types(mut self, registry: ProtoTypeRegistry) -> Self {
198 self.proto_types = Some(Arc::new(registry));
199 self
200 }
201
202 pub fn proto_types(&self) -> Option<&ProtoTypeRegistry> {
204 self.proto_types.as_ref().map(|r| r.as_ref())
205 }
206
207 pub fn with_extension(mut self, extension: impl IntoIterator<Item = FunctionDecl>) -> Self {
222 for decl in extension {
223 self.add_function(decl);
224 }
225 self
226 }
227
228 pub fn with_all_extensions(mut self) -> Self {
245 self = self
246 .with_extension(ext::string_extension())
247 .with_extension(ext::math_extension())
248 .with_extension(ext::encoders_extension())
249 .with_extension(ext::optionals_extension());
250
251 self.add_type_constant("optional_type", CelType::optional(CelType::Dyn));
253
254 self
255 }
256
257 pub fn variables(&self) -> &HashMap<String, CelType> {
259 &self.variables
260 }
261
262 pub fn functions(&self) -> &HashMap<String, FunctionDecl> {
264 &self.functions
265 }
266
267 pub fn parse(&self, source: &str) -> ParseResult {
272 parser::parse(source)
273 }
274
275 pub fn check(&self, expr: &SpannedExpr) -> CheckResult {
280 if let Some(ref proto_types) = self.proto_types {
281 return check_with_proto_types(
282 expr,
283 &self.variables,
284 &self.functions,
285 &self.container,
286 proto_types,
287 );
288 }
289
290 check(expr, &self.variables, &self.functions, &self.container)
291 }
292
293 pub fn compile(&self, source: &str) -> Result<Ast, CompileError> {
312 let parse_result = self.parse(source);
313
314 if !parse_result.errors.is_empty() {
315 return Err(CompileError::Parse(parse_result.errors));
316 }
317
318 let expr = parse_result.ast.ok_or(CompileError::NoAst)?;
319 let check_result = self.check(&expr);
320
321 if !check_result.errors.is_empty() {
322 return Err(CompileError::Check(check_result.errors));
323 }
324
325 Ok(Ast::new_checked(expr, source, check_result))
326 }
327
328 pub fn parse_only(&self, source: &str) -> Result<Ast, CompileError> {
345 let parse_result = self.parse(source);
346
347 if !parse_result.errors.is_empty() {
348 return Err(CompileError::Parse(parse_result.errors));
349 }
350
351 let expr = parse_result.ast.ok_or(CompileError::NoAst)?;
352 Ok(Ast::new_unchecked(expr, source))
353 }
354}
355
356impl Default for Env {
357 fn default() -> Self {
358 Self::new()
359 }
360}
361
362#[cfg(test)]
363mod tests {
364 use super::*;
365 use crate::checker::CheckErrorKind;
366 use crate::types::OverloadDecl;
367
368 #[test]
369 fn test_new_env() {
370 let env = Env::new();
371 assert!(env.variables().is_empty());
372 assert!(env.functions().is_empty());
373 }
374
375 #[test]
376 fn test_with_standard_library() {
377 let env = Env::with_standard_library();
378
379 assert!(env.functions().contains_key("_+_"));
381 assert!(env.functions().contains_key("size"));
382 assert!(env.functions().contains_key("contains"));
383
384 assert!(env.variables().contains_key("bool"));
386 assert!(env.variables().contains_key("int"));
387 }
388
389 #[test]
390 fn test_with_variable() {
391 let env = Env::with_standard_library().with_variable("x", CelType::Int);
392
393 assert!(env.variables().contains_key("x"));
394 assert_eq!(env.variables().get("x"), Some(&CelType::Int));
395 }
396
397 #[test]
398 fn test_parse() {
399 let env = Env::new();
400 let result = env.parse("1 + 2");
401
402 assert!(result.ast.is_some());
403 assert!(result.errors.is_empty());
404 }
405
406 #[test]
407 fn test_check() {
408 let env = Env::with_standard_library().with_variable("x", CelType::Int);
409
410 let parse_result = env.parse("x + 1");
411 let ast = parse_result.ast.unwrap();
412
413 let check_result = env.check(&ast);
414 assert!(check_result.is_ok());
415 }
416
417 #[test]
418 fn test_check_undefined_variable() {
419 let env = Env::with_standard_library();
420
421 let parse_result = env.parse("x + 1");
422 let ast = parse_result.ast.unwrap();
423
424 let check_result = env.check(&ast);
425 assert!(!check_result.is_ok());
426 assert!(check_result.errors.iter().any(|e| matches!(
427 &e.kind,
428 CheckErrorKind::UndeclaredReference { name, .. } if name == "x"
429 )));
430 }
431
432 #[test]
433 fn test_compile_success() {
434 let env = Env::with_standard_library().with_variable("x", CelType::Int);
435
436 let ast = env.compile("x + 1").unwrap();
437 assert!(ast.is_checked());
438 }
439
440 #[test]
441 fn test_compile_parse_error() {
442 let env = Env::with_standard_library();
443
444 let result = env.compile("1 +");
445 assert!(result.is_err());
446 }
447
448 #[test]
449 fn test_container() {
450 let env = Env::with_standard_library().with_container("google.protobuf");
451
452 assert_eq!(env.container(), "google.protobuf");
453 }
454
455 #[test]
456 fn test_add_function() {
457 let mut env = Env::new();
458
459 let func = FunctionDecl::new("custom").with_overload(OverloadDecl::function(
460 "custom_int",
461 vec![CelType::Int],
462 CelType::Bool,
463 ));
464
465 env.add_function(func);
466
467 assert!(env.functions().contains_key("custom"));
468 }
469
470 #[test]
471 fn test_with_extension() {
472 let env = Env::with_standard_library().with_extension(ext::string_extension());
473
474 assert!(env.functions().contains_key("charAt"));
476 assert!(env.functions().contains_key("indexOf"));
477 assert!(env.functions().contains_key("substring"));
478 }
479
480 #[test]
481 fn test_with_all_extensions() {
482 let env = Env::with_standard_library().with_all_extensions();
483
484 assert!(env.functions().contains_key("charAt"));
486 assert!(env.functions().contains_key("indexOf"));
487 assert!(env.functions().contains_key("join"));
488 assert!(env.functions().contains_key("strings.quote"));
489
490 assert!(env.functions().contains_key("math.greatest"));
492 assert!(env.functions().contains_key("math.least"));
493 assert!(env.functions().contains_key("math.abs"));
494 assert!(env.functions().contains_key("math.bitAnd"));
495
496 assert!(env.functions().contains_key("base64.encode"));
498 assert!(env.functions().contains_key("base64.decode"));
499
500 assert!(env.functions().contains_key("optional.of"));
502 assert!(env.functions().contains_key("optional.none"));
503 assert!(env.functions().contains_key("optional.ofNonZeroValue"));
504 assert!(env.functions().contains_key("hasValue"));
505 assert!(env.functions().contains_key("value"));
506 assert!(env.functions().contains_key("or"));
507 assert!(env.functions().contains_key("orValue"));
508 }
509
510 #[test]
511 fn test_encoders_extension_base64() {
512 let env = Env::with_standard_library()
513 .with_all_extensions()
514 .with_variable("data", CelType::Bytes)
515 .with_variable("encoded", CelType::String);
516
517 let result = env.compile("base64.encode(data)");
519 assert!(result.is_ok(), "base64.encode should compile: {:?}", result);
520
521 let result = env.compile("base64.decode(encoded)");
523 assert!(result.is_ok(), "base64.decode should compile: {:?}", result);
524 }
525
526 #[test]
527 fn test_string_extension_char_at() {
528 let env = Env::with_standard_library()
529 .with_all_extensions()
530 .with_variable("s", CelType::String);
531
532 let result = env.compile("s.charAt(0)");
533 assert!(result.is_ok(), "charAt should compile: {:?}", result);
534 }
535
536 #[test]
537 fn test_string_extension_index_of() {
538 let env = Env::with_standard_library()
539 .with_all_extensions()
540 .with_variable("s", CelType::String);
541
542 assert!(env.compile("s.indexOf(\"a\")").is_ok());
544
545 assert!(env.compile("s.indexOf(\"a\", 2)").is_ok());
547 }
548
549 #[test]
550 fn test_string_extension_substring() {
551 let env = Env::with_standard_library()
552 .with_all_extensions()
553 .with_variable("s", CelType::String);
554
555 assert!(env.compile("s.substring(1)").is_ok());
557
558 assert!(env.compile("s.substring(1, 5)").is_ok());
560 }
561
562 #[test]
563 fn test_string_extension_split() {
564 let env = Env::with_standard_library()
565 .with_all_extensions()
566 .with_variable("s", CelType::String);
567
568 let ast = env.compile("s.split(\",\")").unwrap();
569 assert!(ast.is_checked());
570
571 }
574
575 #[test]
576 fn test_string_extension_join() {
577 let env = Env::with_standard_library()
578 .with_all_extensions()
579 .with_variable("parts", CelType::list(CelType::String));
580
581 assert!(env.compile("parts.join()").is_ok());
583
584 assert!(env.compile("parts.join(\",\")").is_ok());
586 }
587
588 #[test]
589 fn test_math_extension_greatest() {
590 let env = Env::with_standard_library().with_all_extensions();
591
592 assert!(env.compile("math.greatest(1, 2)").is_ok());
594
595 assert!(env.compile("math.greatest(1, 2, 3)").is_ok());
597
598 assert!(env.compile("math.greatest([1, 2, 3])").is_ok());
600 }
601
602 #[test]
603 fn test_math_extension_least() {
604 let env = Env::with_standard_library().with_all_extensions();
605 assert!(env.compile("math.least(1, 2)").is_ok());
606 }
607
608 #[test]
609 fn test_math_extension_abs() {
610 let env = Env::with_standard_library()
611 .with_all_extensions()
612 .with_variable("x", CelType::Int);
613
614 assert!(env.compile("math.abs(x)").is_ok());
615 }
616
617 #[test]
618 fn test_math_extension_bit_operations() {
619 let env = Env::with_standard_library()
620 .with_all_extensions()
621 .with_variable("a", CelType::Int)
622 .with_variable("b", CelType::Int);
623
624 assert!(env.compile("math.bitAnd(a, b)").is_ok());
625 assert!(env.compile("math.bitOr(a, b)").is_ok());
626 assert!(env.compile("math.bitNot(a)").is_ok());
627 }
628
629 #[test]
630 fn test_optionals_extension_of_and_none() {
631 let env = Env::with_standard_library()
632 .with_all_extensions()
633 .with_variable("x", CelType::Int);
634
635 let result = env.compile("optional.of(x)");
637 assert!(result.is_ok(), "optional.of should compile: {:?}", result);
638
639 let result = env.compile("optional.none()");
641 assert!(result.is_ok(), "optional.none should compile: {:?}", result);
642
643 let result = env.compile("optional.ofNonZeroValue(x)");
645 assert!(
646 result.is_ok(),
647 "optional.ofNonZeroValue should compile: {:?}",
648 result
649 );
650 }
651
652 #[test]
653 fn test_cel_bind_macro() {
654 let env = Env::with_standard_library().with_all_extensions();
655
656 let result = env.compile("cel.bind(x, 10, x + 1)");
658 assert!(result.is_ok(), "cel.bind should compile: {:?}", result);
659
660 let result = env.compile("cel.bind(msg, \"hello\", msg + msg)");
662 assert!(result.is_ok(), "cel.bind with string should compile: {:?}", result);
663
664 let result = env.compile("cel.bind(x, 1, cel.bind(y, 2, x + y))");
666 assert!(result.is_ok(), "nested cel.bind should compile: {:?}", result);
667 }
668
669 #[test]
670 fn test_optionals_extension_methods() {
671 let env = Env::with_standard_library()
672 .with_all_extensions()
673 .with_variable("opt", CelType::optional(CelType::Int))
674 .with_variable("opt2", CelType::optional(CelType::Int));
675
676 let result = env.compile("opt.hasValue()");
678 assert!(result.is_ok(), "hasValue should compile: {:?}", result);
679
680 let result = env.compile("opt.value()");
682 assert!(result.is_ok(), "value should compile: {:?}", result);
683
684 let result = env.compile("opt.or(opt2)");
686 assert!(result.is_ok(), "or should compile: {:?}", result);
687
688 let result = env.compile("opt.orValue(42)");
690 assert!(result.is_ok(), "orValue should compile: {:?}", result);
691 }
692}