qudag_dag/
tip_selection.rs

1//! DAG tip selection implementation.
2
3use crate::vertex::{Vertex, VertexId};
4use rand::{thread_rng, Rng};
5use std::collections::{HashMap, HashSet};
6use thiserror::Error;
7
8/// Errors that can occur during tip selection.
9#[derive(Debug, Error)]
10pub enum TipSelectionError {
11    /// No valid tips available
12    #[error("No valid tips available")]
13    NoValidTips,
14
15    /// Invalid tip reference
16    #[error("Invalid tip reference")]
17    InvalidTip,
18
19    /// Selection failure
20    #[error("Selection failure")]
21    SelectionFailed,
22
23    /// MCMC walk failed
24    #[error("MCMC walk failed: {0}")]
25    McmcWalkFailed(String),
26
27    /// Weight calculation failed
28    #[error("Weight calculation failed")]
29    WeightCalculationFailed,
30}
31
32/// Tip selection algorithm configuration.
33#[derive(Debug, Clone)]
34pub struct TipSelectionConfig {
35    /// Number of tips to select
36    pub tip_count: usize,
37
38    /// Maximum tip age (in seconds)
39    pub max_age: u64,
40
41    /// Minimum confidence score
42    pub min_confidence: f64,
43
44    /// MCMC walk length
45    pub mcmc_walk_length: usize,
46
47    /// Alpha parameter for weighted selection
48    pub alpha: f64,
49
50    /// Maximum number of attempts
51    pub max_attempts: usize,
52}
53
54impl Default for TipSelectionConfig {
55    fn default() -> Self {
56        Self {
57            tip_count: 2,
58            max_age: 3600, // 1 hour
59            min_confidence: 0.5,
60            mcmc_walk_length: 1000,
61            alpha: 0.001,
62            max_attempts: 50,
63        }
64    }
65}
66
67/// Parent selection algorithm type
68#[derive(Debug, Clone, PartialEq)]
69pub enum ParentSelectionAlgorithm {
70    /// Random selection from tips
71    Random,
72    /// Weighted random selection based on vertex weight
73    WeightedRandom,
74    /// Monte Carlo Markov Chain (MCMC) walk
75    McmcWalk,
76}
77
78/// Vertex weight information for parent selection
79#[derive(Debug, Clone)]
80pub struct VertexWeight {
81    /// Cumulative weight of the vertex
82    pub cumulative_weight: f64,
83    /// Direct weight of the vertex
84    pub direct_weight: f64,
85    /// Number of approvers
86    pub approvers: usize,
87    /// Last update timestamp
88    pub last_updated: u64,
89}
90
91/// DAG tip selection trait defining the interface for tip selection algorithms.
92pub trait TipSelection {
93    /// Initialize tip selection with configuration.
94    fn init(config: TipSelectionConfig) -> Result<(), TipSelectionError>;
95
96    /// Select tips for a new vertex.
97    fn select_tips(&self) -> Result<Vec<VertexId>, TipSelectionError>;
98
99    /// Check if a vertex is eligible as a tip.
100    fn is_valid_tip(&self, vertex: &Vertex) -> bool;
101
102    /// Calculate confidence score for a tip.
103    fn calculate_confidence(&self, tip: &VertexId) -> f64;
104
105    /// Update tip pool with new vertex.
106    fn update_tips(&mut self, vertex: &Vertex) -> Result<(), TipSelectionError>;
107}
108
109/// Advanced tip selection implementation with MCMC and weighted selection
110pub struct AdvancedTipSelection {
111    /// Configuration
112    config: TipSelectionConfig,
113
114    /// Current tips
115    tips: HashSet<VertexId>,
116
117    /// Vertex weights for weighted selection
118    weights: HashMap<VertexId, VertexWeight>,
119
120    /// Vertex adjacency information
121    adjacency: HashMap<VertexId, HashSet<VertexId>>,
122
123    /// Reverse adjacency (children)
124    reverse_adjacency: HashMap<VertexId, HashSet<VertexId>>,
125
126    /// Algorithm to use
127    algorithm: ParentSelectionAlgorithm,
128}
129
130impl AdvancedTipSelection {
131    /// Create a new advanced tip selection instance
132    pub fn new(config: TipSelectionConfig, algorithm: ParentSelectionAlgorithm) -> Self {
133        Self {
134            config,
135            tips: HashSet::new(),
136            weights: HashMap::new(),
137            adjacency: HashMap::new(),
138            reverse_adjacency: HashMap::new(),
139            algorithm,
140        }
141    }
142
143    /// Add a vertex to the DAG structure
144    pub fn add_vertex(&mut self, vertex: &Vertex) -> Result<(), TipSelectionError> {
145        let vertex_id = vertex.id.clone();
146        let parents = vertex.parents();
147
148        // Add to adjacency lists
149        self.adjacency.insert(vertex_id.clone(), parents.clone());
150
151        // Update reverse adjacency
152        for parent in &parents {
153            self.reverse_adjacency
154                .entry(parent.clone())
155                .or_default()
156                .insert(vertex_id.clone());
157        }
158
159        // Remove parents from tips (they now have children)
160        for parent in &parents {
161            self.tips.remove(parent);
162        }
163
164        // Add this vertex as a new tip
165        self.tips.insert(vertex_id.clone());
166
167        // Update weights
168        self.update_vertex_weight(&vertex_id)?;
169
170        Ok(())
171    }
172
173    /// Update weight for a vertex
174    fn update_vertex_weight(&mut self, vertex_id: &VertexId) -> Result<(), TipSelectionError> {
175        let approvers = self
176            .reverse_adjacency
177            .get(vertex_id)
178            .map(|children| children.len())
179            .unwrap_or(0);
180
181        let direct_weight = 1.0;
182        let cumulative_weight = self.calculate_cumulative_weight(vertex_id)?;
183
184        let weight = VertexWeight {
185            cumulative_weight,
186            direct_weight,
187            approvers,
188            last_updated: std::time::SystemTime::now()
189                .duration_since(std::time::UNIX_EPOCH)
190                .unwrap()
191                .as_secs(),
192        };
193
194        self.weights.insert(vertex_id.clone(), weight);
195        Ok(())
196    }
197
198    /// Calculate cumulative weight using DFS
199    fn calculate_cumulative_weight(&self, vertex_id: &VertexId) -> Result<f64, TipSelectionError> {
200        let mut visited = HashSet::new();
201        self.calculate_cumulative_weight_recursive(vertex_id, &mut visited)
202    }
203
204    fn calculate_cumulative_weight_recursive(
205        &self,
206        vertex_id: &VertexId,
207        visited: &mut HashSet<VertexId>,
208    ) -> Result<f64, TipSelectionError> {
209        if visited.contains(vertex_id) {
210            return Ok(0.0); // Avoid cycles
211        }
212
213        visited.insert(vertex_id.clone());
214
215        let direct_weight = self
216            .weights
217            .get(vertex_id)
218            .map(|w| w.direct_weight)
219            .unwrap_or(1.0);
220
221        let mut cumulative = direct_weight;
222
223        if let Some(children) = self.reverse_adjacency.get(vertex_id) {
224            for child in children {
225                cumulative += self.calculate_cumulative_weight_recursive(child, visited)?;
226            }
227        }
228
229        Ok(cumulative)
230    }
231
232    /// Perform MCMC walk for tip selection
233    fn mcmc_walk(&self, start: &VertexId) -> Result<VertexId, TipSelectionError> {
234        let mut current = start.clone();
235        let mut rng = thread_rng();
236
237        for _ in 0..self.config.mcmc_walk_length {
238            // Get children of current vertex
239            let children = self
240                .reverse_adjacency
241                .get(&current)
242                .cloned()
243                .unwrap_or_default();
244
245            if children.is_empty() {
246                // Reached a tip
247                return Ok(current);
248            }
249
250            // Calculate transition probabilities based on weights
251            let mut transition_weights = Vec::new();
252            let mut candidates = Vec::new();
253
254            for child in &children {
255                let weight = self
256                    .weights
257                    .get(child)
258                    .map(|w| w.cumulative_weight)
259                    .unwrap_or(1.0);
260
261                // Apply exponential transformation for better selection
262                let transition_weight = (-self.config.alpha * weight).exp();
263                transition_weights.push(transition_weight);
264                candidates.push(child.clone());
265            }
266
267            // Select next vertex based on weights
268            let total_weight: f64 = transition_weights.iter().sum();
269            if total_weight == 0.0 {
270                // Uniform selection if all weights are zero
271                let idx = rng.gen_range(0..candidates.len());
272                current = candidates[idx].clone();
273            } else {
274                let mut cumulative = 0.0;
275                let target = rng.gen::<f64>() * total_weight;
276
277                for (i, &weight) in transition_weights.iter().enumerate() {
278                    cumulative += weight;
279                    if cumulative >= target {
280                        current = candidates[i].clone();
281                        break;
282                    }
283                }
284            }
285        }
286
287        Ok(current)
288    }
289
290    /// Weighted random selection from tips
291    fn weighted_random_selection(&self) -> Result<Vec<VertexId>, TipSelectionError> {
292        if self.tips.is_empty() {
293            return Err(TipSelectionError::NoValidTips);
294        }
295
296        let mut rng = thread_rng();
297        let mut selected = Vec::new();
298        let mut available_tips: Vec<_> = self.tips.iter().cloned().collect();
299
300        for _ in 0..self.config.tip_count.min(available_tips.len()) {
301            if available_tips.is_empty() {
302                break;
303            }
304
305            // Calculate weights for remaining tips
306            let mut weights = Vec::new();
307            for tip in &available_tips {
308                let weight = self
309                    .weights
310                    .get(tip)
311                    .map(|w| w.cumulative_weight)
312                    .unwrap_or(1.0);
313                weights.push(weight);
314            }
315
316            // Select based on weights
317            let total_weight: f64 = weights.iter().sum();
318            if total_weight == 0.0 {
319                // Uniform selection
320                let idx = rng.gen_range(0..available_tips.len());
321                selected.push(available_tips.remove(idx));
322            } else {
323                let mut cumulative = 0.0;
324                let target = rng.gen::<f64>() * total_weight;
325
326                for (i, &weight) in weights.iter().enumerate() {
327                    cumulative += weight;
328                    if cumulative >= target {
329                        selected.push(available_tips.remove(i));
330                        break;
331                    }
332                }
333            }
334        }
335
336        Ok(selected)
337    }
338
339    /// Random selection from tips
340    fn random_selection(&self) -> Result<Vec<VertexId>, TipSelectionError> {
341        if self.tips.is_empty() {
342            return Err(TipSelectionError::NoValidTips);
343        }
344
345        let mut rng = thread_rng();
346        let mut tips: Vec<_> = self.tips.iter().cloned().collect();
347
348        // Shuffle and take the required number
349        for i in 0..tips.len() {
350            let j = rng.gen_range(i..tips.len());
351            tips.swap(i, j);
352        }
353
354        Ok(tips.into_iter().take(self.config.tip_count).collect())
355    }
356}
357
358impl TipSelection for AdvancedTipSelection {
359    fn init(config: TipSelectionConfig) -> Result<(), TipSelectionError> {
360        // Validation
361        if config.tip_count == 0 {
362            return Err(TipSelectionError::SelectionFailed);
363        }
364
365        if config.mcmc_walk_length == 0 {
366            return Err(TipSelectionError::McmcWalkFailed(
367                "Walk length must be positive".to_string(),
368            ));
369        }
370
371        Ok(())
372    }
373
374    fn select_tips(&self) -> Result<Vec<VertexId>, TipSelectionError> {
375        match self.algorithm {
376            ParentSelectionAlgorithm::Random => self.random_selection(),
377            ParentSelectionAlgorithm::WeightedRandom => self.weighted_random_selection(),
378            ParentSelectionAlgorithm::McmcWalk => {
379                // For MCMC, start from genesis and walk to tips
380                if self.tips.is_empty() {
381                    return Err(TipSelectionError::NoValidTips);
382                }
383
384                let mut selected = Vec::new();
385                let mut rng = thread_rng();
386
387                for _ in 0..self.config.tip_count {
388                    // Find a genesis or low-weight vertex to start from
389                    let start_candidates: Vec<_> = self
390                        .weights
391                        .iter()
392                        .filter(|(_, w)| w.approvers == 0) // Genesis vertices
393                        .map(|(id, _)| id.clone())
394                        .collect();
395
396                    let start = if start_candidates.is_empty() {
397                        // Use random tip if no genesis found
398                        let tips: Vec<_> = self.tips.iter().collect();
399                        tips[rng.gen_range(0..tips.len())].clone()
400                    } else {
401                        start_candidates[rng.gen_range(0..start_candidates.len())].clone()
402                    };
403
404                    match self.mcmc_walk(&start) {
405                        Ok(tip) => {
406                            if !selected.contains(&tip) {
407                                selected.push(tip);
408                            }
409                        }
410                        Err(_) => {
411                            // Fallback to random selection
412                            let tips: Vec<_> = self.tips.iter().collect();
413                            let random_tip = tips[rng.gen_range(0..tips.len())].clone();
414                            if !selected.contains(&random_tip) {
415                                selected.push(random_tip);
416                            }
417                        }
418                    }
419                }
420
421                Ok(selected)
422            }
423        }
424    }
425
426    fn is_valid_tip(&self, vertex: &Vertex) -> bool {
427        let vertex_id = &vertex.id;
428
429        // Check if vertex has no children (is a tip)
430        if let Some(children) = self.reverse_adjacency.get(vertex_id) {
431            if !children.is_empty() {
432                return false;
433            }
434        }
435
436        // Check age constraint
437        let current_time = std::time::SystemTime::now()
438            .duration_since(std::time::UNIX_EPOCH)
439            .unwrap()
440            .as_secs();
441
442        if current_time - vertex.timestamp > self.config.max_age {
443            return false;
444        }
445
446        // Check confidence constraint
447        if let Some(weight) = self.weights.get(vertex_id) {
448            if weight.cumulative_weight < self.config.min_confidence {
449                return false;
450            }
451        }
452
453        true
454    }
455
456    fn calculate_confidence(&self, tip: &VertexId) -> f64 {
457        self.weights
458            .get(tip)
459            .map(|w| w.cumulative_weight)
460            .unwrap_or(0.0)
461    }
462
463    fn update_tips(&mut self, vertex: &Vertex) -> Result<(), TipSelectionError> {
464        self.add_vertex(vertex)
465    }
466}