Skip to main content

cognee_storage/
local_storage.rs

1use super::storage_trait::{StorageError, StorageTrait, StorageWriter};
2use async_trait::async_trait;
3use std::path::{Path, PathBuf};
4use tokio::fs;
5use tokio::io::{AsyncRead, AsyncWriteExt};
6use tracing::{debug, instrument};
7use uuid::Uuid;
8
9pub struct LocalStorage {
10    base_path: PathBuf,
11}
12
13impl LocalStorage {
14    pub fn new(base_path: PathBuf) -> Self {
15        Self { base_path }
16    }
17
18    /// Generate a UUID-based subdirectory structure for organizing files
19    /// Returns a relative path like "ab/cd/filename.txt"
20    fn generate_storage_path(&self, file_name: &str) -> String {
21        let uuid = Uuid::new_v4();
22        let uuid_str = uuid.to_string();
23
24        // Use first 4 chars for first directory, next 4 for second
25        let dir1 = &uuid_str[..2];
26        let dir2 = &uuid_str[2..4];
27
28        format!("{dir1}/{dir2}/{file_name}")
29    }
30
31    /// Resolve a location string into a filesystem path.
32    ///
33    /// Mirrors Python's `get_data_file_path()` + `open_data_file()` which
34    /// strips the `file://` scheme and uses the resulting absolute path
35    /// directly.
36    ///
37    /// Accepted inputs:
38    /// - plain relative path: `ab/cd/file.txt`  → `base_path/ab/cd/file.txt`
39    /// - absolute `file://` URI: `file:///data/ab/cd/file.txt` → `/data/ab/cd/file.txt`
40    fn resolve_location(&self, location: &str) -> PathBuf {
41        let path_str = location.strip_prefix("file://").unwrap_or(location);
42        let path = Path::new(path_str);
43
44        if path.is_absolute() {
45            // Absolute path (from a file:// URI) — use directly, just like
46            // Python's `open_data_file` does after `get_data_file_path()`.
47            path.to_path_buf()
48        } else {
49            // Relative path (plain storage location) — join with base.
50            self.base_path.join(path)
51        }
52    }
53}
54
55#[async_trait]
56impl StorageTrait for LocalStorage {
57    async fn initialize(&self) -> Result<(), StorageError> {
58        fs::create_dir_all(&self.base_path)
59            .await
60            .map_err(|e| StorageError::IoError(format!("Failed to create base directory: {e}")))
61    }
62
63    #[instrument(name = "storage.store", skip(self, data), fields(file_name, bytes = data.len()))]
64    async fn store(&self, data: &[u8], file_name: &str) -> Result<String, StorageError> {
65        let relative_path = self.generate_storage_path(file_name);
66        let full_path = self.base_path.join(&relative_path);
67
68        // Create parent directories
69        if let Some(parent) = full_path.parent() {
70            fs::create_dir_all(parent)
71                .await
72                .map_err(|e| StorageError::IoError(format!("Failed to create directory: {e}")))?;
73        }
74
75        // Write file
76        let mut file = fs::File::create(&full_path)
77            .await
78            .map_err(|e| StorageError::IoError(format!("Failed to create file: {e}")))?;
79
80        file.write_all(data)
81            .await
82            .map_err(|e| StorageError::IoError(format!("Failed to write file: {e}")))?;
83
84        file.flush()
85            .await
86            .map_err(|e| StorageError::IoError(format!("Failed to flush file: {e}")))?;
87
88        Ok(relative_path)
89    }
90
91    #[instrument(name = "storage.store_stream", skip(self, reader), fields(file_name))]
92    async fn store_stream_dyn(
93        &self,
94        reader: &mut (dyn AsyncRead + Unpin + Send),
95        file_name: &str,
96    ) -> Result<String, StorageError> {
97        let relative_path = self.generate_storage_path(file_name);
98        let full_path = self.base_path.join(&relative_path);
99
100        // Create parent directories
101        if let Some(parent) = full_path.parent() {
102            fs::create_dir_all(parent)
103                .await
104                .map_err(|e| StorageError::IoError(format!("Failed to create directory: {e}")))?;
105        }
106
107        // Create file
108        let mut file = fs::File::create(&full_path)
109            .await
110            .map_err(|e| StorageError::IoError(format!("Failed to create file: {e}")))?;
111
112        // Stream copy from reader to file
113        tokio::io::copy(reader, &mut file)
114            .await
115            .map_err(|e| StorageError::IoError(format!("Failed to write file: {e}")))?;
116
117        file.flush()
118            .await
119            .map_err(|e| StorageError::IoError(format!("Failed to flush file: {e}")))?;
120
121        Ok(relative_path)
122    }
123
124    #[instrument(name = "storage.create_writer", skip(self), fields(file_name))]
125    async fn create_writer(&self, file_name: &str) -> Result<StorageWriter, StorageError> {
126        let relative_path = self.generate_storage_path(file_name);
127        let full_path = self.base_path.join(&relative_path);
128
129        // Create parent directories
130        if let Some(parent) = full_path.parent() {
131            fs::create_dir_all(parent)
132                .await
133                .map_err(|e| StorageError::IoError(format!("Failed to create directory: {e}")))?;
134        }
135
136        // Create file
137        let file = fs::File::create(&full_path)
138            .await
139            .map_err(|e| StorageError::IoError(format!("Failed to create file: {e}")))?;
140
141        Ok(StorageWriter::new(file, relative_path))
142    }
143
144    #[instrument(name = "storage.retrieve", skip(self), fields(location))]
145    async fn retrieve(&self, location: &str) -> Result<Vec<u8>, StorageError> {
146        let full_path = self.resolve_location(location);
147
148        let bytes = fs::read(&full_path).await.map_err(|e| {
149            if e.kind() == std::io::ErrorKind::NotFound {
150                StorageError::NotFound(format!("File not found: {location}"))
151            } else {
152                StorageError::IoError(format!("Failed to read file: {e}"))
153            }
154        })?;
155        debug!(bytes = bytes.len(), "file retrieved");
156        Ok(bytes)
157    }
158
159    async fn exists(&self, location: &str) -> Result<bool, StorageError> {
160        let full_path = self.resolve_location(location);
161
162        fs::try_exists(&full_path)
163            .await
164            .map_err(|e| StorageError::IoError(format!("Failed to check file existence: {e}")))
165    }
166
167    #[instrument(name = "storage.delete", skip(self), fields(location))]
168    async fn delete(&self, location: &str) -> Result<(), StorageError> {
169        let full_path = self.resolve_location(location);
170
171        fs::remove_file(&full_path).await.map_err(|e| {
172            if e.kind() == std::io::ErrorKind::NotFound {
173                StorageError::NotFound(format!("File not found: {location}"))
174            } else {
175                StorageError::IoError(format!("Failed to delete file: {e}"))
176            }
177        })
178    }
179
180    fn get_full_path(&self, location: &str) -> PathBuf {
181        self.resolve_location(location)
182    }
183
184    fn base_path(&self) -> &str {
185        self.base_path.to_str().unwrap_or("")
186    }
187
188    async fn remove_all(&self) -> Result<(), StorageError> {
189        let mut entries = fs::read_dir(&self.base_path).await.map_err(|e| {
190            if e.kind() == std::io::ErrorKind::NotFound {
191                // Directory doesn't exist — nothing to remove.
192                return StorageError::NotFound(format!(
193                    "Base directory not found: {}",
194                    self.base_path.display()
195                ));
196            }
197            StorageError::IoError(format!("Failed to read directory: {e}"))
198        })?;
199
200        while let Some(entry) = entries
201            .next_entry()
202            .await
203            .map_err(|e| StorageError::IoError(format!("Failed to iterate directory entry: {e}")))?
204        {
205            let path = entry.path();
206            let file_type = entry
207                .file_type()
208                .await
209                .map_err(|e| StorageError::IoError(format!("Failed to get file type: {e}")))?;
210            if file_type.is_dir() {
211                fs::remove_dir_all(&path).await.map_err(|e| {
212                    StorageError::IoError(format!(
213                        "Failed to remove directory {}: {}",
214                        path.display(),
215                        e
216                    ))
217                })?;
218            } else {
219                fs::remove_file(&path).await.map_err(|e| {
220                    StorageError::IoError(format!(
221                        "Failed to remove file {}: {}",
222                        path.display(),
223                        e
224                    ))
225                })?;
226            }
227        }
228        Ok(())
229    }
230}
231
232#[cfg(test)]
233#[allow(
234    clippy::unwrap_used,
235    reason = "test code — panics are acceptable failures"
236)]
237mod tests {
238    use super::*;
239    use tempfile::TempDir;
240
241    #[tokio::test]
242    async fn test_store_and_retrieve() {
243        let temp_dir = TempDir::new().unwrap();
244        let storage = LocalStorage::new(temp_dir.path().to_path_buf());
245
246        storage.initialize().await.unwrap();
247
248        let data = b"Hello, World!";
249        let location = storage.store(data, "test.txt").await.unwrap();
250
251        let retrieved = storage.retrieve(&location).await.unwrap();
252        assert_eq!(data.to_vec(), retrieved);
253    }
254
255    #[tokio::test]
256    async fn test_exists() {
257        let temp_dir = TempDir::new().unwrap();
258        let storage = LocalStorage::new(temp_dir.path().to_path_buf());
259
260        storage.initialize().await.unwrap();
261
262        let data = b"Test data";
263        let location = storage.store(data, "exists.txt").await.unwrap();
264
265        assert!(storage.exists(&location).await.unwrap());
266        assert!(!storage.exists("nonexistent.txt").await.unwrap());
267    }
268
269    #[test]
270    fn resolve_plain_relative_path() {
271        let storage = LocalStorage::new(PathBuf::from("/data"));
272        assert_eq!(
273            storage.resolve_location("ab/cd/file.txt"),
274            PathBuf::from("/data/ab/cd/file.txt")
275        );
276    }
277
278    #[test]
279    fn resolve_absolute_file_uri() {
280        // file:// URI with an absolute path — strip scheme, use path as-is
281        // (mirrors Python's get_data_file_path for file:///abs/path)
282        let storage = LocalStorage::new(PathBuf::from("/data"));
283        assert_eq!(
284            storage.resolve_location("file:///data/ab/cd/file.txt"),
285            PathBuf::from("/data/ab/cd/file.txt")
286        );
287    }
288
289    #[test]
290    fn resolve_absolute_file_uri_different_base() {
291        // URI points to a different directory than base_path — still works
292        let storage = LocalStorage::new(PathBuf::from("/data"));
293        assert_eq!(
294            storage.resolve_location("file:///other/ab/cd/file.txt"),
295            PathBuf::from("/other/ab/cd/file.txt")
296        );
297    }
298
299    #[tokio::test]
300    async fn test_retrieve_with_file_uri() {
301        let temp_dir = TempDir::new().unwrap();
302        let storage = LocalStorage::new(temp_dir.path().to_path_buf());
303        storage.initialize().await.unwrap();
304
305        let data = b"URI test data";
306        let relative = storage.store(data, "uri_test.txt").await.unwrap();
307
308        // Build a file:// URI the same way the ingestion pipeline does
309        let uri = format!("file://{}", temp_dir.path().join(&relative).display());
310
311        let retrieved = storage.retrieve(&uri).await.unwrap();
312        assert_eq!(data.to_vec(), retrieved);
313    }
314
315    #[tokio::test]
316    async fn test_delete() {
317        let temp_dir = TempDir::new().unwrap();
318        let storage = LocalStorage::new(temp_dir.path().to_path_buf());
319
320        storage.initialize().await.unwrap();
321
322        let data = b"To be deleted";
323        let location = storage.store(data, "delete.txt").await.unwrap();
324
325        assert!(storage.exists(&location).await.unwrap());
326
327        storage.delete(&location).await.unwrap();
328
329        assert!(!storage.exists(&location).await.unwrap());
330    }
331}