1use std::collections::HashMap;
7use std::fmt::Debug;
8use std::sync::Arc;
9
10use async_trait::async_trait;
11use futures::StreamExt;
12use futures::stream::BoxStream;
13use serde::Serialize;
14use serde_json::Value;
15
16use crate::error::Result;
17use crate::load::{Serializable, Serialized};
18
19use super::config::{
20 ConfigOrList, RunnableConfig, ensure_config, get_callback_manager_for_config, get_config_list,
21 merge_configs, patch_config,
22};
23
24#[allow(dead_code)]
26const RUNNABLE_GENERIC_NUM_ARGS: usize = 2;
27
28#[allow(dead_code)]
30const RUNNABLE_SEQUENCE_MIN_STEPS: usize = 2;
31
32#[async_trait]
48pub trait Runnable: Send + Sync + Debug {
49 type Input: Send + Sync + Clone + Debug + 'static;
51 type Output: Send + Sync + Clone + Debug + 'static;
53
54 fn get_name(&self, suffix: Option<&str>, name: Option<&str>) -> String {
56 let name_ = name
57 .map(|s| s.to_string())
58 .or_else(|| self.name())
59 .unwrap_or_else(|| self.type_name().to_string());
60
61 match suffix {
62 Some(s) if !name_.is_empty() && name_.chars().next().unwrap().is_uppercase() => {
63 format!("{}{}", name_, to_title_case(s))
64 }
65 Some(s) => format!("{}_{}", name_, s.to_lowercase()),
66 None => name_,
67 }
68 }
69
70 fn name(&self) -> Option<String> {
72 None
73 }
74
75 fn type_name(&self) -> &'static str {
77 std::any::type_name::<Self>()
78 }
79
80 fn invoke(&self, input: Self::Input, config: Option<RunnableConfig>) -> Result<Self::Output>;
91
92 async fn ainvoke(
96 &self,
97 input: Self::Input,
98 config: Option<RunnableConfig>,
99 ) -> Result<Self::Output>
100 where
101 Self: 'static,
102 {
103 self.invoke(input, config)
105 }
106
107 fn batch(
111 &self,
112 inputs: Vec<Self::Input>,
113 config: Option<ConfigOrList>,
114 _return_exceptions: bool,
115 ) -> Vec<Result<Self::Output>>
116 where
117 Self: 'static,
118 {
119 if inputs.is_empty() {
120 return Vec::new();
121 }
122
123 let configs = get_config_list(config, inputs.len());
124
125 if inputs.len() == 1 {
127 return vec![self.invoke(
128 inputs.into_iter().next().unwrap(),
129 Some(configs.into_iter().next().unwrap()),
130 )];
131 }
132
133 inputs
135 .into_iter()
136 .zip(configs)
137 .map(|(input, config)| self.invoke(input, Some(config)))
138 .collect()
139 }
140
141 async fn abatch(
145 &self,
146 inputs: Vec<Self::Input>,
147 config: Option<ConfigOrList>,
148 _return_exceptions: bool,
149 ) -> Vec<Result<Self::Output>>
150 where
151 Self: 'static,
152 {
153 if inputs.is_empty() {
154 return Vec::new();
155 }
156
157 let configs = get_config_list(config, inputs.len());
158 let mut results = Vec::with_capacity(inputs.len());
159
160 for (input, config) in inputs.into_iter().zip(configs) {
161 let result = self.ainvoke(input, Some(config)).await;
162 results.push(result);
163 }
164
165 results
166 }
167
168 fn stream(
172 &self,
173 input: Self::Input,
174 config: Option<RunnableConfig>,
175 ) -> BoxStream<'_, Result<Self::Output>> {
176 let result = self.invoke(input, config);
177 Box::pin(futures::stream::once(async move { result }))
178 }
179
180 fn astream(
184 &self,
185 input: Self::Input,
186 config: Option<RunnableConfig>,
187 ) -> BoxStream<'_, Result<Self::Output>>
188 where
189 Self: 'static,
190 {
191 Box::pin(futures::stream::once(async move {
192 self.ainvoke(input, config).await
193 }))
194 }
195
196 fn transform<'a>(
200 &'a self,
201 input: BoxStream<'a, Self::Input>,
202 config: Option<RunnableConfig>,
203 ) -> BoxStream<'a, Result<Self::Output>> {
204 Box::pin(async_stream::stream! {
205 let mut final_input: Option<Self::Input> = None;
206 let mut input = input;
207
208 while let Some(ichunk) = input.next().await {
209 if let Some(ref mut current) = final_input {
210 *current = ichunk;
213 } else {
214 final_input = Some(ichunk);
215 }
216 }
217
218 if let Some(input) = final_input {
219 let mut stream = self.stream(input, config);
220 while let Some(output) = stream.next().await {
221 yield output;
222 }
223 }
224 })
225 }
226
227 fn atransform<'a>(
231 &'a self,
232 input: BoxStream<'a, Self::Input>,
233 config: Option<RunnableConfig>,
234 ) -> BoxStream<'a, Result<Self::Output>>
235 where
236 Self: 'static,
237 {
238 Box::pin(async_stream::stream! {
239 let mut final_input: Option<Self::Input> = None;
240 let mut input = input;
241
242 while let Some(ichunk) = input.next().await {
243 if let Some(ref mut current) = final_input {
244 *current = ichunk;
245 } else {
246 final_input = Some(ichunk);
247 }
248 }
249
250 if let Some(input) = final_input {
251 let mut stream = self.astream(input, config);
252 while let Some(output) = stream.next().await {
253 yield output;
254 }
255 }
256 })
257 }
258
259 fn bind(self, kwargs: HashMap<String, Value>) -> RunnableBinding<Self>
261 where
262 Self: Sized,
263 {
264 RunnableBinding::new(self, kwargs, None)
265 }
266
267 fn with_config(self, config: RunnableConfig) -> RunnableBinding<Self>
269 where
270 Self: Sized,
271 {
272 RunnableBinding::new(self, HashMap::new(), Some(config))
273 }
274
275 fn with_retry(
281 self,
282 max_attempts: usize,
283 wait_exponential_jitter: bool,
284 ) -> super::retry::RunnableRetry<Self>
285 where
286 Self: Sized,
287 {
288 super::retry::RunnableRetry::with_simple(self, max_attempts, wait_exponential_jitter)
289 }
290
291 fn map(self) -> RunnableEach<Self>
293 where
294 Self: Sized,
295 {
296 RunnableEach::new(self)
297 }
298}
299
300fn to_title_case(s: &str) -> String {
302 let mut chars = s.chars();
303 match chars.next() {
304 None => String::new(),
305 Some(c) => c.to_uppercase().chain(chars).collect(),
306 }
307}
308
309pub trait RunnableSerializable: Runnable + Serializable {
311 fn to_json_runnable(&self) -> Serialized
313 where
314 Self: Sized + Serialize,
315 {
316 <Self as Serializable>::to_json(self)
317 }
318}
319
320pub struct RunnableLambda<F, I, O>
330where
331 F: Fn(I) -> Result<O> + Send + Sync,
332 I: Send + Sync + Clone + Debug + 'static,
333 O: Send + Sync + Clone + Debug + 'static,
334{
335 func: F,
336 name: Option<String>,
337 _phantom: std::marker::PhantomData<(I, O)>,
338}
339
340impl<F, I, O> Debug for RunnableLambda<F, I, O>
341where
342 F: Fn(I) -> Result<O> + Send + Sync,
343 I: Send + Sync + Clone + Debug + 'static,
344 O: Send + Sync + Clone + Debug + 'static,
345{
346 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
347 f.debug_struct("RunnableLambda")
348 .field("name", &self.name)
349 .finish()
350 }
351}
352
353impl<F, I, O> RunnableLambda<F, I, O>
354where
355 F: Fn(I) -> Result<O> + Send + Sync,
356 I: Send + Sync + Clone + Debug + 'static,
357 O: Send + Sync + Clone + Debug + 'static,
358{
359 pub fn new(func: F) -> Self {
361 Self {
362 func,
363 name: None,
364 _phantom: std::marker::PhantomData,
365 }
366 }
367
368 pub fn with_name(mut self, name: impl Into<String>) -> Self {
370 self.name = Some(name.into());
371 self
372 }
373}
374
375#[async_trait]
376impl<F, I, O> Runnable for RunnableLambda<F, I, O>
377where
378 F: Fn(I) -> Result<O> + Send + Sync,
379 I: Send + Sync + Clone + Debug + 'static,
380 O: Send + Sync + Clone + Debug + 'static,
381{
382 type Input = I;
383 type Output = O;
384
385 fn name(&self) -> Option<String> {
386 self.name.clone()
387 }
388
389 fn invoke(&self, input: Self::Input, _config: Option<RunnableConfig>) -> Result<Self::Output> {
390 (self.func)(input)
391 }
392}
393
394pub fn runnable_lambda<F, I, O>(func: F) -> RunnableLambda<F, I, O>
396where
397 F: Fn(I) -> Result<O> + Send + Sync,
398 I: Send + Sync + Clone + Debug + 'static,
399 O: Send + Sync + Clone + Debug + 'static,
400{
401 RunnableLambda::new(func)
402}
403
404pub struct RunnableSequence<R1, R2>
413where
414 R1: Runnable,
415 R2: Runnable<Input = R1::Output>,
416{
417 first: R1,
418 last: R2,
419 name: Option<String>,
420}
421
422impl<R1, R2> Debug for RunnableSequence<R1, R2>
423where
424 R1: Runnable,
425 R2: Runnable<Input = R1::Output>,
426{
427 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
428 f.debug_struct("RunnableSequence")
429 .field("first", &self.first)
430 .field("last", &self.last)
431 .field("name", &self.name)
432 .finish()
433 }
434}
435
436impl<R1, R2> RunnableSequence<R1, R2>
437where
438 R1: Runnable,
439 R2: Runnable<Input = R1::Output>,
440{
441 pub fn new(first: R1, last: R2) -> Self {
443 Self {
444 first,
445 last,
446 name: None,
447 }
448 }
449
450 pub fn with_name(mut self, name: impl Into<String>) -> Self {
452 self.name = Some(name.into());
453 self
454 }
455}
456
457#[async_trait]
458impl<R1, R2> Runnable for RunnableSequence<R1, R2>
459where
460 R1: Runnable + 'static,
461 R2: Runnable<Input = R1::Output> + 'static,
462{
463 type Input = R1::Input;
464 type Output = R2::Output;
465
466 fn name(&self) -> Option<String> {
467 self.name
468 .clone()
469 .or_else(|| self.first.name())
470 .or_else(|| self.last.name())
471 }
472
473 fn invoke(&self, input: Self::Input, config: Option<RunnableConfig>) -> Result<Self::Output> {
474 let config = ensure_config(config);
475 let callback_manager = get_callback_manager_for_config(&config);
476
477 let run_manager =
479 callback_manager.on_chain_start(&HashMap::new(), &HashMap::new(), config.run_id);
480
481 let first_config = patch_config(
483 Some(config.clone()),
484 Some(run_manager.get_child(Some("seq:step:1"))),
485 None,
486 None,
487 None,
488 None,
489 );
490 let intermediate = match self.first.invoke(input, Some(first_config)) {
491 Ok(output) => output,
492 Err(e) => {
493 run_manager.on_chain_error(&e);
494 return Err(e);
495 }
496 };
497
498 let last_config = patch_config(
500 Some(config),
501 Some(run_manager.get_child(Some("seq:step:2"))),
502 None,
503 None,
504 None,
505 None,
506 );
507 let result = match self.last.invoke(intermediate, Some(last_config)) {
508 Ok(output) => output,
509 Err(e) => {
510 run_manager.on_chain_error(&e);
511 return Err(e);
512 }
513 };
514
515 run_manager.on_chain_end(&HashMap::new());
516 Ok(result)
517 }
518
519 async fn ainvoke(
520 &self,
521 input: Self::Input,
522 config: Option<RunnableConfig>,
523 ) -> Result<Self::Output>
524 where
525 Self: 'static,
526 {
527 let config = ensure_config(config);
528
529 let intermediate = self.first.ainvoke(input, Some(config.clone())).await?;
531
532 self.last.ainvoke(intermediate, Some(config)).await
534 }
535
536 fn stream(
537 &self,
538 input: Self::Input,
539 config: Option<RunnableConfig>,
540 ) -> BoxStream<'_, Result<Self::Output>> {
541 Box::pin(async_stream::stream! {
542 let config = ensure_config(config);
543
544 let intermediate = match self.first.invoke(input, Some(config.clone())) {
546 Ok(output) => output,
547 Err(e) => {
548 yield Err(e);
549 return;
550 }
551 };
552
553 let mut stream = self.last.stream(intermediate, Some(config));
555 while let Some(output) = stream.next().await {
556 yield output;
557 }
558 })
559 }
560}
561
562pub fn pipe<R1, R2>(first: R1, second: R2) -> RunnableSequence<R1, R2>
564where
565 R1: Runnable,
566 R2: Runnable<Input = R1::Output>,
567{
568 RunnableSequence::new(first, second)
569}
570
571pub struct RunnableParallel<I>
579where
580 I: Send + Sync + Clone + Debug + 'static,
581{
582 steps: HashMap<String, Arc<dyn Runnable<Input = I, Output = Value> + Send + Sync>>,
583 name: Option<String>,
584}
585
586impl<I> Debug for RunnableParallel<I>
587where
588 I: Send + Sync + Clone + Debug + 'static,
589{
590 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
591 f.debug_struct("RunnableParallel")
592 .field("steps", &self.steps.keys().collect::<Vec<_>>())
593 .field("name", &self.name)
594 .finish()
595 }
596}
597
598impl<I> RunnableParallel<I>
599where
600 I: Send + Sync + Clone + Debug + 'static,
601{
602 pub fn new() -> Self {
604 Self {
605 steps: HashMap::new(),
606 name: None,
607 }
608 }
609
610 pub fn add<R>(mut self, key: impl Into<String>, runnable: R) -> Self
612 where
613 R: Runnable<Input = I, Output = Value> + Send + Sync + 'static,
614 {
615 self.steps.insert(key.into(), Arc::new(runnable));
616 self
617 }
618
619 pub fn with_name(mut self, name: impl Into<String>) -> Self {
621 self.name = Some(name.into());
622 self
623 }
624}
625
626impl<I> Default for RunnableParallel<I>
627where
628 I: Send + Sync + Clone + Debug + 'static,
629{
630 fn default() -> Self {
631 Self::new()
632 }
633}
634
635#[async_trait]
636impl<I> Runnable for RunnableParallel<I>
637where
638 I: Send + Sync + Clone + Debug + 'static,
639{
640 type Input = I;
641 type Output = HashMap<String, Value>;
642
643 fn name(&self) -> Option<String> {
644 self.name.clone().or_else(|| {
645 Some(format!(
646 "RunnableParallel<{}>",
647 self.steps.keys().cloned().collect::<Vec<_>>().join(",")
648 ))
649 })
650 }
651
652 fn invoke(&self, input: Self::Input, config: Option<RunnableConfig>) -> Result<Self::Output> {
653 let config = ensure_config(config);
654 let mut results = HashMap::new();
655
656 for (key, step) in &self.steps {
657 let result = step.invoke(input.clone(), Some(config.clone()))?;
658 results.insert(key.clone(), result);
659 }
660
661 Ok(results)
662 }
663
664 async fn ainvoke(
665 &self,
666 input: Self::Input,
667 config: Option<RunnableConfig>,
668 ) -> Result<Self::Output>
669 where
670 Self: 'static,
671 {
672 let config = ensure_config(config);
673 let mut results = HashMap::new();
674
675 for (key, step) in &self.steps {
677 let result = step.ainvoke(input.clone(), Some(config.clone())).await?;
678 results.insert(key.clone(), result);
679 }
680
681 Ok(results)
682 }
683}
684
685pub struct RunnableBinding<R>
691where
692 R: Runnable,
693{
694 bound: R,
695 kwargs: HashMap<String, Value>,
696 config: Option<RunnableConfig>,
697}
698
699impl<R> Debug for RunnableBinding<R>
700where
701 R: Runnable,
702{
703 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
704 f.debug_struct("RunnableBinding")
705 .field("bound", &self.bound)
706 .field("kwargs", &self.kwargs)
707 .field("config", &self.config)
708 .finish()
709 }
710}
711
712impl<R> RunnableBinding<R>
713where
714 R: Runnable,
715{
716 pub fn new(bound: R, kwargs: HashMap<String, Value>, config: Option<RunnableConfig>) -> Self {
718 Self {
719 bound,
720 kwargs,
721 config,
722 }
723 }
724
725 fn merge_configs(&self, config: Option<RunnableConfig>) -> RunnableConfig {
727 merge_configs(vec![self.config.clone(), config])
728 }
729}
730
731#[async_trait]
732impl<R> Runnable for RunnableBinding<R>
733where
734 R: Runnable + 'static,
735{
736 type Input = R::Input;
737 type Output = R::Output;
738
739 fn name(&self) -> Option<String> {
740 self.bound.name()
741 }
742
743 fn invoke(&self, input: Self::Input, config: Option<RunnableConfig>) -> Result<Self::Output> {
744 self.bound.invoke(input, Some(self.merge_configs(config)))
745 }
746
747 async fn ainvoke(
748 &self,
749 input: Self::Input,
750 config: Option<RunnableConfig>,
751 ) -> Result<Self::Output>
752 where
753 Self: 'static,
754 {
755 self.bound
756 .ainvoke(input, Some(self.merge_configs(config)))
757 .await
758 }
759
760 fn stream(
761 &self,
762 input: Self::Input,
763 config: Option<RunnableConfig>,
764 ) -> BoxStream<'_, Result<Self::Output>> {
765 self.bound.stream(input, Some(self.merge_configs(config)))
766 }
767}
768
769pub struct RunnableEach<R>
775where
776 R: Runnable,
777{
778 bound: R,
779}
780
781impl<R> Debug for RunnableEach<R>
782where
783 R: Runnable,
784{
785 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
786 f.debug_struct("RunnableEach")
787 .field("bound", &self.bound)
788 .finish()
789 }
790}
791
792impl<R> RunnableEach<R>
793where
794 R: Runnable,
795{
796 pub fn new(bound: R) -> Self {
798 Self { bound }
799 }
800}
801
802#[async_trait]
803impl<R> Runnable for RunnableEach<R>
804where
805 R: Runnable + 'static,
806{
807 type Input = Vec<R::Input>;
808 type Output = Vec<R::Output>;
809
810 fn name(&self) -> Option<String> {
811 self.bound.name().map(|n| format!("RunnableEach<{}>", n))
812 }
813
814 fn invoke(&self, inputs: Self::Input, config: Option<RunnableConfig>) -> Result<Self::Output> {
815 let config = ensure_config(config);
816 let configs: Vec<_> = inputs.iter().map(|_| config.clone()).collect();
817
818 let results: Vec<Result<R::Output>> = inputs
819 .into_iter()
820 .zip(configs)
821 .map(|(input, config)| self.bound.invoke(input, Some(config)))
822 .collect();
823
824 results.into_iter().collect()
826 }
827
828 async fn ainvoke(
829 &self,
830 inputs: Self::Input,
831 config: Option<RunnableConfig>,
832 ) -> Result<Self::Output>
833 where
834 Self: 'static,
835 {
836 let config = ensure_config(config);
837 let mut results = Vec::with_capacity(inputs.len());
838
839 for input in inputs {
840 results.push(self.bound.ainvoke(input, Some(config.clone())).await?);
841 }
842
843 Ok(results)
844 }
845}
846
847pub type DynRunnable<I, O> = Arc<dyn Runnable<Input = I, Output = O> + Send + Sync>;
853
854pub fn to_dyn<R>(runnable: R) -> DynRunnable<R::Input, R::Output>
856where
857 R: Runnable + Send + Sync + 'static,
858{
859 Arc::new(runnable)
860}
861
862pub fn coerce_to_runnable<F, I, O>(func: F) -> RunnableLambda<F, I, O>
868where
869 F: Fn(I) -> Result<O> + Send + Sync,
870 I: Send + Sync + Clone + Debug + 'static,
871 O: Send + Sync + Clone + Debug + 'static,
872{
873 RunnableLambda::new(func)
874}
875
876#[cfg(test)]
877mod tests {
878 use super::*;
879 use crate::runnables::passthrough::RunnablePassthrough;
880 use crate::runnables::utils::AddableDict;
881
882 #[test]
883 fn test_runnable_lambda() {
884 let runnable = RunnableLambda::new(|x: i32| Ok(x + 1));
885 let result = runnable.invoke(1, None).unwrap();
886 assert_eq!(result, 2);
887 }
888
889 #[test]
890 fn test_runnable_lambda_with_name() {
891 let runnable = RunnableLambda::new(|x: i32| Ok(x + 1)).with_name("add_one");
892 assert_eq!(runnable.name(), Some("add_one".to_string()));
893 }
894
895 #[test]
896 fn test_runnable_sequence() {
897 let first = RunnableLambda::new(|x: i32| Ok(x + 1));
898 let second = RunnableLambda::new(|x: i32| Ok(x * 2));
899 let sequence = RunnableSequence::new(first, second);
900
901 let result = sequence.invoke(1, None).unwrap();
902 assert_eq!(result, 4); }
904
905 #[test]
906 fn test_runnable_each() {
907 let runnable = RunnableLambda::new(|x: i32| Ok(x * 2));
908 let each = RunnableEach::new(runnable);
909
910 let result = each.invoke(vec![1, 2, 3], None).unwrap();
911 assert_eq!(result, vec![2, 4, 6]);
912 }
913
914 #[test]
915 fn test_runnable_binding() {
916 let runnable = RunnableLambda::new(|x: i32| Ok(x + 1));
917 let config = RunnableConfig::new().with_tags(vec!["test".to_string()]);
918 let bound = RunnableBinding::new(runnable, HashMap::new(), Some(config));
919
920 let result = bound.invoke(1, None).unwrap();
921 assert_eq!(result, 2);
922 }
923
924 #[test]
925 fn test_runnable_passthrough() {
926 let runnable: RunnablePassthrough<i32> = RunnablePassthrough::new();
927 let result = runnable.invoke(42, None).unwrap();
928 assert_eq!(result, 42);
929 }
930
931 #[test]
932 fn test_runnable_retry() {
933 use crate::runnables::retry::RunnableRetry;
935
936 let runnable = RunnableLambda::new(|x: i32| Ok(x + 1));
937 let retry = RunnableRetry::with_simple(runnable, 3, false);
938
939 let result = retry.invoke(1, None).unwrap();
940 assert_eq!(result, 2);
941 }
942
943 #[test]
944 fn test_addable_dict() {
945 let mut dict1 = AddableDict::new();
946 dict1.insert("a", serde_json::json!(1));
947
948 let mut dict2 = AddableDict::new();
949 dict2.insert("b", serde_json::json!(2));
950
951 let combined = dict1 + dict2;
952 assert_eq!(combined.get("a"), Some(&serde_json::json!(1)));
953 assert_eq!(combined.get("b"), Some(&serde_json::json!(2)));
954 }
955
956 #[test]
957 fn test_pipe() {
958 let first = RunnableLambda::new(|x: i32| Ok(x + 1));
959 let second = RunnableLambda::new(|x: i32| Ok(x * 2));
960 let sequence = pipe(first, second);
961
962 let result = sequence.invoke(1, None).unwrap();
963 assert_eq!(result, 4);
964 }
965
966 #[tokio::test]
967 async fn test_runnable_lambda_async() {
968 let runnable = RunnableLambda::new(|x: i32| Ok(x + 1));
969 let result = runnable.ainvoke(1, None).await.unwrap();
970 assert_eq!(result, 2);
971 }
972
973 #[tokio::test]
974 async fn test_runnable_sequence_async() {
975 let first = RunnableLambda::new(|x: i32| Ok(x + 1));
976 let second = RunnableLambda::new(|x: i32| Ok(x * 2));
977 let sequence = RunnableSequence::new(first, second);
978
979 let result = sequence.ainvoke(1, None).await.unwrap();
980 assert_eq!(result, 4);
981 }
982}