langchain_rust/semantic_router/route_layer/
route_layer.rs1use std::{collections::HashMap, sync::Arc};
2
3use serde_json::Value;
4
5use crate::{
6 chain::{Chain, LLMChain},
7 embedding::Embedder,
8 prompt_args,
9 semantic_router::{Index, RouteLayerError, Router},
10};
11
12pub enum AggregationMethod {
13 Mean,
14 Max,
15 Sum,
16}
17impl AggregationMethod {
18 pub fn aggregate(&self, values: &[f64]) -> f64 {
19 match self {
20 AggregationMethod::Sum => values.iter().sum(),
21 AggregationMethod::Mean => values.iter().sum::<f64>() / values.len() as f64,
22 AggregationMethod::Max => *values
23 .iter()
24 .max_by(|a, b| a.partial_cmp(b).unwrap())
25 .unwrap_or(&0.0),
26 }
27 }
28}
29
30#[derive(Debug, Clone)]
31pub struct RouteChoise {
32 pub route: String,
33 pub similarity_score: f64,
34 pub tool_input: Option<Value>,
35}
36
37pub struct RouteLayer {
38 pub(crate) embedder: Arc<dyn Embedder>,
39 pub(crate) index: Box<dyn Index>,
40 pub(crate) threshold: f64,
41 pub(crate) llm: LLMChain,
42 pub(crate) top_k: usize,
43 pub(crate) aggregation_method: AggregationMethod,
44}
45
46impl RouteLayer {
47 pub async fn add_routes(&mut self, routers: &mut [Router]) -> Result<(), RouteLayerError> {
48 for router in routers.iter_mut() {
49 if router.embedding.is_none() {
50 let embeddigns = self.embedder.embed_documents(&router.utterances).await?;
51 router.embedding = Some(embeddigns);
52 }
53 }
54 self.index.add(routers).await?;
55 Ok(())
56 }
57
58 pub async fn delete_route<S: Into<String>>(
59 &mut self,
60 route_name: S,
61 ) -> Result<(), RouteLayerError> {
62 self.index.delete(&route_name.into()).await?;
63 Ok(())
64 }
65
66 pub async fn get_routers(&self) -> Result<Vec<Router>, RouteLayerError> {
67 let routes = self.index.get_routers().await?;
68 Ok(routes)
69 }
70
71 async fn filter_similar_routes(
72 &self,
73 query_vector: &[f64],
74 ) -> Result<Vec<(String, f64)>, RouteLayerError> {
75 let similar_routes = self.index.query(query_vector, self.top_k).await?;
76
77 Ok(similar_routes
78 .into_iter()
79 .filter(|(_, score)| *score >= self.threshold)
80 .collect())
81 }
82
83 fn compute_total_scores(&self, similar_routes: &[(String, f64)]) -> HashMap<String, f64> {
84 let mut scores_by_route: HashMap<String, Vec<f64>> = HashMap::new();
85
86 for (route_name, score) in similar_routes {
87 scores_by_route
88 .entry(route_name.to_owned())
89 .or_default()
90 .push(*score);
91 }
92
93 scores_by_route
94 .into_iter()
95 .map(|(route, scores)| {
96 let aggregated_score = self.aggregation_method.aggregate(&scores);
97 (route, aggregated_score)
98 })
99 .collect()
100 }
101
102 fn find_top_route_and_scores(
103 &self,
104 total_scores: HashMap<String, f64>,
105 scores_by_route: &HashMap<String, Vec<f64>>,
106 ) -> (Option<String>, Vec<f64>) {
107 let top_route = total_scores
108 .into_iter()
109 .max_by(|a, b| a.1.total_cmp(&b.1))
110 .map(|(route, _)| route);
111
112 let mut top_scores = top_route
113 .as_ref()
114 .and_then(|route| scores_by_route.get(route))
115 .unwrap_or(&vec![])
116 .clone();
117
118 top_scores.sort_unstable_by(|a, b| b.partial_cmp(a).unwrap_or(std::cmp::Ordering::Equal));
119 (top_route, top_scores)
120 }
121
122 pub async fn call<S: Into<String>>(
125 &self,
126 query: S,
127 ) -> Result<Option<RouteChoise>, RouteLayerError> {
128 let query: String = query.into();
129 let query_vector = self.embedder.embed_query(&query).await?;
130
131 let route_choise = self.call_embedding(&query_vector).await?;
132
133 if route_choise.is_none() {
134 return Ok(None);
135 }
136
137 let router = self
138 .index
139 .get_router(&route_choise.as_ref().unwrap().route) .await?;
141
142 if router.tool_description.is_none() {
143 return Ok(route_choise);
144 }
145
146 let tool_input = self
147 .generate_tool_input(&query, &router.tool_description.unwrap())
148 .await?;
149
150 Ok(route_choise.map(|route| RouteChoise {
151 tool_input: Some(tool_input),
152 ..route
153 }))
154 }
155
156 pub async fn call_embedding(
160 &self,
161 embedding: &[f64],
162 ) -> Result<Option<RouteChoise>, RouteLayerError> {
163 let similar_routes = self.filter_similar_routes(embedding).await?;
164
165 if similar_routes.is_empty() {
166 return Ok(None);
167 }
168
169 let mut scores_by_route: HashMap<String, Vec<f64>> = HashMap::new();
171 for (route_name, score) in &similar_routes {
172 scores_by_route
173 .entry(route_name.clone())
174 .or_default()
175 .push(*score);
176 }
177
178 let total_scores = self.compute_total_scores(&similar_routes);
179
180 let (top_route, top_scores) =
181 self.find_top_route_and_scores(total_scores, &scores_by_route);
182
183 Ok(top_route.map(|route| RouteChoise {
184 route,
185 similarity_score: top_scores[0],
186 tool_input: None,
187 }))
188 }
189
190 async fn generate_tool_input(
191 &self,
192 query: &str,
193 description: &str,
194 ) -> Result<Value, RouteLayerError> {
195 let output = self
196 .llm
197 .invoke(prompt_args! {
198 "description"=>description,
199 "query"=>query
200 })
201 .await?;
202 match serde_json::from_str::<Value>(&output) {
203 Ok(value_result) => Ok(value_result),
204 Err(_) => Ok(Value::String(output)),
205 }
206 }
207}
208
209#[cfg(test)]
210mod tests {
211
212 use crate::{embedding::openai::OpenAiEmbedder, semantic_router::RouteLayerBuilder};
213
214 use super::*;
215
216 #[tokio::test]
217 #[ignore]
218 async fn test_route_layer_builder() {
219 let captial_route = Router::new(
220 "captial",
221 &[
222 "Capital of France is Paris.",
223 "What is the captial of France?",
224 ],
225 );
226 let description = String::from(
227 r#""A wrapper around Google Search. "
228 "Useful for when you need to answer questions about current events. "
229 "Always one of the first options when you need to find information on internet"
230 "Input should be a search query."#,
231 );
232
233 let weather_route = Router::new(
234 "temperature",
235 &[
236 "What is the temperature?",
237 "Is it raining?",
238 "Is it cloudy?",
239 ],
240 )
241 .with_tool_description(description);
242 let router_layer = RouteLayerBuilder::default()
243 .embedder(OpenAiEmbedder::default())
244 .add_route(captial_route)
245 .add_route(weather_route)
246 .aggregation_method(AggregationMethod::Sum)
247 .build()
248 .await
249 .unwrap();
250 let routes = router_layer
251 .call("What is the temperature in Peru?")
252 .await
253 .unwrap();
254
255 println!("{:?}", routes);
256 assert_eq!(routes.unwrap().route, "temperature");
257 }
258}