use crate::config::LayerStorageConfig;
use crate::error::{LayerStorageError, Result};
use crate::snapshot::{calculate_directory_digest, create_snapshot, extract_snapshot};
use crate::types::{ContainerLayerId, LayerSnapshot, PendingUpload, SyncState};
use aws_sdk_s3::primitives::ByteStream;
use aws_sdk_s3::types::{CompletedMultipartUpload, CompletedPart as S3CompletedPart};
use aws_sdk_s3::Client as S3Client;
use sqlx::sqlite::{SqliteConnectOptions, SqlitePoolOptions};
use sqlx::SqlitePool;
use std::collections::HashMap;
use std::path::Path;
use std::str::FromStr;
use std::sync::Arc;
use tokio::fs::File;
use tokio::io::AsyncReadExt;
use tokio::sync::RwLock;
use tracing::{debug, info, instrument, warn};
pub struct LayerSyncManager {
config: LayerStorageConfig,
s3_client: S3Client,
pool: SqlitePool,
states: Arc<RwLock<HashMap<String, SyncState>>>,
}
impl LayerSyncManager {
pub async fn new(config: LayerStorageConfig) -> Result<Self> {
tokio::fs::create_dir_all(&config.staging_dir).await?;
if let Some(parent) = config.state_db_path.parent() {
tokio::fs::create_dir_all(parent).await?;
}
let mut aws_config_builder = aws_config::from_env();
if let Some(region) = &config.region {
aws_config_builder =
aws_config_builder.region(aws_sdk_s3::config::Region::new(region.clone()));
}
let aws_config = aws_config_builder.load().await;
let s3_config = if let Some(endpoint) = &config.endpoint_url {
aws_sdk_s3::config::Builder::from(&aws_config)
.endpoint_url(endpoint)
.force_path_style(true)
.build()
} else {
aws_sdk_s3::config::Builder::from(&aws_config).build()
};
let s3_client = S3Client::from_conf(s3_config);
let db_url = format!("sqlite:{}?mode=rwc", config.state_db_path.display());
let connect_options = SqliteConnectOptions::from_str(&db_url)
.map_err(|e| LayerStorageError::Database(e.to_string()))?
.create_if_missing(true);
let pool = SqlitePoolOptions::new()
.max_connections(5)
.connect_with(connect_options)
.await
.map_err(|e| LayerStorageError::Database(e.to_string()))?;
sqlx::query("PRAGMA journal_mode=WAL")
.execute(&pool)
.await
.map_err(|e| LayerStorageError::Database(e.to_string()))?;
sqlx::query(
r"
CREATE TABLE IF NOT EXISTS sync_state (
container_key TEXT PRIMARY KEY NOT NULL,
state_json TEXT NOT NULL,
updated_at TEXT DEFAULT CURRENT_TIMESTAMP
)
",
)
.execute(&pool)
.await
.map_err(|e| LayerStorageError::Database(e.to_string()))?;
let states = Arc::new(RwLock::new(Self::load_all_states(&pool).await?));
Ok(Self {
config,
s3_client,
pool,
states,
})
}
async fn load_all_states(pool: &SqlitePool) -> Result<HashMap<String, SyncState>> {
let rows: Vec<(String, String)> =
sqlx::query_as("SELECT container_key, state_json FROM sync_state")
.fetch_all(pool)
.await
.map_err(|e| LayerStorageError::Database(e.to_string()))?;
let mut states = HashMap::new();
for (key, json) in rows {
let state: SyncState = serde_json::from_str(&json)?;
states.insert(key, state);
}
Ok(states)
}
async fn save_state(&self, state: &SyncState) -> Result<()> {
let key = state.container_id.to_key();
let value = serde_json::to_string(state)?;
sqlx::query(
r"
INSERT OR REPLACE INTO sync_state (container_key, state_json, updated_at)
VALUES (?, ?, CURRENT_TIMESTAMP)
",
)
.bind(&key)
.bind(&value)
.execute(&self.pool)
.await
.map_err(|e| LayerStorageError::Database(e.to_string()))?;
Ok(())
}
#[instrument(skip(self))]
pub async fn register_container(&self, container_id: ContainerLayerId) -> Result<()> {
let key = container_id.to_key();
let mut states = self.states.write().await;
if let std::collections::hash_map::Entry::Vacant(e) = states.entry(key) {
let state = SyncState::new(container_id);
self.save_state(&state).await?;
e.insert(state);
info!("Registered new container for layer sync");
}
Ok(())
}
#[instrument(skip(self, upper_layer_path))]
pub async fn check_for_changes(
&self,
container_id: &ContainerLayerId,
upper_layer_path: impl AsRef<Path>,
) -> Result<bool> {
let key = container_id.to_key();
let states = self.states.read().await;
let state = states
.get(&key)
.ok_or_else(|| LayerStorageError::NotFound(key.clone()))?;
let current_digest = calculate_directory_digest(upper_layer_path)?;
Ok(state.local_digest.as_ref() != Some(¤t_digest))
}
#[instrument(skip(self, upper_layer_path), fields(container = %container_id))]
pub async fn sync_layer(
&self,
container_id: &ContainerLayerId,
upper_layer_path: impl AsRef<Path>,
) -> Result<Option<LayerSnapshot>> {
let upper_layer_path = upper_layer_path.as_ref();
let key = container_id.to_key();
{
let states = self.states.read().await;
if let Some(state) = states.get(&key) {
if let Some(pending) = &state.pending_upload {
info!("Found pending upload, attempting to resume");
return self.resume_upload(container_id, pending.clone()).await;
}
}
}
let current_digest = calculate_directory_digest(upper_layer_path)?;
{
let states = self.states.read().await;
if let Some(state) = states.get(&key) {
if state.remote_digest.as_ref() == Some(¤t_digest) {
debug!("Layer already synced, no changes");
return Ok(None);
}
}
}
let tarball_path = self
.config
.staging_dir
.join(format!("{current_digest}.tar.zst"));
let snapshot = tokio::task::spawn_blocking({
let source = upper_layer_path.to_path_buf();
let output = tarball_path.clone();
let level = self.config.compression_level;
move || create_snapshot(source, output, level)
})
.await
.map_err(|e| LayerStorageError::Io(std::io::Error::other(e)))??;
self.upload_snapshot(container_id, &tarball_path, &snapshot)
.await?;
{
let mut states = self.states.write().await;
if let Some(state) = states.get_mut(&key) {
state.local_digest = Some(snapshot.digest.clone());
state.remote_digest = Some(snapshot.digest.clone());
state.last_sync = Some(chrono::Utc::now());
state.pending_upload = None;
self.save_state(state).await?;
}
}
let _ = tokio::fs::remove_file(&tarball_path).await;
Ok(Some(snapshot))
}
#[allow(clippy::cast_possible_wrap)]
#[instrument(skip(self, tarball_path, snapshot))]
async fn upload_snapshot(
&self,
container_id: &ContainerLayerId,
tarball_path: &Path,
snapshot: &LayerSnapshot,
) -> Result<()> {
let object_key = self.config.object_key(&snapshot.digest);
let file_size = tokio::fs::metadata(tarball_path).await?.len();
let part_size = self.config.part_size_bytes;
#[allow(clippy::cast_possible_truncation)]
let total_parts = file_size.div_ceil(part_size) as u32;
info!(
"Uploading {} ({} bytes) in {} parts",
object_key, file_size, total_parts
);
let create_response = self
.s3_client
.create_multipart_upload()
.bucket(&self.config.bucket)
.key(&object_key)
.content_type("application/zstd")
.send()
.await
.map_err(|e| LayerStorageError::S3(e.to_string()))?;
let upload_id = create_response
.upload_id()
.ok_or_else(|| LayerStorageError::S3("No upload ID returned".to_string()))?
.to_string();
let pending = PendingUpload {
upload_id: upload_id.clone(),
object_key: object_key.clone(),
total_parts,
completed_parts: HashMap::new(),
part_size,
local_tarball_path: tarball_path.to_path_buf(),
started_at: chrono::Utc::now(),
digest: snapshot.digest.clone(),
};
{
let key = container_id.to_key();
let mut states = self.states.write().await;
if let Some(state) = states.get_mut(&key) {
state.pending_upload = Some(pending.clone());
self.save_state(state).await?;
}
}
let completed_parts = self
.upload_parts(
tarball_path,
&upload_id,
&object_key,
total_parts,
part_size,
)
.await?;
let completed_upload = CompletedMultipartUpload::builder()
.set_parts(Some(
completed_parts
.into_iter()
.map(|(num, etag)| {
S3CompletedPart::builder()
.part_number(num as i32)
.e_tag(etag)
.build()
})
.collect(),
))
.build();
self.s3_client
.complete_multipart_upload()
.bucket(&self.config.bucket)
.key(&object_key)
.upload_id(&upload_id)
.multipart_upload(completed_upload)
.send()
.await
.map_err(|e| LayerStorageError::S3(e.to_string()))?;
let metadata_key = self.config.metadata_key(&snapshot.digest);
let metadata_json = serde_json::to_vec(snapshot)?;
self.s3_client
.put_object()
.bucket(&self.config.bucket)
.key(&metadata_key)
.body(ByteStream::from(metadata_json))
.content_type("application/json")
.send()
.await
.map_err(|e| LayerStorageError::S3(e.to_string()))?;
info!("Upload complete: {}", object_key);
Ok(())
}
#[allow(clippy::cast_possible_truncation, clippy::cast_possible_wrap)]
async fn upload_parts(
&self,
tarball_path: &Path,
upload_id: &str,
object_key: &str,
total_parts: u32,
part_size: u64,
) -> Result<Vec<(u32, String)>> {
let mut completed = Vec::new();
for part_number in 1..=total_parts {
let offset = (u64::from(part_number) - 1) * part_size;
let mut file = File::open(tarball_path).await?;
file.seek(std::io::SeekFrom::Start(offset)).await?;
let mut buffer = vec![0u8; part_size as usize];
let bytes_read = file.read(&mut buffer).await?;
buffer.truncate(bytes_read);
let response = self
.s3_client
.upload_part()
.bucket(&self.config.bucket)
.key(object_key)
.upload_id(upload_id)
.part_number(part_number as i32)
.body(ByteStream::from(buffer))
.send()
.await
.map_err(|e| LayerStorageError::S3(e.to_string()))?;
let etag = response
.e_tag()
.ok_or_else(|| LayerStorageError::S3("No ETag returned for part".to_string()))?
.to_string();
debug!("Uploaded part {}/{}: {}", part_number, total_parts, etag);
completed.push((part_number, etag));
}
Ok(completed)
}
#[allow(clippy::cast_possible_truncation, clippy::cast_possible_wrap)]
#[instrument(skip(self, pending))]
async fn resume_upload(
&self,
container_id: &ContainerLayerId,
pending: PendingUpload,
) -> Result<Option<LayerSnapshot>> {
let missing = pending.missing_parts();
if missing.is_empty() {
info!("All parts uploaded, completing multipart upload");
} else {
info!("Resuming upload, {} parts remaining", missing.len());
if !pending.local_tarball_path.exists() {
warn!("Local tarball missing, aborting upload and starting fresh");
self.abort_upload(&pending).await?;
let key = container_id.to_key();
let mut states = self.states.write().await;
if let Some(state) = states.get_mut(&key) {
state.pending_upload = None;
self.save_state(state).await?;
}
return Err(LayerStorageError::UploadInterrupted(
"Local tarball missing".to_string(),
));
}
for part_number in missing {
let offset = (u64::from(part_number) - 1) * pending.part_size;
let mut file = File::open(&pending.local_tarball_path).await?;
file.seek(std::io::SeekFrom::Start(offset)).await?;
let mut buffer = vec![0u8; pending.part_size as usize];
let bytes_read = file.read(&mut buffer).await?;
buffer.truncate(bytes_read);
let response = self
.s3_client
.upload_part()
.bucket(&self.config.bucket)
.key(&pending.object_key)
.upload_id(&pending.upload_id)
.part_number(part_number as i32)
.body(ByteStream::from(buffer))
.send()
.await
.map_err(|e| LayerStorageError::S3(e.to_string()))?;
let etag = response
.e_tag()
.ok_or_else(|| LayerStorageError::S3("No ETag returned".to_string()))?
.to_string();
debug!("Uploaded part {}: {}", part_number, etag);
}
}
let parts_response = self
.s3_client
.list_parts()
.bucket(&self.config.bucket)
.key(&pending.object_key)
.upload_id(&pending.upload_id)
.send()
.await
.map_err(|e| LayerStorageError::S3(e.to_string()))?;
let completed_parts: Vec<S3CompletedPart> = parts_response
.parts()
.iter()
.map(|p| {
S3CompletedPart::builder()
.part_number(p.part_number().unwrap_or(0))
.e_tag(p.e_tag().unwrap_or_default())
.build()
})
.collect();
let completed_upload = CompletedMultipartUpload::builder()
.set_parts(Some(completed_parts))
.build();
self.s3_client
.complete_multipart_upload()
.bucket(&self.config.bucket)
.key(&pending.object_key)
.upload_id(&pending.upload_id)
.multipart_upload(completed_upload)
.send()
.await
.map_err(|e| LayerStorageError::S3(e.to_string()))?;
let key = container_id.to_key();
{
let mut states = self.states.write().await;
if let Some(state) = states.get_mut(&key) {
state.local_digest = Some(pending.digest.clone());
state.remote_digest = Some(pending.digest.clone());
state.last_sync = Some(chrono::Utc::now());
state.pending_upload = None;
self.save_state(state).await?;
}
}
let _ = tokio::fs::remove_file(&pending.local_tarball_path).await;
info!("Upload resumed and completed successfully");
self.get_snapshot_metadata(&pending.digest).await.map(Some)
}
async fn abort_upload(&self, pending: &PendingUpload) -> Result<()> {
self.s3_client
.abort_multipart_upload()
.bucket(&self.config.bucket)
.key(&pending.object_key)
.upload_id(&pending.upload_id)
.send()
.await
.map_err(|e| LayerStorageError::S3(e.to_string()))?;
Ok(())
}
#[instrument(skip(self, target_path))]
pub async fn restore_layer(
&self,
container_id: &ContainerLayerId,
target_path: impl AsRef<Path>,
) -> Result<LayerSnapshot> {
let target_path = target_path.as_ref();
let key = container_id.to_key();
let remote_digest = {
let states = self.states.read().await;
states
.get(&key)
.and_then(|s| s.remote_digest.clone())
.ok_or_else(|| LayerStorageError::NotFound(format!("No remote layer for {key}")))?
};
info!("Restoring layer {} from S3", remote_digest);
let tarball_path = self
.config
.staging_dir
.join(format!("{remote_digest}.tar.zst"));
let object_key = self.config.object_key(&remote_digest);
let response = self
.s3_client
.get_object()
.bucket(&self.config.bucket)
.key(&object_key)
.send()
.await
.map_err(|e| LayerStorageError::S3(e.to_string()))?;
let mut file = tokio::fs::File::create(&tarball_path).await?;
let mut stream = response.body.into_async_read();
tokio::io::copy(&mut stream, &mut file).await?;
let snapshot = self.get_snapshot_metadata(&remote_digest).await?;
tokio::task::spawn_blocking({
let tarball = tarball_path.clone();
let target = target_path.to_path_buf();
let digest = remote_digest.clone();
move || extract_snapshot(tarball, target, Some(&digest))
})
.await
.map_err(|e| LayerStorageError::Io(std::io::Error::other(e)))??;
{
let mut states = self.states.write().await;
if let Some(state) = states.get_mut(&key) {
state.local_digest = Some(remote_digest);
self.save_state(state).await?;
}
}
let _ = tokio::fs::remove_file(&tarball_path).await;
info!("Layer restored successfully");
Ok(snapshot)
}
async fn get_snapshot_metadata(&self, digest: &str) -> Result<LayerSnapshot> {
let metadata_key = self.config.metadata_key(digest);
let response = self
.s3_client
.get_object()
.bucket(&self.config.bucket)
.key(&metadata_key)
.send()
.await
.map_err(|e| LayerStorageError::S3(e.to_string()))?;
let bytes = response
.body
.collect()
.await
.map_err(|e| LayerStorageError::S3(e.to_string()))?
.into_bytes();
serde_json::from_slice(&bytes).map_err(Into::into)
}
pub async fn list_containers(&self) -> Vec<ContainerLayerId> {
let states = self.states.read().await;
states.values().map(|s| s.container_id.clone()).collect()
}
pub async fn get_sync_state(&self, container_id: &ContainerLayerId) -> Option<SyncState> {
let states = self.states.read().await;
states.get(&container_id.to_key()).cloned()
}
}
use tokio::io::AsyncSeekExt;