1use std::fmt::Debug;
7use std::sync::Arc;
8
9use async_trait::async_trait;
10use futures::StreamExt;
11use futures::stream::BoxStream;
12use serde::Serialize;
13
14use crate::error::{Error, Result};
15use crate::load::{Serializable, Serialized, SerializedConstructorData};
16
17use super::base::{DynRunnable, Runnable, RunnableLambda, RunnableSerializable};
18use super::config::{RunnableConfig, ensure_config, get_callback_manager_for_config, patch_config};
19use super::utils::{ConfigurableFieldSpec, get_unique_config_specs};
20
21pub struct RunnableBranch<I, O>
55where
56 I: Send + Sync + Clone + Debug + 'static,
57 O: Send + Sync + Clone + Debug + 'static,
58{
59 branches: Vec<(DynRunnable<I, bool>, DynRunnable<I, O>)>,
61 default: DynRunnable<I, O>,
63 name: Option<String>,
65}
66
67impl<I, O> Debug for RunnableBranch<I, O>
68where
69 I: Send + Sync + Clone + Debug + 'static,
70 O: Send + Sync + Clone + Debug + 'static,
71{
72 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
73 f.debug_struct("RunnableBranch")
74 .field("branches_count", &self.branches.len())
75 .field("name", &self.name)
76 .finish()
77 }
78}
79
80impl<I, O> RunnableBranch<I, O>
81where
82 I: Send + Sync + Clone + Debug + 'static,
83 O: Send + Sync + Clone + Debug + 'static,
84{
85 pub fn new(
104 branches: Vec<(DynRunnable<I, bool>, DynRunnable<I, O>)>,
105 default: DynRunnable<I, O>,
106 ) -> Result<Self> {
107 if branches.is_empty() {
108 return Err(Error::Other(
109 "RunnableBranch requires at least one condition branch".to_string(),
110 ));
111 }
112
113 Ok(Self {
114 branches,
115 default,
116 name: None,
117 })
118 }
119
120 pub fn with_name(mut self, name: impl Into<String>) -> Self {
122 self.name = Some(name.into());
123 self
124 }
125
126 pub fn config_specs(&self) -> std::result::Result<Vec<ConfigurableFieldSpec>, String> {
128 let specs = self
129 .branches
130 .iter()
131 .flat_map(|(_condition, _runnable)| Vec::<ConfigurableFieldSpec>::new())
132 .collect::<Vec<_>>();
133
134 get_unique_config_specs(specs)
135 }
136}
137
138pub struct RunnableBranchBuilder<I, O>
140where
141 I: Send + Sync + Clone + Debug + 'static,
142 O: Send + Sync + Clone + Debug + 'static,
143{
144 branches: Vec<(DynRunnable<I, bool>, DynRunnable<I, O>)>,
145 _phantom: std::marker::PhantomData<(I, O)>,
146}
147
148impl<I, O> RunnableBranchBuilder<I, O>
149where
150 I: Send + Sync + Clone + Debug + 'static,
151 O: Send + Sync + Clone + Debug + 'static,
152{
153 pub fn new() -> Self {
155 Self {
156 branches: Vec::new(),
157 _phantom: std::marker::PhantomData,
158 }
159 }
160
161 pub fn branch<CF, RF>(mut self, condition: CF, runnable: RF) -> Self
163 where
164 CF: Fn(I) -> Result<bool> + Send + Sync + 'static,
165 RF: Fn(I) -> Result<O> + Send + Sync + 'static,
166 {
167 let condition_runnable: DynRunnable<I, bool> = Arc::new(RunnableLambda::new(condition));
168 let branch_runnable: DynRunnable<I, O> = Arc::new(RunnableLambda::new(runnable));
169 self.branches.push((condition_runnable, branch_runnable));
170 self
171 }
172
173 pub fn branch_arc(
175 mut self,
176 condition: DynRunnable<I, bool>,
177 runnable: DynRunnable<I, O>,
178 ) -> Self {
179 self.branches.push((condition, runnable));
180 self
181 }
182
183 pub fn default<DF>(self, default_fn: DF) -> Result<RunnableBranch<I, O>>
185 where
186 DF: Fn(I) -> Result<O> + Send + Sync + 'static,
187 {
188 let default_runnable: DynRunnable<I, O> = Arc::new(RunnableLambda::new(default_fn));
189 RunnableBranch::new(self.branches, default_runnable)
190 }
191
192 pub fn default_arc(self, default: DynRunnable<I, O>) -> Result<RunnableBranch<I, O>> {
194 RunnableBranch::new(self.branches, default)
195 }
196}
197
198impl<I, O> Default for RunnableBranchBuilder<I, O>
199where
200 I: Send + Sync + Clone + Debug + 'static,
201 O: Send + Sync + Clone + Debug + 'static,
202{
203 fn default() -> Self {
204 Self::new()
205 }
206}
207
208#[async_trait]
209impl<I, O> Runnable for RunnableBranch<I, O>
210where
211 I: Send + Sync + Clone + Debug + 'static,
212 O: Send + Sync + Clone + Debug + 'static,
213{
214 type Input = I;
215 type Output = O;
216
217 fn name(&self) -> Option<String> {
218 self.name
219 .clone()
220 .or_else(|| Some("RunnableBranch".to_string()))
221 }
222
223 fn invoke(&self, input: Self::Input, config: Option<RunnableConfig>) -> Result<Self::Output> {
224 let config = ensure_config(config);
225 let callback_manager = get_callback_manager_for_config(&config);
226 let run_manager = callback_manager.on_chain_start(
227 &std::collections::HashMap::new(),
228 &std::collections::HashMap::new(),
229 config.run_id,
230 );
231
232 let result = (|| {
233 for (idx, (condition, runnable)) in self.branches.iter().enumerate() {
234 let condition_config = patch_config(
235 Some(config.clone()),
236 Some(run_manager.get_child(Some(&format!("condition:{}", idx + 1)))),
237 None,
238 None,
239 None,
240 None,
241 );
242
243 let expression_value = condition.invoke(input.clone(), Some(condition_config))?;
244
245 if expression_value {
246 let branch_config = patch_config(
247 Some(config.clone()),
248 Some(run_manager.get_child(Some(&format!("branch:{}", idx + 1)))),
249 None,
250 None,
251 None,
252 None,
253 );
254
255 return runnable.invoke(input.clone(), Some(branch_config));
256 }
257 }
258
259 let default_config = patch_config(
260 Some(config.clone()),
261 Some(run_manager.get_child(Some("branch:default"))),
262 None,
263 None,
264 None,
265 None,
266 );
267
268 self.default.invoke(input, Some(default_config))
269 })();
270
271 match &result {
272 Ok(_) => {
273 run_manager.on_chain_end(&std::collections::HashMap::new());
274 }
275 Err(e) => {
276 run_manager.on_chain_error(e);
277 }
278 }
279
280 result
281 }
282
283 async fn ainvoke(
284 &self,
285 input: Self::Input,
286 config: Option<RunnableConfig>,
287 ) -> Result<Self::Output>
288 where
289 Self: 'static,
290 {
291 let config = ensure_config(config);
292
293 for (condition, runnable) in self.branches.iter() {
294 let expression_value = condition
295 .ainvoke(input.clone(), Some(config.clone()))
296 .await?;
297
298 if expression_value {
299 return runnable.ainvoke(input.clone(), Some(config.clone())).await;
300 }
301 }
302
303 self.default.ainvoke(input, Some(config)).await
304 }
305
306 fn stream(
307 &self,
308 input: Self::Input,
309 config: Option<RunnableConfig>,
310 ) -> BoxStream<'_, Result<Self::Output>> {
311 let config = ensure_config(config);
312
313 Box::pin(async_stream::stream! {
314 'outer: {
315 for (condition, runnable) in self.branches.iter() {
316 let expression_value = match condition.invoke(input.clone(), Some(config.clone())) {
317 Ok(v) => v,
318 Err(e) => {
319 yield Err(e);
320 break 'outer;
321 }
322 };
323
324 if expression_value {
325 let mut stream = runnable.stream(input.clone(), Some(config.clone()));
326 while let Some(chunk_result) = stream.next().await {
327 yield chunk_result;
328 }
329 break 'outer;
330 }
331 }
332
333 let mut stream = self.default.stream(input, Some(config.clone()));
334 while let Some(chunk_result) = stream.next().await {
335 yield chunk_result;
336 }
337 }
338 })
339 }
340
341 fn astream(
342 &self,
343 input: Self::Input,
344 config: Option<RunnableConfig>,
345 ) -> BoxStream<'_, Result<Self::Output>>
346 where
347 Self: 'static,
348 {
349 let config = ensure_config(config);
350
351 Box::pin(async_stream::stream! {
352 'outer: {
353 for (condition, runnable) in self.branches.iter() {
354 let expression_value = match condition.ainvoke(input.clone(), Some(config.clone())).await {
355 Ok(v) => v,
356 Err(e) => {
357 yield Err(e);
358 break 'outer;
359 }
360 };
361
362 if expression_value {
363 let mut stream = runnable.astream(input.clone(), Some(config.clone()));
364 while let Some(chunk_result) = stream.next().await {
365 yield chunk_result;
366 }
367 break 'outer;
368 }
369 }
370
371 let mut stream = self.default.astream(input, Some(config.clone()));
372 while let Some(chunk_result) = stream.next().await {
373 yield chunk_result;
374 }
375 }
376 })
377 }
378}
379
380impl<I, O> Serializable for RunnableBranch<I, O>
381where
382 I: Send + Sync + Clone + Debug + Serialize + 'static,
383 O: Send + Sync + Clone + Debug + 'static,
384{
385 fn is_lc_serializable() -> bool {
386 true
387 }
388
389 fn get_lc_namespace() -> Vec<String> {
390 vec![
391 "langchain".to_string(),
392 "schema".to_string(),
393 "runnable".to_string(),
394 ]
395 }
396
397 fn to_json(&self) -> Serialized {
398 let kwargs = std::collections::HashMap::new();
399
400 Serialized::Constructor(SerializedConstructorData {
401 lc: 1,
402 id: Self::get_lc_namespace(),
403 kwargs,
404 name: None,
405 graph: None,
406 })
407 }
408}
409
410impl<I, O> RunnableSerializable for RunnableBranch<I, O>
411where
412 I: Send + Sync + Clone + Debug + Serialize + 'static,
413 O: Send + Sync + Clone + Debug + Serialize + 'static,
414{
415}
416
417#[cfg(test)]
418mod tests {
419 use super::*;
420
421 #[test]
422 fn test_runnable_branch_invoke_first_condition() {
423 let branch = RunnableBranchBuilder::new()
424 .branch(|x: i32| Ok(x > 0), |x: i32| Ok(format!("positive: {}", x)))
425 .branch(|x: i32| Ok(x < 0), |x: i32| Ok(format!("negative: {}", x)))
426 .default(|_: i32| Ok("zero".to_string()))
427 .unwrap();
428
429 let result = branch.invoke(5, None).unwrap();
430 assert_eq!(result, "positive: 5");
431 }
432
433 #[test]
434 fn test_runnable_branch_invoke_second_condition() {
435 let branch = RunnableBranchBuilder::new()
436 .branch(|x: i32| Ok(x > 0), |x: i32| Ok(format!("positive: {}", x)))
437 .branch(|x: i32| Ok(x < 0), |x: i32| Ok(format!("negative: {}", x)))
438 .default(|_: i32| Ok("zero".to_string()))
439 .unwrap();
440
441 let result = branch.invoke(-3, None).unwrap();
442 assert_eq!(result, "negative: -3");
443 }
444
445 #[test]
446 fn test_runnable_branch_invoke_default() {
447 let branch = RunnableBranchBuilder::new()
448 .branch(|x: i32| Ok(x > 0), |x: i32| Ok(format!("positive: {}", x)))
449 .branch(|x: i32| Ok(x < 0), |x: i32| Ok(format!("negative: {}", x)))
450 .default(|_: i32| Ok("zero".to_string()))
451 .unwrap();
452
453 let result = branch.invoke(0, None).unwrap();
454 assert_eq!(result, "zero");
455 }
456
457 #[test]
458 fn test_runnable_branch_requires_at_least_one_branch() {
459 let result: Result<RunnableBranch<i32, String>> =
460 RunnableBranchBuilder::new().default(|_: i32| Ok("default".to_string()));
461
462 assert!(result.is_err());
463 assert!(
464 result
465 .unwrap_err()
466 .to_string()
467 .contains("at least one condition branch")
468 );
469 }
470
471 #[test]
472 fn test_runnable_branch_name() {
473 let branch = RunnableBranchBuilder::new()
474 .branch(|x: i32| Ok(x > 0), |x: i32| Ok(x.to_string()))
475 .default(|_: i32| Ok("default".to_string()))
476 .unwrap()
477 .with_name("my_branch");
478
479 assert_eq!(branch.name(), Some("my_branch".to_string()));
480 }
481
482 #[test]
483 fn test_runnable_branch_default_name() {
484 let branch = RunnableBranchBuilder::new()
485 .branch(|x: i32| Ok(x > 0), |x: i32| Ok(x.to_string()))
486 .default(|_: i32| Ok("default".to_string()))
487 .unwrap();
488
489 assert_eq!(branch.name(), Some("RunnableBranch".to_string()));
490 }
491
492 #[test]
493 fn test_runnable_branch_with_arc_runnables() {
494 let condition: DynRunnable<i32, bool> = Arc::new(RunnableLambda::new(|x: i32| Ok(x > 10)));
495 let branch_runnable: DynRunnable<i32, String> =
496 Arc::new(RunnableLambda::new(|x: i32| Ok(format!("big: {}", x))));
497 let default: DynRunnable<i32, String> =
498 Arc::new(RunnableLambda::new(|x: i32| Ok(format!("small: {}", x))));
499
500 let branch = RunnableBranch::new(vec![(condition, branch_runnable)], default).unwrap();
501
502 assert_eq!(branch.invoke(15, None).unwrap(), "big: 15");
503 assert_eq!(branch.invoke(5, None).unwrap(), "small: 5");
504 }
505
506 #[tokio::test]
507 async fn test_runnable_branch_ainvoke() {
508 let branch = RunnableBranchBuilder::new()
509 .branch(|x: i32| Ok(x > 0), |x: i32| Ok(format!("positive: {}", x)))
510 .branch(|x: i32| Ok(x < 0), |x: i32| Ok(format!("negative: {}", x)))
511 .default(|_: i32| Ok("zero".to_string()))
512 .unwrap();
513
514 let result = branch.ainvoke(5, None).await.unwrap();
515 assert_eq!(result, "positive: 5");
516
517 let result = branch.ainvoke(-3, None).await.unwrap();
518 assert_eq!(result, "negative: -3");
519
520 let result = branch.ainvoke(0, None).await.unwrap();
521 assert_eq!(result, "zero");
522 }
523
524 #[tokio::test]
525 async fn test_runnable_branch_stream() {
526 let branch = RunnableBranchBuilder::new()
527 .branch(|x: i32| Ok(x > 0), |x: i32| Ok(format!("positive: {}", x)))
528 .default(|_: i32| Ok("non-positive".to_string()))
529 .unwrap();
530
531 let mut stream = branch.stream(5, None);
532 let result = stream.next().await.unwrap().unwrap();
533 assert_eq!(result, "positive: 5");
534 }
535
536 #[tokio::test]
537 async fn test_runnable_branch_astream() {
538 let branch = RunnableBranchBuilder::new()
539 .branch(|x: i32| Ok(x > 0), |x: i32| Ok(format!("positive: {}", x)))
540 .default(|_: i32| Ok("non-positive".to_string()))
541 .unwrap();
542
543 let mut stream = branch.astream(5, None);
544 let result = stream.next().await.unwrap().unwrap();
545 assert_eq!(result, "positive: 5");
546 }
547}