1use super::blob_traits::{BlobHash, BlobToken};
47use anyhow::Result;
48use chrono::{DateTime, Utc};
49use serde::{Deserialize, Serialize};
50use std::collections::HashMap;
51#[cfg(feature = "automerge-backend")]
52use std::sync::Arc;
53use std::time::Duration;
54use tokio::sync::broadcast;
55#[cfg(feature = "automerge-backend")]
56use tokio::sync::RwLock;
57#[cfg(feature = "automerge-backend")]
58use tracing::{debug, info, warn};
59use uuid::Uuid;
60
61#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, Hash, Serialize, Deserialize)]
70pub enum TransferPriority {
71 Critical,
73 High,
75 #[default]
77 Normal,
78 Low,
80}
81
82impl TransferPriority {
83 pub fn as_numeric(&self) -> u8 {
85 match self {
86 Self::Critical => 4,
87 Self::High => 3,
88 Self::Normal => 2,
89 Self::Low => 1,
90 }
91 }
92}
93
94#[derive(Clone, Debug, Default, Serialize, Deserialize)]
98pub enum DistributionScope {
99 #[default]
101 AllNodes,
102
103 Formation {
105 formation_id: String,
107 },
108
109 Nodes {
111 node_ids: Vec<String>,
113 },
114
115 Capable {
117 #[serde(skip_serializing_if = "Option::is_none")]
119 min_gpu_gb: Option<f64>,
120
121 #[serde(skip_serializing_if = "Option::is_none")]
123 cpu_arch: Option<String>,
124
125 #[serde(skip_serializing_if = "Option::is_none")]
127 min_storage_mb: Option<u64>,
128 },
129}
130
131#[derive(Clone, Debug, Default, PartialEq, Eq, Serialize, Deserialize)]
133pub enum TransferState {
134 #[default]
136 Pending,
137 Connecting,
139 Transferring,
141 Completed,
143 Failed,
145}
146
147#[derive(Clone, Debug, Serialize, Deserialize)]
149pub struct NodeTransferStatus {
150 pub node_id: String,
152 pub status: TransferState,
154 pub progress_bytes: u64,
156 pub total_bytes: u64,
158 #[serde(skip_serializing_if = "Option::is_none")]
160 pub started_at: Option<DateTime<Utc>>,
161 #[serde(skip_serializing_if = "Option::is_none")]
163 pub completed_at: Option<DateTime<Utc>>,
164 #[serde(skip_serializing_if = "Option::is_none")]
166 pub error: Option<String>,
167}
168
169impl NodeTransferStatus {
170 pub fn new(node_id: String, total_bytes: u64) -> Self {
172 Self {
173 node_id,
174 status: TransferState::Pending,
175 progress_bytes: 0,
176 total_bytes,
177 started_at: None,
178 completed_at: None,
179 error: None,
180 }
181 }
182
183 pub fn progress_fraction(&self) -> f64 {
185 if self.total_bytes == 0 {
186 return 1.0;
187 }
188 self.progress_bytes as f64 / self.total_bytes as f64
189 }
190}
191
192#[derive(Clone, Debug, Serialize, Deserialize)]
196pub struct DistributionHandle {
197 pub distribution_id: String,
199 pub blob_hash: BlobHash,
201 pub scope: DistributionScope,
203 pub priority: TransferPriority,
205 pub started_at: DateTime<Utc>,
207}
208
209impl DistributionHandle {
210 pub fn new(blob_hash: BlobHash, scope: DistributionScope, priority: TransferPriority) -> Self {
212 Self {
213 distribution_id: Uuid::new_v4().to_string(),
214 blob_hash,
215 scope,
216 priority,
217 started_at: Utc::now(),
218 }
219 }
220}
221
222#[derive(Clone, Debug, Serialize, Deserialize)]
224pub struct DistributionStatus {
225 pub handle: DistributionHandle,
227 pub total_targets: usize,
229 pub completed: usize,
231 pub in_progress: usize,
233 pub failed: usize,
235 pub node_statuses: HashMap<String, NodeTransferStatus>,
237}
238
239impl DistributionStatus {
240 pub fn new(handle: DistributionHandle, target_nodes: Vec<String>, total_bytes: u64) -> Self {
242 let node_statuses: HashMap<String, NodeTransferStatus> = target_nodes
243 .into_iter()
244 .map(|id| (id.clone(), NodeTransferStatus::new(id, total_bytes)))
245 .collect();
246
247 let total_targets = node_statuses.len();
248
249 Self {
250 handle,
251 total_targets,
252 completed: 0,
253 in_progress: 0,
254 failed: 0,
255 node_statuses,
256 }
257 }
258
259 pub fn is_complete(&self) -> bool {
261 self.completed + self.failed >= self.total_targets
262 }
263
264 pub fn is_success(&self) -> bool {
266 self.completed >= self.total_targets && self.failed == 0
267 }
268
269 pub fn overall_progress(&self) -> f64 {
271 if self.total_targets == 0 {
272 return 1.0;
273 }
274 let total_bytes: u64 = self.node_statuses.values().map(|s| s.total_bytes).sum();
275 let progress_bytes: u64 = self.node_statuses.values().map(|s| s.progress_bytes).sum();
276 if total_bytes == 0 {
277 return 1.0;
278 }
279 progress_bytes as f64 / total_bytes as f64
280 }
281
282 pub fn recalculate_counts(&mut self) {
284 self.completed = 0;
285 self.in_progress = 0;
286 self.failed = 0;
287
288 for status in self.node_statuses.values() {
289 match status.status {
290 TransferState::Completed => self.completed += 1,
291 TransferState::Failed => self.failed += 1,
292 TransferState::Transferring | TransferState::Connecting => self.in_progress += 1,
293 TransferState::Pending => {}
294 }
295 }
296 }
297}
298
299#[async_trait::async_trait]
308pub trait FileDistribution: Send + Sync {
309 async fn distribute(
332 &self,
333 blob_token: &BlobToken,
334 scope: DistributionScope,
335 priority: TransferPriority,
336 ) -> Result<DistributionHandle>;
337
338 async fn status(&self, handle: &DistributionHandle) -> Result<DistributionStatus>;
342
343 async fn cancel(&self, handle: &DistributionHandle) -> Result<()>;
348
349 async fn wait_for_completion(
362 &self,
363 handle: &DistributionHandle,
364 timeout: Duration,
365 ) -> Result<DistributionStatus>;
366
367 async fn subscribe_progress(
372 &self,
373 handle: &DistributionHandle,
374 ) -> Result<broadcast::Receiver<DistributionStatus>>;
375}
376
377#[cfg(feature = "automerge-backend")]
382use super::automerge_store::AutomergeStore;
383#[cfg(feature = "automerge-backend")]
384use super::iroh_blob_store::NetworkedIrohBlobStore;
385
386#[cfg(feature = "automerge-backend")]
388const IROH_DISTRIBUTION_COLLECTION: &str = "file_distributions";
389
390#[cfg(feature = "automerge-backend")]
412pub struct IrohFileDistribution {
413 blob_store: Arc<NetworkedIrohBlobStore>,
415 document_store: Arc<AutomergeStore>,
417 distributions: RwLock<HashMap<String, DistributionStatus>>,
419 progress_channels: RwLock<HashMap<String, broadcast::Sender<DistributionStatus>>>,
421}
422
423#[cfg(feature = "automerge-backend")]
424impl IrohFileDistribution {
425 pub fn new(
427 blob_store: Arc<NetworkedIrohBlobStore>,
428 document_store: Arc<AutomergeStore>,
429 ) -> Self {
430 Self {
431 blob_store,
432 document_store,
433 distributions: RwLock::new(HashMap::new()),
434 progress_channels: RwLock::new(HashMap::new()),
435 }
436 }
437
438 pub fn blob_store(&self) -> &Arc<NetworkedIrohBlobStore> {
440 &self.blob_store
441 }
442
443 pub fn document_store(&self) -> &Arc<AutomergeStore> {
445 &self.document_store
446 }
447
448 async fn resolve_targets(&self, scope: &DistributionScope) -> Vec<String> {
453 match scope {
454 DistributionScope::AllNodes => {
455 self.blob_store
457 .known_peers()
458 .await
459 .iter()
460 .map(|p| p.fmt_short().to_string())
461 .collect()
462 }
463 DistributionScope::Nodes { node_ids } => {
464 let known_peers: Vec<String> = self
466 .blob_store
467 .known_peers()
468 .await
469 .iter()
470 .map(|p| p.fmt_short().to_string())
471 .collect();
472
473 node_ids
474 .iter()
475 .filter(|id| known_peers.contains(id))
476 .cloned()
477 .collect()
478 }
479 DistributionScope::Formation { formation_id } => {
480 warn!(
483 formation_id = %formation_id,
484 "Formation-based distribution not yet implemented, distributing to all peers"
485 );
486 self.blob_store
487 .known_peers()
488 .await
489 .iter()
490 .map(|p| p.fmt_short().to_string())
491 .collect()
492 }
493 DistributionScope::Capable { .. } => {
494 warn!(
497 "Capability-based distribution not yet implemented, distributing to all peers"
498 );
499 self.blob_store
500 .known_peers()
501 .await
502 .iter()
503 .map(|p| p.fmt_short().to_string())
504 .collect()
505 }
506 }
507 }
508
509 #[allow(unused_imports)]
511 async fn store_distribution_document(
512 &self,
513 handle: &DistributionHandle,
514 blob_token: &BlobToken,
515 target_nodes: &[String],
516 ) -> Result<()> {
517 use super::traits::Collection;
518
519 let doc_id = &handle.distribution_id;
520
521 let distribution_doc = serde_json::json!({
523 "distribution_id": handle.distribution_id,
524 "blob_hash": blob_token.hash.as_hex(),
525 "blob_size": blob_token.size_bytes,
526 "blob_metadata": blob_token.metadata,
527 "scope": handle.scope,
528 "priority": handle.priority,
529 "target_nodes": target_nodes,
530 "started_at": handle.started_at.to_rfc3339(),
531 "status": "distributing"
532 });
533
534 let bytes = serde_json::to_vec(&distribution_doc)
536 .map_err(|e| anyhow::anyhow!("Failed to serialize distribution doc: {}", e))?;
537
538 let collection = self.document_store.collection(IROH_DISTRIBUTION_COLLECTION);
540 collection.upsert(doc_id, bytes)?;
541
542 debug!(
543 distribution_id = %handle.distribution_id,
544 blob_hash = %blob_token.hash,
545 target_count = target_nodes.len(),
546 "Stored distribution document in Automerge"
547 );
548
549 Ok(())
550 }
551
552 #[allow(dead_code)]
554 async fn broadcast_progress(&self, distribution_id: &str, status: &DistributionStatus) {
555 let channels = self.progress_channels.read().await;
556 if let Some(sender) = channels.get(distribution_id) {
557 let _ = sender.send(status.clone());
559 }
560 }
561}
562
563#[cfg(feature = "automerge-backend")]
564#[async_trait::async_trait]
565impl FileDistribution for IrohFileDistribution {
566 async fn distribute(
567 &self,
568 blob_token: &BlobToken,
569 scope: DistributionScope,
570 priority: TransferPriority,
571 ) -> Result<DistributionHandle> {
572 info!(
573 blob_hash = %blob_token.hash,
574 blob_size = blob_token.size_bytes,
575 scope = ?scope,
576 priority = ?priority,
577 "Starting file distribution"
578 );
579
580 let handle = DistributionHandle::new(blob_token.hash.clone(), scope.clone(), priority);
582
583 let target_nodes = self.resolve_targets(&scope).await;
585
586 if target_nodes.is_empty() {
587 warn!("No target nodes found for distribution scope");
588 }
589
590 let status =
592 DistributionStatus::new(handle.clone(), target_nodes.clone(), blob_token.size_bytes);
593
594 self.store_distribution_document(&handle, blob_token, &target_nodes)
596 .await?;
597
598 {
600 let mut distributions = self.distributions.write().await;
601 distributions.insert(handle.distribution_id.clone(), status.clone());
602 }
603
604 {
606 let (tx, _rx) = broadcast::channel(16);
607 let mut channels = self.progress_channels.write().await;
608 channels.insert(handle.distribution_id.clone(), tx);
609 }
610
611 info!(
612 distribution_id = %handle.distribution_id,
613 target_count = target_nodes.len(),
614 "Distribution initiated - document synced to peers"
615 );
616
617 Ok(handle)
624 }
625
626 async fn status(&self, handle: &DistributionHandle) -> Result<DistributionStatus> {
627 let distributions = self.distributions.read().await;
628 distributions
629 .get(&handle.distribution_id)
630 .cloned()
631 .ok_or_else(|| anyhow::anyhow!("Distribution not found: {}", handle.distribution_id))
632 }
633
634 async fn cancel(&self, handle: &DistributionHandle) -> Result<()> {
635 info!(
636 distribution_id = %handle.distribution_id,
637 "Cancelling distribution"
638 );
639
640 {
642 let mut distributions = self.distributions.write().await;
643 if let Some(status) = distributions.get_mut(&handle.distribution_id) {
644 for node_status in status.node_statuses.values_mut() {
646 if node_status.status != TransferState::Completed {
647 node_status.status = TransferState::Failed;
648 node_status.error = Some("Distribution cancelled".to_string());
649 }
650 }
651 status.recalculate_counts();
652 }
653 }
654
655 #[allow(unused_imports)]
657 use super::traits::Collection;
658
659 let cancel_update = serde_json::json!({
660 "status": "cancelled",
661 "cancelled_at": Utc::now().to_rfc3339()
662 });
663
664 let bytes = serde_json::to_vec(&cancel_update)
665 .map_err(|e| anyhow::anyhow!("Failed to serialize cancel update: {}", e))?;
666
667 let collection = self.document_store.collection(IROH_DISTRIBUTION_COLLECTION);
668 collection.upsert(&handle.distribution_id, bytes)?;
669
670 Ok(())
671 }
672
673 async fn wait_for_completion(
674 &self,
675 handle: &DistributionHandle,
676 timeout: Duration,
677 ) -> Result<DistributionStatus> {
678 let start = std::time::Instant::now();
679 let poll_interval = Duration::from_millis(500);
680
681 loop {
682 let status = self.status(handle).await?;
683
684 if status.is_complete() {
685 return Ok(status);
686 }
687
688 if start.elapsed() >= timeout {
689 return Err(anyhow::anyhow!("Distribution timeout after {:?}", timeout));
690 }
691
692 tokio::time::sleep(poll_interval).await;
693 }
694 }
695
696 async fn subscribe_progress(
697 &self,
698 handle: &DistributionHandle,
699 ) -> Result<broadcast::Receiver<DistributionStatus>> {
700 let channels = self.progress_channels.read().await;
701 channels
702 .get(&handle.distribution_id)
703 .map(|sender| sender.subscribe())
704 .ok_or_else(|| anyhow::anyhow!("Distribution not found: {}", handle.distribution_id))
705 }
706}
707
708#[cfg(test)]
713mod tests {
714 use super::*;
715
716 #[test]
717 fn test_transfer_priority_ordering() {
718 assert!(TransferPriority::Critical.as_numeric() > TransferPriority::High.as_numeric());
719 assert!(TransferPriority::High.as_numeric() > TransferPriority::Normal.as_numeric());
720 assert!(TransferPriority::Normal.as_numeric() > TransferPriority::Low.as_numeric());
721 }
722
723 #[test]
724 fn test_distribution_handle_creation() {
725 let hash = BlobHash::from_hex("abc123");
726 let scope = DistributionScope::AllNodes;
727 let priority = TransferPriority::High;
728
729 let handle = DistributionHandle::new(hash.clone(), scope, priority);
730
731 assert!(!handle.distribution_id.is_empty());
732 assert_eq!(handle.blob_hash, hash);
733 assert_eq!(handle.priority, TransferPriority::High);
734 }
735
736 #[test]
737 fn test_node_transfer_status() {
738 let mut status = NodeTransferStatus::new("node-1".to_string(), 1000);
739
740 assert_eq!(status.status, TransferState::Pending);
741 assert_eq!(status.progress_fraction(), 0.0);
742
743 status.progress_bytes = 500;
744 status.status = TransferState::Transferring;
745 assert_eq!(status.progress_fraction(), 0.5);
746
747 status.progress_bytes = 1000;
748 status.status = TransferState::Completed;
749 assert_eq!(status.progress_fraction(), 1.0);
750 }
751
752 #[test]
753 fn test_distribution_status() {
754 let hash = BlobHash::from_hex("abc123");
755 let handle =
756 DistributionHandle::new(hash, DistributionScope::AllNodes, TransferPriority::Normal);
757 let targets = vec![
758 "node-1".to_string(),
759 "node-2".to_string(),
760 "node-3".to_string(),
761 ];
762
763 let mut status = DistributionStatus::new(handle, targets, 1000);
764
765 assert_eq!(status.total_targets, 3);
766 assert_eq!(status.completed, 0);
767 assert!(!status.is_complete());
768
769 if let Some(node_status) = status.node_statuses.get_mut("node-1") {
771 node_status.status = TransferState::Completed;
772 node_status.progress_bytes = 1000;
773 }
774 if let Some(node_status) = status.node_statuses.get_mut("node-2") {
775 node_status.status = TransferState::Completed;
776 node_status.progress_bytes = 1000;
777 }
778 if let Some(node_status) = status.node_statuses.get_mut("node-3") {
779 node_status.status = TransferState::Failed;
780 node_status.error = Some("Connection lost".to_string());
781 }
782
783 status.recalculate_counts();
784
785 assert_eq!(status.completed, 2);
786 assert_eq!(status.failed, 1);
787 assert!(status.is_complete());
788 assert!(!status.is_success());
789 }
790
791 #[test]
792 fn test_distribution_scope_serialization() {
793 let scope = DistributionScope::Capable {
794 min_gpu_gb: Some(4.0),
795 cpu_arch: Some("x86_64".to_string()),
796 min_storage_mb: Some(1024),
797 };
798
799 let json = serde_json::to_string(&scope).unwrap();
800 let restored: DistributionScope = serde_json::from_str(&json).unwrap();
801
802 match restored {
803 DistributionScope::Capable {
804 min_gpu_gb,
805 cpu_arch,
806 min_storage_mb,
807 } => {
808 assert_eq!(min_gpu_gb, Some(4.0));
809 assert_eq!(cpu_arch, Some("x86_64".to_string()));
810 assert_eq!(min_storage_mb, Some(1024));
811 }
812 _ => panic!("Wrong variant"),
813 }
814 }
815}