agent_chain_core/runnables/
branch.rs

1//! Runnable that selects which branch to run based on a condition.
2//!
3//! This module provides `RunnableBranch` which selects and runs one of several
4//! branches based on conditions, mirroring `langchain_core.runnables.branch`.
5
6use std::fmt::Debug;
7use std::sync::Arc;
8
9use async_trait::async_trait;
10use futures::StreamExt;
11use futures::stream::BoxStream;
12use serde::Serialize;
13
14use crate::error::{Error, Result};
15use crate::load::{Serializable, Serialized, SerializedConstructorData};
16
17use super::base::{DynRunnable, Runnable, RunnableLambda, RunnableSerializable};
18use super::config::{RunnableConfig, ensure_config, get_callback_manager_for_config, patch_config};
19use super::utils::{ConfigurableFieldSpec, get_unique_config_specs};
20
21/// A `Runnable` that selects which branch to run based on a condition.
22///
23/// The `Runnable` is initialized with a list of `(condition, Runnable)` pairs and
24/// a default branch.
25///
26/// When operating on an input, the first condition that evaluates to `true` is
27/// selected, and the corresponding `Runnable` is run on the input.
28///
29/// If no condition evaluates to `true`, the default branch is run on the input.
30///
31/// # Example
32///
33/// ```ignore
34/// use agent_chain_core::runnables::{RunnableBranch, RunnableLambda};
35/// use std::sync::Arc;
36///
37/// let branch = RunnableBranch::new(
38///     vec![
39///         (
40///             Arc::new(RunnableLambda::new(|x: i32| Ok(x > 0))),
41///             Arc::new(RunnableLambda::new(|x: i32| Ok(format!("positive: {}", x)))),
42///         ),
43///         (
44///             Arc::new(RunnableLambda::new(|x: i32| Ok(x < 0))),
45///             Arc::new(RunnableLambda::new(|x: i32| Ok(format!("negative: {}", x)))),
46///         ),
47///     ],
48///     Arc::new(RunnableLambda::new(|_: i32| Ok("zero".to_string()))),
49/// ).unwrap();
50///
51/// let result = branch.invoke(5, None).unwrap();
52/// assert_eq!(result, "positive: 5");
53/// ```
54pub struct RunnableBranch<I, O>
55where
56    I: Send + Sync + Clone + Debug + 'static,
57    O: Send + Sync + Clone + Debug + 'static,
58{
59    /// A list of `(condition, Runnable)` pairs.
60    branches: Vec<(DynRunnable<I, bool>, DynRunnable<I, O>)>,
61    /// A `Runnable` to run if no condition is met.
62    default: DynRunnable<I, O>,
63    /// Optional name for this branch.
64    name: Option<String>,
65}
66
67impl<I, O> Debug for RunnableBranch<I, O>
68where
69    I: Send + Sync + Clone + Debug + 'static,
70    O: Send + Sync + Clone + Debug + 'static,
71{
72    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
73        f.debug_struct("RunnableBranch")
74            .field("branches_count", &self.branches.len())
75            .field("name", &self.name)
76            .finish()
77    }
78}
79
80impl<I, O> RunnableBranch<I, O>
81where
82    I: Send + Sync + Clone + Debug + 'static,
83    O: Send + Sync + Clone + Debug + 'static,
84{
85    /// Create a new RunnableBranch.
86    ///
87    /// # Arguments
88    ///
89    /// * `branches` - A list of `(condition, runnable)` pairs. The condition is a
90    ///   Runnable that returns a boolean, and the runnable is executed if the
91    ///   condition returns true.
92    /// * `default` - The runnable to execute if no condition returns true.
93    ///
94    /// # Returns
95    ///
96    /// A Result containing the RunnableBranch, or an error if fewer than one
97    /// condition branch is provided.
98    ///
99    /// # Errors
100    ///
101    /// Returns an error if the number of branches is less than 1 (meaning at
102    /// least one condition branch plus the default is required).
103    pub fn new(
104        branches: Vec<(DynRunnable<I, bool>, DynRunnable<I, O>)>,
105        default: DynRunnable<I, O>,
106    ) -> Result<Self> {
107        if branches.is_empty() {
108            return Err(Error::Other(
109                "RunnableBranch requires at least one condition branch".to_string(),
110            ));
111        }
112
113        Ok(Self {
114            branches,
115            default,
116            name: None,
117        })
118    }
119
120    /// Set the name of this branch.
121    pub fn with_name(mut self, name: impl Into<String>) -> Self {
122        self.name = Some(name.into());
123        self
124    }
125
126    /// Get the configurable field specs from all contained runnables.
127    pub fn config_specs(&self) -> std::result::Result<Vec<ConfigurableFieldSpec>, String> {
128        let specs = self
129            .branches
130            .iter()
131            .flat_map(|(_condition, _runnable)| Vec::<ConfigurableFieldSpec>::new())
132            .collect::<Vec<_>>();
133
134        get_unique_config_specs(specs)
135    }
136}
137
138/// Builder for creating RunnableBranch with closures.
139pub struct RunnableBranchBuilder<I, O>
140where
141    I: Send + Sync + Clone + Debug + 'static,
142    O: Send + Sync + Clone + Debug + 'static,
143{
144    branches: Vec<(DynRunnable<I, bool>, DynRunnable<I, O>)>,
145    _phantom: std::marker::PhantomData<(I, O)>,
146}
147
148impl<I, O> RunnableBranchBuilder<I, O>
149where
150    I: Send + Sync + Clone + Debug + 'static,
151    O: Send + Sync + Clone + Debug + 'static,
152{
153    /// Create a new builder.
154    pub fn new() -> Self {
155        Self {
156            branches: Vec::new(),
157            _phantom: std::marker::PhantomData,
158        }
159    }
160
161    /// Add a branch with closures.
162    pub fn branch<CF, RF>(mut self, condition: CF, runnable: RF) -> Self
163    where
164        CF: Fn(I) -> Result<bool> + Send + Sync + 'static,
165        RF: Fn(I) -> Result<O> + Send + Sync + 'static,
166    {
167        let condition_runnable: DynRunnable<I, bool> = Arc::new(RunnableLambda::new(condition));
168        let branch_runnable: DynRunnable<I, O> = Arc::new(RunnableLambda::new(runnable));
169        self.branches.push((condition_runnable, branch_runnable));
170        self
171    }
172
173    /// Add a branch with Arc runnables.
174    pub fn branch_arc(
175        mut self,
176        condition: DynRunnable<I, bool>,
177        runnable: DynRunnable<I, O>,
178    ) -> Self {
179        self.branches.push((condition, runnable));
180        self
181    }
182
183    /// Build the RunnableBranch with a default closure.
184    pub fn default<DF>(self, default_fn: DF) -> Result<RunnableBranch<I, O>>
185    where
186        DF: Fn(I) -> Result<O> + Send + Sync + 'static,
187    {
188        let default_runnable: DynRunnable<I, O> = Arc::new(RunnableLambda::new(default_fn));
189        RunnableBranch::new(self.branches, default_runnable)
190    }
191
192    /// Build the RunnableBranch with a default Arc runnable.
193    pub fn default_arc(self, default: DynRunnable<I, O>) -> Result<RunnableBranch<I, O>> {
194        RunnableBranch::new(self.branches, default)
195    }
196}
197
198impl<I, O> Default for RunnableBranchBuilder<I, O>
199where
200    I: Send + Sync + Clone + Debug + 'static,
201    O: Send + Sync + Clone + Debug + 'static,
202{
203    fn default() -> Self {
204        Self::new()
205    }
206}
207
208#[async_trait]
209impl<I, O> Runnable for RunnableBranch<I, O>
210where
211    I: Send + Sync + Clone + Debug + 'static,
212    O: Send + Sync + Clone + Debug + 'static,
213{
214    type Input = I;
215    type Output = O;
216
217    fn name(&self) -> Option<String> {
218        self.name
219            .clone()
220            .or_else(|| Some("RunnableBranch".to_string()))
221    }
222
223    fn invoke(&self, input: Self::Input, config: Option<RunnableConfig>) -> Result<Self::Output> {
224        let config = ensure_config(config);
225        let callback_manager = get_callback_manager_for_config(&config);
226        let run_manager = callback_manager.on_chain_start(
227            &std::collections::HashMap::new(),
228            &std::collections::HashMap::new(),
229            config.run_id,
230        );
231
232        let result = (|| {
233            for (idx, (condition, runnable)) in self.branches.iter().enumerate() {
234                let condition_config = patch_config(
235                    Some(config.clone()),
236                    Some(run_manager.get_child(Some(&format!("condition:{}", idx + 1)))),
237                    None,
238                    None,
239                    None,
240                    None,
241                );
242
243                let expression_value = condition.invoke(input.clone(), Some(condition_config))?;
244
245                if expression_value {
246                    let branch_config = patch_config(
247                        Some(config.clone()),
248                        Some(run_manager.get_child(Some(&format!("branch:{}", idx + 1)))),
249                        None,
250                        None,
251                        None,
252                        None,
253                    );
254
255                    return runnable.invoke(input.clone(), Some(branch_config));
256                }
257            }
258
259            let default_config = patch_config(
260                Some(config.clone()),
261                Some(run_manager.get_child(Some("branch:default"))),
262                None,
263                None,
264                None,
265                None,
266            );
267
268            self.default.invoke(input, Some(default_config))
269        })();
270
271        match &result {
272            Ok(_) => {
273                run_manager.on_chain_end(&std::collections::HashMap::new());
274            }
275            Err(e) => {
276                run_manager.on_chain_error(e);
277            }
278        }
279
280        result
281    }
282
283    async fn ainvoke(
284        &self,
285        input: Self::Input,
286        config: Option<RunnableConfig>,
287    ) -> Result<Self::Output>
288    where
289        Self: 'static,
290    {
291        let config = ensure_config(config);
292
293        for (condition, runnable) in self.branches.iter() {
294            let expression_value = condition
295                .ainvoke(input.clone(), Some(config.clone()))
296                .await?;
297
298            if expression_value {
299                return runnable.ainvoke(input.clone(), Some(config.clone())).await;
300            }
301        }
302
303        self.default.ainvoke(input, Some(config)).await
304    }
305
306    fn stream(
307        &self,
308        input: Self::Input,
309        config: Option<RunnableConfig>,
310    ) -> BoxStream<'_, Result<Self::Output>> {
311        let config = ensure_config(config);
312
313        Box::pin(async_stream::stream! {
314            'outer: {
315                for (condition, runnable) in self.branches.iter() {
316                    let expression_value = match condition.invoke(input.clone(), Some(config.clone())) {
317                        Ok(v) => v,
318                        Err(e) => {
319                            yield Err(e);
320                            break 'outer;
321                        }
322                    };
323
324                    if expression_value {
325                        let mut stream = runnable.stream(input.clone(), Some(config.clone()));
326                        while let Some(chunk_result) = stream.next().await {
327                            yield chunk_result;
328                        }
329                        break 'outer;
330                    }
331                }
332
333                let mut stream = self.default.stream(input, Some(config.clone()));
334                while let Some(chunk_result) = stream.next().await {
335                    yield chunk_result;
336                }
337            }
338        })
339    }
340
341    fn astream(
342        &self,
343        input: Self::Input,
344        config: Option<RunnableConfig>,
345    ) -> BoxStream<'_, Result<Self::Output>>
346    where
347        Self: 'static,
348    {
349        let config = ensure_config(config);
350
351        Box::pin(async_stream::stream! {
352            'outer: {
353                for (condition, runnable) in self.branches.iter() {
354                    let expression_value = match condition.ainvoke(input.clone(), Some(config.clone())).await {
355                        Ok(v) => v,
356                        Err(e) => {
357                            yield Err(e);
358                            break 'outer;
359                        }
360                    };
361
362                    if expression_value {
363                        let mut stream = runnable.astream(input.clone(), Some(config.clone()));
364                        while let Some(chunk_result) = stream.next().await {
365                            yield chunk_result;
366                        }
367                        break 'outer;
368                    }
369                }
370
371                let mut stream = self.default.astream(input, Some(config.clone()));
372                while let Some(chunk_result) = stream.next().await {
373                    yield chunk_result;
374                }
375            }
376        })
377    }
378}
379
380impl<I, O> Serializable for RunnableBranch<I, O>
381where
382    I: Send + Sync + Clone + Debug + Serialize + 'static,
383    O: Send + Sync + Clone + Debug + 'static,
384{
385    fn is_lc_serializable() -> bool {
386        true
387    }
388
389    fn get_lc_namespace() -> Vec<String> {
390        vec![
391            "langchain".to_string(),
392            "schema".to_string(),
393            "runnable".to_string(),
394        ]
395    }
396
397    fn to_json(&self) -> Serialized {
398        let kwargs = std::collections::HashMap::new();
399
400        Serialized::Constructor(SerializedConstructorData {
401            lc: 1,
402            id: Self::get_lc_namespace(),
403            kwargs,
404            name: None,
405            graph: None,
406        })
407    }
408}
409
410impl<I, O> RunnableSerializable for RunnableBranch<I, O>
411where
412    I: Send + Sync + Clone + Debug + Serialize + 'static,
413    O: Send + Sync + Clone + Debug + Serialize + 'static,
414{
415}
416
417#[cfg(test)]
418mod tests {
419    use super::*;
420
421    #[test]
422    fn test_runnable_branch_invoke_first_condition() {
423        let branch = RunnableBranchBuilder::new()
424            .branch(|x: i32| Ok(x > 0), |x: i32| Ok(format!("positive: {}", x)))
425            .branch(|x: i32| Ok(x < 0), |x: i32| Ok(format!("negative: {}", x)))
426            .default(|_: i32| Ok("zero".to_string()))
427            .unwrap();
428
429        let result = branch.invoke(5, None).unwrap();
430        assert_eq!(result, "positive: 5");
431    }
432
433    #[test]
434    fn test_runnable_branch_invoke_second_condition() {
435        let branch = RunnableBranchBuilder::new()
436            .branch(|x: i32| Ok(x > 0), |x: i32| Ok(format!("positive: {}", x)))
437            .branch(|x: i32| Ok(x < 0), |x: i32| Ok(format!("negative: {}", x)))
438            .default(|_: i32| Ok("zero".to_string()))
439            .unwrap();
440
441        let result = branch.invoke(-3, None).unwrap();
442        assert_eq!(result, "negative: -3");
443    }
444
445    #[test]
446    fn test_runnable_branch_invoke_default() {
447        let branch = RunnableBranchBuilder::new()
448            .branch(|x: i32| Ok(x > 0), |x: i32| Ok(format!("positive: {}", x)))
449            .branch(|x: i32| Ok(x < 0), |x: i32| Ok(format!("negative: {}", x)))
450            .default(|_: i32| Ok("zero".to_string()))
451            .unwrap();
452
453        let result = branch.invoke(0, None).unwrap();
454        assert_eq!(result, "zero");
455    }
456
457    #[test]
458    fn test_runnable_branch_requires_at_least_one_branch() {
459        let result: Result<RunnableBranch<i32, String>> =
460            RunnableBranchBuilder::new().default(|_: i32| Ok("default".to_string()));
461
462        assert!(result.is_err());
463        assert!(
464            result
465                .unwrap_err()
466                .to_string()
467                .contains("at least one condition branch")
468        );
469    }
470
471    #[test]
472    fn test_runnable_branch_name() {
473        let branch = RunnableBranchBuilder::new()
474            .branch(|x: i32| Ok(x > 0), |x: i32| Ok(x.to_string()))
475            .default(|_: i32| Ok("default".to_string()))
476            .unwrap()
477            .with_name("my_branch");
478
479        assert_eq!(branch.name(), Some("my_branch".to_string()));
480    }
481
482    #[test]
483    fn test_runnable_branch_default_name() {
484        let branch = RunnableBranchBuilder::new()
485            .branch(|x: i32| Ok(x > 0), |x: i32| Ok(x.to_string()))
486            .default(|_: i32| Ok("default".to_string()))
487            .unwrap();
488
489        assert_eq!(branch.name(), Some("RunnableBranch".to_string()));
490    }
491
492    #[test]
493    fn test_runnable_branch_with_arc_runnables() {
494        let condition: DynRunnable<i32, bool> = Arc::new(RunnableLambda::new(|x: i32| Ok(x > 10)));
495        let branch_runnable: DynRunnable<i32, String> =
496            Arc::new(RunnableLambda::new(|x: i32| Ok(format!("big: {}", x))));
497        let default: DynRunnable<i32, String> =
498            Arc::new(RunnableLambda::new(|x: i32| Ok(format!("small: {}", x))));
499
500        let branch = RunnableBranch::new(vec![(condition, branch_runnable)], default).unwrap();
501
502        assert_eq!(branch.invoke(15, None).unwrap(), "big: 15");
503        assert_eq!(branch.invoke(5, None).unwrap(), "small: 5");
504    }
505
506    #[tokio::test]
507    async fn test_runnable_branch_ainvoke() {
508        let branch = RunnableBranchBuilder::new()
509            .branch(|x: i32| Ok(x > 0), |x: i32| Ok(format!("positive: {}", x)))
510            .branch(|x: i32| Ok(x < 0), |x: i32| Ok(format!("negative: {}", x)))
511            .default(|_: i32| Ok("zero".to_string()))
512            .unwrap();
513
514        let result = branch.ainvoke(5, None).await.unwrap();
515        assert_eq!(result, "positive: 5");
516
517        let result = branch.ainvoke(-3, None).await.unwrap();
518        assert_eq!(result, "negative: -3");
519
520        let result = branch.ainvoke(0, None).await.unwrap();
521        assert_eq!(result, "zero");
522    }
523
524    #[tokio::test]
525    async fn test_runnable_branch_stream() {
526        let branch = RunnableBranchBuilder::new()
527            .branch(|x: i32| Ok(x > 0), |x: i32| Ok(format!("positive: {}", x)))
528            .default(|_: i32| Ok("non-positive".to_string()))
529            .unwrap();
530
531        let mut stream = branch.stream(5, None);
532        let result = stream.next().await.unwrap().unwrap();
533        assert_eq!(result, "positive: 5");
534    }
535
536    #[tokio::test]
537    async fn test_runnable_branch_astream() {
538        let branch = RunnableBranchBuilder::new()
539            .branch(|x: i32| Ok(x > 0), |x: i32| Ok(format!("positive: {}", x)))
540            .default(|_: i32| Ok("non-positive".to_string()))
541            .unwrap();
542
543        let mut stream = branch.astream(5, None);
544        let result = stream.next().await.unwrap().unwrap();
545        assert_eq!(result, "positive: 5");
546    }
547}