Skip to main content

graphmind/
tenant_store.rs

1//! Multi-tenant graph store manager
2//!
3//! Maintains a separate GraphStore per tenant/graph name, providing
4//! true data isolation between tenants.
5
6use crate::graph::GraphStore;
7use std::collections::HashMap;
8use std::sync::Arc;
9use tokio::sync::RwLock;
10
11/// Thread-safe multi-tenant store that holds one GraphStore per graph name.
12#[derive(Clone)]
13pub struct TenantStoreManager {
14    stores: Arc<RwLock<HashMap<String, Arc<RwLock<GraphStore>>>>>,
15}
16
17impl TenantStoreManager {
18    pub fn new() -> Self {
19        let mut stores = HashMap::new();
20        stores.insert(
21            "default".to_string(),
22            Arc::new(RwLock::new(GraphStore::new())),
23        );
24        Self {
25            stores: Arc::new(RwLock::new(stores)),
26        }
27    }
28
29    /// Create a TenantStoreManager with an existing store as the "default" tenant.
30    pub fn with_default(store: Arc<RwLock<GraphStore>>) -> Self {
31        let mut stores = HashMap::new();
32        stores.insert("default".to_string(), store);
33        Self {
34            stores: Arc::new(RwLock::new(stores)),
35        }
36    }
37
38    /// Get or create a store for a tenant/graph name.
39    pub async fn get_store(&self, graph: &str) -> Arc<RwLock<GraphStore>> {
40        // Fast path: read lock
41        {
42            let stores = self.stores.read().await;
43            if let Some(store) = stores.get(graph) {
44                return Arc::clone(store);
45            }
46        }
47        // Slow path: write lock, create new store
48        let mut stores = self.stores.write().await;
49        stores
50            .entry(graph.to_string())
51            .or_insert_with(|| Arc::new(RwLock::new(GraphStore::new())))
52            .clone()
53    }
54
55    /// List all tenant/graph names
56    pub async fn list_graphs(&self) -> Vec<String> {
57        let stores = self.stores.read().await;
58        stores.keys().cloned().collect()
59    }
60
61    /// Delete a tenant/graph (returns true if it existed)
62    pub async fn delete_graph(&self, graph: &str) -> bool {
63        if graph == "default" {
64            return false;
65        } // protect default
66        let mut stores = self.stores.write().await;
67        stores.remove(graph).is_some()
68    }
69
70    /// Get stats for all tenants
71    pub async fn stats(&self) -> Vec<(String, usize, usize)> {
72        let stores = self.stores.read().await;
73        let mut result = Vec::new();
74        for (name, store) in stores.iter() {
75            let s = store.read().await;
76            result.push((name.clone(), s.node_count(), s.edge_count()));
77        }
78        result
79    }
80}
81
82impl Default for TenantStoreManager {
83    fn default() -> Self {
84        Self::new()
85    }
86}
87
88#[cfg(test)]
89mod tests {
90    use super::*;
91
92    #[tokio::test]
93    async fn test_new_has_default() {
94        let mgr = TenantStoreManager::new();
95        let graphs = mgr.list_graphs().await;
96        assert!(graphs.contains(&"default".to_string()));
97    }
98
99    #[tokio::test]
100    async fn test_get_store_creates_on_demand() {
101        let mgr = TenantStoreManager::new();
102        let store = mgr.get_store("tenant_a").await;
103        let guard = store.read().await;
104        assert_eq!(guard.node_count(), 0);
105
106        let graphs = mgr.list_graphs().await;
107        assert!(graphs.contains(&"tenant_a".to_string()));
108    }
109
110    #[tokio::test]
111    async fn test_get_store_returns_same_instance() {
112        let mgr = TenantStoreManager::new();
113        let s1 = mgr.get_store("mydb").await;
114        let s2 = mgr.get_store("mydb").await;
115        assert!(Arc::ptr_eq(&s1, &s2));
116    }
117
118    #[tokio::test]
119    async fn test_tenant_isolation() {
120        let mgr = TenantStoreManager::new();
121
122        // Write to tenant_a
123        {
124            let store = mgr.get_store("tenant_a").await;
125            let mut guard = store.write().await;
126            guard.create_node("Person");
127        }
128
129        // tenant_b should be empty
130        {
131            let store = mgr.get_store("tenant_b").await;
132            let guard = store.read().await;
133            assert_eq!(guard.node_count(), 0);
134        }
135
136        // tenant_a should have 1 node
137        {
138            let store = mgr.get_store("tenant_a").await;
139            let guard = store.read().await;
140            assert_eq!(guard.node_count(), 1);
141        }
142    }
143
144    #[tokio::test]
145    async fn test_delete_graph() {
146        let mgr = TenantStoreManager::new();
147        mgr.get_store("temp").await;
148        assert!(mgr.delete_graph("temp").await);
149        assert!(!mgr.delete_graph("temp").await); // already gone
150    }
151
152    #[tokio::test]
153    async fn test_cannot_delete_default() {
154        let mgr = TenantStoreManager::new();
155        assert!(!mgr.delete_graph("default").await);
156    }
157
158    #[tokio::test]
159    async fn test_stats() {
160        let mgr = TenantStoreManager::new();
161        {
162            let store = mgr.get_store("default").await;
163            let mut guard = store.write().await;
164            let a = guard.create_node("Person");
165            let b = guard.create_node("Person");
166            guard.create_edge(a, b, "KNOWS").unwrap();
167        }
168        let stats = mgr.stats().await;
169        let default_stats = stats.iter().find(|(name, _, _)| name == "default").unwrap();
170        assert_eq!(default_stats.1, 2); // 2 nodes
171        assert_eq!(default_stats.2, 1); // 1 edge
172    }
173
174    #[tokio::test]
175    async fn test_with_default() {
176        let store = Arc::new(RwLock::new(GraphStore::new()));
177        {
178            let mut guard = store.write().await;
179            guard.create_node("Test");
180        }
181        let mgr = TenantStoreManager::with_default(store);
182        let default_store = mgr.get_store("default").await;
183        let guard = default_store.read().await;
184        assert_eq!(guard.node_count(), 1);
185    }
186}