1use std::collections::HashMap;
7use std::fmt::Debug;
8use std::sync::Arc;
9
10use async_trait::async_trait;
11use futures::StreamExt;
12use futures::stream::BoxStream;
13use serde::{Deserialize, Serialize};
14use serde_json::Value;
15
16use crate::error::{Error, Result};
17use crate::load::{Serializable, Serialized, SerializedConstructorData};
18
19use super::base::{DynRunnable, Runnable, RunnableSerializable};
20use super::config::{ConfigOrList, RunnableConfig, get_config_list};
21use super::utils::{ConfigurableFieldSpec, gather_with_concurrency, get_unique_config_specs};
22
23#[derive(Debug, Clone, Serialize, Deserialize)]
28pub struct RouterInput<I> {
29 pub key: String,
31 pub input: I,
33}
34
35impl<I> RouterInput<I> {
36 pub fn new(key: impl Into<String>, input: I) -> Self {
38 Self {
39 key: key.into(),
40 input,
41 }
42 }
43}
44
45pub struct RouterRunnable<I, O>
65where
66 I: Send + Sync + Clone + Debug + 'static,
67 O: Send + Sync + Clone + Debug + 'static,
68{
69 runnables: HashMap<String, DynRunnable<I, O>>,
71 name: Option<String>,
73}
74
75impl<I, O> Debug for RouterRunnable<I, O>
76where
77 I: Send + Sync + Clone + Debug + 'static,
78 O: Send + Sync + Clone + Debug + 'static,
79{
80 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
81 f.debug_struct("RouterRunnable")
82 .field("runnables", &self.runnables.keys().collect::<Vec<_>>())
83 .field("name", &self.name)
84 .finish()
85 }
86}
87
88impl<I, O> RouterRunnable<I, O>
89where
90 I: Send + Sync + Clone + Debug + 'static,
91 O: Send + Sync + Clone + Debug + 'static,
92{
93 pub fn new() -> Self {
95 Self {
96 runnables: HashMap::new(),
97 name: None,
98 }
99 }
100
101 pub fn from_runnables(runnables: HashMap<String, DynRunnable<I, O>>) -> Self {
103 Self {
104 runnables,
105 name: None,
106 }
107 }
108
109 pub fn add<R>(mut self, key: impl Into<String>, runnable: R) -> Self
111 where
112 R: Runnable<Input = I, Output = O> + Send + Sync + 'static,
113 {
114 self.runnables.insert(key.into(), Arc::new(runnable));
115 self
116 }
117
118 pub fn with_name(mut self, name: impl Into<String>) -> Self {
120 self.name = Some(name.into());
121 self
122 }
123
124 pub fn config_specs(&self) -> std::result::Result<Vec<ConfigurableFieldSpec>, String> {
126 let specs = self
127 .runnables
128 .values()
129 .flat_map(|_r| {
130 Vec::<ConfigurableFieldSpec>::new()
133 })
134 .collect::<Vec<_>>();
135
136 get_unique_config_specs(specs)
137 }
138}
139
140impl<I, O> Default for RouterRunnable<I, O>
141where
142 I: Send + Sync + Clone + Debug + 'static,
143 O: Send + Sync + Clone + Debug + 'static,
144{
145 fn default() -> Self {
146 Self::new()
147 }
148}
149
150#[async_trait]
151impl<I, O> Runnable for RouterRunnable<I, O>
152where
153 I: Send + Sync + Clone + Debug + 'static,
154 O: Send + Sync + Clone + Debug + 'static,
155{
156 type Input = RouterInput<I>;
157 type Output = O;
158
159 fn name(&self) -> Option<String> {
160 self.name.clone().or_else(|| {
161 Some(format!(
162 "RouterRunnable<{}>",
163 self.runnables.keys().cloned().collect::<Vec<_>>().join(",")
164 ))
165 })
166 }
167
168 fn invoke(&self, input: Self::Input, config: Option<RunnableConfig>) -> Result<Self::Output> {
169 let key = &input.key;
170 let actual_input = input.input;
171
172 let runnable = self
173 .runnables
174 .get(key)
175 .ok_or_else(|| Error::Other(format!("No runnable associated with key '{}'", key)))?;
176
177 runnable.invoke(actual_input, config)
178 }
179
180 async fn ainvoke(
181 &self,
182 input: Self::Input,
183 config: Option<RunnableConfig>,
184 ) -> Result<Self::Output>
185 where
186 Self: 'static,
187 {
188 let key = &input.key;
189 let actual_input = input.input;
190
191 let runnable = self
192 .runnables
193 .get(key)
194 .ok_or_else(|| Error::Other(format!("No runnable associated with key '{}'", key)))?;
195
196 runnable.ainvoke(actual_input, config).await
197 }
198
199 fn batch(
200 &self,
201 inputs: Vec<Self::Input>,
202 config: Option<ConfigOrList>,
203 return_exceptions: bool,
204 ) -> Vec<Result<Self::Output>>
205 where
206 Self: 'static,
207 {
208 if inputs.is_empty() {
209 return Vec::new();
210 }
211
212 let keys: Vec<_> = inputs.iter().map(|i| i.key.clone()).collect();
213 let actual_inputs: Vec<_> = inputs.into_iter().map(|i| i.input).collect();
214
215 for key in &keys {
217 if !self.runnables.contains_key(key) {
218 return vec![Err(Error::Other(
219 "One or more keys do not have a corresponding runnable".to_string(),
220 ))];
221 }
222 }
223
224 let configs = get_config_list(config, keys.len());
225
226 let _ = return_exceptions; let results: Vec<Result<O>> = keys
228 .into_iter()
229 .zip(actual_inputs)
230 .zip(configs)
231 .map(|((key, input), config)| {
232 let runnable = self.runnables.get(&key).unwrap();
233 runnable.invoke(input, Some(config))
234 })
235 .collect();
236
237 results
238 }
239
240 async fn abatch(
241 &self,
242 inputs: Vec<Self::Input>,
243 config: Option<ConfigOrList>,
244 return_exceptions: bool,
245 ) -> Vec<Result<Self::Output>>
246 where
247 Self: 'static,
248 {
249 if inputs.is_empty() {
250 return Vec::new();
251 }
252
253 let keys: Vec<_> = inputs.iter().map(|i| i.key.clone()).collect();
254 let actual_inputs: Vec<_> = inputs.into_iter().map(|i| i.input).collect();
255
256 for key in &keys {
258 if !self.runnables.contains_key(key) {
259 return vec![Err(Error::Other(
260 "One or more keys do not have a corresponding runnable".to_string(),
261 ))];
262 }
263 }
264
265 let configs = get_config_list(config, keys.len());
266 let max_concurrency = configs.first().and_then(|c| c.max_concurrency);
267
268 let _ = return_exceptions; let futures: Vec<_> = keys
271 .into_iter()
272 .zip(actual_inputs)
273 .zip(configs)
274 .map(|((key, input), config)| {
275 let runnable = self.runnables.get(&key).unwrap().clone();
276 Box::pin(async move { runnable.ainvoke(input, Some(config)).await })
277 as std::pin::Pin<Box<dyn std::future::Future<Output = Result<O>> + Send>>
278 })
279 .collect();
280
281 gather_with_concurrency(max_concurrency, futures).await
282 }
283
284 fn stream(
285 &self,
286 input: Self::Input,
287 config: Option<RunnableConfig>,
288 ) -> BoxStream<'_, Result<Self::Output>> {
289 let key = input.key.clone();
290 let actual_input = input.input;
291
292 Box::pin(async_stream::stream! {
293 let runnable = match self.runnables.get(&key) {
294 Some(r) => r,
295 None => {
296 yield Err(Error::Other(format!("No runnable associated with key '{}'", key)));
297 return;
298 }
299 };
300
301 let mut stream = runnable.stream(actual_input, config);
302 while let Some(output) = stream.next().await {
303 yield output;
304 }
305 })
306 }
307
308 fn astream(
309 &self,
310 input: Self::Input,
311 config: Option<RunnableConfig>,
312 ) -> BoxStream<'_, Result<Self::Output>>
313 where
314 Self: 'static,
315 {
316 let key = input.key.clone();
317 let actual_input = input.input;
318
319 Box::pin(async_stream::stream! {
320 let runnable = match self.runnables.get(&key) {
321 Some(r) => r,
322 None => {
323 yield Err(Error::Other(format!("No runnable associated with key '{}'", key)));
324 return;
325 }
326 };
327
328 let mut stream = runnable.astream(actual_input, config);
329 while let Some(output) = stream.next().await {
330 yield output;
331 }
332 })
333 }
334}
335
336impl<I, O> Serializable for RouterRunnable<I, O>
337where
338 I: Send + Sync + Clone + Debug + Serialize + 'static,
339 O: Send + Sync + Clone + Debug + 'static,
340{
341 fn is_lc_serializable() -> bool {
342 true
343 }
344
345 fn get_lc_namespace() -> Vec<String> {
346 vec![
347 "langchain".to_string(),
348 "schema".to_string(),
349 "runnable".to_string(),
350 ]
351 }
352
353 fn to_json(&self) -> Serialized {
354 let mut kwargs = std::collections::HashMap::new();
355 kwargs.insert(
356 "runnables".to_string(),
357 serde_json::json!(self.runnables.keys().collect::<Vec<_>>()),
358 );
359
360 Serialized::Constructor(SerializedConstructorData {
361 lc: 1,
362 id: Self::get_lc_namespace(),
363 kwargs,
364 name: None,
365 graph: None,
366 })
367 }
368}
369
370impl<I, O> RunnableSerializable for RouterRunnable<I, O>
371where
372 I: Send + Sync + Clone + Debug + Serialize + 'static,
373 O: Send + Sync + Clone + Debug + Serialize + 'static,
374{
375}
376
377pub type DynRouterRunnable = RouterRunnable<Value, Value>;
381
382#[cfg(test)]
383mod tests {
384 use super::*;
385 use crate::runnables::RunnableLambda;
386
387 #[test]
388 fn test_router_input() {
389 let input = RouterInput::new("add", 5);
390 assert_eq!(input.key, "add");
391 assert_eq!(input.input, 5);
392 }
393
394 #[test]
395 fn test_router_runnable_invoke() {
396 let add = RunnableLambda::new(|x: i32| Ok(x + 1));
397 let square = RunnableLambda::new(|x: i32| Ok(x * x));
398
399 let router = RouterRunnable::new().add("add", add).add("square", square);
400
401 let result = router.invoke(RouterInput::new("add", 5), None).unwrap();
402 assert_eq!(result, 6);
403
404 let result = router.invoke(RouterInput::new("square", 4), None).unwrap();
405 assert_eq!(result, 16);
406 }
407
408 #[test]
409 fn test_router_runnable_missing_key() {
410 let add = RunnableLambda::new(|x: i32| Ok(x + 1));
411 let router = RouterRunnable::new().add("add", add);
412
413 let result = router.invoke(RouterInput::new("multiply", 5), None);
414 assert!(result.is_err());
415 assert!(
416 result
417 .unwrap_err()
418 .to_string()
419 .contains("No runnable associated with key")
420 );
421 }
422
423 #[test]
424 fn test_router_runnable_batch() {
425 let add = RunnableLambda::new(|x: i32| Ok(x + 1));
426 let square = RunnableLambda::new(|x: i32| Ok(x * x));
427
428 let router = RouterRunnable::new().add("add", add).add("square", square);
429
430 let inputs = vec![
431 RouterInput::new("add", 5),
432 RouterInput::new("square", 4),
433 RouterInput::new("add", 10),
434 ];
435
436 let results = router.batch(inputs, None, false);
437 assert_eq!(results.len(), 3);
438 assert_eq!(results[0].as_ref().unwrap(), &6);
439 assert_eq!(results[1].as_ref().unwrap(), &16);
440 assert_eq!(results[2].as_ref().unwrap(), &11);
441 }
442
443 #[test]
444 fn test_router_runnable_name() {
445 let add = RunnableLambda::new(|x: i32| Ok(x + 1));
446
447 let router = RouterRunnable::new().add("add", add).with_name("my_router");
448
449 assert_eq!(router.name(), Some("my_router".to_string()));
450 }
451
452 #[test]
453 fn test_router_runnable_default_name() {
454 let add = RunnableLambda::new(|x: i32| Ok(x + 1));
455 let square = RunnableLambda::new(|x: i32| Ok(x * x));
456
457 let router = RouterRunnable::new().add("add", add).add("square", square);
458
459 let name = router.name().unwrap();
460 assert!(name.starts_with("RouterRunnable<"));
461 assert!(name.contains("add") || name.contains("square"));
462 }
463
464 #[tokio::test]
465 async fn test_router_runnable_ainvoke() {
466 let add = RunnableLambda::new(|x: i32| Ok(x + 1));
467 let square = RunnableLambda::new(|x: i32| Ok(x * x));
468
469 let router = RouterRunnable::new().add("add", add).add("square", square);
470
471 let result = router
472 .ainvoke(RouterInput::new("add", 5), None)
473 .await
474 .unwrap();
475 assert_eq!(result, 6);
476
477 let result = router
478 .ainvoke(RouterInput::new("square", 4), None)
479 .await
480 .unwrap();
481 assert_eq!(result, 16);
482 }
483
484 #[tokio::test]
485 async fn test_router_runnable_abatch() {
486 let add = RunnableLambda::new(|x: i32| Ok(x + 1));
487 let square = RunnableLambda::new(|x: i32| Ok(x * x));
488
489 let router = RouterRunnable::new().add("add", add).add("square", square);
490
491 let inputs = vec![RouterInput::new("add", 5), RouterInput::new("square", 4)];
492
493 let results = router.abatch(inputs, None, false).await;
494 assert_eq!(results.len(), 2);
495 assert_eq!(results[0].as_ref().unwrap(), &6);
496 assert_eq!(results[1].as_ref().unwrap(), &16);
497 }
498
499 #[tokio::test]
500 async fn test_router_runnable_stream() {
501 let add = RunnableLambda::new(|x: i32| Ok(x + 1));
502
503 let router = RouterRunnable::new().add("add", add);
504
505 let mut stream = router.stream(RouterInput::new("add", 5), None);
506 let result = stream.next().await.unwrap().unwrap();
507 assert_eq!(result, 6);
508 }
509}