1use super::{ExecutionError, StageExecutor};
2use crate::executor::pure_cache::PureStageCache;
3use crate::lagrange::CompositionNode;
4use crate::trace::{CompositionTrace, StageStatus, StageTrace, TraceStatus};
5use chrono::Utc;
6use noether_core::stage::StageId;
7use serde_json::Value;
8use sha2::{Digest, Sha256};
9use std::time::Instant;
10
11#[derive(Debug)]
13pub struct CompositionResult {
14 pub output: Value,
15 pub trace: CompositionTrace,
16 pub spent_cents: u64,
20}
21
22pub fn run_composition<E: StageExecutor + Sync>(
26 node: &CompositionNode,
27 input: &Value,
28 executor: &E,
29 composition_id: &str,
30) -> Result<CompositionResult, ExecutionError> {
31 run_composition_with_cache(node, input, executor, composition_id, None)
32}
33
34pub fn run_composition_with_cache<E: StageExecutor + Sync>(
36 node: &CompositionNode,
37 input: &Value,
38 executor: &E,
39 composition_id: &str,
40 cache: Option<&mut PureStageCache>,
41) -> Result<CompositionResult, ExecutionError> {
42 let start = Instant::now();
43 let mut stage_traces = Vec::new();
44 let mut step_counter = 0;
45
46 let mut owned_cache;
47 let cache_ref: &mut Option<&mut PureStageCache>;
48 let mut none_holder: Option<&mut PureStageCache> = None;
49
50 if let Some(c) = cache {
51 owned_cache = Some(c);
52 cache_ref = &mut owned_cache;
53 } else {
54 cache_ref = &mut none_holder;
55 }
56
57 let output = execute_node(
58 node,
59 input,
60 executor,
61 &mut stage_traces,
62 &mut step_counter,
63 cache_ref,
64 )?;
65
66 let duration_ms = start.elapsed().as_millis() as u64;
67 let has_failures = stage_traces
68 .iter()
69 .any(|t| matches!(t.status, StageStatus::Failed { .. }));
70
71 let trace = CompositionTrace {
72 composition_id: composition_id.into(),
73 started_at: Utc::now().to_rfc3339(),
74 duration_ms,
75 status: if has_failures {
76 TraceStatus::Failed
77 } else {
78 TraceStatus::Ok
79 },
80 stages: stage_traces,
81 security_events: Vec::new(),
82 warnings: Vec::new(),
83 };
84
85 Ok(CompositionResult {
86 output,
87 trace,
88 spent_cents: 0,
89 })
90}
91
92fn execute_node<E: StageExecutor + Sync>(
93 node: &CompositionNode,
94 input: &Value,
95 executor: &E,
96 traces: &mut Vec<StageTrace>,
97 step_counter: &mut usize,
98 cache: &mut Option<&mut PureStageCache>,
99) -> Result<Value, ExecutionError> {
100 match node {
101 CompositionNode::Stage {
102 id,
103 pinning: _, config,
105 } => {
106 let merged = if let Some(cfg) = config {
107 let mut obj = match input {
108 Value::Object(map) => map.clone(),
109 other => {
110 let mut m = serde_json::Map::new();
111 let data_key = [
112 "items", "text", "data", "input", "records", "train", "document",
113 "html", "csv", "json_str",
114 ]
115 .iter()
116 .find(|k| !cfg.contains_key(**k))
117 .unwrap_or(&"items");
118 m.insert(data_key.to_string(), other.clone());
119 m
120 }
121 };
122 for (k, v) in cfg {
123 obj.insert(k.clone(), v.clone());
124 }
125 Value::Object(obj)
126 } else {
127 input.clone()
128 };
129 execute_stage(id, &merged, executor, traces, step_counter, cache)
130 }
131 CompositionNode::Const { value } => Ok(value.clone()),
132 CompositionNode::Sequential { stages } => {
133 let mut current = input.clone();
134 for stage in stages {
135 current = execute_node(stage, ¤t, executor, traces, step_counter, cache)?;
136 }
137 Ok(current)
138 }
139 CompositionNode::Parallel { branches } => {
140 let branch_data: Vec<(&str, &CompositionNode, Value)> = branches
146 .iter()
147 .map(|(name, branch)| {
148 let branch_input = if let Value::Object(ref obj) = input {
149 obj.get(name).cloned().unwrap_or_else(|| input.clone())
150 } else {
151 input.clone()
152 };
153 (name.as_str(), branch, branch_input)
154 })
155 .collect();
156
157 let branch_results = std::thread::scope(|s| {
161 let handles: Vec<_> = branch_data
162 .iter()
163 .map(|(name, branch, branch_input)| {
164 s.spawn(move || {
165 let mut branch_traces = Vec::new();
166 let mut branch_counter = 0usize;
167 let result = execute_node(
168 branch,
169 branch_input,
170 executor,
171 &mut branch_traces,
172 &mut branch_counter,
173 &mut None,
174 );
175 (*name, result, branch_traces)
176 })
177 })
178 .collect();
179 handles
180 .into_iter()
181 .map(|h| h.join().expect("parallel branch panicked"))
182 .collect::<Vec<_>>()
183 });
184
185 let mut output_fields = serde_json::Map::new();
186 for (name, result, branch_traces) in branch_results {
187 let branch_output = result?;
188 output_fields.insert(name.to_string(), branch_output);
189 traces.extend(branch_traces);
190 }
191 Ok(Value::Object(output_fields))
192 }
193 CompositionNode::Branch {
194 predicate,
195 if_true,
196 if_false,
197 } => {
198 let pred_result =
199 execute_node(predicate, input, executor, traces, step_counter, cache)?;
200 let condition = match &pred_result {
201 Value::Bool(b) => *b,
202 _ => false,
203 };
204 if condition {
205 execute_node(if_true, input, executor, traces, step_counter, cache)
206 } else {
207 execute_node(if_false, input, executor, traces, step_counter, cache)
208 }
209 }
210 CompositionNode::Fanout { source, targets } => {
211 let source_output = execute_node(source, input, executor, traces, step_counter, cache)?;
212 let mut results = Vec::new();
213 for target in targets {
214 let result = execute_node(
215 target,
216 &source_output,
217 executor,
218 traces,
219 step_counter,
220 cache,
221 )?;
222 results.push(result);
223 }
224 Ok(Value::Array(results))
225 }
226 CompositionNode::Merge { sources, target } => {
227 let mut merged = serde_json::Map::new();
228 for (i, source) in sources.iter().enumerate() {
229 let source_input = if let Value::Object(ref obj) = input {
230 obj.get(&format!("source_{i}"))
231 .cloned()
232 .unwrap_or(Value::Null)
233 } else {
234 input.clone()
235 };
236 let result =
237 execute_node(source, &source_input, executor, traces, step_counter, cache)?;
238 merged.insert(format!("source_{i}"), result);
239 }
240 execute_node(
241 target,
242 &Value::Object(merged),
243 executor,
244 traces,
245 step_counter,
246 cache,
247 )
248 }
249 CompositionNode::Retry {
250 stage,
251 max_attempts,
252 ..
253 } => {
254 let mut last_err = None;
255 for _ in 0..*max_attempts {
256 match execute_node(stage, input, executor, traces, step_counter, cache) {
257 Ok(output) => return Ok(output),
258 Err(e) => last_err = Some(e),
259 }
260 }
261 Err(last_err.unwrap_or(ExecutionError::RetryExhausted {
262 stage_id: StageId("unknown".into()),
263 attempts: *max_attempts,
264 }))
265 }
266 CompositionNode::RemoteStage { url, .. } => execute_remote_stage(url, input),
267 CompositionNode::Let { bindings, body } => {
268 let bindings_vec: Vec<(&str, &CompositionNode)> =
271 bindings.iter().map(|(n, b)| (n.as_str(), b)).collect();
272
273 let binding_results = std::thread::scope(|s| {
274 let handles: Vec<_> = bindings_vec
275 .iter()
276 .map(|(name, node)| {
277 s.spawn(move || {
278 let mut bt = Vec::new();
279 let mut bc = 0usize;
280 let r =
281 execute_node(node, input, executor, &mut bt, &mut bc, &mut None);
282 (*name, r, bt)
283 })
284 })
285 .collect();
286 handles
287 .into_iter()
288 .map(|h| h.join().expect("Let binding panicked"))
289 .collect::<Vec<_>>()
290 });
291
292 let mut merged = match input {
294 Value::Object(map) => map.clone(),
295 _ => serde_json::Map::new(),
296 };
297 for (name, result, branch_traces) in binding_results {
298 let value = result?;
299 merged.insert(name.to_string(), value);
300 traces.extend(branch_traces);
301 }
302
303 let body_input = Value::Object(merged);
304 execute_node(body, &body_input, executor, traces, step_counter, cache)
305 }
306 }
307}
308
309fn execute_stage<E: StageExecutor + Sync>(
310 id: &StageId,
311 input: &Value,
312 executor: &E,
313 traces: &mut Vec<StageTrace>,
314 step_counter: &mut usize,
315 cache: &mut Option<&mut PureStageCache>,
316) -> Result<Value, ExecutionError> {
317 let step_index = *step_counter;
318 *step_counter += 1;
319 let start = Instant::now();
320
321 let input_hash = hash_value(input);
322
323 if let Some(ref mut c) = cache {
325 if let Some(cached_output) = c.get(id, input) {
326 let output = cached_output.clone();
327 let duration_ms = start.elapsed().as_millis() as u64;
328 traces.push(StageTrace {
329 stage_id: id.clone(),
330 step_index,
331 status: StageStatus::Ok,
332 duration_ms,
333 input_hash: Some(input_hash),
334 output_hash: Some(hash_value(&output)),
335 });
336 return Ok(output);
337 }
338 }
339
340 match executor.execute(id, input) {
341 Ok(output) => {
342 let output_hash = hash_value(&output);
343 let duration_ms = start.elapsed().as_millis() as u64;
344 traces.push(StageTrace {
345 stage_id: id.clone(),
346 step_index,
347 status: StageStatus::Ok,
348 duration_ms,
349 input_hash: Some(input_hash),
350 output_hash: Some(output_hash),
351 });
352 if let Some(ref mut c) = cache {
354 c.put(id, input, output.clone());
355 }
356 Ok(output)
357 }
358 Err(e) => {
359 let duration_ms = start.elapsed().as_millis() as u64;
360 traces.push(StageTrace {
361 stage_id: id.clone(),
362 step_index,
363 status: StageStatus::Failed {
364 code: "EXECUTION_ERROR".into(),
365 message: format!("{e}"),
366 },
367 duration_ms,
368 input_hash: Some(input_hash),
369 output_hash: None,
370 });
371 Err(e)
372 }
373 }
374}
375
376fn hash_value(value: &Value) -> String {
377 let bytes = serde_json::to_vec(value).unwrap_or_default();
378 let hash = Sha256::digest(&bytes);
379 hex::encode(hash)
380}
381
382fn execute_remote_stage(url: &str, input: &Value) -> Result<Value, ExecutionError> {
390 #[cfg(feature = "native")]
391 {
392 use reqwest::blocking::Client;
393
394 let client = Client::new();
395 let body = serde_json::json!({ "input": input });
396 let resp =
397 client
398 .post(url)
399 .json(&body)
400 .send()
401 .map_err(|e| ExecutionError::RemoteCallFailed {
402 url: url.to_string(),
403 reason: e.to_string(),
404 })?;
405
406 let resp_json: Value = resp.json().map_err(|e| ExecutionError::RemoteCallFailed {
407 url: url.to_string(),
408 reason: format!("invalid JSON response: {e}"),
409 })?;
410
411 if resp_json.get("ok") == Some(&Value::Bool(false)) {
416 let reason = resp_json
417 .get("error")
418 .and_then(|e| e.as_str())
419 .unwrap_or("remote reported ok=false without error message")
420 .to_string();
421 return Err(ExecutionError::RemoteCallFailed {
422 url: url.to_string(),
423 reason,
424 });
425 }
426 resp_json
427 .get("data")
428 .and_then(|d| d.get("output"))
429 .cloned()
430 .ok_or_else(|| ExecutionError::RemoteCallFailed {
431 url: url.to_string(),
432 reason: "response missing data.output field".to_string(),
433 })
434 }
435 #[cfg(not(feature = "native"))]
436 {
437 let _ = (url, input);
438 Err(ExecutionError::RemoteCallFailed {
439 url: url.to_string(),
440 reason: "remote calls are handled by the JS runtime in WASM builds".to_string(),
441 })
442 }
443}
444
445#[cfg(test)]
446mod tests {
447 use super::*;
448 use crate::executor::mock::MockExecutor;
449 use serde_json::json;
450 use std::collections::BTreeMap;
451
452 fn stage(id: &str) -> CompositionNode {
453 CompositionNode::Stage {
454 id: StageId(id.into()),
455 pinning: crate::lagrange::Pinning::Signature,
456 config: None,
457 }
458 }
459
460 #[test]
461 fn run_single_stage() {
462 let executor = MockExecutor::new().with_output(&StageId("a".into()), json!(42));
463 let result = run_composition(&stage("a"), &json!("input"), &executor, "test_comp").unwrap();
464 assert_eq!(result.output, json!(42));
465 assert_eq!(result.trace.stages.len(), 1);
466 assert!(matches!(result.trace.status, TraceStatus::Ok));
467 }
468
469 #[test]
470 fn run_sequential() {
471 let executor = MockExecutor::new()
472 .with_output(&StageId("a".into()), json!("mid"))
473 .with_output(&StageId("b".into()), json!("final"));
474 let node = CompositionNode::Sequential {
475 stages: vec![stage("a"), stage("b")],
476 };
477 let result = run_composition(&node, &json!("start"), &executor, "test").unwrap();
478 assert_eq!(result.output, json!("final"));
479 assert_eq!(result.trace.stages.len(), 2);
480 }
481
482 #[test]
483 fn run_parallel() {
484 let executor = MockExecutor::new()
485 .with_output(&StageId("s1".into()), json!("r1"))
486 .with_output(&StageId("s2".into()), json!("r2"));
487 let node = CompositionNode::Parallel {
488 branches: BTreeMap::from([("left".into(), stage("s1")), ("right".into(), stage("s2"))]),
489 };
490 let result = run_composition(&node, &json!({}), &executor, "test").unwrap();
491 assert_eq!(result.output, json!({"left": "r1", "right": "r2"}));
492 }
493
494 #[test]
495 fn run_branch_true() {
496 let executor = MockExecutor::new()
497 .with_output(&StageId("pred".into()), json!(true))
498 .with_output(&StageId("yes".into()), json!("YES"))
499 .with_output(&StageId("no".into()), json!("NO"));
500 let node = CompositionNode::Branch {
501 predicate: Box::new(stage("pred")),
502 if_true: Box::new(stage("yes")),
503 if_false: Box::new(stage("no")),
504 };
505 let result = run_composition(&node, &json!("input"), &executor, "test").unwrap();
506 assert_eq!(result.output, json!("YES"));
507 }
508
509 #[test]
510 fn run_branch_false() {
511 let executor = MockExecutor::new()
512 .with_output(&StageId("pred".into()), json!(false))
513 .with_output(&StageId("yes".into()), json!("YES"))
514 .with_output(&StageId("no".into()), json!("NO"));
515 let node = CompositionNode::Branch {
516 predicate: Box::new(stage("pred")),
517 if_true: Box::new(stage("yes")),
518 if_false: Box::new(stage("no")),
519 };
520 let result = run_composition(&node, &json!("input"), &executor, "test").unwrap();
521 assert_eq!(result.output, json!("NO"));
522 }
523
524 #[test]
525 fn run_fanout() {
526 let executor = MockExecutor::new()
527 .with_output(&StageId("src".into()), json!("data"))
528 .with_output(&StageId("t1".into()), json!("r1"))
529 .with_output(&StageId("t2".into()), json!("r2"));
530 let node = CompositionNode::Fanout {
531 source: Box::new(stage("src")),
532 targets: vec![stage("t1"), stage("t2")],
533 };
534 let result = run_composition(&node, &json!("in"), &executor, "test").unwrap();
535 assert_eq!(result.output, json!(["r1", "r2"]));
536 }
537
538 #[test]
539 fn trace_has_input_output_hashes() {
540 let executor = MockExecutor::new().with_output(&StageId("a".into()), json!(42));
541 let result = run_composition(&stage("a"), &json!("input"), &executor, "test").unwrap();
542 assert!(result.trace.stages[0].input_hash.is_some());
543 assert!(result.trace.stages[0].output_hash.is_some());
544 }
545}