Skip to main content

omnigraph/
storage.rs

1use std::env;
2use std::fmt::Debug;
3use std::path::{Path, PathBuf};
4use std::sync::Arc;
5
6use async_trait::async_trait;
7use futures::TryStreamExt;
8use object_store::aws::AmazonS3Builder;
9use object_store::path::Path as ObjectPath;
10use object_store::{DynObjectStore, ObjectStore, PutPayload};
11use url::Url;
12
13use crate::error::{OmniError, Result};
14
15const FILE_SCHEME_PREFIX: &str = "file://";
16const S3_SCHEME_PREFIX: &str = "s3://";
17
18#[async_trait]
19pub trait StorageAdapter: Debug + Send + Sync {
20    async fn read_text(&self, uri: &str) -> Result<String>;
21    async fn write_text(&self, uri: &str, contents: &str) -> Result<()>;
22    async fn exists(&self, uri: &str) -> Result<bool>;
23}
24
25#[derive(Debug, Clone, Copy, PartialEq, Eq)]
26pub enum StorageKind {
27    Local,
28    S3,
29}
30
31#[derive(Debug, Default)]
32pub struct LocalStorageAdapter;
33
34#[derive(Debug)]
35pub struct S3StorageAdapter {
36    bucket: String,
37    store: Arc<DynObjectStore>,
38}
39
40#[derive(Debug, Clone, PartialEq, Eq)]
41struct S3Location {
42    bucket: String,
43    key: String,
44}
45
46#[async_trait]
47impl StorageAdapter for LocalStorageAdapter {
48    async fn read_text(&self, uri: &str) -> Result<String> {
49        let path = local_path_from_uri(uri)?;
50        Ok(tokio::fs::read_to_string(&path).await?)
51    }
52
53    async fn write_text(&self, uri: &str, contents: &str) -> Result<()> {
54        let path = local_path_from_uri(uri)?;
55        tokio::fs::write(&path, contents).await?;
56        Ok(())
57    }
58
59    async fn exists(&self, uri: &str) -> Result<bool> {
60        Ok(local_path_from_uri(uri)?.exists())
61    }
62}
63
64#[async_trait]
65impl StorageAdapter for S3StorageAdapter {
66    async fn read_text(&self, uri: &str) -> Result<String> {
67        let location = self.object_path(uri)?;
68        let bytes = self
69            .store
70            .get(&location)
71            .await
72            .map_err(|err| storage_backend_error("read", uri, err))?
73            .bytes()
74            .await
75            .map_err(|err| storage_backend_error("read", uri, err))?;
76
77        String::from_utf8(bytes.to_vec()).map_err(|err| {
78            OmniError::manifest_internal(format!("storage read failed for '{}': {}", uri, err))
79        })
80    }
81
82    async fn write_text(&self, uri: &str, contents: &str) -> Result<()> {
83        let location = self.object_path(uri)?;
84        self.store
85            .put(&location, PutPayload::from(contents.as_bytes().to_vec()))
86            .await
87            .map_err(|err| storage_backend_error("write", uri, err))?;
88        Ok(())
89    }
90
91    async fn exists(&self, uri: &str) -> Result<bool> {
92        let location = self.object_path(uri)?;
93        match self.store.head(&location).await {
94            Ok(_) => Ok(true),
95            Err(object_store::Error::NotFound { .. }) => {
96                let mut entries = self.store.list(Some(&location));
97                let has_prefix_entries = entries
98                    .try_next()
99                    .await
100                    .map_err(|err| storage_backend_error("exists", uri, err))?
101                    .is_some();
102                Ok(has_prefix_entries)
103            }
104            Err(err) => Err(storage_backend_error("exists", uri, err)),
105        }
106    }
107}
108
109impl S3StorageAdapter {
110    fn from_root_uri(root_uri: &str) -> Result<Self> {
111        let location = parse_s3_uri(root_uri)?;
112        let mut builder = AmazonS3Builder::from_env().with_bucket_name(&location.bucket);
113
114        if let Some(endpoint) = env::var("AWS_ENDPOINT_URL_S3")
115            .ok()
116            .or_else(|| env::var("AWS_ENDPOINT_URL").ok())
117        {
118            builder = builder.with_endpoint(&endpoint);
119            if endpoint.starts_with("http://") || env_var_truthy("AWS_ALLOW_HTTP") {
120                builder = builder.with_allow_http(true);
121            }
122        }
123
124        if env_var_truthy("AWS_S3_FORCE_PATH_STYLE") {
125            builder = builder.with_virtual_hosted_style_request(false);
126        }
127
128        let store = builder.build().map_err(|err| {
129            OmniError::manifest_internal(format!(
130                "failed to initialize s3 storage for '{}': {}",
131                root_uri, err
132            ))
133        })?;
134
135        Ok(Self {
136            bucket: location.bucket,
137            store: Arc::new(store),
138        })
139    }
140
141    fn object_path(&self, uri: &str) -> Result<ObjectPath> {
142        let location = parse_s3_uri(uri)?;
143        if location.bucket != self.bucket {
144            return Err(OmniError::manifest_internal(format!(
145                "s3 storage bucket mismatch for '{}': expected '{}', found '{}'",
146                uri, self.bucket, location.bucket
147            )));
148        }
149        if location.key.is_empty() {
150            return Err(OmniError::manifest_internal(format!(
151                "s3 storage path is empty for '{}'",
152                uri
153            )));
154        }
155        ObjectPath::parse(&location.key).map_err(|err| {
156            OmniError::manifest_internal(format!("invalid s3 object path for '{}': {}", uri, err))
157        })
158    }
159}
160
161pub fn storage_kind_for_uri(uri: &str) -> StorageKind {
162    if uri.starts_with(S3_SCHEME_PREFIX) {
163        StorageKind::S3
164    } else {
165        StorageKind::Local
166    }
167}
168
169pub fn storage_for_uri(uri: &str) -> Result<Arc<dyn StorageAdapter>> {
170    match storage_kind_for_uri(uri) {
171        StorageKind::Local => Ok(Arc::new(LocalStorageAdapter)),
172        StorageKind::S3 => Ok(Arc::new(S3StorageAdapter::from_root_uri(uri)?)),
173    }
174}
175
176pub fn normalize_root_uri(uri: &str) -> Result<String> {
177    match storage_kind_for_uri(uri) {
178        StorageKind::Local => {
179            let path = local_path_from_uri(uri)?;
180            Ok(normalize_local_path(&path))
181        }
182        StorageKind::S3 => Ok(trim_trailing_slashes(uri)),
183    }
184}
185
186pub fn join_uri(root_uri: &str, relative_path: &str) -> String {
187    let relative_path = relative_path.trim_start_matches('/');
188    match storage_kind_for_uri(root_uri) {
189        StorageKind::S3 => {
190            let root = trim_trailing_slashes(root_uri);
191            if root.is_empty() {
192                relative_path.to_string()
193            } else {
194                format!("{}/{}", root, relative_path)
195            }
196        }
197        StorageKind::Local => {
198            let root = if root_uri.starts_with(FILE_SCHEME_PREFIX) {
199                local_path_from_file_uri(root_uri)
200                    .map(|path| normalize_local_path(&path))
201                    .unwrap_or_else(|_| trim_trailing_slashes(root_uri))
202            } else {
203                normalize_local_path(Path::new(root_uri))
204            };
205            let joined = Path::new(&root).join(relative_path);
206            normalize_local_path(&joined)
207        }
208    }
209}
210
211fn local_path_from_uri(uri: &str) -> Result<PathBuf> {
212    if uri.starts_with(FILE_SCHEME_PREFIX) {
213        return local_path_from_file_uri(uri);
214    }
215    Ok(PathBuf::from(uri))
216}
217
218fn local_path_from_file_uri(uri: &str) -> Result<PathBuf> {
219    let url = Url::parse(uri).map_err(|err| {
220        OmniError::manifest_internal(format!("invalid file uri '{}': {}", uri, err))
221    })?;
222    url.to_file_path()
223        .map_err(|_| OmniError::manifest_internal(format!("invalid file uri '{}'", uri)))
224}
225
226fn parse_s3_uri(uri: &str) -> Result<S3Location> {
227    let url = Url::parse(uri).map_err(|err| {
228        OmniError::manifest_internal(format!("invalid s3 uri '{}': {}", uri, err))
229    })?;
230    if url.scheme() != "s3" {
231        return Err(OmniError::manifest_internal(format!(
232            "unsupported s3 uri '{}'",
233            uri
234        )));
235    }
236    let bucket = url
237        .host_str()
238        .ok_or_else(|| OmniError::manifest_internal(format!("missing s3 bucket in '{}'", uri)))?;
239    Ok(S3Location {
240        bucket: bucket.to_string(),
241        key: url.path().trim_start_matches('/').to_string(),
242    })
243}
244
245fn storage_backend_error(action: &str, uri: &str, err: impl std::fmt::Display) -> OmniError {
246    OmniError::manifest_internal(format!("storage {} failed for '{}': {}", action, uri, err))
247}
248
249fn normalize_local_path(path: &Path) -> String {
250    let raw = path.as_os_str().to_string_lossy();
251    if raw == "/" {
252        return raw.to_string();
253    }
254    trim_trailing_slashes(&raw)
255}
256
257fn trim_trailing_slashes(value: &str) -> String {
258    let trimmed = value.trim_end_matches('/');
259    if trimmed.is_empty() {
260        value.to_string()
261    } else {
262        trimmed.to_string()
263    }
264}
265
266fn env_var_truthy(key: &str) -> bool {
267    matches!(
268        env::var(key).ok().as_deref(),
269        Some("1" | "true" | "TRUE" | "True" | "yes" | "YES" | "on" | "ON")
270    )
271}
272
273#[cfg(test)]
274mod tests {
275    use super::*;
276
277    #[test]
278    fn storage_backend_selection_is_scheme_aware() {
279        assert_eq!(storage_kind_for_uri("/tmp/repo"), StorageKind::Local);
280        assert_eq!(storage_kind_for_uri("file:///tmp/repo"), StorageKind::Local);
281        assert_eq!(
282            storage_kind_for_uri("s3://omnigraph-preview/repo"),
283            StorageKind::S3
284        );
285    }
286
287    #[test]
288    fn normalize_root_uri_preserves_local_and_s3_shapes() {
289        assert_eq!(
290            normalize_root_uri("/tmp/omnigraph/").unwrap(),
291            "/tmp/omnigraph"
292        );
293        assert_eq!(
294            normalize_root_uri("file:///tmp/omnigraph/").unwrap(),
295            "/tmp/omnigraph"
296        );
297        assert_eq!(
298            normalize_root_uri("s3://bucket/prefix/").unwrap(),
299            "s3://bucket/prefix"
300        );
301    }
302
303    #[test]
304    fn join_uri_handles_local_file_and_s3_roots() {
305        assert_eq!(
306            join_uri("/tmp/omnigraph", "_schema.pg"),
307            "/tmp/omnigraph/_schema.pg"
308        );
309        assert_eq!(
310            join_uri("file:///tmp/omnigraph", "_schema.pg"),
311            "/tmp/omnigraph/_schema.pg"
312        );
313        assert_eq!(
314            join_uri("s3://bucket/prefix", "_schema.pg"),
315            "s3://bucket/prefix/_schema.pg"
316        );
317    }
318
319    #[test]
320    fn parse_s3_uri_splits_bucket_and_key() {
321        let location = parse_s3_uri("s3://bucket/repo/_schema.pg").unwrap();
322        assert_eq!(location.bucket, "bucket");
323        assert_eq!(location.key, "repo/_schema.pg");
324    }
325}