use async_trait::async_trait;
use aws_sdk_s3::primitives::ByteStream;
use aws_sdk_s3::Client;
use aws_smithy_runtime_api::client::result::SdkError;
use serde::{Deserialize, Serialize};
use crate::types::{
AnalyticsQuery, ConversationManifest, SegmentHash, StoredSegment,
};
use super::backend::{StorageBackend, StorageError, StorageResult};
pub const MAX_CAS_RETRIES: u32 = 8;
#[derive(Debug, Clone, Serialize, Deserialize)]
struct SegmentMetaSidecar {
segment_type: String,
tokenizer: String,
token_count: u32,
raw_size: u32,
compressed_size: u32,
ref_count: u64,
created_at: String,
}
pub struct S3Backend {
client: Client,
bucket: String,
prefix: String,
}
impl S3Backend {
pub fn new(
client: Client,
bucket: impl Into<String>,
prefix: impl Into<String>,
) -> Self {
let mut prefix = prefix.into();
if !prefix.is_empty() && !prefix.ends_with('/') {
prefix.push('/');
}
Self {
client,
bucket: bucket.into(),
prefix,
}
}
pub async fn from_env(
bucket: impl Into<String>,
prefix: impl Into<String>,
) -> Result<Self, StorageError> {
let config = aws_config::load_defaults(aws_config::BehaviorVersion::latest()).await;
let client = Client::new(&config);
Ok(Self::new(client, bucket, prefix))
}
pub async fn with_options(
bucket: impl Into<String>,
prefix: impl Into<String>,
endpoint: Option<String>,
region: Option<String>,
) -> Result<Self, StorageError> {
use aws_sdk_s3::config::{Builder, Region};
let mut cfg_loader = aws_config::defaults(aws_config::BehaviorVersion::latest());
if let Some(r) = region {
cfg_loader = cfg_loader.region(Region::new(r));
}
let shared_cfg = cfg_loader.load().await;
let mut s3_cfg = Builder::from(&shared_cfg);
if let Some(ep) = endpoint {
s3_cfg = s3_cfg.endpoint_url(ep).force_path_style(true);
}
let client = Client::from_conf(s3_cfg.build());
Ok(Self::new(client, bucket, prefix))
}
pub fn client(&self) -> &Client {
&self.client
}
pub fn bucket(&self) -> &str {
&self.bucket
}
pub fn prefix(&self) -> &str {
&self.prefix
}
fn segment_key(&self, hash: &SegmentHash) -> String {
format!("{}segments/{}", self.prefix, hash.0)
}
fn segment_meta_key(&self, hash: &SegmentHash) -> String {
format!("{}segments/{}.meta", self.prefix, hash.0)
}
fn manifest_key(&self, id: &str) -> String {
format!("{}manifests/{}.json", self.prefix, id)
}
async fn get_object_bytes(&self, key: &str) -> StorageResult<Vec<u8>> {
let resp = self
.client
.get_object()
.bucket(&self.bucket)
.key(key)
.send()
.await
.map_err(|e| map_get_error(e, key))?;
let bytes = resp
.body
.collect()
.await
.map_err(|e| StorageError::BackendError(e.to_string()))?
.into_bytes()
.to_vec();
Ok(bytes)
}
async fn read_meta_with_etag(
&self,
hash: &SegmentHash,
) -> StorageResult<(SegmentMetaSidecar, String)> {
let key = self.segment_meta_key(hash);
let resp = self
.client
.get_object()
.bucket(&self.bucket)
.key(&key)
.send()
.await
.map_err(|e| map_get_error(e, &key))?;
let etag = resp.e_tag().unwrap_or("").to_owned();
let body = resp
.body
.collect()
.await
.map_err(|e| StorageError::BackendError(e.to_string()))?
.into_bytes()
.to_vec();
let meta: SegmentMetaSidecar = serde_json::from_slice(&body)
.map_err(|e| StorageError::SerializationError(e.to_string()))?;
Ok((meta, etag))
}
async fn write_meta_if_match(
&self,
hash: &SegmentHash,
meta: &SegmentMetaSidecar,
expected_etag: &str,
) -> StorageResult<bool> {
let body = serde_json::to_vec(meta)
.map_err(|e| StorageError::SerializationError(e.to_string()))?;
let key = self.segment_meta_key(hash);
let result = self
.client
.put_object()
.bucket(&self.bucket)
.key(&key)
.if_match(expected_etag)
.body(ByteStream::from(body))
.send()
.await;
match result {
Ok(_) => Ok(true),
Err(e) => {
if is_precondition_failed(&e) {
Ok(false)
} else {
Err(StorageError::BackendError(format!(
"put_object {} (If-Match): {e}",
key
)))
}
}
}
}
async fn create_meta_if_absent(
&self,
hash: &SegmentHash,
meta: &SegmentMetaSidecar,
) -> StorageResult<bool> {
let body = serde_json::to_vec(meta)
.map_err(|e| StorageError::SerializationError(e.to_string()))?;
let key = self.segment_meta_key(hash);
let result = self
.client
.put_object()
.bucket(&self.bucket)
.key(&key)
.if_none_match("*")
.body(ByteStream::from(body))
.send()
.await;
match result {
Ok(_) => Ok(true),
Err(e) => {
if is_precondition_failed(&e) {
Ok(false)
} else {
Err(StorageError::BackendError(format!(
"put_object {} (If-None-Match): {e}",
key
)))
}
}
}
}
async fn cas_update_meta<F>(
&self,
hash: &SegmentHash,
mut f: F,
) -> StorageResult<SegmentMetaSidecar>
where
F: FnMut(&mut SegmentMetaSidecar),
{
for attempt in 0..MAX_CAS_RETRIES {
let (mut meta, etag) = self.read_meta_with_etag(hash).await?;
f(&mut meta);
match self.write_meta_if_match(hash, &meta, &etag).await? {
true => return Ok(meta),
false => {
let ms = (10u64 << attempt.min(5)).min(320);
tokio::time::sleep(std::time::Duration::from_millis(ms)).await;
}
}
}
Err(StorageError::BackendError(format!(
"CAS retry budget exhausted on {} (.meta)",
hash.0
)))
}
async fn key_exists(&self, key: &str) -> StorageResult<bool> {
match self
.client
.head_object()
.bucket(&self.bucket)
.key(key)
.send()
.await
{
Ok(_) => Ok(true),
Err(e) => {
if is_not_found(&e) {
Ok(false)
} else {
Err(StorageError::BackendError(format!(
"head_object {}: {e}",
key
)))
}
}
}
}
async fn delete_key(&self, key: &str) -> StorageResult<()> {
self.client
.delete_object()
.bucket(&self.bucket)
.key(key)
.send()
.await
.map_err(|e| StorageError::BackendError(format!("delete_object {key}: {e}")))?;
Ok(())
}
async fn list_keys_with_prefix(&self, prefix: &str) -> StorageResult<Vec<String>> {
let mut out = Vec::new();
let mut continuation: Option<String> = None;
loop {
let mut req = self
.client
.list_objects_v2()
.bucket(&self.bucket)
.prefix(prefix);
if let Some(token) = continuation.take() {
req = req.continuation_token(token);
}
let resp = req
.send()
.await
.map_err(|e| StorageError::BackendError(format!("list_objects_v2 {prefix}: {e}")))?;
for obj in resp.contents() {
if let Some(k) = obj.key() {
out.push(k.to_owned());
}
}
if let Some(token) = resp.next_continuation_token() {
continuation = Some(token.to_owned());
} else {
break;
}
}
Ok(out)
}
}
#[async_trait]
impl StorageBackend for S3Backend {
async fn put_segment(&self, segment: &StoredSegment) -> StorageResult<()> {
let seg_key = self.segment_key(&segment.hash);
self.client
.put_object()
.bucket(&self.bucket)
.key(&seg_key)
.body(ByteStream::from(segment.compressed_data.clone()))
.send()
.await
.map_err(|e| StorageError::BackendError(format!("put_object {seg_key}: {e}")))?;
let meta = SegmentMetaSidecar {
segment_type: segment.segment_type.to_string(),
tokenizer: segment.tokenizer.clone(),
token_count: segment.token_count,
raw_size: segment.raw_size,
compressed_size: segment.compressed_size,
ref_count: segment.ref_count,
created_at: segment.created_at.to_rfc3339(),
};
if self.create_meta_if_absent(&segment.hash, &meta).await? {
return Ok(());
}
self.increment_ref(&segment.hash).await
}
async fn get_segment(&self, hash: &SegmentHash) -> StorageResult<StoredSegment> {
let seg_bytes = self
.get_object_bytes(&self.segment_key(hash))
.await
.map_err(|e| match e {
StorageError::ConversationNotFound(_) => {
StorageError::SegmentNotFound(hash.0.clone())
}
other => other,
})?;
let (meta, _etag) = self.read_meta_with_etag(hash).await?;
let segment_type = meta
.segment_type
.parse()
.unwrap_or(crate::types::SegmentType::UserTurn);
let created_at = chrono::DateTime::parse_from_rfc3339(&meta.created_at)
.map(|dt| dt.with_timezone(&chrono::Utc))
.unwrap_or_else(|_| chrono::Utc::now());
Ok(StoredSegment {
hash: hash.clone(),
segment_type,
tokenizer: meta.tokenizer,
token_count: meta.token_count,
compressed_data: seg_bytes,
raw_size: meta.raw_size,
compressed_size: meta.compressed_size,
ref_count: meta.ref_count,
created_at,
})
}
async fn has_segment(&self, hash: &SegmentHash) -> StorageResult<bool> {
self.key_exists(&self.segment_key(hash)).await
}
async fn increment_ref(&self, hash: &SegmentHash) -> StorageResult<()> {
self.cas_update_meta(hash, |meta| meta.ref_count += 1).await?;
Ok(())
}
async fn decrement_ref(&self, hash: &SegmentHash) -> StorageResult<bool> {
let meta = self
.cas_update_meta(hash, |meta| {
meta.ref_count = meta.ref_count.saturating_sub(1);
})
.await?;
Ok(meta.ref_count == 0)
}
async fn delete_segment(&self, hash: &SegmentHash) -> StorageResult<()> {
let _ = self.delete_key(&self.segment_key(hash)).await;
let _ = self.delete_key(&self.segment_meta_key(hash)).await;
Ok(())
}
async fn replace_segment_data(
&self,
hash: &SegmentHash,
new_data: Vec<u8>,
) -> StorageResult<()> {
if !self.has_segment(hash).await? {
return Err(StorageError::SegmentNotFound(hash.0.clone()));
}
let key = self.segment_key(hash);
self.client
.put_object()
.bucket(&self.bucket)
.key(&key)
.body(ByteStream::from(new_data.clone()))
.send()
.await
.map_err(|e| StorageError::BackendError(format!("put_object {key}: {e}")))?;
let new_size = new_data.len() as u32;
self.cas_update_meta(hash, |meta| meta.compressed_size = new_size)
.await?;
Ok(())
}
async fn put_manifest(&self, manifest: &ConversationManifest) -> StorageResult<()> {
let body = serde_json::to_vec(manifest)
.map_err(|e| StorageError::SerializationError(e.to_string()))?;
let key = self.manifest_key(&manifest.id);
self.client
.put_object()
.bucket(&self.bucket)
.key(&key)
.body(ByteStream::from(body))
.send()
.await
.map_err(|e| StorageError::BackendError(format!("put_object {key}: {e}")))?;
Ok(())
}
async fn get_manifest(&self, id: &str) -> StorageResult<ConversationManifest> {
let key = self.manifest_key(id);
let bytes = self.get_object_bytes(&key).await.map_err(|e| match e {
StorageError::SegmentNotFound(_) | StorageError::ConversationNotFound(_) => {
StorageError::ConversationNotFound(id.to_owned())
}
other => other,
})?;
serde_json::from_slice(&bytes)
.map_err(|e| StorageError::SerializationError(e.to_string()))
}
async fn delete_manifest(&self, id: &str) -> StorageResult<()> {
let _ = self.delete_key(&self.manifest_key(id)).await;
Ok(())
}
async fn list_conversations(
&self,
query: &AnalyticsQuery,
limit: u64,
offset: u64,
) -> StorageResult<Vec<String>> {
let mprefix = format!("{}manifests/", self.prefix);
let keys = self.list_keys_with_prefix(&mprefix).await?;
let mut ids: Vec<String> = keys
.into_iter()
.filter_map(|k| {
let stripped = k.strip_prefix(&mprefix)?;
let id = stripped.strip_suffix(".json")?;
Some(id.to_owned())
})
.collect();
if query.model.is_some()
|| query.application.is_some()
|| query.date_from.is_some()
|| query.date_to.is_some()
{
let mut filtered = Vec::with_capacity(ids.len());
for id in &ids {
let m = match self.get_manifest(id).await {
Ok(m) => m,
Err(_) => continue,
};
if let Some(want) = &query.model {
if &m.model != want {
continue;
}
}
if let Some(want) = &query.application {
if m.application.as_deref() != Some(want.as_str()) {
continue;
}
}
if let Some(from) = query.date_from {
if m.created_at < from {
continue;
}
}
if let Some(to) = query.date_to {
if m.created_at > to {
continue;
}
}
filtered.push(id.clone());
}
ids = filtered;
}
Ok(ids
.into_iter()
.skip(offset as usize)
.take(limit as usize)
.collect())
}
async fn list_garbage(&self) -> StorageResult<Vec<SegmentHash>> {
let prefix = format!("{}segments/", self.prefix);
let keys = self.list_keys_with_prefix(&prefix).await?;
let mut out = Vec::new();
for k in keys {
let Some(stripped) = k.strip_prefix(&prefix) else { continue };
let Some(hash_str) = stripped.strip_suffix(".meta") else { continue };
let hash = SegmentHash(hash_str.to_owned());
if let Ok((meta, _)) = self.read_meta_with_etag(&hash).await {
if meta.ref_count == 0 {
out.push(hash);
}
}
}
Ok(out)
}
async fn garbage_collect(&self) -> StorageResult<u64> {
let candidates = self.list_garbage().await?;
let mut deleted = 0u64;
for hash in candidates {
self.delete_segment(&hash).await?;
deleted += 1;
}
Ok(deleted)
}
async fn storage_size_bytes(&self) -> StorageResult<u64> {
let prefix = format!("{}segments/", self.prefix);
let mut total: u64 = 0;
let mut continuation: Option<String> = None;
loop {
let mut req = self
.client
.list_objects_v2()
.bucket(&self.bucket)
.prefix(&prefix);
if let Some(token) = continuation.take() {
req = req.continuation_token(token);
}
let resp = req
.send()
.await
.map_err(|e| StorageError::BackendError(format!("list_objects_v2 {prefix}: {e}")))?;
for obj in resp.contents() {
if let Some(k) = obj.key() {
if k.ends_with(".meta") {
continue;
}
}
if let Some(s) = obj.size() {
total += s as u64;
}
}
if let Some(token) = resp.next_continuation_token() {
continuation = Some(token.to_owned());
} else {
break;
}
}
Ok(total)
}
}
fn map_get_error<E>(e: SdkError<E, aws_smithy_runtime_api::http::Response>, key: &str) -> StorageError
where
E: std::fmt::Display + std::fmt::Debug,
{
if let SdkError::ServiceError(svc) = &e {
let status = svc.raw().status().as_u16();
if status == 404 {
return StorageError::SegmentNotFound(key.to_owned());
}
}
StorageError::BackendError(format!("get_object {key}: {e}"))
}
fn is_precondition_failed<E>(e: &SdkError<E, aws_smithy_runtime_api::http::Response>) -> bool
where
E: std::fmt::Debug,
{
matches!(
e,
SdkError::ServiceError(svc) if svc.raw().status().as_u16() == 412
)
}
fn is_not_found<E>(e: &SdkError<E, aws_smithy_runtime_api::http::Response>) -> bool
where
E: std::fmt::Debug,
{
matches!(
e,
SdkError::ServiceError(svc) if svc.raw().status().as_u16() == 404
)
}
#[cfg(test)]
mod tests {
#[test]
fn prefix_normalisation_appends_slash() {
let cases: &[(&str, &str)] = &[
("", ""),
("foo", "foo/"),
("foo/", "foo/"),
("nested/path", "nested/path/"),
];
for (input, expected) in cases {
let mut got = String::from(*input);
if !got.is_empty() && !got.ends_with('/') {
got.push('/');
}
assert_eq!(got, *expected, "prefix normalisation for {input:?}");
}
}
}