use crate::error::{Error, Result};
use opendal::Operator;
use std::collections::HashMap;
use std::sync::{Arc, RwLock};
#[derive(Clone)]
pub struct AwsCredentials {
pub access_key_id: String,
pub secret_access_key: String,
pub session_token: Option<String>,
}
#[derive(Clone)]
pub struct FileIO {
credentials: Option<AwsCredentials>,
default_region: String,
operator_cache: Arc<RwLock<HashMap<String, Operator>>>,
default_operator: Option<Operator>,
}
impl FileIO {
pub fn new(operator: Operator) -> Self {
Self {
credentials: None,
default_region: String::new(),
operator_cache: Arc::new(RwLock::new(HashMap::new())),
default_operator: Some(operator),
}
}
pub fn from_aws_credentials(credentials: AwsCredentials, default_region: String) -> Self {
Self {
credentials: Some(credentials),
default_region,
operator_cache: Arc::new(RwLock::new(HashMap::new())),
default_operator: None,
}
}
fn extract_bucket_from_uri(&self, path: &str) -> Result<(String, String)> {
if let Some(stripped) = path.strip_prefix("s3://") {
if let Some(slash_pos) = stripped.find('/') {
let bucket = stripped[..slash_pos].to_string();
let path = stripped[slash_pos + 1..].to_string();
return Ok((bucket, path));
} else {
return Ok((stripped.to_string(), String::new()));
}
}
Err(Error::InvalidInput(format!(
"Path does not start with s3://: {}",
path
)))
}
async fn get_operator_for_path(&self, path: &str) -> Result<Operator> {
if let Some(ref op) = self.default_operator {
return Ok(op.clone());
}
if self.credentials.is_some() {
let (bucket, _) = self.extract_bucket_from_uri(path)?;
return self.get_or_create_operator(&bucket).await;
}
Err(Error::InvalidInput(
"FileIO not configured with operator or credentials".to_string(),
))
}
async fn get_or_create_operator(&self, bucket: &str) -> Result<Operator> {
{
let cache = self
.operator_cache
.read()
.map_err(|e| Error::IoError(format!("Failed to acquire read lock: {}", e)))?;
if let Some(op) = cache.get(bucket) {
return Ok(op.clone());
}
}
let mut cache = self
.operator_cache
.write()
.map_err(|e| Error::IoError(format!("Failed to acquire write lock: {}", e)))?;
if let Some(op) = cache.get(bucket) {
return Ok(op.clone());
}
let op = self.create_s3_operator(bucket, &self.default_region)?;
cache.insert(bucket.to_string(), op.clone());
Ok(op)
}
fn create_s3_operator(&self, bucket: &str, region: &str) -> Result<Operator> {
use opendal::services::S3;
let credentials = self.credentials.as_ref().ok_or_else(|| {
Error::InvalidInput("No credentials available for S3 operator creation".to_string())
})?;
let builder = S3::default()
.bucket(bucket)
.region(region)
.access_key_id(&credentials.access_key_id)
.secret_access_key(&credentials.secret_access_key);
let builder = if let Some(ref token) = credentials.session_token {
builder.session_token(token)
} else {
builder
};
Ok(Operator::new(builder)
.map_err(|e| Error::IoError(format!("Failed to create S3 operator: {}", e)))?
.finish())
}
fn normalize_path<'a>(&self, path: &'a str) -> &'a str {
path.strip_prefix("s3://")
.and_then(|stripped| stripped.find('/').map(|pos| &stripped[pos + 1..]))
.unwrap_or(path)
}
pub async fn read(&self, path: &str) -> Result<Vec<u8>> {
let operator = self.get_operator_for_path(path).await?;
let normalized = self.normalize_path(path);
operator
.read(normalized)
.await
.map(|b| b.to_vec())
.map_err(|e| Error::IoError(format!("Failed to read {}: {}", path, e)))
}
pub async fn write(&self, path: &str, data: Vec<u8>) -> Result<()> {
let operator = self.get_operator_for_path(path).await?;
let normalized = self.normalize_path(path);
operator
.write(normalized, data)
.await
.map(|_| ()) .map_err(|e| Error::IoError(format!("Failed to write {}: {}", path, e)))
}
pub async fn exists(&self, path: &str) -> Result<bool> {
let operator = self.get_operator_for_path(path).await?;
let normalized = self.normalize_path(path);
match operator.exists(normalized).await {
Ok(exists) => Ok(exists),
Err(e) => Err(Error::IoError(format!(
"Failed to check existence of {}: {}",
path, e
))),
}
}
pub async fn delete(&self, path: &str) -> Result<()> {
let operator = self.get_operator_for_path(path).await?;
let normalized = self.normalize_path(path);
operator
.delete(normalized)
.await
.map_err(|e| Error::IoError(format!("Failed to delete {}: {}", path, e)))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_file_io_creation() {
let op = Operator::via_iter(opendal::Scheme::Memory, []).unwrap();
let _file_io = FileIO::new(op);
}
}