Skip to main content

kanban_persistence/
registry.rs

1use crate::{PersistenceError, PersistenceStore};
2use std::sync::Arc;
3
4pub trait StoreFactory: Send + Sync {
5    fn name(&self) -> &str;
6    fn matches_content(&self, _header: &[u8]) -> bool {
7        false
8    }
9    /// Open or create a store at the given locator path.
10    ///
11    /// Implementations that perform async work (e.g. SQLite) must call this from
12    /// a multi-thread tokio runtime — `block_in_place` will panic on a
13    /// `current_thread` runtime.
14    fn create(
15        &self,
16        locator: &str,
17    ) -> Result<Arc<dyn PersistenceStore + Send + Sync>, PersistenceError>;
18}
19
20pub struct StoreRegistry {
21    factories: Vec<Box<dyn StoreFactory>>,
22}
23
24fn read_header(path: &std::path::Path, n: usize) -> std::io::Result<Vec<u8>> {
25    use std::io::Read;
26    let mut file = std::fs::File::open(path)?;
27    let mut buf = vec![0u8; n];
28    let bytes_read = file.read(&mut buf)?;
29    buf.truncate(bytes_read);
30    Ok(buf)
31}
32
33impl StoreRegistry {
34    pub fn new() -> Self {
35        Self {
36            factories: Vec::new(),
37        }
38    }
39
40    pub fn register(&mut self, factory: Box<dyn StoreFactory>) {
41        self.factories.push(factory);
42    }
43
44    pub fn is_empty(&self) -> bool {
45        self.factories.is_empty()
46    }
47
48    pub fn backend_names(&self) -> Vec<&str> {
49        self.factories.iter().map(|f| f.name()).collect()
50    }
51
52    pub fn detect_backend(&self, locator: &str) -> Option<&str> {
53        let path = std::path::Path::new(locator);
54        if path.exists() {
55            if let Ok(header) = read_header(path, 32) {
56                for factory in &self.factories {
57                    if factory.matches_content(&header) {
58                        return Some(factory.name());
59                    }
60                }
61            }
62        }
63        None
64    }
65
66    pub fn create_store(
67        &self,
68        backend: &str,
69        locator: &str,
70    ) -> Result<Arc<dyn PersistenceStore + Send + Sync>, PersistenceError> {
71        for factory in &self.factories {
72            if factory.name() == backend {
73                return factory.create(locator);
74            }
75        }
76
77        let supported: Vec<String> = self
78            .factories
79            .iter()
80            .map(|f| f.name().to_string())
81            .collect();
82        Err(PersistenceError::UnsupportedLocator {
83            locator: backend.to_string(),
84            supported,
85        })
86    }
87}
88
89impl Default for StoreRegistry {
90    fn default() -> Self {
91        Self::new()
92    }
93}
94
95#[cfg(test)]
96mod tests {
97    use super::*;
98    use crate::traits::{PersistenceMetadata, StoreSnapshot};
99    use async_trait::async_trait;
100    use std::path::{Path, PathBuf};
101
102    const SQLITE_INSTANCE_ID: uuid::Uuid = uuid::Uuid::from_u128(1);
103    const JSON_INSTANCE_ID: uuid::Uuid = uuid::Uuid::from_u128(2);
104
105    struct StubStore {
106        instance_id: uuid::Uuid,
107        path: PathBuf,
108    }
109
110    #[async_trait]
111    impl PersistenceStore for StubStore {
112        async fn save(
113            &self,
114            _snapshot: StoreSnapshot,
115        ) -> crate::PersistenceResult<PersistenceMetadata> {
116            unimplemented!()
117        }
118
119        async fn load(&self) -> crate::PersistenceResult<(StoreSnapshot, PersistenceMetadata)> {
120            unimplemented!()
121        }
122
123        async fn exists(&self) -> bool {
124            unimplemented!()
125        }
126
127        fn path(&self) -> &Path {
128            &self.path
129        }
130
131        fn instance_id(&self) -> uuid::Uuid {
132            self.instance_id
133        }
134    }
135
136    struct FakeSqliteFactory;
137    impl StoreFactory for FakeSqliteFactory {
138        fn name(&self) -> &str {
139            "sqlite"
140        }
141        fn matches_content(&self, header: &[u8]) -> bool {
142            header.starts_with(b"SQLite format 3\0")
143        }
144        fn create(
145            &self,
146            locator: &str,
147        ) -> Result<Arc<dyn PersistenceStore + Send + Sync>, PersistenceError> {
148            Ok(Arc::new(StubStore {
149                instance_id: SQLITE_INSTANCE_ID,
150                path: PathBuf::from(locator),
151            }))
152        }
153    }
154
155    struct FakeJsonFactory;
156    impl StoreFactory for FakeJsonFactory {
157        fn name(&self) -> &str {
158            "json"
159        }
160        fn matches_content(&self, header: &[u8]) -> bool {
161            let trimmed = header.iter().find(|b| !b.is_ascii_whitespace());
162            matches!(trimmed, Some(b'{') | Some(b'['))
163        }
164        fn create(
165            &self,
166            locator: &str,
167        ) -> Result<Arc<dyn PersistenceStore + Send + Sync>, PersistenceError> {
168            Ok(Arc::new(StubStore {
169                instance_id: JSON_INSTANCE_ID,
170                path: PathBuf::from(locator),
171            }))
172        }
173    }
174
175    fn registry_with_both_factories() -> StoreRegistry {
176        let mut registry = StoreRegistry::new();
177        registry.register(Box::new(FakeSqliteFactory));
178        registry.register(Box::new(FakeJsonFactory));
179        registry
180    }
181
182    #[test]
183    fn test_content_sniff_sqlite_header() {
184        let dir = tempfile::tempdir().unwrap();
185        let path = dir.path().join("data.db");
186        std::fs::write(&path, b"SQLite format 3\0extra bytes here").unwrap();
187
188        let header = read_header(&path, 16).unwrap();
189        assert!(FakeSqliteFactory.matches_content(&header));
190        assert!(!FakeJsonFactory.matches_content(&header));
191    }
192
193    #[test]
194    fn test_content_sniff_json_object() {
195        let dir = tempfile::tempdir().unwrap();
196        let path = dir.path().join("data.json");
197        std::fs::write(&path, b"{\"boards\": []}").unwrap();
198
199        let header = read_header(&path, 32).unwrap();
200        assert!(FakeJsonFactory.matches_content(&header));
201        assert!(!FakeSqliteFactory.matches_content(&header));
202    }
203
204    #[test]
205    fn test_content_beats_wrong_extension() {
206        // A .json file with SQLite content should be detected as SQLite
207        let dir = tempfile::tempdir().unwrap();
208        let path = dir.path().join("misleading.json");
209        std::fs::write(&path, b"SQLite format 3\0").unwrap();
210
211        let header = read_header(&path, 32).unwrap();
212        assert!(FakeSqliteFactory.matches_content(&header));
213        assert!(!FakeJsonFactory.matches_content(&header));
214    }
215
216    #[test]
217    fn test_create_store_by_name_returns_correct_backend() {
218        let dir = tempfile::tempdir().unwrap();
219        let path = dir.path().join("data.anything");
220
221        let registry = registry_with_both_factories();
222        let store = registry
223            .create_store("json", path.to_str().unwrap())
224            .unwrap();
225        assert_eq!(store.instance_id(), JSON_INSTANCE_ID);
226
227        let path2 = dir.path().join("data2.anything");
228        let store2 = registry
229            .create_store("sqlite", path2.to_str().unwrap())
230            .unwrap();
231        assert_eq!(store2.instance_id(), SQLITE_INSTANCE_ID);
232    }
233
234    #[test]
235    fn test_create_store_unknown_backend_returns_error() {
236        let registry = registry_with_both_factories();
237        let result = registry.create_store("postgres", "/tmp/test");
238        match result {
239            Err(PersistenceError::UnsupportedLocator { .. }) => {}
240            Err(e) => panic!("expected UnsupportedLocator, got: {e:?}"),
241            Ok(_) => panic!("expected error, got Ok"),
242        }
243    }
244}