use crate::lagrange::ast::{CompositionNode, Pinning};
use noether_core::stage::{SignatureId, StageId, StageLifecycle};
use noether_store::StageStore;
#[derive(Debug, Clone, PartialEq, thiserror::Error)]
pub enum ResolutionError {
#[error(
"stage node with pinning=signature has id `{signature_id}` — \
no Active stage in the store matches that signature"
)]
SignatureNotFound { signature_id: String },
#[error(
"stage node with pinning=both has id `{implementation_id}` — \
no stage in the store has that implementation ID"
)]
ImplementationNotFound { implementation_id: String },
#[error(
"stage node with pinning=both has id `{implementation_id}` — \
the stage exists but its lifecycle is {lifecycle:?}; only \
Active stages may be referenced"
)]
ImplementationNotActive {
implementation_id: String,
lifecycle: StageLifecycle,
},
}
pub fn resolve_pinning<S>(
node: &mut CompositionNode,
store: &S,
) -> Result<ResolutionReport, ResolutionError>
where
S: StageStore + ?Sized,
{
let mut report = ResolutionReport::default();
resolve_recursive(node, store, &mut report)?;
Ok(report)
}
#[derive(Debug, Clone, PartialEq, Default)]
pub struct ResolutionReport {
pub rewrites: Vec<Rewrite>,
pub warnings: Vec<MultiActiveWarning>,
}
#[derive(Debug, Clone, PartialEq)]
pub struct Rewrite {
pub before: String,
pub after: String,
pub pinning: Pinning,
}
#[derive(Debug, Clone, PartialEq)]
pub struct MultiActiveWarning {
pub signature_id: String,
pub active_implementation_ids: Vec<String>,
pub chosen: String,
}
fn resolve_recursive<S>(
node: &mut CompositionNode,
store: &S,
report: &mut ResolutionReport,
) -> Result<(), ResolutionError>
where
S: StageStore + ?Sized,
{
match node {
CompositionNode::Stage { id, pinning, .. } => {
let before = id.0.clone();
if matches!(*pinning, Pinning::Signature) {
let sig = SignatureId(id.0.clone());
let matches = store.active_stages_with_signature(&sig);
if matches.len() > 1 {
report.warnings.push(MultiActiveWarning {
signature_id: id.0.clone(),
active_implementation_ids: matches.iter().map(|s| s.id.0.clone()).collect(),
chosen: matches[0].id.0.clone(),
});
}
}
let resolved = resolve_single(id, *pinning, store)?;
if resolved.0 != before {
report.rewrites.push(Rewrite {
before,
after: resolved.0.clone(),
pinning: *pinning,
});
*id = resolved;
}
Ok(())
}
CompositionNode::RemoteStage { .. } | CompositionNode::Const { .. } => Ok(()),
CompositionNode::Sequential { stages } => {
for s in stages {
resolve_recursive(s, store, report)?;
}
Ok(())
}
CompositionNode::Parallel { branches } => {
for b in branches.values_mut() {
resolve_recursive(b, store, report)?;
}
Ok(())
}
CompositionNode::Branch {
predicate,
if_true,
if_false,
} => {
resolve_recursive(predicate, store, report)?;
resolve_recursive(if_true, store, report)?;
resolve_recursive(if_false, store, report)?;
Ok(())
}
CompositionNode::Fanout { source, targets } => {
resolve_recursive(source, store, report)?;
for t in targets {
resolve_recursive(t, store, report)?;
}
Ok(())
}
CompositionNode::Merge { sources, target } => {
for s in sources {
resolve_recursive(s, store, report)?;
}
resolve_recursive(target, store, report)?;
Ok(())
}
CompositionNode::Retry { stage, .. } => resolve_recursive(stage, store, report),
CompositionNode::Let { bindings, body } => {
for b in bindings.values_mut() {
resolve_recursive(b, store, report)?;
}
resolve_recursive(body, store, report)
}
}
}
fn resolve_single<S>(id: &StageId, pinning: Pinning, store: &S) -> Result<StageId, ResolutionError>
where
S: StageStore + ?Sized,
{
match pinning {
Pinning::Signature => {
let sig = SignatureId(id.0.clone());
if let Some(stage) = store.get_by_signature(&sig) {
return Ok(stage.id.clone());
}
if let Ok(Some(stage)) = store.get(id) {
if matches!(stage.lifecycle, StageLifecycle::Active) {
return Ok(stage.id.clone());
}
}
Err(ResolutionError::SignatureNotFound {
signature_id: id.0.clone(),
})
}
Pinning::Both => match store.get(id) {
Ok(Some(stage)) => match &stage.lifecycle {
StageLifecycle::Active => Ok(stage.id.clone()),
other => Err(ResolutionError::ImplementationNotActive {
implementation_id: id.0.clone(),
lifecycle: other.clone(),
}),
},
_ => Err(ResolutionError::ImplementationNotFound {
implementation_id: id.0.clone(),
}),
},
}
}
#[cfg(test)]
mod tests {
use super::*;
use noether_core::effects::EffectSet;
use noether_core::stage::{CostEstimate, SignatureId, Stage, StageSignature};
use noether_core::types::NType;
use noether_store::MemoryStore;
use std::collections::{BTreeMap, BTreeSet};
fn make_stage(impl_id: &str, sig_id: Option<&str>, lifecycle: StageLifecycle) -> Stage {
Stage {
id: StageId(impl_id.into()),
signature_id: sig_id.map(|s| SignatureId(s.into())),
signature: StageSignature {
input: NType::Text,
output: NType::Number,
effects: EffectSet::pure(),
implementation_hash: format!("impl_{impl_id}"),
},
capabilities: BTreeSet::new(),
cost: CostEstimate {
time_ms_p50: None,
tokens_est: None,
memory_mb: None,
},
description: "test".into(),
examples: vec![],
lifecycle,
ed25519_signature: None,
signer_public_key: None,
implementation_code: None,
implementation_language: None,
ui_style: None,
tags: vec![],
aliases: vec![],
name: None,
properties: vec![],
}
}
fn store_with_impl(impl_id: &str, sig_id: &str) -> MemoryStore {
let mut store = MemoryStore::new();
store
.put(make_stage(impl_id, Some(sig_id), StageLifecycle::Active))
.unwrap();
store
}
#[test]
fn signature_pinning_rewrites_to_impl_id() {
let store = store_with_impl("impl_abc", "sig_xyz");
let mut node = CompositionNode::Stage {
id: StageId("sig_xyz".into()),
pinning: Pinning::Signature,
config: None,
};
let report = resolve_pinning(&mut node, &store).unwrap();
match &node {
CompositionNode::Stage { id, pinning, .. } => {
assert_eq!(id.0, "impl_abc", "id should be rewritten to impl hash");
assert_eq!(*pinning, Pinning::Signature);
}
_ => panic!("expected Stage"),
}
assert_eq!(report.rewrites.len(), 1);
assert_eq!(report.rewrites[0].before, "sig_xyz");
assert_eq!(report.rewrites[0].after, "impl_abc");
}
#[test]
fn both_pinning_accepts_matching_impl_id() {
let store = store_with_impl("impl_abc", "sig_xyz");
let mut node = CompositionNode::Stage {
id: StageId("impl_abc".into()),
pinning: Pinning::Both,
config: None,
};
let report = resolve_pinning(&mut node, &store).unwrap();
assert!(report.rewrites.is_empty());
}
#[test]
fn both_pinning_rejects_missing_impl() {
let store = store_with_impl("impl_abc", "sig_xyz");
let mut node = CompositionNode::Stage {
id: StageId("impl_does_not_exist".into()),
pinning: Pinning::Both,
config: None,
};
let err = resolve_pinning(&mut node, &store).unwrap_err();
assert!(matches!(
err,
ResolutionError::ImplementationNotFound { .. }
));
}
#[test]
fn both_pinning_rejects_deprecated_impl() {
let mut store = MemoryStore::new();
store
.put(make_stage(
"impl_old",
Some("sig_xyz"),
StageLifecycle::Active,
))
.unwrap();
store
.put(make_stage(
"impl_new",
Some("sig_xyz"),
StageLifecycle::Active,
))
.unwrap();
assert!(matches!(
store
.get(&StageId("impl_old".into()))
.unwrap()
.unwrap()
.lifecycle,
StageLifecycle::Deprecated { .. }
));
let mut node = CompositionNode::Stage {
id: StageId("impl_old".into()),
pinning: Pinning::Both,
config: None,
};
let err = resolve_pinning(&mut node, &store).unwrap_err();
assert!(matches!(
err,
ResolutionError::ImplementationNotActive { .. }
));
}
#[test]
fn signature_pinning_rejects_missing_signature() {
let store = store_with_impl("impl_abc", "sig_xyz");
let mut node = CompositionNode::Stage {
id: StageId("sig_does_not_exist".into()),
pinning: Pinning::Signature,
config: None,
};
let err = resolve_pinning(&mut node, &store).unwrap_err();
assert!(matches!(err, ResolutionError::SignatureNotFound { .. }));
}
#[test]
fn signature_pinning_falls_back_to_impl_id_for_legacy_flows() {
let store = store_with_impl("impl_abc", "sig_xyz");
let mut node = CompositionNode::Stage {
id: StageId("impl_abc".into()),
pinning: Pinning::Signature,
config: None,
};
let report = resolve_pinning(&mut node, &store).unwrap();
assert!(report.rewrites.is_empty());
}
#[test]
fn walks_into_nested_sequential() {
let store = store_with_impl("impl_abc", "sig_xyz");
let mut node = CompositionNode::Sequential {
stages: vec![
CompositionNode::Stage {
id: StageId("sig_xyz".into()),
pinning: Pinning::Signature,
config: None,
},
CompositionNode::Stage {
id: StageId("sig_xyz".into()),
pinning: Pinning::Signature,
config: None,
},
],
};
let report = resolve_pinning(&mut node, &store).unwrap();
assert_eq!(report.rewrites.len(), 2);
}
#[test]
fn walks_into_parallel_branches() {
let store = store_with_impl("impl_abc", "sig_xyz");
let mut branches = BTreeMap::new();
branches.insert(
"a".into(),
CompositionNode::Stage {
id: StageId("sig_xyz".into()),
pinning: Pinning::Signature,
config: None,
},
);
branches.insert(
"b".into(),
CompositionNode::Stage {
id: StageId("sig_xyz".into()),
pinning: Pinning::Signature,
config: None,
},
);
let mut node = CompositionNode::Parallel { branches };
let report = resolve_pinning(&mut node, &store).unwrap();
assert_eq!(report.rewrites.len(), 2);
}
#[test]
fn walks_into_branch_predicate_and_arms() {
let store = store_with_impl("impl_abc", "sig_xyz");
let sig = || CompositionNode::Stage {
id: StageId("sig_xyz".into()),
pinning: Pinning::Signature,
config: None,
};
let mut node = CompositionNode::Branch {
predicate: Box::new(sig()),
if_true: Box::new(sig()),
if_false: Box::new(sig()),
};
let report = resolve_pinning(&mut node, &store).unwrap();
assert_eq!(report.rewrites.len(), 3);
}
#[test]
fn walks_into_fanout_source_and_targets() {
let store = store_with_impl("impl_abc", "sig_xyz");
let sig = || CompositionNode::Stage {
id: StageId("sig_xyz".into()),
pinning: Pinning::Signature,
config: None,
};
let mut node = CompositionNode::Fanout {
source: Box::new(sig()),
targets: vec![sig(), sig(), sig()],
};
let report = resolve_pinning(&mut node, &store).unwrap();
assert_eq!(report.rewrites.len(), 4);
}
#[test]
fn walks_into_merge_sources_and_target() {
let store = store_with_impl("impl_abc", "sig_xyz");
let sig = || CompositionNode::Stage {
id: StageId("sig_xyz".into()),
pinning: Pinning::Signature,
config: None,
};
let mut node = CompositionNode::Merge {
sources: vec![sig(), sig()],
target: Box::new(sig()),
};
let report = resolve_pinning(&mut node, &store).unwrap();
assert_eq!(report.rewrites.len(), 3);
}
#[test]
fn walks_into_let_bindings_and_body() {
let store = store_with_impl("impl_abc", "sig_xyz");
let sig = || CompositionNode::Stage {
id: StageId("sig_xyz".into()),
pinning: Pinning::Signature,
config: None,
};
let mut bindings = BTreeMap::new();
bindings.insert("a".into(), sig());
bindings.insert("b".into(), sig());
let mut node = CompositionNode::Let {
bindings,
body: Box::new(sig()),
};
let report = resolve_pinning(&mut node, &store).unwrap();
assert_eq!(report.rewrites.len(), 3);
}
#[test]
fn walks_into_retry_inner_stage() {
let store = store_with_impl("impl_abc", "sig_xyz");
let mut node = CompositionNode::Retry {
stage: Box::new(CompositionNode::Stage {
id: StageId("sig_xyz".into()),
pinning: Pinning::Signature,
config: None,
}),
max_attempts: 3,
delay_ms: None,
};
let report = resolve_pinning(&mut node, &store).unwrap();
assert_eq!(report.rewrites.len(), 1);
}
#[test]
fn stops_at_first_error_leaves_partial_rewrites() {
let store = store_with_impl("impl_abc", "sig_xyz");
let mut node = CompositionNode::Sequential {
stages: vec![
CompositionNode::Stage {
id: StageId("sig_xyz".into()),
pinning: Pinning::Signature,
config: None,
},
CompositionNode::Stage {
id: StageId("sig_missing".into()),
pinning: Pinning::Signature,
config: None,
},
],
};
let err = resolve_pinning(&mut node, &store).unwrap_err();
assert!(matches!(err, ResolutionError::SignatureNotFound { .. }));
match &node {
CompositionNode::Sequential { stages } => match &stages[0] {
CompositionNode::Stage { id, .. } => assert_eq!(id.0, "impl_abc"),
_ => panic!(),
},
_ => panic!(),
}
}
#[test]
fn idempotent_on_already_resolved_graph() {
let store = store_with_impl("impl_abc", "sig_xyz");
let mut node = CompositionNode::Stage {
id: StageId("sig_xyz".into()),
pinning: Pinning::Signature,
config: None,
};
let first = resolve_pinning(&mut node, &store).unwrap();
let second = resolve_pinning(&mut node, &store).unwrap();
assert_eq!(first.rewrites.len(), 1);
assert!(second.rewrites.is_empty());
}
#[test]
fn multi_active_signature_emits_warning() {
let mut store = MemoryStore::new();
store
.put(make_stage(
"impl_a",
Some("shared_sig"),
StageLifecycle::Active,
))
.unwrap();
let extra = make_stage("impl_b", Some("shared_sig"), StageLifecycle::Active);
store.inject_raw_for_testing(extra);
let mut node = CompositionNode::Stage {
id: StageId("shared_sig".into()),
pinning: Pinning::Signature,
config: None,
};
let report = resolve_pinning(&mut node, &store).unwrap();
assert_eq!(report.warnings.len(), 1);
let w = &report.warnings[0];
assert_eq!(w.signature_id, "shared_sig");
assert_eq!(w.active_implementation_ids.len(), 2);
assert_eq!(w.chosen, "impl_a");
}
}