agent_chain_core/runnables/
fallbacks.rs

1//! Runnable that can fallback to other Runnables if it fails.
2//!
3//! This module provides `RunnableWithFallbacks`, a Runnable that tries a primary
4//! runnable first and falls back to alternative runnables if the primary fails.
5//! This mirrors `langchain_core.runnables.fallbacks`.
6
7use std::fmt::Debug;
8use std::sync::Arc;
9
10use async_trait::async_trait;
11use futures::StreamExt;
12use futures::stream::BoxStream;
13
14use crate::error::{Error, Result};
15
16use super::base::{DynRunnable, Runnable};
17use super::config::{
18    ConfigOrList, RunnableConfig, ensure_config, get_callback_manager_for_config, get_config_list,
19    patch_config,
20};
21use super::utils::{ConfigurableFieldSpec, get_unique_config_specs};
22
23/// A `Runnable` that can fallback to other `Runnable`s if it fails.
24///
25/// External APIs (e.g., APIs for a language model) may at times experience
26/// degraded performance or even downtime.
27///
28/// In these cases, it can be useful to have a fallback `Runnable` that can be
29/// used in place of the original `Runnable` (e.g., fallback to another LLM provider).
30///
31/// Fallbacks can be defined at the level of a single `Runnable`, or at the level
32/// of a chain of `Runnable`s. Fallbacks are tried in order until one succeeds or
33/// all fail.
34///
35/// While you can instantiate a `RunnableWithFallbacks` directly, it is usually
36/// more convenient to use the `with_fallbacks` method on a `Runnable`.
37///
38/// # Example
39///
40/// ```ignore
41/// use agent_chain_core::runnables::{RunnableLambda, RunnableWithFallbacks};
42///
43/// // Create a primary runnable that might fail
44/// let primary = RunnableLambda::new(|x: i32| {
45///     if x > 5 { Err(Error::other("too large")) }
46///     else { Ok(x * 2) }
47/// });
48///
49/// // Create a fallback runnable
50/// let fallback = RunnableLambda::new(|x: i32| Ok(x));
51///
52/// // Combine them with fallbacks
53/// let with_fallbacks = RunnableWithFallbacks::new(primary, vec![fallback]);
54///
55/// // Will use primary for x <= 5, fallback for x > 5
56/// assert_eq!(with_fallbacks.invoke(3, None).unwrap(), 6);
57/// assert_eq!(with_fallbacks.invoke(10, None).unwrap(), 10);
58/// ```
59pub struct RunnableWithFallbacks<I, O>
60where
61    I: Send + Sync + Clone + Debug + 'static,
62    O: Send + Sync + Clone + Debug + 'static,
63{
64    /// The `Runnable` to run first.
65    pub runnable: DynRunnable<I, O>,
66    /// A sequence of fallbacks to try.
67    pub fallbacks: Vec<DynRunnable<I, O>>,
68    /// Whether to handle all errors (true) or specific error types.
69    /// In Rust, we simplify the Python's `exceptions_to_handle` to a boolean flag
70    /// since we don't have the same exception hierarchy.
71    pub handle_all_errors: bool,
72    /// Optional name for this runnable.
73    name: Option<String>,
74}
75
76impl<I, O> Debug for RunnableWithFallbacks<I, O>
77where
78    I: Send + Sync + Clone + Debug + 'static,
79    O: Send + Sync + Clone + Debug + 'static,
80{
81    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
82        f.debug_struct("RunnableWithFallbacks")
83            .field("runnable", &"<runnable>")
84            .field("fallbacks_count", &self.fallbacks.len())
85            .field("handle_all_errors", &self.handle_all_errors)
86            .field("name", &self.name)
87            .finish()
88    }
89}
90
91impl<I, O> RunnableWithFallbacks<I, O>
92where
93    I: Send + Sync + Clone + Debug + 'static,
94    O: Send + Sync + Clone + Debug + 'static,
95{
96    /// Create a new RunnableWithFallbacks.
97    ///
98    /// # Arguments
99    /// * `runnable` - The primary runnable to try first
100    /// * `fallbacks` - A list of fallback runnables to try if the primary fails
101    pub fn new<R>(runnable: R, fallbacks: Vec<DynRunnable<I, O>>) -> Self
102    where
103        R: Runnable<Input = I, Output = O> + Send + Sync + 'static,
104    {
105        Self {
106            runnable: Arc::new(runnable),
107            fallbacks,
108            handle_all_errors: true,
109            name: None,
110        }
111    }
112
113    /// Create a new RunnableWithFallbacks from a DynRunnable.
114    pub fn from_dyn(runnable: DynRunnable<I, O>, fallbacks: Vec<DynRunnable<I, O>>) -> Self {
115        Self {
116            runnable,
117            fallbacks,
118            handle_all_errors: true,
119            name: None,
120        }
121    }
122
123    /// Set whether to handle all errors.
124    ///
125    /// If true, any error will trigger fallback.
126    /// If false, only certain errors will trigger fallback (default: true).
127    pub fn with_handle_all_errors(mut self, handle_all_errors: bool) -> Self {
128        self.handle_all_errors = handle_all_errors;
129        self
130    }
131
132    /// Set the name of this runnable.
133    pub fn with_name(mut self, name: impl Into<String>) -> Self {
134        self.name = Some(name.into());
135        self
136    }
137
138    /// Get an iterator over all runnables (primary + fallbacks).
139    pub fn runnables(&self) -> impl Iterator<Item = &DynRunnable<I, O>> {
140        std::iter::once(&self.runnable).chain(self.fallbacks.iter())
141    }
142
143    /// Get the config specs from all runnables.
144    pub fn config_specs(&self) -> Result<Vec<ConfigurableFieldSpec>> {
145        let specs: Vec<ConfigurableFieldSpec> = self
146            .runnables()
147            .flat_map(|_r| {
148                // In a full implementation, we would get config specs from each runnable
149                // For now, return empty as the trait doesn't expose config_specs
150                Vec::<ConfigurableFieldSpec>::new()
151            })
152            .collect();
153
154        get_unique_config_specs(specs).map_err(Error::other)
155    }
156
157    /// Check if an error should trigger a fallback.
158    fn should_fallback(&self, _error: &Error) -> bool {
159        // In the Python version, this checks if the error is an instance of
160        // exceptions_to_handle. In Rust, we simplify to a boolean flag.
161        self.handle_all_errors
162    }
163}
164
165#[async_trait]
166impl<I, O> Runnable for RunnableWithFallbacks<I, O>
167where
168    I: Send + Sync + Clone + Debug + 'static,
169    O: Send + Sync + Clone + Debug + 'static,
170{
171    type Input = I;
172    type Output = O;
173
174    fn name(&self) -> Option<String> {
175        self.name.clone()
176    }
177
178    fn invoke(&self, input: Self::Input, config: Option<RunnableConfig>) -> Result<Self::Output> {
179        let config = ensure_config(config);
180        let callback_manager = get_callback_manager_for_config(&config);
181
182        // Start the root run
183        let run_manager = callback_manager.on_chain_start(
184            &std::collections::HashMap::new(),
185            &std::collections::HashMap::new(),
186            config.run_id,
187        );
188
189        let mut first_error: Option<Error> = None;
190
191        for runnable in self.runnables() {
192            let child_config = patch_config(
193                Some(config.clone()),
194                Some(run_manager.get_child(None)),
195                None,
196                None,
197                None,
198                None,
199            );
200
201            match runnable.invoke(input.clone(), Some(child_config)) {
202                Ok(output) => {
203                    run_manager.on_chain_end(&std::collections::HashMap::new());
204                    return Ok(output);
205                }
206                Err(e) => {
207                    if self.should_fallback(&e) {
208                        if first_error.is_none() {
209                            first_error = Some(e);
210                        }
211                    } else {
212                        run_manager.on_chain_error(&e);
213                        return Err(e);
214                    }
215                }
216            }
217        }
218
219        let error =
220            first_error.unwrap_or_else(|| Error::other("No error stored at end of fallbacks."));
221        run_manager.on_chain_error(&error);
222        Err(error)
223    }
224
225    async fn ainvoke(
226        &self,
227        input: Self::Input,
228        config: Option<RunnableConfig>,
229    ) -> Result<Self::Output>
230    where
231        Self: 'static,
232    {
233        let config = ensure_config(config);
234
235        let mut first_error: Option<Error> = None;
236
237        for runnable in self.runnables() {
238            match runnable.ainvoke(input.clone(), Some(config.clone())).await {
239                Ok(output) => {
240                    return Ok(output);
241                }
242                Err(e) => {
243                    if self.should_fallback(&e) {
244                        if first_error.is_none() {
245                            first_error = Some(e);
246                        }
247                    } else {
248                        return Err(e);
249                    }
250                }
251            }
252        }
253
254        Err(first_error.unwrap_or_else(|| Error::other("No error stored at end of fallbacks.")))
255    }
256
257    fn batch(
258        &self,
259        inputs: Vec<Self::Input>,
260        config: Option<ConfigOrList>,
261        return_exceptions: bool,
262    ) -> Vec<Result<Self::Output>>
263    where
264        Self: 'static,
265    {
266        if inputs.is_empty() {
267            return Vec::new();
268        }
269
270        let configs = get_config_list(config, inputs.len());
271        let n = inputs.len();
272
273        // Track which inputs still need to be processed
274        let mut to_return: Vec<Option<Result<Self::Output>>> = (0..n).map(|_| None).collect();
275        let mut run_again: Vec<(usize, Self::Input)> = inputs.into_iter().enumerate().collect();
276        let mut handled_exception_indices: Vec<usize> = Vec::new();
277        let mut first_to_raise: Option<Error> = None;
278
279        for runnable in self.runnables() {
280            if run_again.is_empty() {
281                break;
282            }
283
284            // Get inputs and configs for items that need to be run again
285            let batch_inputs: Vec<Self::Input> =
286                run_again.iter().map(|(_, inp)| inp.clone()).collect();
287            let batch_configs: Vec<RunnableConfig> =
288                run_again.iter().map(|(i, _)| configs[*i].clone()).collect();
289
290            let outputs = runnable.batch(
291                batch_inputs,
292                Some(ConfigOrList::List(batch_configs)),
293                true, // Always return exceptions to handle them ourselves
294            );
295
296            let mut next_run_again = Vec::new();
297
298            for ((i, input), output) in run_again.iter().zip(outputs) {
299                match output {
300                    Ok(out) => {
301                        to_return[*i] = Some(Ok(out));
302                        handled_exception_indices.retain(|&idx| idx != *i);
303                    }
304                    Err(e) => {
305                        if self.should_fallback(&e) {
306                            if !handled_exception_indices.contains(i) {
307                                handled_exception_indices.push(*i);
308                            }
309                            // Store the error for this index
310                            to_return[*i] = Some(Err(e));
311                            next_run_again.push((*i, input.clone()));
312                        } else if return_exceptions {
313                            to_return[*i] = Some(Err(e));
314                        } else if first_to_raise.is_none() {
315                            first_to_raise = Some(e);
316                        }
317                    }
318                }
319            }
320
321            if first_to_raise.is_some() {
322                // Return early with the first non-fallback error
323                let mut results = Vec::with_capacity(to_return.len());
324                let mut error_consumed = false;
325                for opt in to_return {
326                    match opt {
327                        Some(result) => results.push(result),
328                        None => {
329                            if !error_consumed {
330                                results.push(Err(first_to_raise.take().unwrap()));
331                                error_consumed = true;
332                            } else {
333                                results.push(Err(Error::other("Batch aborted due to error")));
334                            }
335                        }
336                    }
337                }
338                return results;
339            }
340
341            run_again = next_run_again;
342        }
343
344        // All fallbacks exhausted - errors are already stored in to_return
345        if !return_exceptions && !handled_exception_indices.is_empty() {
346            // Return all results as-is, errors from the last fallback attempt are stored
347        }
348
349        // Return results, filling in errors for items that never had any result
350        to_return
351            .into_iter()
352            .map(|opt| opt.unwrap_or_else(|| Err(Error::other("No result for index"))))
353            .collect()
354    }
355
356    async fn abatch(
357        &self,
358        inputs: Vec<Self::Input>,
359        config: Option<ConfigOrList>,
360        return_exceptions: bool,
361    ) -> Vec<Result<Self::Output>>
362    where
363        Self: 'static,
364    {
365        if inputs.is_empty() {
366            return Vec::new();
367        }
368
369        let configs = get_config_list(config, inputs.len());
370        let n = inputs.len();
371
372        // Track which inputs still need to be processed
373        let mut to_return: Vec<Option<Result<Self::Output>>> = (0..n).map(|_| None).collect();
374        let mut run_again: Vec<(usize, Self::Input)> = inputs.into_iter().enumerate().collect();
375        let mut handled_exception_indices: Vec<usize> = Vec::new();
376        let mut first_to_raise: Option<Error> = None;
377
378        for runnable in self.runnables() {
379            if run_again.is_empty() {
380                break;
381            }
382
383            // Get inputs and configs for items that need to be run again
384            let batch_inputs: Vec<Self::Input> =
385                run_again.iter().map(|(_, inp)| inp.clone()).collect();
386            let batch_configs: Vec<RunnableConfig> =
387                run_again.iter().map(|(i, _)| configs[*i].clone()).collect();
388
389            let outputs = runnable
390                .abatch(
391                    batch_inputs,
392                    Some(ConfigOrList::List(batch_configs)),
393                    true, // Always return exceptions to handle them ourselves
394                )
395                .await;
396
397            let mut next_run_again = Vec::new();
398
399            for ((i, input), output) in run_again.iter().zip(outputs) {
400                match output {
401                    Ok(out) => {
402                        to_return[*i] = Some(Ok(out));
403                        handled_exception_indices.retain(|&idx| idx != *i);
404                    }
405                    Err(e) => {
406                        if self.should_fallback(&e) {
407                            if !handled_exception_indices.contains(i) {
408                                handled_exception_indices.push(*i);
409                            }
410                            // Store the error for this index
411                            to_return[*i] = Some(Err(e));
412                            next_run_again.push((*i, input.clone()));
413                        } else if return_exceptions {
414                            to_return[*i] = Some(Err(e));
415                        } else if first_to_raise.is_none() {
416                            first_to_raise = Some(e);
417                        }
418                    }
419                }
420            }
421
422            if first_to_raise.is_some() {
423                // Return early with the first non-fallback error
424                let mut results = Vec::with_capacity(to_return.len());
425                let mut error_consumed = false;
426                for opt in to_return {
427                    match opt {
428                        Some(result) => results.push(result),
429                        None => {
430                            if !error_consumed {
431                                results.push(Err(first_to_raise.take().unwrap()));
432                                error_consumed = true;
433                            } else {
434                                results.push(Err(Error::other("Batch aborted due to error")));
435                            }
436                        }
437                    }
438                }
439                return results;
440            }
441
442            run_again = next_run_again;
443        }
444
445        // All fallbacks exhausted - errors are already stored in to_return
446        if !return_exceptions && !handled_exception_indices.is_empty() {
447            // Return all results as-is, errors from the last fallback attempt are stored
448        }
449
450        // Return results, filling in errors for items that never had any result
451        to_return
452            .into_iter()
453            .map(|opt| opt.unwrap_or_else(|| Err(Error::other("No result for index"))))
454            .collect()
455    }
456
457    fn stream(
458        &self,
459        input: Self::Input,
460        config: Option<RunnableConfig>,
461    ) -> BoxStream<'_, Result<Self::Output>> {
462        let config = ensure_config(config);
463
464        Box::pin(async_stream::stream! {
465            let mut first_error: Option<Error> = None;
466
467            for runnable in self.runnables() {
468                // Try to get the first chunk from this runnable's stream
469                let mut stream = runnable.stream(input.clone(), Some(config.clone()));
470
471                match stream.next().await {
472                    Some(Ok(chunk)) => {
473                        // Success! Yield this chunk and continue streaming
474                        yield Ok(chunk);
475
476                        // Stream remaining chunks
477                        while let Some(result) = stream.next().await {
478                            yield result;
479                        }
480                        return;
481                    }
482                    Some(Err(e)) => {
483                        if self.should_fallback(&e) {
484                            if first_error.is_none() {
485                                first_error = Some(e);
486                            }
487                            // Try next fallback
488                        } else {
489                            yield Err(e);
490                            return;
491                        }
492                    }
493                    None => {
494                        // Empty stream, try next fallback
495                        if first_error.is_none() {
496                            first_error = Some(Error::other("Empty stream from runnable"));
497                        }
498                    }
499                }
500            }
501
502            // All fallbacks exhausted
503            yield Err(first_error.unwrap_or_else(|| Error::other("No error stored at end of fallbacks.")));
504        })
505    }
506
507    fn astream(
508        &self,
509        input: Self::Input,
510        config: Option<RunnableConfig>,
511    ) -> BoxStream<'_, Result<Self::Output>>
512    where
513        Self: 'static,
514    {
515        let config = ensure_config(config);
516
517        Box::pin(async_stream::stream! {
518            let mut first_error: Option<Error> = None;
519
520            for runnable in self.runnables() {
521                // Try to get the first chunk from this runnable's stream
522                let mut stream = runnable.astream(input.clone(), Some(config.clone()));
523
524                match stream.next().await {
525                    Some(Ok(chunk)) => {
526                        // Success! Yield this chunk and continue streaming
527                        yield Ok(chunk);
528
529                        // Stream remaining chunks
530                        while let Some(result) = stream.next().await {
531                            yield result;
532                        }
533                        return;
534                    }
535                    Some(Err(e)) => {
536                        if self.should_fallback(&e) {
537                            if first_error.is_none() {
538                                first_error = Some(e);
539                            }
540                            // Try next fallback
541                        } else {
542                            yield Err(e);
543                            return;
544                        }
545                    }
546                    None => {
547                        // Empty stream, try next fallback
548                        if first_error.is_none() {
549                            first_error = Some(Error::other("Empty stream from runnable"));
550                        }
551                    }
552                }
553            }
554
555            // All fallbacks exhausted
556            yield Err(first_error.unwrap_or_else(|| Error::other("No error stored at end of fallbacks.")));
557        })
558    }
559}
560
561/// Extension trait to add `with_fallbacks` method to any Runnable.
562pub trait RunnableWithFallbacksExt: Runnable {
563    /// Create a new Runnable that tries this runnable first, then falls back to others.
564    ///
565    /// # Arguments
566    /// * `fallbacks` - A list of fallback runnables to try if this one fails
567    ///
568    /// # Returns
569    /// A new `RunnableWithFallbacks` instance
570    fn with_fallbacks(
571        self,
572        fallbacks: Vec<DynRunnable<Self::Input, Self::Output>>,
573    ) -> RunnableWithFallbacks<Self::Input, Self::Output>
574    where
575        Self: Sized + Send + Sync + 'static,
576    {
577        RunnableWithFallbacks::new(self, fallbacks)
578    }
579}
580
581// Implement the extension trait for all Runnables
582impl<R: Runnable> RunnableWithFallbacksExt for R {}
583
584#[cfg(test)]
585mod tests {
586    use super::*;
587    use crate::runnables::base::RunnableLambda;
588
589    #[test]
590    fn test_fallback_on_error() {
591        let primary =
592            RunnableLambda::new(|_x: i32| -> Result<i32> { Err(Error::other("primary failed")) });
593
594        let fallback = RunnableLambda::new(|x: i32| -> Result<i32> { Ok(x * 2) });
595
596        let with_fallbacks = RunnableWithFallbacks::new(primary, vec![Arc::new(fallback)]);
597
598        let result = with_fallbacks.invoke(5, None).unwrap();
599        assert_eq!(result, 10);
600    }
601
602    #[test]
603    fn test_primary_succeeds() {
604        let primary = RunnableLambda::new(|x: i32| -> Result<i32> { Ok(x + 1) });
605
606        let fallback = RunnableLambda::new(|x: i32| -> Result<i32> { Ok(x * 2) });
607
608        let with_fallbacks = RunnableWithFallbacks::new(primary, vec![Arc::new(fallback)]);
609
610        let result = with_fallbacks.invoke(5, None).unwrap();
611        assert_eq!(result, 6); // Primary succeeded, not fallback
612    }
613
614    #[test]
615    fn test_all_fail() {
616        let primary =
617            RunnableLambda::new(|_x: i32| -> Result<i32> { Err(Error::other("primary failed")) });
618
619        let fallback =
620            RunnableLambda::new(|_x: i32| -> Result<i32> { Err(Error::other("fallback failed")) });
621
622        let with_fallbacks = RunnableWithFallbacks::new(primary, vec![Arc::new(fallback)]);
623
624        let result = with_fallbacks.invoke(5, None);
625        assert!(result.is_err());
626    }
627
628    #[test]
629    fn test_multiple_fallbacks() {
630        let primary =
631            RunnableLambda::new(|_x: i32| -> Result<i32> { Err(Error::other("primary failed")) });
632
633        let fallback1 =
634            RunnableLambda::new(|_x: i32| -> Result<i32> { Err(Error::other("fallback1 failed")) });
635
636        let fallback2 = RunnableLambda::new(|x: i32| -> Result<i32> { Ok(x * 3) });
637
638        let with_fallbacks =
639            RunnableWithFallbacks::new(primary, vec![Arc::new(fallback1), Arc::new(fallback2)]);
640
641        let result = with_fallbacks.invoke(5, None).unwrap();
642        assert_eq!(result, 15); // Second fallback succeeded
643    }
644
645    #[test]
646    fn test_with_fallbacks_ext() {
647        let primary =
648            RunnableLambda::new(|_x: i32| -> Result<i32> { Err(Error::other("primary failed")) });
649
650        let fallback = RunnableLambda::new(|x: i32| -> Result<i32> { Ok(x * 2) });
651
652        let with_fallbacks = primary.with_fallbacks(vec![Arc::new(fallback)]);
653
654        let result = with_fallbacks.invoke(5, None).unwrap();
655        assert_eq!(result, 10);
656    }
657
658    #[tokio::test]
659    async fn test_fallback_async() {
660        let primary =
661            RunnableLambda::new(|_x: i32| -> Result<i32> { Err(Error::other("primary failed")) });
662
663        let fallback = RunnableLambda::new(|x: i32| -> Result<i32> { Ok(x * 2) });
664
665        let with_fallbacks = RunnableWithFallbacks::new(primary, vec![Arc::new(fallback)]);
666
667        let result = with_fallbacks.ainvoke(5, None).await.unwrap();
668        assert_eq!(result, 10);
669    }
670
671    #[test]
672    fn test_batch_fallback() {
673        let primary = RunnableLambda::new(|x: i32| -> Result<i32> {
674            if x > 5 {
675                Err(Error::other("too large"))
676            } else {
677                Ok(x + 1)
678            }
679        });
680
681        let fallback = RunnableLambda::new(|x: i32| -> Result<i32> { Ok(x * 2) });
682
683        let with_fallbacks = RunnableWithFallbacks::new(primary, vec![Arc::new(fallback)]);
684
685        let results = with_fallbacks.batch(vec![3, 10, 5], None, false);
686
687        // 3 -> primary succeeds -> 4
688        // 10 -> primary fails, fallback succeeds -> 20
689        // 5 -> primary succeeds -> 6
690        assert_eq!(results[0].as_ref().unwrap(), &4);
691        assert_eq!(results[1].as_ref().unwrap(), &20);
692        assert_eq!(results[2].as_ref().unwrap(), &6);
693    }
694
695    #[tokio::test]
696    async fn test_stream_fallback() {
697        use futures::StreamExt;
698
699        let primary =
700            RunnableLambda::new(|_x: i32| -> Result<i32> { Err(Error::other("primary failed")) });
701
702        let fallback = RunnableLambda::new(|x: i32| -> Result<i32> { Ok(x * 2) });
703
704        let with_fallbacks = RunnableWithFallbacks::new(primary, vec![Arc::new(fallback)]);
705
706        let mut stream = with_fallbacks.stream(5, None);
707        let result = stream.next().await.unwrap().unwrap();
708        assert_eq!(result, 10);
709    }
710
711    #[test]
712    fn test_runnables_iterator() {
713        let primary = RunnableLambda::new(|x: i32| -> Result<i32> { Ok(x) });
714        let fallback1 = RunnableLambda::new(|x: i32| -> Result<i32> { Ok(x) });
715        let fallback2 = RunnableLambda::new(|x: i32| -> Result<i32> { Ok(x) });
716
717        let with_fallbacks =
718            RunnableWithFallbacks::new(primary, vec![Arc::new(fallback1), Arc::new(fallback2)]);
719
720        let count = with_fallbacks.runnables().count();
721        assert_eq!(count, 3); // primary + 2 fallbacks
722    }
723}