use std::collections::BTreeMap;
use std::path::PathBuf;
use sha2::{Digest, Sha256};
use crate::backend::SyncBackend;
use crate::conflict;
use crate::error::SyncError;
use crate::hash::{sha256_bytes, sha256_file};
use crate::merkle::MerkleTree;
use crate::state_db::{FileState, StateDb};
fn deterministic_salt(passphrase: &[u8]) -> [u8; 16] {
let mut hasher = Sha256::new();
hasher.update(b"saku-sync-kek-salt-v1:");
hasher.update(passphrase);
let hash = hasher.finalize();
let mut salt = [0u8; 16];
salt.copy_from_slice(&hash[..16]);
salt
}
#[derive(Debug, Clone)]
pub struct TrackedFile {
pub file_key: String,
pub tool: String,
pub relative_path: String,
pub local_path: PathBuf,
}
pub struct SyncConfig {
pub db_path: PathBuf,
pub passphrase: Vec<u8>,
pub tracked_files: Vec<TrackedFile>,
}
#[derive(Debug)]
pub enum SyncOutcome {
Skipped,
Completed { pushed: usize, pulled: usize },
}
pub struct SyncEngine<B: SyncBackend> {
config: SyncConfig,
backend: B,
state_db: StateDb,
}
impl<B: SyncBackend> SyncEngine<B> {
pub fn new(config: SyncConfig, backend: B) -> Result<Self, SyncError> {
let state_db = StateDb::open(&config.db_path)?;
Ok(Self {
config,
backend,
state_db,
})
}
pub fn new_in_memory(config: SyncConfig, backend: B) -> Result<Self, SyncError> {
let state_db = StateDb::open_in_memory()?;
Ok(Self {
config,
backend,
state_db,
})
}
pub fn sync(&mut self) -> Result<SyncOutcome, SyncError> {
if !self.backend.is_reachable() {
return Ok(SyncOutcome::Skipped);
}
let salt = deterministic_salt(&self.config.passphrase);
let master_key = saku_crypto::kdf::derive_master_key(&self.config.passphrase, &salt)?;
let mut local_dirty_keys: Vec<String> = Vec::new();
for tracked in &self.config.tracked_files {
let current_hash = if tracked.local_path.exists() {
sha256_file(&tracked.local_path)?
} else {
sha256_bytes(b"")
};
let prev_state = self.state_db.get_file_state(&tracked.file_key)?;
let is_dirty = match &prev_state {
Some(state) => state.local_hash != current_hash,
None => true,
};
if is_dirty {
local_dirty_keys.push(tracked.file_key.clone());
}
let now_ms = jiff::Timestamp::now().as_millisecond();
self.state_db.upsert_file_state(&FileState {
file_key: tracked.file_key.clone(),
local_hash: current_hash,
remote_hash: prev_state.map(|s| s.remote_hash).unwrap_or_default(),
status: if is_dirty {
"dirty".to_string()
} else {
"clean".to_string()
},
updated_at_ms: now_ms,
})?;
}
let remote_merkle_data = self.backend.fetch_merkle()?;
let remote_merkle = match &remote_merkle_data {
Some(data) => Some(MerkleTree::from_json(data)?),
None => None,
};
let mut pre_hashes: BTreeMap<String, Vec<(String, String)>> = BTreeMap::new();
for tracked in &self.config.tracked_files {
let hash = if tracked.local_path.exists() {
sha256_file(&tracked.local_path)?
} else {
sha256_bytes(b"")
};
pre_hashes
.entry(tracked.tool.clone())
.or_default()
.push((tracked.relative_path.clone(), hash));
}
let local_merkle = MerkleTree::build(pre_hashes);
if let Some(ref rm) = remote_merkle
&& local_merkle.same_root(rm)
&& local_dirty_keys.is_empty()
{
return Ok(SyncOutcome::Skipped);
}
let mut pushed = 0;
let mut pulled = 0;
if let Some(ref rm) = remote_merkle {
let differing = local_merkle.differing_tools(rm);
for tool_name in differing {
if let Some(remote_tool) = rm.tools.iter().find(|t| t.tool == tool_name) {
for file_leaf in &remote_tool.files {
let file_key = format!("{}/{}", tool_name, file_leaf.path);
let local_state = self.state_db.get_file_state(&file_key)?;
let needs_pull = match &local_state {
Some(state) => state.remote_hash != file_leaf.hash,
None => true,
};
if needs_pull {
let encrypted = match self.backend.fetch(tool_name, &file_leaf.path) {
Ok(data) => data,
Err(_) => continue,
};
let decrypted = saku_crypto::decrypt(&encrypted, &master_key)?;
if let Some(tracked) = self
.config
.tracked_files
.iter()
.find(|f| f.file_key == file_key)
{
if file_leaf.path.ends_with(".json") && tracked.local_path.exists()
{
let local_data = std::fs::read(&tracked.local_path)?;
let local_json: serde_json::Value =
serde_json::from_slice(&local_data)?;
let remote_json: serde_json::Value =
serde_json::from_slice(&decrypted)?;
let merged =
conflict::merge_store_json(&local_json, &remote_json);
let merged_bytes = serde_json::to_vec_pretty(&merged)?;
std::fs::write(&tracked.local_path, &merged_bytes)?;
} else if tracked.local_path.exists() {
let device_id = saku_storage::device::get_or_create_device_id()
.unwrap_or_else(|_| "unknown".to_string());
conflict::write_conflict_copy(
&tracked.local_path,
&decrypted,
&device_id,
)?;
} else {
if let Some(parent) = tracked.local_path.parent() {
std::fs::create_dir_all(parent)?;
}
std::fs::write(&tracked.local_path, &decrypted)?;
}
pulled += 1;
let now_ms = jiff::Timestamp::now().as_millisecond();
let new_hash = sha256_file(&tracked.local_path)?;
self.state_db.upsert_file_state(&FileState {
file_key: file_key.clone(),
local_hash: new_hash.clone(),
remote_hash: file_leaf.hash.clone(),
status: "dirty".to_string(), updated_at_ms: now_ms,
})?;
if !local_dirty_keys.contains(&file_key) {
local_dirty_keys.push(file_key.clone());
}
}
}
}
}
}
}
for tracked in &self.config.tracked_files {
if !local_dirty_keys.contains(&tracked.file_key) {
continue;
}
if !tracked.local_path.exists() {
continue;
}
let local_data = std::fs::read(&tracked.local_path)?;
let encrypted = saku_crypto::encrypt(&local_data, &master_key, &salt)?;
self.backend
.push(&tracked.tool, &tracked.relative_path, &encrypted)?;
pushed += 1;
let now_ms = jiff::Timestamp::now().as_millisecond();
let hash = sha256_bytes(&local_data);
self.state_db.upsert_file_state(&FileState {
file_key: tracked.file_key.clone(),
local_hash: hash.clone(),
remote_hash: hash,
status: "clean".to_string(),
updated_at_ms: now_ms,
})?;
}
for key in &local_dirty_keys {
self.state_db.clear_ops_for_file(key)?;
}
let mut updated_hashes: BTreeMap<String, Vec<(String, String)>> = BTreeMap::new();
for tracked in &self.config.tracked_files {
let hash = if tracked.local_path.exists() {
sha256_file(&tracked.local_path)?
} else {
sha256_bytes(b"")
};
updated_hashes
.entry(tracked.tool.clone())
.or_default()
.push((tracked.relative_path.clone(), hash));
}
let final_merkle = MerkleTree::build(updated_hashes);
let merkle_json = final_merkle.to_json()?;
self.backend.push_merkle(&merkle_json)?;
Ok(SyncOutcome::Completed { pushed, pulled })
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::backend::local_fs::LocalFsSyncBackend;
use std::io::Write;
fn make_config(store_path: PathBuf, db_path: PathBuf) -> SyncConfig {
SyncConfig {
db_path,
passphrase: b"test-passphrase".to_vec(),
tracked_files: vec![TrackedFile {
file_key: "tdo/store.json".to_string(),
tool: "tdo".to_string(),
relative_path: "store.json".to_string(),
local_path: store_path,
}],
}
}
fn write_test_store(path: &PathBuf, content: &str) {
if let Some(parent) = path.parent() {
std::fs::create_dir_all(parent).unwrap();
}
let mut f = std::fs::File::create(path).unwrap();
f.write_all(content.as_bytes()).unwrap();
}
#[test]
fn sync_skipped_when_unreachable() {
let dir = tempfile::tempdir().unwrap();
let store_path = dir.path().join("store.json");
write_test_store(&store_path, "{}");
let backend = LocalFsSyncBackend::new(std::path::Path::new("/nonexistent/remote"));
let config = make_config(store_path, dir.path().join("sync.db"));
let mut engine = SyncEngine::new_in_memory(config, backend).unwrap();
match engine.sync().unwrap() {
SyncOutcome::Skipped => {}
other => panic!("Expected Skipped, got {:?}", other),
}
}
#[test]
fn basic_push_sync() {
let local_dir = tempfile::tempdir().unwrap();
let remote_dir = tempfile::tempdir().unwrap();
let store_content = r#"{"version":9,"entries":{}}"#;
let store_path = local_dir.path().join("store.json");
write_test_store(&store_path, store_content);
let backend = LocalFsSyncBackend::new(remote_dir.path());
let config = make_config(store_path, local_dir.path().join("sync.db"));
let mut engine = SyncEngine::new_in_memory(config, backend).unwrap();
match engine.sync().unwrap() {
SyncOutcome::Completed { pushed, .. } => {
assert!(pushed > 0, "Should have pushed at least one file");
}
other => panic!("Expected Completed, got {:?}", other),
}
assert!(remote_dir.path().join("tdo/store.json.enc").exists());
assert!(remote_dir.path().join("merkle.json").exists());
}
#[test]
fn second_sync_is_skipped_when_unchanged() {
let local_dir = tempfile::tempdir().unwrap();
let remote_dir = tempfile::tempdir().unwrap();
let store_content = r#"{"version":9,"entries":{}}"#;
let store_path = local_dir.path().join("store.json");
write_test_store(&store_path, store_content);
let backend = LocalFsSyncBackend::new(remote_dir.path());
let config = make_config(store_path.clone(), local_dir.path().join("sync.db"));
let mut engine = SyncEngine::new_in_memory(config, backend).unwrap();
engine.sync().unwrap();
let backend2 = LocalFsSyncBackend::new(remote_dir.path());
let config2 = make_config(store_path, local_dir.path().join("sync2.db"));
let mut engine2 = SyncEngine::new_in_memory(config2, backend2).unwrap();
let result = engine2.sync().unwrap();
match result {
SyncOutcome::Completed { .. } | SyncOutcome::Skipped => {}
}
}
}