1#[cfg(feature = "automerge-backend")]
81use anyhow::{Context, Result};
82#[cfg(feature = "automerge-backend")]
83use iroh::endpoint::{RecvStream, SendStream};
84#[cfg(feature = "automerge-backend")]
85use iroh::EndpointId;
86#[cfg(feature = "automerge-backend")]
87use std::collections::HashMap;
88#[cfg(feature = "automerge-backend")]
89use std::sync::atomic::{AtomicU32, AtomicU64, Ordering};
90#[cfg(feature = "automerge-backend")]
91use std::sync::{Arc, RwLock};
92#[cfg(feature = "automerge-backend")]
93use std::time::{Duration, Instant};
94#[cfg(feature = "automerge-backend")]
95use tokio::sync::Mutex;
96#[cfg(feature = "automerge-backend")]
97use tokio::task::JoinHandle;
98
99#[cfg(feature = "automerge-backend")]
100use super::automerge_sync::{AutomergeSyncCoordinator, SyncBatch, SyncMessageType};
101#[cfg(feature = "automerge-backend")]
102use super::sync_transport::SyncTransport;
103
104#[cfg(feature = "automerge-backend")]
106#[derive(Debug, Clone, Copy, PartialEq, Eq)]
107pub enum ChannelState {
108 Connected,
110 Reconnecting,
112 Closed,
114}
115
116#[cfg(feature = "automerge-backend")]
121pub struct SyncChannel {
122 peer_id: EndpointId,
124 transport: Arc<dyn SyncTransport>,
126 send: Arc<Mutex<Option<SendStream>>>,
128 recv_task: Arc<Mutex<Option<JoinHandle<()>>>>,
130 state: Arc<RwLock<ChannelState>>,
132 reconnect_attempts: AtomicU32,
134 last_send: Arc<RwLock<Instant>>,
136 bytes_sent: AtomicU64,
138 batches_sent: AtomicU64,
140}
141
142#[cfg(feature = "automerge-backend")]
143impl SyncChannel {
144 const MAX_RECONNECT_ATTEMPTS: u32 = 3;
146 const RECONNECT_DELAY: Duration = Duration::from_millis(500);
148 const RECV_TIMEOUT: Duration = Duration::from_secs(30);
151
152 pub async fn connect(
158 transport: Arc<dyn SyncTransport>,
159 peer_id: EndpointId,
160 coordinator: Arc<AutomergeSyncCoordinator>,
161 ) -> Result<Self> {
162 Self::connect_with_token(transport, peer_id, coordinator, None).await
163 }
164
165 pub async fn connect_with_token(
169 transport: Arc<dyn SyncTransport>,
170 peer_id: EndpointId,
171 coordinator: Arc<AutomergeSyncCoordinator>,
172 cancel: Option<tokio_util::sync::CancellationToken>,
173 ) -> Result<Self> {
174 let conn = transport.get_or_connect(&peer_id).await?;
176
177 let (send, recv) = conn
179 .open_bi()
180 .await
181 .context("Failed to open bidirectional stream")?;
182
183 let channel = Self {
184 peer_id,
185 transport,
186 send: Arc::new(Mutex::new(Some(send))),
187 recv_task: Arc::new(Mutex::new(None)),
188 state: Arc::new(RwLock::new(ChannelState::Connected)),
189 reconnect_attempts: AtomicU32::new(0),
190 last_send: Arc::new(RwLock::new(Instant::now())),
191 bytes_sent: AtomicU64::new(0),
192 batches_sent: AtomicU64::new(0),
193 };
194
195 channel.spawn_receiver(recv, coordinator, cancel);
197
198 tracing::debug!("Sync channel connected to peer {:?}", peer_id);
199 Ok(channel)
200 }
201
202 fn spawn_receiver(
204 &self,
205 recv: RecvStream,
206 coordinator: Arc<AutomergeSyncCoordinator>,
207 cancel: Option<tokio_util::sync::CancellationToken>,
208 ) {
209 let peer_id = self.peer_id;
210 let state = Arc::clone(&self.state);
211 let recv_task = Arc::clone(&self.recv_task);
212
213 let task = tokio::spawn(async move {
214 tracing::debug!("Sync channel receiver started for peer {:?}", peer_id);
215
216 if let Err(e) = Self::receive_loop(recv, peer_id, coordinator, cancel).await {
217 tracing::warn!(
218 "Sync channel receiver for peer {:?} ended with error: {}",
219 peer_id,
220 e
221 );
222 }
223
224 *state.write().unwrap_or_else(|e| e.into_inner()) = ChannelState::Reconnecting;
226 tracing::debug!("Sync channel receiver ended for peer {:?}", peer_id);
227 });
228
229 tokio::spawn(async move {
231 *recv_task.lock().await = Some(task);
232 });
233 }
234
235 async fn receive_loop(
241 mut recv: RecvStream,
242 peer_id: EndpointId,
243 coordinator: Arc<AutomergeSyncCoordinator>,
244 cancel: Option<tokio_util::sync::CancellationToken>,
245 ) -> Result<()> {
246 loop {
247 if let Some(ref token) = cancel {
249 if token.is_cancelled() {
250 tracing::debug!("Sync channel receive loop for peer {:?} cancelled", peer_id);
251 return Ok(());
252 }
253 }
254
255 let mut marker = [0u8; 1];
257 let read_result = if let Some(ref token) = cancel {
258 tokio::select! {
259 res = tokio::time::timeout(Self::RECV_TIMEOUT, recv.read_exact(&mut marker)) => res,
260 () = token.cancelled() => {
261 tracing::debug!(
262 "Sync channel receive loop for peer {:?} cancelled during read",
263 peer_id
264 );
265 return Ok(());
266 }
267 }
268 } else {
269 tokio::time::timeout(Self::RECV_TIMEOUT, recv.read_exact(&mut marker)).await
270 };
271
272 match read_result {
273 Ok(Ok(_)) => {}
274 Ok(Err(e)) => {
275 return Err(anyhow::anyhow!("Stream read error: {}", e));
277 }
278 Err(_) => {
279 tracing::warn!(
280 "Sync channel receive timeout for peer {:?} (no data for {:?})",
281 peer_id,
282 Self::RECV_TIMEOUT,
283 );
284 return Err(anyhow::anyhow!(
285 "Receive timeout waiting for message marker"
286 ));
287 }
288 }
289
290 if marker[0] != SyncMessageType::SyncBatch as u8 {
292 tracing::warn!(
293 "Unexpected message type on sync channel: 0x{:02x}",
294 marker[0]
295 );
296 continue;
297 }
298
299 let mut len_bytes = [0u8; 4];
301 tokio::time::timeout(Self::RECV_TIMEOUT, recv.read_exact(&mut len_bytes))
302 .await
303 .map_err(|_| {
304 tracing::warn!(
305 "Sync channel receive timeout for peer {:?} reading batch length",
306 peer_id,
307 );
308 anyhow::anyhow!("Receive timeout reading batch length")
309 })?
310 .context("Failed to read batch length")?;
311 let batch_len = u32::from_be_bytes(len_bytes) as usize;
312
313 let mut batch_data = vec![0u8; batch_len];
315 tokio::time::timeout(Self::RECV_TIMEOUT, recv.read_exact(&mut batch_data))
316 .await
317 .map_err(|_| {
318 tracing::warn!(
319 "Sync channel receive timeout for peer {:?} reading batch data ({} bytes)",
320 peer_id,
321 batch_len,
322 );
323 anyhow::anyhow!("Receive timeout reading batch data")
324 })?
325 .context("Failed to read batch data")?;
326
327 match SyncBatch::decode(&batch_data) {
329 Ok(batch) => {
330 let total_bytes = 1 + 4 + batch_len; if let Err(e) = coordinator
332 .receive_batch_message(peer_id, batch, total_bytes)
333 .await
334 {
335 tracing::warn!("Failed to process batch from peer {:?}: {}", peer_id, e);
336 }
337 }
338 Err(e) => {
339 tracing::warn!("Failed to decode batch from peer {:?}: {}", peer_id, e);
340 }
341 }
342 }
343 }
344
345 pub async fn send(&self, batch: &SyncBatch) -> Result<()> {
352 let needs_reconnect = {
354 let state = *self.state.read().unwrap_or_else(|e| e.into_inner());
355 match state {
356 ChannelState::Closed => return Err(anyhow::anyhow!("Channel is closed")),
357 ChannelState::Reconnecting => true,
358 ChannelState::Connected => false,
359 }
360 };
361
362 if needs_reconnect {
364 self.reconnect().await?;
365 }
366
367 let batch_bytes = batch.encode();
368 let mut send_guard = self.send.lock().await;
369
370 let send = send_guard
371 .as_mut()
372 .ok_or_else(|| anyhow::anyhow!("No send stream available"))?;
373
374 let doc_key = b"batch";
376 send.write_all(&(doc_key.len() as u16).to_be_bytes())
377 .await
378 .context("Failed to write doc_key length")?;
379
380 send.write_all(doc_key)
382 .await
383 .context("Failed to write doc_key")?;
384
385 send.write_all(&[SyncMessageType::SyncBatch as u8])
387 .await
388 .context("Failed to write batch marker")?;
389
390 let batch_len = batch_bytes.len() as u32;
392 send.write_all(&batch_len.to_be_bytes())
393 .await
394 .context("Failed to write batch length")?;
395
396 send.write_all(&batch_bytes)
398 .await
399 .context("Failed to write batch data")?;
400
401 let total_bytes = 2 + doc_key.len() + 1 + 4 + batch_bytes.len();
403 self.bytes_sent
404 .fetch_add(total_bytes as u64, Ordering::Relaxed);
405 self.batches_sent.fetch_add(1, Ordering::Relaxed);
406 *self.last_send.write().unwrap_or_else(|e| e.into_inner()) = Instant::now();
407
408 tracing::trace!(
409 "Sent batch {} ({} entries, {} bytes) to peer {:?}",
410 batch.batch_id,
411 batch.len(),
412 total_bytes,
413 self.peer_id
414 );
415
416 Ok(())
417 }
418
419 pub async fn reconnect(&self) -> Result<()> {
421 *self.state.write().unwrap_or_else(|e| e.into_inner()) = ChannelState::Reconnecting;
423
424 let attempts = self.reconnect_attempts.fetch_add(1, Ordering::Relaxed);
425 if attempts >= Self::MAX_RECONNECT_ATTEMPTS {
426 *self.state.write().unwrap_or_else(|e| e.into_inner()) = ChannelState::Closed;
427 return Err(anyhow::anyhow!(
428 "Max reconnection attempts ({}) exceeded",
429 Self::MAX_RECONNECT_ATTEMPTS
430 ));
431 }
432
433 tracing::info!(
434 "Attempting reconnection to peer {:?} (attempt {})",
435 self.peer_id,
436 attempts + 1
437 );
438
439 tokio::time::sleep(Self::RECONNECT_DELAY).await;
441
442 let conn = self.transport.get_or_connect(&self.peer_id).await?;
444
445 let (send, mut recv) = conn
447 .open_bi()
448 .await
449 .context("Failed to open bidirectional stream for reconnection")?;
450
451 let _ = recv.stop(0u32.into());
454
455 *self.send.lock().await = Some(send);
457 *self.state.write().unwrap_or_else(|e| e.into_inner()) = ChannelState::Connected;
458 self.reconnect_attempts.store(0, Ordering::Relaxed);
459
460 tracing::info!("Reconnected sync channel to peer {:?}", self.peer_id);
461 Ok(())
462 }
463
464 pub fn is_connected(&self) -> bool {
466 *self.state.read().unwrap_or_else(|e| e.into_inner()) == ChannelState::Connected
467 }
468
469 pub fn state(&self) -> ChannelState {
471 *self.state.read().unwrap_or_else(|e| e.into_inner())
472 }
473
474 pub fn peer_id(&self) -> EndpointId {
476 self.peer_id
477 }
478
479 pub fn bytes_sent(&self) -> u64 {
481 self.bytes_sent.load(Ordering::Relaxed)
482 }
483
484 pub fn batches_sent(&self) -> u64 {
486 self.batches_sent.load(Ordering::Relaxed)
487 }
488
489 pub async fn close(&self) {
491 *self.state.write().unwrap_or_else(|e| e.into_inner()) = ChannelState::Closed;
492
493 if let Some(task) = self.recv_task.lock().await.take() {
495 task.abort();
496 }
497
498 if let Some(mut send) = self.send.lock().await.take() {
500 let _ = send.finish();
501 }
502
503 tracing::debug!("Sync channel to peer {:?} closed", self.peer_id);
504 }
505}
506
507#[cfg(feature = "automerge-backend")]
511pub struct SyncChannelManager {
512 channels: Arc<RwLock<HashMap<EndpointId, Arc<SyncChannel>>>>,
514 transport: Arc<dyn SyncTransport>,
516 coordinator: Arc<AutomergeSyncCoordinator>,
518 active: Arc<std::sync::atomic::AtomicBool>,
520}
521
522#[cfg(feature = "automerge-backend")]
523impl SyncChannelManager {
524 pub fn new(
526 transport: Arc<dyn SyncTransport>,
527 coordinator: Arc<AutomergeSyncCoordinator>,
528 ) -> Self {
529 Self {
530 channels: Arc::new(RwLock::new(HashMap::new())),
531 transport,
532 coordinator,
533 active: Arc::new(std::sync::atomic::AtomicBool::new(true)),
534 }
535 }
536
537 pub async fn get_channel(&self, peer_id: EndpointId) -> Result<Arc<SyncChannel>> {
539 {
541 let channels = self.channels.read().unwrap_or_else(|e| e.into_inner());
542 if let Some(channel) = channels.get(&peer_id) {
543 if channel.is_connected() {
544 return Ok(Arc::clone(channel));
545 }
546 }
547 }
548
549 let channel = SyncChannel::connect(
551 Arc::clone(&self.transport),
552 peer_id,
553 Arc::clone(&self.coordinator),
554 )
555 .await?;
556
557 let channel = Arc::new(channel);
558 self.channels
559 .write()
560 .unwrap()
561 .insert(peer_id, Arc::clone(&channel));
562
563 Ok(channel)
564 }
565
566 pub async fn send_to_peer(&self, peer_id: EndpointId, batch: &SyncBatch) -> Result<()> {
568 let channel = self.get_channel(peer_id).await?;
569 channel.send(batch).await
570 }
571
572 pub async fn broadcast(&self, batch: &SyncBatch) -> Result<()> {
574 let peer_ids = self.transport.connected_peers();
575
576 for peer_id in peer_ids {
577 if let Err(e) = self.send_to_peer(peer_id, batch).await {
578 tracing::warn!("Failed to send batch to peer {:?}: {}", peer_id, e);
579 }
580 }
581
582 Ok(())
583 }
584
585 pub async fn send_delta_sync(
589 &self,
590 peer_id: EndpointId,
591 doc_key: &str,
592 message: &automerge::sync::Message,
593 ) -> Result<usize> {
594 let encoded = message.clone().encode();
595 let payload_len = encoded.len();
596
597 let mut batch = SyncBatch::new();
598 batch.entries.push(super::automerge_sync::SyncEntry::new(
599 doc_key.to_string(),
600 SyncMessageType::DeltaSync,
601 encoded,
602 ));
603
604 self.send_to_peer(peer_id, &batch).await?;
605
606 Ok(1 + 4 + batch.encode().len() + payload_len)
608 }
609
610 pub async fn send_state_snapshot(
612 &self,
613 peer_id: EndpointId,
614 doc_key: &str,
615 state_bytes: Vec<u8>,
616 ) -> Result<usize> {
617 let payload_len = state_bytes.len();
618
619 let mut batch = SyncBatch::new();
620 batch.entries.push(super::automerge_sync::SyncEntry::new(
621 doc_key.to_string(),
622 SyncMessageType::StateSnapshot,
623 state_bytes,
624 ));
625
626 self.send_to_peer(peer_id, &batch).await?;
627
628 Ok(1 + 4 + batch.encode().len() + payload_len)
629 }
630
631 pub async fn send_tombstone(
633 &self,
634 peer_id: EndpointId,
635 tombstone_msg: &crate::qos::TombstoneSyncMessage,
636 ) -> Result<usize> {
637 let mut batch = SyncBatch::new();
638 batch.add_tombstone(tombstone_msg);
639
640 let batch_bytes = batch.encode();
641 self.send_to_peer(peer_id, &batch).await?;
642
643 Ok(1 + 4 + batch_bytes.len())
644 }
645
646 pub async fn send_tombstone_batch(
648 &self,
649 peer_id: EndpointId,
650 tombstones: &[crate::qos::TombstoneSyncMessage],
651 ) -> Result<usize> {
652 let mut batch = SyncBatch::new();
653 for tombstone in tombstones {
654 batch.add_tombstone(tombstone);
655 }
656
657 let batch_bytes = batch.encode();
658 self.send_to_peer(peer_id, &batch).await?;
659
660 Ok(1 + 4 + batch_bytes.len())
661 }
662
663 pub async fn broadcast_delta_sync(
665 &self,
666 doc_key: &str,
667 message: &automerge::sync::Message,
668 ) -> Result<()> {
669 let encoded = message.clone().encode();
670
671 let mut batch = SyncBatch::new();
672 batch.entries.push(super::automerge_sync::SyncEntry::new(
673 doc_key.to_string(),
674 SyncMessageType::DeltaSync,
675 encoded,
676 ));
677
678 self.broadcast(&batch).await
679 }
680
681 pub async fn broadcast_state_snapshot(
683 &self,
684 doc_key: &str,
685 state_bytes: Vec<u8>,
686 ) -> Result<()> {
687 let mut batch = SyncBatch::new();
688 batch.entries.push(super::automerge_sync::SyncEntry::new(
689 doc_key.to_string(),
690 SyncMessageType::StateSnapshot,
691 state_bytes,
692 ));
693
694 self.broadcast(&batch).await
695 }
696
697 pub async fn broadcast_tombstone(
699 &self,
700 tombstone_msg: &crate::qos::TombstoneSyncMessage,
701 ) -> Result<()> {
702 let mut batch = SyncBatch::new();
703 batch.add_tombstone(tombstone_msg);
704
705 self.broadcast(&batch).await
706 }
707
708 pub async fn remove_channel(&self, peer_id: &EndpointId) {
710 let channel = self
712 .channels
713 .write()
714 .unwrap_or_else(|e| e.into_inner())
715 .remove(peer_id);
716 if let Some(channel) = channel {
717 channel.close().await;
718 }
719 }
720
721 pub fn channel_count(&self) -> usize {
723 self.channels
724 .read()
725 .unwrap_or_else(|e| e.into_inner())
726 .len()
727 }
728
729 pub fn stats(&self) -> ChannelManagerStats {
731 let channels = self.channels.read().unwrap_or_else(|e| e.into_inner());
732 let mut total_bytes = 0u64;
733 let mut total_batches = 0u64;
734 let mut connected = 0usize;
735
736 for channel in channels.values() {
737 total_bytes += channel.bytes_sent();
738 total_batches += channel.batches_sent();
739 if channel.is_connected() {
740 connected += 1;
741 }
742 }
743
744 ChannelManagerStats {
745 total_channels: channels.len(),
746 connected_channels: connected,
747 total_bytes_sent: total_bytes,
748 total_batches_sent: total_batches,
749 }
750 }
751
752 pub async fn shutdown(&self) {
754 self.active.store(false, Ordering::Relaxed);
755
756 let channels: Vec<Arc<SyncChannel>> = {
757 let mut channels = self.channels.write().unwrap_or_else(|e| e.into_inner());
758 channels.drain().map(|(_, c)| c).collect()
759 };
760
761 for channel in channels {
762 channel.close().await;
763 }
764
765 tracing::debug!("SyncChannelManager shutdown complete");
766 }
767}
768
769#[cfg(feature = "automerge-backend")]
771#[derive(Debug, Clone, Default)]
772pub struct ChannelManagerStats {
773 pub total_channels: usize,
775 pub connected_channels: usize,
777 pub total_bytes_sent: u64,
779 pub total_batches_sent: u64,
781}
782
783#[cfg(all(test, feature = "automerge-backend"))]
784mod tests {
785 use super::*;
786
787 #[test]
788 fn test_channel_state_enum() {
789 assert_ne!(ChannelState::Connected, ChannelState::Reconnecting);
790 assert_ne!(ChannelState::Reconnecting, ChannelState::Closed);
791 assert_eq!(ChannelState::Connected, ChannelState::Connected);
792 }
793
794 #[test]
795 fn test_channel_manager_stats_default() {
796 let stats = ChannelManagerStats::default();
797 assert_eq!(stats.total_channels, 0);
798 assert_eq!(stats.connected_channels, 0);
799 assert_eq!(stats.total_bytes_sent, 0);
800 assert_eq!(stats.total_batches_sent, 0);
801 }
802}