1use anyhow::{Context, Result};
4use aws_config::BehaviorVersion;
5use aws_sdk_s3::{
6 config::{Credentials, Region},
7 Client as S3Client,
8};
9use std::path::{Path, PathBuf};
10
11use crate::config::Config;
12
13#[derive(Clone)]
14enum StorageBackend {
15 S3 { client: S3Client, bucket: String },
16 Local { base_dir: PathBuf },
17}
18
19#[derive(Clone)]
20pub struct PluginStorage {
21 backend: StorageBackend,
22}
23
24impl PluginStorage {
25 pub async fn new(config: &Config) -> Result<Self> {
26 let use_s3 = if config.s3_endpoint.is_some() {
29 std::env::var("AWS_ACCESS_KEY_ID")
31 .ok()
32 .filter(|v| !v.trim().is_empty())
33 .is_some()
34 && std::env::var("AWS_SECRET_ACCESS_KEY")
35 .ok()
36 .filter(|v| !v.trim().is_empty())
37 .is_some()
38 } else {
39 config.s3_bucket != "mockforge-plugins"
42 || std::env::var("AWS_ACCESS_KEY_ID").is_ok()
43 || std::env::var("AWS_PROFILE").is_ok()
44 || std::env::var("AWS_ROLE_ARN").is_ok()
45 };
46
47 if use_s3 {
48 let aws_config = if let Some(endpoint) = &config.s3_endpoint {
49 let access_key_id = std::env::var("AWS_ACCESS_KEY_ID")
50 .context("AWS_ACCESS_KEY_ID is required when using custom S3 endpoint")?;
51 let secret_access_key = std::env::var("AWS_SECRET_ACCESS_KEY")
52 .context("AWS_SECRET_ACCESS_KEY is required when using custom S3 endpoint")?;
53
54 tracing::info!("Using custom S3 endpoint: {} with explicit credentials", endpoint);
55
56 let credentials =
57 Credentials::new(access_key_id, secret_access_key, None, None, "static");
58
59 aws_config::defaults(BehaviorVersion::latest())
60 .region(Region::new(config.s3_region.clone()))
61 .credentials_provider(credentials)
62 .endpoint_url(endpoint)
63 .load()
64 .await
65 } else {
66 tracing::info!(
67 "Using AWS S3 with default credentials provider chain (region: {})",
68 config.s3_region
69 );
70
71 aws_config::defaults(BehaviorVersion::latest())
72 .region(Region::new(config.s3_region.clone()))
73 .load()
74 .await
75 };
76
77 let client = S3Client::new(&aws_config);
78 let bucket = config.s3_bucket.clone();
79
80 match client.head_bucket().bucket(&bucket).send().await {
82 Ok(_) => {
83 tracing::info!("S3 storage verified (bucket: {})", bucket);
84 return Ok(Self {
85 backend: StorageBackend::S3 { client, bucket },
86 });
87 }
88 Err(e) => {
89 tracing::warn!(
90 "S3 health check failed (bucket: {}): {}. Falling back to local storage.",
91 bucket,
92 e
93 );
94 }
96 }
97 }
98
99 let base_dir = PathBuf::from(
101 std::env::var("STORAGE_PATH").unwrap_or_else(|_| "./data/storage".to_string()),
102 );
103
104 tokio::fs::create_dir_all(&base_dir).await.with_context(|| {
106 format!("Failed to create local storage directory: {}", base_dir.display())
107 })?;
108
109 tracing::info!("Using local filesystem storage at: {}", base_dir.display());
110
111 Ok(Self {
112 backend: StorageBackend::Local { base_dir },
113 })
114 }
115
116 fn sanitize_key_component(component: &str) -> String {
119 component
120 .chars()
121 .filter(|c| c.is_ascii_alphanumeric() || *c == '-' || *c == '_' || *c == '.')
122 .take(100) .collect::<String>()
124 .trim_matches('.')
125 .trim_matches('-')
126 .trim_matches('_')
127 .to_lowercase()
128 }
129
130 async fn local_write(base_dir: &Path, key: &str, data: Vec<u8>) -> Result<String> {
132 let file_path = base_dir.join(key);
133 if let Some(parent) = file_path.parent() {
134 tokio::fs::create_dir_all(parent)
135 .await
136 .with_context(|| format!("Failed to create directory: {}", parent.display()))?;
137 }
138 tokio::fs::write(&file_path, &data)
139 .await
140 .with_context(|| format!("Failed to write file: {}", file_path.display()))?;
141
142 Ok(format!("/storage/{key}"))
144 }
145
146 async fn local_read(base_dir: &Path, key: &str) -> Result<Vec<u8>> {
148 let file_path = base_dir.join(key);
149 tokio::fs::read(&file_path)
150 .await
151 .with_context(|| format!("Failed to read file: {}", file_path.display()))
152 }
153
154 async fn local_delete(base_dir: &Path, key: &str) -> Result<()> {
156 let file_path = base_dir.join(key);
157 if file_path.exists() {
158 tokio::fs::remove_file(&file_path)
159 .await
160 .with_context(|| format!("Failed to delete file: {}", file_path.display()))?;
161 }
162 Ok(())
163 }
164
165 fn s3_url(bucket: &str, key: &str) -> String {
167 if let Ok(endpoint) = std::env::var("S3_ENDPOINT") {
168 format!("{}/{}/{}", endpoint, bucket, key)
169 } else {
170 format!("https://{}.s3.amazonaws.com/{}", bucket, key)
171 }
172 }
173
174 pub async fn upload_plugin(
175 &self,
176 plugin_name: &str,
177 version: &str,
178 data: Vec<u8>,
179 ) -> Result<String> {
180 let safe_name = Self::sanitize_key_component(plugin_name);
181 let safe_version = Self::sanitize_key_component(version);
182
183 if safe_name.is_empty() {
184 anyhow::bail!("Plugin name cannot be empty after sanitization");
185 }
186 if safe_version.is_empty() {
187 anyhow::bail!("Version cannot be empty after sanitization");
188 }
189
190 let key = format!("plugins/{}/{}.wasm", safe_name, safe_version);
191
192 match &self.backend {
193 StorageBackend::S3 { client, bucket } => {
194 client
195 .put_object()
196 .bucket(bucket)
197 .key(&key)
198 .body(data.into())
199 .content_type("application/wasm")
200 .send()
201 .await?;
202 Ok(Self::s3_url(bucket, &key))
203 }
204 StorageBackend::Local { base_dir } => Self::local_write(base_dir, &key, data).await,
205 }
206 }
207
208 pub async fn upload_template(
209 &self,
210 template_name: &str,
211 version: &str,
212 data: Vec<u8>,
213 ) -> Result<String> {
214 let safe_name = Self::sanitize_key_component(template_name);
215 let safe_version = Self::sanitize_key_component(version);
216
217 if safe_name.is_empty() {
218 anyhow::bail!("Template name cannot be empty after sanitization");
219 }
220 if safe_version.is_empty() {
221 anyhow::bail!("Version cannot be empty after sanitization");
222 }
223
224 let extension = if data.len() >= 2 && data[0] == 0x1F && data[1] == 0x8B {
225 "tar.gz"
226 } else if data.len() >= 4
227 && data[0] == 0x50
228 && data[1] == 0x4B
229 && (data[2] == 0x03 || data[2] == 0x05 || data[2] == 0x07)
230 && (data[3] == 0x04 || data[3] == 0x06 || data[3] == 0x08)
231 {
232 "zip"
233 } else {
234 "tar.gz"
235 };
236
237 let key = format!("templates/{}/{}.{}", safe_name, safe_version, extension);
238
239 match &self.backend {
240 StorageBackend::S3 { client, bucket } => {
241 client
242 .put_object()
243 .bucket(bucket)
244 .key(&key)
245 .body(data.into())
246 .content_type(if extension == "zip" {
247 "application/zip"
248 } else {
249 "application/gzip"
250 })
251 .send()
252 .await?;
253 Ok(Self::s3_url(bucket, &key))
254 }
255 StorageBackend::Local { base_dir } => Self::local_write(base_dir, &key, data).await,
256 }
257 }
258
259 pub async fn upload_scenario(
260 &self,
261 scenario_name: &str,
262 version: &str,
263 data: Vec<u8>,
264 ) -> Result<String> {
265 let safe_name = Self::sanitize_key_component(scenario_name);
266 let safe_version = Self::sanitize_key_component(version);
267
268 if safe_name.is_empty() {
269 anyhow::bail!("Scenario name cannot be empty after sanitization");
270 }
271 if safe_version.is_empty() {
272 anyhow::bail!("Version cannot be empty after sanitization");
273 }
274
275 let extension = if data.len() >= 2 && data[0] == 0x1F && data[1] == 0x8B {
276 "tar.gz"
277 } else if data.len() >= 4
278 && data[0] == 0x50
279 && data[1] == 0x4B
280 && (data[2] == 0x03 || data[2] == 0x05 || data[2] == 0x07)
281 && (data[3] == 0x04 || data[3] == 0x06 || data[3] == 0x08)
282 {
283 "zip"
284 } else {
285 "tar.gz"
286 };
287
288 let key = format!("scenarios/{}/{}.{}", safe_name, safe_version, extension);
289
290 match &self.backend {
291 StorageBackend::S3 { client, bucket } => {
292 client
293 .put_object()
294 .bucket(bucket)
295 .key(&key)
296 .body(data.into())
297 .content_type(if extension == "zip" {
298 "application/zip"
299 } else {
300 "application/gzip"
301 })
302 .send()
303 .await?;
304 Ok(Self::s3_url(bucket, &key))
305 }
306 StorageBackend::Local { base_dir } => Self::local_write(base_dir, &key, data).await,
307 }
308 }
309
310 pub fn plugin_object_key(plugin_name: &str, version: &str) -> Result<String> {
314 let safe_name = Self::sanitize_key_component(plugin_name);
315 let safe_version = Self::sanitize_key_component(version);
316 if safe_name.is_empty() {
317 anyhow::bail!("Plugin name cannot be empty after sanitization");
318 }
319 if safe_version.is_empty() {
320 anyhow::bail!("Version cannot be empty after sanitization");
321 }
322 Ok(format!("plugins/{}/{}.wasm", safe_name, safe_version))
323 }
324
325 pub async fn download_plugin(&self, key: &str) -> Result<Vec<u8>> {
326 match &self.backend {
327 StorageBackend::S3 { client, bucket } => {
328 let response = client.get_object().bucket(bucket).key(key).send().await?;
329 let bytes = response.body.collect().await?;
330 Ok(bytes.to_vec())
331 }
332 StorageBackend::Local { base_dir } => Self::local_read(base_dir, key).await,
333 }
334 }
335
336 pub async fn delete_plugin(&self, key: &str) -> Result<()> {
337 match &self.backend {
338 StorageBackend::S3 { client, bucket } => {
339 client.delete_object().bucket(bucket).key(key).send().await?;
340 Ok(())
341 }
342 StorageBackend::Local { base_dir } => Self::local_delete(base_dir, key).await,
343 }
344 }
345
346 pub async fn upload_spec(
348 &self,
349 org_id: &str,
350 spec_name: &str,
351 data: Vec<u8>,
352 ) -> Result<String> {
353 let safe_org = Self::sanitize_key_component(org_id);
354 let safe_name = Self::sanitize_key_component(spec_name);
355
356 if safe_org.is_empty() {
357 anyhow::bail!("Org ID cannot be empty after sanitization");
358 }
359 if safe_name.is_empty() {
360 anyhow::bail!("Spec name cannot be empty after sanitization");
361 }
362
363 let key = format!("specs/{}/{}.json", safe_org, safe_name);
364
365 match &self.backend {
366 StorageBackend::S3 { client, bucket } => {
367 client
368 .put_object()
369 .bucket(bucket)
370 .key(&key)
371 .body(data.into())
372 .content_type("application/json")
373 .send()
374 .await?;
375 Ok(Self::s3_url(bucket, &key))
376 }
377 StorageBackend::Local { base_dir } => Self::local_write(base_dir, &key, data).await,
378 }
379 }
380
381 pub async fn health_check(&self) -> Result<()> {
383 match &self.backend {
384 StorageBackend::S3 { client, bucket } => {
385 client
386 .head_bucket()
387 .bucket(bucket)
388 .send()
389 .await
390 .context("S3 bucket health check failed")?;
391 Ok(())
392 }
393 StorageBackend::Local { base_dir } => {
394 let test_file = base_dir.join(".health_check");
396 tokio::fs::write(&test_file, b"ok")
397 .await
398 .context("Local storage health check failed: cannot write")?;
399 tokio::fs::remove_file(&test_file)
400 .await
401 .context("Local storage health check failed: cannot delete")?;
402 Ok(())
403 }
404 }
405 }
406
407 pub async fn upload_snapshot_blob(
418 &self,
419 workspace_id: uuid::Uuid,
420 snapshot_id: uuid::Uuid,
421 data: Vec<u8>,
422 ) -> Result<String> {
423 let key = format!("snapshots/{workspace_id}/{snapshot_id}.json");
424 match &self.backend {
425 StorageBackend::S3 { client, bucket } => {
426 client
427 .put_object()
428 .bucket(bucket)
429 .key(&key)
430 .body(data.into())
431 .content_type("application/json")
432 .send()
433 .await
434 .context("snapshot upload to S3 failed")?;
435 Ok(Self::s3_url(bucket, &key))
436 }
437 StorageBackend::Local { base_dir } => Self::local_write(base_dir, &key, data).await,
438 }
439 }
440
441 pub async fn read_snapshot_blob(
444 &self,
445 workspace_id: uuid::Uuid,
446 snapshot_id: uuid::Uuid,
447 ) -> Result<Vec<u8>> {
448 let key = format!("snapshots/{workspace_id}/{snapshot_id}.json");
449 match &self.backend {
450 StorageBackend::S3 { client, bucket } => {
451 let resp = client
452 .get_object()
453 .bucket(bucket)
454 .key(&key)
455 .send()
456 .await
457 .context("snapshot read from S3 failed")?;
458 let bytes =
459 resp.body.collect().await.context("snapshot S3 body read failed")?.into_bytes();
460 Ok(bytes.to_vec())
461 }
462 StorageBackend::Local { base_dir } => Self::local_read(base_dir, &key).await,
463 }
464 }
465
466 pub async fn delete_snapshot_blob(
470 &self,
471 workspace_id: uuid::Uuid,
472 snapshot_id: uuid::Uuid,
473 ) -> Result<()> {
474 let key = format!("snapshots/{workspace_id}/{snapshot_id}.json");
475 match &self.backend {
476 StorageBackend::S3 { client, bucket } => {
477 client
478 .delete_object()
479 .bucket(bucket)
480 .key(&key)
481 .send()
482 .await
483 .context("snapshot delete from S3 failed")?;
484 Ok(())
485 }
486 StorageBackend::Local { base_dir } => Self::local_delete(base_dir, &key).await,
487 }
488 }
489}
490
491#[cfg(test)]
492mod tests {
493 use super::*;
494
495 #[test]
496 fn test_sanitize_key_component() {
497 assert_eq!(PluginStorage::sanitize_key_component("MyPlugin"), "myplugin");
499
500 assert_eq!(PluginStorage::sanitize_key_component("my-plugin_v1.0"), "my-plugin_v1.0");
502
503 assert_eq!(PluginStorage::sanitize_key_component("my/plugin"), "myplugin");
505 assert_eq!(PluginStorage::sanitize_key_component("../evil"), "evil");
506 assert_eq!(PluginStorage::sanitize_key_component("plugin<script>"), "pluginscript");
507
508 assert_eq!(PluginStorage::sanitize_key_component("../../etc/passwd"), "etcpasswd");
510
511 assert_eq!(PluginStorage::sanitize_key_component("plugin@#$%"), "plugin");
513
514 assert_eq!(PluginStorage::sanitize_key_component("...plugin..."), "plugin");
516 assert_eq!(PluginStorage::sanitize_key_component("---plugin---"), "plugin");
517 assert_eq!(PluginStorage::sanitize_key_component("___plugin___"), "plugin");
518
519 assert_eq!(
521 PluginStorage::sanitize_key_component("My!Super@Plugin#2024"),
522 "mysuperplugin2024"
523 );
524
525 let long_name = "a".repeat(150);
527 assert_eq!(PluginStorage::sanitize_key_component(&long_name).len(), 100);
528
529 assert_eq!(PluginStorage::sanitize_key_component("@#$%^&*()"), "");
531 }
532
533 #[test]
534 fn test_sanitize_key_component_versions() {
535 assert_eq!(PluginStorage::sanitize_key_component("1.0.0"), "1.0.0");
537 assert_eq!(PluginStorage::sanitize_key_component("2.3.4-alpha"), "2.3.4-alpha");
538 assert_eq!(PluginStorage::sanitize_key_component("1.0.0-beta.1"), "1.0.0-beta.1");
539
540 assert_eq!(PluginStorage::sanitize_key_component("1.0.0/../../etc"), "1.0.0....etc");
542 }
543
544 #[test]
545 fn test_sanitize_key_component_unicode() {
546 assert_eq!(PluginStorage::sanitize_key_component("plugin-中文"), "plugin");
549 assert_eq!(PluginStorage::sanitize_key_component("émoji-😀"), "moji");
550 }
551
552 #[test]
553 fn test_sanitize_key_component_edge_cases() {
554 assert_eq!(PluginStorage::sanitize_key_component(""), "");
556
557 assert_eq!(PluginStorage::sanitize_key_component("!@#$%^&*()"), "");
559
560 assert_eq!(PluginStorage::sanitize_key_component("my plugin"), "myplugin");
562
563 assert_eq!(PluginStorage::sanitize_key_component("my\tplugin\n"), "myplugin");
565 }
566
567 #[test]
568 fn test_sanitize_key_component_security() {
569 assert_eq!(PluginStorage::sanitize_key_component("../"), "");
571 assert_eq!(PluginStorage::sanitize_key_component("..\\"), "");
572 assert_eq!(
573 PluginStorage::sanitize_key_component("../../../../../../etc/passwd"),
574 "etcpasswd"
575 );
576
577 assert_eq!(PluginStorage::sanitize_key_component("plugin\0evil"), "pluginevil");
579
580 assert_eq!(
582 PluginStorage::sanitize_key_component("C:\\Windows\\System32"),
583 "cwindowssystem32"
584 );
585 }
586
587 #[tokio::test]
588 async fn test_local_storage_roundtrip() {
589 let temp_dir = tempfile::tempdir().unwrap();
590 let base_dir = temp_dir.path().to_path_buf();
591
592 let data = b"test plugin data".to_vec();
593 let key = "plugins/test-plugin/1.0.0.wasm";
594
595 let url = PluginStorage::local_write(&base_dir, key, data.clone()).await.unwrap();
597 assert_eq!(url, format!("/storage/{key}"));
598
599 let read_data = PluginStorage::local_read(&base_dir, key).await.unwrap();
601 assert_eq!(read_data, data);
602
603 PluginStorage::local_delete(&base_dir, key).await.unwrap();
605 assert!(PluginStorage::local_read(&base_dir, key).await.is_err());
606 }
607
608 #[tokio::test]
609 async fn test_local_storage_health_check() {
610 let temp_dir = tempfile::tempdir().unwrap();
611 let storage = PluginStorage {
612 backend: StorageBackend::Local {
613 base_dir: temp_dir.path().to_path_buf(),
614 },
615 };
616 assert!(storage.health_check().await.is_ok());
617 }
618
619 #[tokio::test]
620 async fn test_local_upload_spec() {
621 let temp_dir = tempfile::tempdir().unwrap();
622 let storage = PluginStorage {
623 backend: StorageBackend::Local {
624 base_dir: temp_dir.path().to_path_buf(),
625 },
626 };
627
628 let spec_data = br#"{"openapi":"3.0.0","info":{"title":"Test","version":"1.0"}}"#.to_vec();
629 let url = storage.upload_spec("org123", "my-api", spec_data.clone()).await.unwrap();
630 assert!(url.contains("specs/org123/my-api.json"));
631
632 let read_back =
634 tokio::fs::read(temp_dir.path().join("specs/org123/my-api.json")).await.unwrap();
635 assert_eq!(read_back, spec_data);
636 }
637
638 #[tokio::test]
639 async fn test_local_upload_plugin() {
640 let temp_dir = tempfile::tempdir().unwrap();
641 let storage = PluginStorage {
642 backend: StorageBackend::Local {
643 base_dir: temp_dir.path().to_path_buf(),
644 },
645 };
646
647 let plugin_data = vec![0u8; 100];
648 let url = storage.upload_plugin("my-plugin", "1.0.0", plugin_data).await.unwrap();
649 assert!(url.contains("plugins/my-plugin/1.0.0.wasm"));
650 }
651
652 #[tokio::test]
653 async fn test_local_download_and_delete_plugin() {
654 let temp_dir = tempfile::tempdir().unwrap();
655 let storage = PluginStorage {
656 backend: StorageBackend::Local {
657 base_dir: temp_dir.path().to_path_buf(),
658 },
659 };
660
661 let plugin_data = vec![42u8; 50];
662 storage.upload_plugin("test-dl", "2.0.0", plugin_data.clone()).await.unwrap();
663
664 let key = "plugins/test-dl/2.0.0.wasm";
665 let downloaded = storage.download_plugin(key).await.unwrap();
666 assert_eq!(downloaded, plugin_data);
667
668 storage.delete_plugin(key).await.unwrap();
669 assert!(storage.download_plugin(key).await.is_err());
670 }
671}