1use crate::error::{ReplicationError, Result};
47use sqlx::sqlite::{SqliteConnectOptions, SqlitePool, SqlitePoolOptions};
48use std::collections::{HashMap, HashSet};
49use std::path::Path;
50use std::str::FromStr;
51use std::sync::Arc;
52use std::time::Duration;
53use tokio::sync::RwLock;
54use tracing::{debug, info, warn};
55
56const SQLITE_RETRY_MAX_ATTEMPTS: u32 = 5;
58const SQLITE_RETRY_BASE_DELAY_MS: u64 = 10;
59const SQLITE_RETRY_MAX_DELAY_MS: u64 = 500;
60
61fn is_sqlite_busy_error(e: &sqlx::Error) -> bool {
63 match e {
64 sqlx::Error::Database(db_err) => {
65 if let Some(code) = db_err.code() {
67 return code == "5" || code == "6";
68 }
69 let msg = db_err.message().to_lowercase();
71 msg.contains("database is locked") || msg.contains("database is busy")
72 }
73 _ => false,
74 }
75}
76
77async fn execute_with_retry<F, Fut, T>(operation_name: &str, mut f: F) -> std::result::Result<T, sqlx::Error>
79where
80 F: FnMut() -> Fut,
81 Fut: std::future::Future<Output = std::result::Result<T, sqlx::Error>>,
82{
83 let mut attempts = 0;
84 let mut delay_ms = SQLITE_RETRY_BASE_DELAY_MS;
85
86 loop {
87 attempts += 1;
88 match f().await {
89 Ok(result) => {
90 if attempts > 1 {
91 debug!(
92 operation = operation_name,
93 attempts,
94 "SQLite operation succeeded after retry"
95 );
96 }
97 return Ok(result);
98 }
99 Err(e) if is_sqlite_busy_error(&e) && attempts < SQLITE_RETRY_MAX_ATTEMPTS => {
100 warn!(
101 operation = operation_name,
102 attempts,
103 max_attempts = SQLITE_RETRY_MAX_ATTEMPTS,
104 delay_ms,
105 "SQLite busy, retrying"
106 );
107 crate::metrics::cursor_retries_total(operation_name);
108 tokio::time::sleep(Duration::from_millis(delay_ms)).await;
109 delay_ms = (delay_ms * 2).min(SQLITE_RETRY_MAX_DELAY_MS);
111 }
112 Err(e) => {
113 if is_sqlite_busy_error(&e) {
114 warn!(
115 operation = operation_name,
116 attempts,
117 "SQLite busy, max retries exceeded"
118 );
119 }
120 return Err(e);
121 }
122 }
123 }
124}
125
126#[derive(Debug, Clone)]
128pub struct CursorEntry {
129 pub peer_id: String,
131 pub stream_id: String,
133 pub updated_at: i64,
135}
136
137pub struct CursorStore {
142 pool: SqlitePool,
144 cache: Arc<RwLock<HashMap<String, String>>>,
146 dirty: Arc<RwLock<HashSet<String>>>,
148 path: String,
150}
151
152impl CursorStore {
153 pub async fn new(path: impl AsRef<Path>) -> Result<Self> {
157 let path_str = path.as_ref().to_string_lossy().to_string();
158 info!(path = %path_str, "Initializing cursor store");
159
160 let options = SqliteConnectOptions::from_str(&format!("sqlite://{}?mode=rwc", path_str))
161 .map_err(|e| ReplicationError::Config(format!("Invalid SQLite path: {}", e)))?
162 .journal_mode(sqlx::sqlite::SqliteJournalMode::Wal)
163 .synchronous(sqlx::sqlite::SqliteSynchronous::Normal)
164 .create_if_missing(true);
165
166 let pool = SqlitePoolOptions::new()
167 .max_connections(2) .connect_with(options)
169 .await?;
170
171 sqlx::query(
173 r#"
174 CREATE TABLE IF NOT EXISTS cursors (
175 peer_id TEXT PRIMARY KEY,
176 stream_id TEXT NOT NULL,
177 updated_at INTEGER NOT NULL
178 )
179 "#,
180 )
181 .execute(&pool)
182 .await?;
183
184 let rows: Vec<(String, String)> =
186 sqlx::query_as("SELECT peer_id, stream_id FROM cursors")
187 .fetch_all(&pool)
188 .await?;
189
190 let mut cache = HashMap::new();
191 for (peer_id, stream_id) in rows {
192 debug!(peer_id = %peer_id, stream_id = %stream_id, "Loaded cursor from disk");
193 cache.insert(peer_id, stream_id);
194 }
195
196 if !cache.is_empty() {
197 info!(count = cache.len(), "Restored cursors from previous run");
198 }
199
200 Ok(Self {
201 pool,
202 cache: Arc::new(RwLock::new(cache)),
203 dirty: Arc::new(RwLock::new(HashSet::new())),
204 path: path_str,
205 })
206 }
207
208 pub async fn get(&self, peer_id: &str) -> Option<String> {
212 self.cache.read().await.get(peer_id).cloned()
213 }
214
215 pub async fn get_or_start(&self, peer_id: &str) -> String {
219 self.get(peer_id).await.unwrap_or_else(|| "0".to_string())
220 }
221
222 pub async fn set(&self, peer_id: &str, stream_id: &str) {
230 {
232 let mut cache = self.cache.write().await;
233 cache.insert(peer_id.to_string(), stream_id.to_string());
234 }
235
236 {
238 let mut dirty = self.dirty.write().await;
239 dirty.insert(peer_id.to_string());
240 }
241
242 debug!(peer_id = %peer_id, stream_id = %stream_id, "Cursor updated (pending flush)");
243 }
244
245 pub async fn flush_dirty(&self) -> Result<usize> {
250 let dirty_peers: Vec<String> = {
252 let mut dirty = self.dirty.write().await;
253 let peers: Vec<String> = dirty.drain().collect();
254 peers
255 };
256
257 if dirty_peers.is_empty() {
258 return Ok(0);
259 }
260
261 let now = chrono::Utc::now().timestamp_millis();
262 let cache = self.cache.read().await;
263 let pool = &self.pool;
264
265 let mut flushed = 0;
266 let mut errors = 0;
267
268 for peer_id in &dirty_peers {
269 if let Some(stream_id) = cache.get(peer_id) {
270 let peer_id_owned = peer_id.clone();
271 let stream_id_owned = stream_id.clone();
272
273 let result = execute_with_retry("cursor_flush", || async {
274 sqlx::query(
275 r#"
276 INSERT INTO cursors (peer_id, stream_id, updated_at)
277 VALUES (?, ?, ?)
278 ON CONFLICT(peer_id) DO UPDATE SET
279 stream_id = excluded.stream_id,
280 updated_at = excluded.updated_at
281 "#,
282 )
283 .bind(&peer_id_owned)
284 .bind(&stream_id_owned)
285 .bind(now)
286 .execute(pool)
287 .await
288 })
289 .await;
290
291 match result {
292 Ok(_) => {
293 flushed += 1;
294 }
295 Err(e) => {
296 errors += 1;
297 warn!(peer_id = %peer_id, error = %e, "Failed to flush cursor");
298 self.dirty.write().await.insert(peer_id.clone());
300 }
301 }
302 }
303 }
304
305 if flushed > 0 {
306 debug!(flushed, errors, "Flushed dirty cursors");
307 crate::metrics::record_cursor_flush(flushed, errors);
308 }
309
310 if errors > 0 {
311 return Err(ReplicationError::Internal(format!(
312 "Failed to flush {} cursors",
313 errors
314 )));
315 }
316
317 Ok(flushed)
318 }
319
320 pub async fn has_dirty(&self) -> bool {
322 !self.dirty.read().await.is_empty()
323 }
324
325 pub async fn dirty_count(&self) -> usize {
327 self.dirty.read().await.len()
328 }
329
330 pub async fn delete(&self, peer_id: &str) -> Result<()> {
333 {
334 let mut cache = self.cache.write().await;
335 cache.remove(peer_id);
336 }
337
338 let pool = &self.pool;
339 let peer_id_owned = peer_id.to_string();
340
341 execute_with_retry("cursor_delete", || async {
342 sqlx::query("DELETE FROM cursors WHERE peer_id = ?")
343 .bind(&peer_id_owned)
344 .execute(pool)
345 .await
346 })
347 .await?;
348
349 info!(peer_id = %peer_id, "Deleted cursor");
350 Ok(())
351 }
352
353 pub async fn get_all(&self) -> HashMap<String, String> {
355 self.cache.read().await.clone()
356 }
357
358 pub fn path(&self) -> &str {
360 &self.path
361 }
362
363 pub async fn checkpoint(&self) -> Result<()> {
366 let pool = &self.pool;
367
368 execute_with_retry("cursor_checkpoint", || async {
369 sqlx::query("PRAGMA wal_checkpoint(TRUNCATE)")
370 .execute(pool)
371 .await
372 })
373 .await?;
374
375 debug!("WAL checkpoint complete");
376 Ok(())
377 }
378
379 pub async fn close(&self) {
383 if self.has_dirty().await {
385 match self.flush_dirty().await {
386 Ok(count) => {
387 if count > 0 {
388 info!(count, "Flushed dirty cursors on close");
389 }
390 }
391 Err(e) => {
392 warn!(error = %e, "Failed to flush dirty cursors on close");
393 }
394 }
395 }
396
397 if let Err(e) = self.checkpoint().await {
399 warn!(error = %e, "Failed to checkpoint WAL on close");
400 }
401 self.pool.close().await;
402 info!("Cursor store closed");
403 }
404}
405
406#[cfg(test)]
407mod tests {
408 use super::*;
409 use tempfile::tempdir;
410
411 #[tokio::test]
412 async fn test_cursor_store_basic() {
413 let dir = tempdir().unwrap();
414 let db_path = dir.path().join("test_cursors.db");
415
416 let store = CursorStore::new(&db_path).await.unwrap();
417
418 assert!(store.get("peer1").await.is_none());
420 assert_eq!(store.get_or_start("peer1").await, "0");
421
422 store.set("peer1", "1234567890123-0").await;
424 assert_eq!(store.get("peer1").await, Some("1234567890123-0".to_string()));
425 assert!(store.has_dirty().await);
426
427 store.set("peer1", "1234567890124-0").await;
429 assert_eq!(store.get("peer1").await, Some("1234567890124-0".to_string()));
430
431 let flushed = store.flush_dirty().await.unwrap();
433 assert_eq!(flushed, 1);
434 assert!(!store.has_dirty().await);
435
436 store.close().await;
437 }
438
439 #[tokio::test]
440 async fn test_cursor_store_persistence() {
441 let dir = tempdir().unwrap();
442 let db_path = dir.path().join("test_persist.db");
443
444 {
446 let store = CursorStore::new(&db_path).await.unwrap();
447 store.set("peer1", "9999-0").await;
448 store.flush_dirty().await.unwrap(); store.close().await;
450 }
451
452 {
454 let store = CursorStore::new(&db_path).await.unwrap();
455 assert_eq!(store.get("peer1").await, Some("9999-0".to_string()));
456 store.close().await;
457 }
458 }
459
460 #[tokio::test]
461 async fn test_cursor_store_delete() {
462 let dir = tempdir().unwrap();
463 let db_path = dir.path().join("test_delete.db");
464
465 let store = CursorStore::new(&db_path).await.unwrap();
466 store.set("peer1", "1234-0").await;
467 store.set("peer2", "5678-0").await;
468 store.flush_dirty().await.unwrap();
469
470 store.delete("peer1").await.unwrap();
471
472 assert!(store.get("peer1").await.is_none());
473 assert_eq!(store.get("peer2").await, Some("5678-0".to_string()));
474
475 store.close().await;
476 }
477
478 #[tokio::test]
479 async fn test_cursor_debounce_multiple_updates() {
480 let dir = tempdir().unwrap();
481 let db_path = dir.path().join("test_debounce.db");
482
483 let store = CursorStore::new(&db_path).await.unwrap();
484
485 store.set("peer1", "100-0").await;
487 store.set("peer1", "200-0").await;
488 store.set("peer1", "300-0").await;
489
490 assert_eq!(store.dirty_count().await, 1);
492
493 assert_eq!(store.get("peer1").await, Some("300-0".to_string()));
495
496 let flushed = store.flush_dirty().await.unwrap();
498 assert_eq!(flushed, 1);
499
500 store.close().await;
501 }
502
503 #[tokio::test]
504 async fn test_execute_with_retry_succeeds_immediately() {
505 let mut attempt_count = 0;
506
507 let result: std::result::Result<i32, sqlx::Error> =
508 execute_with_retry("test_op", || {
509 attempt_count += 1;
510 async { Ok(42) }
511 })
512 .await;
513
514 assert_eq!(result.unwrap(), 42);
515 assert_eq!(attempt_count, 1);
516 }
517
518 #[tokio::test]
519 async fn test_execute_with_retry_fails_on_non_busy_error() {
520 let mut attempt_count = 0;
521
522 let result: std::result::Result<i32, sqlx::Error> =
523 execute_with_retry("test_op", || {
524 attempt_count += 1;
525 async { Err(sqlx::Error::RowNotFound) }
526 })
527 .await;
528
529 assert!(result.is_err());
530 assert_eq!(attempt_count, 1);
532 }
533
534 #[tokio::test]
535 async fn test_cursor_store_get_all() {
536 let dir = tempdir().unwrap();
537 let db_path = dir.path().join("test_get_all.db");
538
539 let store = CursorStore::new(&db_path).await.unwrap();
540
541 store.set("peer1", "100-0").await;
542 store.set("peer2", "200-0").await;
543 store.set("peer3", "300-0").await;
544
545 let all = store.get_all().await;
546 assert_eq!(all.len(), 3);
547 assert_eq!(all.get("peer1"), Some(&"100-0".to_string()));
548 assert_eq!(all.get("peer2"), Some(&"200-0".to_string()));
549 assert_eq!(all.get("peer3"), Some(&"300-0".to_string()));
550
551 store.close().await;
552 }
553
554 #[tokio::test]
555 async fn test_cursor_store_path() {
556 let dir = tempdir().unwrap();
557 let db_path = dir.path().join("test_path.db");
558
559 let store = CursorStore::new(&db_path).await.unwrap();
560 assert!(store.path().contains("test_path.db"));
561
562 store.close().await;
563 }
564
565 #[tokio::test]
566 async fn test_cursor_store_checkpoint() {
567 let dir = tempdir().unwrap();
568 let db_path = dir.path().join("test_checkpoint.db");
569
570 let store = CursorStore::new(&db_path).await.unwrap();
571 store.set("peer1", "100-0").await;
572 store.flush_dirty().await.unwrap();
573
574 let result = store.checkpoint().await;
576 assert!(result.is_ok());
577
578 store.close().await;
579 }
580
581 #[tokio::test]
582 async fn test_cursor_store_dirty_count() {
583 let dir = tempdir().unwrap();
584 let db_path = dir.path().join("test_dirty_count.db");
585
586 let store = CursorStore::new(&db_path).await.unwrap();
587
588 assert_eq!(store.dirty_count().await, 0);
589 assert!(!store.has_dirty().await);
590
591 store.set("peer1", "100-0").await;
592 assert_eq!(store.dirty_count().await, 1);
593 assert!(store.has_dirty().await);
594
595 store.set("peer2", "200-0").await;
596 assert_eq!(store.dirty_count().await, 2);
597
598 store.set("peer1", "150-0").await;
600 assert_eq!(store.dirty_count().await, 2);
601
602 store.flush_dirty().await.unwrap();
603 assert_eq!(store.dirty_count().await, 0);
604 assert!(!store.has_dirty().await);
605
606 store.close().await;
607 }
608
609 #[tokio::test]
610 async fn test_cursor_store_close_flushes_dirty() {
611 let dir = tempdir().unwrap();
612 let db_path = dir.path().join("test_close_flush.db");
613
614 {
616 let store = CursorStore::new(&db_path).await.unwrap();
617 store.set("peer1", "999-0").await;
618 store.close().await;
620 }
621
622 {
624 let store = CursorStore::new(&db_path).await.unwrap();
625 assert_eq!(store.get("peer1").await, Some("999-0".to_string()));
626 store.close().await;
627 }
628 }
629
630 #[tokio::test]
631 async fn test_cursor_store_get_or_start() {
632 let dir = tempdir().unwrap();
633 let db_path = dir.path().join("test_get_or_start.db");
634
635 let store = CursorStore::new(&db_path).await.unwrap();
636
637 assert_eq!(store.get_or_start("new_peer").await, "0");
639
640 store.set("new_peer", "123-0").await;
642 assert_eq!(store.get_or_start("new_peer").await, "123-0");
643
644 store.close().await;
645 }
646
647 #[tokio::test]
648 async fn test_cursor_store_multiple_peers() {
649 let dir = tempdir().unwrap();
650 let db_path = dir.path().join("test_multi_peer.db");
651
652 let store = CursorStore::new(&db_path).await.unwrap();
653
654 for i in 0..10 {
656 store.set(&format!("peer{}", i), &format!("{}-0", i * 100)).await;
657 }
658
659 assert_eq!(store.dirty_count().await, 10);
660
661 let flushed = store.flush_dirty().await.unwrap();
662 assert_eq!(flushed, 10);
663
664 for i in 0..10 {
666 let expected = format!("{}-0", i * 100);
667 assert_eq!(store.get(&format!("peer{}", i)).await, Some(expected));
668 }
669
670 store.close().await;
671 }
672
673 #[tokio::test]
674 async fn test_cursor_store_delete_nonexistent() {
675 let dir = tempdir().unwrap();
676 let db_path = dir.path().join("test_delete_nonexistent.db");
677
678 let store = CursorStore::new(&db_path).await.unwrap();
679
680 let result = store.delete("nonexistent").await;
682 assert!(result.is_ok());
683
684 store.close().await;
685 }
686
687 #[test]
688 fn test_is_sqlite_busy_error_row_not_found() {
689 let error = sqlx::Error::RowNotFound;
690 assert!(!is_sqlite_busy_error(&error));
691 }
692
693 #[test]
694 fn test_is_sqlite_busy_error_pool_timed_out() {
695 let error = sqlx::Error::PoolTimedOut;
696 assert!(!is_sqlite_busy_error(&error));
697 }
698}