Skip to main content

omnigraph/
storage.rs

1use std::env;
2use std::fmt::Debug;
3use std::path::{Path, PathBuf};
4use std::sync::Arc;
5
6use async_trait::async_trait;
7use futures::TryStreamExt;
8use object_store::aws::AmazonS3Builder;
9use object_store::path::Path as ObjectPath;
10use object_store::{DynObjectStore, ObjectStore, PutMode, PutPayload};
11use tokio::io::AsyncWriteExt;
12use url::Url;
13
14use crate::error::{OmniError, Result};
15
16const FILE_SCHEME_PREFIX: &str = "file://";
17const S3_SCHEME_PREFIX: &str = "s3://";
18
19#[async_trait]
20pub trait StorageAdapter: Debug + Send + Sync {
21    async fn read_text(&self, uri: &str) -> Result<String>;
22    async fn write_text(&self, uri: &str, contents: &str) -> Result<()>;
23    /// Write a text object only if no object exists at `uri`.
24    ///
25    /// Returns `Ok(true)` when this call created the object, `Ok(false)`
26    /// when the object already existed, and propagates every other storage
27    /// error. Callers use this to establish ownership before running
28    /// best-effort cleanup on partial failure.
29    async fn write_text_if_absent(&self, uri: &str, contents: &str) -> Result<bool>;
30    async fn exists(&self, uri: &str) -> Result<bool>;
31    /// Move a file from `from_uri` to `to_uri`, replacing any existing file at
32    /// `to_uri`. Atomic on local POSIX; on S3 implemented as copy + delete
33    /// (NOT atomic — callers that depend on atomicity for crash recovery must
34    /// tolerate "both source and destination exist after a crash").
35    async fn rename_text(&self, from_uri: &str, to_uri: &str) -> Result<()>;
36    /// Remove a file. Returns Ok(()) if the file does not exist.
37    async fn delete(&self, uri: &str) -> Result<()>;
38    /// List all files (non-recursively, files only) directly under `dir_uri`.
39    /// Returns full URIs (same scheme as `dir_uri`). The result is unordered.
40    /// Returns Ok(empty) if the directory does not exist or is empty.
41    async fn list_dir(&self, dir_uri: &str) -> Result<Vec<String>>;
42}
43
44#[derive(Debug, Clone, Copy, PartialEq, Eq)]
45pub enum StorageKind {
46    Local,
47    S3,
48}
49
50#[derive(Debug, Default)]
51pub struct LocalStorageAdapter;
52
53#[derive(Debug)]
54pub struct S3StorageAdapter {
55    bucket: String,
56    store: Arc<DynObjectStore>,
57}
58
59#[derive(Debug, Clone, PartialEq, Eq)]
60struct S3Location {
61    bucket: String,
62    key: String,
63}
64
65#[async_trait]
66impl StorageAdapter for LocalStorageAdapter {
67    async fn read_text(&self, uri: &str) -> Result<String> {
68        let path = local_path_from_uri(uri)?;
69        Ok(tokio::fs::read_to_string(&path).await?)
70    }
71
72    async fn write_text(&self, uri: &str, contents: &str) -> Result<()> {
73        let path = local_path_from_uri(uri)?;
74        // Ensure parent directory exists. S3 has no equivalent (PutObject
75        // is path-agnostic). For local fs, callers like the recovery
76        // sidecar protocol expect transparent directory creation under
77        // the graph root (the `__recovery/` directory doesn't pre-exist;
78        // first sidecar write creates it).
79        if let Some(parent) = path.parent() {
80            if !parent.as_os_str().is_empty() {
81                tokio::fs::create_dir_all(parent).await?;
82            }
83        }
84        tokio::fs::write(&path, contents).await?;
85        Ok(())
86    }
87
88    async fn write_text_if_absent(&self, uri: &str, contents: &str) -> Result<bool> {
89        let path = local_path_from_uri(uri)?;
90        if let Some(parent) = path.parent() {
91            if !parent.as_os_str().is_empty() {
92                tokio::fs::create_dir_all(parent).await?;
93            }
94        }
95        let mut file = match tokio::fs::OpenOptions::new()
96            .write(true)
97            .create_new(true)
98            .open(&path)
99            .await
100        {
101            Ok(file) => file,
102            Err(err) if err.kind() == std::io::ErrorKind::AlreadyExists => return Ok(false),
103            Err(err) => return Err(err.into()),
104        };
105        if let Err(err) = file.write_all(contents.as_bytes()).await {
106            let _ = tokio::fs::remove_file(&path).await;
107            return Err(err.into());
108        }
109        Ok(true)
110    }
111
112    async fn exists(&self, uri: &str) -> Result<bool> {
113        Ok(local_path_from_uri(uri)?.exists())
114    }
115
116    async fn rename_text(&self, from_uri: &str, to_uri: &str) -> Result<()> {
117        let from = local_path_from_uri(from_uri)?;
118        let to = local_path_from_uri(to_uri)?;
119        tokio::fs::rename(&from, &to).await?;
120        Ok(())
121    }
122
123    async fn delete(&self, uri: &str) -> Result<()> {
124        let path = local_path_from_uri(uri)?;
125        match tokio::fs::remove_file(&path).await {
126            Ok(()) => Ok(()),
127            Err(err) if err.kind() == std::io::ErrorKind::NotFound => Ok(()),
128            Err(err) => Err(err.into()),
129        }
130    }
131
132    async fn list_dir(&self, dir_uri: &str) -> Result<Vec<String>> {
133        let path = local_path_from_uri(dir_uri)?;
134        let mut out = Vec::new();
135        let mut entries = match tokio::fs::read_dir(&path).await {
136            Ok(e) => e,
137            Err(err) if err.kind() == std::io::ErrorKind::NotFound => return Ok(out),
138            Err(err) => return Err(err.into()),
139        };
140        let dir_str = dir_uri.trim_end_matches('/');
141        while let Some(entry) = entries.next_entry().await? {
142            let ft = entry.file_type().await?;
143            if !ft.is_file() {
144                continue;
145            }
146            if let Some(name) = entry.file_name().to_str() {
147                out.push(format!("{}/{}", dir_str, name));
148            }
149        }
150        Ok(out)
151    }
152}
153
154#[async_trait]
155impl StorageAdapter for S3StorageAdapter {
156    async fn read_text(&self, uri: &str) -> Result<String> {
157        let location = self.object_path(uri)?;
158        let bytes = self
159            .store
160            .get(&location)
161            .await
162            .map_err(|err| storage_backend_error("read", uri, err))?
163            .bytes()
164            .await
165            .map_err(|err| storage_backend_error("read", uri, err))?;
166
167        String::from_utf8(bytes.to_vec()).map_err(|err| {
168            OmniError::manifest_internal(format!("storage read failed for '{}': {}", uri, err))
169        })
170    }
171
172    async fn write_text(&self, uri: &str, contents: &str) -> Result<()> {
173        let location = self.object_path(uri)?;
174        self.store
175            .put(&location, PutPayload::from(contents.as_bytes().to_vec()))
176            .await
177            .map_err(|err| storage_backend_error("write", uri, err))?;
178        Ok(())
179    }
180
181    async fn write_text_if_absent(&self, uri: &str, contents: &str) -> Result<bool> {
182        let location = self.object_path(uri)?;
183        match self
184            .store
185            .put_opts(
186                &location,
187                PutPayload::from(contents.as_bytes().to_vec()),
188                PutMode::Create.into(),
189            )
190            .await
191        {
192            Ok(_) => Ok(true),
193            Err(object_store::Error::AlreadyExists { .. })
194            | Err(object_store::Error::Precondition { .. }) => Ok(false),
195            Err(err) => Err(storage_backend_error("write_if_absent", uri, err)),
196        }
197    }
198
199    async fn exists(&self, uri: &str) -> Result<bool> {
200        let location = self.object_path(uri)?;
201        match self.store.head(&location).await {
202            Ok(_) => Ok(true),
203            Err(object_store::Error::NotFound { .. }) => {
204                let mut entries = self.store.list(Some(&location));
205                let has_prefix_entries = entries
206                    .try_next()
207                    .await
208                    .map_err(|err| storage_backend_error("exists", uri, err))?
209                    .is_some();
210                Ok(has_prefix_entries)
211            }
212            Err(err) => Err(storage_backend_error("exists", uri, err)),
213        }
214    }
215
216    async fn rename_text(&self, from_uri: &str, to_uri: &str) -> Result<()> {
217        // S3 has no atomic rename. Copy then delete; if the copy succeeds and
218        // the delete fails (or the process crashes between them), both
219        // source and destination exist with the same content. Recovery code
220        // must tolerate this case — see schema_state::recover_schema_state_files.
221        let from = self.object_path(from_uri)?;
222        let to = self.object_path(to_uri)?;
223        self.store
224            .copy(&from, &to)
225            .await
226            .map_err(|err| storage_backend_error("rename:copy", from_uri, err))?;
227        self.store
228            .delete(&from)
229            .await
230            .map_err(|err| storage_backend_error("rename:delete", from_uri, err))?;
231        Ok(())
232    }
233
234    async fn delete(&self, uri: &str) -> Result<()> {
235        let location = self.object_path(uri)?;
236        match self.store.delete(&location).await {
237            Ok(()) => Ok(()),
238            Err(object_store::Error::NotFound { .. }) => Ok(()),
239            Err(err) => Err(storage_backend_error("delete", uri, err)),
240        }
241    }
242
243    async fn list_dir(&self, dir_uri: &str) -> Result<Vec<String>> {
244        // Normalize: ensure the URI describes a directory (trailing '/') so
245        // we don't match sibling paths with a shared prefix
246        // (e.g. listing `__recovery` shouldn't match `__recovery_log/...`).
247        let dir_with_slash = if dir_uri.ends_with('/') {
248            dir_uri.to_string()
249        } else {
250            format!("{}/", dir_uri)
251        };
252        // object_store::Path strips the trailing '/'; re-add it for filtering.
253        let prefix_loc = self.object_path(&dir_with_slash)?;
254        let prefix_with_slash = format!("{}/", prefix_loc.as_ref());
255
256        let mut entries = self.store.list(Some(&prefix_loc));
257        let mut out = Vec::new();
258        let bucket_root = format!("{}{}/", S3_SCHEME_PREFIX, self.bucket);
259        while let Some(meta) = entries
260            .try_next()
261            .await
262            .map_err(|err| storage_backend_error("list_dir", dir_uri, err))?
263        {
264            let key_str = meta.location.as_ref();
265            // Require the directory boundary to filter out sibling-prefix
266            // matches (object_store's `list` is prefix-based, not dir-based).
267            if !key_str.starts_with(&prefix_with_slash) {
268                continue;
269            }
270            let suffix = &key_str[prefix_with_slash.len()..];
271            // Non-recursive: skip anything inside a sub-directory.
272            if suffix.contains('/') {
273                continue;
274            }
275            out.push(format!("{}{}", bucket_root, key_str));
276        }
277        Ok(out)
278    }
279}
280
281impl S3StorageAdapter {
282    fn from_root_uri(root_uri: &str) -> Result<Self> {
283        let location = parse_s3_uri(root_uri)?;
284        let mut builder = AmazonS3Builder::from_env().with_bucket_name(&location.bucket);
285
286        if let Some(endpoint) = env::var("AWS_ENDPOINT_URL_S3")
287            .ok()
288            .or_else(|| env::var("AWS_ENDPOINT_URL").ok())
289        {
290            builder = builder.with_endpoint(&endpoint);
291            if endpoint.starts_with("http://") || env_var_truthy("AWS_ALLOW_HTTP") {
292                builder = builder.with_allow_http(true);
293            }
294        }
295
296        if env_var_truthy("AWS_S3_FORCE_PATH_STYLE") {
297            builder = builder.with_virtual_hosted_style_request(false);
298        }
299
300        let store = builder.build().map_err(|err| {
301            OmniError::manifest_internal(format!(
302                "failed to initialize s3 storage for '{}': {}",
303                root_uri, err
304            ))
305        })?;
306
307        Ok(Self {
308            bucket: location.bucket,
309            store: Arc::new(store),
310        })
311    }
312
313    fn object_path(&self, uri: &str) -> Result<ObjectPath> {
314        let location = parse_s3_uri(uri)?;
315        if location.bucket != self.bucket {
316            return Err(OmniError::manifest_internal(format!(
317                "s3 storage bucket mismatch for '{}': expected '{}', found '{}'",
318                uri, self.bucket, location.bucket
319            )));
320        }
321        if location.key.is_empty() {
322            return Err(OmniError::manifest_internal(format!(
323                "s3 storage path is empty for '{}'",
324                uri
325            )));
326        }
327        ObjectPath::parse(&location.key).map_err(|err| {
328            OmniError::manifest_internal(format!("invalid s3 object path for '{}': {}", uri, err))
329        })
330    }
331}
332
333pub fn storage_kind_for_uri(uri: &str) -> StorageKind {
334    if uri.starts_with(S3_SCHEME_PREFIX) {
335        StorageKind::S3
336    } else {
337        StorageKind::Local
338    }
339}
340
341pub fn storage_for_uri(uri: &str) -> Result<Arc<dyn StorageAdapter>> {
342    match storage_kind_for_uri(uri) {
343        StorageKind::Local => Ok(Arc::new(LocalStorageAdapter)),
344        StorageKind::S3 => Ok(Arc::new(S3StorageAdapter::from_root_uri(uri)?)),
345    }
346}
347
348pub fn normalize_root_uri(uri: &str) -> Result<String> {
349    match storage_kind_for_uri(uri) {
350        StorageKind::Local => {
351            let path = local_path_from_uri(uri)?;
352            Ok(normalize_local_path(&path))
353        }
354        StorageKind::S3 => Ok(trim_trailing_slashes(uri)),
355    }
356}
357
358pub fn join_uri(root_uri: &str, relative_path: &str) -> String {
359    let relative_path = relative_path.trim_start_matches('/');
360    match storage_kind_for_uri(root_uri) {
361        StorageKind::S3 => {
362            let root = trim_trailing_slashes(root_uri);
363            if root.is_empty() {
364                relative_path.to_string()
365            } else {
366                format!("{}/{}", root, relative_path)
367            }
368        }
369        StorageKind::Local => {
370            let root = if root_uri.starts_with(FILE_SCHEME_PREFIX) {
371                local_path_from_file_uri(root_uri)
372                    .map(|path| normalize_local_path(&path))
373                    .unwrap_or_else(|_| trim_trailing_slashes(root_uri))
374            } else {
375                normalize_local_path(Path::new(root_uri))
376            };
377            let joined = Path::new(&root).join(relative_path);
378            normalize_local_path(&joined)
379        }
380    }
381}
382
383fn local_path_from_uri(uri: &str) -> Result<PathBuf> {
384    if uri.starts_with(FILE_SCHEME_PREFIX) {
385        return local_path_from_file_uri(uri);
386    }
387    Ok(PathBuf::from(uri))
388}
389
390fn local_path_from_file_uri(uri: &str) -> Result<PathBuf> {
391    let url = Url::parse(uri).map_err(|err| {
392        OmniError::manifest_internal(format!("invalid file uri '{}': {}", uri, err))
393    })?;
394    url.to_file_path()
395        .map_err(|_| OmniError::manifest_internal(format!("invalid file uri '{}'", uri)))
396}
397
398fn parse_s3_uri(uri: &str) -> Result<S3Location> {
399    let url = Url::parse(uri).map_err(|err| {
400        OmniError::manifest_internal(format!("invalid s3 uri '{}': {}", uri, err))
401    })?;
402    if url.scheme() != "s3" {
403        return Err(OmniError::manifest_internal(format!(
404            "unsupported s3 uri '{}'",
405            uri
406        )));
407    }
408    let bucket = url
409        .host_str()
410        .ok_or_else(|| OmniError::manifest_internal(format!("missing s3 bucket in '{}'", uri)))?;
411    Ok(S3Location {
412        bucket: bucket.to_string(),
413        key: url.path().trim_start_matches('/').to_string(),
414    })
415}
416
417fn storage_backend_error(action: &str, uri: &str, err: impl std::fmt::Display) -> OmniError {
418    OmniError::manifest_internal(format!("storage {} failed for '{}': {}", action, uri, err))
419}
420
421fn normalize_local_path(path: &Path) -> String {
422    let raw = path.as_os_str().to_string_lossy();
423    if raw == "/" {
424        return raw.to_string();
425    }
426    trim_trailing_slashes(&raw)
427}
428
429fn trim_trailing_slashes(value: &str) -> String {
430    let trimmed = value.trim_end_matches('/');
431    if trimmed.is_empty() {
432        value.to_string()
433    } else {
434        trimmed.to_string()
435    }
436}
437
438fn env_var_truthy(key: &str) -> bool {
439    matches!(
440        env::var(key).ok().as_deref(),
441        Some("1" | "true" | "TRUE" | "True" | "yes" | "YES" | "on" | "ON")
442    )
443}
444
445#[cfg(test)]
446mod tests {
447    use super::*;
448
449    #[test]
450    fn storage_backend_selection_is_scheme_aware() {
451        assert_eq!(storage_kind_for_uri("/tmp/graph"), StorageKind::Local);
452        assert_eq!(
453            storage_kind_for_uri("file:///tmp/graph"),
454            StorageKind::Local
455        );
456        assert_eq!(
457            storage_kind_for_uri("s3://omnigraph-preview/graph"),
458            StorageKind::S3
459        );
460    }
461
462    #[test]
463    fn normalize_root_uri_preserves_local_and_s3_shapes() {
464        assert_eq!(
465            normalize_root_uri("/tmp/omnigraph/").unwrap(),
466            "/tmp/omnigraph"
467        );
468        assert_eq!(
469            normalize_root_uri("file:///tmp/omnigraph/").unwrap(),
470            "/tmp/omnigraph"
471        );
472        assert_eq!(
473            normalize_root_uri("s3://bucket/prefix/").unwrap(),
474            "s3://bucket/prefix"
475        );
476    }
477
478    #[test]
479    fn join_uri_handles_local_file_and_s3_roots() {
480        assert_eq!(
481            join_uri("/tmp/omnigraph", "_schema.pg"),
482            "/tmp/omnigraph/_schema.pg"
483        );
484        assert_eq!(
485            join_uri("file:///tmp/omnigraph", "_schema.pg"),
486            "/tmp/omnigraph/_schema.pg"
487        );
488        assert_eq!(
489            join_uri("s3://bucket/prefix", "_schema.pg"),
490            "s3://bucket/prefix/_schema.pg"
491        );
492    }
493
494    #[test]
495    fn parse_s3_uri_splits_bucket_and_key() {
496        let location = parse_s3_uri("s3://bucket/graph/_schema.pg").unwrap();
497        assert_eq!(location.bucket, "bucket");
498        assert_eq!(location.key, "graph/_schema.pg");
499    }
500
501    #[tokio::test]
502    async fn local_write_text_if_absent_creates_once_without_overwrite() {
503        let dir = tempfile::tempdir().unwrap();
504        let uri = dir.path().join("claim.txt");
505        let uri = uri.to_str().unwrap();
506        let storage = LocalStorageAdapter;
507
508        assert!(storage.write_text_if_absent(uri, "first").await.unwrap());
509        assert!(!storage.write_text_if_absent(uri, "second").await.unwrap());
510        assert_eq!(storage.read_text(uri).await.unwrap(), "first");
511    }
512}