Skip to main content

noether_engine/
checker.rs

1use crate::lagrange::{CompositionNode, Pinning};
2use noether_core::capability::Capability;
3use noether_core::effects::{Effect, EffectKind, EffectSet};
4use noether_core::stage::StageId;
5use noether_core::types::{is_subtype_of, IncompatibilityReason, NType, TypeCompatibility};
6use noether_store::StageStore;
7use std::collections::{BTreeMap, BTreeSet};
8use std::fmt;
9
10/// The resolved input/output types of a composition node.
11#[derive(Debug, Clone)]
12pub struct ResolvedType {
13    pub input: NType,
14    pub output: NType,
15}
16
17// ── Capability enforcement ─────────────────────────────────────────────────
18
19/// Policy controlling which capabilities a composition is allowed to use.
20///
21/// `allowed` is empty → all capabilities permitted (default / backward-compatible).
22/// `allowed` is non-empty → only the listed capabilities are permitted.
23#[derive(Debug, Clone, Default)]
24pub struct CapabilityPolicy {
25    /// Capabilities the caller grants. Empty set = allow all.
26    pub allowed: BTreeSet<Capability>,
27}
28
29impl CapabilityPolicy {
30    /// A policy that allows every capability.
31    pub fn allow_all() -> Self {
32        Self {
33            allowed: BTreeSet::new(),
34        }
35    }
36
37    /// A policy that permits only the listed capabilities.
38    pub fn restrict(caps: impl IntoIterator<Item = Capability>) -> Self {
39        Self {
40            allowed: caps.into_iter().collect(),
41        }
42    }
43
44    fn is_allowed(&self, cap: &Capability) -> bool {
45        self.allowed.is_empty() || self.allowed.contains(cap)
46    }
47}
48
49/// A single capability violation found during pre-flight checking.
50#[derive(Debug, Clone)]
51pub struct CapabilityViolation {
52    pub stage_id: StageId,
53    pub required: Capability,
54    pub message: String,
55}
56
57impl fmt::Display for CapabilityViolation {
58    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
59        write!(
60            f,
61            "stage {} requires capability {:?} which is not granted",
62            self.stage_id.0, self.required
63        )
64    }
65}
66
67/// Pre-flight check: walk the graph and verify every stage's declared capabilities
68/// are within the granted policy. Returns an empty vec when all capabilities pass.
69pub fn check_capabilities(
70    node: &CompositionNode,
71    store: &(impl StageStore + ?Sized),
72    policy: &CapabilityPolicy,
73) -> Vec<CapabilityViolation> {
74    let mut violations = Vec::new();
75    collect_capability_violations(node, store, policy, &mut violations);
76    violations
77}
78
79fn collect_capability_violations(
80    node: &CompositionNode,
81    store: &(impl StageStore + ?Sized),
82    policy: &CapabilityPolicy,
83    violations: &mut Vec<CapabilityViolation>,
84) {
85    match node {
86        CompositionNode::Stage { id, .. } => {
87            if let Ok(Some(stage)) = store.get(id) {
88                for cap in &stage.capabilities {
89                    if !policy.is_allowed(cap) {
90                        violations.push(CapabilityViolation {
91                            stage_id: id.clone(),
92                            required: cap.clone(),
93                            message: format!(
94                                "stage '{}' requires {:?}; grant it with --allow-capabilities",
95                                stage.description, cap
96                            ),
97                        });
98                    }
99                }
100            }
101        }
102        CompositionNode::RemoteStage { .. } => {} // remote stages have no local capabilities
103        CompositionNode::Const { .. } => {}       // no capabilities in a constant
104        CompositionNode::Sequential { stages } => {
105            for s in stages {
106                collect_capability_violations(s, store, policy, violations);
107            }
108        }
109        CompositionNode::Parallel { branches } => {
110            for branch in branches.values() {
111                collect_capability_violations(branch, store, policy, violations);
112            }
113        }
114        CompositionNode::Branch {
115            predicate,
116            if_true,
117            if_false,
118        } => {
119            collect_capability_violations(predicate, store, policy, violations);
120            collect_capability_violations(if_true, store, policy, violations);
121            collect_capability_violations(if_false, store, policy, violations);
122        }
123        CompositionNode::Fanout { source, targets } => {
124            collect_capability_violations(source, store, policy, violations);
125            for t in targets {
126                collect_capability_violations(t, store, policy, violations);
127            }
128        }
129        CompositionNode::Merge { sources, target } => {
130            for s in sources {
131                collect_capability_violations(s, store, policy, violations);
132            }
133            collect_capability_violations(target, store, policy, violations);
134        }
135        CompositionNode::Retry { stage, .. } => {
136            collect_capability_violations(stage, store, policy, violations);
137        }
138        CompositionNode::Let { bindings, body } => {
139            for b in bindings.values() {
140                collect_capability_violations(b, store, policy, violations);
141            }
142            collect_capability_violations(body, store, policy, violations);
143        }
144    }
145}
146
147// ── Effect inference & enforcement ────────────────────────────────────────
148
149/// Policy controlling which effect kinds a composition is allowed to declare.
150///
151/// `allowed` is empty → all effects permitted (default / backward-compatible).
152/// `allowed` is non-empty → only the listed effect kinds are permitted; others
153/// produce an [`EffectViolation`].
154#[derive(Debug, Clone, Default)]
155pub struct EffectPolicy {
156    /// Effect kinds the caller grants. Empty set = allow all.
157    pub allowed: BTreeSet<EffectKind>,
158}
159
160impl EffectPolicy {
161    /// A policy that allows every effect (default).
162    pub fn allow_all() -> Self {
163        Self {
164            allowed: BTreeSet::new(),
165        }
166    }
167
168    /// A policy that permits only the listed effect kinds.
169    pub fn restrict(kinds: impl IntoIterator<Item = EffectKind>) -> Self {
170        Self {
171            allowed: kinds.into_iter().collect(),
172        }
173    }
174
175    pub fn is_allowed(&self, kind: &EffectKind) -> bool {
176        self.allowed.is_empty() || self.allowed.contains(kind)
177    }
178}
179
180/// A single effect violation found during pre-flight checking.
181#[derive(Debug, Clone)]
182pub struct EffectViolation {
183    pub stage_id: StageId,
184    pub effect: Effect,
185    pub message: String,
186}
187
188impl fmt::Display for EffectViolation {
189    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
190        write!(f, "{}", self.message)
191    }
192}
193
194/// Walk the composition graph and return the union of all effects declared by
195/// every stage. `RemoteStage` nodes always contribute `Effect::Network`.
196/// Stages not found in the store contribute `Effect::Unknown`.
197pub fn infer_effects(node: &CompositionNode, store: &(impl StageStore + ?Sized)) -> EffectSet {
198    let mut effects: BTreeSet<Effect> = BTreeSet::new();
199    collect_effects_inner(node, store, &mut effects);
200    EffectSet::new(effects)
201}
202
203fn collect_effects_inner(
204    node: &CompositionNode,
205    store: &(impl StageStore + ?Sized),
206    effects: &mut BTreeSet<Effect>,
207) {
208    match node {
209        CompositionNode::Stage { id, .. } => match store.get(id) {
210            Ok(Some(stage)) => {
211                for e in stage.signature.effects.iter() {
212                    effects.insert(e.clone());
213                }
214            }
215            _ => {
216                effects.insert(Effect::Unknown);
217            }
218        },
219        CompositionNode::RemoteStage { .. } => {
220            effects.insert(Effect::Network);
221            effects.insert(Effect::Fallible);
222        }
223        CompositionNode::Const { .. } => {
224            effects.insert(Effect::Pure);
225        }
226        CompositionNode::Sequential { stages } => {
227            for s in stages {
228                collect_effects_inner(s, store, effects);
229            }
230        }
231        CompositionNode::Parallel { branches } => {
232            for branch in branches.values() {
233                collect_effects_inner(branch, store, effects);
234            }
235        }
236        CompositionNode::Branch {
237            predicate,
238            if_true,
239            if_false,
240        } => {
241            collect_effects_inner(predicate, store, effects);
242            collect_effects_inner(if_true, store, effects);
243            collect_effects_inner(if_false, store, effects);
244        }
245        CompositionNode::Fanout { source, targets } => {
246            collect_effects_inner(source, store, effects);
247            for t in targets {
248                collect_effects_inner(t, store, effects);
249            }
250        }
251        CompositionNode::Merge { sources, target } => {
252            for s in sources {
253                collect_effects_inner(s, store, effects);
254            }
255            collect_effects_inner(target, store, effects);
256        }
257        CompositionNode::Retry { stage, .. } => {
258            collect_effects_inner(stage, store, effects);
259        }
260        CompositionNode::Let { bindings, body } => {
261            for b in bindings.values() {
262                collect_effects_inner(b, store, effects);
263            }
264            collect_effects_inner(body, store, effects);
265        }
266    }
267}
268
269/// Pre-flight check: walk the graph and verify every stage's declared effects
270/// are within the granted policy. Returns an empty vec when all effects are allowed.
271pub fn check_effects(
272    node: &CompositionNode,
273    store: &(impl StageStore + ?Sized),
274    policy: &EffectPolicy,
275) -> Vec<EffectViolation> {
276    let mut violations = Vec::new();
277    collect_effect_violations(node, store, policy, &mut violations);
278    violations
279}
280
281fn collect_effect_violations(
282    node: &CompositionNode,
283    store: &(impl StageStore + ?Sized),
284    policy: &EffectPolicy,
285    violations: &mut Vec<EffectViolation>,
286) {
287    match node {
288        CompositionNode::Stage { id, .. } => match store.get(id) {
289            Ok(Some(stage)) => {
290                for effect in stage.signature.effects.iter() {
291                    let kind = effect.kind();
292                    if !policy.is_allowed(&kind) {
293                        violations.push(EffectViolation {
294                            stage_id: id.clone(),
295                            effect: effect.clone(),
296                            message: format!(
297                                "stage '{}' declares effect {kind}; grant it with --allow-effects {kind}",
298                                stage.description
299                            ),
300                        });
301                    }
302                }
303            }
304            _ => {
305                let kind = EffectKind::Unknown;
306                if !policy.is_allowed(&kind) {
307                    violations.push(EffectViolation {
308                        stage_id: id.clone(),
309                        effect: Effect::Unknown,
310                        message: format!(
311                            "stage {} has unknown effects (not in store); grant with --allow-effects unknown",
312                            id.0
313                        ),
314                    });
315                }
316            }
317        },
318        CompositionNode::RemoteStage { .. } => {
319            for effect in &[Effect::Network, Effect::Fallible] {
320                let kind = effect.kind();
321                if !policy.is_allowed(&kind) {
322                    violations.push(EffectViolation {
323                        stage_id: StageId("remote".into()),
324                        effect: effect.clone(),
325                        message: format!(
326                            "RemoteStage declares implicit effect {kind}; grant with --allow-effects {kind}"
327                        ),
328                    });
329                }
330            }
331        }
332        CompositionNode::Const { .. } => {}
333        CompositionNode::Sequential { stages } => {
334            for s in stages {
335                collect_effect_violations(s, store, policy, violations);
336            }
337        }
338        CompositionNode::Parallel { branches } => {
339            for branch in branches.values() {
340                collect_effect_violations(branch, store, policy, violations);
341            }
342        }
343        CompositionNode::Branch {
344            predicate,
345            if_true,
346            if_false,
347        } => {
348            collect_effect_violations(predicate, store, policy, violations);
349            collect_effect_violations(if_true, store, policy, violations);
350            collect_effect_violations(if_false, store, policy, violations);
351        }
352        CompositionNode::Fanout { source, targets } => {
353            collect_effect_violations(source, store, policy, violations);
354            for t in targets {
355                collect_effect_violations(t, store, policy, violations);
356            }
357        }
358        CompositionNode::Merge { sources, target } => {
359            for s in sources {
360                collect_effect_violations(s, store, policy, violations);
361            }
362            collect_effect_violations(target, store, policy, violations);
363        }
364        CompositionNode::Retry { stage, .. } => {
365            collect_effect_violations(stage, store, policy, violations);
366        }
367        CompositionNode::Let { bindings, body } => {
368            for b in bindings.values() {
369                collect_effect_violations(b, store, policy, violations);
370            }
371            collect_effect_violations(body, store, policy, violations);
372        }
373    }
374}
375
376// ── Signature verification ─────────────────────────────────────────────────
377
378/// Why a stage's signature check failed.
379#[derive(Debug, Clone, PartialEq, Eq)]
380pub enum SignatureViolationKind {
381    /// The stage has no `ed25519_signature` / `signer_public_key` — it was built unsigned.
382    Missing,
383    /// A signature is present but cryptographic verification failed (tampered stage).
384    Invalid,
385}
386
387impl fmt::Display for SignatureViolationKind {
388    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
389        match self {
390            Self::Missing => write!(f, "unsigned"),
391            Self::Invalid => write!(f, "invalid signature"),
392        }
393    }
394}
395
396/// A single signature violation found during pre-flight checking.
397#[derive(Debug, Clone)]
398pub struct SignatureViolation {
399    pub stage_id: StageId,
400    pub kind: SignatureViolationKind,
401    pub message: String,
402}
403
404impl fmt::Display for SignatureViolation {
405    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
406        write!(f, "stage {} — {}", self.stage_id.0, self.message)
407    }
408}
409
410/// Pre-flight check: walk the graph and verify every stage's Ed25519 signature.
411///
412/// Returns an empty vec when all signatures pass. Stages with a missing
413/// signature OR an invalid signature are both reported as violations.
414pub fn verify_signatures(
415    node: &CompositionNode,
416    store: &(impl StageStore + ?Sized),
417) -> Vec<SignatureViolation> {
418    let mut violations = Vec::new();
419    collect_signature_violations(node, store, &mut violations);
420    violations
421}
422
423fn collect_signature_violations(
424    node: &CompositionNode,
425    store: &(impl StageStore + ?Sized),
426    violations: &mut Vec<SignatureViolation>,
427) {
428    match node {
429        CompositionNode::Stage { id, .. } => {
430            if let Ok(Some(stage)) = store.get(id) {
431                match (&stage.ed25519_signature, &stage.signer_public_key) {
432                    (None, _) | (_, None) => {
433                        violations.push(SignatureViolation {
434                            stage_id: id.clone(),
435                            kind: SignatureViolationKind::Missing,
436                            message: format!(
437                                "stage '{}' has no signature — add it via the signing pipeline",
438                                stage.description
439                            ),
440                        });
441                    }
442                    (Some(sig_hex), Some(pub_hex)) => {
443                        match noether_core::stage::verify_stage_signature(id, sig_hex, pub_hex) {
444                            Ok(true) => {} // valid
445                            Ok(false) => {
446                                violations.push(SignatureViolation {
447                                    stage_id: id.clone(),
448                                    kind: SignatureViolationKind::Invalid,
449                                    message: format!(
450                                        "stage '{}' signature verification failed — possible tampering",
451                                        stage.description
452                                    ),
453                                });
454                            }
455                            Err(e) => {
456                                violations.push(SignatureViolation {
457                                    stage_id: id.clone(),
458                                    kind: SignatureViolationKind::Invalid,
459                                    message: format!(
460                                        "stage '{}' signature could not be decoded: {e}",
461                                        stage.description
462                                    ),
463                                });
464                            }
465                        }
466                    }
467                }
468            }
469            // If the stage is not in the store, the type-checker will already
470            // have reported an unknown-stage error; skip here.
471        }
472        CompositionNode::Const { .. } => {} // constants have no signature to verify
473        CompositionNode::RemoteStage { .. } => {} // remote stages have no local signature to verify
474        CompositionNode::Sequential { stages } => {
475            for s in stages {
476                collect_signature_violations(s, store, violations);
477            }
478        }
479        CompositionNode::Parallel { branches } => {
480            for branch in branches.values() {
481                collect_signature_violations(branch, store, violations);
482            }
483        }
484        CompositionNode::Branch {
485            predicate,
486            if_true,
487            if_false,
488        } => {
489            collect_signature_violations(predicate, store, violations);
490            collect_signature_violations(if_true, store, violations);
491            collect_signature_violations(if_false, store, violations);
492        }
493        CompositionNode::Fanout { source, targets } => {
494            collect_signature_violations(source, store, violations);
495            for t in targets {
496                collect_signature_violations(t, store, violations);
497            }
498        }
499        CompositionNode::Merge { sources, target } => {
500            for s in sources {
501                collect_signature_violations(s, store, violations);
502            }
503            collect_signature_violations(target, store, violations);
504        }
505        CompositionNode::Retry { stage, .. } => {
506            collect_signature_violations(stage, store, violations);
507        }
508        CompositionNode::Let { bindings, body } => {
509            for b in bindings.values() {
510                collect_signature_violations(b, store, violations);
511            }
512            collect_signature_violations(body, store, violations);
513        }
514    }
515}
516
517// ── Effect warnings ────────────────────────────────────────────────────────
518
519/// Warnings about effect usage detected during graph type-checking.
520///
521/// These are soft issues — the graph is structurally valid but may have
522/// surprising runtime behaviour. Callers decide whether to block or surface them.
523#[derive(Debug, Clone)]
524pub enum EffectWarning {
525    /// A `Fallible` stage is not wrapped in a `Retry` node. Failures propagate.
526    FallibleWithoutRetry { stage_id: StageId },
527    /// A `NonDeterministic` stage's output feeds a `Pure` stage.
528    NonDeterministicFeedingPure { from: StageId, to: StageId },
529    /// The sum of declared `Cost` effects exceeds the given budget (in cents).
530    CostBudgetExceeded { total_cents: u64, budget_cents: u64 },
531}
532
533impl fmt::Display for EffectWarning {
534    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
535        match self {
536            EffectWarning::FallibleWithoutRetry { stage_id } => write!(
537                f,
538                "stage {} is Fallible but has no Retry wrapper; failures will propagate",
539                stage_id.0
540            ),
541            EffectWarning::NonDeterministicFeedingPure { from, to } => write!(
542                f,
543                "stage {} is NonDeterministic but feeds Pure stage {}; Pure caching will be bypassed",
544                from.0, to.0
545            ),
546            EffectWarning::CostBudgetExceeded { total_cents, budget_cents } => write!(
547                f,
548                "estimated composition cost ({total_cents}¢) exceeds budget ({budget_cents}¢)"
549            ),
550        }
551    }
552}
553
554/// The result of a successful graph type-check: resolved types plus any effect warnings.
555#[derive(Debug, Clone)]
556pub struct CheckResult {
557    pub resolved: ResolvedType,
558    pub warnings: Vec<EffectWarning>,
559}
560
561// ── Errors detected during graph type checking ────────────────────────────
562#[derive(Debug, Clone)]
563pub enum GraphTypeError {
564    StageNotFound {
565        id: StageId,
566    },
567    SequentialTypeMismatch {
568        position: usize,
569        from_output: NType,
570        to_input: NType,
571        reason: IncompatibilityReason,
572    },
573    BranchPredicateNotBool {
574        actual: NType,
575    },
576    BranchOutputMismatch {
577        true_output: NType,
578        false_output: NType,
579        reason: IncompatibilityReason,
580    },
581    FanoutInputMismatch {
582        target_index: usize,
583        source_output: NType,
584        target_input: NType,
585        reason: IncompatibilityReason,
586    },
587    MergeOutputMismatch {
588        merged_type: NType,
589        target_input: NType,
590        reason: IncompatibilityReason,
591    },
592    EmptyNode {
593        operator: String,
594    },
595}
596
597impl fmt::Display for GraphTypeError {
598    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
599        match self {
600            GraphTypeError::StageNotFound { id } => {
601                write!(f, "stage {} not found in store", id.0)
602            }
603            GraphTypeError::SequentialTypeMismatch {
604                position,
605                from_output,
606                to_input,
607                reason,
608            } => write!(
609                f,
610                "type mismatch at position {position}: output {from_output} is not subtype of input {to_input}: {reason}"
611            ),
612            GraphTypeError::BranchPredicateNotBool { actual } => {
613                write!(f, "branch predicate must produce Bool, got {actual}")
614            }
615            GraphTypeError::BranchOutputMismatch {
616                true_output,
617                false_output,
618                reason,
619            } => write!(
620                f,
621                "branch outputs must be compatible: if_true produces {true_output}, if_false produces {false_output}: {reason}"
622            ),
623            GraphTypeError::FanoutInputMismatch {
624                target_index,
625                source_output,
626                target_input,
627                reason,
628            } => write!(
629                f,
630                "fanout target {target_index}: source output {source_output} is not subtype of target input {target_input}: {reason}"
631            ),
632            GraphTypeError::MergeOutputMismatch {
633                merged_type,
634                target_input,
635                reason,
636            } => write!(
637                f,
638                "merge: merged type {merged_type} is not subtype of target input {target_input}: {reason}"
639            ),
640            GraphTypeError::EmptyNode { operator } => {
641                write!(f, "empty {operator} node")
642            }
643        }
644    }
645}
646
647/// Type-check a composition graph against the stage store.
648///
649/// Returns `CheckResult` (resolved types + effect warnings) on success,
650/// or a list of hard type errors on failure.
651pub fn check_graph(
652    node: &CompositionNode,
653    store: &(impl StageStore + ?Sized),
654) -> Result<CheckResult, Vec<GraphTypeError>> {
655    let mut errors = Vec::new();
656    let result = check_node(node, store, &mut errors);
657    if errors.is_empty() {
658        let resolved = result.unwrap();
659        let warnings = collect_effect_warnings(node, store, None);
660        Ok(CheckResult { resolved, warnings })
661    } else {
662        Err(errors)
663    }
664}
665
666/// Collect effect warnings by walking the graph.
667/// `cost_budget_cents` — pass `Some(n)` to enable budget enforcement.
668pub fn collect_effect_warnings(
669    node: &CompositionNode,
670    store: &(impl StageStore + ?Sized),
671    cost_budget_cents: Option<u64>,
672) -> Vec<EffectWarning> {
673    let mut warnings = Vec::new();
674    let mut total_cost: u64 = 0;
675    collect_warnings_inner(node, store, &mut warnings, &mut total_cost, false);
676    if let Some(budget) = cost_budget_cents {
677        if total_cost > budget {
678            warnings.push(EffectWarning::CostBudgetExceeded {
679                total_cents: total_cost,
680                budget_cents: budget,
681            });
682        }
683    }
684    warnings
685}
686
687fn collect_warnings_inner(
688    node: &CompositionNode,
689    store: &(impl StageStore + ?Sized),
690    warnings: &mut Vec<EffectWarning>,
691    total_cost: &mut u64,
692    _parent_is_retry: bool,
693) {
694    match node {
695        CompositionNode::Stage { id, .. } => {
696            if let Ok(Some(stage)) = store.get(id) {
697                // Accumulate cost
698                for effect in stage.signature.effects.iter() {
699                    if let Effect::Cost { cents } = effect {
700                        *total_cost = total_cost.saturating_add(*cents);
701                    }
702                }
703                // Fallible without retry is handled at the parent sequential level
704            }
705        }
706        CompositionNode::RemoteStage { .. } => {} // remote calls have no local effects to warn about
707        CompositionNode::Const { .. } => {}       // no effects in a constant
708        CompositionNode::Sequential { stages } => {
709            for (i, s) in stages.iter().enumerate() {
710                collect_warnings_inner(s, store, warnings, total_cost, false);
711
712                // Rule: Fallible stage not wrapped in Retry
713                if let CompositionNode::Stage { id, .. } = s {
714                    if let Ok(Some(stage)) = store.get(id) {
715                        if stage.signature.effects.contains(&Effect::Fallible) {
716                            warnings.push(EffectWarning::FallibleWithoutRetry {
717                                stage_id: id.clone(),
718                            });
719                        }
720                    }
721                }
722
723                // Rule: NonDeterministic output → Pure input
724                if i + 1 < stages.len() {
725                    if let (
726                        CompositionNode::Stage { id: from_id, .. },
727                        CompositionNode::Stage { id: to_id, .. },
728                    ) = (s, &stages[i + 1])
729                    {
730                        let from_nd = store
731                            .get(from_id)
732                            .ok()
733                            .flatten()
734                            .map(|s| s.signature.effects.contains(&Effect::NonDeterministic))
735                            .unwrap_or(false);
736                        let to_pure = store
737                            .get(to_id)
738                            .ok()
739                            .flatten()
740                            .map(|s| s.signature.effects.contains(&Effect::Pure))
741                            .unwrap_or(false);
742
743                        if from_nd && to_pure {
744                            warnings.push(EffectWarning::NonDeterministicFeedingPure {
745                                from: from_id.clone(),
746                                to: to_id.clone(),
747                            });
748                        }
749                    }
750                }
751            }
752        }
753        CompositionNode::Parallel { branches } => {
754            for branch in branches.values() {
755                collect_warnings_inner(branch, store, warnings, total_cost, false);
756            }
757        }
758        CompositionNode::Branch {
759            predicate,
760            if_true,
761            if_false,
762        } => {
763            collect_warnings_inner(predicate, store, warnings, total_cost, false);
764            collect_warnings_inner(if_true, store, warnings, total_cost, false);
765            collect_warnings_inner(if_false, store, warnings, total_cost, false);
766        }
767        CompositionNode::Fanout { source, targets } => {
768            collect_warnings_inner(source, store, warnings, total_cost, false);
769            for t in targets {
770                collect_warnings_inner(t, store, warnings, total_cost, false);
771            }
772        }
773        CompositionNode::Merge { sources, target } => {
774            for s in sources {
775                collect_warnings_inner(s, store, warnings, total_cost, false);
776            }
777            collect_warnings_inner(target, store, warnings, total_cost, false);
778        }
779        CompositionNode::Retry { stage, .. } => {
780            // Retry wraps Fallible — suppress FallibleWithoutRetry for direct child
781            collect_warnings_inner(stage, store, warnings, total_cost, true);
782            // Remove any FallibleWithoutRetry that was just added for the immediate child
783            if let CompositionNode::Stage { id, .. } = stage.as_ref() {
784                warnings.retain(|w| !matches!(w, EffectWarning::FallibleWithoutRetry { stage_id } if stage_id == id));
785            }
786        }
787        CompositionNode::Let { bindings, body } => {
788            for b in bindings.values() {
789                collect_warnings_inner(b, store, warnings, total_cost, false);
790            }
791            collect_warnings_inner(body, store, warnings, total_cost, false);
792        }
793    }
794}
795
796fn check_node(
797    node: &CompositionNode,
798    store: &(impl StageStore + ?Sized),
799    errors: &mut Vec<GraphTypeError>,
800) -> Option<ResolvedType> {
801    match node {
802        CompositionNode::Stage {
803            id,
804            pinning,
805            config,
806        } => {
807            let resolved = check_stage(id, *pinning, store, errors)?;
808            // When config provides fields, reduce the effective input type
809            if let Some(cfg) = config {
810                if !cfg.is_empty() {
811                    if let NType::Record(fields) = &resolved.input {
812                        let remaining: std::collections::BTreeMap<String, NType> = fields
813                            .iter()
814                            .filter(|(name, _)| !cfg.contains_key(*name))
815                            .map(|(name, ty)| (name.clone(), ty.clone()))
816                            .collect();
817                        let effective = if remaining.is_empty() || remaining.len() == 1 {
818                            NType::Any
819                        } else {
820                            NType::Record(remaining)
821                        };
822                        return Some(ResolvedType {
823                            input: effective,
824                            output: resolved.output,
825                        });
826                    }
827                }
828            }
829            Some(resolved)
830        }
831        // RemoteStage: types are declared inline — no store lookup needed.
832        // The type checker trusts the declared input/output types.
833        CompositionNode::RemoteStage { input, output, .. } => Some(ResolvedType {
834            input: input.clone(),
835            output: output.clone(),
836        }),
837        // Const: accepts Any input, emits Any output (actual type is inferred from value at runtime)
838        CompositionNode::Const { .. } => Some(ResolvedType {
839            input: NType::Any,
840            output: NType::Any,
841        }),
842        CompositionNode::Sequential { stages } => check_sequential(stages, store, errors),
843        CompositionNode::Parallel { branches } => check_parallel(branches, store, errors),
844        CompositionNode::Branch {
845            predicate,
846            if_true,
847            if_false,
848        } => check_branch(predicate, if_true, if_false, store, errors),
849        CompositionNode::Fanout { source, targets } => check_fanout(source, targets, store, errors),
850        CompositionNode::Merge { sources, target } => check_merge(sources, target, store, errors),
851        CompositionNode::Retry { stage, .. } => check_node(stage, store, errors),
852        CompositionNode::Let { bindings, body } => check_let(bindings, body, store, errors),
853    }
854}
855
856/// Type-check a `Let` node.
857///
858/// Each binding sees the **outer Let input**. The body sees an augmented
859/// record `{ ...outer-input fields, <binding>: <binding-output> }`. The
860/// Let's overall input requirement is the union of:
861///   - every binding's input field requirements (each binding sees the same
862///     outer input), and
863///   - any field the body's input requires that is *not* satisfied by a
864///     binding (those must come through from the outer input).
865///
866/// The Let's output is the body's output. When inputs are not Records (e.g.
867/// `Any`), we conservatively widen to `NType::Any` rather than failing.
868fn check_let(
869    bindings: &BTreeMap<String, CompositionNode>,
870    body: &CompositionNode,
871    store: &(impl StageStore + ?Sized),
872    errors: &mut Vec<GraphTypeError>,
873) -> Option<ResolvedType> {
874    if bindings.is_empty() {
875        errors.push(GraphTypeError::EmptyNode {
876            operator: "Let".into(),
877        });
878        return None;
879    }
880
881    // Resolve every binding's types.
882    let mut binding_outputs: BTreeMap<String, NType> = BTreeMap::new();
883    let mut required_input: BTreeMap<String, NType> = BTreeMap::new();
884    let mut any_input = false;
885
886    for (name, node) in bindings {
887        let resolved = check_node(node, store, errors)?;
888        binding_outputs.insert(name.clone(), resolved.output);
889        match resolved.input {
890            NType::Record(fields) => {
891                for (f, ty) in fields {
892                    required_input.insert(f, ty);
893                }
894            }
895            NType::Any => {
896                any_input = true;
897            }
898            other => {
899                // A binding that wants a non-Record, non-Any input doesn't
900                // compose cleanly with the Let's record-shaped input. We
901                // conservatively require the outer input to be Any.
902                let _ = other;
903                any_input = true;
904            }
905        }
906    }
907
908    // Build the body's input record by merging outer-input requirements with
909    // the binding outputs (bindings shadow outer fields with the same name).
910    let mut body_input_fields = required_input.clone();
911    for (name, out_ty) in &binding_outputs {
912        body_input_fields.insert(name.clone(), out_ty.clone());
913    }
914
915    let body_resolved = check_node(body, store, errors)?;
916
917    // Verify the body's input is satisfied by the augmented record. For each
918    // field the body requires, either it must come from a binding output (in
919    // which case the binding's output must be a subtype of the expected
920    // field) or from the outer input — in which case we add it to the Let's
921    // overall input requirement.
922    if let NType::Record(body_fields) = &body_resolved.input {
923        for (name, expected_ty) in body_fields {
924            let provided = body_input_fields.get(name).cloned();
925            match provided {
926                Some(actual) => {
927                    if let TypeCompatibility::Incompatible(reason) =
928                        is_subtype_of(&actual, expected_ty)
929                    {
930                        errors.push(GraphTypeError::SequentialTypeMismatch {
931                            position: 0,
932                            from_output: actual,
933                            to_input: expected_ty.clone(),
934                            reason,
935                        });
936                    }
937                }
938                None => {
939                    // Body needs a field neither bindings nor known outer
940                    // requirements provide. Mark it as required from outer
941                    // input.
942                    required_input.insert(name.clone(), expected_ty.clone());
943                }
944            }
945        }
946    }
947
948    let input = if any_input || required_input.is_empty() {
949        NType::Any
950    } else {
951        NType::Record(required_input)
952    };
953
954    Some(ResolvedType {
955        input,
956        output: body_resolved.output,
957    })
958}
959
960fn check_stage(
961    id: &StageId,
962    pinning: Pinning,
963    store: &(impl StageStore + ?Sized),
964    errors: &mut Vec<GraphTypeError>,
965) -> Option<ResolvedType> {
966    match crate::lagrange::resolve_stage_ref(id, pinning, store) {
967        Some(stage) => Some(ResolvedType {
968            input: stage.signature.input.clone(),
969            output: stage.signature.output.clone(),
970        }),
971        None => {
972            errors.push(GraphTypeError::StageNotFound { id: id.clone() });
973            None
974        }
975    }
976}
977
978fn check_sequential(
979    stages: &[CompositionNode],
980    store: &(impl StageStore + ?Sized),
981    errors: &mut Vec<GraphTypeError>,
982) -> Option<ResolvedType> {
983    if stages.is_empty() {
984        errors.push(GraphTypeError::EmptyNode {
985            operator: "Sequential".into(),
986        });
987        return None;
988    }
989
990    let resolved: Vec<Option<ResolvedType>> = stages
991        .iter()
992        .map(|s| check_node(s, store, errors))
993        .collect();
994
995    // Check consecutive pairs
996    for i in 0..resolved.len() - 1 {
997        if let (Some(from), Some(to)) = (&resolved[i], &resolved[i + 1]) {
998            if let TypeCompatibility::Incompatible(reason) = is_subtype_of(&from.output, &to.input)
999            {
1000                errors.push(GraphTypeError::SequentialTypeMismatch {
1001                    position: i,
1002                    from_output: from.output.clone(),
1003                    to_input: to.input.clone(),
1004                    reason,
1005                });
1006            }
1007        }
1008    }
1009
1010    let first_input = resolved
1011        .first()
1012        .and_then(|r| r.as_ref())
1013        .map(|r| r.input.clone());
1014    let last_output = resolved
1015        .last()
1016        .and_then(|r| r.as_ref())
1017        .map(|r| r.output.clone());
1018
1019    match (first_input, last_output) {
1020        (Some(input), Some(output)) => Some(ResolvedType { input, output }),
1021        _ => None,
1022    }
1023}
1024
1025fn check_parallel(
1026    branches: &BTreeMap<String, CompositionNode>,
1027    store: &(impl StageStore + ?Sized),
1028    errors: &mut Vec<GraphTypeError>,
1029) -> Option<ResolvedType> {
1030    if branches.is_empty() {
1031        errors.push(GraphTypeError::EmptyNode {
1032            operator: "Parallel".into(),
1033        });
1034        return None;
1035    }
1036
1037    let mut input_fields = BTreeMap::new();
1038    let mut output_fields = BTreeMap::new();
1039
1040    for (name, node) in branches {
1041        if let Some(resolved) = check_node(node, store, errors) {
1042            input_fields.insert(name.clone(), resolved.input);
1043            output_fields.insert(name.clone(), resolved.output);
1044        }
1045    }
1046
1047    if input_fields.len() == branches.len() {
1048        Some(ResolvedType {
1049            input: NType::Record(input_fields),
1050            output: NType::Record(output_fields),
1051        })
1052    } else {
1053        None
1054    }
1055}
1056
1057fn check_branch(
1058    predicate: &CompositionNode,
1059    if_true: &CompositionNode,
1060    if_false: &CompositionNode,
1061    store: &(impl StageStore + ?Sized),
1062    errors: &mut Vec<GraphTypeError>,
1063) -> Option<ResolvedType> {
1064    let pred = check_node(predicate, store, errors);
1065    let true_branch = check_node(if_true, store, errors);
1066    let false_branch = check_node(if_false, store, errors);
1067
1068    // Check predicate output is Bool
1069    if let Some(ref p) = pred {
1070        if let TypeCompatibility::Incompatible(_) = is_subtype_of(&p.output, &NType::Bool) {
1071            errors.push(GraphTypeError::BranchPredicateNotBool {
1072                actual: p.output.clone(),
1073            });
1074        }
1075    }
1076
1077    // Branch outputs are unioned — both paths are valid return types.
1078    // No compatibility check required between branches; the consumer
1079    // of the branch output must handle the union type.
1080    match (pred, true_branch, false_branch) {
1081        (Some(p), Some(t), Some(f)) => Some(ResolvedType {
1082            input: p.input,
1083            output: NType::union(vec![t.output, f.output]),
1084        }),
1085        _ => None,
1086    }
1087}
1088
1089fn check_fanout(
1090    source: &CompositionNode,
1091    targets: &[CompositionNode],
1092    store: &(impl StageStore + ?Sized),
1093    errors: &mut Vec<GraphTypeError>,
1094) -> Option<ResolvedType> {
1095    if targets.is_empty() {
1096        errors.push(GraphTypeError::EmptyNode {
1097            operator: "Fanout".into(),
1098        });
1099        return None;
1100    }
1101
1102    let src = check_node(source, store, errors);
1103    let tgts: Vec<Option<ResolvedType>> = targets
1104        .iter()
1105        .map(|t| check_node(t, store, errors))
1106        .collect();
1107
1108    // Check source output is subtype of each target input
1109    if let Some(ref s) = src {
1110        for (i, t) in tgts.iter().enumerate() {
1111            if let Some(ref t) = t {
1112                if let TypeCompatibility::Incompatible(reason) = is_subtype_of(&s.output, &t.input)
1113                {
1114                    errors.push(GraphTypeError::FanoutInputMismatch {
1115                        target_index: i,
1116                        source_output: s.output.clone(),
1117                        target_input: t.input.clone(),
1118                        reason,
1119                    });
1120                }
1121            }
1122        }
1123    }
1124
1125    let output_types: Vec<NType> = tgts
1126        .iter()
1127        .filter_map(|t| t.as_ref().map(|r| r.output.clone()))
1128        .collect();
1129
1130    match src {
1131        Some(s) if output_types.len() == targets.len() => Some(ResolvedType {
1132            input: s.input,
1133            output: NType::List(Box::new(if output_types.len() == 1 {
1134                output_types.into_iter().next().unwrap()
1135            } else {
1136                NType::union(output_types)
1137            })),
1138        }),
1139        _ => None,
1140    }
1141}
1142
1143fn check_merge(
1144    sources: &[CompositionNode],
1145    target: &CompositionNode,
1146    store: &(impl StageStore + ?Sized),
1147    errors: &mut Vec<GraphTypeError>,
1148) -> Option<ResolvedType> {
1149    if sources.is_empty() {
1150        errors.push(GraphTypeError::EmptyNode {
1151            operator: "Merge".into(),
1152        });
1153        return None;
1154    }
1155
1156    let srcs: Vec<Option<ResolvedType>> = sources
1157        .iter()
1158        .map(|s| check_node(s, store, errors))
1159        .collect();
1160    let tgt = check_node(target, store, errors);
1161
1162    // Build merged output record from sources
1163    let mut merged_fields = BTreeMap::new();
1164    for (i, s) in srcs.iter().enumerate() {
1165        if let Some(ref r) = s {
1166            merged_fields.insert(format!("source_{i}"), r.output.clone());
1167        }
1168    }
1169    let merged_type = NType::Record(merged_fields);
1170
1171    // Check merged type is subtype of target input
1172    if let Some(ref t) = tgt {
1173        if let TypeCompatibility::Incompatible(reason) = is_subtype_of(&merged_type, &t.input) {
1174            errors.push(GraphTypeError::MergeOutputMismatch {
1175                merged_type: merged_type.clone(),
1176                target_input: t.input.clone(),
1177                reason,
1178            });
1179        }
1180    }
1181
1182    // Overall: input is record of source inputs, output is target output
1183    let mut input_fields = BTreeMap::new();
1184    for (i, s) in srcs.iter().enumerate() {
1185        if let Some(ref r) = s {
1186            input_fields.insert(format!("source_{i}"), r.input.clone());
1187        }
1188    }
1189
1190    match tgt {
1191        Some(t) => Some(ResolvedType {
1192            input: NType::Record(input_fields),
1193            output: t.output,
1194        }),
1195        None => None,
1196    }
1197}
1198
1199#[cfg(test)]
1200mod tests {
1201    use super::*;
1202    use noether_core::capability::Capability;
1203    use noether_core::effects::EffectSet;
1204    use noether_core::stage::{CostEstimate, Stage, StageSignature};
1205    use noether_store::MemoryStore;
1206    use std::collections::BTreeSet;
1207
1208    fn make_stage(id: &str, input: NType, output: NType) -> Stage {
1209        Stage {
1210            id: StageId(id.into()),
1211            signature_id: None,
1212            signature: StageSignature {
1213                input,
1214                output,
1215                effects: EffectSet::pure(),
1216                implementation_hash: format!("impl_{id}"),
1217            },
1218            capabilities: BTreeSet::new(),
1219            cost: CostEstimate {
1220                time_ms_p50: Some(10),
1221                tokens_est: None,
1222                memory_mb: None,
1223            },
1224            description: format!("test stage {id}"),
1225            examples: vec![],
1226            lifecycle: noether_core::stage::StageLifecycle::Active,
1227            ed25519_signature: None,
1228            signer_public_key: None,
1229            implementation_code: None,
1230            implementation_language: None,
1231            ui_style: None,
1232            tags: vec![],
1233            aliases: vec![],
1234            name: None,
1235            properties: Vec::new(),
1236        }
1237    }
1238
1239    fn test_store() -> MemoryStore {
1240        let mut store = MemoryStore::new();
1241        store
1242            .put(make_stage("text_to_num", NType::Text, NType::Number))
1243            .unwrap();
1244        store
1245            .put(make_stage("num_to_bool", NType::Number, NType::Bool))
1246            .unwrap();
1247        store
1248            .put(make_stage("text_to_text", NType::Text, NType::Text))
1249            .unwrap();
1250        store
1251            .put(make_stage("bool_pred", NType::Text, NType::Bool))
1252            .unwrap();
1253        store
1254            .put(make_stage("any_to_text", NType::Any, NType::Text))
1255            .unwrap();
1256        store
1257    }
1258
1259    fn stage(id: &str) -> CompositionNode {
1260        CompositionNode::Stage {
1261            id: StageId(id.into()),
1262            pinning: Pinning::Signature,
1263            config: None,
1264        }
1265    }
1266
1267    #[test]
1268    fn check_single_stage() {
1269        let store = test_store();
1270        let result = check_graph(&stage("text_to_num"), &store);
1271        let check = result.unwrap();
1272        assert_eq!(check.resolved.input, NType::Text);
1273        assert_eq!(check.resolved.output, NType::Number);
1274    }
1275
1276    #[test]
1277    fn check_missing_stage() {
1278        let store = test_store();
1279        let result = check_graph(&stage("nonexistent"), &store);
1280        assert!(result.is_err());
1281        let errors = result.unwrap_err();
1282        assert!(matches!(errors[0], GraphTypeError::StageNotFound { .. }));
1283    }
1284
1285    #[test]
1286    fn check_valid_sequential() {
1287        let store = test_store();
1288        let node = CompositionNode::Sequential {
1289            stages: vec![stage("text_to_num"), stage("num_to_bool")],
1290        };
1291        let result = check_graph(&node, &store);
1292        let check = result.unwrap();
1293        assert_eq!(check.resolved.input, NType::Text);
1294        assert_eq!(check.resolved.output, NType::Bool);
1295    }
1296
1297    #[test]
1298    fn check_invalid_sequential() {
1299        let store = test_store();
1300        // Bool output cannot feed Text input
1301        let node = CompositionNode::Sequential {
1302            stages: vec![stage("num_to_bool"), stage("text_to_num")],
1303        };
1304        let result = check_graph(&node, &store);
1305        assert!(result.is_err());
1306        let errors = result.unwrap_err();
1307        assert!(matches!(
1308            errors[0],
1309            GraphTypeError::SequentialTypeMismatch { .. }
1310        ));
1311    }
1312
1313    #[test]
1314    fn check_parallel() {
1315        let store = test_store();
1316        let node = CompositionNode::Parallel {
1317            branches: BTreeMap::from([
1318                ("nums".into(), stage("text_to_num")),
1319                ("bools".into(), stage("bool_pred")),
1320            ]),
1321        };
1322        let result = check_graph(&node, &store);
1323        let check = result.unwrap();
1324        // Input is Record { bools: Text, nums: Text }
1325        // Output is Record { bools: Bool, nums: Number }
1326        assert!(matches!(check.resolved.input, NType::Record(_)));
1327        assert!(matches!(check.resolved.output, NType::Record(_)));
1328    }
1329
1330    #[test]
1331    fn check_branch_valid() {
1332        let store = test_store();
1333        let node = CompositionNode::Branch {
1334            predicate: Box::new(stage("bool_pred")),
1335            if_true: Box::new(stage("text_to_num")),
1336            if_false: Box::new(stage("text_to_text")),
1337        };
1338        // Predicate: Text -> Bool ✓
1339        // Both branches take Text, so input matches
1340        // Outputs are Number and Text, which union into Number | Text
1341        let result = check_graph(&node, &store);
1342        let check = result.unwrap();
1343        assert_eq!(check.resolved.input, NType::Text);
1344    }
1345
1346    #[test]
1347    fn check_retry_transparent() {
1348        let store = test_store();
1349        let node = CompositionNode::Retry {
1350            stage: Box::new(stage("text_to_num")),
1351            max_attempts: 3,
1352            delay_ms: Some(100),
1353        };
1354        let result = check_graph(&node, &store);
1355        let check = result.unwrap();
1356        assert_eq!(check.resolved.input, NType::Text);
1357        assert_eq!(check.resolved.output, NType::Number);
1358    }
1359
1360    #[test]
1361    fn capability_policy_allow_all_passes() {
1362        let mut store = test_store();
1363        let mut stage_net = make_stage("net_stage", NType::Text, NType::Text);
1364        stage_net.capabilities.insert(Capability::Network);
1365        store.put(stage_net).unwrap();
1366
1367        let policy = CapabilityPolicy::allow_all();
1368        let violations = check_capabilities(&stage("net_stage"), &store, &policy);
1369        assert!(violations.is_empty());
1370    }
1371
1372    #[test]
1373    fn capability_policy_restrict_blocks_network() {
1374        let mut store = test_store();
1375        let mut stage_net = make_stage("net_stage2", NType::Text, NType::Text);
1376        stage_net.capabilities.insert(Capability::Network);
1377        store.put(stage_net).unwrap();
1378
1379        let policy = CapabilityPolicy::restrict([Capability::FsRead]);
1380        let violations = check_capabilities(&stage("net_stage2"), &store, &policy);
1381        assert_eq!(violations.len(), 1);
1382        assert!(matches!(violations[0].required, Capability::Network));
1383    }
1384
1385    #[test]
1386    fn capability_policy_restrict_allows_declared() {
1387        let mut store = test_store();
1388        let mut stage_net = make_stage("net_stage3", NType::Text, NType::Text);
1389        stage_net.capabilities.insert(Capability::Network);
1390        store.put(stage_net).unwrap();
1391
1392        let policy = CapabilityPolicy::restrict([Capability::Network]);
1393        let violations = check_capabilities(&stage("net_stage3"), &store, &policy);
1394        assert!(violations.is_empty());
1395    }
1396
1397    #[test]
1398    fn remote_stage_resolves_declared_types() {
1399        let store = test_store();
1400        let node = CompositionNode::RemoteStage {
1401            url: "http://api.example.com".into(),
1402            input: NType::Text,
1403            output: NType::Number,
1404        };
1405        let result = check_graph(&node, &store).unwrap();
1406        assert_eq!(result.resolved.input, NType::Text);
1407        assert_eq!(result.resolved.output, NType::Number);
1408    }
1409
1410    #[test]
1411    fn remote_stage_in_sequential_type_flows() {
1412        let mut store = test_store();
1413        store
1414            .put(make_stage("num_render", NType::Number, NType::Text))
1415            .unwrap();
1416
1417        // Text -> RemoteStage(Text->Number) -> num_render(Number->Text) = Text->Text
1418        let node = CompositionNode::Sequential {
1419            stages: vec![
1420                CompositionNode::RemoteStage {
1421                    url: "http://api:8080".into(),
1422                    input: NType::Text,
1423                    output: NType::Number,
1424                },
1425                CompositionNode::Stage {
1426                    id: StageId("num_render".into()),
1427                    pinning: Pinning::Signature,
1428                    config: None,
1429                },
1430            ],
1431        };
1432        let result = check_graph(&node, &store).unwrap();
1433        assert_eq!(result.resolved.input, NType::Text);
1434        assert_eq!(result.resolved.output, NType::Text);
1435    }
1436
1437    #[test]
1438    fn remote_stage_type_mismatch_is_detected() {
1439        let store = test_store();
1440        // RemoteStage outputs Number, but next stage expects Text
1441        let node = CompositionNode::Sequential {
1442            stages: vec![
1443                CompositionNode::RemoteStage {
1444                    url: "http://api:8080".into(),
1445                    input: NType::Text,
1446                    output: NType::Bool,
1447                },
1448                CompositionNode::Stage {
1449                    id: StageId("text_to_num".into()),
1450                    pinning: Pinning::Signature,
1451                    config: None,
1452                },
1453            ],
1454        };
1455        let result = check_graph(&node, &store);
1456        assert!(result.is_err());
1457        let errors = result.unwrap_err();
1458        assert!(errors
1459            .iter()
1460            .any(|e| matches!(e, GraphTypeError::SequentialTypeMismatch { .. })));
1461    }
1462
1463    // ── Effect inference ────────────────────────────────────────────────────
1464
1465    fn make_stage_with_effects(id: &str, effects: EffectSet) -> Stage {
1466        let mut s = make_stage(id, NType::Any, NType::Any);
1467        s.signature.effects = effects;
1468        s
1469    }
1470
1471    #[test]
1472    fn infer_effects_pure_stage() {
1473        let mut store = MemoryStore::new();
1474        let stage = make_stage_with_effects("pure1", EffectSet::pure());
1475        store.put(stage.clone()).unwrap();
1476        let node = CompositionNode::Stage {
1477            id: StageId("pure1".into()),
1478            pinning: Pinning::Signature,
1479            config: None,
1480        };
1481        let effects = infer_effects(&node, &store);
1482        assert!(effects.contains(&Effect::Pure));
1483        assert!(!effects.contains(&Effect::Network));
1484    }
1485
1486    #[test]
1487    fn infer_effects_union_sequential() {
1488        let mut store = MemoryStore::new();
1489        store
1490            .put(make_stage_with_effects("a", EffectSet::new([Effect::Pure])))
1491            .unwrap();
1492        store
1493            .put(make_stage_with_effects(
1494                "b",
1495                EffectSet::new([Effect::Network]),
1496            ))
1497            .unwrap();
1498        let node = CompositionNode::Sequential {
1499            stages: vec![
1500                CompositionNode::Stage {
1501                    id: StageId("a".into()),
1502                    pinning: Pinning::Signature,
1503                    config: None,
1504                },
1505                CompositionNode::Stage {
1506                    id: StageId("b".into()),
1507                    pinning: Pinning::Signature,
1508                    config: None,
1509                },
1510            ],
1511        };
1512        let effects = infer_effects(&node, &store);
1513        assert!(effects.contains(&Effect::Pure));
1514        assert!(effects.contains(&Effect::Network));
1515    }
1516
1517    #[test]
1518    fn infer_effects_remote_stage_adds_network() {
1519        let store = MemoryStore::new();
1520        let node = CompositionNode::RemoteStage {
1521            url: "http://localhost:8080".into(),
1522            input: NType::Any,
1523            output: NType::Any,
1524        };
1525        let effects = infer_effects(&node, &store);
1526        assert!(effects.contains(&Effect::Network));
1527        assert!(effects.contains(&Effect::Fallible));
1528    }
1529
1530    #[test]
1531    fn infer_effects_missing_stage_adds_unknown() {
1532        let store = MemoryStore::new();
1533        let node = CompositionNode::Stage {
1534            id: StageId("missing".into()),
1535            pinning: Pinning::Signature,
1536            config: None,
1537        };
1538        let effects = infer_effects(&node, &store);
1539        assert!(effects.contains(&Effect::Unknown));
1540    }
1541
1542    // ── Effect policy ───────────────────────────────────────────────────────
1543
1544    #[test]
1545    fn effect_policy_allow_all_never_violates() {
1546        let mut store = MemoryStore::new();
1547        store
1548            .put(make_stage_with_effects(
1549                "net",
1550                EffectSet::new([Effect::Network, Effect::Fallible]),
1551            ))
1552            .unwrap();
1553        let node = CompositionNode::Stage {
1554            id: StageId("net".into()),
1555            pinning: Pinning::Signature,
1556            config: None,
1557        };
1558        let policy = EffectPolicy::allow_all();
1559        assert!(check_effects(&node, &store, &policy).is_empty());
1560    }
1561
1562    #[test]
1563    fn effect_policy_restrict_blocks_network() {
1564        let mut store = MemoryStore::new();
1565        store
1566            .put(make_stage_with_effects(
1567                "net",
1568                EffectSet::new([Effect::Network]),
1569            ))
1570            .unwrap();
1571        let node = CompositionNode::Stage {
1572            id: StageId("net".into()),
1573            pinning: Pinning::Signature,
1574            config: None,
1575        };
1576        let policy = EffectPolicy::restrict([EffectKind::Pure]);
1577        let violations = check_effects(&node, &store, &policy);
1578        assert!(!violations.is_empty());
1579        assert!(violations[0].message.contains("network"));
1580    }
1581
1582    #[test]
1583    fn effect_policy_restrict_allows_matching_effect() {
1584        let mut store = MemoryStore::new();
1585        store
1586            .put(make_stage_with_effects(
1587                "llm",
1588                EffectSet::new([Effect::Llm {
1589                    model: "gpt-4o".into(),
1590                }]),
1591            ))
1592            .unwrap();
1593        let node = CompositionNode::Stage {
1594            id: StageId("llm".into()),
1595            pinning: Pinning::Signature,
1596            config: None,
1597        };
1598        let policy = EffectPolicy::restrict([EffectKind::Llm]);
1599        assert!(check_effects(&node, &store, &policy).is_empty());
1600    }
1601}