use crate::error::{Error, Result};
use opendal::Operator;
use std::collections::HashMap;
use std::sync::{Arc, RwLock};
#[derive(Clone)]
pub struct AwsCredentials {
pub access_key_id: String,
pub secret_access_key: String,
pub session_token: Option<String>,
}
#[derive(Debug, Clone)]
pub struct VendedCredentials {
pub access_key_id: String,
pub secret_access_key: String,
pub session_token: Option<String>,
pub endpoint: Option<String>,
pub region: Option<String>,
pub expires_at_ms: Option<i64>,
}
impl VendedCredentials {
pub fn is_expired(&self) -> bool {
const EXPIRY_BUFFER_MS: i64 = 60_000;
match self.expires_at_ms {
Some(expires_at) => {
let now_ms = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.as_millis() as i64)
.unwrap_or(0);
now_ms >= (expires_at - EXPIRY_BUFFER_MS)
}
None => false, }
}
}
#[cfg_attr(not(target_family = "wasm"), async_trait::async_trait)]
#[cfg_attr(target_family = "wasm", async_trait::async_trait(?Send))]
pub trait VendedCredentialProvider: Send + Sync + std::fmt::Debug {
async fn get_credentials(&self, path: &str) -> Result<VendedCredentials>;
fn s3_endpoint(&self) -> Option<&str>;
fn register_table(
&self,
_table_location: &str,
_namespace: &str,
_table_name: &str,
) -> Result<()> {
Ok(()) }
}
#[derive(Clone)]
pub struct FileIO {
credentials: Option<AwsCredentials>,
default_region: String,
operator_cache: Arc<RwLock<HashMap<String, Operator>>>,
default_operator: Option<Operator>,
vended_credential_provider: Option<Arc<dyn VendedCredentialProvider>>,
}
impl FileIO {
pub fn new(operator: Operator) -> Self {
Self {
credentials: None,
default_region: String::new(),
operator_cache: Arc::new(RwLock::new(HashMap::new())),
default_operator: Some(operator),
vended_credential_provider: None,
}
}
pub fn from_aws_credentials(credentials: AwsCredentials, default_region: String) -> Self {
Self {
credentials: Some(credentials),
default_region,
operator_cache: Arc::new(RwLock::new(HashMap::new())),
default_operator: None,
vended_credential_provider: None,
}
}
pub fn with_vended_credentials(provider: Arc<dyn VendedCredentialProvider>) -> Self {
Self {
credentials: None,
default_region: "auto".to_string(),
operator_cache: Arc::new(RwLock::new(HashMap::new())),
default_operator: None,
vended_credential_provider: Some(provider),
}
}
pub fn from_vended_credentials(creds: VendedCredentials, bucket: &str) -> Result<Self> {
let endpoint = creds.endpoint.clone().ok_or_else(|| {
Error::InvalidInput("Vended credentials missing endpoint".to_string())
})?;
let region = creds.region.clone().unwrap_or_else(|| "auto".to_string());
use opendal::services::S3;
let mut builder = S3::default()
.bucket(bucket)
.region(®ion)
.endpoint(&endpoint)
.access_key_id(&creds.access_key_id)
.secret_access_key(&creds.secret_access_key);
if let Some(ref token) = creds.session_token {
builder = builder.session_token(token);
}
let operator = Operator::new(builder)
.map_err(|e| Error::IoError(format!("Failed to create S3 operator: {}", e)))?
.finish();
Ok(Self::new(operator))
}
fn extract_bucket_from_uri(&self, path: &str) -> Result<(String, String)> {
if let Some(stripped) = path.strip_prefix("s3://") {
if let Some(slash_pos) = stripped.find('/') {
let bucket = stripped[..slash_pos].to_string();
let path = stripped[slash_pos + 1..].to_string();
return Ok((bucket, path));
} else {
return Ok((stripped.to_string(), String::new()));
}
}
Err(Error::InvalidInput(format!(
"Path does not start with s3://: {}",
path
)))
}
async fn get_operator_for_path(&self, path: &str) -> Result<Operator> {
if let Some(ref op) = self.default_operator {
return Ok(op.clone());
}
if self.credentials.is_some() {
let (bucket, _) = self.extract_bucket_from_uri(path)?;
return self.get_or_create_operator(&bucket).await;
}
if let Some(ref provider) = self.vended_credential_provider {
let (bucket, _) = self.extract_bucket_from_uri(path)?;
{
let cache = self
.operator_cache
.read()
.map_err(|e| Error::IoError(format!(
"Lock poisoned due to panic in another thread. This indicates a critical bug. Original error: {}",
e
)))?;
if let Some(op) = cache.get(&bucket) {
return Ok(op.clone());
}
}
let creds = provider.get_credentials(path).await?;
let endpoint = creds
.endpoint
.clone()
.or_else(|| provider.s3_endpoint().map(|s| s.to_string()))
.ok_or_else(|| {
Error::InvalidInput(
"No S3 endpoint available for vended credentials".to_string(),
)
})?;
let region = creds.region.clone().unwrap_or_else(|| "auto".to_string());
use opendal::services::S3;
let mut builder = S3::default()
.bucket(&bucket)
.region(®ion)
.endpoint(&endpoint)
.access_key_id(&creds.access_key_id)
.secret_access_key(&creds.secret_access_key);
if let Some(ref token) = creds.session_token {
builder = builder.session_token(token);
}
let operator = Operator::new(builder)
.map_err(|e| Error::IoError(format!("Failed to create S3 operator: {}", e)))?
.finish();
let mut cache = self
.operator_cache
.write()
.map_err(|e| Error::IoError(format!(
"Lock poisoned due to panic in another thread. This indicates a critical bug. Original error: {}",
e
)))?;
cache.insert(bucket, operator.clone());
return Ok(operator);
}
Err(Error::InvalidInput(
"FileIO not configured with operator or credentials".to_string(),
))
}
async fn get_or_create_operator(&self, bucket: &str) -> Result<Operator> {
{
let cache = self
.operator_cache
.read()
.map_err(|e| Error::IoError(format!(
"Lock poisoned due to panic in another thread. This indicates a critical bug. Original error: {}",
e
)))?;
if let Some(op) = cache.get(bucket) {
return Ok(op.clone());
}
}
let mut cache = self
.operator_cache
.write()
.map_err(|e| Error::IoError(format!(
"Lock poisoned due to panic in another thread. This indicates a critical bug. Original error: {}",
e
)))?;
if let Some(op) = cache.get(bucket) {
return Ok(op.clone());
}
let op = self.create_s3_operator(bucket, &self.default_region)?;
cache.insert(bucket.to_string(), op.clone());
Ok(op)
}
fn create_s3_operator(&self, bucket: &str, region: &str) -> Result<Operator> {
use opendal::services::S3;
let credentials = self.credentials.as_ref().ok_or_else(|| {
Error::InvalidInput("No credentials available for S3 operator creation".to_string())
})?;
let builder = S3::default()
.bucket(bucket)
.region(region)
.access_key_id(&credentials.access_key_id)
.secret_access_key(&credentials.secret_access_key);
let builder = if let Some(ref token) = credentials.session_token {
builder.session_token(token)
} else {
builder
};
Ok(Operator::new(builder)
.map_err(|e| Error::IoError(format!("Failed to create S3 operator: {}", e)))?
.finish())
}
fn normalize_path<'a>(&self, path: &'a str) -> &'a str {
path.strip_prefix("s3://")
.and_then(|stripped| stripped.find('/').map(|pos| &stripped[pos + 1..]))
.unwrap_or(path)
}
pub async fn read(&self, path: &str) -> Result<Vec<u8>> {
let operator = self.get_operator_for_path(path).await?;
let normalized = self.normalize_path(path);
operator
.read(normalized)
.await
.map(|b| b.to_vec())
.map_err(|e| Error::IoError(format!("Failed to read {}: {}", path, e)))
}
pub async fn read_range(&self, path: &str, offset: u64, length: u64) -> Result<Vec<u8>> {
let operator = self.get_operator_for_path(path).await?;
let normalized = self.normalize_path(path);
let start = offset;
let end = offset.saturating_add(length);
let range = start..end;
operator
.read_with(normalized)
.range(range)
.await
.map(|b| b.to_vec())
.map_err(|e| Error::IoError(format!("Failed to read range of {}: {}", path, e)))
}
pub async fn file_size(&self, path: &str) -> Result<u64> {
let operator = self.get_operator_for_path(path).await?;
let normalized = self.normalize_path(path);
operator
.stat(normalized)
.await
.map(|m| m.content_length())
.map_err(|e| Error::IoError(format!("Failed to stat {}: {}", path, e)))
}
pub async fn write(&self, path: &str, data: Vec<u8>) -> Result<()> {
let operator = self.get_operator_for_path(path).await?;
let normalized = self.normalize_path(path);
operator
.write(normalized, data)
.await
.map(|_| ()) .map_err(|e| Error::IoError(format!("Failed to write {}: {}", path, e)))
}
pub async fn exists(&self, path: &str) -> Result<bool> {
let operator = self.get_operator_for_path(path).await?;
let normalized = self.normalize_path(path);
match operator.exists(normalized).await {
Ok(exists) => Ok(exists),
Err(e) => Err(Error::IoError(format!(
"Failed to check existence of {}: {}",
path, e
))),
}
}
pub async fn delete(&self, path: &str) -> Result<()> {
let operator = self.get_operator_for_path(path).await?;
let normalized = self.normalize_path(path);
operator
.delete(normalized)
.await
.map_err(|e| Error::IoError(format!("Failed to delete {}: {}", path, e)))
}
pub fn register_table(
&self,
table_location: &str,
namespace: &str,
table_name: &str,
) -> Result<()> {
if let Some(ref provider) = self.vended_credential_provider {
provider.register_table(table_location, namespace, table_name)
} else {
Ok(()) }
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_file_io_creation() {
let op = Operator::via_iter(opendal::Scheme::Memory, []).unwrap();
let _file_io = FileIO::new(op);
}
}