use crate::config::ObjectStoreConfig;
use anyhow::{Context, Result};
use s3::creds::Credentials;
use s3::{Bucket, Region};
use std::path::Path;
use std::sync::Arc;
use tokio::fs;
use tracing::{debug, info};
pub struct CheckpointStore {
bucket: Box<Bucket>,
}
pub struct UploadStats {
pub files: usize,
pub bytes: u64,
pub duration_secs: f64,
}
pub struct DownloadStats {
pub files: usize,
pub bytes: u64,
pub duration_secs: f64,
}
impl CheckpointStore {
pub fn new(config: &ObjectStoreConfig) -> Result<Self> {
let region = Region::Custom {
region: config.region.clone(),
endpoint: config.endpoint.clone(),
};
let credentials = Credentials::new(
Some(&config.access_key),
Some(&config.secret_key),
None,
None,
None,
)
.context("Failed to create S3 credentials")?;
let bucket =
Bucket::new(&config.bucket, region, credentials).context("Failed to create S3 bucket client")?;
let bucket = bucket.with_path_style();
Ok(Self { bucket })
}
pub async fn upload_checkpoint(&self, model: &str, images_dir: &Path) -> Result<UploadStats> {
let start = std::time::Instant::now();
let prefix = format!("{}/images/", model);
let list = self
.bucket
.list(prefix.clone(), None)
.await
.context("Failed to list existing S3 objects for cleanup")?;
let existing: Vec<String> = list
.iter()
.flat_map(|r| r.contents.iter())
.map(|o| o.key.clone())
.collect();
if !existing.is_empty() {
info!(model, files = existing.len(), "Cleaning up previous checkpoint in S3");
for key in &existing {
self.bucket
.delete_object(key)
.await
.with_context(|| format!("Failed to delete S3 object: {}", key))?;
}
}
let mut entries = fs::read_dir(images_dir)
.await
.with_context(|| format!("Failed to read checkpoint dir: {}", images_dir.display()))?;
let mut large_files = Vec::new();
let mut small_files = Vec::new();
while let Some(entry) = entries.next_entry().await? {
let metadata = entry.metadata().await?;
if !metadata.is_file() {
continue;
}
let size = metadata.len();
let path = entry.path();
if size > 1_048_576 {
large_files.push((path, size));
} else {
small_files.push((path, size));
}
}
let total_files = large_files.len() + small_files.len();
let total_bytes: u64 = large_files
.iter()
.chain(&small_files)
.map(|(_, size)| size)
.sum();
info!(
model,
files = total_files,
size_mb = total_bytes / 1_048_576,
"Uploading checkpoint to S3"
);
let large_sem = Arc::new(tokio::sync::Semaphore::new(4));
let mut set = tokio::task::JoinSet::new();
for (path, size) in large_files {
let key = format!(
"{}/images/{}",
model,
path.file_name().unwrap().to_string_lossy()
);
let bucket = self.bucket.clone();
let sem = large_sem.clone();
set.spawn(async move {
let _permit = sem.acquire().await?;
let data = fs::read(&path).await?;
bucket.put_object(&key, &data).await?;
debug!(key, size, "Uploaded large file");
Ok::<_, anyhow::Error>(())
});
}
let small_sem = Arc::new(tokio::sync::Semaphore::new(32));
for (path, size) in small_files {
let key = format!(
"{}/images/{}",
model,
path.file_name().unwrap().to_string_lossy()
);
let bucket = self.bucket.clone();
let sem = small_sem.clone();
set.spawn(async move {
let _permit = sem.acquire().await?;
let data = fs::read(&path).await?;
bucket.put_object(&key, &data).await?;
debug!(key, size, "Uploaded small file");
Ok::<_, anyhow::Error>(())
});
}
while let Some(result) = set.join_next().await {
result??;
}
let duration_secs = start.elapsed().as_secs_f64();
info!(
model,
files = total_files,
size_mb = total_bytes / 1_048_576,
duration_secs = format!("{:.1}", duration_secs),
"Checkpoint uploaded to S3"
);
Ok(UploadStats {
files: total_files,
bytes: total_bytes,
duration_secs,
})
}
pub async fn download_checkpoint(
&self,
model: &str,
images_dir: &Path,
) -> Result<DownloadStats> {
let start = std::time::Instant::now();
fs::create_dir_all(images_dir)
.await
.with_context(|| format!("Failed to create images dir: {}", images_dir.display()))?;
let prefix = format!("{}/images/", model);
let list = self
.bucket
.list(prefix.clone(), None)
.await
.context("Failed to list S3 objects")?;
let objects: Vec<_> = list.iter().flat_map(|r| r.contents.iter()).collect();
if objects.is_empty() {
anyhow::bail!("No checkpoint found in S3 for model '{}'", model);
}
let total_files = objects.len();
let total_bytes: u64 = objects.iter().map(|o| o.size).sum();
info!(
model,
files = total_files,
size_mb = total_bytes / 1_048_576,
"Downloading checkpoint from S3"
);
let semaphore = Arc::new(tokio::sync::Semaphore::new(8));
let mut set = tokio::task::JoinSet::new();
for obj in objects {
let key = obj.key.clone();
let filename = key
.strip_prefix(&prefix)
.unwrap_or(&key)
.to_string();
let dest = images_dir.join(&filename);
let bucket = self.bucket.clone();
let sem = semaphore.clone();
let size = obj.size;
set.spawn(async move {
let _permit = sem.acquire().await?;
let response = bucket.get_object(&key).await?;
fs::write(&dest, response.as_slice()).await?;
debug!(key, size, "Downloaded file");
Ok::<_, anyhow::Error>(())
});
}
while let Some(result) = set.join_next().await {
result??;
}
let duration_secs = start.elapsed().as_secs_f64();
info!(
model,
files = total_files,
size_mb = total_bytes / 1_048_576,
duration_secs = format!("{:.1}", duration_secs),
"Checkpoint downloaded from S3"
);
Ok(DownloadStats {
files: total_files,
bytes: total_bytes,
duration_secs,
})
}
pub async fn checkpoint_exists(&self, model: &str) -> Result<bool> {
let prefix = format!("{}/images/", model);
let list = self
.bucket
.list(prefix, None)
.await
.context("Failed to list S3 objects")?;
Ok(list.iter().any(|r| !r.contents.is_empty()))
}
}