langchain_rust/semantic_router/index/
memory_index.rs1use std::collections::HashMap;
2
3use async_trait::async_trait;
4
5use crate::semantic_router::{utils::cosine_similarity, IndexError, Router};
6
7use super::Index;
8
9pub struct MemoryIndex {
10 routers: HashMap<String, Router>,
11}
12impl MemoryIndex {
13 pub fn new() -> Self {
14 Self {
15 routers: HashMap::new(),
16 }
17 }
18}
19
20#[async_trait]
21impl Index for MemoryIndex {
22 async fn add(&mut self, routers: &[Router]) -> Result<(), IndexError> {
23 for router in routers {
24 if router.embedding.is_none() {
25 return Err(IndexError::MissingEmbedding(router.name.clone()));
26 }
27 if self.routers.contains_key(&router.name) {
28 log::warn!("Router {} already exists in the index", router.name);
29 }
30 self.routers.insert(router.name.clone(), router.clone());
31 }
32
33 Ok(())
34 }
35
36 async fn delete(&mut self, router_name: &str) -> Result<(), IndexError> {
37 if self.routers.remove(router_name).is_none() {
38 log::warn!("Router {} not found in the index", router_name);
39 }
40 Ok(())
41 }
42
43 async fn query(&self, vector: &[f64], top_k: usize) -> Result<Vec<(String, f64)>, IndexError> {
44 let mut all_similarities: Vec<(String, f64)> = Vec::new();
45
46 for (name, router) in &self.routers {
48 if let Some(embeddings) = &router.embedding {
49 for embedding in embeddings {
50 let similarity = cosine_similarity(vector, embedding);
51 all_similarities.push((name.clone(), similarity));
52 }
53 }
54 }
55
56 all_similarities
58 .sort_unstable_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
59
60 let top_similarities: Vec<(String, f64)> =
62 all_similarities.into_iter().take(top_k).collect();
63
64 Ok(top_similarities)
65 }
66
67 async fn get_routers(&self) -> Result<Vec<Router>, IndexError> {
68 let routes = self.routers.values().cloned().collect();
69 Ok(routes)
70 }
71
72 async fn get_router(&self, route_name: &str) -> Result<Router, IndexError> {
73 return self
74 .routers
75 .get(route_name)
76 .cloned()
77 .ok_or(IndexError::RouterNotFound(route_name.into()));
78 }
79
80 async fn delete_index(&mut self) -> Result<(), IndexError> {
81 self.routers.clear();
82 Ok(())
83 }
84}