Skip to main content

de_mls/core/peer_scoring/
service.rs

1//! Reference [`PeerScoringService`] — a [`PeerScoringPlugin`] implementation
2//! over [`PeerScoreStorage`]. Threshold travels with [`ScoringConfig`];
3//! per-event score deltas are supplied at construction.
4
5use std::collections::HashMap;
6
7use crate::core::{
8    PeerScoreStorage, PeerScoringEvent, PeerScoringPlugin, ScoreEvent, ScoreOp, ScoreSnapshot,
9    ScoringConfig,
10};
11
12/// Per-conversation, per-member score tracker. Reference [`PeerScoringPlugin`]
13/// implementation. One instance per conversation; threshold travels with
14/// [`ScoringConfig`]. Storage is abstracted via [`PeerScoreStorage`] so
15/// app-layer backends (in-memory, on-disk, …) plug in without touching
16/// this protocol logic.
17pub struct PeerScoringService<S: PeerScoreStorage> {
18    storage: S,
19    score_deltas: HashMap<ScoreEvent, i64>,
20    config: ScoringConfig,
21}
22
23impl<S: PeerScoreStorage> PeerScoringService<S> {
24    pub fn new(storage: S, score_deltas: HashMap<ScoreEvent, i64>, config: ScoringConfig) -> Self {
25        Self {
26            storage,
27            score_deltas,
28            config,
29        }
30    }
31
32    /// Signed score delta for `event`. Events not in the table contribute 0.
33    fn score_delta(&self, event: ScoreEvent) -> i64 {
34        self.score_deltas.get(&event).copied().unwrap_or(0)
35    }
36}
37
38impl<S: PeerScoreStorage> PeerScoringPlugin for PeerScoringService<S> {
39    fn add_member(&mut self, member_id: &[u8]) -> Vec<PeerScoringEvent> {
40        let default = self.config.default_score;
41        self.storage.set(member_id, default);
42        // "Untracked → tracked" treated as "above → new state" for
43        // cross-detection purposes, so an unusual config with
44        // `default_score <= threshold` still surfaces the new member as
45        // a downward cross. The standard config (default 100, threshold
46        // 0) silently produces no event.
47        if default <= self.config.threshold {
48            vec![PeerScoringEvent::ThresholdCrossedDown {
49                member_id: member_id.to_vec(),
50                score: default,
51            }]
52        } else {
53            Vec::new()
54        }
55    }
56
57    fn remove_member(&mut self, member_id: &[u8]) {
58        self.storage.remove(member_id);
59    }
60
61    fn apply_op(&mut self, op: &ScoreOp) -> Vec<PeerScoringEvent> {
62        let Some(current) = self.storage.get(&op.member_id) else {
63            return Vec::new();
64        };
65        let delta = self.score_delta(op.event);
66        let new_score = current.saturating_add(delta);
67        self.storage.set(&op.member_id, new_score);
68        cross_event(
69            &op.member_id,
70            Some(current),
71            new_score,
72            self.config.threshold,
73        )
74        .into_iter()
75        .collect()
76    }
77
78    fn apply_snapshot(&mut self, snapshot: &ScoreSnapshot) -> Vec<PeerScoringEvent> {
79        let threshold = self.config.threshold;
80        let mut events = Vec::new();
81        for (member_id, new_score) in &snapshot.diverged {
82            let prior = self.storage.get(member_id);
83            self.storage.set(member_id, *new_score);
84            if let Some(ev) = cross_event(member_id, prior, *new_score, threshold) {
85                events.push(ev);
86            }
87        }
88        events
89    }
90
91    fn snapshot(&self) -> ScoreSnapshot {
92        let default = self.config.default_score;
93        let diverged = self
94            .storage
95            .all_scores()
96            .into_iter()
97            .filter(|(_, score)| *score != default)
98            .collect();
99        ScoreSnapshot { diverged }
100    }
101
102    fn score_for(&self, member_id: &[u8]) -> Option<i64> {
103        self.storage.get(member_id)
104    }
105
106    fn members_below_threshold(&self) -> Vec<Vec<u8>> {
107        let threshold = self.config.threshold;
108        self.storage
109            .all_scores()
110            .into_iter()
111            .filter(|(_, score)| *score <= threshold)
112            .map(|(id, _)| id)
113            .collect()
114    }
115
116    fn all_members_with_scores(&self) -> Vec<(Vec<u8>, i64)> {
117        self.storage.all_scores()
118    }
119
120    fn threshold(&self) -> i64 {
121        self.config.threshold
122    }
123
124    fn set_threshold(&mut self, threshold: i64) {
125        self.config.threshold = threshold;
126    }
127
128    fn default_score(&self) -> i64 {
129        self.config.default_score
130    }
131
132    fn set_default_score(&mut self, score: i64) {
133        self.config.default_score = score;
134    }
135}
136
137/// Compute the [`PeerScoringEvent`] for a transition from `prior` to
138/// `new_score`. `prior == None` (untracked) is treated as "above
139/// threshold" so a fresh entry landing at-or-below threshold emits a
140/// downward cross. Returns `None` when no cross occurred.
141fn cross_event(
142    member_id: &[u8],
143    prior: Option<i64>,
144    new_score: i64,
145    threshold: i64,
146) -> Option<PeerScoringEvent> {
147    let was_above = prior.is_none_or(|p| p > threshold);
148    let now_below = new_score <= threshold;
149    if was_above && now_below {
150        Some(PeerScoringEvent::ThresholdCrossedDown {
151            member_id: member_id.to_vec(),
152            score: new_score,
153        })
154    } else if !was_above && new_score > threshold {
155        Some(PeerScoringEvent::ThresholdCrossedUp {
156            member_id: member_id.to_vec(),
157            score: new_score,
158        })
159    } else {
160        None
161    }
162}
163
164#[cfg(test)]
165mod tests {
166    use std::collections::HashMap;
167
168    use super::*;
169
170    // ── Test scaffolding ────────────────────────────────────────────
171
172    /// Minimal in-memory storage for service tests. Production storage
173    /// lives in [`crate::app::InMemoryPeerScoreStorage`].
174    #[derive(Default)]
175    struct TestStorage(HashMap<Vec<u8>, i64>);
176
177    impl PeerScoreStorage for TestStorage {
178        fn get(&self, member_id: &[u8]) -> Option<i64> {
179            self.0.get(member_id).copied()
180        }
181        fn set(&mut self, member_id: &[u8], score: i64) {
182            self.0.insert(member_id.to_vec(), score);
183        }
184        fn remove(&mut self, member_id: &[u8]) {
185            self.0.remove(member_id);
186        }
187        fn all_scores(&self) -> Vec<(Vec<u8>, i64)> {
188            self.0.iter().map(|(k, v)| (k.clone(), *v)).collect()
189        }
190    }
191
192    fn make_service() -> PeerScoringService<TestStorage> {
193        let deltas = HashMap::from([
194            (ScoreEvent::EmergencyNoCreator, -50),
195            (ScoreEvent::EmergencyYesCreator, 20),
196            (ScoreEvent::BrokenCommit, -50),
197            (ScoreEvent::SuccessfulCommit, 10),
198            (ScoreEvent::MisbehavingCommit, -30),
199        ]);
200        PeerScoringService::new(
201            TestStorage::default(),
202            deltas,
203            ScoringConfig {
204                default_score: 100,
205                threshold: 0,
206            },
207        )
208    }
209
210    #[test]
211    fn add_member_gets_default_score() {
212        let mut svc = make_service();
213        let events = svc.add_member(b"alice");
214        assert!(events.is_empty(), "default 100 > threshold 0, no cross");
215        assert_eq!(svc.score_for(b"alice"), Some(100));
216    }
217
218    #[test]
219    fn add_member_with_default_below_threshold_emits_down_event() {
220        let mut svc = PeerScoringService::new(
221            TestStorage::default(),
222            HashMap::new(),
223            ScoringConfig {
224                default_score: -10,
225                threshold: 0,
226            },
227        );
228        let events = svc.add_member(b"alice");
229        assert_eq!(
230            events,
231            vec![PeerScoringEvent::ThresholdCrossedDown {
232                member_id: b"alice".to_vec(),
233                score: -10,
234            }]
235        );
236    }
237
238    #[test]
239    fn unknown_member_returns_none() {
240        let svc = make_service();
241        assert_eq!(svc.score_for(b"unknown"), None);
242    }
243
244    #[test]
245    fn remove_member_clears_score() {
246        let mut svc = make_service();
247        let _ = svc.add_member(b"alice");
248        svc.remove_member(b"alice");
249        assert_eq!(svc.score_for(b"alice"), None);
250    }
251
252    #[test]
253    fn apply_event_decreases_score() {
254        let mut svc = make_service();
255        let _ = svc.add_member(b"alice");
256        let events = svc.apply_op(&ScoreOp {
257            member_id: b"alice".to_vec(),
258            event: ScoreEvent::EmergencyNoCreator,
259        });
260        assert!(events.is_empty(), "100 → 50 stays above threshold 0");
261        assert_eq!(svc.score_for(b"alice"), Some(50));
262    }
263
264    #[test]
265    fn apply_event_unknown_member_returns_no_events() {
266        let mut svc = make_service();
267        let events = svc.apply_op(&ScoreOp {
268            member_id: b"unknown".to_vec(),
269            event: ScoreEvent::EmergencyNoCreator,
270        });
271        assert!(events.is_empty());
272    }
273
274    #[test]
275    fn multiple_events_accumulate() {
276        let mut svc = make_service();
277        let _ = svc.add_member(b"alice");
278        for event in [
279            ScoreEvent::EmergencyNoCreator,
280            ScoreEvent::MisbehavingCommit,
281            ScoreEvent::SuccessfulCommit,
282        ] {
283            let _ = svc.apply_op(&ScoreOp {
284                member_id: b"alice".to_vec(),
285                event,
286            });
287        }
288        assert_eq!(svc.score_for(b"alice"), Some(30));
289    }
290
291    #[test]
292    fn apply_op_emits_threshold_crossed_up_on_recovery() {
293        let mut svc = make_service();
294        let _ = svc.add_member(b"alice");
295        // 100 → 50 → 0 (down), 0 → 20 (up via EmergencyYesCreator).
296        for _ in 0..2 {
297            let _ = svc.apply_op(&ScoreOp {
298                member_id: b"alice".to_vec(),
299                event: ScoreEvent::EmergencyNoCreator,
300            });
301        }
302        let events = svc.apply_op(&ScoreOp {
303            member_id: b"alice".to_vec(),
304            event: ScoreEvent::EmergencyYesCreator,
305        });
306        assert_eq!(
307            events,
308            vec![PeerScoringEvent::ThresholdCrossedUp {
309                member_id: b"alice".to_vec(),
310                score: 20,
311            }]
312        );
313    }
314
315    #[test]
316    fn threshold_cross_down_emits_event() {
317        let mut svc = make_service();
318        let _ = svc.add_member(b"alice");
319
320        // 100 → 50, still above threshold 0.
321        let events = svc.apply_op(&ScoreOp {
322            member_id: b"alice".to_vec(),
323            event: ScoreEvent::EmergencyNoCreator,
324        });
325        assert!(events.is_empty(), "above threshold, no event");
326
327        // 50 → 0, crosses to at-or-below threshold.
328        let events = svc.apply_op(&ScoreOp {
329            member_id: b"alice".to_vec(),
330            event: ScoreEvent::EmergencyNoCreator,
331        });
332        assert_eq!(
333            events,
334            vec![PeerScoringEvent::ThresholdCrossedDown {
335                member_id: b"alice".to_vec(),
336                score: 0,
337            }]
338        );
339
340        // 0 → -50, already below — no further event.
341        let events = svc.apply_op(&ScoreOp {
342            member_id: b"alice".to_vec(),
343            event: ScoreEvent::BrokenCommit,
344        });
345        assert!(events.is_empty(), "already below threshold, no event");
346    }
347
348    #[test]
349    fn apply_ops_concatenates_events_in_order() {
350        let mut svc = make_service();
351        let _ = svc.add_member(b"alice");
352        let _ = svc.add_member(b"bob");
353        // Drop alice across threshold (-50 + -50 = 0 ≤ 0) and bob too in
354        // the same batch.
355        let ops = vec![
356            ScoreOp {
357                member_id: b"alice".to_vec(),
358                event: ScoreEvent::BrokenCommit,
359            },
360            ScoreOp {
361                member_id: b"alice".to_vec(),
362                event: ScoreEvent::BrokenCommit,
363            },
364            ScoreOp {
365                member_id: b"bob".to_vec(),
366                event: ScoreEvent::BrokenCommit,
367            },
368            ScoreOp {
369                member_id: b"bob".to_vec(),
370                event: ScoreEvent::BrokenCommit,
371            },
372        ];
373        let events = svc.apply_ops(&ops);
374        assert_eq!(events.len(), 2);
375        assert!(matches!(
376            events[0],
377            PeerScoringEvent::ThresholdCrossedDown { ref member_id, .. } if member_id == b"alice"
378        ));
379        assert!(matches!(
380            events[1],
381            PeerScoringEvent::ThresholdCrossedDown { ref member_id, .. } if member_id == b"bob"
382        ));
383    }
384
385    #[test]
386    fn snapshot_includes_only_diverged_scores() {
387        let mut svc = make_service();
388        let _ = svc.add_member(b"alice");
389        let _ = svc.add_member(b"bob");
390        let _ = svc.add_member(b"charlie");
391        let _ = svc.apply_op(&ScoreOp {
392            member_id: b"alice".to_vec(),
393            event: ScoreEvent::SuccessfulCommit,
394        });
395        let snap = svc.snapshot();
396        let ids: Vec<&[u8]> = snap.diverged.iter().map(|(id, _)| id.as_slice()).collect();
397        assert_eq!(ids, vec![b"alice".as_slice()]);
398        assert_eq!(snap.diverged[0].1, 110);
399    }
400
401    #[test]
402    fn apply_snapshot_emits_event_only_on_actual_cross() {
403        let mut svc = make_service();
404        let _ = svc.add_member(b"alice");
405        let _ = svc.add_member(b"bob");
406        let snap = ScoreSnapshot {
407            diverged: vec![(b"alice".to_vec(), -10), (b"bob".to_vec(), 50)],
408        };
409        let events = svc.apply_snapshot(&snap);
410        assert_eq!(
411            events,
412            vec![PeerScoringEvent::ThresholdCrossedDown {
413                member_id: b"alice".to_vec(),
414                score: -10,
415            }]
416        );
417        assert_eq!(svc.score_for(b"alice"), Some(-10));
418        assert_eq!(svc.score_for(b"bob"), Some(50));
419    }
420
421    #[test]
422    fn apply_snapshot_idempotent_on_repeat() {
423        let mut svc = make_service();
424        let _ = svc.add_member(b"alice");
425        let snap = ScoreSnapshot {
426            diverged: vec![(b"alice".to_vec(), -10)],
427        };
428        let first = svc.apply_snapshot(&snap);
429        assert_eq!(first.len(), 1, "first apply emits the cross");
430        let second = svc.apply_snapshot(&snap);
431        assert!(
432            second.is_empty(),
433            "second apply on unchanged state emits nothing"
434        );
435    }
436
437    #[test]
438    fn apply_snapshot_emits_threshold_crossed_up_on_recovery() {
439        let mut svc = make_service();
440        let _ = svc.add_member(b"alice");
441        // First push alice below threshold.
442        let _ = svc.apply_snapshot(&ScoreSnapshot {
443            diverged: vec![(b"alice".to_vec(), -10)],
444        });
445        // Now snapshot her back above.
446        let events = svc.apply_snapshot(&ScoreSnapshot {
447            diverged: vec![(b"alice".to_vec(), 50)],
448        });
449        assert_eq!(
450            events,
451            vec![PeerScoringEvent::ThresholdCrossedUp {
452                member_id: b"alice".to_vec(),
453                score: 50,
454            }]
455        );
456    }
457
458    #[test]
459    fn apply_snapshot_for_untracked_below_threshold_emits_down() {
460        // Untracked → tracked (below threshold) treated as a downward
461        // cross from "above-by-default."
462        let mut svc = make_service();
463        let events = svc.apply_snapshot(&ScoreSnapshot {
464            diverged: vec![(b"newcomer".to_vec(), -10)],
465        });
466        assert_eq!(
467            events,
468            vec![PeerScoringEvent::ThresholdCrossedDown {
469                member_id: b"newcomer".to_vec(),
470                score: -10,
471            }]
472        );
473    }
474
475    #[test]
476    fn members_below_threshold_filters_correctly() {
477        let mut svc = make_service();
478        let _ = svc.add_member(b"alice");
479        let _ = svc.add_member(b"bob");
480        let _ = svc.add_member(b"charlie");
481        for event in [ScoreEvent::EmergencyNoCreator, ScoreEvent::BrokenCommit] {
482            let _ = svc.apply_op(&ScoreOp {
483                member_id: b"alice".to_vec(),
484                event,
485            });
486        }
487        for _ in 0..2 {
488            let _ = svc.apply_op(&ScoreOp {
489                member_id: b"charlie".to_vec(),
490                event: ScoreEvent::EmergencyNoCreator,
491            });
492        }
493        let below = svc.members_below_threshold();
494        assert!(below.contains(&b"alice".to_vec()));
495        assert!(below.contains(&b"charlie".to_vec()));
496        assert!(!below.contains(&b"bob".to_vec()));
497    }
498
499    #[test]
500    fn set_threshold_changes_below_threshold_set() {
501        let mut svc = make_service();
502        let _ = svc.add_member(b"alice");
503        // Apply via snapshot to set an absolute score without going
504        // through the delta table.
505        let _ = svc.apply_snapshot(&ScoreSnapshot {
506            diverged: vec![(b"alice".to_vec(), -10)],
507        });
508
509        svc.set_threshold(-50);
510        assert!(!svc.members_below_threshold().contains(&b"alice".to_vec()));
511
512        svc.set_threshold(-5);
513        assert!(svc.members_below_threshold().contains(&b"alice".to_vec()));
514    }
515
516    #[test]
517    fn score_saturates_no_overflow() {
518        let mut svc = PeerScoringService::new(
519            TestStorage::default(),
520            HashMap::from([(ScoreEvent::SuccessfulCommit, i64::MAX)]),
521            ScoringConfig {
522                default_score: i64::MAX,
523                threshold: 0,
524            },
525        );
526        let _ = svc.add_member(b"alice");
527        let _ = svc.apply_op(&ScoreOp {
528            member_id: b"alice".to_vec(),
529            event: ScoreEvent::SuccessfulCommit,
530        });
531        assert_eq!(svc.score_for(b"alice"), Some(i64::MAX));
532    }
533
534    #[test]
535    fn unknown_event_yields_zero_delta() {
536        let mut svc = PeerScoringService::new(
537            TestStorage::default(),
538            HashMap::from([(ScoreEvent::EmergencyNoCreator, -50)]),
539            ScoringConfig {
540                default_score: 100,
541                threshold: 0,
542            },
543        );
544        let _ = svc.add_member(b"alice");
545        let _ = svc.apply_op(&ScoreOp {
546            member_id: b"alice".to_vec(),
547            event: ScoreEvent::SuccessfulCommit,
548        });
549        assert_eq!(svc.score_for(b"alice"), Some(100));
550    }
551}