use noether_core::stage::StageId;
use noether_core::types::NType;
use serde::{Deserialize, Serialize};
use std::collections::BTreeMap;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
#[serde(rename_all = "lowercase")]
pub enum Pinning {
#[default]
Signature,
Both,
}
impl Pinning {
pub fn is_signature(&self) -> bool {
matches!(self, Pinning::Signature)
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
#[serde(tag = "op")]
pub enum CompositionNode {
Stage {
id: StageId,
#[serde(default, skip_serializing_if = "Pinning::is_signature")]
pinning: Pinning,
#[serde(default, skip_serializing_if = "Option::is_none")]
config: Option<BTreeMap<String, serde_json::Value>>,
},
RemoteStage {
url: String,
input: NType,
output: NType,
},
Const { value: serde_json::Value },
Sequential { stages: Vec<CompositionNode> },
Parallel {
branches: BTreeMap<String, CompositionNode>,
},
Branch {
predicate: Box<CompositionNode>,
if_true: Box<CompositionNode>,
if_false: Box<CompositionNode>,
},
Fanout {
source: Box<CompositionNode>,
targets: Vec<CompositionNode>,
},
Merge {
sources: Vec<CompositionNode>,
target: Box<CompositionNode>,
},
Retry {
stage: Box<CompositionNode>,
max_attempts: u32,
delay_ms: Option<u64>,
},
Let {
bindings: BTreeMap<String, CompositionNode>,
body: Box<CompositionNode>,
},
}
impl CompositionNode {
pub fn stage(id: impl Into<String>) -> Self {
Self::Stage {
id: StageId(id.into()),
pinning: Pinning::Signature,
config: None,
}
}
pub fn stage_pinned(id: impl Into<String>) -> Self {
Self::Stage {
id: StageId(id.into()),
pinning: Pinning::Both,
config: None,
}
}
}
pub fn resolve_stage_ref<'a, S>(
id: &StageId,
pinning: Pinning,
store: &'a S,
) -> Option<&'a noether_core::stage::Stage>
where
S: noether_store::StageStore + ?Sized,
{
use noether_core::stage::{SignatureId, StageLifecycle};
match pinning {
Pinning::Signature => {
let sig = SignatureId(id.0.clone());
if let Some(stage) = store.get_by_signature(&sig) {
return Some(stage);
}
match store.get(id).ok().flatten() {
Some(s) if matches!(s.lifecycle, StageLifecycle::Active) => Some(s),
_ => None,
}
}
Pinning::Both => store.get(id).ok().flatten(),
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct CompositionGraph {
pub description: String,
pub root: CompositionNode,
pub version: String,
}
impl CompositionGraph {
pub fn new(description: impl Into<String>, root: CompositionNode) -> Self {
Self {
description: description.into(),
root,
version: "0.1.0".into(),
}
}
}
pub fn collect_stage_ids(node: &CompositionNode) -> Vec<&StageId> {
let mut ids = Vec::new();
collect_ids_recursive(node, &mut ids);
ids
}
fn collect_ids_recursive<'a>(node: &'a CompositionNode, ids: &mut Vec<&'a StageId>) {
match node {
CompositionNode::Stage { id, .. } => ids.push(id),
CompositionNode::RemoteStage { .. } => {} CompositionNode::Const { .. } => {} CompositionNode::Sequential { stages } => {
for s in stages {
collect_ids_recursive(s, ids);
}
}
CompositionNode::Parallel { branches } => {
for b in branches.values() {
collect_ids_recursive(b, ids);
}
}
CompositionNode::Branch {
predicate,
if_true,
if_false,
} => {
collect_ids_recursive(predicate, ids);
collect_ids_recursive(if_true, ids);
collect_ids_recursive(if_false, ids);
}
CompositionNode::Fanout { source, targets } => {
collect_ids_recursive(source, ids);
for t in targets {
collect_ids_recursive(t, ids);
}
}
CompositionNode::Merge { sources, target } => {
for s in sources {
collect_ids_recursive(s, ids);
}
collect_ids_recursive(target, ids);
}
CompositionNode::Retry { stage, .. } => {
collect_ids_recursive(stage, ids);
}
CompositionNode::Let { bindings, body } => {
for b in bindings.values() {
collect_ids_recursive(b, ids);
}
collect_ids_recursive(body, ids);
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
fn stage(id: &str) -> CompositionNode {
CompositionNode::Stage {
id: StageId(id.into()),
pinning: Pinning::Signature,
config: None,
}
}
#[test]
fn serde_stage_round_trip() {
let node = stage("abc123");
let json = serde_json::to_string(&node).unwrap();
let parsed: CompositionNode = serde_json::from_str(&json).unwrap();
assert_eq!(node, parsed);
}
#[test]
fn serde_sequential() {
let node = CompositionNode::Sequential {
stages: vec![stage("a"), stage("b"), stage("c")],
};
let json = serde_json::to_string_pretty(&node).unwrap();
let parsed: CompositionNode = serde_json::from_str(&json).unwrap();
assert_eq!(node, parsed);
}
#[test]
fn serde_parallel() {
let mut branches = BTreeMap::new();
branches.insert("left".into(), stage("a"));
branches.insert("right".into(), stage("b"));
let node = CompositionNode::Parallel { branches };
let json = serde_json::to_string(&node).unwrap();
let parsed: CompositionNode = serde_json::from_str(&json).unwrap();
assert_eq!(node, parsed);
}
#[test]
fn serde_branch() {
let node = CompositionNode::Branch {
predicate: Box::new(stage("pred")),
if_true: Box::new(stage("yes")),
if_false: Box::new(stage("no")),
};
let json = serde_json::to_string(&node).unwrap();
let parsed: CompositionNode = serde_json::from_str(&json).unwrap();
assert_eq!(node, parsed);
}
#[test]
fn serde_retry() {
let node = CompositionNode::Retry {
stage: Box::new(stage("fallible")),
max_attempts: 3,
delay_ms: Some(500),
};
let json = serde_json::to_string(&node).unwrap();
let parsed: CompositionNode = serde_json::from_str(&json).unwrap();
assert_eq!(node, parsed);
}
#[test]
fn serde_full_graph() {
let graph = CompositionGraph::new(
"test pipeline",
CompositionNode::Sequential {
stages: vec![stage("parse"), stage("transform"), stage("output")],
},
);
let json = serde_json::to_string_pretty(&graph).unwrap();
let parsed: CompositionGraph = serde_json::from_str(&json).unwrap();
assert_eq!(graph, parsed);
}
#[test]
fn serde_nested_composition() {
let node = CompositionNode::Sequential {
stages: vec![
stage("input"),
CompositionNode::Retry {
stage: Box::new(CompositionNode::Sequential {
stages: vec![stage("a"), stage("b")],
}),
max_attempts: 2,
delay_ms: None,
},
stage("output"),
],
};
let json = serde_json::to_string(&node).unwrap();
let parsed: CompositionNode = serde_json::from_str(&json).unwrap();
assert_eq!(node, parsed);
}
#[test]
fn collect_stage_ids_finds_all() {
let node = CompositionNode::Sequential {
stages: vec![
stage("a"),
CompositionNode::Parallel {
branches: BTreeMap::from([("x".into(), stage("b")), ("y".into(), stage("c"))]),
},
stage("d"),
],
};
let ids = collect_stage_ids(&node);
assert_eq!(ids.len(), 4);
}
#[test]
fn json_format_is_tagged() {
let node = stage("abc123");
let v: serde_json::Value = serde_json::to_value(&node).unwrap();
assert_eq!(v["op"], json!("Stage"));
assert_eq!(v["id"], json!("abc123"));
}
#[test]
fn default_pinning_omitted_from_json() {
let node = stage("abc123");
let v: serde_json::Value = serde_json::to_value(&node).unwrap();
assert!(
v.get("pinning").is_none(),
"default Signature pinning should be omitted from JSON, got: {v}"
);
}
#[test]
fn both_pinning_serialises_explicitly() {
let node = CompositionNode::stage_pinned("impl_abc");
let v: serde_json::Value = serde_json::to_value(&node).unwrap();
assert_eq!(v["pinning"], json!("both"));
}
#[test]
fn legacy_graph_without_pinning_deserialises() {
let legacy = json!({
"op": "Stage",
"id": "legacy_hash",
});
let parsed: CompositionNode = serde_json::from_value(legacy).unwrap();
match parsed {
CompositionNode::Stage { id, pinning, .. } => {
assert_eq!(id.0, "legacy_hash");
assert_eq!(pinning, Pinning::Signature);
}
_ => panic!("expected Stage variant"),
}
}
#[test]
fn explicit_both_pinning_deserialises() {
let pinned = json!({
"op": "Stage",
"id": "impl_xyz",
"pinning": "both",
});
let parsed: CompositionNode = serde_json::from_value(pinned).unwrap();
match parsed {
CompositionNode::Stage { pinning, .. } => {
assert_eq!(pinning, Pinning::Both);
}
_ => panic!("expected Stage variant"),
}
}
#[test]
fn serde_remote_stage_round_trip() {
let node = CompositionNode::RemoteStage {
url: "http://localhost:8080".into(),
input: NType::record([("count", NType::Number)]),
output: NType::VNode,
};
let json = serde_json::to_string(&node).unwrap();
let parsed: CompositionNode = serde_json::from_str(&json).unwrap();
assert_eq!(node, parsed);
}
#[test]
fn remote_stage_json_shape() {
let node = CompositionNode::RemoteStage {
url: "http://api.example.com".into(),
input: NType::Text,
output: NType::Number,
};
let v: serde_json::Value = serde_json::to_value(&node).unwrap();
assert_eq!(v["op"], json!("RemoteStage"));
assert_eq!(v["url"], json!("http://api.example.com"));
assert!(v["input"].is_object());
assert!(v["output"].is_object());
}
#[test]
fn collect_stage_ids_skips_remote_stage() {
let node = CompositionNode::Sequential {
stages: vec![
stage("local-a"),
CompositionNode::RemoteStage {
url: "http://remote".into(),
input: NType::Text,
output: NType::Text,
},
stage("local-b"),
],
};
let ids = collect_stage_ids(&node);
assert_eq!(ids.len(), 2);
assert_eq!(ids[0].0, "local-a");
assert_eq!(ids[1].0, "local-b");
}
#[test]
fn remote_stage_in_full_graph_serde() {
let graph = CompositionGraph::new(
"full-stack pipeline",
CompositionNode::Sequential {
stages: vec![
CompositionNode::RemoteStage {
url: "http://api:8080".into(),
input: NType::record([("query", NType::Text)]),
output: NType::List(Box::new(NType::Text)),
},
stage("render"),
],
},
);
let json = serde_json::to_string_pretty(&graph).unwrap();
let parsed: CompositionGraph = serde_json::from_str(&json).unwrap();
assert_eq!(graph, parsed);
}
}