Skip to main content

mockforge_registry_server/
storage.rs

1//! Plugin binary storage (S3-compatible with local filesystem fallback)
2
3use anyhow::{Context, Result};
4use aws_config::BehaviorVersion;
5use aws_sdk_s3::{
6    config::{Credentials, Region},
7    Client as S3Client,
8};
9use std::path::{Path, PathBuf};
10
11use crate::config::Config;
12
13#[derive(Clone)]
14enum StorageBackend {
15    S3 { client: S3Client, bucket: String },
16    Local { base_dir: PathBuf },
17}
18
19#[derive(Clone)]
20pub struct PluginStorage {
21    backend: StorageBackend,
22}
23
24impl PluginStorage {
25    pub async fn new(config: &Config) -> Result<Self> {
26        // Determine whether S3 is usable: we need a real bucket name AND either
27        // an explicit endpoint with credentials, or default AWS credentials.
28        let use_s3 = if config.s3_endpoint.is_some() {
29            // Custom endpoint requires explicit credentials
30            std::env::var("AWS_ACCESS_KEY_ID")
31                .ok()
32                .filter(|v| !v.trim().is_empty())
33                .is_some()
34                && std::env::var("AWS_SECRET_ACCESS_KEY")
35                    .ok()
36                    .filter(|v| !v.trim().is_empty())
37                    .is_some()
38        } else {
39            // For AWS S3, check if the bucket is the default placeholder and
40            // if AWS credentials are likely available
41            config.s3_bucket != "mockforge-plugins"
42                || std::env::var("AWS_ACCESS_KEY_ID").is_ok()
43                || std::env::var("AWS_PROFILE").is_ok()
44                || std::env::var("AWS_ROLE_ARN").is_ok()
45        };
46
47        if use_s3 {
48            let aws_config = if let Some(endpoint) = &config.s3_endpoint {
49                let access_key_id = std::env::var("AWS_ACCESS_KEY_ID")
50                    .context("AWS_ACCESS_KEY_ID is required when using custom S3 endpoint")?;
51                let secret_access_key = std::env::var("AWS_SECRET_ACCESS_KEY")
52                    .context("AWS_SECRET_ACCESS_KEY is required when using custom S3 endpoint")?;
53
54                tracing::info!("Using custom S3 endpoint: {} with explicit credentials", endpoint);
55
56                let credentials =
57                    Credentials::new(access_key_id, secret_access_key, None, None, "static");
58
59                aws_config::defaults(BehaviorVersion::latest())
60                    .region(Region::new(config.s3_region.clone()))
61                    .credentials_provider(credentials)
62                    .endpoint_url(endpoint)
63                    .load()
64                    .await
65            } else {
66                tracing::info!(
67                    "Using AWS S3 with default credentials provider chain (region: {})",
68                    config.s3_region
69                );
70
71                aws_config::defaults(BehaviorVersion::latest())
72                    .region(Region::new(config.s3_region.clone()))
73                    .load()
74                    .await
75            };
76
77            let client = S3Client::new(&aws_config);
78            let bucket = config.s3_bucket.clone();
79
80            // Validate S3 connectivity — fall back to local storage if unreachable
81            match client.head_bucket().bucket(&bucket).send().await {
82                Ok(_) => {
83                    tracing::info!("S3 storage verified (bucket: {})", bucket);
84                    return Ok(Self {
85                        backend: StorageBackend::S3 { client, bucket },
86                    });
87                }
88                Err(e) => {
89                    tracing::warn!(
90                        "S3 health check failed (bucket: {}): {}. Falling back to local storage.",
91                        bucket,
92                        e
93                    );
94                    // Fall through to local storage
95                }
96            }
97        }
98
99        // Local filesystem fallback
100        let base_dir = PathBuf::from(
101            std::env::var("STORAGE_PATH").unwrap_or_else(|_| "./data/storage".to_string()),
102        );
103
104        // Ensure base directory exists
105        tokio::fs::create_dir_all(&base_dir).await.with_context(|| {
106            format!("Failed to create local storage directory: {}", base_dir.display())
107        })?;
108
109        tracing::info!("Using local filesystem storage at: {}", base_dir.display());
110
111        Ok(Self {
112            backend: StorageBackend::Local { base_dir },
113        })
114    }
115
116    /// Sanitize a name/version for use in S3 keys or local file paths
117    /// Removes dangerous characters and path traversal attempts
118    fn sanitize_key_component(component: &str) -> String {
119        component
120            .chars()
121            .filter(|c| c.is_ascii_alphanumeric() || *c == '-' || *c == '_' || *c == '.')
122            .take(100) // Limit length
123            .collect::<String>()
124            .trim_matches('.')
125            .trim_matches('-')
126            .trim_matches('_')
127            .to_lowercase()
128    }
129
130    /// Write data to a local file path, creating parent directories as needed
131    async fn local_write(base_dir: &Path, key: &str, data: Vec<u8>) -> Result<String> {
132        let file_path = base_dir.join(key);
133        if let Some(parent) = file_path.parent() {
134            tokio::fs::create_dir_all(parent)
135                .await
136                .with_context(|| format!("Failed to create directory: {}", parent.display()))?;
137        }
138        tokio::fs::write(&file_path, &data)
139            .await
140            .with_context(|| format!("Failed to write file: {}", file_path.display()))?;
141
142        // Return relative path as the URL (the registry server can serve these)
143        Ok(format!("/storage/{key}"))
144    }
145
146    /// Read data from a local file path
147    async fn local_read(base_dir: &Path, key: &str) -> Result<Vec<u8>> {
148        let file_path = base_dir.join(key);
149        tokio::fs::read(&file_path)
150            .await
151            .with_context(|| format!("Failed to read file: {}", file_path.display()))
152    }
153
154    /// Delete a local file
155    async fn local_delete(base_dir: &Path, key: &str) -> Result<()> {
156        let file_path = base_dir.join(key);
157        if file_path.exists() {
158            tokio::fs::remove_file(&file_path)
159                .await
160                .with_context(|| format!("Failed to delete file: {}", file_path.display()))?;
161        }
162        Ok(())
163    }
164
165    /// Build an S3 download URL for the given bucket and key
166    fn s3_url(bucket: &str, key: &str) -> String {
167        if let Ok(endpoint) = std::env::var("S3_ENDPOINT") {
168            format!("{}/{}/{}", endpoint, bucket, key)
169        } else {
170            format!("https://{}.s3.amazonaws.com/{}", bucket, key)
171        }
172    }
173
174    pub async fn upload_plugin(
175        &self,
176        plugin_name: &str,
177        version: &str,
178        data: Vec<u8>,
179    ) -> Result<String> {
180        let safe_name = Self::sanitize_key_component(plugin_name);
181        let safe_version = Self::sanitize_key_component(version);
182
183        if safe_name.is_empty() {
184            anyhow::bail!("Plugin name cannot be empty after sanitization");
185        }
186        if safe_version.is_empty() {
187            anyhow::bail!("Version cannot be empty after sanitization");
188        }
189
190        let key = format!("plugins/{}/{}.wasm", safe_name, safe_version);
191
192        match &self.backend {
193            StorageBackend::S3 { client, bucket } => {
194                client
195                    .put_object()
196                    .bucket(bucket)
197                    .key(&key)
198                    .body(data.into())
199                    .content_type("application/wasm")
200                    .send()
201                    .await?;
202                Ok(Self::s3_url(bucket, &key))
203            }
204            StorageBackend::Local { base_dir } => Self::local_write(base_dir, &key, data).await,
205        }
206    }
207
208    pub async fn upload_template(
209        &self,
210        template_name: &str,
211        version: &str,
212        data: Vec<u8>,
213    ) -> Result<String> {
214        let safe_name = Self::sanitize_key_component(template_name);
215        let safe_version = Self::sanitize_key_component(version);
216
217        if safe_name.is_empty() {
218            anyhow::bail!("Template name cannot be empty after sanitization");
219        }
220        if safe_version.is_empty() {
221            anyhow::bail!("Version cannot be empty after sanitization");
222        }
223
224        let extension = if data.len() >= 2 && data[0] == 0x1F && data[1] == 0x8B {
225            "tar.gz"
226        } else if data.len() >= 4
227            && data[0] == 0x50
228            && data[1] == 0x4B
229            && (data[2] == 0x03 || data[2] == 0x05 || data[2] == 0x07)
230            && (data[3] == 0x04 || data[3] == 0x06 || data[3] == 0x08)
231        {
232            "zip"
233        } else {
234            "tar.gz"
235        };
236
237        let key = format!("templates/{}/{}.{}", safe_name, safe_version, extension);
238
239        match &self.backend {
240            StorageBackend::S3 { client, bucket } => {
241                client
242                    .put_object()
243                    .bucket(bucket)
244                    .key(&key)
245                    .body(data.into())
246                    .content_type(if extension == "zip" {
247                        "application/zip"
248                    } else {
249                        "application/gzip"
250                    })
251                    .send()
252                    .await?;
253                Ok(Self::s3_url(bucket, &key))
254            }
255            StorageBackend::Local { base_dir } => Self::local_write(base_dir, &key, data).await,
256        }
257    }
258
259    pub async fn upload_scenario(
260        &self,
261        scenario_name: &str,
262        version: &str,
263        data: Vec<u8>,
264    ) -> Result<String> {
265        let safe_name = Self::sanitize_key_component(scenario_name);
266        let safe_version = Self::sanitize_key_component(version);
267
268        if safe_name.is_empty() {
269            anyhow::bail!("Scenario name cannot be empty after sanitization");
270        }
271        if safe_version.is_empty() {
272            anyhow::bail!("Version cannot be empty after sanitization");
273        }
274
275        let extension = if data.len() >= 2 && data[0] == 0x1F && data[1] == 0x8B {
276            "tar.gz"
277        } else if data.len() >= 4
278            && data[0] == 0x50
279            && data[1] == 0x4B
280            && (data[2] == 0x03 || data[2] == 0x05 || data[2] == 0x07)
281            && (data[3] == 0x04 || data[3] == 0x06 || data[3] == 0x08)
282        {
283            "zip"
284        } else {
285            "tar.gz"
286        };
287
288        let key = format!("scenarios/{}/{}.{}", safe_name, safe_version, extension);
289
290        match &self.backend {
291            StorageBackend::S3 { client, bucket } => {
292                client
293                    .put_object()
294                    .bucket(bucket)
295                    .key(&key)
296                    .body(data.into())
297                    .content_type(if extension == "zip" {
298                        "application/zip"
299                    } else {
300                        "application/gzip"
301                    })
302                    .send()
303                    .await?;
304                Ok(Self::s3_url(bucket, &key))
305            }
306            StorageBackend::Local { base_dir } => Self::local_write(base_dir, &key, data).await,
307        }
308    }
309
310    /// Resolve the storage key for a plugin version using the same
311    /// sanitization rules as `upload_plugin`. Public so background workers can
312    /// fetch artifacts for scanning without duplicating the scheme.
313    pub fn plugin_object_key(plugin_name: &str, version: &str) -> Result<String> {
314        let safe_name = Self::sanitize_key_component(plugin_name);
315        let safe_version = Self::sanitize_key_component(version);
316        if safe_name.is_empty() {
317            anyhow::bail!("Plugin name cannot be empty after sanitization");
318        }
319        if safe_version.is_empty() {
320            anyhow::bail!("Version cannot be empty after sanitization");
321        }
322        Ok(format!("plugins/{}/{}.wasm", safe_name, safe_version))
323    }
324
325    pub async fn download_plugin(&self, key: &str) -> Result<Vec<u8>> {
326        match &self.backend {
327            StorageBackend::S3 { client, bucket } => {
328                let response = client.get_object().bucket(bucket).key(key).send().await?;
329                let bytes = response.body.collect().await?;
330                Ok(bytes.to_vec())
331            }
332            StorageBackend::Local { base_dir } => Self::local_read(base_dir, key).await,
333        }
334    }
335
336    pub async fn delete_plugin(&self, key: &str) -> Result<()> {
337        match &self.backend {
338            StorageBackend::S3 { client, bucket } => {
339                client.delete_object().bucket(bucket).key(key).send().await?;
340                Ok(())
341            }
342            StorageBackend::Local { base_dir } => Self::local_delete(base_dir, key).await,
343        }
344    }
345
346    /// Upload an OpenAPI spec file for a hosted mock deployment
347    pub async fn upload_spec(
348        &self,
349        org_id: &str,
350        spec_name: &str,
351        data: Vec<u8>,
352    ) -> Result<String> {
353        let safe_org = Self::sanitize_key_component(org_id);
354        let safe_name = Self::sanitize_key_component(spec_name);
355
356        if safe_org.is_empty() {
357            anyhow::bail!("Org ID cannot be empty after sanitization");
358        }
359        if safe_name.is_empty() {
360            anyhow::bail!("Spec name cannot be empty after sanitization");
361        }
362
363        let key = format!("specs/{}/{}.json", safe_org, safe_name);
364
365        match &self.backend {
366            StorageBackend::S3 { client, bucket } => {
367                client
368                    .put_object()
369                    .bucket(bucket)
370                    .key(&key)
371                    .body(data.into())
372                    .content_type("application/json")
373                    .send()
374                    .await?;
375                Ok(Self::s3_url(bucket, &key))
376            }
377            StorageBackend::Local { base_dir } => Self::local_write(base_dir, &key, data).await,
378        }
379    }
380
381    /// Health check - verify storage connectivity
382    pub async fn health_check(&self) -> Result<()> {
383        match &self.backend {
384            StorageBackend::S3 { client, bucket } => {
385                client
386                    .head_bucket()
387                    .bucket(bucket)
388                    .send()
389                    .await
390                    .context("S3 bucket health check failed")?;
391                Ok(())
392            }
393            StorageBackend::Local { base_dir } => {
394                // Verify directory exists and is writable
395                let test_file = base_dir.join(".health_check");
396                tokio::fs::write(&test_file, b"ok")
397                    .await
398                    .context("Local storage health check failed: cannot write")?;
399                tokio::fs::remove_file(&test_file)
400                    .await
401                    .context("Local storage health check failed: cannot delete")?;
402                Ok(())
403            }
404        }
405    }
406
407    /// Upload a workspace snapshot manifest blob (#10). Snapshots
408    /// inline their manifest in `snapshots.manifest` for small
409    /// workspaces; for larger ones the registry serializes the JSON
410    /// once + uploads here, then stores the returned URL in
411    /// `snapshots.storage_url`.
412    ///
413    /// Key format: `snapshots/{workspace_id}/{snapshot_id}.json` —
414    /// keyed by workspace so a stray scan stays within tenant
415    /// boundaries (the snapshots row is workspace-scoped on the DB
416    /// side too, no org_id column).
417    pub async fn upload_snapshot_blob(
418        &self,
419        workspace_id: uuid::Uuid,
420        snapshot_id: uuid::Uuid,
421        data: Vec<u8>,
422    ) -> Result<String> {
423        let key = format!("snapshots/{workspace_id}/{snapshot_id}.json");
424        match &self.backend {
425            StorageBackend::S3 { client, bucket } => {
426                client
427                    .put_object()
428                    .bucket(bucket)
429                    .key(&key)
430                    .body(data.into())
431                    .content_type("application/json")
432                    .send()
433                    .await
434                    .context("snapshot upload to S3 failed")?;
435                Ok(Self::s3_url(bucket, &key))
436            }
437            StorageBackend::Local { base_dir } => Self::local_write(base_dir, &key, data).await,
438        }
439    }
440
441    /// Read back a snapshot manifest blob the registry previously
442    /// uploaded. Returns the raw bytes — caller deserializes.
443    pub async fn read_snapshot_blob(
444        &self,
445        workspace_id: uuid::Uuid,
446        snapshot_id: uuid::Uuid,
447    ) -> Result<Vec<u8>> {
448        let key = format!("snapshots/{workspace_id}/{snapshot_id}.json");
449        match &self.backend {
450            StorageBackend::S3 { client, bucket } => {
451                let resp = client
452                    .get_object()
453                    .bucket(bucket)
454                    .key(&key)
455                    .send()
456                    .await
457                    .context("snapshot read from S3 failed")?;
458                let bytes =
459                    resp.body.collect().await.context("snapshot S3 body read failed")?.into_bytes();
460                Ok(bytes.to_vec())
461            }
462            StorageBackend::Local { base_dir } => Self::local_read(base_dir, &key).await,
463        }
464    }
465
466    /// Delete a snapshot's blob (called from snapshot retention
467    /// worker once the row is marked `expired`). Idempotent — missing
468    /// keys are not an error.
469    pub async fn delete_snapshot_blob(
470        &self,
471        workspace_id: uuid::Uuid,
472        snapshot_id: uuid::Uuid,
473    ) -> Result<()> {
474        let key = format!("snapshots/{workspace_id}/{snapshot_id}.json");
475        match &self.backend {
476            StorageBackend::S3 { client, bucket } => {
477                client
478                    .delete_object()
479                    .bucket(bucket)
480                    .key(&key)
481                    .send()
482                    .await
483                    .context("snapshot delete from S3 failed")?;
484                Ok(())
485            }
486            StorageBackend::Local { base_dir } => Self::local_delete(base_dir, &key).await,
487        }
488    }
489}
490
491#[cfg(test)]
492mod tests {
493    use super::*;
494
495    #[test]
496    fn test_sanitize_key_component() {
497        // Normal names should be lowercased
498        assert_eq!(PluginStorage::sanitize_key_component("MyPlugin"), "myplugin");
499
500        // Alphanumeric with hyphens, underscores, and dots should be preserved
501        assert_eq!(PluginStorage::sanitize_key_component("my-plugin_v1.0"), "my-plugin_v1.0");
502
503        // Dangerous characters should be removed
504        assert_eq!(PluginStorage::sanitize_key_component("my/plugin"), "myplugin");
505        assert_eq!(PluginStorage::sanitize_key_component("../evil"), "evil");
506        assert_eq!(PluginStorage::sanitize_key_component("plugin<script>"), "pluginscript");
507
508        // Path traversal attempts should be sanitized
509        assert_eq!(PluginStorage::sanitize_key_component("../../etc/passwd"), "etcpasswd");
510
511        // Special characters should be removed
512        assert_eq!(PluginStorage::sanitize_key_component("plugin@#$%"), "plugin");
513
514        // Leading/trailing dots, hyphens, underscores should be trimmed
515        assert_eq!(PluginStorage::sanitize_key_component("...plugin..."), "plugin");
516        assert_eq!(PluginStorage::sanitize_key_component("---plugin---"), "plugin");
517        assert_eq!(PluginStorage::sanitize_key_component("___plugin___"), "plugin");
518
519        // Mixed case with special chars
520        assert_eq!(
521            PluginStorage::sanitize_key_component("My!Super@Plugin#2024"),
522            "mysuperplugin2024"
523        );
524
525        // Long strings should be truncated to 100 characters
526        let long_name = "a".repeat(150);
527        assert_eq!(PluginStorage::sanitize_key_component(&long_name).len(), 100);
528
529        // Empty after sanitization
530        assert_eq!(PluginStorage::sanitize_key_component("@#$%^&*()"), "");
531    }
532
533    #[test]
534    fn test_sanitize_key_component_versions() {
535        // Semantic versions should be preserved
536        assert_eq!(PluginStorage::sanitize_key_component("1.0.0"), "1.0.0");
537        assert_eq!(PluginStorage::sanitize_key_component("2.3.4-alpha"), "2.3.4-alpha");
538        assert_eq!(PluginStorage::sanitize_key_component("1.0.0-beta.1"), "1.0.0-beta.1");
539
540        // Version with invalid characters (slashes removed, dots preserved)
541        assert_eq!(PluginStorage::sanitize_key_component("1.0.0/../../etc"), "1.0.0....etc");
542    }
543
544    #[test]
545    fn test_sanitize_key_component_unicode() {
546        // Unicode should be removed (only ASCII alphanumeric allowed)
547        // Trailing hyphen is also trimmed
548        assert_eq!(PluginStorage::sanitize_key_component("plugin-中文"), "plugin");
549        assert_eq!(PluginStorage::sanitize_key_component("émoji-😀"), "moji");
550    }
551
552    #[test]
553    fn test_sanitize_key_component_edge_cases() {
554        // Empty string
555        assert_eq!(PluginStorage::sanitize_key_component(""), "");
556
557        // Only special characters
558        assert_eq!(PluginStorage::sanitize_key_component("!@#$%^&*()"), "");
559
560        // Whitespace should be removed
561        assert_eq!(PluginStorage::sanitize_key_component("my plugin"), "myplugin");
562
563        // Tabs and newlines should be removed
564        assert_eq!(PluginStorage::sanitize_key_component("my\tplugin\n"), "myplugin");
565    }
566
567    #[test]
568    fn test_sanitize_key_component_security() {
569        // Path traversal attempts
570        assert_eq!(PluginStorage::sanitize_key_component("../"), "");
571        assert_eq!(PluginStorage::sanitize_key_component("..\\"), "");
572        assert_eq!(
573            PluginStorage::sanitize_key_component("../../../../../../etc/passwd"),
574            "etcpasswd"
575        );
576
577        // Null bytes
578        assert_eq!(PluginStorage::sanitize_key_component("plugin\0evil"), "pluginevil");
579
580        // Windows path separators
581        assert_eq!(
582            PluginStorage::sanitize_key_component("C:\\Windows\\System32"),
583            "cwindowssystem32"
584        );
585    }
586
587    #[tokio::test]
588    async fn test_local_storage_roundtrip() {
589        let temp_dir = tempfile::tempdir().unwrap();
590        let base_dir = temp_dir.path().to_path_buf();
591
592        let data = b"test plugin data".to_vec();
593        let key = "plugins/test-plugin/1.0.0.wasm";
594
595        // Write
596        let url = PluginStorage::local_write(&base_dir, key, data.clone()).await.unwrap();
597        assert_eq!(url, format!("/storage/{key}"));
598
599        // Read
600        let read_data = PluginStorage::local_read(&base_dir, key).await.unwrap();
601        assert_eq!(read_data, data);
602
603        // Delete
604        PluginStorage::local_delete(&base_dir, key).await.unwrap();
605        assert!(PluginStorage::local_read(&base_dir, key).await.is_err());
606    }
607
608    #[tokio::test]
609    async fn test_local_storage_health_check() {
610        let temp_dir = tempfile::tempdir().unwrap();
611        let storage = PluginStorage {
612            backend: StorageBackend::Local {
613                base_dir: temp_dir.path().to_path_buf(),
614            },
615        };
616        assert!(storage.health_check().await.is_ok());
617    }
618
619    #[tokio::test]
620    async fn test_local_upload_spec() {
621        let temp_dir = tempfile::tempdir().unwrap();
622        let storage = PluginStorage {
623            backend: StorageBackend::Local {
624                base_dir: temp_dir.path().to_path_buf(),
625            },
626        };
627
628        let spec_data = br#"{"openapi":"3.0.0","info":{"title":"Test","version":"1.0"}}"#.to_vec();
629        let url = storage.upload_spec("org123", "my-api", spec_data.clone()).await.unwrap();
630        assert!(url.contains("specs/org123/my-api.json"));
631
632        // Verify file was written
633        let read_back =
634            tokio::fs::read(temp_dir.path().join("specs/org123/my-api.json")).await.unwrap();
635        assert_eq!(read_back, spec_data);
636    }
637
638    #[tokio::test]
639    async fn test_local_upload_plugin() {
640        let temp_dir = tempfile::tempdir().unwrap();
641        let storage = PluginStorage {
642            backend: StorageBackend::Local {
643                base_dir: temp_dir.path().to_path_buf(),
644            },
645        };
646
647        let plugin_data = vec![0u8; 100];
648        let url = storage.upload_plugin("my-plugin", "1.0.0", plugin_data).await.unwrap();
649        assert!(url.contains("plugins/my-plugin/1.0.0.wasm"));
650    }
651
652    #[tokio::test]
653    async fn test_local_download_and_delete_plugin() {
654        let temp_dir = tempfile::tempdir().unwrap();
655        let storage = PluginStorage {
656            backend: StorageBackend::Local {
657                base_dir: temp_dir.path().to_path_buf(),
658            },
659        };
660
661        let plugin_data = vec![42u8; 50];
662        storage.upload_plugin("test-dl", "2.0.0", plugin_data.clone()).await.unwrap();
663
664        let key = "plugins/test-dl/2.0.0.wasm";
665        let downloaded = storage.download_plugin(key).await.unwrap();
666        assert_eq!(downloaded, plugin_data);
667
668        storage.delete_plugin(key).await.unwrap();
669        assert!(storage.download_plugin(key).await.is_err());
670    }
671}