Skip to main content

dstack_memory/
file.rs

1use std::path::PathBuf;
2
3use async_trait::async_trait;
4
5use crate::{Field, MemoryProvider, Result};
6
7/// JSON file-based memory backend.
8///
9/// Each field is stored as a separate `.json` file. The field path maps directly
10/// to the filesystem: `"projects/ehb/learnings/tags"` becomes `{root}/projects/ehb/learnings/tags.json`.
11pub struct FileProvider {
12    root: PathBuf,
13}
14
15impl FileProvider {
16    pub fn new(root: PathBuf) -> Self {
17        Self { root }
18    }
19
20    /// Convert a logical path to a filesystem path, sanitizing dangerous components.
21    fn field_path(&self, path: &str) -> PathBuf {
22        let sanitized: PathBuf = path
23            .split('/')
24            .filter(|seg| !seg.is_empty() && *seg != ".." && !seg.contains('\0'))
25            .collect();
26        let mut full = self.root.join(sanitized);
27        full.set_extension("json");
28        full
29    }
30
31    /// Recursively walk the root directory collecting all `.json` files.
32    fn all_json_files(&self) -> Vec<PathBuf> {
33        let mut files = Vec::new();
34        if self.root.is_dir() {
35            self.walk_dir(&self.root, &mut files);
36        }
37        files
38    }
39
40    fn walk_dir(&self, dir: &std::path::Path, out: &mut Vec<PathBuf>) {
41        let entries = match std::fs::read_dir(dir) {
42            Ok(e) => e,
43            Err(_) => return,
44        };
45        for entry in entries.flatten() {
46            let path = entry.path();
47            if path.is_dir() {
48                self.walk_dir(&path, out);
49            } else if path.extension().and_then(|e| e.to_str()) == Some("json") {
50                out.push(path);
51            }
52        }
53    }
54
55    /// Read and deserialize a single JSON file into a Field.
56    fn read_field(&self, path: &std::path::Path) -> Result<Field> {
57        let data = std::fs::read_to_string(path)?;
58        let field: Field = serde_json::from_str(&data)?;
59        Ok(field)
60    }
61}
62
63#[async_trait]
64impl MemoryProvider for FileProvider {
65    async fn load(&self, path: &str) -> Result<Vec<Field>> {
66        let files = self.all_json_files();
67        let prefix = if path.ends_with('/') {
68            path.to_string()
69        } else {
70            format!("{}/", path)
71        };
72
73        let mut fields: Vec<Field> = files
74            .into_iter()
75            .filter_map(|f| self.read_field(&f).ok())
76            .filter(|field| field.path.starts_with(&prefix) || field.path == path)
77            .collect();
78
79        fields.sort_by(|a, b| b.updated_at.cmp(&a.updated_at));
80        Ok(fields)
81    }
82
83    async fn write(&self, field: &Field) -> Result<()> {
84        let file_path = self.field_path(&field.path);
85        if let Some(parent) = file_path.parent() {
86            std::fs::create_dir_all(parent)?;
87        }
88        let json = serde_json::to_string_pretty(field)?;
89        std::fs::write(&file_path, json)?;
90        Ok(())
91    }
92
93    async fn search(&self, query: &str) -> Result<Vec<Field>> {
94        let query_lower = query.to_lowercase();
95        let files = self.all_json_files();
96
97        let fields: Vec<Field> = files
98            .into_iter()
99            .filter_map(|f| self.read_field(&f).ok())
100            .filter(|field| {
101                field.value.to_lowercase().contains(&query_lower)
102                    || field.path.to_lowercase().contains(&query_lower)
103            })
104            .collect();
105
106        Ok(fields)
107    }
108
109    async fn delete(&self, path: &str) -> Result<()> {
110        let file_path = self.field_path(path);
111        if file_path.exists() {
112            std::fs::remove_file(&file_path)?;
113        }
114        Ok(())
115    }
116
117    async fn export_all(&self) -> Result<Vec<Field>> {
118        let files = self.all_json_files();
119        let fields: Vec<Field> = files
120            .into_iter()
121            .filter_map(|f| self.read_field(&f).ok())
122            .collect();
123        Ok(fields)
124    }
125}
126
127#[cfg(test)]
128mod tests {
129    use super::*;
130    use crate::Field;
131    use tempfile::TempDir;
132
133    async fn test_provider() -> (FileProvider, TempDir) {
134        let dir = TempDir::new().unwrap();
135        let provider = FileProvider::new(dir.path().to_path_buf());
136        (provider, dir)
137    }
138
139    #[tokio::test]
140    async fn write_and_load() {
141        let (provider, _dir) = test_provider().await;
142        let field = Field::new("project/test/key1", "hello world", "test");
143        provider.write(&field).await.unwrap();
144        let loaded = provider.load("project/test").await.unwrap();
145        assert_eq!(loaded.len(), 1);
146        assert_eq!(loaded[0].value, "hello world");
147    }
148
149    #[tokio::test]
150    async fn search_finds_match() {
151        let (provider, _dir) = test_provider().await;
152        provider
153            .write(&Field::new("a/b", "rust compiler error", "test"))
154            .await
155            .unwrap();
156        provider
157            .write(&Field::new("a/c", "python import issue", "test"))
158            .await
159            .unwrap();
160        let results = provider.search("rust").await.unwrap();
161        assert_eq!(results.len(), 1);
162        assert_eq!(results[0].path, "a/b");
163    }
164
165    #[tokio::test]
166    async fn delete_removes_field() {
167        let (provider, _dir) = test_provider().await;
168        provider
169            .write(&Field::new("x/y", "val", "test"))
170            .await
171            .unwrap();
172        provider.delete("x/y").await.unwrap();
173        let loaded = provider.load("x").await.unwrap();
174        assert!(loaded.is_empty());
175    }
176
177    #[tokio::test]
178    async fn load_empty_returns_empty() {
179        let (provider, _dir) = test_provider().await;
180        let loaded = provider.load("nonexistent").await.unwrap();
181        assert!(loaded.is_empty());
182    }
183
184    #[tokio::test]
185    async fn export_all_returns_everything() {
186        let (provider, _dir) = test_provider().await;
187        provider
188            .write(&Field::new("a/1", "v1", "test"))
189            .await
190            .unwrap();
191        provider
192            .write(&Field::new("b/2", "v2", "test"))
193            .await
194            .unwrap();
195        let all = provider.export_all().await.unwrap();
196        assert_eq!(all.len(), 2);
197    }
198}