1use super::{ExecutionError, StageExecutor};
2use crate::executor::pure_cache::PureStageCache;
3use crate::lagrange::CompositionNode;
4use crate::trace::{CompositionTrace, StageStatus, StageTrace, TraceStatus};
5use chrono::Utc;
6use noether_core::stage::StageId;
7use serde_json::Value;
8use sha2::{Digest, Sha256};
9use std::time::Instant;
10
11#[derive(Debug)]
13pub struct CompositionResult {
14 pub output: Value,
15 pub trace: CompositionTrace,
16 pub spent_cents: u64,
20}
21
22pub fn run_composition<E: StageExecutor + Sync>(
26 node: &CompositionNode,
27 input: &Value,
28 executor: &E,
29 composition_id: &str,
30) -> Result<CompositionResult, ExecutionError> {
31 run_composition_with_cache(node, input, executor, composition_id, None)
32}
33
34pub fn run_composition_with_cache<E: StageExecutor + Sync>(
36 node: &CompositionNode,
37 input: &Value,
38 executor: &E,
39 composition_id: &str,
40 cache: Option<&mut PureStageCache>,
41) -> Result<CompositionResult, ExecutionError> {
42 let start = Instant::now();
43 let mut stage_traces = Vec::new();
44 let mut step_counter = 0;
45
46 let mut owned_cache;
47 let cache_ref: &mut Option<&mut PureStageCache>;
48 let mut none_holder: Option<&mut PureStageCache> = None;
49
50 if let Some(c) = cache {
51 owned_cache = Some(c);
52 cache_ref = &mut owned_cache;
53 } else {
54 cache_ref = &mut none_holder;
55 }
56
57 let output = execute_node(
58 node,
59 input,
60 executor,
61 &mut stage_traces,
62 &mut step_counter,
63 cache_ref,
64 )?;
65
66 let duration_ms = start.elapsed().as_millis() as u64;
67 let has_failures = stage_traces
68 .iter()
69 .any(|t| matches!(t.status, StageStatus::Failed { .. }));
70
71 let trace = CompositionTrace {
72 composition_id: composition_id.into(),
73 started_at: Utc::now().to_rfc3339(),
74 duration_ms,
75 status: if has_failures {
76 TraceStatus::Failed
77 } else {
78 TraceStatus::Ok
79 },
80 stages: stage_traces,
81 security_events: Vec::new(),
82 warnings: Vec::new(),
83 };
84
85 Ok(CompositionResult {
86 output,
87 trace,
88 spent_cents: 0,
89 })
90}
91
92fn execute_node<E: StageExecutor + Sync>(
93 node: &CompositionNode,
94 input: &Value,
95 executor: &E,
96 traces: &mut Vec<StageTrace>,
97 step_counter: &mut usize,
98 cache: &mut Option<&mut PureStageCache>,
99) -> Result<Value, ExecutionError> {
100 match node {
101 CompositionNode::Stage { id, config } => {
102 let merged = if let Some(cfg) = config {
103 let mut obj = match input {
104 Value::Object(map) => map.clone(),
105 other => {
106 let mut m = serde_json::Map::new();
107 let data_key = [
108 "items", "text", "data", "input", "records", "train", "document",
109 "html", "csv", "json_str",
110 ]
111 .iter()
112 .find(|k| !cfg.contains_key(**k))
113 .unwrap_or(&"items");
114 m.insert(data_key.to_string(), other.clone());
115 m
116 }
117 };
118 for (k, v) in cfg {
119 obj.insert(k.clone(), v.clone());
120 }
121 Value::Object(obj)
122 } else {
123 input.clone()
124 };
125 execute_stage(id, &merged, executor, traces, step_counter, cache)
126 }
127 CompositionNode::Const { value } => Ok(value.clone()),
128 CompositionNode::Sequential { stages } => {
129 let mut current = input.clone();
130 for stage in stages {
131 current = execute_node(stage, ¤t, executor, traces, step_counter, cache)?;
132 }
133 Ok(current)
134 }
135 CompositionNode::Parallel { branches } => {
136 let branch_data: Vec<(&str, &CompositionNode, Value)> = branches
142 .iter()
143 .map(|(name, branch)| {
144 let branch_input = if let Value::Object(ref obj) = input {
145 obj.get(name).cloned().unwrap_or_else(|| input.clone())
146 } else {
147 input.clone()
148 };
149 (name.as_str(), branch, branch_input)
150 })
151 .collect();
152
153 let branch_results = std::thread::scope(|s| {
157 let handles: Vec<_> = branch_data
158 .iter()
159 .map(|(name, branch, branch_input)| {
160 s.spawn(move || {
161 let mut branch_traces = Vec::new();
162 let mut branch_counter = 0usize;
163 let result = execute_node(
164 branch,
165 branch_input,
166 executor,
167 &mut branch_traces,
168 &mut branch_counter,
169 &mut None,
170 );
171 (*name, result, branch_traces)
172 })
173 })
174 .collect();
175 handles
176 .into_iter()
177 .map(|h| h.join().expect("parallel branch panicked"))
178 .collect::<Vec<_>>()
179 });
180
181 let mut output_fields = serde_json::Map::new();
182 for (name, result, branch_traces) in branch_results {
183 let branch_output = result?;
184 output_fields.insert(name.to_string(), branch_output);
185 traces.extend(branch_traces);
186 }
187 Ok(Value::Object(output_fields))
188 }
189 CompositionNode::Branch {
190 predicate,
191 if_true,
192 if_false,
193 } => {
194 let pred_result =
195 execute_node(predicate, input, executor, traces, step_counter, cache)?;
196 let condition = match &pred_result {
197 Value::Bool(b) => *b,
198 _ => false,
199 };
200 if condition {
201 execute_node(if_true, input, executor, traces, step_counter, cache)
202 } else {
203 execute_node(if_false, input, executor, traces, step_counter, cache)
204 }
205 }
206 CompositionNode::Fanout { source, targets } => {
207 let source_output = execute_node(source, input, executor, traces, step_counter, cache)?;
208 let mut results = Vec::new();
209 for target in targets {
210 let result = execute_node(
211 target,
212 &source_output,
213 executor,
214 traces,
215 step_counter,
216 cache,
217 )?;
218 results.push(result);
219 }
220 Ok(Value::Array(results))
221 }
222 CompositionNode::Merge { sources, target } => {
223 let mut merged = serde_json::Map::new();
224 for (i, source) in sources.iter().enumerate() {
225 let source_input = if let Value::Object(ref obj) = input {
226 obj.get(&format!("source_{i}"))
227 .cloned()
228 .unwrap_or(Value::Null)
229 } else {
230 input.clone()
231 };
232 let result =
233 execute_node(source, &source_input, executor, traces, step_counter, cache)?;
234 merged.insert(format!("source_{i}"), result);
235 }
236 execute_node(
237 target,
238 &Value::Object(merged),
239 executor,
240 traces,
241 step_counter,
242 cache,
243 )
244 }
245 CompositionNode::Retry {
246 stage,
247 max_attempts,
248 ..
249 } => {
250 let mut last_err = None;
251 for _ in 0..*max_attempts {
252 match execute_node(stage, input, executor, traces, step_counter, cache) {
253 Ok(output) => return Ok(output),
254 Err(e) => last_err = Some(e),
255 }
256 }
257 Err(last_err.unwrap_or(ExecutionError::RetryExhausted {
258 stage_id: StageId("unknown".into()),
259 attempts: *max_attempts,
260 }))
261 }
262 CompositionNode::RemoteStage { url, .. } => execute_remote_stage(url, input),
263 CompositionNode::Let { bindings, body } => {
264 let bindings_vec: Vec<(&str, &CompositionNode)> =
267 bindings.iter().map(|(n, b)| (n.as_str(), b)).collect();
268
269 let binding_results = std::thread::scope(|s| {
270 let handles: Vec<_> = bindings_vec
271 .iter()
272 .map(|(name, node)| {
273 s.spawn(move || {
274 let mut bt = Vec::new();
275 let mut bc = 0usize;
276 let r =
277 execute_node(node, input, executor, &mut bt, &mut bc, &mut None);
278 (*name, r, bt)
279 })
280 })
281 .collect();
282 handles
283 .into_iter()
284 .map(|h| h.join().expect("Let binding panicked"))
285 .collect::<Vec<_>>()
286 });
287
288 let mut merged = match input {
290 Value::Object(map) => map.clone(),
291 _ => serde_json::Map::new(),
292 };
293 for (name, result, branch_traces) in binding_results {
294 let value = result?;
295 merged.insert(name.to_string(), value);
296 traces.extend(branch_traces);
297 }
298
299 let body_input = Value::Object(merged);
300 execute_node(body, &body_input, executor, traces, step_counter, cache)
301 }
302 }
303}
304
305fn execute_stage<E: StageExecutor + Sync>(
306 id: &StageId,
307 input: &Value,
308 executor: &E,
309 traces: &mut Vec<StageTrace>,
310 step_counter: &mut usize,
311 cache: &mut Option<&mut PureStageCache>,
312) -> Result<Value, ExecutionError> {
313 let step_index = *step_counter;
314 *step_counter += 1;
315 let start = Instant::now();
316
317 let input_hash = hash_value(input);
318
319 if let Some(ref mut c) = cache {
321 if let Some(cached_output) = c.get(id, input) {
322 let output = cached_output.clone();
323 let duration_ms = start.elapsed().as_millis() as u64;
324 traces.push(StageTrace {
325 stage_id: id.clone(),
326 step_index,
327 status: StageStatus::Ok,
328 duration_ms,
329 input_hash: Some(input_hash),
330 output_hash: Some(hash_value(&output)),
331 });
332 return Ok(output);
333 }
334 }
335
336 match executor.execute(id, input) {
337 Ok(output) => {
338 let output_hash = hash_value(&output);
339 let duration_ms = start.elapsed().as_millis() as u64;
340 traces.push(StageTrace {
341 stage_id: id.clone(),
342 step_index,
343 status: StageStatus::Ok,
344 duration_ms,
345 input_hash: Some(input_hash),
346 output_hash: Some(output_hash),
347 });
348 if let Some(ref mut c) = cache {
350 c.put(id, input, output.clone());
351 }
352 Ok(output)
353 }
354 Err(e) => {
355 let duration_ms = start.elapsed().as_millis() as u64;
356 traces.push(StageTrace {
357 stage_id: id.clone(),
358 step_index,
359 status: StageStatus::Failed {
360 code: "EXECUTION_ERROR".into(),
361 message: format!("{e}"),
362 },
363 duration_ms,
364 input_hash: Some(input_hash),
365 output_hash: None,
366 });
367 Err(e)
368 }
369 }
370}
371
372fn hash_value(value: &Value) -> String {
373 let bytes = serde_json::to_vec(value).unwrap_or_default();
374 let hash = Sha256::digest(&bytes);
375 hex::encode(hash)
376}
377
378fn execute_remote_stage(url: &str, input: &Value) -> Result<Value, ExecutionError> {
386 #[cfg(feature = "native")]
387 {
388 use reqwest::blocking::Client;
389
390 let client = Client::new();
391 let body = serde_json::json!({ "input": input });
392 let resp =
393 client
394 .post(url)
395 .json(&body)
396 .send()
397 .map_err(|e| ExecutionError::RemoteCallFailed {
398 url: url.to_string(),
399 reason: e.to_string(),
400 })?;
401
402 let resp_json: Value = resp.json().map_err(|e| ExecutionError::RemoteCallFailed {
403 url: url.to_string(),
404 reason: format!("invalid JSON response: {e}"),
405 })?;
406
407 if resp_json.get("ok") == Some(&Value::Bool(false)) {
412 let reason = resp_json
413 .get("error")
414 .and_then(|e| e.as_str())
415 .unwrap_or("remote reported ok=false without error message")
416 .to_string();
417 return Err(ExecutionError::RemoteCallFailed {
418 url: url.to_string(),
419 reason,
420 });
421 }
422 resp_json
423 .get("data")
424 .and_then(|d| d.get("output"))
425 .cloned()
426 .ok_or_else(|| ExecutionError::RemoteCallFailed {
427 url: url.to_string(),
428 reason: "response missing data.output field".to_string(),
429 })
430 }
431 #[cfg(not(feature = "native"))]
432 {
433 let _ = (url, input);
434 Err(ExecutionError::RemoteCallFailed {
435 url: url.to_string(),
436 reason: "remote calls are handled by the JS runtime in WASM builds".to_string(),
437 })
438 }
439}
440
441#[cfg(test)]
442mod tests {
443 use super::*;
444 use crate::executor::mock::MockExecutor;
445 use serde_json::json;
446 use std::collections::BTreeMap;
447
448 fn stage(id: &str) -> CompositionNode {
449 CompositionNode::Stage {
450 id: StageId(id.into()),
451 config: None,
452 }
453 }
454
455 #[test]
456 fn run_single_stage() {
457 let executor = MockExecutor::new().with_output(&StageId("a".into()), json!(42));
458 let result = run_composition(&stage("a"), &json!("input"), &executor, "test_comp").unwrap();
459 assert_eq!(result.output, json!(42));
460 assert_eq!(result.trace.stages.len(), 1);
461 assert!(matches!(result.trace.status, TraceStatus::Ok));
462 }
463
464 #[test]
465 fn run_sequential() {
466 let executor = MockExecutor::new()
467 .with_output(&StageId("a".into()), json!("mid"))
468 .with_output(&StageId("b".into()), json!("final"));
469 let node = CompositionNode::Sequential {
470 stages: vec![stage("a"), stage("b")],
471 };
472 let result = run_composition(&node, &json!("start"), &executor, "test").unwrap();
473 assert_eq!(result.output, json!("final"));
474 assert_eq!(result.trace.stages.len(), 2);
475 }
476
477 #[test]
478 fn run_parallel() {
479 let executor = MockExecutor::new()
480 .with_output(&StageId("s1".into()), json!("r1"))
481 .with_output(&StageId("s2".into()), json!("r2"));
482 let node = CompositionNode::Parallel {
483 branches: BTreeMap::from([("left".into(), stage("s1")), ("right".into(), stage("s2"))]),
484 };
485 let result = run_composition(&node, &json!({}), &executor, "test").unwrap();
486 assert_eq!(result.output, json!({"left": "r1", "right": "r2"}));
487 }
488
489 #[test]
490 fn run_branch_true() {
491 let executor = MockExecutor::new()
492 .with_output(&StageId("pred".into()), json!(true))
493 .with_output(&StageId("yes".into()), json!("YES"))
494 .with_output(&StageId("no".into()), json!("NO"));
495 let node = CompositionNode::Branch {
496 predicate: Box::new(stage("pred")),
497 if_true: Box::new(stage("yes")),
498 if_false: Box::new(stage("no")),
499 };
500 let result = run_composition(&node, &json!("input"), &executor, "test").unwrap();
501 assert_eq!(result.output, json!("YES"));
502 }
503
504 #[test]
505 fn run_branch_false() {
506 let executor = MockExecutor::new()
507 .with_output(&StageId("pred".into()), json!(false))
508 .with_output(&StageId("yes".into()), json!("YES"))
509 .with_output(&StageId("no".into()), json!("NO"));
510 let node = CompositionNode::Branch {
511 predicate: Box::new(stage("pred")),
512 if_true: Box::new(stage("yes")),
513 if_false: Box::new(stage("no")),
514 };
515 let result = run_composition(&node, &json!("input"), &executor, "test").unwrap();
516 assert_eq!(result.output, json!("NO"));
517 }
518
519 #[test]
520 fn run_fanout() {
521 let executor = MockExecutor::new()
522 .with_output(&StageId("src".into()), json!("data"))
523 .with_output(&StageId("t1".into()), json!("r1"))
524 .with_output(&StageId("t2".into()), json!("r2"));
525 let node = CompositionNode::Fanout {
526 source: Box::new(stage("src")),
527 targets: vec![stage("t1"), stage("t2")],
528 };
529 let result = run_composition(&node, &json!("in"), &executor, "test").unwrap();
530 assert_eq!(result.output, json!(["r1", "r2"]));
531 }
532
533 #[test]
534 fn trace_has_input_output_hashes() {
535 let executor = MockExecutor::new().with_output(&StageId("a".into()), json!(42));
536 let result = run_composition(&stage("a"), &json!("input"), &executor, "test").unwrap();
537 assert!(result.trace.stages[0].input_hash.is_some());
538 assert!(result.trace.stages[0].output_hash.is_some());
539 }
540}