ipfrs_transport/
graphsync.rs

1//! GraphSync protocol for DAG traversal
2//!
3//! Implements efficient DAG traversal with:
4//! - IPLD selector parsing and execution
5//! - Incremental response streaming
6//! - Resume capability from partial transfers
7//! - Breadth-first and depth-first traversal
8//!
9//! # Example
10//!
11//! ```
12//! use ipfrs_transport::Selector;
13//!
14//! // Create a selector for recursive depth-limited traversal
15//! let selector = Selector::RecursiveDepth { max_depth: 5 };
16//!
17//! // Validate the selector
18//! assert!(selector.validate().is_ok());
19//!
20//! // Create a selector for specific fields
21//! let field_selector = Selector::Fields {
22//!     fields: vec!["data".to_string(), "links".to_string()]
23//! };
24//!
25//! // Parse from JSON
26//! let json = r#"{"type": "recursivedepth", "max_depth": 3}"#;
27//! let parsed = Selector::from_json(json).unwrap();
28//! ```
29
30use ipfrs_core::error::{Error, Result};
31use ipfrs_core::{Block, Cid};
32use ipfrs_storage::traits::BlockStore;
33use serde::{Deserialize, Serialize};
34use std::collections::{HashMap, HashSet, VecDeque};
35use std::sync::Arc;
36use tokio::sync::RwLock;
37
38/// IPLD Selector
39///
40/// Selectors specify which parts of a DAG to traverse
41#[derive(Debug, Clone, Serialize, Deserialize)]
42#[serde(tag = "type", rename_all = "lowercase")]
43#[derive(Default)]
44pub enum Selector {
45    /// Match everything
46    #[default]
47    All,
48    /// Match specific fields by name
49    Fields { fields: Vec<String> },
50    /// Recursively traverse to a depth limit
51    RecursiveDepth { max_depth: usize },
52    /// Recursively traverse all links
53    RecursiveAll,
54    /// Match based on index
55    Index { index: usize },
56    /// Sequence of selectors
57    Sequence { selectors: Vec<Selector> },
58    /// Match the current node
59    Matcher,
60}
61
62impl Selector {
63    /// Parse a selector from JSON
64    pub fn from_json(json: &str) -> Result<Self> {
65        serde_json::from_str(json)
66            .map_err(|e| Error::InvalidInput(format!("Failed to parse selector: {}", e)))
67    }
68
69    /// Validate the selector
70    pub fn validate(&self) -> Result<()> {
71        match self {
72            Selector::RecursiveDepth { max_depth } => {
73                if *max_depth == 0 {
74                    return Err(Error::InvalidInput(
75                        "max_depth must be greater than 0".to_string(),
76                    ));
77                }
78            }
79            Selector::Sequence { selectors } => {
80                for sel in selectors {
81                    sel.validate()?;
82                }
83            }
84            _ => {}
85        }
86        Ok(())
87    }
88
89    /// Check if this selector matches all
90    pub fn matches_all(&self) -> bool {
91        matches!(self, Selector::All | Selector::RecursiveAll)
92    }
93}
94
95/// Traversal mode
96#[derive(Debug, Clone, Copy, PartialEq, Eq)]
97pub enum TraversalMode {
98    /// Breadth-first search
99    BreadthFirst,
100    /// Depth-first search
101    DepthFirst,
102}
103
104/// DAG traversal state
105#[derive(Debug, Clone)]
106pub struct TraversalState {
107    /// Root CID
108    pub root: Cid,
109    /// Visited CIDs
110    pub visited: HashSet<Cid>,
111    /// Queue of CIDs to visit (for BFS) or stack (for DFS)
112    pub queue: VecDeque<(Cid, usize)>, // (CID, depth)
113    /// Current depth
114    pub current_depth: usize,
115    /// Maximum depth (if limited)
116    pub max_depth: Option<usize>,
117    /// Blocks fetched so far
118    pub blocks_fetched: usize,
119    /// Bytes fetched so far
120    pub bytes_fetched: u64,
121}
122
123impl TraversalState {
124    /// Create a new traversal state
125    pub fn new(root: Cid, max_depth: Option<usize>) -> Self {
126        let mut queue = VecDeque::new();
127        queue.push_back((root, 0));
128
129        Self {
130            root,
131            visited: HashSet::new(),
132            queue,
133            current_depth: 0,
134            max_depth,
135            blocks_fetched: 0,
136            bytes_fetched: 0,
137        }
138    }
139
140    /// Check if traversal is complete
141    pub fn is_complete(&self) -> bool {
142        self.queue.is_empty()
143    }
144
145    /// Get next CID to visit
146    pub fn next(&mut self, mode: TraversalMode) -> Option<(Cid, usize)> {
147        match mode {
148            TraversalMode::BreadthFirst => self.queue.pop_front(),
149            TraversalMode::DepthFirst => self.queue.pop_back(),
150        }
151    }
152
153    /// Add a CID to the queue
154    pub fn enqueue(&mut self, cid: Cid, depth: usize) {
155        if let Some(max) = self.max_depth {
156            if depth > max {
157                return;
158            }
159        }
160
161        if !self.visited.contains(&cid) {
162            self.queue.push_back((cid, depth));
163        }
164    }
165
166    /// Mark a CID as visited
167    pub fn mark_visited(&mut self, cid: Cid, size: u64) {
168        self.visited.insert(cid);
169        self.blocks_fetched += 1;
170        self.bytes_fetched += size;
171    }
172
173    /// Save checkpoint for resume
174    pub fn checkpoint(&self) -> TraversalCheckpoint {
175        TraversalCheckpoint {
176            root: self.root,
177            visited: self.visited.clone(),
178            queue: self.queue.clone(),
179            max_depth: self.max_depth,
180            blocks_fetched: self.blocks_fetched,
181            bytes_fetched: self.bytes_fetched,
182        }
183    }
184
185    /// Restore from checkpoint
186    pub fn from_checkpoint(checkpoint: TraversalCheckpoint) -> Self {
187        Self {
188            root: checkpoint.root,
189            visited: checkpoint.visited,
190            queue: checkpoint.queue,
191            current_depth: 0,
192            max_depth: checkpoint.max_depth,
193            blocks_fetched: checkpoint.blocks_fetched,
194            bytes_fetched: checkpoint.bytes_fetched,
195        }
196    }
197}
198
199/// Checkpoint for resuming traversal
200#[derive(Debug, Clone)]
201pub struct TraversalCheckpoint {
202    /// Root CID
203    pub root: Cid,
204    /// Visited CIDs
205    pub visited: HashSet<Cid>,
206    /// Queue state
207    pub queue: VecDeque<(Cid, usize)>,
208    /// Maximum depth
209    pub max_depth: Option<usize>,
210    /// Blocks fetched
211    pub blocks_fetched: usize,
212    /// Bytes fetched
213    pub bytes_fetched: u64,
214}
215
216impl TraversalCheckpoint {
217    /// Serialize to JSON using CID strings
218    pub fn to_json(&self) -> Result<String> {
219        #[derive(Serialize)]
220        struct SerializableCheckpoint {
221            root: String,
222            visited: Vec<String>,
223            queue: Vec<(String, usize)>,
224            max_depth: Option<usize>,
225            blocks_fetched: usize,
226            bytes_fetched: u64,
227        }
228
229        let serializable = SerializableCheckpoint {
230            root: self.root.to_string(),
231            visited: self.visited.iter().map(|c| c.to_string()).collect(),
232            queue: self
233                .queue
234                .iter()
235                .map(|(c, d)| (c.to_string(), *d))
236                .collect(),
237            max_depth: self.max_depth,
238            blocks_fetched: self.blocks_fetched,
239            bytes_fetched: self.bytes_fetched,
240        };
241
242        serde_json::to_string(&serializable)
243            .map_err(|e| Error::Internal(format!("Failed to serialize checkpoint: {}", e)))
244    }
245
246    /// Deserialize from JSON
247    pub fn from_json(json: &str) -> Result<Self> {
248        #[derive(Deserialize)]
249        struct SerializableCheckpoint {
250            root: String,
251            visited: Vec<String>,
252            queue: Vec<(String, usize)>,
253            max_depth: Option<usize>,
254            blocks_fetched: usize,
255            bytes_fetched: u64,
256        }
257
258        let serializable: SerializableCheckpoint = serde_json::from_str(json)
259            .map_err(|e| Error::Internal(format!("Failed to deserialize checkpoint: {}", e)))?;
260
261        let root: Cid = serializable
262            .root
263            .parse()
264            .map_err(|e| Error::InvalidInput(format!("Invalid root CID: {}", e)))?;
265
266        let visited: Result<HashSet<Cid>> = serializable
267            .visited
268            .iter()
269            .map(|s| {
270                s.parse()
271                    .map_err(|e| Error::InvalidInput(format!("Invalid CID: {}", e)))
272            })
273            .collect();
274
275        let queue: Result<VecDeque<(Cid, usize)>> = serializable
276            .queue
277            .iter()
278            .map(|(s, d)| {
279                s.parse()
280                    .map(|c| (c, *d))
281                    .map_err(|e| Error::InvalidInput(format!("Invalid CID: {}", e)))
282            })
283            .collect();
284
285        Ok(Self {
286            root,
287            visited: visited?,
288            queue: queue?,
289            max_depth: serializable.max_depth,
290            blocks_fetched: serializable.blocks_fetched,
291            bytes_fetched: serializable.bytes_fetched,
292        })
293    }
294}
295
296/// DAG traversal engine
297pub struct DagTraversal<S: BlockStore> {
298    /// Block store
299    store: Arc<S>,
300    /// Traversal mode
301    mode: TraversalMode,
302    /// Selector
303    #[allow(dead_code)]
304    selector: Selector,
305    /// Traversal state
306    state: Arc<RwLock<TraversalState>>,
307}
308
309impl<S: BlockStore> DagTraversal<S> {
310    /// Create a new DAG traversal
311    pub fn new(store: Arc<S>, root: Cid, selector: Selector, mode: TraversalMode) -> Result<Self> {
312        selector.validate()?;
313
314        let max_depth = match &selector {
315            Selector::RecursiveDepth { max_depth } => Some(*max_depth),
316            _ => None,
317        };
318
319        let state = TraversalState::new(root, max_depth);
320
321        Ok(Self {
322            store,
323            mode,
324            selector,
325            state: Arc::new(RwLock::new(state)),
326        })
327    }
328
329    /// Resume from a checkpoint
330    pub fn from_checkpoint(
331        store: Arc<S>,
332        checkpoint: TraversalCheckpoint,
333        selector: Selector,
334        mode: TraversalMode,
335    ) -> Result<Self> {
336        selector.validate()?;
337        let state = TraversalState::from_checkpoint(checkpoint);
338
339        Ok(Self {
340            store,
341            mode,
342            selector,
343            state: Arc::new(RwLock::new(state)),
344        })
345    }
346
347    /// Get the next block in the traversal
348    pub async fn next_block(&self) -> Result<Option<Block>> {
349        let mut state = self.state.write().await;
350
351        // Get next CID to visit
352        let (cid, depth) = match state.next(self.mode) {
353            Some(item) => item,
354            None => return Ok(None),
355        };
356
357        // Fetch the block
358        let block = match self.store.get(&cid).await? {
359            Some(b) => b,
360            None => return Err(Error::NotFound(format!("Block not found for CID: {}", cid))),
361        };
362
363        // Mark as visited
364        state.mark_visited(cid, block.data().len() as u64);
365        state.current_depth = depth;
366
367        // Extract links from the block and add to queue
368        if let Ok(links) = self.extract_links(&block) {
369            for link_cid in links {
370                state.enqueue(link_cid, depth + 1);
371            }
372        }
373
374        Ok(Some(block))
375    }
376
377    /// Extract CID links from a block
378    fn extract_links(&self, _block: &Block) -> Result<Vec<Cid>> {
379        // Simple link extraction - in a real implementation, this would parse IPLD
380        // and extract CIDs based on the selector
381
382        // For now, we'll just return an empty vector
383        // In a real implementation, you would:
384        // 1. Parse the block data as IPLD
385        // 2. Apply the selector to determine which fields to follow
386        // 3. Extract CID links from those fields
387
388        Ok(Vec::new())
389    }
390
391    /// Check if traversal is complete
392    pub async fn is_complete(&self) -> bool {
393        self.state.read().await.is_complete()
394    }
395
396    /// Get traversal statistics
397    pub async fn stats(&self) -> TraversalStats {
398        let state = self.state.read().await;
399        TraversalStats {
400            blocks_fetched: state.blocks_fetched,
401            bytes_fetched: state.bytes_fetched,
402            blocks_remaining: state.queue.len(),
403            current_depth: state.current_depth,
404        }
405    }
406
407    /// Create a checkpoint for resume
408    pub async fn checkpoint(&self) -> TraversalCheckpoint {
409        self.state.read().await.checkpoint()
410    }
411
412    /// Traverse all and collect blocks
413    pub async fn collect_all(&self) -> Result<Vec<Block>> {
414        let mut blocks = Vec::new();
415
416        while let Some(block) = self.next_block().await? {
417            blocks.push(block);
418        }
419
420        Ok(blocks)
421    }
422}
423
424/// Traversal statistics
425#[derive(Debug, Clone)]
426pub struct TraversalStats {
427    /// Number of blocks fetched
428    pub blocks_fetched: usize,
429    /// Bytes fetched
430    pub bytes_fetched: u64,
431    /// Blocks remaining in queue
432    pub blocks_remaining: usize,
433    /// Current traversal depth
434    pub current_depth: usize,
435}
436
437/// GraphSync protocol handler
438pub struct GraphSync<S: BlockStore> {
439    /// Block store
440    store: Arc<S>,
441    /// Active traversals
442    traversals: Arc<RwLock<HashMap<Cid, Arc<DagTraversal<S>>>>>,
443}
444
445impl<S: BlockStore> GraphSync<S> {
446    /// Create a new GraphSync instance
447    pub fn new(store: Arc<S>) -> Result<Self> {
448        Ok(Self {
449            store,
450            traversals: Arc::new(RwLock::new(HashMap::new())),
451        })
452    }
453
454    /// Start a new traversal
455    pub async fn start_traversal(
456        &self,
457        root: Cid,
458        selector: Selector,
459        mode: TraversalMode,
460    ) -> Result<Arc<DagTraversal<S>>> {
461        let traversal = Arc::new(DagTraversal::new(self.store.clone(), root, selector, mode)?);
462
463        let mut traversals = self.traversals.write().await;
464        traversals.insert(root, traversal.clone());
465
466        Ok(traversal)
467    }
468
469    /// Resume a traversal from checkpoint
470    pub async fn resume_traversal(
471        &self,
472        checkpoint: TraversalCheckpoint,
473        selector: Selector,
474        mode: TraversalMode,
475    ) -> Result<Arc<DagTraversal<S>>> {
476        let root = checkpoint.root;
477        let traversal = Arc::new(DagTraversal::from_checkpoint(
478            self.store.clone(),
479            checkpoint,
480            selector,
481            mode,
482        )?);
483
484        let mut traversals = self.traversals.write().await;
485        traversals.insert(root, traversal.clone());
486
487        Ok(traversal)
488    }
489
490    /// Get an active traversal
491    pub async fn get_traversal(&self, root: &Cid) -> Option<Arc<DagTraversal<S>>> {
492        self.traversals.read().await.get(root).cloned()
493    }
494
495    /// Remove a completed traversal
496    pub async fn remove_traversal(&self, root: &Cid) {
497        self.traversals.write().await.remove(root);
498    }
499
500    /// Get number of active traversals
501    pub async fn active_count(&self) -> usize {
502        self.traversals.read().await.len()
503    }
504}
505
506/// Gradient message for federated learning
507#[derive(Debug, Clone, Serialize, Deserialize)]
508pub struct GradientMessage {
509    /// Gradient identifier (e.g., layer name or tensor CID)
510    pub id: String,
511    /// Gradient data (compressed)
512    pub data: Vec<u8>,
513    /// Shape of the gradient tensor
514    pub shape: Vec<usize>,
515    /// Data type (f32, f16, etc.)
516    pub dtype: String,
517    /// Checksum for verification
518    pub checksum: u64,
519    /// Metadata (e.g., learning rate, batch size)
520    pub metadata: HashMap<String, String>,
521}
522
523impl GradientMessage {
524    /// Create a new gradient message
525    pub fn new(
526        id: impl Into<String>,
527        data: Vec<u8>,
528        shape: Vec<usize>,
529        dtype: impl Into<String>,
530    ) -> Self {
531        let checksum = Self::compute_checksum(&data);
532        Self {
533            id: id.into(),
534            data,
535            shape,
536            dtype: dtype.into(),
537            checksum,
538            metadata: HashMap::new(),
539        }
540    }
541
542    /// Compute checksum for data verification
543    fn compute_checksum(data: &[u8]) -> u64 {
544        // Simple checksum using FNV-1a hash
545        let mut hash: u64 = 0xcbf29ce484222325;
546        for &byte in data {
547            hash ^= byte as u64;
548            hash = hash.wrapping_mul(0x100000001b3);
549        }
550        hash
551    }
552
553    /// Verify checksum
554    pub fn verify(&self) -> bool {
555        Self::compute_checksum(&self.data) == self.checksum
556    }
557
558    /// Add metadata
559    pub fn with_metadata(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
560        self.metadata.insert(key.into(), value.into());
561        self
562    }
563
564    /// Get total elements in gradient
565    pub fn num_elements(&self) -> usize {
566        self.shape.iter().product()
567    }
568
569    /// Estimate size in bytes
570    pub fn size_bytes(&self) -> usize {
571        self.data.len()
572    }
573}
574
575/// Gradient aggregation strategy
576#[derive(Debug, Clone, Copy, PartialEq, Eq)]
577pub enum AggregationStrategy {
578    /// Simple averaging
579    Average,
580    /// Weighted average based on sample counts
581    WeightedAverage,
582    /// Median aggregation (robust to outliers)
583    Median,
584    /// Federated averaging (FedAvg)
585    FederatedAvg,
586}
587
588/// Gradient aggregator for federated learning
589pub struct GradientAggregator {
590    /// Aggregation strategy
591    strategy: AggregationStrategy,
592    /// Accumulated gradients per layer
593    gradients: Arc<RwLock<HashMap<String, Vec<GradientMessage>>>>,
594    /// Expected number of contributors
595    expected_contributors: usize,
596    /// Verification enabled
597    verify_checksums: bool,
598}
599
600impl GradientAggregator {
601    /// Create a new gradient aggregator
602    pub fn new(strategy: AggregationStrategy, expected_contributors: usize) -> Self {
603        Self {
604            strategy,
605            gradients: Arc::new(RwLock::new(HashMap::new())),
606            expected_contributors,
607            verify_checksums: true,
608        }
609    }
610
611    /// Add a gradient to the aggregator
612    pub async fn add_gradient(&self, gradient: GradientMessage) -> Result<()> {
613        // Verify checksum if enabled
614        if self.verify_checksums && !gradient.verify() {
615            return Err(Error::InvalidInput(format!(
616                "Gradient checksum verification failed for {}",
617                gradient.id
618            )));
619        }
620
621        // Verify dimensions
622        if gradient.num_elements() == 0 {
623            return Err(Error::InvalidInput(
624                "Gradient has zero elements".to_string(),
625            ));
626        }
627
628        let mut gradients = self.gradients.write().await;
629        gradients
630            .entry(gradient.id.clone())
631            .or_insert_with(Vec::new)
632            .push(gradient);
633
634        Ok(())
635    }
636
637    /// Check if ready to aggregate (all contributors submitted)
638    pub async fn is_ready(&self, layer_id: &str) -> bool {
639        let gradients = self.gradients.read().await;
640        gradients
641            .get(layer_id)
642            .map(|g| g.len() >= self.expected_contributors)
643            .unwrap_or(false)
644    }
645
646    /// Aggregate gradients for a specific layer
647    pub async fn aggregate(&self, layer_id: &str) -> Result<GradientMessage> {
648        let gradients = self.gradients.read().await;
649        let layer_gradients = gradients
650            .get(layer_id)
651            .ok_or_else(|| Error::NotFound(format!("No gradients for layer: {}", layer_id)))?;
652
653        if layer_gradients.is_empty() {
654            return Err(Error::InvalidInput("No gradients to aggregate".to_string()));
655        }
656
657        // Verify all gradients have same shape
658        let first_shape = &layer_gradients[0].shape;
659        for grad in layer_gradients.iter().skip(1) {
660            if &grad.shape != first_shape {
661                return Err(Error::InvalidInput("Gradient shape mismatch".to_string()));
662            }
663        }
664
665        match self.strategy {
666            AggregationStrategy::Average | AggregationStrategy::FederatedAvg => {
667                self.aggregate_average(layer_id, layer_gradients)
668            }
669            AggregationStrategy::WeightedAverage => {
670                self.aggregate_weighted(layer_id, layer_gradients)
671            }
672            AggregationStrategy::Median => self.aggregate_median(layer_id, layer_gradients),
673        }
674    }
675
676    /// Simple averaging aggregation
677    fn aggregate_average(
678        &self,
679        layer_id: &str,
680        gradients: &[GradientMessage],
681    ) -> Result<GradientMessage> {
682        let n = gradients.len();
683        let size = gradients[0].data.len();
684
685        // Sum all gradients (treating as bytes for now)
686        let mut sum = vec![0u8; size];
687        for grad in gradients {
688            for (i, &byte) in grad.data.iter().enumerate() {
689                sum[i] = sum[i].saturating_add(byte / n as u8);
690            }
691        }
692
693        Ok(GradientMessage::new(
694            layer_id,
695            sum,
696            gradients[0].shape.clone(),
697            gradients[0].dtype.clone(),
698        ))
699    }
700
701    /// Weighted average aggregation
702    fn aggregate_weighted(
703        &self,
704        layer_id: &str,
705        gradients: &[GradientMessage],
706    ) -> Result<GradientMessage> {
707        // Extract weights from metadata (sample counts)
708        let weights: Vec<f32> = gradients
709            .iter()
710            .map(|g| {
711                g.metadata
712                    .get("samples")
713                    .and_then(|s| s.parse::<f32>().ok())
714                    .unwrap_or(1.0)
715            })
716            .collect();
717
718        let total_weight: f32 = weights.iter().sum();
719        let size = gradients[0].data.len();
720
721        // Weighted sum
722        let mut weighted_sum = vec![0u8; size];
723        for (grad, &weight) in gradients.iter().zip(weights.iter()) {
724            let normalized_weight = weight / total_weight;
725            for (i, &byte) in grad.data.iter().enumerate() {
726                weighted_sum[i] =
727                    weighted_sum[i].saturating_add((byte as f32 * normalized_weight) as u8);
728            }
729        }
730
731        Ok(GradientMessage::new(
732            layer_id,
733            weighted_sum,
734            gradients[0].shape.clone(),
735            gradients[0].dtype.clone(),
736        ))
737    }
738
739    /// Median aggregation (robust to outliers)
740    fn aggregate_median(
741        &self,
742        _layer_id: &str,
743        _gradients: &[GradientMessage],
744    ) -> Result<GradientMessage> {
745        // Median aggregation would require parsing the actual float values
746        // This is a placeholder implementation
747        Err(Error::NotImplemented(
748            "Median aggregation not yet implemented".to_string(),
749        ))
750    }
751
752    /// Clear gradients for a layer after aggregation
753    pub async fn clear(&self, layer_id: &str) {
754        let mut gradients = self.gradients.write().await;
755        gradients.remove(layer_id);
756    }
757
758    /// Get statistics
759    pub async fn stats(&self) -> GradientAggregatorStats {
760        let gradients = self.gradients.read().await;
761        let total_gradients: usize = gradients.values().map(|v| v.len()).sum();
762        let layers_count = gradients.len();
763
764        GradientAggregatorStats {
765            total_gradients,
766            layers_count,
767            expected_contributors: self.expected_contributors,
768        }
769    }
770}
771
772/// Gradient aggregator statistics
773#[derive(Debug, Clone)]
774pub struct GradientAggregatorStats {
775    /// Total gradients received
776    pub total_gradients: usize,
777    /// Number of layers
778    pub layers_count: usize,
779    /// Expected contributors
780    pub expected_contributors: usize,
781}
782
783/// Bidirectional gradient stream
784pub struct GradientStream {
785    /// Gradient aggregator
786    aggregator: Arc<GradientAggregator>,
787    /// Outgoing gradient queue
788    outgoing: Arc<RwLock<VecDeque<GradientMessage>>>,
789    /// Maximum queue size
790    max_queue_size: usize,
791}
792
793impl GradientStream {
794    /// Create a new gradient stream
795    pub fn new(aggregator: Arc<GradientAggregator>, max_queue_size: usize) -> Self {
796        Self {
797            aggregator,
798            outgoing: Arc::new(RwLock::new(VecDeque::new())),
799            max_queue_size,
800        }
801    }
802
803    /// Push a gradient to send
804    pub async fn push_gradient(&self, gradient: GradientMessage) -> Result<()> {
805        let mut outgoing = self.outgoing.write().await;
806        if outgoing.len() >= self.max_queue_size {
807            return Err(Error::Internal("Gradient queue is full".to_string()));
808        }
809        outgoing.push_back(gradient);
810        Ok(())
811    }
812
813    /// Pop a gradient to send
814    pub async fn pop_gradient(&self) -> Option<GradientMessage> {
815        self.outgoing.write().await.pop_front()
816    }
817
818    /// Receive a gradient
819    pub async fn receive_gradient(&self, gradient: GradientMessage) -> Result<()> {
820        self.aggregator.add_gradient(gradient).await
821    }
822
823    /// Get queue size
824    pub async fn queue_size(&self) -> usize {
825        self.outgoing.read().await.len()
826    }
827}
828
829#[cfg(test)]
830mod tests {
831    use super::*;
832
833    #[test]
834    fn test_selector_parse() {
835        let json = r#"{"type":"all"}"#;
836        let selector = Selector::from_json(json).unwrap();
837        assert!(selector.matches_all());
838
839        let json2 = r#"{"type":"recursivedepth","max_depth":5}"#;
840        let selector2 = Selector::from_json(json2).unwrap();
841        match selector2 {
842            Selector::RecursiveDepth { max_depth } => assert_eq!(max_depth, 5),
843            _ => panic!("Wrong selector type"),
844        }
845    }
846
847    #[test]
848    fn test_selector_validate() {
849        let selector = Selector::RecursiveDepth { max_depth: 0 };
850        assert!(selector.validate().is_err());
851
852        let selector2 = Selector::RecursiveDepth { max_depth: 5 };
853        assert!(selector2.validate().is_ok());
854    }
855
856    #[test]
857    fn test_traversal_state() {
858        let cid: Cid = "bafybeigdyrzt5sfp7udm7hu76uh7y26nf3efuylqabf3oclgtqy55fbzdi"
859            .parse()
860            .unwrap();
861
862        let mut state = TraversalState::new(cid, Some(3));
863        assert!(!state.is_complete());
864
865        // Get root
866        let (root_cid, depth) = state.next(TraversalMode::BreadthFirst).unwrap();
867        assert_eq!(root_cid, cid);
868        assert_eq!(depth, 0);
869
870        state.mark_visited(cid, 1024);
871        assert_eq!(state.blocks_fetched, 1);
872        assert_eq!(state.bytes_fetched, 1024);
873
874        assert!(state.is_complete());
875    }
876
877    #[test]
878    fn test_checkpoint() {
879        let cid: Cid = "bafybeigdyrzt5sfp7udm7hu76uh7y26nf3efuylqabf3oclgtqy55fbzdi"
880            .parse()
881            .unwrap();
882
883        let mut state = TraversalState::new(cid, Some(3));
884        state.mark_visited(cid, 1024);
885
886        let checkpoint = state.checkpoint();
887        assert_eq!(checkpoint.root, cid);
888        assert_eq!(checkpoint.blocks_fetched, 1);
889        assert_eq!(checkpoint.bytes_fetched, 1024);
890
891        let restored = TraversalState::from_checkpoint(checkpoint);
892        assert_eq!(restored.blocks_fetched, 1);
893        assert_eq!(restored.bytes_fetched, 1024);
894    }
895
896    #[test]
897    fn test_gradient_message() {
898        let data = vec![1, 2, 3, 4, 5];
899        let shape = vec![5];
900        let gradient = GradientMessage::new("layer1", data.clone(), shape.clone(), "f32");
901
902        assert_eq!(gradient.id, "layer1");
903        assert_eq!(gradient.data, data);
904        assert_eq!(gradient.shape, shape);
905        assert_eq!(gradient.dtype, "f32");
906        assert!(gradient.verify());
907        assert_eq!(gradient.num_elements(), 5);
908        assert_eq!(gradient.size_bytes(), 5);
909    }
910
911    #[test]
912    fn test_gradient_checksum() {
913        let data = vec![1, 2, 3, 4, 5];
914        let mut gradient = GradientMessage::new("layer1", data, vec![5], "f32");
915
916        // Verify original
917        assert!(gradient.verify());
918
919        // Corrupt data
920        gradient.data[0] = 99;
921
922        // Should fail verification
923        assert!(!gradient.verify());
924    }
925
926    #[tokio::test]
927    async fn test_gradient_aggregator() {
928        let aggregator = GradientAggregator::new(AggregationStrategy::Average, 2);
929
930        let grad1 = GradientMessage::new("layer1", vec![10, 20, 30], vec![3], "f32");
931        let grad2 = GradientMessage::new("layer1", vec![20, 30, 40], vec![3], "f32");
932
933        aggregator.add_gradient(grad1).await.unwrap();
934        aggregator.add_gradient(grad2).await.unwrap();
935
936        assert!(aggregator.is_ready("layer1").await);
937
938        let aggregated = aggregator.aggregate("layer1").await.unwrap();
939        assert_eq!(aggregated.shape, vec![3]);
940        assert_eq!(aggregated.id, "layer1");
941    }
942
943    #[tokio::test]
944    async fn test_gradient_stream() {
945        let aggregator = Arc::new(GradientAggregator::new(AggregationStrategy::Average, 1));
946        let stream = GradientStream::new(aggregator, 10);
947
948        let grad = GradientMessage::new("layer1", vec![1, 2, 3], vec![3], "f32");
949
950        // Push gradient
951        stream.push_gradient(grad.clone()).await.unwrap();
952        assert_eq!(stream.queue_size().await, 1);
953
954        // Pop gradient
955        let popped = stream.pop_gradient().await.unwrap();
956        assert_eq!(popped.id, "layer1");
957        assert_eq!(stream.queue_size().await, 0);
958
959        // Receive gradient
960        stream.receive_gradient(grad).await.unwrap();
961    }
962}