use std::collections::HashMap;
use std::path::{Path, PathBuf};
use tokio::fs;
use crate::error::NikaError;
use crate::io::atomic::write_atomic;
use crate::io::security::validate_artifact_path;
use crate::io::template::TemplateResolver;
use crate::OutputFormat;
pub const DEFAULT_MAX_SIZE: u64 = 10 * 1024 * 1024;
#[derive(Debug, Clone)]
pub struct WriteResult {
pub path: PathBuf,
pub size: u64,
pub format: OutputFormat,
}
#[derive(Debug, Clone)]
pub struct WriteRequest {
pub task_id: String,
pub output_path: String,
pub content: String,
pub format: OutputFormat,
pub vars: HashMap<String, String>,
}
impl WriteRequest {
pub fn new(task_id: impl Into<String>, output_path: impl Into<String>) -> Self {
Self {
task_id: task_id.into(),
output_path: output_path.into(),
content: String::new(),
format: OutputFormat::Text,
vars: HashMap::new(),
}
}
pub fn with_content(mut self, content: impl Into<String>) -> Self {
self.content = content.into();
self
}
pub fn with_format(mut self, format: OutputFormat) -> Self {
self.format = format;
self
}
pub fn with_var(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
self.vars.insert(key.into(), value.into());
self
}
pub fn with_vars(mut self, vars: HashMap<String, String>) -> Self {
self.vars.extend(vars);
self
}
}
#[derive(Debug)]
pub struct ArtifactWriter {
artifact_dir: PathBuf,
workflow_name: String,
max_size: u64,
}
impl ArtifactWriter {
pub fn new(artifact_dir: impl Into<PathBuf>, workflow_name: impl Into<String>) -> Self {
Self {
artifact_dir: artifact_dir.into(),
workflow_name: workflow_name.into(),
max_size: DEFAULT_MAX_SIZE,
}
}
pub fn with_max_size(mut self, max_size: u64) -> Self {
self.max_size = max_size;
self
}
pub fn artifact_dir(&self) -> &Path {
&self.artifact_dir
}
pub async fn write(&self, request: WriteRequest) -> Result<WriteResult, NikaError> {
let content_size = request.content.len() as u64;
if content_size > self.max_size {
return Err(NikaError::ArtifactSizeExceeded {
path: request.output_path.clone(),
size: content_size,
max_size: self.max_size,
});
}
if matches!(request.format, OutputFormat::Json) && !request.content.is_empty() {
if let Err(e) = serde_json::from_str::<serde_json::Value>(&request.content) {
return Err(NikaError::ArtifactWriteError {
path: request.output_path.clone(),
reason: format!("Invalid JSON content: {}", e),
});
}
}
let resolver = TemplateResolver::new(&request.task_id, &self.workflow_name)
.with_vars(request.vars.clone())?;
let resolved_path = resolver.resolve(&request.output_path)?;
let full_path = validate_artifact_path(&self.artifact_dir, Path::new(&resolved_path))?;
if let Some(parent) = full_path.parent() {
fs::create_dir_all(parent)
.await
.map_err(|e| NikaError::ArtifactWriteError {
path: parent.display().to_string(),
reason: format!("Failed to create parent directories: {}", e),
})?;
}
let final_path = validate_artifact_path(&self.artifact_dir, Path::new(&resolved_path))?;
write_atomic(&final_path, request.content.as_bytes())
.await
.map_err(|e| NikaError::ArtifactWriteError {
path: final_path.display().to_string(),
reason: format!("Atomic write failed: {}", e),
})?;
Ok(WriteResult {
path: final_path,
size: content_size,
format: request.format,
})
}
pub fn validate_path(&self, task_id: &str, output_path: &str) -> Result<PathBuf, NikaError> {
let resolver = TemplateResolver::new(task_id, &self.workflow_name);
let resolved_path = resolver.resolve(output_path)?;
validate_artifact_path(&self.artifact_dir, Path::new(&resolved_path))
}
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::tempdir;
fn test_writer() -> (ArtifactWriter, tempfile::TempDir) {
let temp = tempdir().unwrap();
let artifact_dir = temp.path().join("artifacts");
std::fs::create_dir_all(&artifact_dir).unwrap();
let canonical_dir = artifact_dir.canonicalize().unwrap();
let writer = ArtifactWriter::new(canonical_dir, "test-workflow");
(writer, temp)
}
#[tokio::test]
async fn test_write_simple() {
let (writer, _temp) = test_writer();
let request = WriteRequest::new("task1", "output.json")
.with_content(r#"{"key": "value"}"#)
.with_format(OutputFormat::Json);
let result = writer.write(request).await.unwrap();
assert!(result.path.ends_with("output.json"));
assert_eq!(result.size, 16);
assert!(matches!(result.format, OutputFormat::Json));
}
#[tokio::test]
async fn test_write_with_template() {
let (writer, _temp) = test_writer();
let request = WriteRequest::new("generate_page", "{{task_id}}/output.json")
.with_content("test content");
let result = writer.write(request).await.unwrap();
assert!(result.path.to_string_lossy().contains("generate_page"));
}
#[tokio::test]
async fn test_write_nested_path() {
let (writer, _temp) = test_writer();
let request =
WriteRequest::new("task1", "deep/nested/path/output.txt").with_content("hello");
let result = writer.write(request).await.unwrap();
assert!(result.path.ends_with("deep/nested/path/output.txt"));
}
#[tokio::test]
async fn test_write_size_exceeded() {
let temp = tempdir().unwrap();
let artifact_dir = temp.path().join("artifacts");
std::fs::create_dir_all(&artifact_dir).unwrap();
let canonical_dir = artifact_dir.canonicalize().unwrap();
let writer = ArtifactWriter::new(canonical_dir, "test").with_max_size(10);
let request = WriteRequest::new("task1", "output.txt")
.with_content("this content is longer than 10 bytes");
let result = writer.write(request).await;
assert!(result.is_err());
let err = result.unwrap_err();
assert!(matches!(err, NikaError::ArtifactSizeExceeded { .. }));
}
#[tokio::test]
async fn test_write_path_traversal_blocked() {
let (writer, _temp) = test_writer();
let request =
WriteRequest::new("task1", "../../../etc/passwd").with_content("malicious content");
let result = writer.write(request).await;
assert!(result.is_err());
let err = result.unwrap_err();
assert!(matches!(err, NikaError::ArtifactPathError { .. }));
}
#[tokio::test]
async fn test_write_absolute_path_blocked() {
let (writer, _temp) = test_writer();
let request = WriteRequest::new("task1", "/etc/passwd").with_content("test");
let result = writer.write(request).await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_write_custom_vars() {
let (writer, _temp) = test_writer();
let request = WriteRequest::new("task1", "locales/{{locale}}/{{entity}}.json")
.with_content("{}")
.with_var("locale", "fr-FR")
.with_var("entity", "qr-code");
let result = writer.write(request).await.unwrap();
assert!(result.path.to_string_lossy().contains("fr-FR"));
assert!(result.path.to_string_lossy().contains("qr-code"));
}
#[tokio::test]
async fn test_write_invalid_json_rejected() {
let (writer, _temp) = test_writer();
let request = WriteRequest::new("task1", "output.json")
.with_content("{ invalid json }")
.with_format(OutputFormat::Json);
let result = writer.write(request).await;
assert!(result.is_err());
let err = result.unwrap_err();
if let NikaError::ArtifactWriteError { reason, .. } = err {
assert!(reason.contains("Invalid JSON"));
} else {
panic!("Expected ArtifactWriteError");
}
}
#[tokio::test]
async fn test_write_valid_json_accepted() {
let (writer, _temp) = test_writer();
let request = WriteRequest::new("task1", "output.json")
.with_content(r#"{"valid": true, "nested": {"key": 123}}"#)
.with_format(OutputFormat::Json);
let result = writer.write(request).await;
assert!(result.is_ok());
}
#[tokio::test]
async fn test_write_var_path_traversal_blocked() {
let (writer, _temp) = test_writer();
let request = WriteRequest::new("task1", "{{entity}}/output.json")
.with_content("{}")
.with_var("entity", "../escape");
let result = writer.write(request).await;
assert!(result.is_err());
let err = result.unwrap_err();
assert!(matches!(err, NikaError::TemplateError { .. }));
}
#[test]
fn test_validate_path() {
let (writer, _temp) = test_writer();
let result = writer.validate_path("task1", "output.json");
assert!(result.is_ok());
}
#[test]
fn test_validate_path_traversal() {
let (writer, _temp) = test_writer();
let result = writer.validate_path("task1", "../escape.txt");
assert!(result.is_err());
}
#[test]
fn test_writer_max_size() {
let temp = tempdir().unwrap();
let writer = ArtifactWriter::new(temp.path(), "test").with_max_size(1024);
assert_eq!(writer.max_size, 1024);
}
#[test]
fn test_write_request_builder() {
let mut vars = HashMap::new();
vars.insert("key1".to_string(), "val1".to_string());
vars.insert("key2".to_string(), "val2".to_string());
let request = WriteRequest::new("task", "path.txt")
.with_content("content")
.with_format(OutputFormat::Json)
.with_vars(vars);
assert_eq!(request.task_id, "task");
assert_eq!(request.output_path, "path.txt");
assert_eq!(request.content, "content");
assert!(matches!(request.format, OutputFormat::Json));
assert_eq!(request.vars.len(), 2);
}
}