langchain_rust/semantic_router/route_layer/
builder.rs1use std::sync::Arc;
2
3use futures_util::future::try_join_all;
4
5use crate::{
6 chain::{LLMChain, LLMChainBuilder},
7 embedding::{openai::OpenAiEmbedder, Embedder},
8 language_models::llm::LLM,
9 llm::openai::OpenAI,
10 prompt::HumanMessagePromptTemplate,
11 semantic_router::{Index, MemoryIndex, RouteLayerBuilderError, Router},
12 template_jinja2,
13};
14
15use super::{AggregationMethod, RouteLayer};
16
17pub struct RouteLayerBuilder {
45 embedder: Option<Arc<dyn Embedder>>,
46 routes: Vec<Router>,
47 threshold: Option<f64>,
48 index: Option<Box<dyn Index>>,
49 llm: Option<LLMChain>,
50 top_k: usize,
51 aggregation_method: AggregationMethod,
52}
53impl Default for RouteLayerBuilder {
54 fn default() -> Self {
55 Self::new()
56 .embedder(OpenAiEmbedder::default())
57 .llm(OpenAI::default())
58 .index(MemoryIndex::new())
59 }
60}
61
62impl RouteLayerBuilder {
63 pub fn new() -> Self {
64 Self {
65 embedder: None,
66 routes: Vec::new(),
67 threshold: None,
68 llm: None,
69 index: None,
70 top_k: 5,
71 aggregation_method: AggregationMethod::Sum,
72 }
73 }
74
75 pub fn top_k(mut self, top_k: usize) -> Self {
76 let mut top_k = top_k;
77 if top_k == 0 {
78 log::warn!("top_k cannot be 0, setting it to 1");
79 top_k = 1;
80 }
81 self.top_k = top_k;
82 self
83 }
84
85 pub fn llm<L: LLM + 'static>(mut self, llm: L) -> Self {
86 let prompt = HumanMessagePromptTemplate::new(template_jinja2!(
87 "You should Generate the input for the following tool.
88Tool description:{{description}}.
89Input query context to generate the input for the tool :{{query}}
90
91Tool Input:
92",
93 "description",
94 "query"
95 ));
96 let chain = LLMChainBuilder::new()
97 .prompt(prompt)
98 .llm(llm)
99 .build()
100 .unwrap(); self.llm = Some(chain);
102 self
103 }
104
105 pub fn index<I: Index + 'static>(mut self, index: I) -> Self {
106 self.index = Some(Box::new(index));
107 self
108 }
109
110 pub fn embedder<E: Embedder + 'static>(mut self, embedder: E) -> Self {
111 self.embedder = Some(Arc::new(embedder));
112 self
113 }
114
115 pub fn threshold(mut self, threshold: f64) -> Self {
119 self.threshold = Some(threshold);
120 self
121 }
122
123 pub fn add_route(mut self, route: Router) -> Self {
124 self.routes.push(route);
125 self
126 }
127
128 pub fn aggregation_method(mut self, aggregation_method: AggregationMethod) -> Self {
129 self.aggregation_method = aggregation_method;
130 self
131 }
132
133 pub async fn build(mut self) -> Result<RouteLayer, RouteLayerBuilderError> {
134 if self.embedder.is_none() {
136 return Err(RouteLayerBuilderError::MissingEmbedder);
137 }
138
139 if self.llm.is_none() {
140 return Err(RouteLayerBuilderError::MissingLLM);
141 }
142
143 if self.index.is_none() {
144 return Err(RouteLayerBuilderError::MissingIndex);
145 }
146
147 let mut router = RouteLayer {
148 embedder: self.embedder.unwrap(), index: self.index.unwrap(),
150 llm: self.llm.unwrap(),
151 threshold: self.threshold.unwrap_or(0.82),
152 top_k: self.top_k,
153 aggregation_method: self.aggregation_method,
154 };
155
156 let embedding_futures = self
157 .routes
158 .iter_mut()
159 .filter_map(|route| {
160 if route.embedding.is_none() {
161 Some(router.embedder.embed_documents(&route.utterances))
162 } else {
163 None
164 }
165 })
166 .collect::<Vec<_>>();
167
168 let embeddings = try_join_all(embedding_futures).await?;
169
170 for (route, embedding) in self
171 .routes
172 .iter_mut()
173 .filter(|r| r.embedding.is_none())
174 .zip(embeddings)
175 {
176 route.embedding = Some(embedding);
177 }
178
179 router.index.add(&self.routes).await?;
181
182 Ok(router)
183 }
184}