adk_graph/deferred.rs
1//! Deferred node (fan-in barrier) support for graph workflows.
2//!
3//! Provides fan-in barrier semantics for nodes that wait on multiple upstream
4//! parallel paths before executing. This enables scatter-gather patterns where
5//! work is distributed across parallel branches and then collected at a single
6//! synchronization point.
7//!
8//! # Overview
9//!
10//! A deferred node is declared with a [`DeferredNodeConfig`] that specifies:
11//! - [`MergeStrategy`]: How upstream outputs are combined (collect, merge maps, first, or custom).
12//! - `fan_in_timeout`: Optional maximum wait duration for all upstream paths.
13//!
14//! The [`FanInTracker`] tracks which upstream paths have completed and merges
15//! their outputs according to the configured strategy.
16//!
17//! # Example
18//!
19//! ```rust
20//! use std::time::Duration;
21//! use adk_graph::deferred::{DeferredNodeConfig, FanInTracker, MergeStrategy};
22//! use serde_json::json;
23//!
24//! // Configure a deferred node that collects all upstream outputs
25//! let config = DeferredNodeConfig {
26//! merge_strategy: MergeStrategy::Collect,
27//! fan_in_timeout: Some(Duration::from_secs(30)),
28//! };
29//!
30//! // Track upstream completions
31//! let mut tracker = FanInTracker::new(vec!["branch_a", "branch_b", "branch_c"]);
32//!
33//! tracker.record("branch_a", json!({"result": 1}));
34//! tracker.record("branch_b", json!({"result": 2}));
35//! assert!(!tracker.is_ready());
36//!
37//! tracker.record("branch_c", json!({"result": 3}));
38//! assert!(tracker.is_ready());
39//!
40//! // Merge outputs using the configured strategy
41//! let merged = tracker.merge(&config.merge_strategy);
42//! assert_eq!(merged, json!([{"result": 1}, {"result": 2}, {"result": 3}]));
43//! ```
44
45use std::collections::{HashMap, HashSet};
46use std::fmt;
47use std::sync::Arc;
48use std::time::Duration;
49
50use serde_json::Value;
51
52/// How to combine outputs from multiple upstream parallel paths.
53///
54/// The merge strategy determines how the collected outputs from all upstream
55/// branches are combined into a single value for the deferred node's input.
56///
57/// # Example
58///
59/// ```rust
60/// use adk_graph::deferred::MergeStrategy;
61///
62/// // Default strategy collects all outputs into a Vec
63/// let strategy = MergeStrategy::default();
64/// assert!(matches!(strategy, MergeStrategy::Collect));
65///
66/// // MergeMap combines all output maps with last-write-wins
67/// let strategy = MergeStrategy::MergeMap;
68/// ```
69#[derive(Clone, Default)]
70pub enum MergeStrategy {
71 /// Collect all outputs into a `Vec<Value>`.
72 ///
73 /// Outputs are ordered by the insertion order of source nodes
74 /// (the order in which they were recorded).
75 #[default]
76 Collect,
77
78 /// Merge all output maps into a single map (last-write-wins on key conflict).
79 ///
80 /// Each upstream output is expected to be a JSON object. Non-object outputs
81 /// are skipped. When multiple outputs contain the same key, the value from
82 /// the later-recorded source wins.
83 MergeMap,
84
85 /// Use only the first completed output.
86 ///
87 /// Returns the output from whichever upstream path completed first
88 /// (i.e., was recorded first).
89 First,
90
91 /// Custom merge function.
92 ///
93 /// Accepts a closure that takes all collected outputs and produces a
94 /// single merged value.
95 ///
96 /// # Example
97 ///
98 /// ```rust
99 /// use std::sync::Arc;
100 /// use adk_graph::deferred::MergeStrategy;
101 /// use serde_json::{json, Value};
102 ///
103 /// let strategy = MergeStrategy::Custom(Arc::new(|outputs: Vec<Value>| {
104 /// json!({ "count": outputs.len() })
105 /// }));
106 /// ```
107 Custom(Arc<dyn Fn(Vec<Value>) -> Value + Send + Sync>),
108}
109
110impl fmt::Debug for MergeStrategy {
111 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
112 match self {
113 Self::Collect => write!(f, "Collect"),
114 Self::MergeMap => write!(f, "MergeMap"),
115 Self::First => write!(f, "First"),
116 Self::Custom(_) => write!(f, "Custom(<fn>)"),
117 }
118 }
119}
120
121/// Configuration for a deferred (fan-in) node.
122///
123/// A deferred node waits for all upstream parallel paths to complete before
124/// executing. The configuration controls how outputs are merged and how long
125/// the node waits.
126///
127/// # Example
128///
129/// ```rust
130/// use std::time::Duration;
131/// use adk_graph::deferred::{DeferredNodeConfig, MergeStrategy};
132///
133/// let config = DeferredNodeConfig {
134/// merge_strategy: MergeStrategy::MergeMap,
135/// fan_in_timeout: Some(Duration::from_secs(60)),
136/// };
137/// ```
138#[derive(Debug, Clone, Default)]
139pub struct DeferredNodeConfig {
140 /// Strategy for combining upstream outputs.
141 pub merge_strategy: MergeStrategy,
142
143 /// Maximum time to wait for all upstream paths to complete.
144 ///
145 /// - `None`: Wait indefinitely for all upstream paths.
146 /// - `Some(duration)`: If the timeout expires and some paths have completed,
147 /// proceed with partial results. If zero paths have completed, return
148 /// `GraphError::FanInTimedOut`.
149 pub fan_in_timeout: Option<Duration>,
150}
151
152/// Tracks which upstream paths have completed for a deferred node.
153///
154/// The tracker maintains a set of expected source nodes and records their
155/// outputs as they arrive. Once all expected sources have reported, the
156/// tracker is ready and outputs can be merged.
157///
158/// # Example
159///
160/// ```rust
161/// use adk_graph::deferred::{FanInTracker, MergeStrategy};
162/// use serde_json::json;
163///
164/// let mut tracker = FanInTracker::new(vec!["node_a", "node_b"]);
165///
166/// assert!(!tracker.is_ready());
167/// assert_eq!(tracker.received_count(), 0);
168/// assert_eq!(tracker.expected_count(), 2);
169///
170/// tracker.record("node_a", json!("output_a"));
171/// assert!(!tracker.is_ready());
172///
173/// tracker.record("node_b", json!("output_b"));
174/// assert!(tracker.is_ready());
175///
176/// let merged = tracker.merge(&MergeStrategy::Collect);
177/// assert_eq!(merged, json!(["output_a", "output_b"]));
178/// ```
179pub struct FanInTracker {
180 /// The set of source node names we expect to receive output from.
181 expected: HashSet<String>,
182 /// Outputs received so far, keyed by source node name.
183 received: HashMap<String, Value>,
184 /// Insertion order of received outputs (for deterministic merge ordering).
185 insertion_order: Vec<String>,
186}
187
188impl FanInTracker {
189 /// Create a new tracker expecting outputs from the given source nodes.
190 ///
191 /// # Arguments
192 ///
193 /// * `expected_sources` - Names of upstream nodes that must complete
194 /// before this deferred node can execute.
195 ///
196 /// # Example
197 ///
198 /// ```rust
199 /// use adk_graph::deferred::FanInTracker;
200 ///
201 /// let tracker = FanInTracker::new(vec!["branch_1", "branch_2", "branch_3"]);
202 /// assert_eq!(tracker.expected_count(), 3);
203 /// assert!(!tracker.is_ready());
204 /// ```
205 pub fn new(expected_sources: Vec<&str>) -> Self {
206 Self {
207 expected: expected_sources.iter().map(|s| (*s).to_string()).collect(),
208 received: HashMap::new(),
209 insertion_order: Vec::new(),
210 }
211 }
212
213 /// Returns `true` when all expected sources have reported their output.
214 ///
215 /// # Example
216 ///
217 /// ```rust
218 /// use adk_graph::deferred::FanInTracker;
219 /// use serde_json::json;
220 ///
221 /// let mut tracker = FanInTracker::new(vec!["a"]);
222 /// assert!(!tracker.is_ready());
223 ///
224 /// tracker.record("a", json!(42));
225 /// assert!(tracker.is_ready());
226 /// ```
227 pub fn is_ready(&self) -> bool {
228 self.expected.iter().all(|s| self.received.contains_key(s))
229 }
230
231 /// Record the output from a source node.
232 ///
233 /// If the source has already been recorded, the previous value is
234 /// overwritten (last-write-wins). Recording a source that is not in
235 /// the expected set is a no-op for readiness but the value is still stored.
236 ///
237 /// # Arguments
238 ///
239 /// * `source_node` - The name of the upstream node that produced the output.
240 /// * `output` - The output value from the source node.
241 ///
242 /// # Example
243 ///
244 /// ```rust
245 /// use adk_graph::deferred::FanInTracker;
246 /// use serde_json::json;
247 ///
248 /// let mut tracker = FanInTracker::new(vec!["worker_1", "worker_2"]);
249 /// tracker.record("worker_1", json!({"status": "done"}));
250 /// assert_eq!(tracker.received_count(), 1);
251 /// ```
252 pub fn record(&mut self, source_node: &str, output: Value) {
253 let key = source_node.to_string();
254 if !self.received.contains_key(&key) {
255 self.insertion_order.push(key.clone());
256 }
257 self.received.insert(key, output);
258 }
259
260 /// Merge all received outputs according to the given strategy.
261 ///
262 /// The merge operation combines all recorded outputs into a single
263 /// [`Value`] based on the [`MergeStrategy`]:
264 ///
265 /// - [`MergeStrategy::Collect`]: Returns a JSON array of all outputs in
266 /// insertion order.
267 /// - [`MergeStrategy::MergeMap`]: Merges all JSON object outputs into a
268 /// single object (last-write-wins). Non-object outputs are skipped.
269 /// - [`MergeStrategy::First`]: Returns the first recorded output.
270 /// - [`MergeStrategy::Custom`]: Invokes the custom function with all outputs.
271 ///
272 /// # Arguments
273 ///
274 /// * `strategy` - The merge strategy to apply.
275 ///
276 /// # Example
277 ///
278 /// ```rust
279 /// use adk_graph::deferred::{FanInTracker, MergeStrategy};
280 /// use serde_json::json;
281 ///
282 /// let mut tracker = FanInTracker::new(vec!["a", "b"]);
283 /// tracker.record("a", json!({"x": 1}));
284 /// tracker.record("b", json!({"y": 2}));
285 ///
286 /// // Collect strategy
287 /// let result = tracker.merge(&MergeStrategy::Collect);
288 /// assert_eq!(result, json!([{"x": 1}, {"y": 2}]));
289 ///
290 /// // MergeMap strategy
291 /// let result = tracker.merge(&MergeStrategy::MergeMap);
292 /// assert_eq!(result, json!({"x": 1, "y": 2}));
293 /// ```
294 pub fn merge(&self, strategy: &MergeStrategy) -> Value {
295 match strategy {
296 MergeStrategy::Collect => {
297 let outputs: Vec<Value> = self
298 .insertion_order
299 .iter()
300 .filter_map(|key| self.received.get(key).cloned())
301 .collect();
302 Value::Array(outputs)
303 }
304 MergeStrategy::MergeMap => {
305 let mut merged = serde_json::Map::new();
306 for key in &self.insertion_order {
307 if let Some(Value::Object(map)) = self.received.get(key) {
308 for (k, v) in map {
309 merged.insert(k.clone(), v.clone());
310 }
311 }
312 }
313 Value::Object(merged)
314 }
315 MergeStrategy::First => self
316 .insertion_order
317 .first()
318 .and_then(|key| self.received.get(key).cloned())
319 .unwrap_or(Value::Null),
320 MergeStrategy::Custom(f) => {
321 let outputs: Vec<Value> = self
322 .insertion_order
323 .iter()
324 .filter_map(|key| self.received.get(key).cloned())
325 .collect();
326 f(outputs)
327 }
328 }
329 }
330
331 /// Returns the number of outputs received so far.
332 pub fn received_count(&self) -> usize {
333 self.received.len()
334 }
335
336 /// Returns the number of expected source nodes.
337 pub fn expected_count(&self) -> usize {
338 self.expected.len()
339 }
340
341 /// Returns the names of sources that have not yet reported.
342 pub fn pending_sources(&self) -> Vec<&str> {
343 self.expected
344 .iter()
345 .filter(|s| !self.received.contains_key(*s))
346 .map(|s| s.as_str())
347 .collect()
348 }
349
350 /// Returns the names of sources that have reported.
351 pub fn completed_sources(&self) -> Vec<&str> {
352 self.insertion_order.iter().map(|s| s.as_str()).collect()
353 }
354}
355
356impl fmt::Debug for FanInTracker {
357 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
358 f.debug_struct("FanInTracker")
359 .field("expected", &self.expected)
360 .field("received_keys", &self.insertion_order)
361 .field("is_ready", &self.is_ready())
362 .finish()
363 }
364}
365
366#[cfg(test)]
367mod tests {
368 use super::*;
369 use serde_json::json;
370
371 #[test]
372 fn test_tracker_new_empty_not_ready() {
373 let tracker = FanInTracker::new(vec!["a", "b", "c"]);
374 assert!(!tracker.is_ready());
375 assert_eq!(tracker.expected_count(), 3);
376 assert_eq!(tracker.received_count(), 0);
377 }
378
379 #[test]
380 fn test_tracker_ready_when_all_received() {
381 let mut tracker = FanInTracker::new(vec!["a", "b"]);
382 tracker.record("a", json!(1));
383 assert!(!tracker.is_ready());
384 tracker.record("b", json!(2));
385 assert!(tracker.is_ready());
386 }
387
388 #[test]
389 fn test_merge_collect() {
390 let mut tracker = FanInTracker::new(vec!["x", "y", "z"]);
391 tracker.record("x", json!("first"));
392 tracker.record("y", json!("second"));
393 tracker.record("z", json!("third"));
394
395 let result = tracker.merge(&MergeStrategy::Collect);
396 assert_eq!(result, json!(["first", "second", "third"]));
397 }
398
399 #[test]
400 fn test_merge_map_combines_objects() {
401 let mut tracker = FanInTracker::new(vec!["a", "b"]);
402 tracker.record("a", json!({"key1": "val1", "shared": "from_a"}));
403 tracker.record("b", json!({"key2": "val2", "shared": "from_b"}));
404
405 let result = tracker.merge(&MergeStrategy::MergeMap);
406 assert_eq!(result, json!({"key1": "val1", "key2": "val2", "shared": "from_b"}));
407 }
408
409 #[test]
410 fn test_merge_map_skips_non_objects() {
411 let mut tracker = FanInTracker::new(vec!["a", "b"]);
412 tracker.record("a", json!(42)); // Not an object, skipped
413 tracker.record("b", json!({"key": "value"}));
414
415 let result = tracker.merge(&MergeStrategy::MergeMap);
416 assert_eq!(result, json!({"key": "value"}));
417 }
418
419 #[test]
420 fn test_merge_first() {
421 let mut tracker = FanInTracker::new(vec!["a", "b", "c"]);
422 tracker.record("b", json!("first_to_arrive"));
423 tracker.record("a", json!("second_to_arrive"));
424 tracker.record("c", json!("third_to_arrive"));
425
426 let result = tracker.merge(&MergeStrategy::First);
427 assert_eq!(result, json!("first_to_arrive"));
428 }
429
430 #[test]
431 fn test_merge_first_empty() {
432 let tracker = FanInTracker::new(vec!["a"]);
433 let result = tracker.merge(&MergeStrategy::First);
434 assert_eq!(result, Value::Null);
435 }
436
437 #[test]
438 fn test_merge_custom() {
439 let mut tracker = FanInTracker::new(vec!["a", "b"]);
440 tracker.record("a", json!(10));
441 tracker.record("b", json!(20));
442
443 let strategy = MergeStrategy::Custom(Arc::new(|outputs| {
444 let sum: i64 = outputs.iter().filter_map(|v| v.as_i64()).sum();
445 json!(sum)
446 }));
447
448 let result = tracker.merge(&strategy);
449 assert_eq!(result, json!(30));
450 }
451
452 #[test]
453 fn test_record_overwrites_previous() {
454 let mut tracker = FanInTracker::new(vec!["a"]);
455 tracker.record("a", json!("first"));
456 tracker.record("a", json!("second"));
457
458 assert!(tracker.is_ready());
459 assert_eq!(tracker.received_count(), 1);
460
461 let result = tracker.merge(&MergeStrategy::First);
462 assert_eq!(result, json!("second"));
463 }
464
465 #[test]
466 fn test_pending_and_completed_sources() {
467 let mut tracker = FanInTracker::new(vec!["a", "b", "c"]);
468 tracker.record("b", json!(1));
469
470 let mut pending = tracker.pending_sources();
471 pending.sort();
472 assert_eq!(pending, vec!["a", "c"]);
473 assert_eq!(tracker.completed_sources(), vec!["b"]);
474 }
475
476 #[test]
477 fn test_default_config() {
478 let config = DeferredNodeConfig::default();
479 assert!(matches!(config.merge_strategy, MergeStrategy::Collect));
480 assert!(config.fan_in_timeout.is_none());
481 }
482
483 #[test]
484 fn test_merge_strategy_debug() {
485 assert_eq!(format!("{:?}", MergeStrategy::Collect), "Collect");
486 assert_eq!(format!("{:?}", MergeStrategy::MergeMap), "MergeMap");
487 assert_eq!(format!("{:?}", MergeStrategy::First), "First");
488 let custom = MergeStrategy::Custom(Arc::new(Value::Array));
489 assert_eq!(format!("{:?}", custom), "Custom(<fn>)");
490 }
491}