langchain_rust/semantic_router/route_layer/
route_layer.rs

1use std::{collections::HashMap, sync::Arc};
2
3use serde_json::Value;
4
5use crate::{
6    chain::{Chain, LLMChain},
7    embedding::Embedder,
8    prompt_args,
9    semantic_router::{Index, RouteLayerError, Router},
10};
11
12pub enum AggregationMethod {
13    Mean,
14    Max,
15    Sum,
16}
17impl AggregationMethod {
18    pub fn aggregate(&self, values: &[f64]) -> f64 {
19        match self {
20            AggregationMethod::Sum => values.iter().sum(),
21            AggregationMethod::Mean => values.iter().sum::<f64>() / values.len() as f64,
22            AggregationMethod::Max => *values
23                .iter()
24                .max_by(|a, b| a.partial_cmp(b).unwrap())
25                .unwrap_or(&0.0),
26        }
27    }
28}
29
30#[derive(Debug, Clone)]
31pub struct RouteChoise {
32    pub route: String,
33    pub similarity_score: f64,
34    pub tool_input: Option<Value>,
35}
36
37pub struct RouteLayer {
38    pub(crate) embedder: Arc<dyn Embedder>,
39    pub(crate) index: Box<dyn Index>,
40    pub(crate) threshold: f64,
41    pub(crate) llm: LLMChain,
42    pub(crate) top_k: usize,
43    pub(crate) aggregation_method: AggregationMethod,
44}
45
46impl RouteLayer {
47    pub async fn add_routes(&mut self, routers: &mut [Router]) -> Result<(), RouteLayerError> {
48        for router in routers.iter_mut() {
49            if router.embedding.is_none() {
50                let embeddigns = self.embedder.embed_documents(&router.utterances).await?;
51                router.embedding = Some(embeddigns);
52            }
53        }
54        self.index.add(routers).await?;
55        Ok(())
56    }
57
58    pub async fn delete_route<S: Into<String>>(
59        &mut self,
60        route_name: S,
61    ) -> Result<(), RouteLayerError> {
62        self.index.delete(&route_name.into()).await?;
63        Ok(())
64    }
65
66    pub async fn get_routers(&self) -> Result<Vec<Router>, RouteLayerError> {
67        let routes = self.index.get_routers().await?;
68        Ok(routes)
69    }
70
71    async fn filter_similar_routes(
72        &self,
73        query_vector: &[f64],
74    ) -> Result<Vec<(String, f64)>, RouteLayerError> {
75        let similar_routes = self.index.query(query_vector, self.top_k).await?;
76
77        Ok(similar_routes
78            .into_iter()
79            .filter(|(_, score)| *score >= self.threshold)
80            .collect())
81    }
82
83    fn compute_total_scores(&self, similar_routes: &[(String, f64)]) -> HashMap<String, f64> {
84        let mut scores_by_route: HashMap<String, Vec<f64>> = HashMap::new();
85
86        for (route_name, score) in similar_routes {
87            scores_by_route
88                .entry(route_name.to_owned())
89                .or_default()
90                .push(*score);
91        }
92
93        scores_by_route
94            .into_iter()
95            .map(|(route, scores)| {
96                let aggregated_score = self.aggregation_method.aggregate(&scores);
97                (route, aggregated_score)
98            })
99            .collect()
100    }
101
102    fn find_top_route_and_scores(
103        &self,
104        total_scores: HashMap<String, f64>,
105        scores_by_route: &HashMap<String, Vec<f64>>,
106    ) -> (Option<String>, Vec<f64>) {
107        let top_route = total_scores
108            .into_iter()
109            .max_by(|a, b| a.1.total_cmp(&b.1))
110            .map(|(route, _)| route);
111
112        let mut top_scores = top_route
113            .as_ref()
114            .and_then(|route| scores_by_route.get(route))
115            .unwrap_or(&vec![])
116            .clone();
117
118        top_scores.sort_unstable_by(|a, b| b.partial_cmp(a).unwrap_or(std::cmp::Ordering::Equal));
119        (top_route, top_scores)
120    }
121
122    /// Call the route layer with a query and return the best route choise
123    /// If route has a tool description, it will also return the tool input
124    pub async fn call<S: Into<String>>(
125        &self,
126        query: S,
127    ) -> Result<Option<RouteChoise>, RouteLayerError> {
128        let query: String = query.into();
129        let query_vector = self.embedder.embed_query(&query).await?;
130
131        let route_choise = self.call_embedding(&query_vector).await?;
132
133        if route_choise.is_none() {
134            return Ok(None);
135        }
136
137        let router = self
138            .index
139            .get_router(&route_choise.as_ref().unwrap().route) //safe to unwrap
140            .await?;
141
142        if router.tool_description.is_none() {
143            return Ok(route_choise);
144        }
145
146        let tool_input = self
147            .generate_tool_input(&query, &router.tool_description.unwrap())
148            .await?;
149
150        Ok(route_choise.map(|route| RouteChoise {
151            tool_input: Some(tool_input),
152            ..route
153        }))
154    }
155
156    /// Call the route layer with a query and return the best route choise
157    /// If route has a tool description, it will not return the tool input,
158    /// this just returns the route
159    pub async fn call_embedding(
160        &self,
161        embedding: &[f64],
162    ) -> Result<Option<RouteChoise>, RouteLayerError> {
163        let similar_routes = self.filter_similar_routes(embedding).await?;
164
165        if similar_routes.is_empty() {
166            return Ok(None);
167        }
168
169        // Correctly collect scores by route manually
170        let mut scores_by_route: HashMap<String, Vec<f64>> = HashMap::new();
171        for (route_name, score) in &similar_routes {
172            scores_by_route
173                .entry(route_name.clone())
174                .or_default()
175                .push(*score);
176        }
177
178        let total_scores = self.compute_total_scores(&similar_routes);
179
180        let (top_route, top_scores) =
181            self.find_top_route_and_scores(total_scores, &scores_by_route);
182
183        Ok(top_route.map(|route| RouteChoise {
184            route,
185            similarity_score: top_scores[0],
186            tool_input: None,
187        }))
188    }
189
190    async fn generate_tool_input(
191        &self,
192        query: &str,
193        description: &str,
194    ) -> Result<Value, RouteLayerError> {
195        let output = self
196            .llm
197            .invoke(prompt_args! {
198                "description"=>description,
199                "query"=>query
200            })
201            .await?;
202        match serde_json::from_str::<Value>(&output) {
203            Ok(value_result) => Ok(value_result),
204            Err(_) => Ok(Value::String(output)),
205        }
206    }
207}
208
209#[cfg(test)]
210mod tests {
211
212    use crate::{embedding::openai::OpenAiEmbedder, semantic_router::RouteLayerBuilder};
213
214    use super::*;
215
216    #[tokio::test]
217    #[ignore]
218    async fn test_route_layer_builder() {
219        let captial_route = Router::new(
220            "captial",
221            &[
222                "Capital of France is Paris.",
223                "What is the captial of France?",
224            ],
225        );
226        let description = String::from(
227            r#""A wrapper around Google Search. "
228	"Useful for when you need to answer questions about current events. "
229	"Always one of the first options when you need to find information on internet"
230	"Input should be a search query."#,
231        );
232
233        let weather_route = Router::new(
234            "temperature",
235            &[
236                "What is the temperature?",
237                "Is it raining?",
238                "Is it cloudy?",
239            ],
240        )
241        .with_tool_description(description);
242        let router_layer = RouteLayerBuilder::default()
243            .embedder(OpenAiEmbedder::default())
244            .add_route(captial_route)
245            .add_route(weather_route)
246            .aggregation_method(AggregationMethod::Sum)
247            .build()
248            .await
249            .unwrap();
250        let routes = router_layer
251            .call("What is the temperature in Peru?")
252            .await
253            .unwrap();
254
255        println!("{:?}", routes);
256        assert_eq!(routes.unwrap().route, "temperature");
257    }
258}