agent_chain_core/runnables/
router.rs

1//! Runnable that routes to a set of Runnables.
2//!
3//! This module provides `RouterRunnable` which routes to different Runnables
4//! based on a key in the input, mirroring `langchain_core.runnables.router`.
5
6use std::collections::HashMap;
7use std::fmt::Debug;
8use std::sync::Arc;
9
10use async_trait::async_trait;
11use futures::StreamExt;
12use futures::stream::BoxStream;
13use serde::{Deserialize, Serialize};
14use serde_json::Value;
15
16use crate::error::{Error, Result};
17use crate::load::{Serializable, Serialized, SerializedConstructorData};
18
19use super::base::{DynRunnable, Runnable, RunnableSerializable};
20use super::config::{ConfigOrList, RunnableConfig, get_config_list};
21use super::utils::{ConfigurableFieldSpec, gather_with_concurrency, get_unique_config_specs};
22
23/// Router input.
24///
25/// This struct represents the input to a RouterRunnable, containing
26/// the key to route on and the actual input to pass to the selected Runnable.
27#[derive(Debug, Clone, Serialize, Deserialize)]
28pub struct RouterInput<I> {
29    /// The key to route on.
30    pub key: String,
31    /// The input to pass to the selected `Runnable`.
32    pub input: I,
33}
34
35impl<I> RouterInput<I> {
36    /// Create a new RouterInput.
37    pub fn new(key: impl Into<String>, input: I) -> Self {
38        Self {
39            key: key.into(),
40            input,
41        }
42    }
43}
44
45/// A `Runnable` that routes to a set of `Runnable` based on `Input['key']`.
46///
47/// Returns the output of the selected Runnable.
48///
49/// # Example
50///
51/// ```ignore
52/// use agent_chain_core::runnables::{RouterRunnable, RunnableLambda, RouterInput};
53///
54/// let add = RunnableLambda::new(|x: i32| Ok(x + 1));
55/// let square = RunnableLambda::new(|x: i32| Ok(x * x));
56///
57/// let router = RouterRunnable::new()
58///     .add("add", add)
59///     .add("square", square);
60///
61/// let result = router.invoke(RouterInput::new("square", 3), None)?;
62/// assert_eq!(result, 9);
63/// ```
64pub struct RouterRunnable<I, O>
65where
66    I: Send + Sync + Clone + Debug + 'static,
67    O: Send + Sync + Clone + Debug + 'static,
68{
69    /// The mapping of keys to Runnables.
70    runnables: HashMap<String, DynRunnable<I, O>>,
71    /// Optional name for this router.
72    name: Option<String>,
73}
74
75impl<I, O> Debug for RouterRunnable<I, O>
76where
77    I: Send + Sync + Clone + Debug + 'static,
78    O: Send + Sync + Clone + Debug + 'static,
79{
80    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
81        f.debug_struct("RouterRunnable")
82            .field("runnables", &self.runnables.keys().collect::<Vec<_>>())
83            .field("name", &self.name)
84            .finish()
85    }
86}
87
88impl<I, O> RouterRunnable<I, O>
89where
90    I: Send + Sync + Clone + Debug + 'static,
91    O: Send + Sync + Clone + Debug + 'static,
92{
93    /// Create a new empty RouterRunnable.
94    pub fn new() -> Self {
95        Self {
96            runnables: HashMap::new(),
97            name: None,
98        }
99    }
100
101    /// Create a new RouterRunnable from a HashMap of runnables.
102    pub fn from_runnables(runnables: HashMap<String, DynRunnable<I, O>>) -> Self {
103        Self {
104            runnables,
105            name: None,
106        }
107    }
108
109    /// Add a runnable to the router.
110    pub fn add<R>(mut self, key: impl Into<String>, runnable: R) -> Self
111    where
112        R: Runnable<Input = I, Output = O> + Send + Sync + 'static,
113    {
114        self.runnables.insert(key.into(), Arc::new(runnable));
115        self
116    }
117
118    /// Set the name of this router.
119    pub fn with_name(mut self, name: impl Into<String>) -> Self {
120        self.name = Some(name.into());
121        self
122    }
123
124    /// Get the configurable field specs from all contained runnables.
125    pub fn config_specs(&self) -> std::result::Result<Vec<ConfigurableFieldSpec>, String> {
126        let specs = self
127            .runnables
128            .values()
129            .flat_map(|_r| {
130                // For now, return empty specs since DynRunnable doesn't expose config_specs
131                // In a full implementation, this would need to be part of the Runnable trait
132                Vec::<ConfigurableFieldSpec>::new()
133            })
134            .collect::<Vec<_>>();
135
136        get_unique_config_specs(specs)
137    }
138}
139
140impl<I, O> Default for RouterRunnable<I, O>
141where
142    I: Send + Sync + Clone + Debug + 'static,
143    O: Send + Sync + Clone + Debug + 'static,
144{
145    fn default() -> Self {
146        Self::new()
147    }
148}
149
150#[async_trait]
151impl<I, O> Runnable for RouterRunnable<I, O>
152where
153    I: Send + Sync + Clone + Debug + 'static,
154    O: Send + Sync + Clone + Debug + 'static,
155{
156    type Input = RouterInput<I>;
157    type Output = O;
158
159    fn name(&self) -> Option<String> {
160        self.name.clone().or_else(|| {
161            Some(format!(
162                "RouterRunnable<{}>",
163                self.runnables.keys().cloned().collect::<Vec<_>>().join(",")
164            ))
165        })
166    }
167
168    fn invoke(&self, input: Self::Input, config: Option<RunnableConfig>) -> Result<Self::Output> {
169        let key = &input.key;
170        let actual_input = input.input;
171
172        let runnable = self
173            .runnables
174            .get(key)
175            .ok_or_else(|| Error::Other(format!("No runnable associated with key '{}'", key)))?;
176
177        runnable.invoke(actual_input, config)
178    }
179
180    async fn ainvoke(
181        &self,
182        input: Self::Input,
183        config: Option<RunnableConfig>,
184    ) -> Result<Self::Output>
185    where
186        Self: 'static,
187    {
188        let key = &input.key;
189        let actual_input = input.input;
190
191        let runnable = self
192            .runnables
193            .get(key)
194            .ok_or_else(|| Error::Other(format!("No runnable associated with key '{}'", key)))?;
195
196        runnable.ainvoke(actual_input, config).await
197    }
198
199    fn batch(
200        &self,
201        inputs: Vec<Self::Input>,
202        config: Option<ConfigOrList>,
203        return_exceptions: bool,
204    ) -> Vec<Result<Self::Output>>
205    where
206        Self: 'static,
207    {
208        if inputs.is_empty() {
209            return Vec::new();
210        }
211
212        let keys: Vec<_> = inputs.iter().map(|i| i.key.clone()).collect();
213        let actual_inputs: Vec<_> = inputs.into_iter().map(|i| i.input).collect();
214
215        // Check if all keys have corresponding runnables
216        for key in &keys {
217            if !self.runnables.contains_key(key) {
218                return vec![Err(Error::Other(
219                    "One or more keys do not have a corresponding runnable".to_string(),
220                ))];
221            }
222        }
223
224        let configs = get_config_list(config, keys.len());
225
226        let _ = return_exceptions; // Used for API compatibility, not yet implemented
227        let results: Vec<Result<O>> = keys
228            .into_iter()
229            .zip(actual_inputs)
230            .zip(configs)
231            .map(|((key, input), config)| {
232                let runnable = self.runnables.get(&key).unwrap();
233                runnable.invoke(input, Some(config))
234            })
235            .collect();
236
237        results
238    }
239
240    async fn abatch(
241        &self,
242        inputs: Vec<Self::Input>,
243        config: Option<ConfigOrList>,
244        return_exceptions: bool,
245    ) -> Vec<Result<Self::Output>>
246    where
247        Self: 'static,
248    {
249        if inputs.is_empty() {
250            return Vec::new();
251        }
252
253        let keys: Vec<_> = inputs.iter().map(|i| i.key.clone()).collect();
254        let actual_inputs: Vec<_> = inputs.into_iter().map(|i| i.input).collect();
255
256        // Check if all keys have corresponding runnables
257        for key in &keys {
258            if !self.runnables.contains_key(key) {
259                return vec![Err(Error::Other(
260                    "One or more keys do not have a corresponding runnable".to_string(),
261                ))];
262            }
263        }
264
265        let configs = get_config_list(config, keys.len());
266        let max_concurrency = configs.first().and_then(|c| c.max_concurrency);
267
268        let _ = return_exceptions; // Used for API compatibility, not yet implemented
269        // Create futures for each invocation
270        let futures: Vec<_> = keys
271            .into_iter()
272            .zip(actual_inputs)
273            .zip(configs)
274            .map(|((key, input), config)| {
275                let runnable = self.runnables.get(&key).unwrap().clone();
276                Box::pin(async move { runnable.ainvoke(input, Some(config)).await })
277                    as std::pin::Pin<Box<dyn std::future::Future<Output = Result<O>> + Send>>
278            })
279            .collect();
280
281        gather_with_concurrency(max_concurrency, futures).await
282    }
283
284    fn stream(
285        &self,
286        input: Self::Input,
287        config: Option<RunnableConfig>,
288    ) -> BoxStream<'_, Result<Self::Output>> {
289        let key = input.key.clone();
290        let actual_input = input.input;
291
292        Box::pin(async_stream::stream! {
293            let runnable = match self.runnables.get(&key) {
294                Some(r) => r,
295                None => {
296                    yield Err(Error::Other(format!("No runnable associated with key '{}'", key)));
297                    return;
298                }
299            };
300
301            let mut stream = runnable.stream(actual_input, config);
302            while let Some(output) = stream.next().await {
303                yield output;
304            }
305        })
306    }
307
308    fn astream(
309        &self,
310        input: Self::Input,
311        config: Option<RunnableConfig>,
312    ) -> BoxStream<'_, Result<Self::Output>>
313    where
314        Self: 'static,
315    {
316        let key = input.key.clone();
317        let actual_input = input.input;
318
319        Box::pin(async_stream::stream! {
320            let runnable = match self.runnables.get(&key) {
321                Some(r) => r,
322                None => {
323                    yield Err(Error::Other(format!("No runnable associated with key '{}'", key)));
324                    return;
325                }
326            };
327
328            let mut stream = runnable.astream(actual_input, config);
329            while let Some(output) = stream.next().await {
330                yield output;
331            }
332        })
333    }
334}
335
336impl<I, O> Serializable for RouterRunnable<I, O>
337where
338    I: Send + Sync + Clone + Debug + Serialize + 'static,
339    O: Send + Sync + Clone + Debug + 'static,
340{
341    fn is_lc_serializable() -> bool {
342        true
343    }
344
345    fn get_lc_namespace() -> Vec<String> {
346        vec![
347            "langchain".to_string(),
348            "schema".to_string(),
349            "runnable".to_string(),
350        ]
351    }
352
353    fn to_json(&self) -> Serialized {
354        let mut kwargs = std::collections::HashMap::new();
355        kwargs.insert(
356            "runnables".to_string(),
357            serde_json::json!(self.runnables.keys().collect::<Vec<_>>()),
358        );
359
360        Serialized::Constructor(SerializedConstructorData {
361            lc: 1,
362            id: Self::get_lc_namespace(),
363            kwargs,
364            name: None,
365            graph: None,
366        })
367    }
368}
369
370impl<I, O> RunnableSerializable for RouterRunnable<I, O>
371where
372    I: Send + Sync + Clone + Debug + Serialize + 'static,
373    O: Send + Sync + Clone + Debug + Serialize + 'static,
374{
375}
376
377/// Type alias for a RouterRunnable with Value input and output.
378///
379/// This is useful when the types are not known at compile time.
380pub type DynRouterRunnable = RouterRunnable<Value, Value>;
381
382#[cfg(test)]
383mod tests {
384    use super::*;
385    use crate::runnables::RunnableLambda;
386
387    #[test]
388    fn test_router_input() {
389        let input = RouterInput::new("add", 5);
390        assert_eq!(input.key, "add");
391        assert_eq!(input.input, 5);
392    }
393
394    #[test]
395    fn test_router_runnable_invoke() {
396        let add = RunnableLambda::new(|x: i32| Ok(x + 1));
397        let square = RunnableLambda::new(|x: i32| Ok(x * x));
398
399        let router = RouterRunnable::new().add("add", add).add("square", square);
400
401        let result = router.invoke(RouterInput::new("add", 5), None).unwrap();
402        assert_eq!(result, 6);
403
404        let result = router.invoke(RouterInput::new("square", 4), None).unwrap();
405        assert_eq!(result, 16);
406    }
407
408    #[test]
409    fn test_router_runnable_missing_key() {
410        let add = RunnableLambda::new(|x: i32| Ok(x + 1));
411        let router = RouterRunnable::new().add("add", add);
412
413        let result = router.invoke(RouterInput::new("multiply", 5), None);
414        assert!(result.is_err());
415        assert!(
416            result
417                .unwrap_err()
418                .to_string()
419                .contains("No runnable associated with key")
420        );
421    }
422
423    #[test]
424    fn test_router_runnable_batch() {
425        let add = RunnableLambda::new(|x: i32| Ok(x + 1));
426        let square = RunnableLambda::new(|x: i32| Ok(x * x));
427
428        let router = RouterRunnable::new().add("add", add).add("square", square);
429
430        let inputs = vec![
431            RouterInput::new("add", 5),
432            RouterInput::new("square", 4),
433            RouterInput::new("add", 10),
434        ];
435
436        let results = router.batch(inputs, None, false);
437        assert_eq!(results.len(), 3);
438        assert_eq!(results[0].as_ref().unwrap(), &6);
439        assert_eq!(results[1].as_ref().unwrap(), &16);
440        assert_eq!(results[2].as_ref().unwrap(), &11);
441    }
442
443    #[test]
444    fn test_router_runnable_name() {
445        let add = RunnableLambda::new(|x: i32| Ok(x + 1));
446
447        let router = RouterRunnable::new().add("add", add).with_name("my_router");
448
449        assert_eq!(router.name(), Some("my_router".to_string()));
450    }
451
452    #[test]
453    fn test_router_runnable_default_name() {
454        let add = RunnableLambda::new(|x: i32| Ok(x + 1));
455        let square = RunnableLambda::new(|x: i32| Ok(x * x));
456
457        let router = RouterRunnable::new().add("add", add).add("square", square);
458
459        let name = router.name().unwrap();
460        assert!(name.starts_with("RouterRunnable<"));
461        assert!(name.contains("add") || name.contains("square"));
462    }
463
464    #[tokio::test]
465    async fn test_router_runnable_ainvoke() {
466        let add = RunnableLambda::new(|x: i32| Ok(x + 1));
467        let square = RunnableLambda::new(|x: i32| Ok(x * x));
468
469        let router = RouterRunnable::new().add("add", add).add("square", square);
470
471        let result = router
472            .ainvoke(RouterInput::new("add", 5), None)
473            .await
474            .unwrap();
475        assert_eq!(result, 6);
476
477        let result = router
478            .ainvoke(RouterInput::new("square", 4), None)
479            .await
480            .unwrap();
481        assert_eq!(result, 16);
482    }
483
484    #[tokio::test]
485    async fn test_router_runnable_abatch() {
486        let add = RunnableLambda::new(|x: i32| Ok(x + 1));
487        let square = RunnableLambda::new(|x: i32| Ok(x * x));
488
489        let router = RouterRunnable::new().add("add", add).add("square", square);
490
491        let inputs = vec![RouterInput::new("add", 5), RouterInput::new("square", 4)];
492
493        let results = router.abatch(inputs, None, false).await;
494        assert_eq!(results.len(), 2);
495        assert_eq!(results[0].as_ref().unwrap(), &6);
496        assert_eq!(results[1].as_ref().unwrap(), &16);
497    }
498
499    #[tokio::test]
500    async fn test_router_runnable_stream() {
501        let add = RunnableLambda::new(|x: i32| Ok(x + 1));
502
503        let router = RouterRunnable::new().add("add", add);
504
505        let mut stream = router.stream(RouterInput::new("add", 5), None);
506        let result = stream.next().await.unwrap().unwrap();
507        assert_eq!(result, 6);
508    }
509}