agent_chain_core/runnables/
fallbacks.rs1use std::fmt::Debug;
8use std::sync::Arc;
9
10use async_trait::async_trait;
11use futures::StreamExt;
12use futures::stream::BoxStream;
13
14use crate::error::{Error, Result};
15
16use super::base::{DynRunnable, Runnable};
17use super::config::{
18 ConfigOrList, RunnableConfig, ensure_config, get_callback_manager_for_config, get_config_list,
19 patch_config,
20};
21use super::utils::{ConfigurableFieldSpec, get_unique_config_specs};
22
23pub struct RunnableWithFallbacks<I, O>
60where
61 I: Send + Sync + Clone + Debug + 'static,
62 O: Send + Sync + Clone + Debug + 'static,
63{
64 pub runnable: DynRunnable<I, O>,
66 pub fallbacks: Vec<DynRunnable<I, O>>,
68 pub handle_all_errors: bool,
72 name: Option<String>,
74}
75
76impl<I, O> Debug for RunnableWithFallbacks<I, O>
77where
78 I: Send + Sync + Clone + Debug + 'static,
79 O: Send + Sync + Clone + Debug + 'static,
80{
81 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
82 f.debug_struct("RunnableWithFallbacks")
83 .field("runnable", &"<runnable>")
84 .field("fallbacks_count", &self.fallbacks.len())
85 .field("handle_all_errors", &self.handle_all_errors)
86 .field("name", &self.name)
87 .finish()
88 }
89}
90
91impl<I, O> RunnableWithFallbacks<I, O>
92where
93 I: Send + Sync + Clone + Debug + 'static,
94 O: Send + Sync + Clone + Debug + 'static,
95{
96 pub fn new<R>(runnable: R, fallbacks: Vec<DynRunnable<I, O>>) -> Self
102 where
103 R: Runnable<Input = I, Output = O> + Send + Sync + 'static,
104 {
105 Self {
106 runnable: Arc::new(runnable),
107 fallbacks,
108 handle_all_errors: true,
109 name: None,
110 }
111 }
112
113 pub fn from_dyn(runnable: DynRunnable<I, O>, fallbacks: Vec<DynRunnable<I, O>>) -> Self {
115 Self {
116 runnable,
117 fallbacks,
118 handle_all_errors: true,
119 name: None,
120 }
121 }
122
123 pub fn with_handle_all_errors(mut self, handle_all_errors: bool) -> Self {
128 self.handle_all_errors = handle_all_errors;
129 self
130 }
131
132 pub fn with_name(mut self, name: impl Into<String>) -> Self {
134 self.name = Some(name.into());
135 self
136 }
137
138 pub fn runnables(&self) -> impl Iterator<Item = &DynRunnable<I, O>> {
140 std::iter::once(&self.runnable).chain(self.fallbacks.iter())
141 }
142
143 pub fn config_specs(&self) -> Result<Vec<ConfigurableFieldSpec>> {
145 let specs: Vec<ConfigurableFieldSpec> = self
146 .runnables()
147 .flat_map(|_r| {
148 Vec::<ConfigurableFieldSpec>::new()
151 })
152 .collect();
153
154 get_unique_config_specs(specs).map_err(Error::other)
155 }
156
157 fn should_fallback(&self, _error: &Error) -> bool {
159 self.handle_all_errors
162 }
163}
164
165#[async_trait]
166impl<I, O> Runnable for RunnableWithFallbacks<I, O>
167where
168 I: Send + Sync + Clone + Debug + 'static,
169 O: Send + Sync + Clone + Debug + 'static,
170{
171 type Input = I;
172 type Output = O;
173
174 fn name(&self) -> Option<String> {
175 self.name.clone()
176 }
177
178 fn invoke(&self, input: Self::Input, config: Option<RunnableConfig>) -> Result<Self::Output> {
179 let config = ensure_config(config);
180 let callback_manager = get_callback_manager_for_config(&config);
181
182 let run_manager = callback_manager.on_chain_start(
184 &std::collections::HashMap::new(),
185 &std::collections::HashMap::new(),
186 config.run_id,
187 );
188
189 let mut first_error: Option<Error> = None;
190
191 for runnable in self.runnables() {
192 let child_config = patch_config(
193 Some(config.clone()),
194 Some(run_manager.get_child(None)),
195 None,
196 None,
197 None,
198 None,
199 );
200
201 match runnable.invoke(input.clone(), Some(child_config)) {
202 Ok(output) => {
203 run_manager.on_chain_end(&std::collections::HashMap::new());
204 return Ok(output);
205 }
206 Err(e) => {
207 if self.should_fallback(&e) {
208 if first_error.is_none() {
209 first_error = Some(e);
210 }
211 } else {
212 run_manager.on_chain_error(&e);
213 return Err(e);
214 }
215 }
216 }
217 }
218
219 let error =
220 first_error.unwrap_or_else(|| Error::other("No error stored at end of fallbacks."));
221 run_manager.on_chain_error(&error);
222 Err(error)
223 }
224
225 async fn ainvoke(
226 &self,
227 input: Self::Input,
228 config: Option<RunnableConfig>,
229 ) -> Result<Self::Output>
230 where
231 Self: 'static,
232 {
233 let config = ensure_config(config);
234
235 let mut first_error: Option<Error> = None;
236
237 for runnable in self.runnables() {
238 match runnable.ainvoke(input.clone(), Some(config.clone())).await {
239 Ok(output) => {
240 return Ok(output);
241 }
242 Err(e) => {
243 if self.should_fallback(&e) {
244 if first_error.is_none() {
245 first_error = Some(e);
246 }
247 } else {
248 return Err(e);
249 }
250 }
251 }
252 }
253
254 Err(first_error.unwrap_or_else(|| Error::other("No error stored at end of fallbacks.")))
255 }
256
257 fn batch(
258 &self,
259 inputs: Vec<Self::Input>,
260 config: Option<ConfigOrList>,
261 return_exceptions: bool,
262 ) -> Vec<Result<Self::Output>>
263 where
264 Self: 'static,
265 {
266 if inputs.is_empty() {
267 return Vec::new();
268 }
269
270 let configs = get_config_list(config, inputs.len());
271 let n = inputs.len();
272
273 let mut to_return: Vec<Option<Result<Self::Output>>> = (0..n).map(|_| None).collect();
275 let mut run_again: Vec<(usize, Self::Input)> = inputs.into_iter().enumerate().collect();
276 let mut handled_exception_indices: Vec<usize> = Vec::new();
277 let mut first_to_raise: Option<Error> = None;
278
279 for runnable in self.runnables() {
280 if run_again.is_empty() {
281 break;
282 }
283
284 let batch_inputs: Vec<Self::Input> =
286 run_again.iter().map(|(_, inp)| inp.clone()).collect();
287 let batch_configs: Vec<RunnableConfig> =
288 run_again.iter().map(|(i, _)| configs[*i].clone()).collect();
289
290 let outputs = runnable.batch(
291 batch_inputs,
292 Some(ConfigOrList::List(batch_configs)),
293 true, );
295
296 let mut next_run_again = Vec::new();
297
298 for ((i, input), output) in run_again.iter().zip(outputs) {
299 match output {
300 Ok(out) => {
301 to_return[*i] = Some(Ok(out));
302 handled_exception_indices.retain(|&idx| idx != *i);
303 }
304 Err(e) => {
305 if self.should_fallback(&e) {
306 if !handled_exception_indices.contains(i) {
307 handled_exception_indices.push(*i);
308 }
309 to_return[*i] = Some(Err(e));
311 next_run_again.push((*i, input.clone()));
312 } else if return_exceptions {
313 to_return[*i] = Some(Err(e));
314 } else if first_to_raise.is_none() {
315 first_to_raise = Some(e);
316 }
317 }
318 }
319 }
320
321 if first_to_raise.is_some() {
322 let mut results = Vec::with_capacity(to_return.len());
324 let mut error_consumed = false;
325 for opt in to_return {
326 match opt {
327 Some(result) => results.push(result),
328 None => {
329 if !error_consumed {
330 results.push(Err(first_to_raise.take().unwrap()));
331 error_consumed = true;
332 } else {
333 results.push(Err(Error::other("Batch aborted due to error")));
334 }
335 }
336 }
337 }
338 return results;
339 }
340
341 run_again = next_run_again;
342 }
343
344 if !return_exceptions && !handled_exception_indices.is_empty() {
346 }
348
349 to_return
351 .into_iter()
352 .map(|opt| opt.unwrap_or_else(|| Err(Error::other("No result for index"))))
353 .collect()
354 }
355
356 async fn abatch(
357 &self,
358 inputs: Vec<Self::Input>,
359 config: Option<ConfigOrList>,
360 return_exceptions: bool,
361 ) -> Vec<Result<Self::Output>>
362 where
363 Self: 'static,
364 {
365 if inputs.is_empty() {
366 return Vec::new();
367 }
368
369 let configs = get_config_list(config, inputs.len());
370 let n = inputs.len();
371
372 let mut to_return: Vec<Option<Result<Self::Output>>> = (0..n).map(|_| None).collect();
374 let mut run_again: Vec<(usize, Self::Input)> = inputs.into_iter().enumerate().collect();
375 let mut handled_exception_indices: Vec<usize> = Vec::new();
376 let mut first_to_raise: Option<Error> = None;
377
378 for runnable in self.runnables() {
379 if run_again.is_empty() {
380 break;
381 }
382
383 let batch_inputs: Vec<Self::Input> =
385 run_again.iter().map(|(_, inp)| inp.clone()).collect();
386 let batch_configs: Vec<RunnableConfig> =
387 run_again.iter().map(|(i, _)| configs[*i].clone()).collect();
388
389 let outputs = runnable
390 .abatch(
391 batch_inputs,
392 Some(ConfigOrList::List(batch_configs)),
393 true, )
395 .await;
396
397 let mut next_run_again = Vec::new();
398
399 for ((i, input), output) in run_again.iter().zip(outputs) {
400 match output {
401 Ok(out) => {
402 to_return[*i] = Some(Ok(out));
403 handled_exception_indices.retain(|&idx| idx != *i);
404 }
405 Err(e) => {
406 if self.should_fallback(&e) {
407 if !handled_exception_indices.contains(i) {
408 handled_exception_indices.push(*i);
409 }
410 to_return[*i] = Some(Err(e));
412 next_run_again.push((*i, input.clone()));
413 } else if return_exceptions {
414 to_return[*i] = Some(Err(e));
415 } else if first_to_raise.is_none() {
416 first_to_raise = Some(e);
417 }
418 }
419 }
420 }
421
422 if first_to_raise.is_some() {
423 let mut results = Vec::with_capacity(to_return.len());
425 let mut error_consumed = false;
426 for opt in to_return {
427 match opt {
428 Some(result) => results.push(result),
429 None => {
430 if !error_consumed {
431 results.push(Err(first_to_raise.take().unwrap()));
432 error_consumed = true;
433 } else {
434 results.push(Err(Error::other("Batch aborted due to error")));
435 }
436 }
437 }
438 }
439 return results;
440 }
441
442 run_again = next_run_again;
443 }
444
445 if !return_exceptions && !handled_exception_indices.is_empty() {
447 }
449
450 to_return
452 .into_iter()
453 .map(|opt| opt.unwrap_or_else(|| Err(Error::other("No result for index"))))
454 .collect()
455 }
456
457 fn stream(
458 &self,
459 input: Self::Input,
460 config: Option<RunnableConfig>,
461 ) -> BoxStream<'_, Result<Self::Output>> {
462 let config = ensure_config(config);
463
464 Box::pin(async_stream::stream! {
465 let mut first_error: Option<Error> = None;
466
467 for runnable in self.runnables() {
468 let mut stream = runnable.stream(input.clone(), Some(config.clone()));
470
471 match stream.next().await {
472 Some(Ok(chunk)) => {
473 yield Ok(chunk);
475
476 while let Some(result) = stream.next().await {
478 yield result;
479 }
480 return;
481 }
482 Some(Err(e)) => {
483 if self.should_fallback(&e) {
484 if first_error.is_none() {
485 first_error = Some(e);
486 }
487 } else {
489 yield Err(e);
490 return;
491 }
492 }
493 None => {
494 if first_error.is_none() {
496 first_error = Some(Error::other("Empty stream from runnable"));
497 }
498 }
499 }
500 }
501
502 yield Err(first_error.unwrap_or_else(|| Error::other("No error stored at end of fallbacks.")));
504 })
505 }
506
507 fn astream(
508 &self,
509 input: Self::Input,
510 config: Option<RunnableConfig>,
511 ) -> BoxStream<'_, Result<Self::Output>>
512 where
513 Self: 'static,
514 {
515 let config = ensure_config(config);
516
517 Box::pin(async_stream::stream! {
518 let mut first_error: Option<Error> = None;
519
520 for runnable in self.runnables() {
521 let mut stream = runnable.astream(input.clone(), Some(config.clone()));
523
524 match stream.next().await {
525 Some(Ok(chunk)) => {
526 yield Ok(chunk);
528
529 while let Some(result) = stream.next().await {
531 yield result;
532 }
533 return;
534 }
535 Some(Err(e)) => {
536 if self.should_fallback(&e) {
537 if first_error.is_none() {
538 first_error = Some(e);
539 }
540 } else {
542 yield Err(e);
543 return;
544 }
545 }
546 None => {
547 if first_error.is_none() {
549 first_error = Some(Error::other("Empty stream from runnable"));
550 }
551 }
552 }
553 }
554
555 yield Err(first_error.unwrap_or_else(|| Error::other("No error stored at end of fallbacks.")));
557 })
558 }
559}
560
561pub trait RunnableWithFallbacksExt: Runnable {
563 fn with_fallbacks(
571 self,
572 fallbacks: Vec<DynRunnable<Self::Input, Self::Output>>,
573 ) -> RunnableWithFallbacks<Self::Input, Self::Output>
574 where
575 Self: Sized + Send + Sync + 'static,
576 {
577 RunnableWithFallbacks::new(self, fallbacks)
578 }
579}
580
581impl<R: Runnable> RunnableWithFallbacksExt for R {}
583
584#[cfg(test)]
585mod tests {
586 use super::*;
587 use crate::runnables::base::RunnableLambda;
588
589 #[test]
590 fn test_fallback_on_error() {
591 let primary =
592 RunnableLambda::new(|_x: i32| -> Result<i32> { Err(Error::other("primary failed")) });
593
594 let fallback = RunnableLambda::new(|x: i32| -> Result<i32> { Ok(x * 2) });
595
596 let with_fallbacks = RunnableWithFallbacks::new(primary, vec![Arc::new(fallback)]);
597
598 let result = with_fallbacks.invoke(5, None).unwrap();
599 assert_eq!(result, 10);
600 }
601
602 #[test]
603 fn test_primary_succeeds() {
604 let primary = RunnableLambda::new(|x: i32| -> Result<i32> { Ok(x + 1) });
605
606 let fallback = RunnableLambda::new(|x: i32| -> Result<i32> { Ok(x * 2) });
607
608 let with_fallbacks = RunnableWithFallbacks::new(primary, vec![Arc::new(fallback)]);
609
610 let result = with_fallbacks.invoke(5, None).unwrap();
611 assert_eq!(result, 6); }
613
614 #[test]
615 fn test_all_fail() {
616 let primary =
617 RunnableLambda::new(|_x: i32| -> Result<i32> { Err(Error::other("primary failed")) });
618
619 let fallback =
620 RunnableLambda::new(|_x: i32| -> Result<i32> { Err(Error::other("fallback failed")) });
621
622 let with_fallbacks = RunnableWithFallbacks::new(primary, vec![Arc::new(fallback)]);
623
624 let result = with_fallbacks.invoke(5, None);
625 assert!(result.is_err());
626 }
627
628 #[test]
629 fn test_multiple_fallbacks() {
630 let primary =
631 RunnableLambda::new(|_x: i32| -> Result<i32> { Err(Error::other("primary failed")) });
632
633 let fallback1 =
634 RunnableLambda::new(|_x: i32| -> Result<i32> { Err(Error::other("fallback1 failed")) });
635
636 let fallback2 = RunnableLambda::new(|x: i32| -> Result<i32> { Ok(x * 3) });
637
638 let with_fallbacks =
639 RunnableWithFallbacks::new(primary, vec![Arc::new(fallback1), Arc::new(fallback2)]);
640
641 let result = with_fallbacks.invoke(5, None).unwrap();
642 assert_eq!(result, 15); }
644
645 #[test]
646 fn test_with_fallbacks_ext() {
647 let primary =
648 RunnableLambda::new(|_x: i32| -> Result<i32> { Err(Error::other("primary failed")) });
649
650 let fallback = RunnableLambda::new(|x: i32| -> Result<i32> { Ok(x * 2) });
651
652 let with_fallbacks = primary.with_fallbacks(vec![Arc::new(fallback)]);
653
654 let result = with_fallbacks.invoke(5, None).unwrap();
655 assert_eq!(result, 10);
656 }
657
658 #[tokio::test]
659 async fn test_fallback_async() {
660 let primary =
661 RunnableLambda::new(|_x: i32| -> Result<i32> { Err(Error::other("primary failed")) });
662
663 let fallback = RunnableLambda::new(|x: i32| -> Result<i32> { Ok(x * 2) });
664
665 let with_fallbacks = RunnableWithFallbacks::new(primary, vec![Arc::new(fallback)]);
666
667 let result = with_fallbacks.ainvoke(5, None).await.unwrap();
668 assert_eq!(result, 10);
669 }
670
671 #[test]
672 fn test_batch_fallback() {
673 let primary = RunnableLambda::new(|x: i32| -> Result<i32> {
674 if x > 5 {
675 Err(Error::other("too large"))
676 } else {
677 Ok(x + 1)
678 }
679 });
680
681 let fallback = RunnableLambda::new(|x: i32| -> Result<i32> { Ok(x * 2) });
682
683 let with_fallbacks = RunnableWithFallbacks::new(primary, vec![Arc::new(fallback)]);
684
685 let results = with_fallbacks.batch(vec![3, 10, 5], None, false);
686
687 assert_eq!(results[0].as_ref().unwrap(), &4);
691 assert_eq!(results[1].as_ref().unwrap(), &20);
692 assert_eq!(results[2].as_ref().unwrap(), &6);
693 }
694
695 #[tokio::test]
696 async fn test_stream_fallback() {
697 use futures::StreamExt;
698
699 let primary =
700 RunnableLambda::new(|_x: i32| -> Result<i32> { Err(Error::other("primary failed")) });
701
702 let fallback = RunnableLambda::new(|x: i32| -> Result<i32> { Ok(x * 2) });
703
704 let with_fallbacks = RunnableWithFallbacks::new(primary, vec![Arc::new(fallback)]);
705
706 let mut stream = with_fallbacks.stream(5, None);
707 let result = stream.next().await.unwrap().unwrap();
708 assert_eq!(result, 10);
709 }
710
711 #[test]
712 fn test_runnables_iterator() {
713 let primary = RunnableLambda::new(|x: i32| -> Result<i32> { Ok(x) });
714 let fallback1 = RunnableLambda::new(|x: i32| -> Result<i32> { Ok(x) });
715 let fallback2 = RunnableLambda::new(|x: i32| -> Result<i32> { Ok(x) });
716
717 let with_fallbacks =
718 RunnableWithFallbacks::new(primary, vec![Arc::new(fallback1), Arc::new(fallback2)]);
719
720 let count = with_fallbacks.runnables().count();
721 assert_eq!(count, 3); }
723}