use anyhow::{Context, Result};
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use std::fs;
use std::path::PathBuf;
use std::sync::Arc;
use tokio::sync::RwLock;
use tokio::time::{interval, Duration};
use crate::cloud::client::{CloudClient, PushSession, SessionMetadata};
use crate::cloud::credentials::CredentialsStore;
use crate::cloud::encryption::{decode_key_hex, encode_base64, encrypt_data};
use crate::config::Config;
use crate::storage::models::Message;
use crate::storage::Database;
const SYNC_INTERVAL_HOURS: u64 = 4;
const PUSH_BATCH_SIZE: usize = 3;
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct SyncState {
pub last_sync_at: Option<DateTime<Utc>>,
pub next_sync_at: Option<DateTime<Utc>>,
pub last_sync_count: Option<u64>,
pub last_sync_success: Option<bool>,
}
impl SyncState {
fn state_path() -> Result<PathBuf> {
let lore_dir = dirs::home_dir()
.context("Could not find home directory")?
.join(".lore");
Ok(lore_dir.join("daemon_state.json"))
}
pub fn load_from_path(path: &std::path::Path) -> Result<Self> {
if !path.exists() {
return Ok(Self::default());
}
let content = fs::read_to_string(path).context("Failed to read sync state file")?;
let state: SyncState =
serde_json::from_str(&content).context("Failed to parse sync state file")?;
Ok(state)
}
pub fn load() -> Result<Self> {
let path = Self::state_path()?;
Self::load_from_path(&path)
}
pub fn save_to_path(&self, path: &std::path::Path) -> Result<()> {
if let Some(parent) = path.parent() {
fs::create_dir_all(parent).context("Failed to create parent directory")?;
}
let content = serde_json::to_string_pretty(self)?;
let temp_path = path.with_extension("json.tmp");
fs::write(&temp_path, &content).context("Failed to write sync state temp file")?;
#[cfg(windows)]
if path.exists() {
let _ = fs::remove_file(path);
}
fs::rename(&temp_path, path).context("Failed to rename sync state file")?;
Ok(())
}
fn save(&self) -> Result<()> {
let path = Self::state_path()?;
self.save_to_path(&path)
}
fn schedule_next(&mut self, next_at: DateTime<Utc>) -> Result<()> {
self.next_sync_at = Some(next_at);
self.save()
}
fn record_sync(&mut self, success: bool, count: u64, next_at: DateTime<Utc>) -> Result<()> {
self.last_sync_at = Some(Utc::now());
self.last_sync_success = Some(success);
self.last_sync_count = Some(count);
self.next_sync_at = Some(next_at);
self.save()
}
}
pub type SharedSyncState = Arc<RwLock<SyncState>>;
fn calculate_next_sync(state: &SyncState) -> DateTime<Utc> {
let interval = chrono::Duration::hours(SYNC_INTERVAL_HOURS as i64);
if let Some(last_sync) = state.last_sync_at {
let next = last_sync + interval;
let now = Utc::now();
if next <= now {
now + interval
} else {
next
}
} else {
Utc::now() + interval
}
}
pub async fn run_periodic_sync(
sync_state: SharedSyncState,
mut shutdown_rx: tokio::sync::broadcast::Receiver<()>,
) {
{
let mut state = sync_state.write().await;
let next_sync = if let Some(persisted_next) = state.next_sync_at {
if persisted_next > Utc::now() {
persisted_next
} else {
calculate_next_sync(&state)
}
} else {
calculate_next_sync(&state)
};
if let Err(e) = state.schedule_next(next_sync) {
tracing::warn!("Failed to save initial sync state: {e}");
} else {
tracing::info!(
"Periodic sync scheduled for {}",
next_sync.format("%Y-%m-%d %H:%M:%S UTC")
);
}
}
let mut check_interval = interval(Duration::from_secs(60));
loop {
tokio::select! {
_ = check_interval.tick() => {
let should_sync = {
let state = sync_state.read().await;
if let Some(next_sync) = state.next_sync_at {
Utc::now() >= next_sync
} else {
false
}
};
if should_sync {
let result = perform_sync().await;
let next_sync = Utc::now() + chrono::Duration::hours(SYNC_INTERVAL_HOURS as i64);
let mut state = sync_state.write().await;
match result {
Ok(count) => {
tracing::info!("Periodic sync completed: {} sessions synced", count);
if let Err(e) = state.record_sync(true, count, next_sync) {
tracing::warn!("Failed to save sync state: {e}");
}
}
Err(e) => {
tracing::info!("Periodic sync skipped or failed: {e}");
if let Err(e) = state.record_sync(false, 0, next_sync) {
tracing::warn!("Failed to save sync state: {e}");
}
}
}
}
}
_ = shutdown_rx.recv() => {
tracing::info!("Periodic sync shutting down");
break;
}
}
}
}
async fn perform_sync() -> Result<u64> {
tokio::task::spawn_blocking(perform_sync_blocking)
.await
.context("Sync task panicked")?
}
fn perform_sync_blocking() -> Result<u64> {
let config = Config::load().context("Could not load config")?;
let store = CredentialsStore::with_keychain(config.use_keychain);
let credentials = match store.load()? {
Some(creds) => creds,
None => {
return Err(anyhow::anyhow!("Not logged in"));
}
};
let encryption_key = match store.load_encryption_key()? {
Some(key_hex) => decode_key_hex(&key_hex)?,
None => {
return Err(anyhow::anyhow!("Encryption key not configured"));
}
};
let machine_id = match config.machine_id.clone() {
Some(id) => id,
None => {
return Err(anyhow::anyhow!("Machine ID not configured"));
}
};
let db = Database::open_default().context("Could not open database")?;
let sessions = db.get_unsynced_sessions()?;
if sessions.is_empty() {
tracing::debug!("No sessions to sync");
return Ok(0);
}
tracing::info!("Found {} sessions to sync", sessions.len());
let client = CloudClient::with_url(&credentials.cloud_url).with_api_key(&credentials.api_key);
let session_data: Vec<_> = sessions
.iter()
.filter_map(|session| match db.get_messages(&session.id) {
Ok(messages) => Some((session.clone(), messages)),
Err(e) => {
tracing::warn!(
"Failed to get messages for session {}: {}",
&session.id.to_string()[..8],
e
);
None
}
})
.collect();
let mut total_synced: u64 = 0;
for batch in session_data.chunks(PUSH_BATCH_SIZE) {
let mut push_sessions = Vec::new();
for (session, messages) in batch {
let encrypted = encrypt_session_messages(messages, &encryption_key)?;
push_sessions.push(PushSession {
id: session.id.to_string(),
machine_id: machine_id.clone(),
encrypted_data: encrypted,
metadata: SessionMetadata {
tool_name: session.tool.clone(),
project_path: session.working_directory.clone(),
started_at: session.started_at,
ended_at: session.ended_at,
message_count: session.message_count,
},
updated_at: session.ended_at.unwrap_or_else(Utc::now),
});
}
match client.push(push_sessions.clone()) {
Ok(response) => {
let batch_session_ids: Vec<_> = push_sessions
.iter()
.filter_map(|ps| uuid::Uuid::parse_str(&ps.id).ok())
.collect();
if let Err(e) = db.mark_sessions_synced(&batch_session_ids, response.server_time) {
tracing::warn!("Failed to mark sessions as synced: {e}");
}
total_synced += response.synced_count as u64;
}
Err(e) => {
let error_str = e.to_string();
if error_str.contains("quota")
|| error_str.contains("Would exceed session limit")
|| (error_str.contains("403") && error_str.contains("limit"))
{
tracing::debug!("Sync stopped due to quota limit");
break;
}
tracing::warn!("Failed to push batch: {e}");
}
}
}
Ok(total_synced)
}
fn encrypt_session_messages(messages: &[Message], key: &[u8]) -> Result<String> {
let json = serde_json::to_vec(messages)?;
let encrypted = encrypt_data(&json, key)?;
Ok(encode_base64(&encrypted))
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::TempDir;
#[test]
fn test_sync_state_default() {
let state = SyncState::default();
assert!(state.last_sync_at.is_none());
assert!(state.next_sync_at.is_none());
assert!(state.last_sync_count.is_none());
assert!(state.last_sync_success.is_none());
}
#[test]
fn test_calculate_next_sync_no_previous() {
let state = SyncState::default();
let next = calculate_next_sync(&state);
let expected = Utc::now() + chrono::Duration::hours(SYNC_INTERVAL_HOURS as i64);
let diff = (next - expected).num_seconds().abs();
assert!(diff < 5, "Next sync should be ~4 hours from now");
}
#[test]
fn test_calculate_next_sync_with_recent_previous() {
let last_sync = Utc::now() - chrono::Duration::hours(1);
let state = SyncState {
last_sync_at: Some(last_sync),
..Default::default()
};
let next = calculate_next_sync(&state);
let expected = last_sync + chrono::Duration::hours(SYNC_INTERVAL_HOURS as i64);
let diff = (next - expected).num_seconds().abs();
assert!(diff < 5, "Next sync should be 4 hours after last sync");
}
#[test]
fn test_calculate_next_sync_with_old_previous() {
let state = SyncState {
last_sync_at: Some(Utc::now() - chrono::Duration::hours(10)),
..Default::default()
};
let next = calculate_next_sync(&state);
let expected = Utc::now() + chrono::Duration::hours(SYNC_INTERVAL_HOURS as i64);
let diff = (next - expected).num_seconds().abs();
assert!(
diff < 5,
"Next sync should be ~4 hours from now when last sync is old"
);
}
#[test]
fn test_sync_state_serialization() {
let state = SyncState {
last_sync_at: Some(Utc::now()),
next_sync_at: Some(Utc::now() + chrono::Duration::hours(4)),
last_sync_count: Some(10),
last_sync_success: Some(true),
};
let json = serde_json::to_string(&state).unwrap();
let parsed: SyncState = serde_json::from_str(&json).unwrap();
assert!(parsed.last_sync_at.is_some());
assert!(parsed.next_sync_at.is_some());
assert_eq!(parsed.last_sync_count, Some(10));
assert_eq!(parsed.last_sync_success, Some(true));
}
#[test]
fn test_sync_state_save_load_round_trip() {
let temp_dir = TempDir::new().unwrap();
let state_path = temp_dir.path().join("daemon_state.json");
let state = SyncState {
last_sync_at: Some(Utc::now()),
next_sync_at: Some(Utc::now() + chrono::Duration::hours(4)),
last_sync_count: Some(5),
last_sync_success: Some(true),
};
state.save_to_path(&state_path).unwrap();
let loaded = SyncState::load_from_path(&state_path).unwrap();
assert_eq!(loaded.last_sync_count, Some(5));
assert_eq!(loaded.last_sync_success, Some(true));
assert!(loaded.next_sync_at.is_some());
assert!(loaded.last_sync_at.is_some());
}
#[test]
fn test_sync_state_save_creates_parent_directory() {
let temp_dir = TempDir::new().unwrap();
let nested_path = temp_dir
.path()
.join("nested")
.join("deep")
.join("state.json");
let parent = nested_path.parent().unwrap();
assert!(!parent.exists());
let state = SyncState::default();
state.save_to_path(&nested_path).unwrap();
assert!(parent.exists());
assert!(nested_path.exists());
let loaded = SyncState::load_from_path(&nested_path).unwrap();
assert!(loaded.last_sync_at.is_none());
}
#[test]
fn test_persisted_next_sync_at_respected_when_future() {
let future_time = Utc::now() + chrono::Duration::hours(2);
let state = SyncState {
last_sync_at: Some(Utc::now() - chrono::Duration::hours(1)),
next_sync_at: Some(future_time),
last_sync_count: Some(3),
last_sync_success: Some(true),
};
let next_sync = if let Some(persisted_next) = state.next_sync_at {
if persisted_next > Utc::now() {
persisted_next
} else {
calculate_next_sync(&state)
}
} else {
calculate_next_sync(&state)
};
let diff = (next_sync - future_time).num_seconds().abs();
assert!(diff < 1, "Should use persisted next_sync_at when in future");
}
#[test]
fn test_persisted_next_sync_at_recalculated_when_past() {
let past_time = Utc::now() - chrono::Duration::hours(1);
let state = SyncState {
last_sync_at: Some(Utc::now() - chrono::Duration::hours(2)),
next_sync_at: Some(past_time),
last_sync_count: Some(3),
last_sync_success: Some(true),
};
let next_sync = if let Some(persisted_next) = state.next_sync_at {
if persisted_next > Utc::now() {
persisted_next
} else {
calculate_next_sync(&state)
}
} else {
calculate_next_sync(&state)
};
assert!(
next_sync > Utc::now(),
"Should recalculate when persisted next_sync_at is in the past"
);
}
#[test]
fn test_sync_state_atomic_save_overwrites() {
let temp_dir = TempDir::new().unwrap();
let state_path = temp_dir.path().join("daemon_state.json");
let state1 = SyncState {
last_sync_count: Some(1),
..Default::default()
};
state1.save_to_path(&state_path).unwrap();
let loaded1 = SyncState::load_from_path(&state_path).unwrap();
assert_eq!(loaded1.last_sync_count, Some(1));
let state2 = SyncState {
last_sync_count: Some(2),
..Default::default()
};
state2.save_to_path(&state_path).unwrap();
let loaded2 = SyncState::load_from_path(&state_path).unwrap();
assert_eq!(loaded2.last_sync_count, Some(2));
}
}