use super::blob_traits::{BlobHash, BlobMetadata, BlobToken};
use anyhow::Result;
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[cfg(feature = "automerge-backend")]
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::broadcast;
#[cfg(feature = "automerge-backend")]
use tokio::sync::RwLock;
#[cfg(feature = "automerge-backend")]
use tracing::{debug, info, warn};
use uuid::Uuid;
#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum TransferPriority {
Critical,
High,
#[default]
Normal,
Low,
}
impl TransferPriority {
pub fn as_numeric(&self) -> u8 {
match self {
Self::Critical => 4,
Self::High => 3,
Self::Normal => 2,
Self::Low => 1,
}
}
}
#[derive(Clone, Debug, Default, Serialize, Deserialize)]
pub enum DistributionScope {
#[default]
AllNodes,
Formation {
formation_id: String,
},
Nodes {
node_ids: Vec<String>,
},
Capable {
#[serde(skip_serializing_if = "Option::is_none")]
min_gpu_gb: Option<f64>,
#[serde(skip_serializing_if = "Option::is_none")]
cpu_arch: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
min_storage_mb: Option<u64>,
},
}
#[derive(Clone, Debug, Default, PartialEq, Eq, Serialize, Deserialize)]
pub enum TransferState {
#[default]
Pending,
Connecting,
Transferring,
Completed,
Failed,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct NodeTransferStatus {
pub node_id: String,
pub status: TransferState,
pub progress_bytes: u64,
pub total_bytes: u64,
#[serde(skip_serializing_if = "Option::is_none")]
pub started_at: Option<DateTime<Utc>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub completed_at: Option<DateTime<Utc>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub error: Option<String>,
}
impl NodeTransferStatus {
pub fn new(node_id: String, total_bytes: u64) -> Self {
Self {
node_id,
status: TransferState::Pending,
progress_bytes: 0,
total_bytes,
started_at: None,
completed_at: None,
error: None,
}
}
pub fn progress_fraction(&self) -> f64 {
if self.total_bytes == 0 {
return 1.0;
}
self.progress_bytes as f64 / self.total_bytes as f64
}
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct DistributionHandle {
pub distribution_id: String,
pub blob_hash: BlobHash,
pub scope: DistributionScope,
pub priority: TransferPriority,
pub started_at: DateTime<Utc>,
}
impl DistributionHandle {
pub fn new(blob_hash: BlobHash, scope: DistributionScope, priority: TransferPriority) -> Self {
Self {
distribution_id: Uuid::new_v4().to_string(),
blob_hash,
scope,
priority,
started_at: Utc::now(),
}
}
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct DistributionStatus {
pub handle: DistributionHandle,
pub total_targets: usize,
pub completed: usize,
pub in_progress: usize,
pub failed: usize,
pub node_statuses: HashMap<String, NodeTransferStatus>,
}
impl DistributionStatus {
pub fn new(handle: DistributionHandle, target_nodes: Vec<String>, total_bytes: u64) -> Self {
let node_statuses: HashMap<String, NodeTransferStatus> = target_nodes
.into_iter()
.map(|id| (id.clone(), NodeTransferStatus::new(id, total_bytes)))
.collect();
let total_targets = node_statuses.len();
Self {
handle,
total_targets,
completed: 0,
in_progress: 0,
failed: 0,
node_statuses,
}
}
pub fn is_complete(&self) -> bool {
self.completed + self.failed >= self.total_targets
}
pub fn is_success(&self) -> bool {
self.completed >= self.total_targets && self.failed == 0
}
pub fn overall_progress(&self) -> f64 {
if self.total_targets == 0 {
return 1.0;
}
let total_bytes: u64 = self.node_statuses.values().map(|s| s.total_bytes).sum();
let progress_bytes: u64 = self.node_statuses.values().map(|s| s.progress_bytes).sum();
if total_bytes == 0 {
return 1.0;
}
progress_bytes as f64 / total_bytes as f64
}
pub fn recalculate_counts(&mut self) {
self.completed = 0;
self.in_progress = 0;
self.failed = 0;
for status in self.node_statuses.values() {
match status.status {
TransferState::Completed => self.completed += 1,
TransferState::Failed => self.failed += 1,
TransferState::Transferring | TransferState::Connecting => self.in_progress += 1,
TransferState::Pending => {}
}
}
}
}
#[async_trait::async_trait]
pub trait FileDistribution: Send + Sync {
async fn distribute(
&self,
blob_token: &BlobToken,
scope: DistributionScope,
priority: TransferPriority,
) -> Result<DistributionHandle>;
async fn status(&self, handle: &DistributionHandle) -> Result<DistributionStatus>;
async fn cancel(&self, handle: &DistributionHandle) -> Result<()>;
async fn wait_for_completion(
&self,
handle: &DistributionHandle,
timeout: Duration,
) -> Result<DistributionStatus>;
async fn subscribe_progress(
&self,
handle: &DistributionHandle,
) -> Result<broadcast::Receiver<DistributionStatus>>;
}
#[cfg(feature = "automerge-backend")]
use super::automerge_store::AutomergeStore;
#[cfg(feature = "automerge-backend")]
use super::iroh_blob_store::NetworkedIrohBlobStore;
#[cfg(feature = "automerge-backend")]
pub const IROH_DISTRIBUTION_COLLECTION: &str = "file_distributions";
#[cfg(feature = "automerge-backend")]
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct DistributionDocument {
pub distribution_id: String,
pub blob_hash: String,
pub blob_size: u64,
pub blob_metadata: BlobMetadata,
pub scope: DistributionScope,
pub priority: TransferPriority,
pub target_nodes: Vec<String>,
pub started_at: DateTime<Utc>,
pub status: String,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub cancelled_at: Option<DateTime<Utc>>,
#[serde(default)]
pub node_statuses: HashMap<String, NodeTransferStatus>,
}
#[cfg(feature = "automerge-backend")]
type DistributionsMap = Arc<RwLock<HashMap<String, DistributionStatus>>>;
#[cfg(feature = "automerge-backend")]
type ProgressChannels = Arc<RwLock<HashMap<String, broadcast::Sender<DistributionStatus>>>>;
#[cfg(feature = "automerge-backend")]
pub struct IrohFileDistribution {
blob_store: Arc<NetworkedIrohBlobStore>,
document_store: Arc<AutomergeStore>,
distributions: DistributionsMap,
progress_channels: ProgressChannels,
watcher_handle: Option<tokio::task::JoinHandle<()>>,
}
#[cfg(feature = "automerge-backend")]
impl IrohFileDistribution {
pub fn new(
blob_store: Arc<NetworkedIrohBlobStore>,
document_store: Arc<AutomergeStore>,
) -> Self {
let distributions: DistributionsMap = Arc::new(RwLock::new(HashMap::new()));
let progress_channels: ProgressChannels = Arc::new(RwLock::new(HashMap::new()));
let watcher_handle = {
let document_store = Arc::clone(&document_store);
let distributions = Arc::clone(&distributions);
let progress_channels = Arc::clone(&progress_channels);
tokio::spawn(async move {
watch_distribution_documents(document_store, distributions, progress_channels)
.await;
})
};
Self {
blob_store,
document_store,
distributions,
progress_channels,
watcher_handle: Some(watcher_handle),
}
}
pub fn blob_store(&self) -> &Arc<NetworkedIrohBlobStore> {
&self.blob_store
}
pub fn document_store(&self) -> &Arc<AutomergeStore> {
&self.document_store
}
async fn resolve_targets(&self, scope: &DistributionScope) -> Vec<String> {
match scope {
DistributionScope::AllNodes => {
self.blob_store
.known_peers()
.await
.iter()
.map(|p| p.fmt_short().to_string())
.collect()
}
DistributionScope::Nodes { node_ids } => {
let known_peers: Vec<String> = self
.blob_store
.known_peers()
.await
.iter()
.map(|p| p.fmt_short().to_string())
.collect();
node_ids
.iter()
.filter(|id| known_peers.contains(id))
.cloned()
.collect()
}
DistributionScope::Formation { formation_id } => {
warn!(
formation_id = %formation_id,
"Formation-based distribution not yet implemented, distributing to all peers"
);
self.blob_store
.known_peers()
.await
.iter()
.map(|p| p.fmt_short().to_string())
.collect()
}
DistributionScope::Capable { .. } => {
warn!(
"Capability-based distribution not yet implemented, distributing to all peers"
);
self.blob_store
.known_peers()
.await
.iter()
.map(|p| p.fmt_short().to_string())
.collect()
}
}
}
#[allow(unused_imports)]
async fn store_distribution_document(
&self,
handle: &DistributionHandle,
blob_token: &BlobToken,
target_nodes: &[String],
) -> Result<()> {
use super::traits::Collection;
let doc_id = &handle.distribution_id;
let distribution_doc = DistributionDocument {
distribution_id: handle.distribution_id.clone(),
blob_hash: blob_token.hash.as_hex().to_string(),
blob_size: blob_token.size_bytes,
blob_metadata: blob_token.metadata.clone(),
scope: handle.scope.clone(),
priority: handle.priority,
target_nodes: target_nodes.to_vec(),
started_at: handle.started_at,
status: "distributing".to_string(),
cancelled_at: None,
node_statuses: HashMap::new(),
};
let bytes = serde_json::to_vec(&distribution_doc)
.map_err(|e| anyhow::anyhow!("Failed to serialize distribution doc: {}", e))?;
let collection = self.document_store.collection(IROH_DISTRIBUTION_COLLECTION);
collection.upsert(doc_id, bytes)?;
debug!(
distribution_id = %handle.distribution_id,
blob_hash = %blob_token.hash,
target_count = target_nodes.len(),
"Stored distribution document in Automerge"
);
Ok(())
}
async fn broadcast_progress(&self, distribution_id: &str, status: &DistributionStatus) {
let channels = self.progress_channels.read().await;
if let Some(sender) = channels.get(distribution_id) {
let _ = sender.send(status.clone());
}
}
}
#[cfg(feature = "automerge-backend")]
impl Drop for IrohFileDistribution {
fn drop(&mut self) {
if let Some(handle) = self.watcher_handle.take() {
handle.abort();
}
}
}
#[cfg(feature = "automerge-backend")]
async fn watch_distribution_documents(
document_store: Arc<AutomergeStore>,
distributions: DistributionsMap,
progress_channels: ProgressChannels,
) {
let mut rx = document_store.subscribe_to_observer_changes();
let prefix = format!("{}:", IROH_DISTRIBUTION_COLLECTION);
loop {
let key = match rx.recv().await {
Ok(k) => k,
Err(broadcast::error::RecvError::Lagged(n)) => {
warn!(
lagged = n,
"distribution watcher lagged on observer channel"
);
continue;
}
Err(broadcast::error::RecvError::Closed) => return,
};
let Some(doc_id) = key.strip_prefix(&prefix) else {
continue;
};
if !distributions.read().await.contains_key(doc_id) {
continue;
}
let collection = document_store.collection(IROH_DISTRIBUTION_COLLECTION);
let bytes = match collection.get(doc_id) {
Ok(Some(b)) => b,
Ok(None) => continue,
Err(e) => {
warn!(error = %e, doc_id, "failed to read distribution doc");
continue;
}
};
let doc: DistributionDocument = match serde_json::from_slice(&bytes) {
Ok(d) => d,
Err(e) => {
warn!(error = %e, doc_id, "failed to deserialize distribution doc");
continue;
}
};
if doc.status != "distributing" {
continue;
}
let (snapshot, complete) = {
let mut dists = distributions.write().await;
let Some(status) = dists.get_mut(doc_id) else {
continue;
};
let mut changed = false;
for (node_id, ns) in &doc.node_statuses {
let differs = match status.node_statuses.get(node_id) {
Some(existing) => {
existing.status != ns.status
|| existing.progress_bytes != ns.progress_bytes
|| existing.error != ns.error
}
None => true,
};
if differs {
status.node_statuses.insert(node_id.clone(), ns.clone());
changed = true;
}
}
if !changed {
continue;
}
status.recalculate_counts();
(status.clone(), status.is_complete())
};
{
let channels = progress_channels.read().await;
if let Some(sender) = channels.get(doc_id) {
let _ = sender.send(snapshot);
}
}
if complete {
progress_channels.write().await.remove(doc_id);
}
}
}
#[cfg(feature = "automerge-backend")]
#[async_trait::async_trait]
impl FileDistribution for IrohFileDistribution {
async fn distribute(
&self,
blob_token: &BlobToken,
scope: DistributionScope,
priority: TransferPriority,
) -> Result<DistributionHandle> {
info!(
blob_hash = %blob_token.hash,
blob_size = blob_token.size_bytes,
scope = ?scope,
priority = ?priority,
"Starting file distribution"
);
let handle = DistributionHandle::new(blob_token.hash.clone(), scope.clone(), priority);
let target_nodes = self.resolve_targets(&scope).await;
if target_nodes.is_empty() {
warn!("No target nodes found for distribution scope");
}
let status =
DistributionStatus::new(handle.clone(), target_nodes.clone(), blob_token.size_bytes);
self.store_distribution_document(&handle, blob_token, &target_nodes)
.await?;
{
let mut distributions = self.distributions.write().await;
distributions.insert(handle.distribution_id.clone(), status.clone());
}
{
let (tx, _rx) = broadcast::channel(16);
let mut channels = self.progress_channels.write().await;
channels.insert(handle.distribution_id.clone(), tx);
}
info!(
distribution_id = %handle.distribution_id,
target_count = target_nodes.len(),
"Distribution initiated - document synced to peers"
);
Ok(handle)
}
async fn status(&self, handle: &DistributionHandle) -> Result<DistributionStatus> {
let distributions = self.distributions.read().await;
distributions
.get(&handle.distribution_id)
.cloned()
.ok_or_else(|| anyhow::anyhow!("Distribution not found: {}", handle.distribution_id))
}
async fn cancel(&self, handle: &DistributionHandle) -> Result<()> {
info!(
distribution_id = %handle.distribution_id,
"Cancelling distribution"
);
let cancelled_status = {
let mut distributions = self.distributions.write().await;
distributions
.get_mut(&handle.distribution_id)
.map(|status| {
for node_status in status.node_statuses.values_mut() {
if node_status.status != TransferState::Completed {
node_status.status = TransferState::Failed;
node_status.error = Some("Distribution cancelled".to_string());
}
}
status.recalculate_counts();
status.clone()
})
};
if let Some(status) = cancelled_status {
self.broadcast_progress(&handle.distribution_id, &status)
.await;
let mut channels = self.progress_channels.write().await;
channels.remove(&handle.distribution_id);
}
#[allow(unused_imports)]
use super::traits::Collection;
let collection = self.document_store.collection(IROH_DISTRIBUTION_COLLECTION);
if let Some(existing) = collection.get(&handle.distribution_id)? {
let mut doc: DistributionDocument = serde_json::from_slice(&existing)
.map_err(|e| anyhow::anyhow!("Failed to deserialize distribution doc: {}", e))?;
doc.status = "cancelled".to_string();
doc.cancelled_at = Some(Utc::now());
let bytes = serde_json::to_vec(&doc)
.map_err(|e| anyhow::anyhow!("Failed to serialize cancel update: {}", e))?;
collection.upsert(&handle.distribution_id, bytes)?;
}
Ok(())
}
async fn wait_for_completion(
&self,
handle: &DistributionHandle,
timeout: Duration,
) -> Result<DistributionStatus> {
let start = std::time::Instant::now();
let poll_interval = Duration::from_millis(500);
loop {
let status = self.status(handle).await?;
if status.is_complete() {
return Ok(status);
}
if start.elapsed() >= timeout {
return Err(anyhow::anyhow!("Distribution timeout after {:?}", timeout));
}
tokio::time::sleep(poll_interval).await;
}
}
async fn subscribe_progress(
&self,
handle: &DistributionHandle,
) -> Result<broadcast::Receiver<DistributionStatus>> {
let channels = self.progress_channels.read().await;
channels
.get(&handle.distribution_id)
.map(|sender| sender.subscribe())
.ok_or_else(|| anyhow::anyhow!("Distribution not found: {}", handle.distribution_id))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_transfer_priority_ordering() {
assert!(TransferPriority::Critical.as_numeric() > TransferPriority::High.as_numeric());
assert!(TransferPriority::High.as_numeric() > TransferPriority::Normal.as_numeric());
assert!(TransferPriority::Normal.as_numeric() > TransferPriority::Low.as_numeric());
}
#[test]
fn test_distribution_handle_creation() {
let hash = BlobHash::from_hex("abc123");
let scope = DistributionScope::AllNodes;
let priority = TransferPriority::High;
let handle = DistributionHandle::new(hash.clone(), scope, priority);
assert!(!handle.distribution_id.is_empty());
assert_eq!(handle.blob_hash, hash);
assert_eq!(handle.priority, TransferPriority::High);
}
#[test]
fn test_node_transfer_status() {
let mut status = NodeTransferStatus::new("node-1".to_string(), 1000);
assert_eq!(status.status, TransferState::Pending);
assert_eq!(status.progress_fraction(), 0.0);
status.progress_bytes = 500;
status.status = TransferState::Transferring;
assert_eq!(status.progress_fraction(), 0.5);
status.progress_bytes = 1000;
status.status = TransferState::Completed;
assert_eq!(status.progress_fraction(), 1.0);
}
#[test]
fn test_distribution_status() {
let hash = BlobHash::from_hex("abc123");
let handle =
DistributionHandle::new(hash, DistributionScope::AllNodes, TransferPriority::Normal);
let targets = vec![
"node-1".to_string(),
"node-2".to_string(),
"node-3".to_string(),
];
let mut status = DistributionStatus::new(handle, targets, 1000);
assert_eq!(status.total_targets, 3);
assert_eq!(status.completed, 0);
assert!(!status.is_complete());
if let Some(node_status) = status.node_statuses.get_mut("node-1") {
node_status.status = TransferState::Completed;
node_status.progress_bytes = 1000;
}
if let Some(node_status) = status.node_statuses.get_mut("node-2") {
node_status.status = TransferState::Completed;
node_status.progress_bytes = 1000;
}
if let Some(node_status) = status.node_statuses.get_mut("node-3") {
node_status.status = TransferState::Failed;
node_status.error = Some("Connection lost".to_string());
}
status.recalculate_counts();
assert_eq!(status.completed, 2);
assert_eq!(status.failed, 1);
assert!(status.is_complete());
assert!(!status.is_success());
}
#[cfg(feature = "automerge-backend")]
#[test]
fn test_distribution_document_round_trip() {
let mut node_statuses = HashMap::new();
node_statuses.insert(
"node-a".to_string(),
NodeTransferStatus {
node_id: "node-a".to_string(),
status: TransferState::Completed,
progress_bytes: 1024,
total_bytes: 1024,
started_at: None,
completed_at: None,
error: None,
},
);
let doc = DistributionDocument {
distribution_id: "dist-1".to_string(),
blob_hash: "deadbeef".to_string(),
blob_size: 1024,
blob_metadata: BlobMetadata::default(),
scope: DistributionScope::AllNodes,
priority: TransferPriority::Normal,
target_nodes: vec!["node-a".to_string()],
started_at: Utc::now(),
status: "distributing".to_string(),
cancelled_at: None,
node_statuses,
};
let bytes = serde_json::to_vec(&doc).expect("serialize");
let restored: DistributionDocument = serde_json::from_slice(&bytes).expect("deserialize");
assert_eq!(restored.distribution_id, "dist-1");
assert_eq!(restored.target_nodes, vec!["node-a".to_string()]);
assert_eq!(restored.node_statuses.len(), 1);
assert_eq!(
restored.node_statuses["node-a"].status,
TransferState::Completed
);
}
#[cfg(feature = "automerge-backend")]
#[test]
fn test_distribution_document_legacy_compat() {
let current = DistributionDocument {
distribution_id: "dist-legacy".to_string(),
blob_hash: "abc123".to_string(),
blob_size: 42,
blob_metadata: BlobMetadata::default(),
scope: DistributionScope::AllNodes,
priority: TransferPriority::Normal,
target_nodes: vec!["node-x".to_string()],
started_at: Utc::now(),
status: "distributing".to_string(),
cancelled_at: None,
node_statuses: HashMap::new(),
};
let mut value = serde_json::to_value(¤t).unwrap();
value
.as_object_mut()
.unwrap()
.remove("node_statuses")
.expect("node_statuses present in current schema");
let bytes = serde_json::to_vec(&value).unwrap();
let restored: DistributionDocument = serde_json::from_slice(&bytes).expect("deserialize");
assert_eq!(restored.distribution_id, "dist-legacy");
assert!(restored.node_statuses.is_empty());
assert!(restored.cancelled_at.is_none());
}
#[test]
fn test_distribution_scope_serialization() {
let scope = DistributionScope::Capable {
min_gpu_gb: Some(4.0),
cpu_arch: Some("x86_64".to_string()),
min_storage_mb: Some(1024),
};
let json = serde_json::to_string(&scope).unwrap();
let restored: DistributionScope = serde_json::from_str(&json).unwrap();
match restored {
DistributionScope::Capable {
min_gpu_gb,
cpu_arch,
min_storage_mb,
} => {
assert_eq!(min_gpu_gb, Some(4.0));
assert_eq!(cpu_arch, Some("x86_64".to_string()));
assert_eq!(min_storage_mb, Some(1024));
}
_ => panic!("Wrong variant"),
}
}
}