1use std::{collections::HashMap, path::Path, sync::Arc};
14
15use serde::{Deserialize, Serialize};
16use serde_json::{Map, Value, json};
17
18use crate::{
19 error::{Error, Result},
20 index::{QueryOutput, RedisConnectionInfo, SearchIndex},
21 query::{Vector, VectorRangeQuery},
22 schema::VectorDataType,
23 vectorizers::Vectorizer,
24};
25
26const ROUTER_REFERENCE_ID_FIELD: &str = "reference_id";
27const ROUTER_ROUTE_NAME_FIELD: &str = "route_name";
28const ROUTER_REFERENCE_FIELD: &str = "reference";
29const ROUTER_VECTOR_FIELD: &str = "vector";
30const DEFAULT_ROUTE_DISTANCE_THRESHOLD: f32 = 0.5;
31
32#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
34#[serde(rename_all = "lowercase")]
35pub enum DistanceAggregationMethod {
36 Avg,
38 Min,
40 Sum,
42}
43
44#[derive(Debug, Clone, Serialize, Deserialize)]
46pub struct Route {
47 pub name: String,
49 pub references: Vec<String>,
51 #[serde(default)]
53 pub metadata: Map<String, Value>,
54 #[serde(default)]
56 pub distance_threshold: Option<f32>,
57}
58
59impl Route {
60 pub fn new(name: impl Into<String>, references: Vec<String>) -> Self {
62 Self {
63 name: name.into(),
64 references,
65 metadata: Map::new(),
66 distance_threshold: None,
67 }
68 }
69
70 fn effective_threshold(&self) -> f32 {
71 self.distance_threshold
72 .unwrap_or(DEFAULT_ROUTE_DISTANCE_THRESHOLD)
73 }
74}
75
76#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
78pub struct RouteMatch {
79 pub name: Option<String>,
81 pub distance: Option<f32>,
83}
84
85impl RouteMatch {
86 fn no_match() -> Self {
87 Self {
88 name: None,
89 distance: None,
90 }
91 }
92}
93
94#[derive(Debug, Clone, Serialize, Deserialize)]
96pub struct RoutingConfig {
97 pub max_k: usize,
99 pub aggregation_method: DistanceAggregationMethod,
101}
102
103impl Default for RoutingConfig {
104 fn default() -> Self {
105 Self {
106 max_k: 1,
107 aggregation_method: DistanceAggregationMethod::Avg,
108 }
109 }
110}
111
112#[derive(Clone)]
114pub struct SemanticRouter {
115 pub name: String,
117 pub connection: RedisConnectionInfo,
119 pub routes: Vec<Route>,
121 pub routing_config: RoutingConfig,
123 pub dtype: VectorDataType,
125 pub index: SearchIndex,
127 vectorizer: Arc<dyn Vectorizer>,
128 vector_dimensions: usize,
129}
130
131impl SemanticRouter {
132 pub fn new<V>(
137 name: impl Into<String>,
138 redis_url: impl Into<String>,
139 routes: Vec<Route>,
140 routing_config: RoutingConfig,
141 vectorizer: V,
142 ) -> Result<Self>
143 where
144 V: Vectorizer + 'static,
145 {
146 Self::new_with_options(
147 name,
148 redis_url,
149 routes,
150 routing_config,
151 vectorizer,
152 VectorDataType::Float32,
153 false,
154 )
155 }
156
157 #[cfg(feature = "hf-local")]
169 pub fn with_default_vectorizer(
170 name: impl Into<String>,
171 redis_url: impl Into<String>,
172 routes: Vec<Route>,
173 routing_config: RoutingConfig,
174 ) -> Result<Self> {
175 let vectorizer = crate::vectorizers::HuggingFaceTextVectorizer::new(Default::default())?;
176 Self::new(name, redis_url, routes, routing_config, vectorizer)
177 }
178
179 pub fn new_with_options<V>(
181 name: impl Into<String>,
182 redis_url: impl Into<String>,
183 routes: Vec<Route>,
184 routing_config: RoutingConfig,
185 vectorizer: V,
186 dtype: VectorDataType,
187 overwrite: bool,
188 ) -> Result<Self>
189 where
190 V: Vectorizer + 'static,
191 {
192 if routes.is_empty() {
193 return Err(Error::InvalidInput(
194 "semantic router requires at least one route".to_owned(),
195 ));
196 }
197
198 let Some(first_reference) = routes
199 .iter()
200 .flat_map(|route| route.references.iter())
201 .next()
202 .cloned()
203 else {
204 return Err(Error::InvalidInput(
205 "semantic router routes require at least one reference".to_owned(),
206 ));
207 };
208
209 let vectorizer = Arc::new(vectorizer);
210 let probe_vector = vectorizer.embed(&first_reference)?;
211 if probe_vector.is_empty() {
212 return Err(Error::InvalidInput(
213 "router vectorizer produced an empty embedding".to_owned(),
214 ));
215 }
216
217 let name = name.into();
218 let connection = RedisConnectionInfo::new(redis_url);
219 let schema = router_schema(&name, probe_vector.len(), dtype);
220 let index = SearchIndex::from_json_value(schema, connection.redis_url.clone())?;
221 let existed = index.exists().unwrap_or(false);
222
223 if !overwrite && existed {
225 let existing_index = SearchIndex::from_existing(&name, connection.redis_url.clone())?;
226 if existing_index.schema().to_json_value()? != index.schema().to_json_value()? {
227 return Err(Error::InvalidInput(format!(
228 "Existing index {name} schema does not match the user provided schema for the semantic router. \
229 If you wish to overwrite the index schema, set overwrite=true during initialization."
230 )));
231 }
232 }
233
234 index.create_with_options(overwrite, false)?;
235
236 let semantic_router = Self {
237 name,
238 connection,
239 routes,
240 routing_config,
241 dtype,
242 index,
243 vectorizer,
244 vector_dimensions: probe_vector.len(),
245 };
246
247 if !existed || overwrite {
248 semantic_router.load_routes()?;
249 }
250 semantic_router.persist_config()?;
251 Ok(semantic_router)
252 }
253
254 pub fn from_existing<V>(
263 name: impl Into<String>,
264 redis_url: impl Into<String>,
265 vectorizer: V,
266 dtype: VectorDataType,
267 ) -> Result<Self>
268 where
269 V: Vectorizer + 'static,
270 {
271 let name = name.into();
272 let connection = RedisConnectionInfo::new(redis_url);
273 let config_key = router_config_key(&name);
274
275 let client = connection.client()?;
276 let mut conn = client.get_connection()?;
277
278 let raw: Option<String> = redis::cmd("JSON.GET")
280 .arg(&config_key)
281 .arg(".")
282 .query(&mut conn)?;
283 let raw = raw.ok_or_else(|| {
284 Error::InvalidInput(format!(
285 "No valid router config found for {name}. No persisted configuration exists at key '{config_key}'."
286 ))
287 })?;
288 let config_value: Value = serde_json::from_str(&raw)?;
289 let config_obj = config_value
290 .as_object()
291 .ok_or_else(|| Error::InvalidInput("Router config is not an object".to_owned()))?;
292
293 let routes_value = config_obj
294 .get("routes")
295 .ok_or_else(|| Error::InvalidInput("Router config missing 'routes'".to_owned()))?;
296 let routes: Vec<Route> = serde_json::from_value(routes_value.clone())?;
297 let routing_config_value = config_obj.get("routing_config").ok_or_else(|| {
298 Error::InvalidInput("Router config missing 'routing_config'".to_owned())
299 })?;
300 let routing_config: RoutingConfig = serde_json::from_value(routing_config_value.clone())?;
301
302 let vectorizer = Arc::new(vectorizer);
303
304 let first_ref = routes
306 .iter()
307 .flat_map(|r| r.references.iter())
308 .next()
309 .ok_or_else(|| Error::InvalidInput("Persisted routes have no references".to_owned()))?;
310 let probe = vectorizer.embed(first_ref)?;
311 let vector_dimensions = probe.len();
312
313 let schema = router_schema(&name, vector_dimensions, dtype);
314 let index = SearchIndex::from_json_value(schema, connection.redis_url.clone())?;
315 if !index.exists().unwrap_or(false) {
317 return Err(Error::InvalidInput(format!(
318 "Index '{name}' does not exist in Redis"
319 )));
320 }
321
322 Ok(Self {
323 name,
324 connection,
325 routes,
326 routing_config,
327 dtype,
328 index,
329 vectorizer,
330 vector_dimensions,
331 })
332 }
333
334 pub fn route_names(&self) -> Vec<String> {
336 self.routes.iter().map(|route| route.name.clone()).collect()
337 }
338
339 pub fn route_thresholds(&self) -> HashMap<String, f32> {
341 self.routes
342 .iter()
343 .map(|route| (route.name.clone(), route.effective_threshold()))
344 .collect()
345 }
346
347 pub fn get(&self, route_name: &str) -> Option<&Route> {
349 self.routes.iter().find(|route| route.name == route_name)
350 }
351
352 pub fn update_routing_config(&mut self, routing_config: RoutingConfig) {
354 self.routing_config = routing_config;
355 }
356
357 pub fn update_route_thresholds(&mut self, route_thresholds: &HashMap<String, f32>) {
359 for route in &mut self.routes {
360 if let Some(distance_threshold) = route_thresholds.get(&route.name) {
361 route.distance_threshold = Some(*distance_threshold);
362 }
363 }
364 }
365
366 pub fn route(&self, statement: Option<&str>, vector: Option<&[f32]>) -> Result<RouteMatch> {
368 Ok(self
369 .route_many(statement, vector, None, None)?
370 .into_iter()
371 .next()
372 .unwrap_or_else(RouteMatch::no_match))
373 }
374
375 pub fn route_many(
377 &self,
378 statement: Option<&str>,
379 vector: Option<&[f32]>,
380 max_k: Option<usize>,
381 aggregation_method: Option<DistanceAggregationMethod>,
382 ) -> Result<Vec<RouteMatch>> {
383 let vector = self.resolve_vector(statement, vector)?;
384 let max_threshold = self
385 .routes
386 .iter()
387 .map(Route::effective_threshold)
388 .fold(DEFAULT_ROUTE_DISTANCE_THRESHOLD, f32::max);
389 let reference_count = self
390 .routes
391 .iter()
392 .map(|route| route.references.len())
393 .sum::<usize>()
394 .max(1);
395 let query = VectorRangeQuery::new(Vector::new(vector), ROUTER_VECTOR_FIELD, max_threshold)
396 .paging(0, reference_count)
397 .with_return_fields([ROUTER_ROUTE_NAME_FIELD, ROUTER_REFERENCE_ID_FIELD]);
398
399 let documents = query_output_documents(self.index.query(&query)?)?;
400 let mut grouped: HashMap<String, Vec<f32>> = HashMap::new();
401 for document in documents {
402 let Some(route_name) = document
403 .get(ROUTER_ROUTE_NAME_FIELD)
404 .and_then(Value::as_str)
405 .map(str::to_owned)
406 else {
407 continue;
408 };
409 let Some(distance) = parse_distance(document.get("vector_distance")) else {
410 continue;
411 };
412 grouped.entry(route_name).or_default().push(distance);
413 }
414
415 let aggregation_method =
416 aggregation_method.unwrap_or(self.routing_config.aggregation_method);
417 let mut matches = self
418 .routes
419 .iter()
420 .filter_map(|route| {
421 let distances = grouped.get(&route.name)?;
422 let distance = aggregate_distances(distances, aggregation_method);
423 (distance <= route.effective_threshold()).then(|| RouteMatch {
424 name: Some(route.name.clone()),
425 distance: Some(distance),
426 })
427 })
428 .collect::<Vec<_>>();
429
430 matches.sort_by(|left, right| {
431 let left = left.distance.unwrap_or(f32::INFINITY);
432 let right = right.distance.unwrap_or(f32::INFINITY);
433 left.total_cmp(&right)
434 });
435 matches.truncate(max_k.unwrap_or(self.routing_config.max_k));
436 Ok(matches)
437 }
438
439 pub fn add_routes(&mut self, routes: &[Route]) -> Result<()> {
441 for route in routes {
442 validate_route(route)?;
443 }
444 self.load_route_batch(routes)?;
445 for route in routes {
446 if self.get(&route.name).is_none() {
447 self.routes.push(route.clone());
448 }
449 }
450 Ok(())
451 }
452
453 pub fn remove_route(&mut self, route_name: &str) -> Result<()> {
455 let Some(route) = self.get(route_name).cloned() else {
456 return Ok(());
457 };
458
459 let keys = route
460 .references
461 .iter()
462 .map(|reference| self.index.key(&route_reference_id(&route.name, reference)))
463 .collect::<Vec<_>>();
464 self.index.drop_keys(&keys)?;
465 self.routes.retain(|route| route.name != route_name);
466 self.persist_config()?;
467 Ok(())
468 }
469
470 pub fn clear(&mut self) -> Result<usize> {
472 let deleted = self.index.clear()?;
473 self.routes.clear();
474 Ok(deleted)
475 }
476
477 pub fn delete(&self) -> Result<()> {
479 let config_key = router_config_key(&self.name);
481 let client = self.connection.client()?;
482 let mut conn = client.get_connection()?;
483 let _: usize = redis::cmd("DEL").arg(&config_key).query(&mut conn)?;
484 match self.index.delete(true) {
487 Ok(()) => Ok(()),
488 Err(Error::InvalidInput(msg)) if msg.contains("does not exist") => Ok(()),
489 Err(other) => Err(other),
490 }
491 }
492
493 pub fn add_route_references(
497 &mut self,
498 route_name: &str,
499 references: &[String],
500 ) -> Result<Vec<String>> {
501 if references.is_empty() {
502 return Ok(Vec::new());
503 }
504
505 if self.get(route_name).is_none() {
507 return Err(crate::Error::InvalidInput(format!(
508 "Route '{route_name}' not found in the SemanticRouter"
509 )));
510 }
511
512 let refs_str: Vec<&str> = references.iter().map(String::as_str).collect();
513 let embeddings = self.vectorizer.embed_many(&refs_str)?;
514 let mut records = Vec::with_capacity(references.len());
515 let mut keys = Vec::with_capacity(references.len());
516
517 for (reference, embedding) in references.iter().zip(embeddings) {
518 if embedding.len() != self.vector_dimensions {
519 return Err(crate::Error::InvalidInput(format!(
520 "router vector dimensions mismatch: expected {}, got {}",
521 self.vector_dimensions,
522 embedding.len()
523 )));
524 }
525 let ref_id = route_reference_id(route_name, reference);
526 keys.push(self.index.key(&ref_id));
527 records.push(json!({
528 ROUTER_REFERENCE_ID_FIELD: ref_id,
529 ROUTER_ROUTE_NAME_FIELD: route_name,
530 ROUTER_REFERENCE_FIELD: reference,
531 ROUTER_VECTOR_FIELD: embedding,
532 }));
533 }
534
535 if !records.is_empty() {
536 let _: Vec<String> = self.index.load(&records, ROUTER_REFERENCE_ID_FIELD, None)?;
537 }
538
539 if let Some(route) = self.routes.iter_mut().find(|r| r.name == route_name) {
541 route.references.extend(references.iter().cloned());
542 }
543 self.persist_config()?;
544
545 Ok(keys)
546 }
547
548 pub fn get_route_references(
552 &self,
553 route_name: Option<&str>,
554 reference_ids: Option<&[String]>,
555 ) -> Result<Vec<Map<String, Value>>> {
556 let ids_to_query: Vec<String> = if let Some(ref_ids) = reference_ids {
557 ref_ids.to_vec()
558 } else if let Some(route_name) = route_name {
559 let pattern = self.route_pattern(route_name);
560 let scanned = self.scan_keys(&pattern)?;
561 let sep = self.index.key_separator();
562 let prefix = self.index.prefix();
563 let prefix_with_sep = if prefix.ends_with(sep) {
564 prefix.to_owned()
565 } else {
566 format!("{prefix}{sep}")
567 };
568 scanned
570 .into_iter()
571 .map(|key| {
572 key.strip_prefix(&prefix_with_sep)
573 .unwrap_or(&key)
574 .to_owned()
575 })
576 .collect()
577 } else {
578 return Err(crate::Error::InvalidInput(
579 "Must provide a route name, reference ids, or keys to get references".to_owned(),
580 ));
581 };
582
583 let queries: Vec<crate::query::FilterQuery> = ids_to_query
584 .iter()
585 .map(|id| {
586 let filter = crate::filter::Tag::new(ROUTER_REFERENCE_ID_FIELD).eq(id.as_str());
587 crate::query::FilterQuery::new(filter).with_return_fields([
588 ROUTER_REFERENCE_ID_FIELD,
589 ROUTER_ROUTE_NAME_FIELD,
590 ROUTER_REFERENCE_FIELD,
591 ])
592 })
593 .collect();
594
595 let results = self.index.batch_query(queries.iter())?;
596 let mut refs = Vec::new();
597 for result in results {
598 if let QueryOutput::Documents(docs) = result {
599 for doc in docs {
600 refs.push(doc);
601 }
602 }
603 }
604 Ok(refs)
605 }
606
607 pub fn delete_route_references(
611 &mut self,
612 route_name: Option<&str>,
613 reference_ids: Option<&[String]>,
614 keys: Option<&[String]>,
615 ) -> Result<usize> {
616 let keys_to_delete: Vec<String> = if let Some(explicit_keys) = keys {
617 explicit_keys.to_vec()
618 } else if let Some(ref_ids) = reference_ids {
619 let queries: Vec<crate::query::FilterQuery> = ref_ids
621 .iter()
622 .map(|id| {
623 let filter = crate::filter::Tag::new(ROUTER_REFERENCE_ID_FIELD).eq(id.as_str());
624 crate::query::FilterQuery::new(filter).with_return_fields([
625 ROUTER_REFERENCE_ID_FIELD,
626 ROUTER_ROUTE_NAME_FIELD,
627 ROUTER_REFERENCE_FIELD,
628 ])
629 })
630 .collect();
631
632 let results = self.index.batch_query(queries.iter())?;
633 let mut found_keys = Vec::new();
634 for result in results {
635 if let QueryOutput::Documents(docs) = result {
636 for doc in docs {
637 if let Some(ref_id) =
640 doc.get(ROUTER_REFERENCE_ID_FIELD).and_then(Value::as_str)
641 {
642 found_keys.push(self.index.key(ref_id));
643 }
644 }
645 }
646 }
647 found_keys
648 } else if let Some(route_name) = route_name {
649 let pattern = self.route_pattern(route_name);
650 self.scan_keys(&pattern)?
651 } else {
652 return Err(crate::Error::InvalidInput(
653 "Must provide route_name, reference_ids, or keys to delete references".to_owned(),
654 ));
655 };
656
657 if keys_to_delete.is_empty() {
658 return Ok(0);
659 }
660
661 let sep = self.index.key_separator();
670 let prefix_raw = self.index.prefix().trim_end_matches(sep);
671 let prefix_with_sep = if prefix_raw.is_empty() {
672 String::new()
673 } else {
674 format!("{prefix_raw}{sep}")
675 };
676 for key in &keys_to_delete {
677 let id = key.strip_prefix(&prefix_with_sep).unwrap_or(key);
678 if let Some((rname, _hash)) = id.split_once(':') {
680 if let Some(route) = self.routes.iter_mut().find(|r| r.name == rname) {
681 route
683 .references
684 .retain(|ref_text| route_reference_id(rname, ref_text) != id);
685 }
686 }
687 }
688
689 let deleted = self.index.drop_keys(&keys_to_delete)?;
690 self.persist_config()?;
691 Ok(deleted)
692 }
693
694 pub fn to_json_value(&self) -> Result<Value> {
700 Ok(json!({
701 "name": self.name,
702 "routes": self.routes,
703 "routing_config": self.routing_config,
704 "vectorizer": {
705 "type": "custom"
706 }
707 }))
708 }
709
710 pub fn to_dict(&self) -> Result<Value> {
712 self.to_json_value()
713 }
714
715 pub fn to_yaml(&self, file_path: impl AsRef<Path>, overwrite: bool) -> Result<()> {
719 let path = file_path.as_ref();
720 if path.exists() && !overwrite {
721 return Err(Error::InvalidInput(format!(
722 "Schema file {} already exists.",
723 path.display()
724 )));
725 }
726 let dict = self.to_json_value()?;
727 let file = std::fs::File::create(path)
728 .map_err(|e| Error::InvalidInput(format!("Cannot create file: {e}")))?;
729 serde_yaml::to_writer(file, &dict)
730 .map_err(|e| Error::InvalidInput(format!("YAML serialization error: {e}")))?;
731 Ok(())
732 }
733
734 pub fn from_yaml<V>(
739 file_path: impl AsRef<Path>,
740 redis_url: impl Into<String>,
741 vectorizer: V,
742 dtype: VectorDataType,
743 overwrite: bool,
744 ) -> Result<Self>
745 where
746 V: Vectorizer + 'static,
747 {
748 let path = file_path.as_ref();
749 if !path.exists() {
750 return Err(Error::InvalidInput(format!(
751 "File {} does not exist",
752 path.display()
753 )));
754 }
755 let file = std::fs::File::open(path)
756 .map_err(|e| Error::InvalidInput(format!("Cannot open file: {e}")))?;
757 let dict: Value = serde_yaml::from_reader(file)
758 .map_err(|e| Error::InvalidInput(format!("YAML deserialization error: {e}")))?;
759 Self::from_dict(dict, redis_url, vectorizer, dtype, overwrite)
760 }
761
762 pub fn from_dict<V>(
767 data: Value,
768 redis_url: impl Into<String>,
769 vectorizer: V,
770 dtype: VectorDataType,
771 overwrite: bool,
772 ) -> Result<Self>
773 where
774 V: Vectorizer + 'static,
775 {
776 let obj = data
777 .as_object()
778 .ok_or_else(|| Error::InvalidInput("Router dict must be a JSON object".to_owned()))?;
779
780 let name = obj
781 .get("name")
782 .and_then(Value::as_str)
783 .ok_or_else(|| {
784 Error::InvalidInput(
785 "Unable to load semantic router from dict: missing 'name'".to_owned(),
786 )
787 })?
788 .to_owned();
789
790 let routes_value = obj.get("routes").ok_or_else(|| {
791 Error::InvalidInput(
792 "Unable to load semantic router from dict: missing 'routes'".to_owned(),
793 )
794 })?;
795 let routes: Vec<Route> = serde_json::from_value(routes_value.clone())?;
796
797 let routing_config_value = obj.get("routing_config").ok_or_else(|| {
798 Error::InvalidInput(
799 "Unable to load semantic router from dict: missing 'routing_config'".to_owned(),
800 )
801 })?;
802 let routing_config: RoutingConfig = serde_json::from_value(routing_config_value.clone())?;
803
804 Self::new_with_options(
805 name,
806 redis_url,
807 routes,
808 routing_config,
809 vectorizer,
810 dtype,
811 overwrite,
812 )
813 }
814
815 fn persist_config(&self) -> Result<()> {
817 let config_key = router_config_key(&self.name);
818 let dict = self.to_json_value()?;
819 let json_str = serde_json::to_string(&dict)?;
820 let client = self.connection.client()?;
821 let mut conn = client.get_connection()?;
822 let _: () = redis::cmd("JSON.SET")
823 .arg(&config_key)
824 .arg(".")
825 .arg(&json_str)
826 .query(&mut conn)?;
827 Ok(())
828 }
829
830 fn scan_keys(&self, pattern: &str) -> Result<Vec<String>> {
832 let client = self.connection.client()?;
833 let mut connection = client.get_connection()?;
834 let mut cursor = 0_u64;
835 let mut keys = Vec::new();
836 loop {
837 let (next_cursor, batch): (u64, Vec<String>) = redis::cmd("SCAN")
838 .arg(cursor)
839 .arg("MATCH")
840 .arg(pattern)
841 .arg("COUNT")
842 .arg(100)
843 .query(&mut connection)?;
844 keys.extend(batch);
845 if next_cursor == 0 {
846 break;
847 }
848 cursor = next_cursor;
849 }
850 Ok(keys)
851 }
852
853 fn route_pattern(&self, route_name: &str) -> String {
855 let sep = self.index.key_separator();
856 let prefix = self.index.prefix().trim_end_matches(sep);
857 if prefix.is_empty() {
858 format!("{route_name}{sep}*")
859 } else {
860 format!("{prefix}{sep}{route_name}{sep}*")
861 }
862 }
863
864 fn load_routes(&self) -> Result<()> {
865 self.load_route_batch(&self.routes)
866 }
867
868 fn load_route_batch(&self, routes: &[Route]) -> Result<()> {
869 for route in routes {
870 validate_route(route)?;
871 }
872
873 let mut records = Vec::new();
874 for route in routes {
875 let refs = route
876 .references
877 .iter()
878 .map(String::as_str)
879 .collect::<Vec<_>>();
880 let embeddings = self.vectorizer.embed_many(&refs)?;
881 for (reference, embedding) in route.references.iter().zip(embeddings) {
882 if embedding.len() != self.vector_dimensions {
883 return Err(crate::Error::InvalidInput(format!(
884 "router vector dimensions mismatch: expected {}, got {}",
885 self.vector_dimensions,
886 embedding.len()
887 )));
888 }
889 records.push(json!({
890 ROUTER_REFERENCE_ID_FIELD: route_reference_id(&route.name, reference),
891 ROUTER_ROUTE_NAME_FIELD: route.name,
892 ROUTER_REFERENCE_FIELD: reference,
893 ROUTER_VECTOR_FIELD: embedding,
894 }));
895 }
896 }
897
898 if !records.is_empty() {
899 let _: Vec<String> = self.index.load(&records, ROUTER_REFERENCE_ID_FIELD, None)?;
900 }
901 Ok(())
902 }
903
904 fn resolve_vector(&self, statement: Option<&str>, vector: Option<&[f32]>) -> Result<Vec<f32>> {
905 match (statement, vector) {
906 (_, Some(vector)) => {
907 if vector.len() != self.vector_dimensions {
908 return Err(crate::Error::InvalidInput(format!(
909 "router vector dimensions mismatch: expected {}, got {}",
910 self.vector_dimensions,
911 vector.len()
912 )));
913 }
914 Ok(vector.to_vec())
915 }
916 (Some(statement), None) => {
917 let vector = self.vectorizer.embed(statement)?;
918 if vector.len() != self.vector_dimensions {
919 return Err(crate::Error::InvalidInput(format!(
920 "router vector dimensions mismatch: expected {}, got {}",
921 self.vector_dimensions,
922 vector.len()
923 )));
924 }
925 Ok(vector)
926 }
927 (None, None) => Err(crate::Error::InvalidInput(
928 "must provide a statement or vector to the router".to_owned(),
929 )),
930 }
931 }
932}
933
934impl std::fmt::Debug for SemanticRouter {
935 fn fmt(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
936 formatter
937 .debug_struct("SemanticRouter")
938 .field("name", &self.name)
939 .field("routes", &self.routes)
940 .field("routing_config", &self.routing_config)
941 .field("vector_dimensions", &self.vector_dimensions)
942 .finish()
943 }
944}
945
946fn validate_route(route: &Route) -> Result<()> {
947 if route.name.trim().is_empty() {
948 return Err(crate::Error::InvalidInput(
949 "route name must not be empty".to_owned(),
950 ));
951 }
952 if route.references.is_empty() {
953 return Err(crate::Error::InvalidInput(
954 "route references must not be empty".to_owned(),
955 ));
956 }
957 if route
958 .references
959 .iter()
960 .any(|reference| reference.trim().is_empty())
961 {
962 return Err(crate::Error::InvalidInput(
963 "route references must not contain empty strings".to_owned(),
964 ));
965 }
966 let threshold = route.effective_threshold();
967 if !(0.0..=2.0).contains(&threshold) {
968 return Err(crate::Error::InvalidInput(format!(
969 "route distance threshold must be between 0 and 2, got {threshold}"
970 )));
971 }
972 Ok(())
973}
974
975fn router_config_key(name: &str) -> String {
976 format!("{name}:route_config")
977}
978
979fn router_schema(name: &str, vector_dimensions: usize, dtype: VectorDataType) -> Value {
980 json!({
981 "index": {
982 "name": name,
983 "prefix": name,
984 "storage_type": "hash",
985 },
986 "fields": [
987 { "name": ROUTER_REFERENCE_ID_FIELD, "type": "tag" },
988 { "name": ROUTER_ROUTE_NAME_FIELD, "type": "tag" },
989 { "name": ROUTER_REFERENCE_FIELD, "type": "text" },
990 {
991 "name": ROUTER_VECTOR_FIELD,
992 "type": "vector",
993 "attrs": {
994 "algorithm": "flat",
995 "dims": vector_dimensions,
996 "datatype": dtype.as_str(),
997 "distance_metric": "cosine"
998 }
999 }
1000 ]
1001 })
1002}
1003
1004fn route_reference_id(route_name: &str, reference: &str) -> String {
1005 format!("{route_name}:{}", hashify(reference))
1006}
1007
1008fn hashify(content: &str) -> String {
1009 use sha2::{Digest, Sha256};
1010
1011 let mut hasher = Sha256::new();
1012 hasher.update(content.as_bytes());
1013 let digest = hasher.finalize();
1014 let mut output = String::with_capacity(digest.len() * 2);
1015 for byte in digest {
1016 use std::fmt::Write as _;
1017 let _ = write!(&mut output, "{byte:02x}");
1018 }
1019 output
1020}
1021
1022fn query_output_documents(output: QueryOutput) -> Result<Vec<Map<String, Value>>> {
1023 match output {
1024 QueryOutput::Documents(documents) => Ok(documents),
1025 QueryOutput::Count(_) => Err(crate::Error::InvalidInput(
1026 "router queries must return documents".to_owned(),
1027 )),
1028 }
1029}
1030
1031fn parse_distance(value: Option<&Value>) -> Option<f32> {
1032 match value {
1033 Some(Value::Number(number)) => number.as_f64().map(|value| value as f32),
1034 Some(Value::String(value)) => value.parse::<f32>().ok(),
1035 _ => None,
1036 }
1037}
1038
1039fn aggregate_distances(distances: &[f32], aggregation_method: DistanceAggregationMethod) -> f32 {
1040 match aggregation_method {
1041 DistanceAggregationMethod::Avg => distances.iter().sum::<f32>() / distances.len() as f32,
1042 DistanceAggregationMethod::Min => distances.iter().copied().fold(f32::INFINITY, f32::min),
1043 DistanceAggregationMethod::Sum => distances.iter().sum::<f32>(),
1044 }
1045}
1046
1047#[cfg(test)]
1048mod tests {
1049 use super::*;
1050
1051 #[test]
1052 fn route_new_defaults() {
1053 let route = Route::new("test", vec!["ref1".to_owned()]);
1054 assert_eq!(route.name, "test");
1055 assert_eq!(route.references, vec!["ref1".to_owned()]);
1056 assert!(route.metadata.is_empty());
1057 assert!(route.distance_threshold.is_none());
1058 }
1059
1060 #[test]
1061 fn route_effective_threshold_default() {
1062 let route = Route::new("test", vec!["ref1".to_owned()]);
1063 assert_eq!(
1064 route.effective_threshold(),
1065 DEFAULT_ROUTE_DISTANCE_THRESHOLD
1066 );
1067 }
1068
1069 #[test]
1070 fn route_effective_threshold_custom() {
1071 let route = Route {
1072 distance_threshold: Some(0.3),
1073 ..Route::new("test", vec!["ref1".to_owned()])
1074 };
1075 assert_eq!(route.effective_threshold(), 0.3);
1076 }
1077
1078 #[test]
1079 fn route_match_no_match() {
1080 let rm = RouteMatch::no_match();
1081 assert!(rm.name.is_none());
1082 assert!(rm.distance.is_none());
1083 }
1084
1085 #[test]
1086 fn routing_config_default() {
1087 let config = RoutingConfig::default();
1088 assert_eq!(config.max_k, 1);
1089 assert_eq!(config.aggregation_method, DistanceAggregationMethod::Avg);
1090 }
1091
1092 #[test]
1093 fn routing_config_serde_roundtrip() {
1094 let config = RoutingConfig {
1095 max_k: 5,
1096 aggregation_method: DistanceAggregationMethod::Min,
1097 };
1098 let json = serde_json::to_string(&config).unwrap();
1099 let deserialized: RoutingConfig = serde_json::from_str(&json).unwrap();
1100 assert_eq!(deserialized.max_k, 5);
1101 assert_eq!(
1102 deserialized.aggregation_method,
1103 DistanceAggregationMethod::Min
1104 );
1105 }
1106
1107 #[test]
1108 fn validate_route_ok() {
1109 let route = Route {
1110 distance_threshold: Some(0.3),
1111 ..Route::new("greeting", vec!["hello".to_owned()])
1112 };
1113 assert!(validate_route(&route).is_ok());
1114 }
1115
1116 #[test]
1117 fn validate_route_empty_name() {
1118 let route = Route::new("", vec!["hello".to_owned()]);
1119 assert!(validate_route(&route).is_err());
1120 }
1121
1122 #[test]
1123 fn validate_route_whitespace_name() {
1124 let route = Route::new(" ", vec!["hello".to_owned()]);
1125 assert!(validate_route(&route).is_err());
1126 }
1127
1128 #[test]
1129 fn validate_route_empty_references() {
1130 let route = Route::new("test", vec![]);
1131 assert!(validate_route(&route).is_err());
1132 }
1133
1134 #[test]
1135 fn validate_route_empty_reference_string() {
1136 let route = Route::new("test", vec!["".to_owned()]);
1137 assert!(validate_route(&route).is_err());
1138 }
1139
1140 #[test]
1141 fn validate_route_bad_threshold() {
1142 let route = Route {
1143 distance_threshold: Some(-0.1),
1144 ..Route::new("test", vec!["hello".to_owned()])
1145 };
1146 assert!(validate_route(&route).is_err());
1147
1148 let route = Route {
1149 distance_threshold: Some(2.5),
1150 ..Route::new("test", vec!["hello".to_owned()])
1151 };
1152 assert!(validate_route(&route).is_err());
1153 }
1154
1155 #[test]
1156 fn route_reference_id_deterministic() {
1157 let id1 = route_reference_id("greeting", "hello");
1158 let id2 = route_reference_id("greeting", "hello");
1159 assert_eq!(id1, id2);
1160 }
1161
1162 #[test]
1163 fn route_reference_id_different_for_different_refs() {
1164 let id1 = route_reference_id("greeting", "hello");
1165 let id2 = route_reference_id("greeting", "hi");
1166 assert_ne!(id1, id2);
1167 }
1168
1169 #[test]
1170 fn route_reference_id_different_for_different_routes() {
1171 let id1 = route_reference_id("greeting", "hello");
1172 let id2 = route_reference_id("farewell", "hello");
1173 assert_ne!(id1, id2);
1174 }
1175
1176 #[test]
1177 fn hashify_deterministic() {
1178 assert_eq!(hashify("test"), hashify("test"));
1179 assert_ne!(hashify("test"), hashify("other"));
1180 }
1181
1182 #[test]
1183 fn aggregate_distances_avg() {
1184 let result = aggregate_distances(&[0.1, 0.3], DistanceAggregationMethod::Avg);
1185 assert!((result - 0.2).abs() < 1e-6);
1186 }
1187
1188 #[test]
1189 fn aggregate_distances_min() {
1190 let result = aggregate_distances(&[0.3, 0.1, 0.5], DistanceAggregationMethod::Min);
1191 assert!((result - 0.1).abs() < 1e-6);
1192 }
1193
1194 #[test]
1195 fn aggregate_distances_sum() {
1196 let result = aggregate_distances(&[0.1, 0.2, 0.3], DistanceAggregationMethod::Sum);
1197 assert!((result - 0.6).abs() < 1e-6);
1198 }
1199
1200 #[test]
1201 fn parse_distance_number() {
1202 let val = Value::Number(serde_json::Number::from_f64(0.5).unwrap());
1203 assert_eq!(parse_distance(Some(&val)), Some(0.5));
1204 }
1205
1206 #[test]
1207 fn parse_distance_string() {
1208 let val = Value::String("0.25".to_owned());
1209 assert_eq!(parse_distance(Some(&val)), Some(0.25));
1210 }
1211
1212 #[test]
1213 fn parse_distance_none() {
1214 assert_eq!(parse_distance(None), None);
1215 }
1216
1217 #[test]
1218 fn parse_distance_invalid() {
1219 let val = Value::String("not_a_number".to_owned());
1220 assert_eq!(parse_distance(Some(&val)), None);
1221 }
1222
1223 #[test]
1224 fn router_schema_structure() {
1225 let schema = router_schema("my_router", 64, VectorDataType::Float32);
1226 assert_eq!(schema["index"]["name"], "my_router");
1227 assert_eq!(schema["index"]["prefix"], "my_router");
1228 assert_eq!(schema["index"]["storage_type"], "hash");
1229
1230 let fields = schema["fields"].as_array().unwrap();
1231 let field_names: Vec<&str> = fields.iter().filter_map(|f| f["name"].as_str()).collect();
1232 assert!(field_names.contains(&"reference_id"));
1233 assert!(field_names.contains(&"route_name"));
1234 assert!(field_names.contains(&"reference"));
1235 assert!(field_names.contains(&"vector"));
1236
1237 let vector_field = fields
1238 .iter()
1239 .find(|f| f["name"].as_str() == Some("vector"))
1240 .unwrap();
1241 assert_eq!(vector_field["attrs"]["dims"], 64);
1242 assert_eq!(vector_field["attrs"]["datatype"], "float32");
1243 }
1244
1245 #[test]
1246 fn distance_aggregation_method_serde() {
1247 let json = serde_json::to_string(&DistanceAggregationMethod::Min).unwrap();
1248 assert_eq!(json, "\"min\"");
1249 let deserialized: DistanceAggregationMethod = serde_json::from_str(&json).unwrap();
1250 assert_eq!(deserialized, DistanceAggregationMethod::Min);
1251 }
1252
1253 #[test]
1254 fn route_serde_roundtrip() {
1255 let route = Route {
1256 name: "test".to_owned(),
1257 references: vec!["hello".to_owned(), "hi".to_owned()],
1258 metadata: serde_json::Map::from_iter([("type".to_owned(), json!("greeting"))]),
1259 distance_threshold: Some(0.3),
1260 };
1261 let json = serde_json::to_string(&route).unwrap();
1262 let deserialized: Route = serde_json::from_str(&json).unwrap();
1263 assert_eq!(deserialized.name, "test");
1264 assert_eq!(deserialized.references, vec!["hello", "hi"]);
1265 assert_eq!(deserialized.distance_threshold, Some(0.3));
1266 }
1267
1268 #[test]
1269 fn router_config_key_format() {
1270 assert_eq!(router_config_key("my_router"), "my_router:route_config");
1271 }
1272
1273 #[test]
1274 fn router_schema_respects_dtype() {
1275 let schema_f64 = router_schema("my_router", 64, VectorDataType::Float64);
1276 let fields = schema_f64["fields"].as_array().unwrap();
1277 let vector_field = fields
1278 .iter()
1279 .find(|f| f["name"].as_str() == Some("vector"))
1280 .unwrap();
1281 assert_eq!(vector_field["attrs"]["datatype"], "float64");
1282
1283 let schema_bfloat16 = router_schema("my_router", 64, VectorDataType::Bfloat16);
1284 let fields = schema_bfloat16["fields"].as_array().unwrap();
1285 let vector_field = fields
1286 .iter()
1287 .find(|f| f["name"].as_str() == Some("vector"))
1288 .unwrap();
1289 assert_eq!(vector_field["attrs"]["datatype"], "bfloat16");
1290
1291 let schema_float16 = router_schema("my_router", 64, VectorDataType::Float16);
1292 let fields = schema_float16["fields"].as_array().unwrap();
1293 let vector_field = fields
1294 .iter()
1295 .find(|f| f["name"].as_str() == Some("vector"))
1296 .unwrap();
1297 assert_eq!(vector_field["attrs"]["datatype"], "float16");
1298 }
1299}