langchain_rust/semantic_router/index/
memory_index.rs

1use std::collections::HashMap;
2
3use async_trait::async_trait;
4
5use crate::semantic_router::{utils::cosine_similarity, IndexError, Router};
6
7use super::Index;
8
9pub struct MemoryIndex {
10    routers: HashMap<String, Router>,
11}
12impl MemoryIndex {
13    pub fn new() -> Self {
14        Self {
15            routers: HashMap::new(),
16        }
17    }
18}
19
20#[async_trait]
21impl Index for MemoryIndex {
22    async fn add(&mut self, routers: &[Router]) -> Result<(), IndexError> {
23        for router in routers {
24            if router.embedding.is_none() {
25                return Err(IndexError::MissingEmbedding(router.name.clone()));
26            }
27            if self.routers.contains_key(&router.name) {
28                log::warn!("Router {} already exists in the index", router.name);
29            }
30            self.routers.insert(router.name.clone(), router.clone());
31        }
32
33        Ok(())
34    }
35
36    async fn delete(&mut self, router_name: &str) -> Result<(), IndexError> {
37        if self.routers.remove(router_name).is_none() {
38            log::warn!("Router {} not found in the index", router_name);
39        }
40        Ok(())
41    }
42
43    async fn query(&self, vector: &[f64], top_k: usize) -> Result<Vec<(String, f64)>, IndexError> {
44        let mut all_similarities: Vec<(String, f64)> = Vec::new();
45
46        // Compute similarity for each embedding of each router
47        for (name, router) in &self.routers {
48            if let Some(embeddings) = &router.embedding {
49                for embedding in embeddings {
50                    let similarity = cosine_similarity(vector, embedding);
51                    all_similarities.push((name.clone(), similarity));
52                }
53            }
54        }
55
56        // Sort all similarities by descending similarity score
57        all_similarities
58            .sort_unstable_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
59
60        // Only keep the top_k similarities
61        let top_similarities: Vec<(String, f64)> =
62            all_similarities.into_iter().take(top_k).collect();
63
64        Ok(top_similarities)
65    }
66
67    async fn get_routers(&self) -> Result<Vec<Router>, IndexError> {
68        let routes = self.routers.values().cloned().collect();
69        Ok(routes)
70    }
71
72    async fn get_router(&self, route_name: &str) -> Result<Router, IndexError> {
73        return self
74            .routers
75            .get(route_name)
76            .cloned()
77            .ok_or(IndexError::RouterNotFound(route_name.into()));
78    }
79
80    async fn delete_index(&mut self) -> Result<(), IndexError> {
81        self.routers.clear();
82        Ok(())
83    }
84}