Skip to main content

radiate_core/fitness/
novelty.rs

1use crate::{
2    BatchFitnessFunction, BatchedFn, CosineDistance, EuclideanDistance, FitnessFunction,
3    HammingDistance, diversity::Distance, math::knn::KNN,
4};
5use radiate_utils::WindowBuffer;
6use std::sync::{Arc, RwLock};
7
8const DEFAULT_ARCHIVE_SIZE: usize = 1000;
9const DEFAULT_K: usize = 15;
10const DEFAULT_THRESHOLD: f32 = 0.5;
11
12pub trait Novelty<T>: Send + Sync {
13    fn description(&self, member: &T) -> Vec<f32>;
14
15    /// Compute descriptors for a whole batch. Default fans out to `description`.
16    /// Override this on your own concrete `Novelty` impl if you can vectorise the
17    /// batch path (shared setup, SIMD, GPU, etc.) — the closure blanket impl
18    /// always takes the default.
19    fn batch_description(&self, members: &[T]) -> Vec<Vec<f32>> {
20        members.iter().map(|m| self.description(m)).collect()
21    }
22}
23
24impl<T, F> Novelty<T> for F
25where
26    F: Fn(&T) -> Vec<f32> + Send + Sync,
27{
28    fn description(&self, member: &T) -> Vec<f32> {
29        self(member)
30    }
31}
32
33impl<T, F> Novelty<T> for BatchedFn<F>
34where
35    F: Fn(&[T]) -> Vec<Vec<f32>> + Send + Sync,
36{
37    fn description(&self, member: &T) -> Vec<f32> {
38        (self.0)(std::slice::from_ref(member))
39            .into_iter()
40            .next()
41            .unwrap_or_default()
42    }
43
44    fn batch_description(&self, members: &[T]) -> Vec<Vec<f32>> {
45        (self.0)(members)
46    }
47}
48
49#[derive(Clone)]
50pub struct NoveltySearch<T> {
51    pub behavior: Arc<dyn Novelty<T>>,
52    pub archive: Arc<RwLock<WindowBuffer<Vec<f32>>>>,
53    pub k: usize,
54    pub threshold: f32,
55    pub distance_fn: Arc<dyn Distance<Vec<f32>>>,
56}
57
58impl<T> NoveltySearch<T> {
59    pub fn new<N>(behavior: N) -> Self
60    where
61        N: Novelty<T> + Send + Sync + 'static,
62    {
63        NoveltySearch {
64            behavior: Arc::new(behavior),
65            archive: Arc::new(RwLock::new(WindowBuffer::with_capacity(
66                DEFAULT_ARCHIVE_SIZE,
67            ))),
68            k: DEFAULT_K,
69            threshold: DEFAULT_THRESHOLD,
70            distance_fn: Arc::new(EuclideanDistance),
71        }
72    }
73
74    /// Construct from a batch-shaped descriptor closure.
75    /// Equivalent to `NoveltySearch::new(BatchedFn(f))`.
76    pub fn from_batch_fn<F>(f: F) -> Self
77    where
78        F: Fn(&[T]) -> Vec<Vec<f32>> + Send + Sync + 'static,
79        T: 'static,
80    {
81        Self::new(BatchedFn(f))
82    }
83
84    pub fn k(mut self, k: usize) -> Self {
85        self.k = k;
86        self
87    }
88
89    pub fn threshold(mut self, threshold: f32) -> Self {
90        self.threshold = threshold;
91        self
92    }
93
94    pub fn archive_size(mut self, size: usize) -> Self {
95        self.archive = Arc::new(RwLock::new(WindowBuffer::with_capacity(size)));
96        self
97    }
98
99    pub fn cosine_distance(mut self) -> Self {
100        self.distance_fn = Arc::new(CosineDistance);
101        self
102    }
103
104    pub fn euclidean_distance(mut self) -> Self {
105        self.distance_fn = Arc::new(EuclideanDistance);
106        self
107    }
108
109    pub fn hamming_distance(mut self) -> Self {
110        self.distance_fn = Arc::new(HammingDistance);
111        self
112    }
113
114    fn novelty_score(&self, descriptor: &Vec<f32>, archive: &WindowBuffer<Vec<f32>>) -> f32 {
115        let slice = archive.values();
116        let mut knn = KNN::new(slice, Arc::clone(&self.distance_fn));
117        let query = knn.query_point(descriptor, self.k);
118
119        let min_dist = query.min_distance;
120        let max_dist = query.max_distance;
121        let range = max_dist - min_dist;
122
123        if range < f32::EPSILON {
124            return if min_dist < f32::EPSILON { 0.0 } else { 0.5 };
125        }
126
127        let avg_dist = query.average_distance();
128        (avg_dist - min_dist) / range
129    }
130
131    fn evaluate_internal(&self, individual: &T) -> f32 {
132        let description = self.behavior.description(individual);
133        let mut archive = self.archive.write().unwrap();
134
135        if archive.is_empty() {
136            archive.push(description);
137            return 0.5;
138        }
139
140        let novelty = self.novelty_score(&description, &archive);
141        if novelty > self.threshold || archive.len() < self.k {
142            archive.push(description);
143        }
144
145        novelty
146    }
147
148    fn evaluate_batch_internal(&self, individuals: &[T]) -> Vec<f32> {
149        let descriptions = self.behavior.batch_description(individuals);
150        let mut archive = self.archive.write().unwrap();
151
152        if archive.is_empty() {
153            let result = vec![0.5; descriptions.len()];
154            for desc in descriptions {
155                archive.push(desc);
156            }
157
158            return result;
159        }
160
161        // Score every descriptor against the same pre-batch archive snapshot —
162        // no archive mutations happen during scoring, so individual-N's score
163        // does not depend on its position in the batch.
164        let mut scores = Vec::with_capacity(descriptions.len());
165        for desc in descriptions.into_iter() {
166            let score = self.novelty_score(&desc, &archive);
167
168            if score > self.threshold || archive.len() < self.k {
169                archive.push(desc);
170            }
171
172            scores.push(score);
173        }
174
175        scores
176    }
177}
178
179impl<T> FitnessFunction<T, f32> for NoveltySearch<T>
180where
181    T: Send + Sync,
182{
183    fn evaluate(&self, individual: T) -> f32 {
184        self.evaluate_internal(&individual)
185    }
186}
187
188impl<T> FitnessFunction<&T, f32> for NoveltySearch<T>
189where
190    T: Send + Sync,
191{
192    fn evaluate(&self, individual: &T) -> f32 {
193        self.evaluate_internal(individual)
194    }
195}
196
197impl<T> BatchFitnessFunction<T, f32> for NoveltySearch<T>
198where
199    T: Send + Sync,
200{
201    fn evaluate(&self, individuals: Vec<T>) -> Vec<f32> {
202        self.evaluate_batch_internal(&individuals)
203    }
204}
205
206#[cfg(test)]
207mod tests {
208    use super::*;
209    use crate::{BatchFitnessFunction, FitnessFunction};
210
211    fn make_ns(k: usize, threshold: f32) -> NoveltySearch<Vec<f32>> {
212        NoveltySearch::new(|v: &Vec<f32>| v.clone())
213            .k(k)
214            .threshold(threshold)
215            .archive_size(100)
216    }
217
218    fn seed(ns: &NoveltySearch<Vec<f32>>, points: impl IntoIterator<Item = Vec<f32>>) {
219        let mut archive = ns.archive.write().unwrap();
220        for p in points {
221            archive.push(p);
222        }
223    }
224
225    fn archive_view_len(ns: &NoveltySearch<Vec<f32>>) -> usize {
226        ns.archive.read().unwrap().values().len()
227    }
228
229    fn eval(ns: &NoveltySearch<Vec<f32>>, v: Vec<f32>) -> f32 {
230        <NoveltySearch<Vec<f32>> as FitnessFunction<Vec<f32>, f32>>::evaluate(ns, v)
231    }
232
233    fn eval_batch(ns: &NoveltySearch<Vec<f32>>, vs: Vec<Vec<f32>>) -> Vec<f32> {
234        <NoveltySearch<Vec<f32>> as BatchFitnessFunction<Vec<f32>, f32>>::evaluate(ns, vs)
235    }
236
237    #[test]
238    fn empty_archive_single_eval_scores_half_and_seeds_one_entry() {
239        let ns = make_ns(3, 0.5);
240        let score = eval(&ns, vec![1.0, 0.0]);
241        assert_eq!(score, 0.5);
242        assert_eq!(archive_view_len(&ns), 1);
243    }
244
245    #[test]
246    fn empty_archive_batch_eval_scores_half_and_seeds_every_entry() {
247        let ns = make_ns(3, 0.5);
248        let scores = eval_batch(
249            &ns,
250            vec![vec![0.0], vec![1.0], vec![2.0], vec![3.0], vec![4.0]],
251        );
252
253        assert_eq!(scores, vec![0.5; 5]);
254        assert_eq!(archive_view_len(&ns), 5);
255    }
256
257    #[test]
258    fn identical_to_sole_archive_point_scores_zero() {
259        let ns = make_ns(3, 0.99);
260        seed(&ns, [vec![1.0, 0.0]]);
261
262        // distance is clamped to 1e-12 inside KNN, which is < f32::EPSILON,
263        // so the degenerate branch returns 0.0.
264        let score = eval(&ns, vec![1.0, 0.0]);
265        assert_eq!(score, 0.0);
266    }
267
268    #[test]
269    fn different_from_sole_archive_point_scores_half_neutral() {
270        let ns = make_ns(3, 0.99);
271        seed(&ns, [vec![0.0]]);
272
273        // single archive point => min == max, range == 0, min > epsilon => 0.5
274        let score = eval(&ns, vec![5.0]);
275        assert_eq!(score, 0.5);
276    }
277
278    #[test]
279    fn bootstrap_admits_first_k_individuals_under_strict_threshold() {
280        // threshold 0.99 means scores effectively can't pass on merit.
281        // archive.len() < k path must still admit the first k.
282        let ns = make_ns(3, 0.99);
283
284        for _ in 0..3 {
285            eval(&ns, vec![1.0]);
286        }
287        assert_eq!(archive_view_len(&ns), 3);
288
289        // 4th identical individual: archive.len() == k, score won't beat 0.99 → not added.
290        eval(&ns, vec![1.0]);
291        assert_eq!(archive_view_len(&ns), 3);
292    }
293
294    #[test]
295    fn post_bootstrap_score_below_threshold_does_not_admit() {
296        let ns = make_ns(2, 0.2);
297        // 3 seeded points → past bootstrap (archive.len() >= k=2).
298        seed(&ns, [vec![0.0], vec![5.0], vec![10.0]]);
299
300        // query=[2.0]: 2-NN are [0.0]@2 and [5.0]@3, global_max=8 (to [10.0]).
301        // avg=2.5, range=6 → score = 0.5/6 ≈ 0.0833 < 0.2 → not added.
302        let score = eval(&ns, vec![2.0]);
303        assert!(
304            (score - 0.083_333).abs() < 1e-4,
305            "expected ≈0.0833, got {score}"
306        );
307        assert_eq!(archive_view_len(&ns), 3);
308    }
309
310    #[test]
311    fn novelty_score_matches_normalization_formula() {
312        // threshold = -1 so admission never filters; we want to inspect the score itself.
313        let ns = make_ns(2, -1.0);
314        seed(&ns, [vec![0.0], vec![5.0], vec![10.0]]);
315
316        // query=[12.0]: 2-NN are [10.0]@2 and [5.0]@7, global_max=12 (to [0.0]).
317        // avg=4.5, range=10 → score = 2.5/10 = 0.25 exact.
318        let score = eval(&ns, vec![12.0]);
319        assert!((score - 0.25).abs() < 1e-5, "expected 0.25, got {score}");
320    }
321
322    #[test]
323    fn novelty_score_with_k_equal_to_archive_size_uses_all_points() {
324        let ns = make_ns(3, -1.0);
325        seed(&ns, [vec![0.0], vec![5.0], vec![10.0]]);
326
327        // k >= n branch: cluster contains all archive points sorted ascending.
328        // query=[4.0] distances: 4, 1, 6 → min=1, max=6, avg=11/3.
329        // score = (11/3 - 1) / 5 = (8/3) / 5 = 8/15 ≈ 0.5333.
330        let score = eval(&ns, vec![4.0]);
331        assert!(
332            (score - 8.0 / 15.0).abs() < 1e-5,
333            "expected 8/15 ≈ 0.5333, got {score}"
334        );
335    }
336
337    #[test]
338    fn novelty_score_always_in_unit_interval() {
339        let ns = make_ns(2, -1.0);
340        seed(&ns, [vec![0.0], vec![5.0], vec![10.0]]);
341
342        for x in [-100.0, -1.0, 0.0, 2.5, 5.0, 7.5, 10.0, 12.0, 100.0] {
343            // fresh archive each iteration so writes from earlier queries don't drift.
344            let ns = make_ns(2, -1.0);
345            seed(&ns, [vec![0.0], vec![5.0], vec![10.0]]);
346
347            let score = eval(&ns, vec![x]);
348            assert!(
349                (0.0..=1.0).contains(&score),
350                "score {score} out of [0,1] for x={x}"
351            );
352        }
353    }
354
355    #[test]
356    fn score_above_threshold_admits_to_archive() {
357        let ns = make_ns(2, 0.2);
358        seed(&ns, [vec![0.0], vec![5.0], vec![10.0]]);
359        assert_eq!(archive_view_len(&ns), 3);
360
361        // query=[12.0] → score 0.25 > 0.2 → admitted.
362        let score = eval(&ns, vec![12.0]);
363        assert!(score > 0.2, "expected > threshold, got {score}");
364        assert_eq!(archive_view_len(&ns), 4);
365    }
366
367    #[test]
368    fn archive_window_caps_at_configured_size() {
369        // archive_size=5, threshold=-1 (always admit).
370        let ns = NoveltySearch::new(|v: &Vec<f32>| v.clone())
371            .k(1)
372            .threshold(-1.0)
373            .archive_size(5);
374
375        for i in 0..40 {
376            eval(&ns, vec![i as f32 * 100.0]);
377        }
378
379        // The k-NN sees archive.values(), which is the live window.
380        let archive = ns.archive.read().unwrap();
381        assert!(
382            archive.values().len() <= 5,
383            "archive view exceeds window cap: {}",
384            archive.values().len()
385        );
386    }
387
388    #[test]
389    fn fitness_function_ref_variant_evaluates_and_admits() {
390        let ns = make_ns(1, 0.5);
391        let ind = vec![1.0, 0.0];
392        let score = eval(&ns, ind);
393        assert_eq!(score, 0.5);
394        assert_eq!(archive_view_len(&ns), 1);
395    }
396
397    #[test]
398    fn batch_eval_returns_one_score_per_individual() {
399        let ns = make_ns(3, 0.5);
400        let scores = eval_batch(&ns, vec![vec![0.0], vec![5.0], vec![10.0]]);
401        assert_eq!(scores.len(), 3);
402        for (i, &s) in scores.iter().enumerate() {
403            assert!((0.0..=1.0).contains(&s), "scores[{i}] = {s} out of [0,1]");
404        }
405    }
406
407    #[test]
408    fn batch_eval_admits_via_running_archive_size_for_bootstrap() {
409        // Scoring is against a frozen pre-batch snapshot, but the admission pass
410        // walks the batch with the live archive size so bootstrap (`archive.len()
411        // < k`) still works mid-batch when the pre-batch archive is undersized.
412        let ns = make_ns(2, -1.0); // threshold=-1 → score-based admission also passes.
413        seed(&ns, [vec![0.0], vec![10.0]]);
414        let initial = archive_view_len(&ns);
415
416        let scores = eval_batch(&ns, vec![vec![5.0], vec![20.0], vec![-5.0]]);
417        assert_eq!(scores.len(), 3);
418        assert_eq!(archive_view_len(&ns), initial + 3);
419    }
420
421    #[test]
422    fn batch_eval_does_not_score_against_intra_batch_additions() {
423        // If the batch were "online" (admitting earlier members before scoring
424        // later ones), then a duplicate entry later in the batch would see its
425        // earlier copy as a near-zero-distance neighbour and score ~0.
426        // True batch should score the duplicate against only the pre-batch
427        // archive — yielding the same score as the original.
428        let ns = make_ns(2, -1.0);
429        seed(&ns, [vec![0.0], vec![10.0]]);
430
431        let scores = eval_batch(&ns, vec![vec![5.0], vec![5.0]]);
432        assert_eq!(scores.len(), 2);
433        assert!(
434            (scores[0] - scores[1]).abs() < 1e-6,
435            "duplicate batch members should score identically: {scores:?}"
436        );
437    }
438
439    #[test]
440    fn from_batch_fn_routes_batch_through_user_closure_and_falls_back_per_item() {
441        use std::sync::atomic::{AtomicUsize, Ordering};
442
443        let batch_calls = Arc::new(AtomicUsize::new(0));
444        let total_seen = Arc::new(AtomicUsize::new(0));
445
446        let ns: NoveltySearch<Vec<f32>> = {
447            let batch_calls = Arc::clone(&batch_calls);
448            let total_seen = Arc::clone(&total_seen);
449            NoveltySearch::from_batch_fn(move |members: &[Vec<f32>]| {
450                batch_calls.fetch_add(1, Ordering::Relaxed);
451                total_seen.fetch_add(members.len(), Ordering::Relaxed);
452                members.iter().map(|v| v.clone()).collect()
453            })
454            .k(3)
455            .threshold(0.5)
456            .archive_size(100)
457        };
458
459        // Single-eval routes through the per-item fallback, which calls the
460        // batch closure with a 1-element slice.
461        let _ = eval(&ns, vec![1.0, 0.0]);
462        assert_eq!(batch_calls.load(Ordering::Relaxed), 1);
463        assert_eq!(total_seen.load(Ordering::Relaxed), 1);
464
465        // Batch eval calls the closure once with the full slice — the fast path.
466        let _ = eval_batch(&ns, vec![vec![1.0], vec![2.0], vec![3.0], vec![4.0]]);
467        assert_eq!(batch_calls.load(Ordering::Relaxed), 2);
468        assert_eq!(total_seen.load(Ordering::Relaxed), 5);
469    }
470
471    #[test]
472    fn clone_shares_archive_with_original() {
473        let ns = make_ns(3, 0.5);
474        let twin = ns.clone();
475
476        eval(&ns, vec![1.0]);
477        eval(&ns, vec![2.0]);
478
479        // Same Arc<RwLock<...>> backing both handles.
480        assert_eq!(archive_view_len(&twin), 2);
481    }
482
483    #[test]
484    fn cosine_distance_identical_direction_scores_zero_in_degenerate_case() {
485        let ns = NoveltySearch::new(|v: &Vec<f32>| v.clone())
486            .k(3)
487            .threshold(0.99)
488            .archive_size(100)
489            .cosine_distance();
490        seed(&ns, [vec![1.0, 0.0]]);
491
492        // Same direction, different magnitude → cosine distance 0 (clamped to 1e-12).
493        let score = eval(&ns, vec![100.0, 0.0]);
494        assert_eq!(score, 0.0);
495    }
496
497    #[test]
498    fn concurrent_evaluation_does_not_panic_or_deadlock() {
499        use std::thread;
500
501        let ns = Arc::new(
502            NoveltySearch::new(|v: &Vec<f32>| v.clone())
503                .k(5)
504                .threshold(0.3)
505                .archive_size(200),
506        );
507
508        let handles: Vec<_> = (0..8)
509            .map(|i| {
510                let ns = Arc::clone(&ns);
511                thread::spawn(move || {
512                    for j in 0..50 {
513                        let v = (i * 50 + j) as f32;
514                        eval(&ns, vec![v, v * 0.5]);
515                    }
516                })
517            })
518            .collect();
519
520        for h in handles {
521            h.join().expect("worker thread panicked");
522        }
523
524        // 8 threads × 50 evals = 400 attempts; capped by archive_size=200.
525        let archive = ns.archive.read().unwrap();
526        assert!(archive.values().len() <= 200);
527    }
528}