use crate::config::Config;
use crate::error::{Error, Result};
use crate::store::dblock::DbLock;
use crate::store::s3::{parse_s3_bucket_and_key, s3_local_cache_dir_for_dsn, SyncDb, SyncState};
use crate::store::sqlite::SqliteTables;
use crate::store::{DbOpFuture, DbTables};
use async_trait::async_trait;
use object_store::aws::AmazonS3Builder;
use object_store::path::Path as ObjectPath;
use object_store::{ObjectStore, PutMode, UpdateVersion};
use std::path::PathBuf;
use std::sync::Arc;
use tokio::sync::{Mutex, RwLock};
#[derive(Clone)]
pub struct SnapshotDb {
inner: Arc<RwLock<SnapshotState>>,
write_gate: Arc<Mutex<()>>,
config: Config,
}
struct SnapshotState {
object_store: Arc<dyn ObjectStore>,
object_key: String,
last_etag: Option<String>,
is_dirty: bool,
write_version: u64,
sqlite_tables: SqliteTables,
}
impl SnapshotDb {
pub async fn new(config: &Config) -> Result<Self> {
let (bucket, object_key) = parse_s3_bucket_and_key(&config.dsn)?;
let object_store = build_object_store_from_env(&bucket)?;
let mut config = config.clone();
config.sqlite.use_wal = false;
let sqlite_dsn = sqlite_dsn_for_revision(&config.dsn, None)?;
let sqlite_tables = SqliteTables::new(&sqlite_dsn, &config).await?;
Ok(Self {
inner: Arc::new(RwLock::new(SnapshotState {
object_store,
object_key,
last_etag: None,
is_dirty: false,
write_version: 0,
sqlite_tables,
})),
write_gate: Arc::new(Mutex::new(())),
config,
})
}
pub(crate) fn write_gate(&self) -> &Arc<Mutex<()>> {
&self.write_gate
}
pub async fn state(&self) -> Result<SyncState> {
let (object_store, object_key, local_etag, is_dirty) = {
let guard = self.inner.read().await;
(
guard.object_store.clone(),
guard.object_key.clone(),
guard.last_etag.clone(),
guard.is_dirty,
)
};
let sqlite_path = sqlite_path_for_revision(&self.config.dsn, local_etag.as_deref())?;
if !sqlite_path.exists() {
return Ok(SyncState::LocalMissing);
}
let object_path = ObjectPath::from(object_key.as_str());
let remote_head = object_store.head(&object_path).await;
match remote_head {
Ok(head) => {
let remote_etag = head.e_tag;
if remote_etag == local_etag && !is_dirty {
Ok(SyncState::InSync)
} else {
Ok(SyncState::Diverged {
local_dirty: is_dirty,
})
}
}
Err(object_store::Error::NotFound { .. }) => Ok(SyncState::RemoteMissing {
local_dirty: is_dirty,
}),
Err(e) => Err(map_object_store_error("head", &object_key, &e)),
}
}
}
pub(crate) fn build_object_store_from_env(bucket: &str) -> Result<Arc<dyn ObjectStore>> {
let region = std::env::var("AWS_REGION")
.ok()
.map(|v| v.trim().to_string())
.filter(|v| !v.is_empty());
let endpoint = std::env::var("AWS_ENDPOINT_URL")
.ok()
.map(|v| v.trim().to_string())
.filter(|v| !v.is_empty());
let endpoint_is_local = endpoint.as_ref().is_some_and(|ep| {
let ep_lc = ep.to_ascii_lowercase();
ep_lc.contains("localhost") || ep_lc.contains("127.0.0.1") || ep_lc.contains("localstack")
});
if !endpoint_is_local && std::env::var_os("SSL_CERT_FILE").is_none() {
let fallback_paths = [
"/etc/ssl/certs/ca-certificates.crt",
"/etc/pki/tls/certs/ca-bundle.crt",
"/etc/ssl/cert.pem",
];
if let Some(path) = fallback_paths
.iter()
.find_map(|path| std::path::Path::new(path).exists().then_some(*path))
{
std::env::set_var("SSL_CERT_FILE", path);
}
}
let mut builder = AmazonS3Builder::from_env()
.with_bucket_name(bucket)
.with_virtual_hosted_style_request(false);
if let Some(region) = region {
builder = builder.with_region(region);
}
if let Some(ep) = endpoint {
if ep.starts_with("http://") {
builder = builder.with_allow_http(true);
}
builder = builder.with_endpoint(ep);
}
let store = builder.build().map_err(|e| Error::InvalidConfig {
field: "dsn".to_string(),
message: format!("Failed to build AmazonS3 object store: {}", e),
})?;
Ok(Arc::new(store))
}
#[async_trait]
impl DbLock for SnapshotDb {
fn config(&self) -> &Config {
&self.config
}
fn concurrency_model(&self) -> crate::store::ConcurrencyModel {
crate::store::ConcurrencyModel::SingleProcess
}
async fn with_read<R, F>(&self, f: F) -> Result<R>
where
R: Send,
F: for<'a> FnOnce(&'a dyn DbTables) -> DbOpFuture<'a, R> + Send,
{
let guard = self.inner.read().await;
f(&guard.sqlite_tables).await
}
async fn with_write<R, F>(&self, f: F) -> Result<R>
where
R: Send,
F: for<'a> FnOnce(&'a dyn DbTables) -> DbOpFuture<'a, R> + Send,
{
let mut guard = self.inner.write().await;
let out = f(&guard.sqlite_tables).await?;
guard.is_dirty = true;
guard.write_version = guard.write_version.saturating_add(1);
Ok(out)
}
}
#[async_trait]
impl SyncDb for SnapshotDb {
async fn snapshot(&mut self) -> Result<()> {
let (object_store, object_key, local_etag, is_dirty) = {
let guard = self.inner.read().await;
(
guard.object_store.clone(),
guard.object_key.clone(),
guard.last_etag.clone(),
guard.is_dirty,
)
};
if is_dirty {
return Err(Error::Conflict {
message: format!(
"Snapshot refused for key '{}': local store has unsynced writes",
object_key
),
});
}
let object_path = ObjectPath::from(object_key.as_str());
let head = object_store
.head(&object_path)
.await
.map_err(|e| map_object_store_error("head", &object_key, &e))?;
let remote_etag = head.e_tag.clone();
if remote_etag == local_etag {
return Ok(());
}
let remote = object_store
.get(&object_path)
.await
.map_err(|e| map_object_store_error("get", &object_key, &e))?;
let remote_bytes = remote.bytes().await.map_err(|e| Error::Internal {
message: format!("snapshot get bytes failed for key '{}': {}", object_key, e),
})?;
let sqlite_path = sqlite_path_for_revision(&self.config.dsn, remote_etag.as_deref())?;
std::fs::write(&sqlite_path, remote_bytes.as_ref()).map_err(|e| Error::Internal {
message: format!(
"Failed writing snapshot db {}: {}",
sqlite_path.display(),
e
),
})?;
let sqlite_dsn = sqlite_dsn_for_revision(&self.config.dsn, remote_etag.as_deref())?;
let new_tables = SqliteTables::new(&sqlite_dsn, &self.config)
.await
.map_err(|e| Error::Internal {
message: format!("snapshot reopen sqlite store failed: {e}"),
})?;
let mut guard = self.inner.write().await;
if guard.last_etag != local_etag {
return Err(Error::Conflict {
message: format!(
"Snapshot CAS failed for key '{}': local etag changed from {:?} to {:?}",
object_key, local_etag, guard.last_etag
),
});
}
guard.last_etag = remote_etag;
guard.is_dirty = false;
guard.sqlite_tables = new_tables;
Ok(())
}
async fn sync(&mut self) -> Result<()> {
let (object_store, object_key, start_etag, is_dirty, start_write_version) = {
let guard = self.inner.read().await;
(
guard.object_store.clone(),
guard.object_key.clone(),
guard.last_etag.clone(),
guard.is_dirty,
guard.write_version,
)
};
if !is_dirty {
return Ok(());
}
let sqlite_path = sqlite_path_for_revision(&self.config.dsn, start_etag.as_deref())?;
let payload = std::fs::read(&sqlite_path).map_err(|e| Error::Internal {
message: format!(
"Failed reading snapshot db {}: {}",
sqlite_path.display(),
e
),
})?;
let payload_for_next_revision = payload.clone();
let object_path = ObjectPath::from(object_key.as_str());
let put_result = match &start_etag {
Some(last_etag) => {
let version = UpdateVersion {
e_tag: Some(last_etag.clone()),
version: None,
};
match object_store
.put_opts(
&object_path,
payload.clone().into(),
PutMode::Update(version).into(),
)
.await
{
Ok(put_result) => put_result,
Err(object_store::Error::NotFound { .. }) => object_store
.put_opts(&object_path, payload.into(), PutMode::Create.into())
.await
.map_err(|e| map_object_store_error("put", &object_key, &e))?,
Err(e) => return Err(map_object_store_error("put", &object_key, &e)),
}
}
None => object_store
.put_opts(&object_path, payload.into(), PutMode::Create.into())
.await
.map_err(|e| map_object_store_error("put", &object_key, &e))?,
};
let next_etag = put_result.e_tag;
let next_path = sqlite_path_for_revision(&self.config.dsn, next_etag.as_deref())?;
let previous_path = sqlite_path_for_revision(&self.config.dsn, start_etag.as_deref())?;
if previous_path != next_path {
std::fs::write(&next_path, &payload_for_next_revision).map_err(|e| {
Error::Internal {
message: format!(
"Failed to materialize snapshot db {} after sync from {}: {}",
next_path.display(),
previous_path.display(),
e
),
}
})?;
}
let sqlite_dsn = sqlite_dsn_for_revision(&self.config.dsn, next_etag.as_deref())?;
let new_tables = SqliteTables::new(&sqlite_dsn, &self.config)
.await
.map_err(|e| Error::Internal {
message: format!("sync reopen sqlite store failed: {e}"),
})?;
let mut guard = self.inner.write().await;
if guard.last_etag != start_etag
|| guard.write_version != start_write_version
|| !guard.is_dirty
{
return Err(Error::Conflict {
message: format!(
"Sync CAS failed for key '{}': local state changed during sync",
object_key
),
});
}
guard.last_etag = next_etag;
guard.is_dirty = false;
guard.sqlite_tables = new_tables;
Ok(())
}
}
fn map_object_store_error(operation: &str, key: &str, err: &object_store::Error) -> Error {
match err {
object_store::Error::NotFound { .. } => Error::NotFound {
entity: "object".to_string(),
id: key.to_string(),
},
object_store::Error::Precondition { .. } | object_store::Error::AlreadyExists { .. } => {
Error::Conflict {
message: format!(
"ObjectStore {} conflict for key '{}': {}",
operation, key, err
),
}
}
object_store::Error::Generic { source, .. } => {
let msg = source.to_string().to_ascii_lowercase();
if msg.contains("412")
|| msg.contains("precondition")
|| msg.contains("if-match")
|| msg.contains("condition")
{
Error::Conflict {
message: format!(
"ObjectStore {} conflict for key '{}': {}",
operation, key, source
),
}
} else if msg.contains("timeout")
|| msg.contains("timed out")
|| msg.contains("connection refused")
{
Error::Timeout {
operation: format!("object_store:{} key={}", operation, key),
}
} else {
Error::Internal {
message: format!(
"ObjectStore {} failed for key '{}': {}",
operation, key, source
),
}
}
}
_ => Error::Internal {
message: format!(
"ObjectStore {} failed for key '{}': {}",
operation, key, err
),
},
}
}
fn sqlite_dsn_for_revision(dsn: &str, etag: Option<&str>) -> Result<String> {
let path = sqlite_path_for_revision(dsn, etag)?;
Ok(format!("sqlite://{}?mode=rwc", path.display()))
}
fn sqlite_path_for_revision(dsn: &str, etag: Option<&str>) -> Result<PathBuf> {
let dir = s3_local_cache_dir_for_dsn(dsn)?;
let filename = match etag {
Some(etag) if !etag.trim().is_empty() => {
format!("{}.sqlite", sanitize_filename_component(etag))
}
_ => "bootstrap.sqlite".to_string(),
};
Ok(dir.join(filename))
}
fn sanitize_filename_component(input: &str) -> String {
let out: String = input
.chars()
.map(|c| match c {
'a'..='z' | 'A'..='Z' | '0'..='9' | '-' | '_' | '.' => c,
_ => '_',
})
.collect();
if out.is_empty() {
"_".to_string()
} else {
out
}
}
#[cfg(test)]
mod tests {
use super::{
parse_s3_bucket_and_key, sanitize_filename_component, sqlite_dsn_for_revision,
sqlite_path_for_revision,
};
#[test]
fn parse_s3_bucket_and_key_accepts_valid_s3_url() {
let (bucket, key) = parse_s3_bucket_and_key("s3://my-bucket/path/to/queue.db").unwrap();
assert_eq!(bucket, "my-bucket");
assert_eq!(key, "path/to/queue.db");
}
#[test]
fn parse_s3_bucket_and_key_rejects_missing_bucket() {
let err = parse_s3_bucket_and_key("s3:///queue.db").unwrap_err();
assert!(err.to_string().contains("missing bucket"));
}
#[test]
fn parse_s3_bucket_and_key_rejects_missing_key() {
let err = parse_s3_bucket_and_key("s3://my-bucket").unwrap_err();
assert!(err.to_string().contains("missing object key"));
}
#[test]
fn parse_s3_bucket_and_key_rejects_wrong_scheme() {
let err = parse_s3_bucket_and_key("sqlite://my-bucket/queue.db").unwrap_err();
assert!(err.to_string().contains("Invalid S3 DSN"));
}
#[test]
fn sqlite_revision_paths_use_bootstrap_before_first_sync() {
let path = sqlite_path_for_revision("s3://bucket/queue.db", None).unwrap();
assert_eq!(
path.file_name().unwrap().to_string_lossy(),
"bootstrap.sqlite"
);
let dsn = sqlite_dsn_for_revision("s3://bucket/queue.db", None).unwrap();
assert!(dsn.ends_with("bootstrap.sqlite?mode=rwc"));
}
#[test]
fn sqlite_revision_paths_use_sanitized_etag_filename() {
let path =
sqlite_path_for_revision("s3://bucket/nested/queue.db", Some("\"etag:1/2\"")).unwrap();
assert_eq!(
path.file_name().unwrap().to_string_lossy(),
"_etag_1_2_.sqlite"
);
}
#[test]
fn sanitize_filename_component_replaces_non_filename_chars() {
assert_eq!(sanitize_filename_component("a/b:c"), "a_b_c");
assert_eq!(sanitize_filename_component(""), "_");
}
}