use crate::flow_dispatcher::pure_shape::{run_pure_shape, PureShapeStep};
use crate::flow_dispatcher::{DispatchCtx, DispatchError, NodeOutcome};
use crate::flow_execution_event::{now_ms, FlowExecutionEvent};
use crate::ir_nodes::{
IRAggregateStep, IRAssociateStep, IRCorroborateStep, IRExploreStep, IRFocusStep,
IRForgeBlock, IRIngestStep, IRNavigateStep, IRRecallStep, IRRememberStep,
};
pub async fn run_remember(
node: &IRRememberStep,
ctx: &mut DispatchCtx,
) -> Result<NodeOutcome, DispatchError> {
if ctx.cancel.is_cancelled() {
return Err(DispatchError::UpstreamCancelled);
}
let step_index = ctx.step_counter;
ctx.step_counter += 1;
let value = ctx
.let_bindings
.get(&node.expression)
.cloned()
.unwrap_or_else(|| node.expression.clone());
emit_step_start(ctx, &step_name_for_remember(node), step_index, "remember")?;
ctx.let_bindings
.insert(node.memory_target.clone(), value.clone());
if let Some(backend) = ctx.pem_backend.clone() {
write_through_pem(&backend, ctx, &node.memory_target, &value).await?;
}
emit_step_complete(
ctx,
&step_name_for_remember(node),
step_index,
&value,
0,
)?;
Ok(NodeOutcome::Completed {
output: value,
tokens_emitted: 0,
step_index,
})
}
fn step_name_for_remember(node: &IRRememberStep) -> String {
if node.memory_target.is_empty() {
"Remember".to_string()
} else {
node.memory_target.clone()
}
}
async fn write_through_pem(
backend: &std::sync::Arc<dyn crate::pem::PersistenceBackend>,
ctx: &DispatchCtx,
key: &str,
value: &str,
) -> Result<(), DispatchError> {
use crate::pem::state::{CognitiveState, MemoryEntry};
use chrono::{Duration as ChronoDuration, Utc};
let mut state = match backend.restore(&ctx.session_id).await {
Ok(s) => s,
Err(_) => CognitiveState::new(&ctx.session_id, &ctx.tenant_id, &ctx.flow_name),
};
state.short_term_memory.push(MemoryEntry {
key: key.to_string(),
payload: serde_json::Value::String(value.to_string()),
symbolic_refs: Vec::new(),
stored_at: Utc::now(),
});
state.last_updated_at = Utc::now();
backend
.persist(&ctx.session_id, &state, ChronoDuration::hours(24))
.await
.map_err(|e| DispatchError::BackendError {
name: "pem".to_string(),
message: format!("{e:?}"),
})?;
Ok(())
}
pub async fn run_recall(
node: &IRRecallStep,
ctx: &mut DispatchCtx,
) -> Result<NodeOutcome, DispatchError> {
if ctx.cancel.is_cancelled() {
return Err(DispatchError::UpstreamCancelled);
}
let step_index = ctx.step_counter;
ctx.step_counter += 1;
emit_step_start(ctx, &step_name_for_recall(node), step_index, "recall")?;
let resolved = resolve_recall_value(node, ctx).await;
ctx.let_bindings
.insert(node.query.clone(), resolved.clone());
emit_step_complete(
ctx,
&step_name_for_recall(node),
step_index,
&resolved,
0,
)?;
Ok(NodeOutcome::Completed {
output: resolved,
tokens_emitted: 0,
step_index,
})
}
fn step_name_for_recall(node: &IRRecallStep) -> String {
if node.query.is_empty() {
"Recall".to_string()
} else {
node.query.clone()
}
}
async fn resolve_recall_value(node: &IRRecallStep, ctx: &DispatchCtx) -> String {
if let Some(backend) = &ctx.pem_backend {
if let Ok(state) = backend.restore(&ctx.session_id).await {
if let Some(entry) = state
.short_term_memory
.iter()
.rev()
.find(|e| e.key == node.memory_source)
{
if let serde_json::Value::String(s) = &entry.payload {
return s.clone();
}
return entry.payload.to_string();
}
}
}
ctx.let_bindings
.get(&node.memory_source)
.cloned()
.unwrap_or_default()
}
pub async fn run_forge(
_node: &IRForgeBlock,
ctx: &mut DispatchCtx,
) -> Result<NodeOutcome, DispatchError> {
if ctx.cancel.is_cancelled() {
return Err(DispatchError::UpstreamCancelled);
}
let step_index = ctx.step_counter;
ctx.step_counter += 1;
emit_step_start(ctx, "Forge", step_index, "forge")?;
emit_step_complete(ctx, "Forge", step_index, "", 0)?;
Ok(NodeOutcome::Completed {
output: String::new(),
tokens_emitted: 0,
step_index,
})
}
pub async fn run_focus(
node: &IRFocusStep,
ctx: &mut DispatchCtx,
) -> Result<NodeOutcome, DispatchError> {
let shape = PureShapeStep {
name: if node.expression.is_empty() {
"Focus".to_string()
} else {
node.expression.clone()
},
user_prompt: format!("Focus on: {}", node.expression),
framing_addendum: Some(
"You are focusing your attention. Narrow scope to the target; surface what matters most.".into(),
),
kind_slug: "focus",
tools: Vec::new(),
};
run_pure_shape(shape, ctx).await
}
pub async fn run_associate(
node: &IRAssociateStep,
ctx: &mut DispatchCtx,
) -> Result<NodeOutcome, DispatchError> {
let using_clause = if node.using_field.is_empty() {
String::new()
} else {
format!(" using `{}`", node.using_field)
};
let shape = PureShapeStep {
name: if node.left.is_empty() {
"Associate".to_string()
} else {
format!("{}↔{}", node.left, node.right)
},
user_prompt: format!(
"Associate {} with {}{}",
node.left, node.right, using_clause
),
framing_addendum: Some(
"You are associating. Find the meaningful relationship; return a structured link.".into(),
),
kind_slug: "associate",
tools: Vec::new(),
};
run_pure_shape(shape, ctx).await
}
pub async fn run_aggregate(
node: &IRAggregateStep,
ctx: &mut DispatchCtx,
) -> Result<NodeOutcome, DispatchError> {
let group_clause = if node.group_by.is_empty() {
String::new()
} else {
format!(" grouped by [{}]", node.group_by.join(", "))
};
let alias_clause = if node.alias.is_empty() {
String::new()
} else {
format!(" as `{}`", node.alias)
};
let shape = PureShapeStep {
name: if node.target.is_empty() {
"Aggregate".to_string()
} else {
node.target.clone()
},
user_prompt: format!(
"Aggregate {}{}{}",
node.target, group_clause, alias_clause
),
framing_addendum: Some(
"You are aggregating. Group + summarize over the declared dimensions; surface the structure.".into(),
),
kind_slug: "aggregate",
tools: Vec::new(),
};
run_pure_shape(shape, ctx).await
}
pub async fn run_explore(
node: &IRExploreStep,
ctx: &mut DispatchCtx,
) -> Result<NodeOutcome, DispatchError> {
let limit_clause = match node.limit {
Some(n) => format!(" (top {})", n),
None => String::new(),
};
let shape = PureShapeStep {
name: if node.target.is_empty() {
"Explore".to_string()
} else {
node.target.clone()
},
user_prompt: format!("Explore: {}{}", node.target, limit_clause),
framing_addendum: Some(
"You are exploring. Sample broadly; surface the most-relevant directions.".into(),
),
kind_slug: "explore",
tools: Vec::new(),
};
run_pure_shape(shape, ctx).await
}
pub async fn run_ingest(
node: &IRIngestStep,
ctx: &mut DispatchCtx,
) -> Result<NodeOutcome, DispatchError> {
let shape = PureShapeStep {
name: if node.target.is_empty() {
"Ingest".to_string()
} else {
node.target.clone()
},
user_prompt: format!("Ingest from `{}` into `{}`", node.source, node.target),
framing_addendum: Some(
"You are ingesting. Map the source's structure into the target; preserve fidelity.".into(),
),
kind_slug: "ingest",
tools: Vec::new(),
};
run_pure_shape(shape, ctx).await
}
pub async fn run_navigate(
node: &IRNavigateStep,
ctx: &mut DispatchCtx,
) -> Result<NodeOutcome, DispatchError> {
let trail_clause = if node.trail_enabled { " (with trail)" } else { "" };
let shape = PureShapeStep {
name: if node.output_name.is_empty() {
"Navigate".to_string()
} else {
node.output_name.clone()
},
user_prompt: format!(
"Navigate corpus `{}` via PIX `{}` for query: {}{}",
node.corpus_ref, node.pix_ref, node.query, trail_clause
),
framing_addendum: Some(
"You are navigating a PIX (paper §6 hidden state). Trace your reasoning path; surface the corpus regions you crossed.".into(),
),
kind_slug: "navigate",
tools: Vec::new(),
};
run_pure_shape(shape, ctx).await
}
pub async fn run_corroborate(
node: &IRCorroborateStep,
ctx: &mut DispatchCtx,
) -> Result<NodeOutcome, DispatchError> {
let shape = PureShapeStep {
name: if node.output_name.is_empty() {
"Corroborate".to_string()
} else {
node.output_name.clone()
},
user_prompt: format!("Corroborate navigation result `{}`", node.navigate_ref),
framing_addendum: Some(
"You are corroborating. Cross-validate independently; surface agreement strength + disagreements.".into(),
),
kind_slug: "corroborate",
tools: Vec::new(),
};
run_pure_shape(shape, ctx).await
}
fn emit_step_start(
ctx: &mut DispatchCtx,
step_name: &str,
step_index: usize,
step_type: &str,
) -> Result<(), DispatchError> {
ctx.tx
.send(FlowExecutionEvent::StepStart {
step_name: step_name.to_string(),
step_index,
step_type: step_type.to_string(),
timestamp_ms: now_ms(),
})
.map_err(|_| DispatchError::ChannelClosed)
}
fn emit_step_complete(
ctx: &mut DispatchCtx,
step_name: &str,
step_index: usize,
full_output: &str,
tokens_output: u64,
) -> Result<(), DispatchError> {
ctx.tx
.send(FlowExecutionEvent::StepComplete {
step_name: step_name.to_string(),
step_index,
success: true,
full_output: full_output.to_string(),
tokens_input: 0,
tokens_output,
timestamp_ms: now_ms(),
})
.map_err(|_| DispatchError::ChannelClosed)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::cancel_token::CancellationFlag;
use crate::ir_nodes::*;
use crate::pem::InMemoryBackend;
use std::sync::Arc;
use tokio::sync::mpsc;
fn fresh_ctx() -> (
DispatchCtx,
mpsc::UnboundedReceiver<FlowExecutionEvent>,
) {
let (tx, rx) = mpsc::unbounded_channel();
let ctx = DispatchCtx::new(
"TestFlow",
"stub",
"",
CancellationFlag::new(),
tx,
);
(ctx, rx)
}
#[tokio::test]
async fn run_remember_literal_value_binds_to_let_bindings() {
let (mut ctx, _rx) = fresh_ctx();
let node = IRRememberStep {
node_type: "remember",
source_line: 0,
source_column: 0,
expression: "us-east-1".into(),
memory_target: "region".into(),
};
let outcome = run_remember(&node, &mut ctx).await.unwrap();
match outcome {
NodeOutcome::Completed { output, tokens_emitted, .. } => {
assert_eq!(output, "us-east-1");
assert_eq!(tokens_emitted, 0);
}
other => panic!("expected Completed, got {other:?}"),
}
assert_eq!(ctx.let_bindings.get("region").unwrap(), "us-east-1");
}
#[tokio::test]
async fn run_remember_resolves_expression_through_let_bindings() {
let (mut ctx, _rx) = fresh_ctx();
ctx.let_bindings.insert("upstream".into(), "computed-X".into());
let node = IRRememberStep {
node_type: "remember",
source_line: 0,
source_column: 0,
expression: "upstream".into(),
memory_target: "snapshot".into(),
};
run_remember(&node, &mut ctx).await.unwrap();
assert_eq!(ctx.let_bindings.get("snapshot").unwrap(), "computed-X");
}
#[tokio::test]
async fn run_remember_with_pem_persists_to_backend() {
let backend: Arc<dyn crate::pem::PersistenceBackend> =
Arc::new(InMemoryBackend::default());
let (tx, _rx) = mpsc::unbounded_channel();
let mut ctx = DispatchCtx::new(
"F",
"stub",
"",
CancellationFlag::new(),
tx,
)
.with_pem(backend.clone())
.with_session_id("session-1");
let node = IRRememberStep {
node_type: "remember",
source_line: 0,
source_column: 0,
expression: "persisted-value".into(),
memory_target: "key1".into(),
};
run_remember(&node, &mut ctx).await.unwrap();
let state = backend.restore("session-1").await.unwrap();
assert_eq!(state.short_term_memory.len(), 1);
assert_eq!(state.short_term_memory[0].key, "key1");
}
#[tokio::test]
async fn run_recall_from_let_bindings_when_no_pem() {
let (mut ctx, _rx) = fresh_ctx();
ctx.let_bindings.insert("region".into(), "us-east-1".into());
let node = IRRecallStep {
node_type: "recall",
source_line: 0,
source_column: 0,
query: "current_region".into(),
memory_source: "region".into(),
};
let outcome = run_recall(&node, &mut ctx).await.unwrap();
match outcome {
NodeOutcome::Completed { output, .. } => {
assert_eq!(output, "us-east-1");
}
other => panic!("expected Completed, got {other:?}"),
}
assert_eq!(
ctx.let_bindings.get("current_region").unwrap(),
"us-east-1"
);
}
#[tokio::test]
async fn run_recall_from_pem_when_backend_set() {
let backend: Arc<dyn crate::pem::PersistenceBackend> =
Arc::new(InMemoryBackend::default());
let (tx, _rx) = mpsc::unbounded_channel();
let mut ctx = DispatchCtx::new(
"F",
"stub",
"",
CancellationFlag::new(),
tx,
)
.with_pem(backend.clone())
.with_session_id("sess");
run_remember(
&IRRememberStep {
node_type: "remember",
source_line: 0,
source_column: 0,
expression: "value-from-pem".into(),
memory_target: "pem_key".into(),
},
&mut ctx,
)
.await
.unwrap();
let outcome = run_recall(
&IRRecallStep {
node_type: "recall",
source_line: 0,
source_column: 0,
query: "recalled".into(),
memory_source: "pem_key".into(),
},
&mut ctx,
)
.await
.unwrap();
match outcome {
NodeOutcome::Completed { output, .. } => {
assert_eq!(output, "value-from-pem");
}
other => panic!("expected Completed, got {other:?}"),
}
}
#[tokio::test]
async fn run_recall_missing_key_returns_empty_string() {
let (mut ctx, _rx) = fresh_ctx();
let node = IRRecallStep {
node_type: "recall",
source_line: 0,
source_column: 0,
query: "x".into(),
memory_source: "never_set".into(),
};
let outcome = run_recall(&node, &mut ctx).await.unwrap();
match outcome {
NodeOutcome::Completed { output, .. } => assert_eq!(output, ""),
other => panic!("expected Completed, got {other:?}"),
}
}
#[tokio::test]
async fn run_forge_emits_canonical_wire_shape() {
let (mut ctx, mut rx) = fresh_ctx();
let node = IRForgeBlock {
node_type: "forge",
source_line: 0,
source_column: 0,
};
let outcome = run_forge(&node, &mut ctx).await.unwrap();
match outcome {
NodeOutcome::Completed { output, tokens_emitted, .. } => {
assert_eq!(output, "");
assert_eq!(tokens_emitted, 0);
}
other => panic!("expected Completed, got {other:?}"),
}
let mut events = Vec::new();
while let Ok(ev) = rx.try_recv() {
events.push(ev);
}
assert_eq!(events.len(), 2);
match &events[0] {
FlowExecutionEvent::StepStart { step_type, .. } => {
assert_eq!(step_type, "forge");
}
e => panic!("expected StepStart, got {e:?}"),
}
}
#[tokio::test]
async fn run_focus_emits_focus_slug() {
let (mut ctx, mut rx) = fresh_ctx();
let node = IRFocusStep {
node_type: "focus",
source_line: 0,
source_column: 0,
expression: "key_insight".into(),
};
let _ = run_focus(&node, &mut ctx).await.unwrap();
let ev = rx.try_recv().unwrap();
match ev {
FlowExecutionEvent::StepStart { step_type, .. } => {
assert_eq!(step_type, "focus");
}
e => panic!("expected StepStart, got {e:?}"),
}
}
#[tokio::test]
async fn run_associate_emits_associate_slug() {
let (mut ctx, mut rx) = fresh_ctx();
let node = IRAssociateStep {
node_type: "associate",
source_line: 0,
source_column: 0,
left: "A".into(),
right: "B".into(),
using_field: "id".into(),
};
run_associate(&node, &mut ctx).await.unwrap();
let ev = rx.try_recv().unwrap();
match ev {
FlowExecutionEvent::StepStart { step_type, .. } => {
assert_eq!(step_type, "associate");
}
e => panic!("expected StepStart, got {e:?}"),
}
}
#[tokio::test]
async fn run_aggregate_emits_aggregate_slug() {
let (mut ctx, mut rx) = fresh_ctx();
let node = IRAggregateStep {
node_type: "aggregate",
source_line: 0,
source_column: 0,
target: "events".into(),
group_by: vec!["region".into()],
alias: "by_region".into(),
};
run_aggregate(&node, &mut ctx).await.unwrap();
let ev = rx.try_recv().unwrap();
match ev {
FlowExecutionEvent::StepStart { step_type, .. } => {
assert_eq!(step_type, "aggregate");
}
e => panic!("expected StepStart, got {e:?}"),
}
}
#[tokio::test]
async fn run_explore_emits_explore_slug() {
let (mut ctx, mut rx) = fresh_ctx();
let node = IRExploreStep {
node_type: "explore",
source_line: 0,
source_column: 0,
target: "hypothesis_space".into(),
limit: Some(5),
};
run_explore(&node, &mut ctx).await.unwrap();
let ev = rx.try_recv().unwrap();
match ev {
FlowExecutionEvent::StepStart { step_type, .. } => {
assert_eq!(step_type, "explore");
}
e => panic!("expected StepStart, got {e:?}"),
}
}
#[tokio::test]
async fn run_ingest_emits_ingest_slug() {
let (mut ctx, mut rx) = fresh_ctx();
let node = IRIngestStep {
node_type: "ingest",
source_line: 0,
source_column: 0,
source: "external_api".into(),
target: "raw".into(),
};
run_ingest(&node, &mut ctx).await.unwrap();
let ev = rx.try_recv().unwrap();
match ev {
FlowExecutionEvent::StepStart { step_type, .. } => {
assert_eq!(step_type, "ingest");
}
e => panic!("expected StepStart, got {e:?}"),
}
}
#[tokio::test]
async fn run_navigate_emits_navigate_slug() {
let (mut ctx, mut rx) = fresh_ctx();
let node = IRNavigateStep {
node_type: "navigate",
source_line: 0,
source_column: 0,
pix_ref: "main_pix".into(),
corpus_ref: "law_corpus".into(),
query: "interpret_clause".into(),
trail_enabled: true,
output_name: "nav_result".into(),
};
run_navigate(&node, &mut ctx).await.unwrap();
let ev = rx.try_recv().unwrap();
match ev {
FlowExecutionEvent::StepStart { step_type, .. } => {
assert_eq!(step_type, "navigate");
}
e => panic!("expected StepStart, got {e:?}"),
}
}
#[tokio::test]
async fn run_corroborate_emits_corroborate_slug() {
let (mut ctx, mut rx) = fresh_ctx();
let node = IRCorroborateStep {
node_type: "corroborate",
source_line: 0,
source_column: 0,
navigate_ref: "nav_result".into(),
output_name: "validated".into(),
};
run_corroborate(&node, &mut ctx).await.unwrap();
let ev = rx.try_recv().unwrap();
match ev {
FlowExecutionEvent::StepStart { step_type, .. } => {
assert_eq!(step_type, "corroborate");
}
e => panic!("expected StepStart, got {e:?}"),
}
}
#[tokio::test]
async fn every_cognitive_handler_short_circuits_on_cancel() {
let cancel = CancellationFlag::new();
cancel.cancel();
let (tx, _rx) = mpsc::unbounded_channel();
let mut ctx = DispatchCtx::new("F", "stub", "", cancel, tx);
let r = IRRememberStep {
node_type: "remember",
source_line: 0,
source_column: 0,
expression: "x".into(),
memory_target: "y".into(),
};
assert!(matches!(
run_remember(&r, &mut ctx).await,
Err(DispatchError::UpstreamCancelled)
));
let r = IRRecallStep {
node_type: "recall",
source_line: 0,
source_column: 0,
query: "q".into(),
memory_source: "k".into(),
};
assert!(matches!(
run_recall(&r, &mut ctx).await,
Err(DispatchError::UpstreamCancelled)
));
assert!(matches!(
run_forge(
&IRForgeBlock {
node_type: "forge",
source_line: 0,
source_column: 0,
},
&mut ctx,
)
.await,
Err(DispatchError::UpstreamCancelled)
));
assert!(matches!(
run_focus(
&IRFocusStep {
node_type: "focus",
source_line: 0,
source_column: 0,
expression: "x".into(),
},
&mut ctx,
)
.await,
Err(DispatchError::UpstreamCancelled)
));
}
}