langchain_rust/semantic_router/route_layer/
builder.rs

1use std::sync::Arc;
2
3use futures_util::future::try_join_all;
4
5use crate::{
6    chain::{LLMChain, LLMChainBuilder},
7    embedding::{openai::OpenAiEmbedder, Embedder},
8    language_models::llm::LLM,
9    llm::openai::OpenAI,
10    prompt::HumanMessagePromptTemplate,
11    semantic_router::{Index, MemoryIndex, RouteLayerBuilderError, Router},
12    template_jinja2,
13};
14
15use super::{AggregationMethod, RouteLayer};
16
17/// A builder for creating a `RouteLayer`.
18///```rust,ignore
19/// let captial_route = Router::new(
20///     "captial",
21///     &[
22///         "Capital of France is Paris.",
23///         "What is the captial of France?",
24///     ],
25/// );
26/// let weather_route = Router::new(
27///     "temperature",
28///     &[
29///         "What is the temperature?",
30///         "Is it raining?",
31///         "Is it cloudy?",
32///     ],
33/// );
34/// let router_layer = RouteLayerBuilder::default()
35///     .embedder(OpenAiEmbedder::default())
36///     .add_route(captial_route)
37///     .add_route(weather_route)
38///     .aggregation_method(AggregationMethod::Sum)
39///     .threshold(0.82)
40///     .build()
41///     .await
42///     .unwrap();
43/// ```
44pub struct RouteLayerBuilder {
45    embedder: Option<Arc<dyn Embedder>>,
46    routes: Vec<Router>,
47    threshold: Option<f64>,
48    index: Option<Box<dyn Index>>,
49    llm: Option<LLMChain>,
50    top_k: usize,
51    aggregation_method: AggregationMethod,
52}
53impl Default for RouteLayerBuilder {
54    fn default() -> Self {
55        Self::new()
56            .embedder(OpenAiEmbedder::default())
57            .llm(OpenAI::default())
58            .index(MemoryIndex::new())
59    }
60}
61
62impl RouteLayerBuilder {
63    pub fn new() -> Self {
64        Self {
65            embedder: None,
66            routes: Vec::new(),
67            threshold: None,
68            llm: None,
69            index: None,
70            top_k: 5,
71            aggregation_method: AggregationMethod::Sum,
72        }
73    }
74
75    pub fn top_k(mut self, top_k: usize) -> Self {
76        let mut top_k = top_k;
77        if top_k == 0 {
78            log::warn!("top_k cannot be 0, setting it to 1");
79            top_k = 1;
80        }
81        self.top_k = top_k;
82        self
83    }
84
85    pub fn llm<L: LLM + 'static>(mut self, llm: L) -> Self {
86        let prompt = HumanMessagePromptTemplate::new(template_jinja2!(
87            "You should Generate the input for the following tool.
88Tool description:{{description}}.
89Input query context to generate the input for the tool :{{query}}
90
91Tool Input:
92",
93            "description",
94            "query"
95        ));
96        let chain = LLMChainBuilder::new()
97            .prompt(prompt)
98            .llm(llm)
99            .build()
100            .unwrap(); //safe to unwrap
101        self.llm = Some(chain);
102        self
103    }
104
105    pub fn index<I: Index + 'static>(mut self, index: I) -> Self {
106        self.index = Some(Box::new(index));
107        self
108    }
109
110    pub fn embedder<E: Embedder + 'static>(mut self, embedder: E) -> Self {
111        self.embedder = Some(Arc::new(embedder));
112        self
113    }
114
115    /// The threshold is the minimum similarity score that a route must have to be considered.
116    /// This depends on the similarity metric used by the embedder.
117    /// For open ai text-embedding-ada-002, the best threshold is 0.82
118    pub fn threshold(mut self, threshold: f64) -> Self {
119        self.threshold = Some(threshold);
120        self
121    }
122
123    pub fn add_route(mut self, route: Router) -> Self {
124        self.routes.push(route);
125        self
126    }
127
128    pub fn aggregation_method(mut self, aggregation_method: AggregationMethod) -> Self {
129        self.aggregation_method = aggregation_method;
130        self
131    }
132
133    pub async fn build(mut self) -> Result<RouteLayer, RouteLayerBuilderError> {
134        // Check if any routers lack an embedding and there's no global embedder provided.
135        if self.embedder.is_none() {
136            return Err(RouteLayerBuilderError::MissingEmbedder);
137        }
138
139        if self.llm.is_none() {
140            return Err(RouteLayerBuilderError::MissingLLM);
141        }
142
143        if self.index.is_none() {
144            return Err(RouteLayerBuilderError::MissingIndex);
145        }
146
147        let mut router = RouteLayer {
148            embedder: self.embedder.unwrap(), //it's safe to unwrap here because we checked for None above
149            index: self.index.unwrap(),
150            llm: self.llm.unwrap(),
151            threshold: self.threshold.unwrap_or(0.82),
152            top_k: self.top_k,
153            aggregation_method: self.aggregation_method,
154        };
155
156        let embedding_futures = self
157            .routes
158            .iter_mut()
159            .filter_map(|route| {
160                if route.embedding.is_none() {
161                    Some(router.embedder.embed_documents(&route.utterances))
162                } else {
163                    None
164                }
165            })
166            .collect::<Vec<_>>();
167
168        let embeddings = try_join_all(embedding_futures).await?;
169
170        for (route, embedding) in self
171            .routes
172            .iter_mut()
173            .filter(|r| r.embedding.is_none())
174            .zip(embeddings)
175        {
176            route.embedding = Some(embedding);
177        }
178
179        // Add routes to the index.
180        router.index.add(&self.routes).await?;
181
182        Ok(router)
183    }
184}