1use std::{cell::Cell, collections::BTreeMap};
23
24use serde::{Deserialize, Deserializer, Serialize, de};
25
26use crate::error::GatekeeperError;
27
28pub const MAX_COMBINATOR_DEPTH: usize = 128;
34
35pub const MAX_COMBINATOR_BRANCH_WIDTH: usize = 256;
37
38pub const MAX_COMBINATOR_NODE_COUNT: usize = 10_000;
40
41pub const MAX_RETRY_ATTEMPTS: u32 = 100;
43
44#[derive(Debug, Clone, Copy)]
45struct ReductionBudget {
46 original_units: u64,
47 remaining_units: u64,
48}
49
50impl ReductionBudget {
51 const fn new(duration: Duration) -> Self {
52 Self {
53 original_units: duration.0,
54 remaining_units: duration.0,
55 }
56 }
57
58 fn consume(&mut self) -> Result<(), GatekeeperError> {
59 if self.remaining_units == 0 {
60 return Err(GatekeeperError::CombinatorError(format!(
61 "timeout budget exhausted: deterministic reduction exceeded {} units",
62 self.original_units
63 )));
64 }
65 self.remaining_units -= 1;
66 Ok(())
67 }
68}
69
70fn consume_active_budgets(budgets: &mut [ReductionBudget]) -> Result<(), GatekeeperError> {
71 for budget in budgets {
72 budget.consume()?;
73 }
74 Ok(())
75}
76
77#[derive(Debug, Clone, Serialize, Deserialize)]
79pub struct Predicate {
80 pub name: String,
82 pub required_key: String,
84 pub expected_value: Option<String>,
86}
87
88impl Predicate {
89 pub fn evaluate(&self, input: &CombinatorInput) -> bool {
90 match input.fields.get(&self.required_key) {
91 None => false,
92 Some(val) => match &self.expected_value {
93 None => true,
94 Some(expected) => val == expected,
95 },
96 }
97 }
98}
99
100#[derive(Debug, Clone, Serialize, Deserialize)]
102pub struct TransformFn {
103 pub name: String,
105 pub output_key: String,
107 pub output_value: String,
109}
110
111#[derive(Debug, Clone, Serialize)]
113pub struct RetryPolicy {
114 pub max_retries: u32,
116 pub current_attempt: u32,
118}
119
120impl RetryPolicy {
121 fn validate(&self) -> Result<(), GatekeeperError> {
122 if self.max_retries > MAX_RETRY_ATTEMPTS {
123 return Err(GatekeeperError::CombinatorError(format!(
124 "maximum retry budget exceeded: {} > {}",
125 self.max_retries, MAX_RETRY_ATTEMPTS
126 )));
127 }
128 if self.current_attempt > self.max_retries {
129 return Err(GatekeeperError::CombinatorError(format!(
130 "retry current_attempt {} exceeds max_retries {}",
131 self.current_attempt, self.max_retries
132 )));
133 }
134 Ok(())
135 }
136}
137
138#[derive(Deserialize)]
139struct RetryPolicyProxy {
140 max_retries: u32,
141 current_attempt: u32,
142}
143
144impl<'de> Deserialize<'de> for RetryPolicy {
145 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
146 where
147 D: Deserializer<'de>,
148 {
149 let proxy = RetryPolicyProxy::deserialize(deserializer)?;
150 let policy = Self {
151 max_retries: proxy.max_retries,
152 current_attempt: proxy.current_attempt,
153 };
154 policy.validate().map_err(de::Error::custom)?;
155 Ok(policy)
156 }
157}
158
159#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
161pub struct CheckpointId(pub String);
162
163#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
170pub struct Duration(pub u64);
171
172#[derive(Debug, Clone, Serialize)]
174pub enum Combinator {
175 Identity,
177 Sequence(Vec<Combinator>),
179 Parallel(Vec<Combinator>),
181 Choice(Vec<Combinator>),
183 Guard(Box<Combinator>, Predicate),
185 Transform(Box<Combinator>, TransformFn),
187 Retry(Box<Combinator>, RetryPolicy),
189 Timeout(Box<Combinator>, Duration),
191 Checkpoint(Box<Combinator>, CheckpointId),
193}
194
195thread_local! {
196 static COMBINATOR_DESERIALIZE_DEPTH: Cell<usize> = const { Cell::new(0) };
197 static COMBINATOR_DESERIALIZE_NODE_COUNT: Cell<usize> = const { Cell::new(0) };
198}
199
200struct CombinatorDeserializeDepthGuard {
201 is_root: bool,
202}
203
204impl Drop for CombinatorDeserializeDepthGuard {
205 fn drop(&mut self) {
206 COMBINATOR_DESERIALIZE_DEPTH.with(|depth| {
207 depth.set(depth.get().saturating_sub(1));
208 });
209 if self.is_root {
210 COMBINATOR_DESERIALIZE_NODE_COUNT.with(|nodes| {
211 nodes.set(0);
212 });
213 }
214 }
215}
216
217fn enter_combinator_deserialize_depth<E>() -> Result<CombinatorDeserializeDepthGuard, E>
218where
219 E: de::Error,
220{
221 COMBINATOR_DESERIALIZE_DEPTH.with(|depth| {
222 let current = depth.get();
223 let is_root = current == 0;
224 if current > MAX_COMBINATOR_DEPTH {
225 return Err(de::Error::custom(format!(
226 "maximum combinator nesting depth exceeded during deserialization: {} > {}",
227 current, MAX_COMBINATOR_DEPTH
228 )));
229 }
230 COMBINATOR_DESERIALIZE_NODE_COUNT.with(|nodes| {
231 if is_root {
232 nodes.set(0);
233 }
234 let current_nodes = nodes.get();
235 if current_nodes >= MAX_COMBINATOR_NODE_COUNT {
236 return Err(de::Error::custom(format!(
237 "maximum combinator node count exceeded during deserialization: more than {}",
238 MAX_COMBINATOR_NODE_COUNT
239 )));
240 }
241 nodes.set(current_nodes + 1);
242 depth.set(current + 1);
243 Ok(CombinatorDeserializeDepthGuard { is_root })
244 })
245 })
246}
247
248struct BoundedCombinators(Vec<Combinator>);
249
250impl<'de> Deserialize<'de> for BoundedCombinators {
251 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
252 where
253 D: Deserializer<'de>,
254 {
255 deserializer.deserialize_seq(BoundedCombinatorsVisitor)
256 }
257}
258
259struct BoundedCombinatorsVisitor;
260
261impl<'de> de::Visitor<'de> for BoundedCombinatorsVisitor {
262 type Value = BoundedCombinators;
263
264 fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
265 formatter.write_str("a bounded combinator sequence")
266 }
267
268 fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
269 where
270 A: de::SeqAccess<'de>,
271 {
272 if seq
273 .size_hint()
274 .is_some_and(|hint| hint > MAX_COMBINATOR_BRANCH_WIDTH)
275 {
276 return Err(de::Error::custom(format!(
277 "maximum combinator branch width exceeded: more than {}",
278 MAX_COMBINATOR_BRANCH_WIDTH
279 )));
280 }
281
282 let mut combinators = Vec::new();
283 while let Some(combinator) = seq.next_element()? {
284 if combinators.len() >= MAX_COMBINATOR_BRANCH_WIDTH {
285 return Err(de::Error::custom(format!(
286 "maximum combinator branch width exceeded: more than {}",
287 MAX_COMBINATOR_BRANCH_WIDTH
288 )));
289 }
290 combinators.push(combinator);
291 }
292
293 Ok(BoundedCombinators(combinators))
294 }
295}
296
297#[derive(Deserialize)]
298enum CombinatorProxy {
299 Identity,
300 Sequence(BoundedCombinators),
301 Parallel(BoundedCombinators),
302 Choice(BoundedCombinators),
303 Guard(Box<Combinator>, Predicate),
304 Transform(Box<Combinator>, TransformFn),
305 Retry(Box<Combinator>, RetryPolicy),
306 Timeout(Box<Combinator>, Duration),
307 Checkpoint(Box<Combinator>, CheckpointId),
308}
309
310impl<'de> Deserialize<'de> for Combinator {
311 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
312 where
313 D: Deserializer<'de>,
314 {
315 let _depth_guard = enter_combinator_deserialize_depth::<D::Error>()?;
316 let proxy = CombinatorProxy::deserialize(deserializer)?;
317 Ok(match proxy {
318 CombinatorProxy::Identity => Self::Identity,
319 CombinatorProxy::Sequence(BoundedCombinators(combinators)) => {
320 Self::Sequence(combinators)
321 }
322 CombinatorProxy::Parallel(BoundedCombinators(combinators)) => {
323 Self::Parallel(combinators)
324 }
325 CombinatorProxy::Choice(BoundedCombinators(combinators)) => Self::Choice(combinators),
326 CombinatorProxy::Guard(inner, predicate) => Self::Guard(inner, predicate),
327 CombinatorProxy::Transform(inner, transform) => Self::Transform(inner, transform),
328 CombinatorProxy::Retry(inner, policy) => Self::Retry(inner, policy),
329 CombinatorProxy::Timeout(inner, duration) => Self::Timeout(inner, duration),
330 CombinatorProxy::Checkpoint(inner, checkpoint_id) => {
331 Self::Checkpoint(inner, checkpoint_id)
332 }
333 })
334 }
335}
336
337#[derive(Debug, Clone, Serialize, Deserialize, Default)]
343pub struct CombinatorInput {
344 pub fields: BTreeMap<String, String>,
346}
347
348impl CombinatorInput {
349 #[must_use]
350 pub fn new() -> Self {
351 Self::default()
352 }
353
354 pub fn set(&mut self, key: impl Into<String>, value: impl Into<String>) {
355 self.fields.insert(key.into(), value.into());
356 }
357
358 #[must_use]
359 pub fn with(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
360 self.set(key, value);
361 self
362 }
363}
364
365#[derive(Debug, Clone, Serialize, Deserialize, Default)]
367pub struct CombinatorOutput {
368 pub fields: BTreeMap<String, String>,
370 pub checkpoint: Option<CheckpointId>,
372}
373
374impl CombinatorOutput {
375 #[must_use]
376 pub fn from_input(input: &CombinatorInput) -> Self {
377 Self {
378 fields: input.fields.clone(),
379 checkpoint: None,
380 }
381 }
382
383 pub fn set(&mut self, key: impl Into<String>, value: impl Into<String>) {
384 self.fields.insert(key.into(), value.into());
385 }
386
387 pub fn merge(&mut self, other: &CombinatorOutput) {
388 for (k, v) in &other.fields {
389 self.fields.insert(k.clone(), v.clone());
390 }
391 if other.checkpoint.is_some() {
392 self.checkpoint.clone_from(&other.checkpoint);
393 }
394 }
395}
396
397pub fn reduce(
405 combinator: &Combinator,
406 input: &CombinatorInput,
407) -> Result<CombinatorOutput, GatekeeperError> {
408 validate_combinator_structure(combinator)?;
409 let mut budgets = Vec::new();
410 reduce_inner(combinator, input, 0, &mut budgets)
411}
412
413fn validate_combinator_structure(combinator: &Combinator) -> Result<(), GatekeeperError> {
414 let mut stack = vec![(combinator, 0usize)];
415 let mut node_count = 0usize;
416
417 while let Some((current, depth)) = stack.pop() {
418 if depth > MAX_COMBINATOR_DEPTH {
419 return Err(GatekeeperError::CombinatorError(format!(
420 "maximum combinator nesting depth exceeded: {} > {}",
421 depth, MAX_COMBINATOR_DEPTH
422 )));
423 }
424
425 if node_count >= MAX_COMBINATOR_NODE_COUNT {
426 return Err(GatekeeperError::CombinatorError(format!(
427 "maximum combinator node count exceeded: more than {}",
428 MAX_COMBINATOR_NODE_COUNT
429 )));
430 }
431 node_count += 1;
432
433 match current {
434 Combinator::Identity => {}
435 Combinator::Sequence(combinators) => {
436 enforce_branch_width("Sequence", combinators.len())?;
437 for child in combinators.iter().rev() {
438 stack.push((child, depth + 1));
439 }
440 }
441 Combinator::Parallel(combinators) => {
442 enforce_branch_width("Parallel", combinators.len())?;
443 for child in combinators.iter().rev() {
444 stack.push((child, depth + 1));
445 }
446 }
447 Combinator::Choice(combinators) => {
448 enforce_branch_width("Choice", combinators.len())?;
449 for child in combinators.iter().rev() {
450 stack.push((child, depth + 1));
451 }
452 }
453 Combinator::Guard(inner, _) | Combinator::Transform(inner, _) => {
454 stack.push((inner, depth + 1));
455 }
456 Combinator::Retry(inner, policy) => {
457 policy.validate()?;
458 stack.push((inner, depth + 1));
459 }
460 Combinator::Timeout(inner, _) | Combinator::Checkpoint(inner, _) => {
461 stack.push((inner, depth + 1));
462 }
463 }
464 }
465
466 Ok(())
467}
468
469fn reduce_inner(
470 combinator: &Combinator,
471 input: &CombinatorInput,
472 depth: usize,
473 budgets: &mut Vec<ReductionBudget>,
474) -> Result<CombinatorOutput, GatekeeperError> {
475 if depth > MAX_COMBINATOR_DEPTH {
476 return Err(GatekeeperError::CombinatorError(format!(
477 "maximum combinator nesting depth exceeded: {} > {}",
478 depth, MAX_COMBINATOR_DEPTH
479 )));
480 }
481 consume_active_budgets(budgets)?;
482
483 match combinator {
484 Combinator::Identity => Ok(CombinatorOutput::from_input(input)),
485
486 Combinator::Sequence(combinators) => {
487 enforce_branch_width("Sequence", combinators.len())?;
488 let mut current_input = input.clone();
489 let mut last_output = CombinatorOutput::from_input(input);
490
491 for (i, c) in combinators.iter().enumerate() {
492 match reduce_inner(c, ¤t_input, depth + 1, budgets) {
493 Ok(output) => {
494 current_input = CombinatorInput {
496 fields: output.fields.clone(),
497 };
498 last_output = output;
499 }
500 Err(e) => {
501 return Err(GatekeeperError::CombinatorError(format!(
502 "Sequence step {} failed: {}",
503 i, e
504 )));
505 }
506 }
507 }
508 Ok(last_output)
509 }
510
511 Combinator::Parallel(combinators) => {
512 enforce_branch_width("Parallel", combinators.len())?;
513 let mut merged = CombinatorOutput::from_input(input);
514
515 for (i, c) in combinators.iter().enumerate() {
516 match reduce_inner(c, input, depth + 1, budgets) {
517 Ok(output) => {
518 merged.merge(&output);
519 }
520 Err(e) => {
521 return Err(GatekeeperError::CombinatorError(format!(
522 "Parallel branch {} failed: {}",
523 i, e
524 )));
525 }
526 }
527 }
528 Ok(merged)
529 }
530
531 Combinator::Choice(combinators) => {
532 enforce_branch_width("Choice", combinators.len())?;
533 for c in combinators {
534 match reduce_inner(c, input, depth + 1, budgets) {
535 Ok(output) => return Ok(output),
536 Err(_) => continue,
537 }
538 }
539 Err(GatekeeperError::CombinatorError(
540 "Choice: all alternatives failed".into(),
541 ))
542 }
543
544 Combinator::Guard(inner, predicate) => {
545 if !predicate.evaluate(input) {
546 return Err(GatekeeperError::CombinatorError(format!(
547 "Guard predicate '{}' failed",
548 predicate.name
549 )));
550 }
551 reduce_inner(inner, input, depth + 1, budgets)
552 }
553
554 Combinator::Transform(inner, transform) => {
555 let mut output = reduce_inner(inner, input, depth + 1, budgets)?;
556 output.set(transform.output_key.clone(), transform.output_value.clone());
557 Ok(output)
558 }
559
560 Combinator::Retry(inner, policy) => {
561 policy.validate()?;
562 let mut last_err = None;
563 for attempt in 0..=policy.max_retries {
564 match reduce_inner(inner, input, depth + 1, budgets) {
565 Ok(mut output) => {
566 output.set("retry_attempts", attempt.to_string());
567 return Ok(output);
568 }
569 Err(e) => {
570 last_err = Some(e);
571 }
572 }
573 }
574 Err(last_err
575 .unwrap_or_else(|| GatekeeperError::CombinatorError("Retry exhausted".into())))
576 }
577
578 Combinator::Timeout(inner, duration) => {
579 budgets.push(ReductionBudget::new(*duration));
580 let result = reduce_inner(inner, input, depth + 1, budgets);
581 if budgets.pop().is_none() {
582 return Err(GatekeeperError::CombinatorError(
583 "timeout budget stack underflow".into(),
584 ));
585 }
586 let mut output = result?;
587 output.set("timeout_budget_ms", duration.0.to_string());
588 Ok(output)
589 }
590
591 Combinator::Checkpoint(inner, checkpoint_id) => {
592 let mut output = reduce_inner(inner, input, depth + 1, budgets)?;
593 output.checkpoint = Some(checkpoint_id.clone());
594 Ok(output)
595 }
596 }
597}
598
599fn enforce_branch_width(kind: &str, len: usize) -> Result<(), GatekeeperError> {
600 if len > MAX_COMBINATOR_BRANCH_WIDTH {
601 return Err(GatekeeperError::CombinatorError(format!(
602 "maximum combinator branch width exceeded in {}: {} > {}",
603 kind, len, MAX_COMBINATOR_BRANCH_WIDTH
604 )));
605 }
606 Ok(())
607}
608
609#[cfg(test)]
614mod tests {
615 use super::*;
616
617 fn sample_input() -> CombinatorInput {
618 CombinatorInput::new()
619 .with("name", "alice")
620 .with("role", "judge")
621 }
622
623 fn branch_bounded_sequence(total_leaves: usize) -> Combinator {
624 let mut remaining = total_leaves;
625 let mut branches = Vec::new();
626 while remaining > 0 {
627 let chunk = remaining.min(MAX_COMBINATOR_BRANCH_WIDTH);
628 branches.push(Combinator::Sequence(vec![Combinator::Identity; chunk]));
629 remaining -= chunk;
630 }
631 Combinator::Sequence(branches)
632 }
633
634 fn branch_bounded_sequence_json(total_leaves: usize) -> String {
635 let mut remaining = total_leaves;
636 let mut json = String::from("{\"Sequence\":[");
637 let mut outer_idx = 0;
638 while remaining > 0 {
639 if outer_idx > 0 {
640 json.push(',');
641 }
642 let chunk = remaining.min(MAX_COMBINATOR_BRANCH_WIDTH);
643 json.push_str("{\"Sequence\":[");
644 for inner_idx in 0..chunk {
645 if inner_idx > 0 {
646 json.push(',');
647 }
648 json.push_str("\"Identity\"");
649 }
650 json.push_str("]}");
651 remaining -= chunk;
652 outer_idx += 1;
653 }
654 json.push_str("]}");
655 json
656 }
657
658 #[test]
661 fn identity_passes_through() {
662 let input = sample_input();
663 let output = reduce(&Combinator::Identity, &input).unwrap();
664 assert_eq!(output.fields, input.fields);
665 }
666
667 #[test]
670 fn sequence_empty_returns_input() {
671 let input = sample_input();
672 let output = reduce(&Combinator::Sequence(vec![]), &input).unwrap();
673 assert_eq!(output.fields, input.fields);
674 }
675
676 #[test]
677 fn sequence_chains_results() {
678 let input = sample_input();
679 let seq = Combinator::Sequence(vec![
680 Combinator::Transform(
681 Box::new(Combinator::Identity),
682 TransformFn {
683 name: "add_step1".into(),
684 output_key: "step1".into(),
685 output_value: "done".into(),
686 },
687 ),
688 Combinator::Transform(
689 Box::new(Combinator::Identity),
690 TransformFn {
691 name: "add_step2".into(),
692 output_key: "step2".into(),
693 output_value: "done".into(),
694 },
695 ),
696 ]);
697 let output = reduce(&seq, &input).unwrap();
698 assert_eq!(output.fields.get("step1"), Some(&"done".to_string()));
699 assert_eq!(output.fields.get("step2"), Some(&"done".to_string()));
700 }
701
702 #[test]
703 fn sequence_fails_if_any_step_fails() {
704 let input = sample_input();
705 let seq = Combinator::Sequence(vec![
706 Combinator::Identity,
707 Combinator::Guard(
708 Box::new(Combinator::Identity),
709 Predicate {
710 name: "requires_admin".into(),
711 required_key: "admin".into(),
712 expected_value: None,
713 },
714 ),
715 ]);
716 let result = reduce(&seq, &input);
717 assert!(result.is_err());
718 }
719
720 #[test]
723 fn parallel_merges_results() {
724 let input = sample_input();
725 let par = Combinator::Parallel(vec![
726 Combinator::Transform(
727 Box::new(Combinator::Identity),
728 TransformFn {
729 name: "branch_a".into(),
730 output_key: "a".into(),
731 output_value: "1".into(),
732 },
733 ),
734 Combinator::Transform(
735 Box::new(Combinator::Identity),
736 TransformFn {
737 name: "branch_b".into(),
738 output_key: "b".into(),
739 output_value: "2".into(),
740 },
741 ),
742 ]);
743 let output = reduce(&par, &input).unwrap();
744 assert_eq!(output.fields.get("a"), Some(&"1".to_string()));
745 assert_eq!(output.fields.get("b"), Some(&"2".to_string()));
746 }
747
748 #[test]
749 fn parallel_fails_if_any_branch_fails() {
750 let input = sample_input();
751 let par = Combinator::Parallel(vec![
752 Combinator::Identity,
753 Combinator::Guard(
754 Box::new(Combinator::Identity),
755 Predicate {
756 name: "impossible".into(),
757 required_key: "nonexistent".into(),
758 expected_value: None,
759 },
760 ),
761 ]);
762 assert!(reduce(&par, &input).is_err());
763 }
764
765 #[test]
768 fn choice_returns_first_success() {
769 let input = sample_input();
770 let choice = Combinator::Choice(vec![
771 Combinator::Guard(
772 Box::new(Combinator::Transform(
773 Box::new(Combinator::Identity),
774 TransformFn {
775 name: "fail_branch".into(),
776 output_key: "branch".into(),
777 output_value: "first".into(),
778 },
779 )),
780 Predicate {
781 name: "impossible".into(),
782 required_key: "nonexistent".into(),
783 expected_value: None,
784 },
785 ),
786 Combinator::Transform(
787 Box::new(Combinator::Identity),
788 TransformFn {
789 name: "success_branch".into(),
790 output_key: "branch".into(),
791 output_value: "second".into(),
792 },
793 ),
794 ]);
795 let output = reduce(&choice, &input).unwrap();
796 assert_eq!(output.fields.get("branch"), Some(&"second".to_string()));
797 }
798
799 #[test]
800 fn choice_fails_if_all_alternatives_fail() {
801 let input = sample_input();
802 let guard = |key: &str| {
803 Combinator::Guard(
804 Box::new(Combinator::Identity),
805 Predicate {
806 name: "fail".into(),
807 required_key: key.into(),
808 expected_value: None,
809 },
810 )
811 };
812 let choice = Combinator::Choice(vec![guard("x"), guard("y"), guard("z")]);
813 assert!(reduce(&choice, &input).is_err());
814 }
815
816 #[test]
819 fn guard_passes_when_predicate_holds() {
820 let input = sample_input();
821 let guarded = Combinator::Guard(
822 Box::new(Combinator::Identity),
823 Predicate {
824 name: "has_name".into(),
825 required_key: "name".into(),
826 expected_value: None,
827 },
828 );
829 assert!(reduce(&guarded, &input).is_ok());
830 }
831
832 #[test]
833 fn guard_fails_when_predicate_does_not_hold() {
834 let input = sample_input();
835 let guarded = Combinator::Guard(
836 Box::new(Combinator::Identity),
837 Predicate {
838 name: "has_admin".into(),
839 required_key: "admin".into(),
840 expected_value: None,
841 },
842 );
843 assert!(reduce(&guarded, &input).is_err());
844 }
845
846 #[test]
847 fn guard_checks_expected_value() {
848 let input = sample_input();
849 let guarded = Combinator::Guard(
850 Box::new(Combinator::Identity),
851 Predicate {
852 name: "name_is_alice".into(),
853 required_key: "name".into(),
854 expected_value: Some("alice".into()),
855 },
856 );
857 assert!(reduce(&guarded, &input).is_ok());
858
859 let guarded_wrong = Combinator::Guard(
860 Box::new(Combinator::Identity),
861 Predicate {
862 name: "name_is_bob".into(),
863 required_key: "name".into(),
864 expected_value: Some("bob".into()),
865 },
866 );
867 assert!(reduce(&guarded_wrong, &input).is_err());
868 }
869
870 #[test]
873 fn transform_adds_key_to_output() {
874 let input = sample_input();
875 let transformed = Combinator::Transform(
876 Box::new(Combinator::Identity),
877 TransformFn {
878 name: "add_status".into(),
879 output_key: "status".into(),
880 output_value: "verified".into(),
881 },
882 );
883 let output = reduce(&transformed, &input).unwrap();
884 assert_eq!(output.fields.get("status"), Some(&"verified".to_string()));
885 assert_eq!(output.fields.get("name"), Some(&"alice".to_string()));
887 }
888
889 #[test]
892 fn retry_succeeds_on_first_attempt_for_identity() {
893 let input = sample_input();
894 let retried = Combinator::Retry(
895 Box::new(Combinator::Identity),
896 RetryPolicy {
897 max_retries: 3,
898 current_attempt: 0,
899 },
900 );
901 let output = reduce(&retried, &input).unwrap();
902 assert_eq!(output.fields.get("retry_attempts"), Some(&"0".to_string()));
903 }
904
905 #[test]
906 fn retry_exhausts_on_permanent_failure() {
907 let input = sample_input();
908 let retried = Combinator::Retry(
909 Box::new(Combinator::Guard(
910 Box::new(Combinator::Identity),
911 Predicate {
912 name: "impossible".into(),
913 required_key: "nonexistent".into(),
914 expected_value: None,
915 },
916 )),
917 RetryPolicy {
918 max_retries: 2,
919 current_attempt: 0,
920 },
921 );
922 assert!(reduce(&retried, &input).is_err());
923 }
924
925 #[test]
926 fn retry_rejects_excessive_retry_budget_before_looping() {
927 let input = sample_input();
928 let retried = Combinator::Retry(
929 Box::new(Combinator::Guard(
930 Box::new(Combinator::Identity),
931 Predicate {
932 name: "impossible".into(),
933 required_key: "nonexistent".into(),
934 expected_value: None,
935 },
936 )),
937 RetryPolicy {
938 max_retries: 101,
939 current_attempt: 0,
940 },
941 );
942
943 let err = match reduce(&retried, &input) {
944 Ok(output) => panic!("excessive retries must fail fast: {output:?}"),
945 Err(err) => err,
946 };
947 assert!(
948 err.to_string().contains("maximum retry"),
949 "unexpected error: {err}"
950 );
951 }
952
953 #[test]
954 fn reduce_rejects_excessive_combinator_depth() {
955 let input = sample_input();
956 let mut combinator = Combinator::Identity;
957 for _ in 0..129 {
958 combinator = Combinator::Timeout(Box::new(combinator), Duration(1));
959 }
960
961 let err = match reduce(&combinator, &input) {
962 Ok(output) => panic!("excessive depth must fail: {output:?}"),
963 Err(err) => err,
964 };
965 assert!(
966 err.to_string().contains("maximum combinator nesting depth"),
967 "unexpected error: {err}"
968 );
969 }
970
971 #[test]
972 fn reduce_rejects_excessive_branch_width() {
973 let input = sample_input();
974 let combinator = Combinator::Sequence(vec![Combinator::Identity; 257]);
975
976 let err = match reduce(&combinator, &input) {
977 Ok(output) => panic!("excessive branch width must fail: {output:?}"),
978 Err(err) => err,
979 };
980 assert!(
981 err.to_string().contains("maximum combinator branch width"),
982 "unexpected error: {err}"
983 );
984 }
985
986 #[test]
987 fn reduce_rejects_excessive_total_node_count() {
988 let input = sample_input();
989 let combinator = branch_bounded_sequence(MAX_COMBINATOR_NODE_COUNT + 1);
990
991 let err = match reduce(&combinator, &input) {
992 Ok(_) => panic!("excessive node count must fail"),
993 Err(err) => err,
994 };
995 assert!(
996 err.to_string().contains("maximum combinator node count"),
997 "unexpected error: {err}"
998 );
999 }
1000
1001 #[test]
1002 fn combinator_deserialization_is_not_directly_derived() {
1003 let source = include_str!("combinator.rs");
1004 assert!(
1005 !source
1006 .contains("#[derive(Debug, Clone, Serialize, Deserialize)]\npub enum Combinator"),
1007 "Combinator deserialization must enforce structural limits"
1008 );
1009 }
1010
1011 #[test]
1012 fn deserialization_rejects_excessive_branch_width() {
1013 let mut json = String::from("{\"Sequence\":[");
1014 for idx in 0..257 {
1015 if idx > 0 {
1016 json.push(',');
1017 }
1018 json.push_str("\"Identity\"");
1019 }
1020 json.push_str("]}");
1021
1022 let err = match serde_json::from_str::<Combinator>(&json) {
1023 Ok(combinator) => panic!("wide sequence must be rejected: {combinator:?}"),
1024 Err(err) => err,
1025 };
1026 assert!(
1027 err.to_string().contains("maximum combinator branch width"),
1028 "unexpected error: {err}"
1029 );
1030 }
1031
1032 #[test]
1033 fn deserialization_rejects_excessive_total_node_count() {
1034 let json = branch_bounded_sequence_json(MAX_COMBINATOR_NODE_COUNT + 1);
1035
1036 let err = match serde_json::from_str::<Combinator>(&json) {
1037 Ok(_) => panic!("oversized combinator tree must be rejected"),
1038 Err(err) => err,
1039 };
1040 assert!(
1041 err.to_string().contains("maximum combinator node count"),
1042 "unexpected error: {err}"
1043 );
1044 }
1045
1046 #[test]
1049 fn timeout_runs_inner_and_records_budget() {
1050 let input = sample_input();
1051 let timed = Combinator::Timeout(Box::new(Combinator::Identity), Duration(5000));
1052 let output = reduce(&timed, &input).unwrap();
1053 assert_eq!(
1054 output.fields.get("timeout_budget_ms"),
1055 Some(&"5000".to_string())
1056 );
1057 }
1058
1059 #[test]
1060 fn timeout_rejects_inner_reduction_over_deterministic_budget() {
1061 let input = sample_input();
1062 let timed = Combinator::Timeout(
1063 Box::new(Combinator::Sequence(vec![
1064 Combinator::Identity,
1065 Combinator::Identity,
1066 ])),
1067 Duration(2),
1068 );
1069
1070 let err = reduce(&timed, &input).expect_err("over-budget timeout must fail closed");
1071 assert!(
1072 err.to_string().contains("timeout budget exhausted"),
1073 "unexpected timeout error: {err}"
1074 );
1075 }
1076
1077 #[test]
1078 fn timeout_allows_inner_reduction_within_deterministic_budget() {
1079 let input = sample_input();
1080 let timed = Combinator::Timeout(
1081 Box::new(Combinator::Sequence(vec![
1082 Combinator::Identity,
1083 Combinator::Identity,
1084 ])),
1085 Duration(3),
1086 );
1087
1088 let output = reduce(&timed, &input).expect("three-node inner reduction should fit budget");
1089 assert_eq!(
1090 output.fields.get("timeout_budget_ms"),
1091 Some(&"3".to_owned())
1092 );
1093 }
1094
1095 #[test]
1098 fn checkpoint_records_id_in_output() {
1099 let input = sample_input();
1100 let cp = Combinator::Checkpoint(
1101 Box::new(Combinator::Identity),
1102 CheckpointId("cp-001".into()),
1103 );
1104 let output = reduce(&cp, &input).unwrap();
1105 assert_eq!(output.checkpoint, Some(CheckpointId("cp-001".into())));
1106 }
1107
1108 #[test]
1111 fn reduction_is_deterministic() {
1112 let input = sample_input();
1113 let combinator = Combinator::Sequence(vec![
1114 Combinator::Transform(
1115 Box::new(Combinator::Identity),
1116 TransformFn {
1117 name: "step1".into(),
1118 output_key: "x".into(),
1119 output_value: "1".into(),
1120 },
1121 ),
1122 Combinator::Checkpoint(Box::new(Combinator::Identity), CheckpointId("cp".into())),
1123 ]);
1124
1125 let output1 = reduce(&combinator, &input).unwrap();
1126 let output2 = reduce(&combinator, &input).unwrap();
1127 assert_eq!(output1.fields, output2.fields);
1128 assert_eq!(output1.checkpoint, output2.checkpoint);
1129 }
1130
1131 #[test]
1134 fn complex_composition() {
1135 let input = CombinatorInput::new()
1136 .with("authorized", "true")
1137 .with("user", "alice");
1138
1139 let program = Combinator::Sequence(vec![
1140 Combinator::Guard(
1141 Box::new(Combinator::Identity),
1142 Predicate {
1143 name: "is_authorized".into(),
1144 required_key: "authorized".into(),
1145 expected_value: Some("true".into()),
1146 },
1147 ),
1148 Combinator::Parallel(vec![
1149 Combinator::Transform(
1150 Box::new(Combinator::Identity),
1151 TransformFn {
1152 name: "audit".into(),
1153 output_key: "audited".into(),
1154 output_value: "yes".into(),
1155 },
1156 ),
1157 Combinator::Transform(
1158 Box::new(Combinator::Identity),
1159 TransformFn {
1160 name: "log".into(),
1161 output_key: "logged".into(),
1162 output_value: "yes".into(),
1163 },
1164 ),
1165 ]),
1166 Combinator::Checkpoint(Box::new(Combinator::Identity), CheckpointId("final".into())),
1167 ]);
1168
1169 let output = reduce(&program, &input).unwrap();
1170 assert_eq!(output.fields.get("audited"), Some(&"yes".to_string()));
1171 assert_eq!(output.fields.get("logged"), Some(&"yes".to_string()));
1172 assert_eq!(output.checkpoint, Some(CheckpointId("final".into())));
1173 }
1174
1175 #[test]
1178 fn combinator_input_new_is_empty() {
1179 let input = CombinatorInput::new();
1180 assert!(input.fields.is_empty());
1181 }
1182
1183 #[test]
1184 fn combinator_input_with_chaining() {
1185 let input = CombinatorInput::new().with("a", "1").with("b", "2");
1186 assert_eq!(input.fields.len(), 2);
1187 }
1188
1189 #[test]
1192 fn combinator_output_merge() {
1193 let mut out1 = CombinatorOutput::default();
1194 out1.set("a", "1");
1195 let mut out2 = CombinatorOutput::default();
1196 out2.set("b", "2");
1197 out2.checkpoint = Some(CheckpointId("cp".into()));
1198
1199 out1.merge(&out2);
1200 assert_eq!(out1.fields.get("a"), Some(&"1".to_string()));
1201 assert_eq!(out1.fields.get("b"), Some(&"2".to_string()));
1202 assert_eq!(out1.checkpoint, Some(CheckpointId("cp".into())));
1203 }
1204
1205 #[test]
1208 fn predicate_evaluate_missing_key() {
1209 let input = CombinatorInput::new();
1210 let pred = Predicate {
1211 name: "test".into(),
1212 required_key: "missing".into(),
1213 expected_value: None,
1214 };
1215 assert!(!pred.evaluate(&input));
1216 }
1217
1218 #[test]
1219 fn predicate_evaluate_key_exists_no_value_check() {
1220 let input = CombinatorInput::new().with("key", "anything");
1221 let pred = Predicate {
1222 name: "test".into(),
1223 required_key: "key".into(),
1224 expected_value: None,
1225 };
1226 assert!(pred.evaluate(&input));
1227 }
1228
1229 #[test]
1230 fn predicate_evaluate_value_mismatch() {
1231 let input = CombinatorInput::new().with("key", "actual");
1232 let pred = Predicate {
1233 name: "test".into(),
1234 required_key: "key".into(),
1235 expected_value: Some("expected".into()),
1236 };
1237 assert!(!pred.evaluate(&input));
1238 }
1239
1240 #[test]
1243 fn parallel_empty_returns_input() {
1244 let input = sample_input();
1245 let output = reduce(&Combinator::Parallel(vec![]), &input).unwrap();
1246 assert_eq!(output.fields, input.fields);
1247 }
1248
1249 #[test]
1252 fn choice_single_returns_it() {
1253 let input = sample_input();
1254 let choice = Combinator::Choice(vec![Combinator::Identity]);
1255 let output = reduce(&choice, &input).unwrap();
1256 assert_eq!(output.fields, input.fields);
1257 }
1258}