#[cfg(feature = "automerge-backend")]
use anyhow::{Context, Result};
#[cfg(feature = "automerge-backend")]
use iroh::endpoint::{RecvStream, SendStream};
#[cfg(feature = "automerge-backend")]
use iroh::EndpointId;
#[cfg(feature = "automerge-backend")]
use std::collections::HashMap;
#[cfg(feature = "automerge-backend")]
use std::sync::atomic::{AtomicU32, AtomicU64, Ordering};
#[cfg(feature = "automerge-backend")]
use std::sync::{Arc, RwLock};
#[cfg(feature = "automerge-backend")]
use std::time::{Duration, Instant};
#[cfg(feature = "automerge-backend")]
use tokio::sync::Mutex;
#[cfg(feature = "automerge-backend")]
use tokio::task::JoinHandle;
#[cfg(feature = "automerge-backend")]
use super::automerge_sync::{AutomergeSyncCoordinator, SyncBatch, SyncMessageType};
#[cfg(feature = "automerge-backend")]
use super::sync_transport::SyncTransport;
#[cfg(feature = "automerge-backend")]
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ChannelState {
Connected,
Reconnecting,
Closed,
}
#[cfg(feature = "automerge-backend")]
pub struct SyncChannel {
peer_id: EndpointId,
transport: Arc<dyn SyncTransport>,
send: Arc<Mutex<Option<SendStream>>>,
recv_task: Arc<Mutex<Option<JoinHandle<()>>>>,
state: Arc<RwLock<ChannelState>>,
reconnect_attempts: AtomicU32,
last_send: Arc<RwLock<Instant>>,
bytes_sent: AtomicU64,
batches_sent: AtomicU64,
}
#[cfg(feature = "automerge-backend")]
impl SyncChannel {
const MAX_RECONNECT_ATTEMPTS: u32 = 3;
const RECONNECT_DELAY: Duration = Duration::from_millis(500);
const RECV_TIMEOUT: Duration = Duration::from_secs(30);
pub async fn connect(
transport: Arc<dyn SyncTransport>,
peer_id: EndpointId,
coordinator: Arc<AutomergeSyncCoordinator>,
) -> Result<Self> {
Self::connect_with_token(transport, peer_id, coordinator, None).await
}
pub async fn connect_with_token(
transport: Arc<dyn SyncTransport>,
peer_id: EndpointId,
coordinator: Arc<AutomergeSyncCoordinator>,
cancel: Option<tokio_util::sync::CancellationToken>,
) -> Result<Self> {
let conn = transport.get_or_connect(&peer_id).await?;
let (send, recv) = conn
.open_bi()
.await
.context("Failed to open bidirectional stream")?;
let channel = Self {
peer_id,
transport,
send: Arc::new(Mutex::new(Some(send))),
recv_task: Arc::new(Mutex::new(None)),
state: Arc::new(RwLock::new(ChannelState::Connected)),
reconnect_attempts: AtomicU32::new(0),
last_send: Arc::new(RwLock::new(Instant::now())),
bytes_sent: AtomicU64::new(0),
batches_sent: AtomicU64::new(0),
};
channel.spawn_receiver(recv, coordinator, cancel);
tracing::debug!("Sync channel connected to peer {:?}", peer_id);
Ok(channel)
}
fn spawn_receiver(
&self,
recv: RecvStream,
coordinator: Arc<AutomergeSyncCoordinator>,
cancel: Option<tokio_util::sync::CancellationToken>,
) {
let peer_id = self.peer_id;
let state = Arc::clone(&self.state);
let recv_task = Arc::clone(&self.recv_task);
let task = tokio::spawn(async move {
tracing::debug!("Sync channel receiver started for peer {:?}", peer_id);
if let Err(e) = Self::receive_loop(recv, peer_id, coordinator, cancel).await {
tracing::warn!(
"Sync channel receiver for peer {:?} ended with error: {}",
peer_id,
e
);
}
*state.write().unwrap_or_else(|e| e.into_inner()) = ChannelState::Reconnecting;
tracing::debug!("Sync channel receiver ended for peer {:?}", peer_id);
});
tokio::spawn(async move {
*recv_task.lock().await = Some(task);
});
}
async fn receive_loop(
mut recv: RecvStream,
peer_id: EndpointId,
coordinator: Arc<AutomergeSyncCoordinator>,
cancel: Option<tokio_util::sync::CancellationToken>,
) -> Result<()> {
loop {
if let Some(ref token) = cancel {
if token.is_cancelled() {
tracing::debug!("Sync channel receive loop for peer {:?} cancelled", peer_id);
return Ok(());
}
}
let mut marker = [0u8; 1];
let read_result = if let Some(ref token) = cancel {
tokio::select! {
res = tokio::time::timeout(Self::RECV_TIMEOUT, recv.read_exact(&mut marker)) => res,
() = token.cancelled() => {
tracing::debug!(
"Sync channel receive loop for peer {:?} cancelled during read",
peer_id
);
return Ok(());
}
}
} else {
tokio::time::timeout(Self::RECV_TIMEOUT, recv.read_exact(&mut marker)).await
};
match read_result {
Ok(Ok(_)) => {}
Ok(Err(e)) => {
return Err(anyhow::anyhow!("Stream read error: {}", e));
}
Err(_) => {
tracing::warn!(
"Sync channel receive timeout for peer {:?} (no data for {:?})",
peer_id,
Self::RECV_TIMEOUT,
);
return Err(anyhow::anyhow!(
"Receive timeout waiting for message marker"
));
}
}
if marker[0] != SyncMessageType::SyncBatch as u8 {
tracing::warn!(
"Unexpected message type on sync channel: 0x{:02x}",
marker[0]
);
continue;
}
let mut len_bytes = [0u8; 4];
tokio::time::timeout(Self::RECV_TIMEOUT, recv.read_exact(&mut len_bytes))
.await
.map_err(|_| {
tracing::warn!(
"Sync channel receive timeout for peer {:?} reading batch length",
peer_id,
);
anyhow::anyhow!("Receive timeout reading batch length")
})?
.context("Failed to read batch length")?;
let batch_len = u32::from_be_bytes(len_bytes) as usize;
let mut batch_data = vec![0u8; batch_len];
tokio::time::timeout(Self::RECV_TIMEOUT, recv.read_exact(&mut batch_data))
.await
.map_err(|_| {
tracing::warn!(
"Sync channel receive timeout for peer {:?} reading batch data ({} bytes)",
peer_id,
batch_len,
);
anyhow::anyhow!("Receive timeout reading batch data")
})?
.context("Failed to read batch data")?;
match SyncBatch::decode(&batch_data) {
Ok(batch) => {
let total_bytes = 1 + 4 + batch_len; if let Err(e) = coordinator
.receive_batch_message(peer_id, batch, total_bytes)
.await
{
tracing::warn!("Failed to process batch from peer {:?}: {}", peer_id, e);
}
}
Err(e) => {
tracing::warn!("Failed to decode batch from peer {:?}: {}", peer_id, e);
}
}
}
}
pub async fn send(&self, batch: &SyncBatch) -> Result<()> {
let needs_reconnect = {
let state = *self.state.read().unwrap_or_else(|e| e.into_inner());
match state {
ChannelState::Closed => return Err(anyhow::anyhow!("Channel is closed")),
ChannelState::Reconnecting => true,
ChannelState::Connected => false,
}
};
if needs_reconnect {
self.reconnect().await?;
}
let batch_bytes = batch.encode();
let mut send_guard = self.send.lock().await;
let send = send_guard
.as_mut()
.ok_or_else(|| anyhow::anyhow!("No send stream available"))?;
let doc_key = b"batch";
send.write_all(&(doc_key.len() as u16).to_be_bytes())
.await
.context("Failed to write doc_key length")?;
send.write_all(doc_key)
.await
.context("Failed to write doc_key")?;
send.write_all(&[SyncMessageType::SyncBatch as u8])
.await
.context("Failed to write batch marker")?;
let batch_len = batch_bytes.len() as u32;
send.write_all(&batch_len.to_be_bytes())
.await
.context("Failed to write batch length")?;
send.write_all(&batch_bytes)
.await
.context("Failed to write batch data")?;
let total_bytes = 2 + doc_key.len() + 1 + 4 + batch_bytes.len();
self.bytes_sent
.fetch_add(total_bytes as u64, Ordering::Relaxed);
self.batches_sent.fetch_add(1, Ordering::Relaxed);
*self.last_send.write().unwrap_or_else(|e| e.into_inner()) = Instant::now();
tracing::trace!(
"Sent batch {} ({} entries, {} bytes) to peer {:?}",
batch.batch_id,
batch.len(),
total_bytes,
self.peer_id
);
Ok(())
}
pub async fn reconnect(&self) -> Result<()> {
*self.state.write().unwrap_or_else(|e| e.into_inner()) = ChannelState::Reconnecting;
let attempts = self.reconnect_attempts.fetch_add(1, Ordering::Relaxed);
if attempts >= Self::MAX_RECONNECT_ATTEMPTS {
*self.state.write().unwrap_or_else(|e| e.into_inner()) = ChannelState::Closed;
return Err(anyhow::anyhow!(
"Max reconnection attempts ({}) exceeded",
Self::MAX_RECONNECT_ATTEMPTS
));
}
tracing::info!(
"Attempting reconnection to peer {:?} (attempt {})",
self.peer_id,
attempts + 1
);
tokio::time::sleep(Self::RECONNECT_DELAY).await;
let conn = self.transport.get_or_connect(&self.peer_id).await?;
let (send, mut recv) = conn
.open_bi()
.await
.context("Failed to open bidirectional stream for reconnection")?;
let _ = recv.stop(0u32.into());
*self.send.lock().await = Some(send);
*self.state.write().unwrap_or_else(|e| e.into_inner()) = ChannelState::Connected;
self.reconnect_attempts.store(0, Ordering::Relaxed);
tracing::info!("Reconnected sync channel to peer {:?}", self.peer_id);
Ok(())
}
pub fn is_connected(&self) -> bool {
*self.state.read().unwrap_or_else(|e| e.into_inner()) == ChannelState::Connected
}
pub fn state(&self) -> ChannelState {
*self.state.read().unwrap_or_else(|e| e.into_inner())
}
pub fn peer_id(&self) -> EndpointId {
self.peer_id
}
pub fn bytes_sent(&self) -> u64 {
self.bytes_sent.load(Ordering::Relaxed)
}
pub fn batches_sent(&self) -> u64 {
self.batches_sent.load(Ordering::Relaxed)
}
pub async fn close(&self) {
*self.state.write().unwrap_or_else(|e| e.into_inner()) = ChannelState::Closed;
if let Some(task) = self.recv_task.lock().await.take() {
task.abort();
}
if let Some(mut send) = self.send.lock().await.take() {
let _ = send.finish();
}
tracing::debug!("Sync channel to peer {:?} closed", self.peer_id);
}
}
#[cfg(feature = "automerge-backend")]
pub struct SyncChannelManager {
channels: Arc<RwLock<HashMap<EndpointId, Arc<SyncChannel>>>>,
transport: Arc<dyn SyncTransport>,
coordinator: Arc<AutomergeSyncCoordinator>,
active: Arc<std::sync::atomic::AtomicBool>,
}
#[cfg(feature = "automerge-backend")]
impl SyncChannelManager {
pub fn new(
transport: Arc<dyn SyncTransport>,
coordinator: Arc<AutomergeSyncCoordinator>,
) -> Self {
Self {
channels: Arc::new(RwLock::new(HashMap::new())),
transport,
coordinator,
active: Arc::new(std::sync::atomic::AtomicBool::new(true)),
}
}
pub async fn get_channel(&self, peer_id: EndpointId) -> Result<Arc<SyncChannel>> {
{
let channels = self.channels.read().unwrap_or_else(|e| e.into_inner());
if let Some(channel) = channels.get(&peer_id) {
if channel.is_connected() {
return Ok(Arc::clone(channel));
}
}
}
let channel = SyncChannel::connect(
Arc::clone(&self.transport),
peer_id,
Arc::clone(&self.coordinator),
)
.await?;
let channel = Arc::new(channel);
self.channels
.write()
.unwrap()
.insert(peer_id, Arc::clone(&channel));
Ok(channel)
}
pub async fn send_to_peer(&self, peer_id: EndpointId, batch: &SyncBatch) -> Result<()> {
let channel = self.get_channel(peer_id).await?;
channel.send(batch).await
}
pub async fn broadcast(&self, batch: &SyncBatch) -> Result<()> {
let peer_ids = self.transport.connected_peers();
for peer_id in peer_ids {
if let Err(e) = self.send_to_peer(peer_id, batch).await {
tracing::warn!("Failed to send batch to peer {:?}: {}", peer_id, e);
}
}
Ok(())
}
pub async fn send_delta_sync(
&self,
peer_id: EndpointId,
doc_key: &str,
message: &automerge::sync::Message,
) -> Result<usize> {
let encoded = message.clone().encode();
let payload_len = encoded.len();
let mut batch = SyncBatch::new();
batch.entries.push(super::automerge_sync::SyncEntry::new(
doc_key.to_string(),
SyncMessageType::DeltaSync,
encoded,
));
self.send_to_peer(peer_id, &batch).await?;
Ok(1 + 4 + batch.encode().len() + payload_len)
}
pub async fn send_state_snapshot(
&self,
peer_id: EndpointId,
doc_key: &str,
state_bytes: Vec<u8>,
) -> Result<usize> {
let payload_len = state_bytes.len();
let mut batch = SyncBatch::new();
batch.entries.push(super::automerge_sync::SyncEntry::new(
doc_key.to_string(),
SyncMessageType::StateSnapshot,
state_bytes,
));
self.send_to_peer(peer_id, &batch).await?;
Ok(1 + 4 + batch.encode().len() + payload_len)
}
pub async fn send_tombstone(
&self,
peer_id: EndpointId,
tombstone_msg: &crate::qos::TombstoneSyncMessage,
) -> Result<usize> {
let mut batch = SyncBatch::new();
batch.add_tombstone(tombstone_msg);
let batch_bytes = batch.encode();
self.send_to_peer(peer_id, &batch).await?;
Ok(1 + 4 + batch_bytes.len())
}
pub async fn send_tombstone_batch(
&self,
peer_id: EndpointId,
tombstones: &[crate::qos::TombstoneSyncMessage],
) -> Result<usize> {
let mut batch = SyncBatch::new();
for tombstone in tombstones {
batch.add_tombstone(tombstone);
}
let batch_bytes = batch.encode();
self.send_to_peer(peer_id, &batch).await?;
Ok(1 + 4 + batch_bytes.len())
}
pub async fn broadcast_delta_sync(
&self,
doc_key: &str,
message: &automerge::sync::Message,
) -> Result<()> {
let encoded = message.clone().encode();
let mut batch = SyncBatch::new();
batch.entries.push(super::automerge_sync::SyncEntry::new(
doc_key.to_string(),
SyncMessageType::DeltaSync,
encoded,
));
self.broadcast(&batch).await
}
pub async fn broadcast_state_snapshot(
&self,
doc_key: &str,
state_bytes: Vec<u8>,
) -> Result<()> {
let mut batch = SyncBatch::new();
batch.entries.push(super::automerge_sync::SyncEntry::new(
doc_key.to_string(),
SyncMessageType::StateSnapshot,
state_bytes,
));
self.broadcast(&batch).await
}
pub async fn broadcast_tombstone(
&self,
tombstone_msg: &crate::qos::TombstoneSyncMessage,
) -> Result<()> {
let mut batch = SyncBatch::new();
batch.add_tombstone(tombstone_msg);
self.broadcast(&batch).await
}
pub async fn remove_channel(&self, peer_id: &EndpointId) {
let channel = self
.channels
.write()
.unwrap_or_else(|e| e.into_inner())
.remove(peer_id);
if let Some(channel) = channel {
channel.close().await;
}
}
pub fn channel_count(&self) -> usize {
self.channels
.read()
.unwrap_or_else(|e| e.into_inner())
.len()
}
pub fn stats(&self) -> ChannelManagerStats {
let channels = self.channels.read().unwrap_or_else(|e| e.into_inner());
let mut total_bytes = 0u64;
let mut total_batches = 0u64;
let mut connected = 0usize;
for channel in channels.values() {
total_bytes += channel.bytes_sent();
total_batches += channel.batches_sent();
if channel.is_connected() {
connected += 1;
}
}
ChannelManagerStats {
total_channels: channels.len(),
connected_channels: connected,
total_bytes_sent: total_bytes,
total_batches_sent: total_batches,
}
}
pub async fn shutdown(&self) {
self.active.store(false, Ordering::Relaxed);
let channels: Vec<Arc<SyncChannel>> = {
let mut channels = self.channels.write().unwrap_or_else(|e| e.into_inner());
channels.drain().map(|(_, c)| c).collect()
};
for channel in channels {
channel.close().await;
}
tracing::debug!("SyncChannelManager shutdown complete");
}
}
#[cfg(feature = "automerge-backend")]
#[derive(Debug, Clone, Default)]
pub struct ChannelManagerStats {
pub total_channels: usize,
pub connected_channels: usize,
pub total_bytes_sent: u64,
pub total_batches_sent: u64,
}
#[cfg(all(test, feature = "automerge-backend"))]
mod tests {
use super::*;
#[test]
fn test_channel_state_enum() {
assert_ne!(ChannelState::Connected, ChannelState::Reconnecting);
assert_ne!(ChannelState::Reconnecting, ChannelState::Closed);
assert_eq!(ChannelState::Connected, ChannelState::Connected);
}
#[test]
fn test_channel_manager_stats_default() {
let stats = ChannelManagerStats::default();
assert_eq!(stats.total_channels, 0);
assert_eq!(stats.connected_channels, 0);
assert_eq!(stats.total_bytes_sent, 0);
assert_eq!(stats.total_batches_sent, 0);
}
}