1use std::sync::{Arc, RwLock};
12
13use tokio::sync::Notify;
14use tokio::task::JoinHandle;
15use tracing::{debug, error, info, instrument, warn};
16
17use crate::db::PulseDB;
18
19use super::applier::RemoteChangeApplier;
20use super::config::{SyncConfig, SyncDirection};
21use super::error::SyncError;
22use super::progress::SyncProgressCallback;
23use super::pusher::LocalChangePusher;
24use super::transport::SyncTransport;
25use super::types::{HandshakeRequest, InstanceId, PullRequest, SyncCursor, SyncStatus};
26use super::SYNC_PROTOCOL_VERSION;
27
28pub struct SyncManager {
51 db: Arc<PulseDB>,
52 transport: Arc<dyn SyncTransport>,
53 config: SyncConfig,
54 local_instance_id: InstanceId,
55 peer_instance_id: Option<InstanceId>,
56 status: Arc<RwLock<SyncStatus>>,
57 shutdown: Arc<Notify>,
58 task_handle: Option<JoinHandle<()>>,
59}
60
61impl SyncManager {
62 pub fn new(db: Arc<PulseDB>, transport: Box<dyn SyncTransport>, config: SyncConfig) -> Self {
67 let local_instance_id = db.storage_for_test().instance_id();
68 Self {
69 db,
70 transport: Arc::from(transport),
71 config,
72 local_instance_id,
73 peer_instance_id: None,
74 status: Arc::new(RwLock::new(SyncStatus::Idle)),
75 shutdown: Arc::new(Notify::new()),
76 task_handle: None,
77 }
78 }
79
80 #[instrument(skip(self), fields(instance_id = %self.local_instance_id))]
85 pub async fn start(&mut self) -> Result<(), SyncError> {
86 if self.task_handle.is_some() {
87 return Err(SyncError::transport("SyncManager already started"));
88 }
89
90 let peer_id = self.perform_handshake().await?;
92 self.peer_instance_id = Some(peer_id);
93
94 self.set_status(SyncStatus::Syncing);
95
96 let db = Arc::clone(&self.db);
98 let transport = Arc::clone(&self.transport);
99 let config = self.config.clone();
100 let local_id = self.local_instance_id;
101 let status = Arc::clone(&self.status);
102 let shutdown = Arc::clone(&self.shutdown);
103
104 let handle = tokio::spawn(async move {
105 Self::background_loop(db, transport, config, local_id, peer_id, status, shutdown).await;
106 });
107
108 self.task_handle = Some(handle);
109 info!("SyncManager started");
110 Ok(())
111 }
112
113 #[instrument(skip(self))]
115 pub async fn stop(&mut self) -> Result<(), SyncError> {
116 if let Some(handle) = self.task_handle.take() {
117 self.shutdown.notify_one();
118 handle
119 .await
120 .map_err(|e| SyncError::transport(format!("Background task panicked: {}", e)))?;
121 self.set_status(SyncStatus::Idle);
122 info!("SyncManager stopped");
123 }
124 Ok(())
125 }
126
127 #[instrument(skip(self))]
131 pub async fn sync_once(&mut self) -> Result<SyncStatus, SyncError> {
132 if self.peer_instance_id.is_none() {
134 let peer_id = self.perform_handshake().await?;
135 self.peer_instance_id = Some(peer_id);
136 }
137 let peer_id = self.peer_instance_id.unwrap();
138
139 self.set_status(SyncStatus::Syncing);
140
141 let push_seq = self.load_cursor_sequence(peer_id)?;
143 let mut pusher = LocalChangePusher::new(
144 Arc::clone(&self.db),
145 Arc::clone(&self.transport),
146 self.config.clone(),
147 self.local_instance_id,
148 peer_id,
149 push_seq,
150 );
151
152 let applier = RemoteChangeApplier::new(Arc::clone(&self.db), self.config.clone());
153
154 let pushed = if matches!(
156 self.config.direction,
157 SyncDirection::PushOnly | SyncDirection::Bidirectional
158 ) {
159 pusher.push_pending().await?
160 } else {
161 0
162 };
163
164 let pulled = if matches!(
166 self.config.direction,
167 SyncDirection::PullOnly | SyncDirection::Bidirectional
168 ) {
169 self.pull_and_apply(&applier, peer_id).await?
170 } else {
171 0
172 };
173
174 self.set_status(SyncStatus::Idle);
175
176 debug!(pushed, pulled, "sync_once complete");
177 Ok(SyncStatus::Idle)
178 }
179
180 #[instrument(skip(self, progress))]
184 pub async fn initial_sync(
185 &mut self,
186 progress: Option<Box<dyn SyncProgressCallback>>,
187 ) -> Result<(), SyncError> {
188 if self.peer_instance_id.is_none() {
190 let peer_id = self.perform_handshake().await?;
191 self.peer_instance_id = Some(peer_id);
192 }
193 let peer_id = self.peer_instance_id.unwrap();
194
195 self.set_status(SyncStatus::Syncing);
196
197 let applier = RemoteChangeApplier::new(Arc::clone(&self.db), self.config.clone());
198
199 let mut total_pulled = 0usize;
200 let mut cursor = SyncCursor {
201 instance_id: peer_id,
202 last_sequence: self.load_cursor_sequence(peer_id)?,
203 };
204
205 loop {
206 let pull_request = PullRequest {
207 cursor: cursor.clone(),
208 batch_size: self.config.batch_size,
209 collectives: self.config.collectives.clone(),
210 };
211
212 let response = self.transport.pull_changes(pull_request).await?;
213 let batch_size = response.changes.len();
214
215 if batch_size > 0 {
216 applier.apply_batch(response.changes)?;
217 }
218
219 total_pulled += batch_size;
220 cursor = response.new_cursor;
221
222 self.save_cursor(&cursor)?;
224
225 if let Some(ref cb) = progress {
226 cb.on_progress(batch_size, total_pulled, response.has_more);
227 }
228
229 if !response.has_more {
230 break;
231 }
232 }
233
234 self.set_status(SyncStatus::Idle);
235 info!(total_pulled, "Initial sync complete");
236 Ok(())
237 }
238
239 pub fn status(&self) -> SyncStatus {
241 self.status
242 .read()
243 .unwrap_or_else(|e| e.into_inner())
244 .clone()
245 }
246
247 #[instrument(skip(self))]
250 async fn perform_handshake(&self) -> Result<InstanceId, SyncError> {
251 let request = HandshakeRequest {
252 instance_id: self.local_instance_id,
253 protocol_version: SYNC_PROTOCOL_VERSION,
254 capabilities: vec!["push".into(), "pull".into()],
255 };
256
257 let response = self.transport.handshake(request).await?;
258
259 if !response.accepted {
260 return Err(SyncError::handshake(
261 response.reason.unwrap_or_else(|| "rejected".into()),
262 ));
263 }
264
265 if response.protocol_version != SYNC_PROTOCOL_VERSION {
266 return Err(SyncError::ProtocolVersion {
267 local: SYNC_PROTOCOL_VERSION,
268 remote: response.protocol_version,
269 });
270 }
271
272 debug!(peer = %response.instance_id, "Handshake accepted");
273 Ok(response.instance_id)
274 }
275
276 fn set_status(&self, status: SyncStatus) {
277 if let Ok(mut s) = self.status.write() {
278 *s = status;
279 }
280 }
281
282 fn load_cursor_sequence(&self, peer_id: InstanceId) -> Result<u64, SyncError> {
283 self.db
284 .storage_for_test()
285 .load_sync_cursor(&peer_id)
286 .map_err(|e| SyncError::transport(format!("Failed to load cursor: {}", e)))
287 .map(|opt| opt.map_or(0, |c| c.last_sequence))
288 }
289
290 fn save_cursor(&self, cursor: &SyncCursor) -> Result<(), SyncError> {
291 self.db
292 .storage_for_test()
293 .save_sync_cursor(cursor)
294 .map_err(|e| SyncError::transport(format!("Failed to save cursor: {}", e)))
295 }
296
297 async fn pull_and_apply(
299 &self,
300 applier: &RemoteChangeApplier,
301 peer_id: InstanceId,
302 ) -> Result<usize, SyncError> {
303 let cursor_seq = self.load_cursor_sequence(peer_id)?;
304 let pull_request = PullRequest {
305 cursor: SyncCursor {
306 instance_id: peer_id,
307 last_sequence: cursor_seq,
308 },
309 batch_size: self.config.batch_size,
310 collectives: self.config.collectives.clone(),
311 };
312
313 let response = self.transport.pull_changes(pull_request).await?;
314 let count = response.changes.len();
315
316 if count > 0 {
317 applier.apply_batch(response.changes)?;
318 self.save_cursor(&response.new_cursor)?;
319 }
320
321 Ok(count)
322 }
323
324 async fn background_loop(
326 db: Arc<PulseDB>,
327 transport: Arc<dyn SyncTransport>,
328 config: SyncConfig,
329 local_id: InstanceId,
330 peer_id: InstanceId,
331 status: Arc<RwLock<SyncStatus>>,
332 shutdown: Arc<Notify>,
333 ) {
334 let interval_ms = std::cmp::max(config.push_interval_ms, config.pull_interval_ms);
335 let interval = tokio::time::Duration::from_millis(interval_ms);
336
337 let mut consecutive_failures = 0u32;
338 let max_retries = config.retry.max_retries;
339 let initial_backoff = config.retry.initial_backoff_ms;
340 let max_backoff = config.retry.max_backoff_ms;
341 let multiplier = config.retry.backoff_multiplier;
342
343 loop {
344 let sleep_duration = if consecutive_failures > 0 {
345 let backoff = (initial_backoff as f64)
347 * multiplier.powi(consecutive_failures.saturating_sub(1) as i32);
348 let backoff_ms = (backoff as u64).min(max_backoff);
349 tokio::time::Duration::from_millis(backoff_ms)
350 } else {
351 interval
352 };
353
354 tokio::select! {
355 _ = shutdown.notified() => {
356 debug!("Sync background loop shutting down");
357 break;
358 }
359 _ = tokio::time::sleep(sleep_duration) => {
360 let push_seq = db
362 .storage_for_test()
363 .load_sync_cursor(&peer_id)
364 .unwrap_or(None)
365 .map_or(0, |c| c.last_sequence);
366
367 let mut pusher = LocalChangePusher::new(
368 Arc::clone(&db),
369 Arc::clone(&transport),
370 config.clone(),
371 local_id,
372 peer_id,
373 push_seq,
374 );
375 let applier = RemoteChangeApplier::new(Arc::clone(&db), config.clone());
376
377 let result = Self::run_sync_cycle(&mut pusher, &applier, &transport, &db, &config, peer_id).await;
378
379 match result {
380 Ok(_) => {
381 if consecutive_failures > 0 {
382 info!("Sync recovered after {} failures", consecutive_failures);
383 }
384 consecutive_failures = 0;
385 if let Ok(mut s) = status.write() {
386 *s = SyncStatus::Syncing;
387 }
388 }
389 Err(e) => {
390 consecutive_failures += 1;
391 if consecutive_failures > max_retries {
392 warn!(
393 failures = consecutive_failures,
394 "Sync errors exceed max_retries, continuing with backoff"
395 );
396 }
397 error!("Sync cycle failed: {}", e);
398 if let Ok(mut s) = status.write() {
399 *s = SyncStatus::Error(e.to_string());
400 }
401 }
402 }
403 }
404 }
405 }
406
407 if let Ok(mut s) = status.write() {
408 *s = SyncStatus::Idle;
409 }
410 }
411
412 async fn run_sync_cycle(
414 pusher: &mut LocalChangePusher,
415 applier: &RemoteChangeApplier,
416 transport: &Arc<dyn SyncTransport>,
417 db: &Arc<PulseDB>,
418 config: &SyncConfig,
419 peer_id: InstanceId,
420 ) -> Result<(), SyncError> {
421 if matches!(
423 config.direction,
424 SyncDirection::PushOnly | SyncDirection::Bidirectional
425 ) {
426 pusher.push_pending().await?;
427 }
428
429 if matches!(
431 config.direction,
432 SyncDirection::PullOnly | SyncDirection::Bidirectional
433 ) {
434 let cursor_seq = db
435 .storage_for_test()
436 .load_sync_cursor(&peer_id)
437 .map_err(|e| SyncError::transport(format!("cursor load: {}", e)))?
438 .map_or(0, |c| c.last_sequence);
439
440 let pull_request = PullRequest {
441 cursor: SyncCursor {
442 instance_id: peer_id,
443 last_sequence: cursor_seq,
444 },
445 batch_size: config.batch_size,
446 collectives: config.collectives.clone(),
447 };
448
449 let response = transport.pull_changes(pull_request).await?;
450 let count = response.changes.len();
451
452 if count > 0 {
453 applier.apply_batch(response.changes)?;
454 let cursor = response.new_cursor;
455 db.storage_for_test()
456 .save_sync_cursor(&cursor)
457 .map_err(|e| SyncError::transport(format!("cursor save: {}", e)))?;
458 }
459 }
460
461 Ok(())
462 }
463}