1use anyhow::{Context, Result};
8use chrono::{DateTime, Utc};
9use serde::{Deserialize, Serialize};
10use std::fs;
11use std::path::PathBuf;
12use std::sync::Arc;
13use tokio::sync::RwLock;
14use tokio::time::{interval, Duration};
15
16use crate::cloud::client::{CloudClient, PushSession, SessionMetadata};
17use crate::cloud::credentials::CredentialsStore;
18use crate::cloud::encryption::{decode_key_hex, encode_base64, encrypt_data};
19use crate::config::Config;
20use crate::storage::models::Message;
21use crate::storage::Database;
22
23const SYNC_INTERVAL_HOURS: u64 = 4;
25
26const PUSH_BATCH_SIZE: usize = 3;
28
29#[derive(Debug, Clone, Default, Serialize, Deserialize)]
31pub struct SyncState {
32 pub last_sync_at: Option<DateTime<Utc>>,
34 pub next_sync_at: Option<DateTime<Utc>>,
36 pub last_sync_count: Option<u64>,
38 pub last_sync_success: Option<bool>,
40}
41
42impl SyncState {
43 fn state_path() -> Result<PathBuf> {
45 let lore_dir = dirs::home_dir()
46 .context("Could not find home directory")?
47 .join(".lore");
48 Ok(lore_dir.join("daemon_state.json"))
49 }
50
51 pub fn load_from_path(path: &std::path::Path) -> Result<Self> {
55 if !path.exists() {
56 return Ok(Self::default());
57 }
58
59 let content = fs::read_to_string(path).context("Failed to read sync state file")?;
60 let state: SyncState =
61 serde_json::from_str(&content).context("Failed to parse sync state file")?;
62 Ok(state)
63 }
64
65 pub fn load() -> Result<Self> {
69 let path = Self::state_path()?;
70 Self::load_from_path(&path)
71 }
72
73 pub fn save_to_path(&self, path: &std::path::Path) -> Result<()> {
77 if let Some(parent) = path.parent() {
78 fs::create_dir_all(parent).context("Failed to create parent directory")?;
79 }
80
81 let content = serde_json::to_string_pretty(self)?;
82
83 let temp_path = path.with_extension("json.tmp");
86 fs::write(&temp_path, &content).context("Failed to write sync state temp file")?;
87
88 #[cfg(windows)]
89 if path.exists() {
90 let _ = fs::remove_file(path);
91 }
92
93 fs::rename(&temp_path, path).context("Failed to rename sync state file")?;
94
95 Ok(())
96 }
97
98 fn save(&self) -> Result<()> {
100 let path = Self::state_path()?;
101 self.save_to_path(&path)
102 }
103
104 fn schedule_next(&mut self, next_at: DateTime<Utc>) -> Result<()> {
106 self.next_sync_at = Some(next_at);
107 self.save()
108 }
109
110 fn record_sync(&mut self, success: bool, count: u64, next_at: DateTime<Utc>) -> Result<()> {
112 self.last_sync_at = Some(Utc::now());
113 self.last_sync_success = Some(success);
114 self.last_sync_count = Some(count);
115 self.next_sync_at = Some(next_at);
116 self.save()
117 }
118}
119
120pub type SharedSyncState = Arc<RwLock<SyncState>>;
122
123fn calculate_next_sync(state: &SyncState) -> DateTime<Utc> {
128 let interval = chrono::Duration::hours(SYNC_INTERVAL_HOURS as i64);
129
130 if let Some(last_sync) = state.last_sync_at {
131 let next = last_sync + interval;
133 let now = Utc::now();
135 if next <= now {
136 now + interval
137 } else {
138 next
139 }
140 } else {
141 Utc::now() + interval
143 }
144}
145
146pub async fn run_periodic_sync(
152 sync_state: SharedSyncState,
153 mut shutdown_rx: tokio::sync::broadcast::Receiver<()>,
154) {
155 {
156 let mut state = sync_state.write().await;
157 let next_sync = if let Some(persisted_next) = state.next_sync_at {
158 if persisted_next > Utc::now() {
159 persisted_next
160 } else {
161 calculate_next_sync(&state)
162 }
163 } else {
164 calculate_next_sync(&state)
165 };
166 if let Err(e) = state.schedule_next(next_sync) {
167 tracing::warn!("Failed to save initial sync state: {e}");
168 } else {
169 tracing::info!(
170 "Periodic sync scheduled for {}",
171 next_sync.format("%Y-%m-%d %H:%M:%S UTC")
172 );
173 }
174 }
175
176 let mut check_interval = interval(Duration::from_secs(60));
177
178 loop {
179 tokio::select! {
180 _ = check_interval.tick() => {
181 let should_sync = {
182 let state = sync_state.read().await;
183 if let Some(next_sync) = state.next_sync_at {
184 Utc::now() >= next_sync
185 } else {
186 false
187 }
188 };
189
190 if should_sync {
191 let result = perform_sync().await;
192 let next_sync = Utc::now() + chrono::Duration::hours(SYNC_INTERVAL_HOURS as i64);
193
194 let mut state = sync_state.write().await;
195 match result {
196 Ok(count) => {
197 tracing::info!("Periodic sync completed: {} sessions synced", count);
198 if let Err(e) = state.record_sync(true, count, next_sync) {
199 tracing::warn!("Failed to save sync state: {e}");
200 }
201 }
202 Err(e) => {
203 tracing::info!("Periodic sync skipped or failed: {e}");
204 if let Err(e) = state.record_sync(false, 0, next_sync) {
205 tracing::warn!("Failed to save sync state: {e}");
206 }
207 }
208 }
209 }
210 }
211 _ = shutdown_rx.recv() => {
212 tracing::info!("Periodic sync shutting down");
213 break;
214 }
215 }
216 }
217}
218
219async fn perform_sync() -> Result<u64> {
224 tokio::task::spawn_blocking(perform_sync_blocking)
225 .await
226 .context("Sync task panicked")?
227}
228
229fn perform_sync_blocking() -> Result<u64> {
233 let config = Config::load().context("Could not load config")?;
234
235 let store = CredentialsStore::with_keychain(config.use_keychain);
236
237 let credentials = match store.load()? {
238 Some(creds) => creds,
239 None => {
240 return Err(anyhow::anyhow!("Not logged in"));
241 }
242 };
243
244 let encryption_key = match store.load_encryption_key()? {
245 Some(key_hex) => decode_key_hex(&key_hex)?,
246 None => {
247 return Err(anyhow::anyhow!("Encryption key not configured"));
248 }
249 };
250
251 let machine_id = match config.machine_id.clone() {
252 Some(id) => id,
253 None => {
254 return Err(anyhow::anyhow!("Machine ID not configured"));
255 }
256 };
257
258 let db = Database::open_default().context("Could not open database")?;
259
260 let sessions = db.get_unsynced_sessions()?;
261 if sessions.is_empty() {
262 tracing::debug!("No sessions to sync");
263 return Ok(0);
264 }
265
266 tracing::info!("Found {} sessions to sync", sessions.len());
267
268 let client = CloudClient::with_url(&credentials.cloud_url).with_api_key(&credentials.api_key);
269
270 let session_data: Vec<_> = sessions
271 .iter()
272 .filter_map(|session| match db.get_messages(&session.id) {
273 Ok(messages) => Some((session.clone(), messages)),
274 Err(e) => {
275 tracing::warn!(
276 "Failed to get messages for session {}: {}",
277 &session.id.to_string()[..8],
278 e
279 );
280 None
281 }
282 })
283 .collect();
284
285 let mut total_synced: u64 = 0;
286
287 for batch in session_data.chunks(PUSH_BATCH_SIZE) {
288 let mut push_sessions = Vec::new();
289
290 for (session, messages) in batch {
291 let encrypted = encrypt_session_messages(messages, &encryption_key)?;
292 push_sessions.push(PushSession {
293 id: session.id.to_string(),
294 machine_id: machine_id.clone(),
295 encrypted_data: encrypted,
296 metadata: SessionMetadata {
297 tool_name: session.tool.clone(),
298 project_path: session.working_directory.clone(),
299 started_at: session.started_at,
300 ended_at: session.ended_at,
301 message_count: session.message_count,
302 },
303 updated_at: session.ended_at.unwrap_or_else(Utc::now),
304 });
305 }
306
307 match client.push(push_sessions.clone()) {
308 Ok(response) => {
309 let batch_session_ids: Vec<_> = push_sessions
310 .iter()
311 .filter_map(|ps| uuid::Uuid::parse_str(&ps.id).ok())
312 .collect();
313
314 if let Err(e) = db.mark_sessions_synced(&batch_session_ids, response.server_time) {
315 tracing::warn!("Failed to mark sessions as synced: {e}");
316 }
317
318 total_synced += response.synced_count as u64;
319 }
320 Err(e) => {
321 let error_str = e.to_string();
322 if error_str.contains("quota")
323 || error_str.contains("Would exceed session limit")
324 || (error_str.contains("403") && error_str.contains("limit"))
325 {
326 tracing::debug!("Sync stopped due to quota limit");
327 break;
328 }
329 tracing::warn!("Failed to push batch: {e}");
330 }
331 }
332 }
333
334 Ok(total_synced)
335}
336
337fn encrypt_session_messages(messages: &[Message], key: &[u8]) -> Result<String> {
339 let json = serde_json::to_vec(messages)?;
340 let encrypted = encrypt_data(&json, key)?;
341 Ok(encode_base64(&encrypted))
342}
343
344#[cfg(test)]
345mod tests {
346 use super::*;
347 use tempfile::TempDir;
348
349 #[test]
350 fn test_sync_state_default() {
351 let state = SyncState::default();
352 assert!(state.last_sync_at.is_none());
353 assert!(state.next_sync_at.is_none());
354 assert!(state.last_sync_count.is_none());
355 assert!(state.last_sync_success.is_none());
356 }
357
358 #[test]
359 fn test_calculate_next_sync_no_previous() {
360 let state = SyncState::default();
361 let next = calculate_next_sync(&state);
362
363 let expected = Utc::now() + chrono::Duration::hours(SYNC_INTERVAL_HOURS as i64);
364 let diff = (next - expected).num_seconds().abs();
365 assert!(diff < 5, "Next sync should be ~4 hours from now");
366 }
367
368 #[test]
369 fn test_calculate_next_sync_with_recent_previous() {
370 let last_sync = Utc::now() - chrono::Duration::hours(1);
371 let state = SyncState {
372 last_sync_at: Some(last_sync),
373 ..Default::default()
374 };
375
376 let next = calculate_next_sync(&state);
377
378 let expected = last_sync + chrono::Duration::hours(SYNC_INTERVAL_HOURS as i64);
379 let diff = (next - expected).num_seconds().abs();
380 assert!(diff < 5, "Next sync should be 4 hours after last sync");
381 }
382
383 #[test]
384 fn test_calculate_next_sync_with_old_previous() {
385 let state = SyncState {
386 last_sync_at: Some(Utc::now() - chrono::Duration::hours(10)),
387 ..Default::default()
388 };
389
390 let next = calculate_next_sync(&state);
391
392 let expected = Utc::now() + chrono::Duration::hours(SYNC_INTERVAL_HOURS as i64);
393 let diff = (next - expected).num_seconds().abs();
394 assert!(
395 diff < 5,
396 "Next sync should be ~4 hours from now when last sync is old"
397 );
398 }
399
400 #[test]
401 fn test_sync_state_serialization() {
402 let state = SyncState {
403 last_sync_at: Some(Utc::now()),
404 next_sync_at: Some(Utc::now() + chrono::Duration::hours(4)),
405 last_sync_count: Some(10),
406 last_sync_success: Some(true),
407 };
408
409 let json = serde_json::to_string(&state).unwrap();
410 let parsed: SyncState = serde_json::from_str(&json).unwrap();
411
412 assert!(parsed.last_sync_at.is_some());
413 assert!(parsed.next_sync_at.is_some());
414 assert_eq!(parsed.last_sync_count, Some(10));
415 assert_eq!(parsed.last_sync_success, Some(true));
416 }
417
418 #[test]
419 fn test_sync_state_save_load_round_trip() {
420 let temp_dir = TempDir::new().unwrap();
421 let state_path = temp_dir.path().join("daemon_state.json");
422
423 let state = SyncState {
424 last_sync_at: Some(Utc::now()),
425 next_sync_at: Some(Utc::now() + chrono::Duration::hours(4)),
426 last_sync_count: Some(5),
427 last_sync_success: Some(true),
428 };
429
430 state.save_to_path(&state_path).unwrap();
431
432 let loaded = SyncState::load_from_path(&state_path).unwrap();
433
434 assert_eq!(loaded.last_sync_count, Some(5));
435 assert_eq!(loaded.last_sync_success, Some(true));
436 assert!(loaded.next_sync_at.is_some());
437 assert!(loaded.last_sync_at.is_some());
438 }
439
440 #[test]
441 fn test_sync_state_save_creates_parent_directory() {
442 let temp_dir = TempDir::new().unwrap();
443 let nested_path = temp_dir
444 .path()
445 .join("nested")
446 .join("deep")
447 .join("state.json");
448
449 let parent = nested_path.parent().unwrap();
450 assert!(!parent.exists());
451
452 let state = SyncState::default();
453 state.save_to_path(&nested_path).unwrap();
454
455 assert!(parent.exists());
456 assert!(nested_path.exists());
457
458 let loaded = SyncState::load_from_path(&nested_path).unwrap();
460 assert!(loaded.last_sync_at.is_none());
461 }
462
463 #[test]
464 fn test_persisted_next_sync_at_respected_when_future() {
465 let future_time = Utc::now() + chrono::Duration::hours(2);
466 let state = SyncState {
467 last_sync_at: Some(Utc::now() - chrono::Duration::hours(1)),
468 next_sync_at: Some(future_time),
469 last_sync_count: Some(3),
470 last_sync_success: Some(true),
471 };
472
473 let next_sync = if let Some(persisted_next) = state.next_sync_at {
474 if persisted_next > Utc::now() {
475 persisted_next
476 } else {
477 calculate_next_sync(&state)
478 }
479 } else {
480 calculate_next_sync(&state)
481 };
482
483 let diff = (next_sync - future_time).num_seconds().abs();
484 assert!(diff < 1, "Should use persisted next_sync_at when in future");
485 }
486
487 #[test]
488 fn test_persisted_next_sync_at_recalculated_when_past() {
489 let past_time = Utc::now() - chrono::Duration::hours(1);
490 let state = SyncState {
491 last_sync_at: Some(Utc::now() - chrono::Duration::hours(2)),
492 next_sync_at: Some(past_time),
493 last_sync_count: Some(3),
494 last_sync_success: Some(true),
495 };
496
497 let next_sync = if let Some(persisted_next) = state.next_sync_at {
498 if persisted_next > Utc::now() {
499 persisted_next
500 } else {
501 calculate_next_sync(&state)
502 }
503 } else {
504 calculate_next_sync(&state)
505 };
506
507 assert!(
508 next_sync > Utc::now(),
509 "Should recalculate when persisted next_sync_at is in the past"
510 );
511 }
512
513 #[test]
514 fn test_sync_state_atomic_save_overwrites() {
515 let temp_dir = TempDir::new().unwrap();
516 let state_path = temp_dir.path().join("daemon_state.json");
517
518 let state1 = SyncState {
520 last_sync_count: Some(1),
521 ..Default::default()
522 };
523 state1.save_to_path(&state_path).unwrap();
524
525 let loaded1 = SyncState::load_from_path(&state_path).unwrap();
527 assert_eq!(loaded1.last_sync_count, Some(1));
528
529 let state2 = SyncState {
531 last_sync_count: Some(2),
532 ..Default::default()
533 };
534 state2.save_to_path(&state_path).unwrap();
535
536 let loaded2 = SyncState::load_from_path(&state_path).unwrap();
538 assert_eq!(loaded2.last_sync_count, Some(2));
539 }
540}