use std::path::PathBuf;
use async_trait::async_trait;
use tokio::fs;
use super::{now_secs, Artifact, ArtifactError, ArtifactMetadata, ArtifactService};
pub struct FileArtifactService {
root_dir: PathBuf,
}
impl FileArtifactService {
pub fn new(root_dir: impl Into<PathBuf>) -> Result<Self, ArtifactError> {
let root = root_dir.into();
std::fs::create_dir_all(&root)
.map_err(|e| ArtifactError::Storage(format!("Failed to create root dir: {}", e)))?;
Ok(Self { root_dir: root })
}
fn artifact_dir(&self, session_id: &str, name: &str) -> PathBuf {
let safe_session = sanitize_path_component(session_id);
let safe_name = sanitize_path_component(name);
self.root_dir.join(&safe_session).join(&safe_name)
}
fn version_dir(&self, session_id: &str, name: &str, version: u32) -> PathBuf {
self.artifact_dir(session_id, name)
.join(format!("v{}", version))
}
async fn next_version(&self, session_id: &str, name: &str) -> u32 {
let dir = self.artifact_dir(session_id, name);
if !dir.exists() {
return 1;
}
let mut max_version = 0u32;
if let Ok(mut entries) = fs::read_dir(&dir).await {
while let Ok(Some(entry)) = entries.next_entry().await {
if let Some(name) = entry.file_name().to_str() {
if let Some(v) = name.strip_prefix('v') {
if let Ok(version) = v.parse::<u32>() {
max_version = max_version.max(version);
}
}
}
}
}
max_version + 1
}
}
fn sanitize_path_component(s: &str) -> String {
s.replace(['/', '\\', '.'], "_")
}
#[async_trait]
impl ArtifactService for FileArtifactService {
async fn save(
&self,
session_id: &str,
artifact: Artifact,
) -> Result<ArtifactMetadata, ArtifactError> {
let version = self.next_version(session_id, &artifact.metadata.name).await;
let ver_dir = self.version_dir(session_id, &artifact.metadata.name, version);
fs::create_dir_all(&ver_dir)
.await
.map_err(|e| ArtifactError::Storage(e.to_string()))?;
fs::write(ver_dir.join("data"), &artifact.data)
.await
.map_err(|e| ArtifactError::Storage(e.to_string()))?;
let mut metadata = artifact.metadata;
metadata.version = version;
metadata.updated_at = now_secs();
if version == 1 {
metadata.created_at = metadata.updated_at;
}
let metadata_json = serde_json::to_string_pretty(&metadata)
.map_err(|e| ArtifactError::Storage(e.to_string()))?;
fs::write(ver_dir.join("metadata.json"), metadata_json)
.await
.map_err(|e| ArtifactError::Storage(e.to_string()))?;
Ok(metadata)
}
async fn load(&self, session_id: &str, name: &str) -> Result<Option<Artifact>, ArtifactError> {
let latest = self.next_version(session_id, name).await;
if latest == 1 {
return Ok(None);
}
self.load_version(session_id, name, latest - 1).await
}
async fn load_version(
&self,
session_id: &str,
name: &str,
version: u32,
) -> Result<Option<Artifact>, ArtifactError> {
let ver_dir = self.version_dir(session_id, name, version);
if !ver_dir.exists() {
return Ok(None);
}
let data = fs::read(ver_dir.join("data"))
.await
.map_err(|e| ArtifactError::Storage(e.to_string()))?;
let metadata_str = fs::read_to_string(ver_dir.join("metadata.json"))
.await
.map_err(|e| ArtifactError::Storage(e.to_string()))?;
let metadata: ArtifactMetadata = serde_json::from_str(&metadata_str)
.map_err(|e| ArtifactError::Storage(e.to_string()))?;
Ok(Some(Artifact { metadata, data }))
}
async fn list(&self, session_id: &str) -> Result<Vec<ArtifactMetadata>, ArtifactError> {
let session_dir = self.root_dir.join(sanitize_path_component(session_id));
if !session_dir.exists() {
return Ok(vec![]);
}
let mut result = vec![];
let mut entries = fs::read_dir(&session_dir)
.await
.map_err(|e| ArtifactError::Storage(e.to_string()))?;
while let Some(entry) = entries
.next_entry()
.await
.map_err(|e| ArtifactError::Storage(e.to_string()))?
{
if entry.file_type().await.map(|t| t.is_dir()).unwrap_or(false) {
let name = entry.file_name().to_string_lossy().to_string();
if let Ok(Some(artifact)) = self.load(session_id, &name).await {
result.push(artifact.metadata);
}
}
}
Ok(result)
}
async fn delete(&self, session_id: &str, name: &str) -> Result<(), ArtifactError> {
let dir = self.artifact_dir(session_id, name);
if dir.exists() {
fs::remove_dir_all(&dir)
.await
.map_err(|e| ArtifactError::Storage(e.to_string()))?;
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::atomic::{AtomicU32, Ordering};
fn test_dir() -> PathBuf {
static COUNTER: AtomicU32 = AtomicU32::new(0);
let id = COUNTER.fetch_add(1, Ordering::SeqCst);
let dir = std::env::temp_dir()
.join("rs_adk_file_artifact_tests")
.join(format!("test_{}_{}", std::process::id(), id));
let _ = std::fs::remove_dir_all(&dir);
dir
}
#[tokio::test]
async fn save_and_load_round_trip() {
let dir = test_dir();
let svc = FileArtifactService::new(&dir).unwrap();
let artifact = Artifact::text("notes", "Hello, world!");
let meta = svc.save("session1", artifact).await.unwrap();
assert_eq!(meta.name, "notes");
assert_eq!(meta.version, 1);
let loaded = svc.load("session1", "notes").await.unwrap().unwrap();
assert_eq!(std::str::from_utf8(&loaded.data).unwrap(), "Hello, world!");
assert_eq!(loaded.metadata.version, 1);
assert_eq!(loaded.metadata.mime_type, "text/plain");
std::fs::remove_dir_all(&dir).ok();
}
#[tokio::test]
async fn versioning_increments_and_load_gets_latest() {
let dir = test_dir();
let svc = FileArtifactService::new(&dir).unwrap();
let m1 = svc
.save("s1", Artifact::text("doc", "version 1"))
.await
.unwrap();
assert_eq!(m1.version, 1);
let m2 = svc
.save("s1", Artifact::text("doc", "version 2"))
.await
.unwrap();
assert_eq!(m2.version, 2);
let m3 = svc
.save("s1", Artifact::text("doc", "version 3"))
.await
.unwrap();
assert_eq!(m3.version, 3);
let latest = svc.load("s1", "doc").await.unwrap().unwrap();
assert_eq!(latest.metadata.version, 3);
assert_eq!(std::str::from_utf8(&latest.data).unwrap(), "version 3");
std::fs::remove_dir_all(&dir).ok();
}
#[tokio::test]
async fn load_specific_version() {
let dir = test_dir();
let svc = FileArtifactService::new(&dir).unwrap();
svc.save("s1", Artifact::text("doc", "v1 data"))
.await
.unwrap();
svc.save("s1", Artifact::text("doc", "v2 data"))
.await
.unwrap();
svc.save("s1", Artifact::text("doc", "v3 data"))
.await
.unwrap();
let v1 = svc.load_version("s1", "doc", 1).await.unwrap().unwrap();
assert_eq!(std::str::from_utf8(&v1.data).unwrap(), "v1 data");
assert_eq!(v1.metadata.version, 1);
let v2 = svc.load_version("s1", "doc", 2).await.unwrap().unwrap();
assert_eq!(std::str::from_utf8(&v2.data).unwrap(), "v2 data");
let v3 = svc.load_version("s1", "doc", 3).await.unwrap().unwrap();
assert_eq!(std::str::from_utf8(&v3.data).unwrap(), "v3 data");
let v99 = svc.load_version("s1", "doc", 99).await.unwrap();
assert!(v99.is_none());
std::fs::remove_dir_all(&dir).ok();
}
#[tokio::test]
async fn list_artifacts() {
let dir = test_dir();
let svc = FileArtifactService::new(&dir).unwrap();
svc.save("s1", Artifact::text("alpha", "data"))
.await
.unwrap();
svc.save("s1", Artifact::text("beta", "data"))
.await
.unwrap();
svc.save("s2", Artifact::text("gamma", "data"))
.await
.unwrap();
let list = svc.list("s1").await.unwrap();
assert_eq!(list.len(), 2);
let names: Vec<&str> = list.iter().map(|m| m.name.as_str()).collect();
assert!(names.contains(&"alpha"));
assert!(names.contains(&"beta"));
let list2 = svc.list("s2").await.unwrap();
assert_eq!(list2.len(), 1);
assert_eq!(list2[0].name, "gamma");
std::fs::remove_dir_all(&dir).ok();
}
#[tokio::test]
async fn delete_artifact() {
let dir = test_dir();
let svc = FileArtifactService::new(&dir).unwrap();
svc.save("s1", Artifact::text("notes", "data"))
.await
.unwrap();
svc.save("s1", Artifact::text("notes", "v2")).await.unwrap();
svc.delete("s1", "notes").await.unwrap();
let result = svc.load("s1", "notes").await.unwrap();
assert!(result.is_none());
svc.delete("s1", "notes").await.unwrap();
std::fs::remove_dir_all(&dir).ok();
}
#[tokio::test]
async fn load_nonexistent_returns_none() {
let dir = test_dir();
let svc = FileArtifactService::new(&dir).unwrap();
let result = svc.load("no_session", "no_artifact").await.unwrap();
assert!(result.is_none());
std::fs::remove_dir_all(&dir).ok();
}
#[tokio::test]
async fn path_traversal_prevention() {
let dir = test_dir();
let svc = FileArtifactService::new(&dir).unwrap();
let artifact = Artifact::text("../../../etc/passwd", "malicious");
let meta = svc.save("../../hack", artifact).await.unwrap();
assert_eq!(meta.version, 1);
let sanitized_session = sanitize_path_component("../../hack");
let sanitized_name = sanitize_path_component("../../../etc/passwd");
assert!(!sanitized_session.contains('/'));
assert!(!sanitized_session.contains('\\'));
assert!(!sanitized_session.contains('.'));
assert!(!sanitized_name.contains('/'));
assert!(!sanitized_name.contains('\\'));
assert!(!sanitized_name.contains('.'));
let loaded = svc.load("../../hack", "../../../etc/passwd").await.unwrap();
assert!(loaded.is_some());
assert_eq!(
std::str::from_utf8(&loaded.unwrap().data).unwrap(),
"malicious"
);
assert!(dir.exists());
std::fs::remove_dir_all(&dir).ok();
}
#[test]
fn sanitize_removes_dangerous_chars() {
assert_eq!(sanitize_path_component("normal"), "normal");
assert_eq!(sanitize_path_component(".."), "__");
assert_eq!(sanitize_path_component("a/b"), "a_b");
assert_eq!(sanitize_path_component("a\\b"), "a_b");
assert_eq!(sanitize_path_component("../../etc"), "______etc");
assert_eq!(sanitize_path_component("file.txt"), "file_txt");
}
}