use super::blob_traits::{BlobHash, BlobToken};
use anyhow::Result;
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[cfg(feature = "automerge-backend")]
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::broadcast;
#[cfg(feature = "automerge-backend")]
use tokio::sync::RwLock;
#[cfg(feature = "automerge-backend")]
use tracing::{debug, info, warn};
use uuid::Uuid;
#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum TransferPriority {
Critical,
High,
#[default]
Normal,
Low,
}
impl TransferPriority {
pub fn as_numeric(&self) -> u8 {
match self {
Self::Critical => 4,
Self::High => 3,
Self::Normal => 2,
Self::Low => 1,
}
}
}
#[derive(Clone, Debug, Default, Serialize, Deserialize)]
pub enum DistributionScope {
#[default]
AllNodes,
Formation {
formation_id: String,
},
Nodes {
node_ids: Vec<String>,
},
Capable {
#[serde(skip_serializing_if = "Option::is_none")]
min_gpu_gb: Option<f64>,
#[serde(skip_serializing_if = "Option::is_none")]
cpu_arch: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
min_storage_mb: Option<u64>,
},
}
#[derive(Clone, Debug, Default, PartialEq, Eq, Serialize, Deserialize)]
pub enum TransferState {
#[default]
Pending,
Connecting,
Transferring,
Completed,
Failed,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct NodeTransferStatus {
pub node_id: String,
pub status: TransferState,
pub progress_bytes: u64,
pub total_bytes: u64,
#[serde(skip_serializing_if = "Option::is_none")]
pub started_at: Option<DateTime<Utc>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub completed_at: Option<DateTime<Utc>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub error: Option<String>,
}
impl NodeTransferStatus {
pub fn new(node_id: String, total_bytes: u64) -> Self {
Self {
node_id,
status: TransferState::Pending,
progress_bytes: 0,
total_bytes,
started_at: None,
completed_at: None,
error: None,
}
}
pub fn progress_fraction(&self) -> f64 {
if self.total_bytes == 0 {
return 1.0;
}
self.progress_bytes as f64 / self.total_bytes as f64
}
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct DistributionHandle {
pub distribution_id: String,
pub blob_hash: BlobHash,
pub scope: DistributionScope,
pub priority: TransferPriority,
pub started_at: DateTime<Utc>,
}
impl DistributionHandle {
pub fn new(blob_hash: BlobHash, scope: DistributionScope, priority: TransferPriority) -> Self {
Self {
distribution_id: Uuid::new_v4().to_string(),
blob_hash,
scope,
priority,
started_at: Utc::now(),
}
}
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct DistributionStatus {
pub handle: DistributionHandle,
pub total_targets: usize,
pub completed: usize,
pub in_progress: usize,
pub failed: usize,
pub node_statuses: HashMap<String, NodeTransferStatus>,
}
impl DistributionStatus {
pub fn new(handle: DistributionHandle, target_nodes: Vec<String>, total_bytes: u64) -> Self {
let node_statuses: HashMap<String, NodeTransferStatus> = target_nodes
.into_iter()
.map(|id| (id.clone(), NodeTransferStatus::new(id, total_bytes)))
.collect();
let total_targets = node_statuses.len();
Self {
handle,
total_targets,
completed: 0,
in_progress: 0,
failed: 0,
node_statuses,
}
}
pub fn is_complete(&self) -> bool {
self.completed + self.failed >= self.total_targets
}
pub fn is_success(&self) -> bool {
self.completed >= self.total_targets && self.failed == 0
}
pub fn overall_progress(&self) -> f64 {
if self.total_targets == 0 {
return 1.0;
}
let total_bytes: u64 = self.node_statuses.values().map(|s| s.total_bytes).sum();
let progress_bytes: u64 = self.node_statuses.values().map(|s| s.progress_bytes).sum();
if total_bytes == 0 {
return 1.0;
}
progress_bytes as f64 / total_bytes as f64
}
pub fn recalculate_counts(&mut self) {
self.completed = 0;
self.in_progress = 0;
self.failed = 0;
for status in self.node_statuses.values() {
match status.status {
TransferState::Completed => self.completed += 1,
TransferState::Failed => self.failed += 1,
TransferState::Transferring | TransferState::Connecting => self.in_progress += 1,
TransferState::Pending => {}
}
}
}
}
#[async_trait::async_trait]
pub trait FileDistribution: Send + Sync {
async fn distribute(
&self,
blob_token: &BlobToken,
scope: DistributionScope,
priority: TransferPriority,
) -> Result<DistributionHandle>;
async fn status(&self, handle: &DistributionHandle) -> Result<DistributionStatus>;
async fn cancel(&self, handle: &DistributionHandle) -> Result<()>;
async fn wait_for_completion(
&self,
handle: &DistributionHandle,
timeout: Duration,
) -> Result<DistributionStatus>;
async fn subscribe_progress(
&self,
handle: &DistributionHandle,
) -> Result<broadcast::Receiver<DistributionStatus>>;
}
#[cfg(feature = "automerge-backend")]
use super::automerge_store::AutomergeStore;
#[cfg(feature = "automerge-backend")]
use super::iroh_blob_store::NetworkedIrohBlobStore;
#[cfg(feature = "automerge-backend")]
const IROH_DISTRIBUTION_COLLECTION: &str = "file_distributions";
#[cfg(feature = "automerge-backend")]
pub struct IrohFileDistribution {
blob_store: Arc<NetworkedIrohBlobStore>,
document_store: Arc<AutomergeStore>,
distributions: RwLock<HashMap<String, DistributionStatus>>,
progress_channels: RwLock<HashMap<String, broadcast::Sender<DistributionStatus>>>,
}
#[cfg(feature = "automerge-backend")]
impl IrohFileDistribution {
pub fn new(
blob_store: Arc<NetworkedIrohBlobStore>,
document_store: Arc<AutomergeStore>,
) -> Self {
Self {
blob_store,
document_store,
distributions: RwLock::new(HashMap::new()),
progress_channels: RwLock::new(HashMap::new()),
}
}
pub fn blob_store(&self) -> &Arc<NetworkedIrohBlobStore> {
&self.blob_store
}
pub fn document_store(&self) -> &Arc<AutomergeStore> {
&self.document_store
}
async fn resolve_targets(&self, scope: &DistributionScope) -> Vec<String> {
match scope {
DistributionScope::AllNodes => {
self.blob_store
.known_peers()
.await
.iter()
.map(|p| p.fmt_short().to_string())
.collect()
}
DistributionScope::Nodes { node_ids } => {
let known_peers: Vec<String> = self
.blob_store
.known_peers()
.await
.iter()
.map(|p| p.fmt_short().to_string())
.collect();
node_ids
.iter()
.filter(|id| known_peers.contains(id))
.cloned()
.collect()
}
DistributionScope::Formation { formation_id } => {
warn!(
formation_id = %formation_id,
"Formation-based distribution not yet implemented, distributing to all peers"
);
self.blob_store
.known_peers()
.await
.iter()
.map(|p| p.fmt_short().to_string())
.collect()
}
DistributionScope::Capable { .. } => {
warn!(
"Capability-based distribution not yet implemented, distributing to all peers"
);
self.blob_store
.known_peers()
.await
.iter()
.map(|p| p.fmt_short().to_string())
.collect()
}
}
}
#[allow(unused_imports)]
async fn store_distribution_document(
&self,
handle: &DistributionHandle,
blob_token: &BlobToken,
target_nodes: &[String],
) -> Result<()> {
use super::traits::Collection;
let doc_id = &handle.distribution_id;
let distribution_doc = serde_json::json!({
"distribution_id": handle.distribution_id,
"blob_hash": blob_token.hash.as_hex(),
"blob_size": blob_token.size_bytes,
"blob_metadata": blob_token.metadata,
"scope": handle.scope,
"priority": handle.priority,
"target_nodes": target_nodes,
"started_at": handle.started_at.to_rfc3339(),
"status": "distributing"
});
let bytes = serde_json::to_vec(&distribution_doc)
.map_err(|e| anyhow::anyhow!("Failed to serialize distribution doc: {}", e))?;
let collection = self.document_store.collection(IROH_DISTRIBUTION_COLLECTION);
collection.upsert(doc_id, bytes)?;
debug!(
distribution_id = %handle.distribution_id,
blob_hash = %blob_token.hash,
target_count = target_nodes.len(),
"Stored distribution document in Automerge"
);
Ok(())
}
#[allow(dead_code)]
async fn broadcast_progress(&self, distribution_id: &str, status: &DistributionStatus) {
let channels = self.progress_channels.read().await;
if let Some(sender) = channels.get(distribution_id) {
let _ = sender.send(status.clone());
}
}
}
#[cfg(feature = "automerge-backend")]
#[async_trait::async_trait]
impl FileDistribution for IrohFileDistribution {
async fn distribute(
&self,
blob_token: &BlobToken,
scope: DistributionScope,
priority: TransferPriority,
) -> Result<DistributionHandle> {
info!(
blob_hash = %blob_token.hash,
blob_size = blob_token.size_bytes,
scope = ?scope,
priority = ?priority,
"Starting file distribution"
);
let handle = DistributionHandle::new(blob_token.hash.clone(), scope.clone(), priority);
let target_nodes = self.resolve_targets(&scope).await;
if target_nodes.is_empty() {
warn!("No target nodes found for distribution scope");
}
let status =
DistributionStatus::new(handle.clone(), target_nodes.clone(), blob_token.size_bytes);
self.store_distribution_document(&handle, blob_token, &target_nodes)
.await?;
{
let mut distributions = self.distributions.write().await;
distributions.insert(handle.distribution_id.clone(), status.clone());
}
{
let (tx, _rx) = broadcast::channel(16);
let mut channels = self.progress_channels.write().await;
channels.insert(handle.distribution_id.clone(), tx);
}
info!(
distribution_id = %handle.distribution_id,
target_count = target_nodes.len(),
"Distribution initiated - document synced to peers"
);
Ok(handle)
}
async fn status(&self, handle: &DistributionHandle) -> Result<DistributionStatus> {
let distributions = self.distributions.read().await;
distributions
.get(&handle.distribution_id)
.cloned()
.ok_or_else(|| anyhow::anyhow!("Distribution not found: {}", handle.distribution_id))
}
async fn cancel(&self, handle: &DistributionHandle) -> Result<()> {
info!(
distribution_id = %handle.distribution_id,
"Cancelling distribution"
);
{
let mut distributions = self.distributions.write().await;
if let Some(status) = distributions.get_mut(&handle.distribution_id) {
for node_status in status.node_statuses.values_mut() {
if node_status.status != TransferState::Completed {
node_status.status = TransferState::Failed;
node_status.error = Some("Distribution cancelled".to_string());
}
}
status.recalculate_counts();
}
}
#[allow(unused_imports)]
use super::traits::Collection;
let cancel_update = serde_json::json!({
"status": "cancelled",
"cancelled_at": Utc::now().to_rfc3339()
});
let bytes = serde_json::to_vec(&cancel_update)
.map_err(|e| anyhow::anyhow!("Failed to serialize cancel update: {}", e))?;
let collection = self.document_store.collection(IROH_DISTRIBUTION_COLLECTION);
collection.upsert(&handle.distribution_id, bytes)?;
Ok(())
}
async fn wait_for_completion(
&self,
handle: &DistributionHandle,
timeout: Duration,
) -> Result<DistributionStatus> {
let start = std::time::Instant::now();
let poll_interval = Duration::from_millis(500);
loop {
let status = self.status(handle).await?;
if status.is_complete() {
return Ok(status);
}
if start.elapsed() >= timeout {
return Err(anyhow::anyhow!("Distribution timeout after {:?}", timeout));
}
tokio::time::sleep(poll_interval).await;
}
}
async fn subscribe_progress(
&self,
handle: &DistributionHandle,
) -> Result<broadcast::Receiver<DistributionStatus>> {
let channels = self.progress_channels.read().await;
channels
.get(&handle.distribution_id)
.map(|sender| sender.subscribe())
.ok_or_else(|| anyhow::anyhow!("Distribution not found: {}", handle.distribution_id))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_transfer_priority_ordering() {
assert!(TransferPriority::Critical.as_numeric() > TransferPriority::High.as_numeric());
assert!(TransferPriority::High.as_numeric() > TransferPriority::Normal.as_numeric());
assert!(TransferPriority::Normal.as_numeric() > TransferPriority::Low.as_numeric());
}
#[test]
fn test_distribution_handle_creation() {
let hash = BlobHash::from_hex("abc123");
let scope = DistributionScope::AllNodes;
let priority = TransferPriority::High;
let handle = DistributionHandle::new(hash.clone(), scope, priority);
assert!(!handle.distribution_id.is_empty());
assert_eq!(handle.blob_hash, hash);
assert_eq!(handle.priority, TransferPriority::High);
}
#[test]
fn test_node_transfer_status() {
let mut status = NodeTransferStatus::new("node-1".to_string(), 1000);
assert_eq!(status.status, TransferState::Pending);
assert_eq!(status.progress_fraction(), 0.0);
status.progress_bytes = 500;
status.status = TransferState::Transferring;
assert_eq!(status.progress_fraction(), 0.5);
status.progress_bytes = 1000;
status.status = TransferState::Completed;
assert_eq!(status.progress_fraction(), 1.0);
}
#[test]
fn test_distribution_status() {
let hash = BlobHash::from_hex("abc123");
let handle =
DistributionHandle::new(hash, DistributionScope::AllNodes, TransferPriority::Normal);
let targets = vec![
"node-1".to_string(),
"node-2".to_string(),
"node-3".to_string(),
];
let mut status = DistributionStatus::new(handle, targets, 1000);
assert_eq!(status.total_targets, 3);
assert_eq!(status.completed, 0);
assert!(!status.is_complete());
if let Some(node_status) = status.node_statuses.get_mut("node-1") {
node_status.status = TransferState::Completed;
node_status.progress_bytes = 1000;
}
if let Some(node_status) = status.node_statuses.get_mut("node-2") {
node_status.status = TransferState::Completed;
node_status.progress_bytes = 1000;
}
if let Some(node_status) = status.node_statuses.get_mut("node-3") {
node_status.status = TransferState::Failed;
node_status.error = Some("Connection lost".to_string());
}
status.recalculate_counts();
assert_eq!(status.completed, 2);
assert_eq!(status.failed, 1);
assert!(status.is_complete());
assert!(!status.is_success());
}
#[test]
fn test_distribution_scope_serialization() {
let scope = DistributionScope::Capable {
min_gpu_gb: Some(4.0),
cpu_arch: Some("x86_64".to_string()),
min_storage_mb: Some(1024),
};
let json = serde_json::to_string(&scope).unwrap();
let restored: DistributionScope = serde_json::from_str(&json).unwrap();
match restored {
DistributionScope::Capable {
min_gpu_gb,
cpu_arch,
min_storage_mb,
} => {
assert_eq!(min_gpu_gb, Some(4.0));
assert_eq!(cpu_arch, Some("x86_64".to_string()));
assert_eq!(min_storage_mb, Some(1024));
}
_ => panic!("Wrong variant"),
}
}
}