1use crate::vertex::{Vertex, VertexId};
4use rand::{thread_rng, Rng};
5use std::collections::{HashMap, HashSet};
6use thiserror::Error;
7
8#[derive(Debug, Error)]
10pub enum TipSelectionError {
11 #[error("No valid tips available")]
13 NoValidTips,
14
15 #[error("Invalid tip reference")]
17 InvalidTip,
18
19 #[error("Selection failure")]
21 SelectionFailed,
22
23 #[error("MCMC walk failed: {0}")]
25 McmcWalkFailed(String),
26
27 #[error("Weight calculation failed")]
29 WeightCalculationFailed,
30}
31
32#[derive(Debug, Clone)]
34pub struct TipSelectionConfig {
35 pub tip_count: usize,
37
38 pub max_age: u64,
40
41 pub min_confidence: f64,
43
44 pub mcmc_walk_length: usize,
46
47 pub alpha: f64,
49
50 pub max_attempts: usize,
52}
53
54impl Default for TipSelectionConfig {
55 fn default() -> Self {
56 Self {
57 tip_count: 2,
58 max_age: 3600, min_confidence: 0.5,
60 mcmc_walk_length: 1000,
61 alpha: 0.001,
62 max_attempts: 50,
63 }
64 }
65}
66
67#[derive(Debug, Clone, PartialEq)]
69pub enum ParentSelectionAlgorithm {
70 Random,
72 WeightedRandom,
74 McmcWalk,
76}
77
78#[derive(Debug, Clone)]
80pub struct VertexWeight {
81 pub cumulative_weight: f64,
83 pub direct_weight: f64,
85 pub approvers: usize,
87 pub last_updated: u64,
89}
90
91pub trait TipSelection {
93 fn init(config: TipSelectionConfig) -> Result<(), TipSelectionError>;
95
96 fn select_tips(&self) -> Result<Vec<VertexId>, TipSelectionError>;
98
99 fn is_valid_tip(&self, vertex: &Vertex) -> bool;
101
102 fn calculate_confidence(&self, tip: &VertexId) -> f64;
104
105 fn update_tips(&mut self, vertex: &Vertex) -> Result<(), TipSelectionError>;
107}
108
109pub struct AdvancedTipSelection {
111 config: TipSelectionConfig,
113
114 tips: HashSet<VertexId>,
116
117 weights: HashMap<VertexId, VertexWeight>,
119
120 adjacency: HashMap<VertexId, HashSet<VertexId>>,
122
123 reverse_adjacency: HashMap<VertexId, HashSet<VertexId>>,
125
126 algorithm: ParentSelectionAlgorithm,
128}
129
130impl AdvancedTipSelection {
131 pub fn new(config: TipSelectionConfig, algorithm: ParentSelectionAlgorithm) -> Self {
133 Self {
134 config,
135 tips: HashSet::new(),
136 weights: HashMap::new(),
137 adjacency: HashMap::new(),
138 reverse_adjacency: HashMap::new(),
139 algorithm,
140 }
141 }
142
143 pub fn add_vertex(&mut self, vertex: &Vertex) -> Result<(), TipSelectionError> {
145 let vertex_id = vertex.id.clone();
146 let parents = vertex.parents();
147
148 self.adjacency.insert(vertex_id.clone(), parents.clone());
150
151 for parent in &parents {
153 self.reverse_adjacency
154 .entry(parent.clone())
155 .or_default()
156 .insert(vertex_id.clone());
157 }
158
159 for parent in &parents {
161 self.tips.remove(parent);
162 }
163
164 self.tips.insert(vertex_id.clone());
166
167 self.update_vertex_weight(&vertex_id)?;
169
170 Ok(())
171 }
172
173 fn update_vertex_weight(&mut self, vertex_id: &VertexId) -> Result<(), TipSelectionError> {
175 let approvers = self
176 .reverse_adjacency
177 .get(vertex_id)
178 .map(|children| children.len())
179 .unwrap_or(0);
180
181 let direct_weight = 1.0;
182 let cumulative_weight = self.calculate_cumulative_weight(vertex_id)?;
183
184 let weight = VertexWeight {
185 cumulative_weight,
186 direct_weight,
187 approvers,
188 last_updated: std::time::SystemTime::now()
189 .duration_since(std::time::UNIX_EPOCH)
190 .unwrap()
191 .as_secs(),
192 };
193
194 self.weights.insert(vertex_id.clone(), weight);
195 Ok(())
196 }
197
198 fn calculate_cumulative_weight(&self, vertex_id: &VertexId) -> Result<f64, TipSelectionError> {
200 let mut visited = HashSet::new();
201 self.calculate_cumulative_weight_recursive(vertex_id, &mut visited)
202 }
203
204 fn calculate_cumulative_weight_recursive(
205 &self,
206 vertex_id: &VertexId,
207 visited: &mut HashSet<VertexId>,
208 ) -> Result<f64, TipSelectionError> {
209 if visited.contains(vertex_id) {
210 return Ok(0.0); }
212
213 visited.insert(vertex_id.clone());
214
215 let direct_weight = self
216 .weights
217 .get(vertex_id)
218 .map(|w| w.direct_weight)
219 .unwrap_or(1.0);
220
221 let mut cumulative = direct_weight;
222
223 if let Some(children) = self.reverse_adjacency.get(vertex_id) {
224 for child in children {
225 cumulative += self.calculate_cumulative_weight_recursive(child, visited)?;
226 }
227 }
228
229 Ok(cumulative)
230 }
231
232 fn mcmc_walk(&self, start: &VertexId) -> Result<VertexId, TipSelectionError> {
234 let mut current = start.clone();
235 let mut rng = thread_rng();
236
237 for _ in 0..self.config.mcmc_walk_length {
238 let children = self
240 .reverse_adjacency
241 .get(¤t)
242 .cloned()
243 .unwrap_or_default();
244
245 if children.is_empty() {
246 return Ok(current);
248 }
249
250 let mut transition_weights = Vec::new();
252 let mut candidates = Vec::new();
253
254 for child in &children {
255 let weight = self
256 .weights
257 .get(child)
258 .map(|w| w.cumulative_weight)
259 .unwrap_or(1.0);
260
261 let transition_weight = (-self.config.alpha * weight).exp();
263 transition_weights.push(transition_weight);
264 candidates.push(child.clone());
265 }
266
267 let total_weight: f64 = transition_weights.iter().sum();
269 if total_weight == 0.0 {
270 let idx = rng.gen_range(0..candidates.len());
272 current = candidates[idx].clone();
273 } else {
274 let mut cumulative = 0.0;
275 let target = rng.gen::<f64>() * total_weight;
276
277 for (i, &weight) in transition_weights.iter().enumerate() {
278 cumulative += weight;
279 if cumulative >= target {
280 current = candidates[i].clone();
281 break;
282 }
283 }
284 }
285 }
286
287 Ok(current)
288 }
289
290 fn weighted_random_selection(&self) -> Result<Vec<VertexId>, TipSelectionError> {
292 if self.tips.is_empty() {
293 return Err(TipSelectionError::NoValidTips);
294 }
295
296 let mut rng = thread_rng();
297 let mut selected = Vec::new();
298 let mut available_tips: Vec<_> = self.tips.iter().cloned().collect();
299
300 for _ in 0..self.config.tip_count.min(available_tips.len()) {
301 if available_tips.is_empty() {
302 break;
303 }
304
305 let mut weights = Vec::new();
307 for tip in &available_tips {
308 let weight = self
309 .weights
310 .get(tip)
311 .map(|w| w.cumulative_weight)
312 .unwrap_or(1.0);
313 weights.push(weight);
314 }
315
316 let total_weight: f64 = weights.iter().sum();
318 if total_weight == 0.0 {
319 let idx = rng.gen_range(0..available_tips.len());
321 selected.push(available_tips.remove(idx));
322 } else {
323 let mut cumulative = 0.0;
324 let target = rng.gen::<f64>() * total_weight;
325
326 for (i, &weight) in weights.iter().enumerate() {
327 cumulative += weight;
328 if cumulative >= target {
329 selected.push(available_tips.remove(i));
330 break;
331 }
332 }
333 }
334 }
335
336 Ok(selected)
337 }
338
339 fn random_selection(&self) -> Result<Vec<VertexId>, TipSelectionError> {
341 if self.tips.is_empty() {
342 return Err(TipSelectionError::NoValidTips);
343 }
344
345 let mut rng = thread_rng();
346 let mut tips: Vec<_> = self.tips.iter().cloned().collect();
347
348 for i in 0..tips.len() {
350 let j = rng.gen_range(i..tips.len());
351 tips.swap(i, j);
352 }
353
354 Ok(tips.into_iter().take(self.config.tip_count).collect())
355 }
356}
357
358impl TipSelection for AdvancedTipSelection {
359 fn init(config: TipSelectionConfig) -> Result<(), TipSelectionError> {
360 if config.tip_count == 0 {
362 return Err(TipSelectionError::SelectionFailed);
363 }
364
365 if config.mcmc_walk_length == 0 {
366 return Err(TipSelectionError::McmcWalkFailed(
367 "Walk length must be positive".to_string(),
368 ));
369 }
370
371 Ok(())
372 }
373
374 fn select_tips(&self) -> Result<Vec<VertexId>, TipSelectionError> {
375 match self.algorithm {
376 ParentSelectionAlgorithm::Random => self.random_selection(),
377 ParentSelectionAlgorithm::WeightedRandom => self.weighted_random_selection(),
378 ParentSelectionAlgorithm::McmcWalk => {
379 if self.tips.is_empty() {
381 return Err(TipSelectionError::NoValidTips);
382 }
383
384 let mut selected = Vec::new();
385 let mut rng = thread_rng();
386
387 for _ in 0..self.config.tip_count {
388 let start_candidates: Vec<_> = self
390 .weights
391 .iter()
392 .filter(|(_, w)| w.approvers == 0) .map(|(id, _)| id.clone())
394 .collect();
395
396 let start = if start_candidates.is_empty() {
397 let tips: Vec<_> = self.tips.iter().collect();
399 tips[rng.gen_range(0..tips.len())].clone()
400 } else {
401 start_candidates[rng.gen_range(0..start_candidates.len())].clone()
402 };
403
404 match self.mcmc_walk(&start) {
405 Ok(tip) => {
406 if !selected.contains(&tip) {
407 selected.push(tip);
408 }
409 }
410 Err(_) => {
411 let tips: Vec<_> = self.tips.iter().collect();
413 let random_tip = tips[rng.gen_range(0..tips.len())].clone();
414 if !selected.contains(&random_tip) {
415 selected.push(random_tip);
416 }
417 }
418 }
419 }
420
421 Ok(selected)
422 }
423 }
424 }
425
426 fn is_valid_tip(&self, vertex: &Vertex) -> bool {
427 let vertex_id = &vertex.id;
428
429 if let Some(children) = self.reverse_adjacency.get(vertex_id) {
431 if !children.is_empty() {
432 return false;
433 }
434 }
435
436 let current_time = std::time::SystemTime::now()
438 .duration_since(std::time::UNIX_EPOCH)
439 .unwrap()
440 .as_secs();
441
442 if current_time - vertex.timestamp > self.config.max_age {
443 return false;
444 }
445
446 if let Some(weight) = self.weights.get(vertex_id) {
448 if weight.cumulative_weight < self.config.min_confidence {
449 return false;
450 }
451 }
452
453 true
454 }
455
456 fn calculate_confidence(&self, tip: &VertexId) -> f64 {
457 self.weights
458 .get(tip)
459 .map(|w| w.cumulative_weight)
460 .unwrap_or(0.0)
461 }
462
463 fn update_tips(&mut self, vertex: &Vertex) -> Result<(), TipSelectionError> {
464 self.add_vertex(vertex)
465 }
466}