Skip to main content

adk_graph/
deferred.rs

1//! Deferred node (fan-in barrier) support for graph workflows.
2//!
3//! Provides fan-in barrier semantics for nodes that wait on multiple upstream
4//! parallel paths before executing. This enables scatter-gather patterns where
5//! work is distributed across parallel branches and then collected at a single
6//! synchronization point.
7//!
8//! # Overview
9//!
10//! A deferred node is declared with a [`DeferredNodeConfig`] that specifies:
11//! - [`MergeStrategy`]: How upstream outputs are combined (collect, merge maps, first, or custom).
12//! - `fan_in_timeout`: Optional maximum wait duration for all upstream paths.
13//!
14//! The [`FanInTracker`] tracks which upstream paths have completed and merges
15//! their outputs according to the configured strategy.
16//!
17//! # Example
18//!
19//! ```rust
20//! use std::time::Duration;
21//! use adk_graph::deferred::{DeferredNodeConfig, FanInTracker, MergeStrategy};
22//! use serde_json::json;
23//!
24//! // Configure a deferred node that collects all upstream outputs
25//! let config = DeferredNodeConfig {
26//!     merge_strategy: MergeStrategy::Collect,
27//!     fan_in_timeout: Some(Duration::from_secs(30)),
28//! };
29//!
30//! // Track upstream completions
31//! let mut tracker = FanInTracker::new(vec!["branch_a", "branch_b", "branch_c"]);
32//!
33//! tracker.record("branch_a", json!({"result": 1}));
34//! tracker.record("branch_b", json!({"result": 2}));
35//! assert!(!tracker.is_ready());
36//!
37//! tracker.record("branch_c", json!({"result": 3}));
38//! assert!(tracker.is_ready());
39//!
40//! // Merge outputs using the configured strategy
41//! let merged = tracker.merge(&config.merge_strategy);
42//! assert_eq!(merged, json!([{"result": 1}, {"result": 2}, {"result": 3}]));
43//! ```
44
45use std::collections::{HashMap, HashSet};
46use std::fmt;
47use std::sync::Arc;
48use std::time::Duration;
49
50use serde_json::Value;
51
52/// How to combine outputs from multiple upstream parallel paths.
53///
54/// The merge strategy determines how the collected outputs from all upstream
55/// branches are combined into a single value for the deferred node's input.
56///
57/// # Example
58///
59/// ```rust
60/// use adk_graph::deferred::MergeStrategy;
61///
62/// // Default strategy collects all outputs into a Vec
63/// let strategy = MergeStrategy::default();
64/// assert!(matches!(strategy, MergeStrategy::Collect));
65///
66/// // MergeMap combines all output maps with last-write-wins
67/// let strategy = MergeStrategy::MergeMap;
68/// ```
69#[derive(Clone, Default)]
70pub enum MergeStrategy {
71    /// Collect all outputs into a `Vec<Value>`.
72    ///
73    /// Outputs are ordered by the insertion order of source nodes
74    /// (the order in which they were recorded).
75    #[default]
76    Collect,
77
78    /// Merge all output maps into a single map (last-write-wins on key conflict).
79    ///
80    /// Each upstream output is expected to be a JSON object. Non-object outputs
81    /// are skipped. When multiple outputs contain the same key, the value from
82    /// the later-recorded source wins.
83    MergeMap,
84
85    /// Use only the first completed output.
86    ///
87    /// Returns the output from whichever upstream path completed first
88    /// (i.e., was recorded first).
89    First,
90
91    /// Custom merge function.
92    ///
93    /// Accepts a closure that takes all collected outputs and produces a
94    /// single merged value.
95    ///
96    /// # Example
97    ///
98    /// ```rust
99    /// use std::sync::Arc;
100    /// use adk_graph::deferred::MergeStrategy;
101    /// use serde_json::{json, Value};
102    ///
103    /// let strategy = MergeStrategy::Custom(Arc::new(|outputs: Vec<Value>| {
104    ///     json!({ "count": outputs.len() })
105    /// }));
106    /// ```
107    Custom(Arc<dyn Fn(Vec<Value>) -> Value + Send + Sync>),
108}
109
110impl fmt::Debug for MergeStrategy {
111    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
112        match self {
113            Self::Collect => write!(f, "Collect"),
114            Self::MergeMap => write!(f, "MergeMap"),
115            Self::First => write!(f, "First"),
116            Self::Custom(_) => write!(f, "Custom(<fn>)"),
117        }
118    }
119}
120
121/// Configuration for a deferred (fan-in) node.
122///
123/// A deferred node waits for all upstream parallel paths to complete before
124/// executing. The configuration controls how outputs are merged and how long
125/// the node waits.
126///
127/// # Example
128///
129/// ```rust
130/// use std::time::Duration;
131/// use adk_graph::deferred::{DeferredNodeConfig, MergeStrategy};
132///
133/// let config = DeferredNodeConfig {
134///     merge_strategy: MergeStrategy::MergeMap,
135///     fan_in_timeout: Some(Duration::from_secs(60)),
136/// };
137/// ```
138#[derive(Debug, Clone, Default)]
139pub struct DeferredNodeConfig {
140    /// Strategy for combining upstream outputs.
141    pub merge_strategy: MergeStrategy,
142
143    /// Maximum time to wait for all upstream paths to complete.
144    ///
145    /// - `None`: Wait indefinitely for all upstream paths.
146    /// - `Some(duration)`: If the timeout expires and some paths have completed,
147    ///   proceed with partial results. If zero paths have completed, return
148    ///   `GraphError::FanInTimedOut`.
149    pub fan_in_timeout: Option<Duration>,
150}
151
152/// Tracks which upstream paths have completed for a deferred node.
153///
154/// The tracker maintains a set of expected source nodes and records their
155/// outputs as they arrive. Once all expected sources have reported, the
156/// tracker is ready and outputs can be merged.
157///
158/// # Example
159///
160/// ```rust
161/// use adk_graph::deferred::{FanInTracker, MergeStrategy};
162/// use serde_json::json;
163///
164/// let mut tracker = FanInTracker::new(vec!["node_a", "node_b"]);
165///
166/// assert!(!tracker.is_ready());
167/// assert_eq!(tracker.received_count(), 0);
168/// assert_eq!(tracker.expected_count(), 2);
169///
170/// tracker.record("node_a", json!("output_a"));
171/// assert!(!tracker.is_ready());
172///
173/// tracker.record("node_b", json!("output_b"));
174/// assert!(tracker.is_ready());
175///
176/// let merged = tracker.merge(&MergeStrategy::Collect);
177/// assert_eq!(merged, json!(["output_a", "output_b"]));
178/// ```
179pub struct FanInTracker {
180    /// The set of source node names we expect to receive output from.
181    expected: HashSet<String>,
182    /// Outputs received so far, keyed by source node name.
183    received: HashMap<String, Value>,
184    /// Insertion order of received outputs (for deterministic merge ordering).
185    insertion_order: Vec<String>,
186}
187
188impl FanInTracker {
189    /// Create a new tracker expecting outputs from the given source nodes.
190    ///
191    /// # Arguments
192    ///
193    /// * `expected_sources` - Names of upstream nodes that must complete
194    ///   before this deferred node can execute.
195    ///
196    /// # Example
197    ///
198    /// ```rust
199    /// use adk_graph::deferred::FanInTracker;
200    ///
201    /// let tracker = FanInTracker::new(vec!["branch_1", "branch_2", "branch_3"]);
202    /// assert_eq!(tracker.expected_count(), 3);
203    /// assert!(!tracker.is_ready());
204    /// ```
205    pub fn new(expected_sources: Vec<&str>) -> Self {
206        Self {
207            expected: expected_sources.iter().map(|s| (*s).to_string()).collect(),
208            received: HashMap::new(),
209            insertion_order: Vec::new(),
210        }
211    }
212
213    /// Returns `true` when all expected sources have reported their output.
214    ///
215    /// # Example
216    ///
217    /// ```rust
218    /// use adk_graph::deferred::FanInTracker;
219    /// use serde_json::json;
220    ///
221    /// let mut tracker = FanInTracker::new(vec!["a"]);
222    /// assert!(!tracker.is_ready());
223    ///
224    /// tracker.record("a", json!(42));
225    /// assert!(tracker.is_ready());
226    /// ```
227    pub fn is_ready(&self) -> bool {
228        self.expected.iter().all(|s| self.received.contains_key(s))
229    }
230
231    /// Record the output from a source node.
232    ///
233    /// If the source has already been recorded, the previous value is
234    /// overwritten (last-write-wins). Recording a source that is not in
235    /// the expected set is a no-op for readiness but the value is still stored.
236    ///
237    /// # Arguments
238    ///
239    /// * `source_node` - The name of the upstream node that produced the output.
240    /// * `output` - The output value from the source node.
241    ///
242    /// # Example
243    ///
244    /// ```rust
245    /// use adk_graph::deferred::FanInTracker;
246    /// use serde_json::json;
247    ///
248    /// let mut tracker = FanInTracker::new(vec!["worker_1", "worker_2"]);
249    /// tracker.record("worker_1", json!({"status": "done"}));
250    /// assert_eq!(tracker.received_count(), 1);
251    /// ```
252    pub fn record(&mut self, source_node: &str, output: Value) {
253        let key = source_node.to_string();
254        if !self.received.contains_key(&key) {
255            self.insertion_order.push(key.clone());
256        }
257        self.received.insert(key, output);
258    }
259
260    /// Merge all received outputs according to the given strategy.
261    ///
262    /// The merge operation combines all recorded outputs into a single
263    /// [`Value`] based on the [`MergeStrategy`]:
264    ///
265    /// - [`MergeStrategy::Collect`]: Returns a JSON array of all outputs in
266    ///   insertion order.
267    /// - [`MergeStrategy::MergeMap`]: Merges all JSON object outputs into a
268    ///   single object (last-write-wins). Non-object outputs are skipped.
269    /// - [`MergeStrategy::First`]: Returns the first recorded output.
270    /// - [`MergeStrategy::Custom`]: Invokes the custom function with all outputs.
271    ///
272    /// # Arguments
273    ///
274    /// * `strategy` - The merge strategy to apply.
275    ///
276    /// # Example
277    ///
278    /// ```rust
279    /// use adk_graph::deferred::{FanInTracker, MergeStrategy};
280    /// use serde_json::json;
281    ///
282    /// let mut tracker = FanInTracker::new(vec!["a", "b"]);
283    /// tracker.record("a", json!({"x": 1}));
284    /// tracker.record("b", json!({"y": 2}));
285    ///
286    /// // Collect strategy
287    /// let result = tracker.merge(&MergeStrategy::Collect);
288    /// assert_eq!(result, json!([{"x": 1}, {"y": 2}]));
289    ///
290    /// // MergeMap strategy
291    /// let result = tracker.merge(&MergeStrategy::MergeMap);
292    /// assert_eq!(result, json!({"x": 1, "y": 2}));
293    /// ```
294    pub fn merge(&self, strategy: &MergeStrategy) -> Value {
295        match strategy {
296            MergeStrategy::Collect => {
297                let outputs: Vec<Value> = self
298                    .insertion_order
299                    .iter()
300                    .filter_map(|key| self.received.get(key).cloned())
301                    .collect();
302                Value::Array(outputs)
303            }
304            MergeStrategy::MergeMap => {
305                let mut merged = serde_json::Map::new();
306                for key in &self.insertion_order {
307                    if let Some(Value::Object(map)) = self.received.get(key) {
308                        for (k, v) in map {
309                            merged.insert(k.clone(), v.clone());
310                        }
311                    }
312                }
313                Value::Object(merged)
314            }
315            MergeStrategy::First => self
316                .insertion_order
317                .first()
318                .and_then(|key| self.received.get(key).cloned())
319                .unwrap_or(Value::Null),
320            MergeStrategy::Custom(f) => {
321                let outputs: Vec<Value> = self
322                    .insertion_order
323                    .iter()
324                    .filter_map(|key| self.received.get(key).cloned())
325                    .collect();
326                f(outputs)
327            }
328        }
329    }
330
331    /// Returns the number of outputs received so far.
332    pub fn received_count(&self) -> usize {
333        self.received.len()
334    }
335
336    /// Returns the number of expected source nodes.
337    pub fn expected_count(&self) -> usize {
338        self.expected.len()
339    }
340
341    /// Returns the names of sources that have not yet reported.
342    pub fn pending_sources(&self) -> Vec<&str> {
343        self.expected
344            .iter()
345            .filter(|s| !self.received.contains_key(*s))
346            .map(|s| s.as_str())
347            .collect()
348    }
349
350    /// Returns the names of sources that have reported.
351    pub fn completed_sources(&self) -> Vec<&str> {
352        self.insertion_order.iter().map(|s| s.as_str()).collect()
353    }
354}
355
356impl fmt::Debug for FanInTracker {
357    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
358        f.debug_struct("FanInTracker")
359            .field("expected", &self.expected)
360            .field("received_keys", &self.insertion_order)
361            .field("is_ready", &self.is_ready())
362            .finish()
363    }
364}
365
366#[cfg(test)]
367mod tests {
368    use super::*;
369    use serde_json::json;
370
371    #[test]
372    fn test_tracker_new_empty_not_ready() {
373        let tracker = FanInTracker::new(vec!["a", "b", "c"]);
374        assert!(!tracker.is_ready());
375        assert_eq!(tracker.expected_count(), 3);
376        assert_eq!(tracker.received_count(), 0);
377    }
378
379    #[test]
380    fn test_tracker_ready_when_all_received() {
381        let mut tracker = FanInTracker::new(vec!["a", "b"]);
382        tracker.record("a", json!(1));
383        assert!(!tracker.is_ready());
384        tracker.record("b", json!(2));
385        assert!(tracker.is_ready());
386    }
387
388    #[test]
389    fn test_merge_collect() {
390        let mut tracker = FanInTracker::new(vec!["x", "y", "z"]);
391        tracker.record("x", json!("first"));
392        tracker.record("y", json!("second"));
393        tracker.record("z", json!("third"));
394
395        let result = tracker.merge(&MergeStrategy::Collect);
396        assert_eq!(result, json!(["first", "second", "third"]));
397    }
398
399    #[test]
400    fn test_merge_map_combines_objects() {
401        let mut tracker = FanInTracker::new(vec!["a", "b"]);
402        tracker.record("a", json!({"key1": "val1", "shared": "from_a"}));
403        tracker.record("b", json!({"key2": "val2", "shared": "from_b"}));
404
405        let result = tracker.merge(&MergeStrategy::MergeMap);
406        assert_eq!(result, json!({"key1": "val1", "key2": "val2", "shared": "from_b"}));
407    }
408
409    #[test]
410    fn test_merge_map_skips_non_objects() {
411        let mut tracker = FanInTracker::new(vec!["a", "b"]);
412        tracker.record("a", json!(42)); // Not an object, skipped
413        tracker.record("b", json!({"key": "value"}));
414
415        let result = tracker.merge(&MergeStrategy::MergeMap);
416        assert_eq!(result, json!({"key": "value"}));
417    }
418
419    #[test]
420    fn test_merge_first() {
421        let mut tracker = FanInTracker::new(vec!["a", "b", "c"]);
422        tracker.record("b", json!("first_to_arrive"));
423        tracker.record("a", json!("second_to_arrive"));
424        tracker.record("c", json!("third_to_arrive"));
425
426        let result = tracker.merge(&MergeStrategy::First);
427        assert_eq!(result, json!("first_to_arrive"));
428    }
429
430    #[test]
431    fn test_merge_first_empty() {
432        let tracker = FanInTracker::new(vec!["a"]);
433        let result = tracker.merge(&MergeStrategy::First);
434        assert_eq!(result, Value::Null);
435    }
436
437    #[test]
438    fn test_merge_custom() {
439        let mut tracker = FanInTracker::new(vec!["a", "b"]);
440        tracker.record("a", json!(10));
441        tracker.record("b", json!(20));
442
443        let strategy = MergeStrategy::Custom(Arc::new(|outputs| {
444            let sum: i64 = outputs.iter().filter_map(|v| v.as_i64()).sum();
445            json!(sum)
446        }));
447
448        let result = tracker.merge(&strategy);
449        assert_eq!(result, json!(30));
450    }
451
452    #[test]
453    fn test_record_overwrites_previous() {
454        let mut tracker = FanInTracker::new(vec!["a"]);
455        tracker.record("a", json!("first"));
456        tracker.record("a", json!("second"));
457
458        assert!(tracker.is_ready());
459        assert_eq!(tracker.received_count(), 1);
460
461        let result = tracker.merge(&MergeStrategy::First);
462        assert_eq!(result, json!("second"));
463    }
464
465    #[test]
466    fn test_pending_and_completed_sources() {
467        let mut tracker = FanInTracker::new(vec!["a", "b", "c"]);
468        tracker.record("b", json!(1));
469
470        let mut pending = tracker.pending_sources();
471        pending.sort();
472        assert_eq!(pending, vec!["a", "c"]);
473        assert_eq!(tracker.completed_sources(), vec!["b"]);
474    }
475
476    #[test]
477    fn test_default_config() {
478        let config = DeferredNodeConfig::default();
479        assert!(matches!(config.merge_strategy, MergeStrategy::Collect));
480        assert!(config.fan_in_timeout.is_none());
481    }
482
483    #[test]
484    fn test_merge_strategy_debug() {
485        assert_eq!(format!("{:?}", MergeStrategy::Collect), "Collect");
486        assert_eq!(format!("{:?}", MergeStrategy::MergeMap), "MergeMap");
487        assert_eq!(format!("{:?}", MergeStrategy::First), "First");
488        let custom = MergeStrategy::Custom(Arc::new(Value::Array));
489        assert_eq!(format!("{:?}", custom), "Custom(<fn>)");
490    }
491}