1use crate::{
2 BatchFitnessFunction, BatchedFn, CosineDistance, EuclideanDistance, FitnessFunction,
3 HammingDistance, diversity::Distance, math::knn::KNN,
4};
5use radiate_utils::WindowBuffer;
6use std::sync::{Arc, RwLock};
7
8const DEFAULT_ARCHIVE_SIZE: usize = 1000;
9const DEFAULT_K: usize = 15;
10const DEFAULT_THRESHOLD: f32 = 0.5;
11
12pub trait Novelty<T>: Send + Sync {
13 fn description(&self, member: &T) -> Vec<f32>;
14
15 fn batch_description(&self, members: &[T]) -> Vec<Vec<f32>> {
20 members.iter().map(|m| self.description(m)).collect()
21 }
22}
23
24impl<T, F> Novelty<T> for F
25where
26 F: Fn(&T) -> Vec<f32> + Send + Sync,
27{
28 fn description(&self, member: &T) -> Vec<f32> {
29 self(member)
30 }
31}
32
33impl<T, F> Novelty<T> for BatchedFn<F>
34where
35 F: Fn(&[T]) -> Vec<Vec<f32>> + Send + Sync,
36{
37 fn description(&self, member: &T) -> Vec<f32> {
38 (self.0)(std::slice::from_ref(member))
39 .into_iter()
40 .next()
41 .unwrap_or_default()
42 }
43
44 fn batch_description(&self, members: &[T]) -> Vec<Vec<f32>> {
45 (self.0)(members)
46 }
47}
48
49#[derive(Clone)]
50pub struct NoveltySearch<T> {
51 pub behavior: Arc<dyn Novelty<T>>,
52 pub archive: Arc<RwLock<WindowBuffer<Vec<f32>>>>,
53 pub k: usize,
54 pub threshold: f32,
55 pub distance_fn: Arc<dyn Distance<Vec<f32>>>,
56}
57
58impl<T> NoveltySearch<T> {
59 pub fn new<N>(behavior: N) -> Self
60 where
61 N: Novelty<T> + Send + Sync + 'static,
62 {
63 NoveltySearch {
64 behavior: Arc::new(behavior),
65 archive: Arc::new(RwLock::new(WindowBuffer::with_capacity(
66 DEFAULT_ARCHIVE_SIZE,
67 ))),
68 k: DEFAULT_K,
69 threshold: DEFAULT_THRESHOLD,
70 distance_fn: Arc::new(EuclideanDistance),
71 }
72 }
73
74 pub fn from_batch_fn<F>(f: F) -> Self
77 where
78 F: Fn(&[T]) -> Vec<Vec<f32>> + Send + Sync + 'static,
79 T: 'static,
80 {
81 Self::new(BatchedFn(f))
82 }
83
84 pub fn k(mut self, k: usize) -> Self {
85 self.k = k;
86 self
87 }
88
89 pub fn threshold(mut self, threshold: f32) -> Self {
90 self.threshold = threshold;
91 self
92 }
93
94 pub fn archive_size(mut self, size: usize) -> Self {
95 self.archive = Arc::new(RwLock::new(WindowBuffer::with_capacity(size)));
96 self
97 }
98
99 pub fn cosine_distance(mut self) -> Self {
100 self.distance_fn = Arc::new(CosineDistance);
101 self
102 }
103
104 pub fn euclidean_distance(mut self) -> Self {
105 self.distance_fn = Arc::new(EuclideanDistance);
106 self
107 }
108
109 pub fn hamming_distance(mut self) -> Self {
110 self.distance_fn = Arc::new(HammingDistance);
111 self
112 }
113
114 fn novelty_score(&self, descriptor: &Vec<f32>, archive: &WindowBuffer<Vec<f32>>) -> f32 {
115 let slice = archive.values();
116 let mut knn = KNN::new(slice, Arc::clone(&self.distance_fn));
117 let query = knn.query_point(descriptor, self.k);
118
119 let min_dist = query.min_distance;
120 let max_dist = query.max_distance;
121 let range = max_dist - min_dist;
122
123 if range < f32::EPSILON {
124 return if min_dist < f32::EPSILON { 0.0 } else { 0.5 };
125 }
126
127 let avg_dist = query.average_distance();
128 (avg_dist - min_dist) / range
129 }
130
131 fn evaluate_internal(&self, individual: &T) -> f32 {
132 let description = self.behavior.description(individual);
133 let mut archive = self.archive.write().unwrap();
134
135 if archive.is_empty() {
136 archive.push(description);
137 return 0.5;
138 }
139
140 let novelty = self.novelty_score(&description, &archive);
141 if novelty > self.threshold || archive.len() < self.k {
142 archive.push(description);
143 }
144
145 novelty
146 }
147
148 fn evaluate_batch_internal(&self, individuals: &[T]) -> Vec<f32> {
149 let descriptions = self.behavior.batch_description(individuals);
150 let mut archive = self.archive.write().unwrap();
151
152 if archive.is_empty() {
153 let result = vec![0.5; descriptions.len()];
154 for desc in descriptions {
155 archive.push(desc);
156 }
157
158 return result;
159 }
160
161 let mut scores = Vec::with_capacity(descriptions.len());
165 for desc in descriptions.into_iter() {
166 let score = self.novelty_score(&desc, &archive);
167
168 if score > self.threshold || archive.len() < self.k {
169 archive.push(desc);
170 }
171
172 scores.push(score);
173 }
174
175 scores
176 }
177}
178
179impl<T> FitnessFunction<T, f32> for NoveltySearch<T>
180where
181 T: Send + Sync,
182{
183 fn evaluate(&self, individual: T) -> f32 {
184 self.evaluate_internal(&individual)
185 }
186}
187
188impl<T> FitnessFunction<&T, f32> for NoveltySearch<T>
189where
190 T: Send + Sync,
191{
192 fn evaluate(&self, individual: &T) -> f32 {
193 self.evaluate_internal(individual)
194 }
195}
196
197impl<T> BatchFitnessFunction<T, f32> for NoveltySearch<T>
198where
199 T: Send + Sync,
200{
201 fn evaluate(&self, individuals: Vec<T>) -> Vec<f32> {
202 self.evaluate_batch_internal(&individuals)
203 }
204}
205
206#[cfg(test)]
207mod tests {
208 use super::*;
209 use crate::{BatchFitnessFunction, FitnessFunction};
210
211 fn make_ns(k: usize, threshold: f32) -> NoveltySearch<Vec<f32>> {
212 NoveltySearch::new(|v: &Vec<f32>| v.clone())
213 .k(k)
214 .threshold(threshold)
215 .archive_size(100)
216 }
217
218 fn seed(ns: &NoveltySearch<Vec<f32>>, points: impl IntoIterator<Item = Vec<f32>>) {
219 let mut archive = ns.archive.write().unwrap();
220 for p in points {
221 archive.push(p);
222 }
223 }
224
225 fn archive_view_len(ns: &NoveltySearch<Vec<f32>>) -> usize {
226 ns.archive.read().unwrap().values().len()
227 }
228
229 fn eval(ns: &NoveltySearch<Vec<f32>>, v: Vec<f32>) -> f32 {
230 <NoveltySearch<Vec<f32>> as FitnessFunction<Vec<f32>, f32>>::evaluate(ns, v)
231 }
232
233 fn eval_batch(ns: &NoveltySearch<Vec<f32>>, vs: Vec<Vec<f32>>) -> Vec<f32> {
234 <NoveltySearch<Vec<f32>> as BatchFitnessFunction<Vec<f32>, f32>>::evaluate(ns, vs)
235 }
236
237 #[test]
238 fn empty_archive_single_eval_scores_half_and_seeds_one_entry() {
239 let ns = make_ns(3, 0.5);
240 let score = eval(&ns, vec![1.0, 0.0]);
241 assert_eq!(score, 0.5);
242 assert_eq!(archive_view_len(&ns), 1);
243 }
244
245 #[test]
246 fn empty_archive_batch_eval_scores_half_and_seeds_every_entry() {
247 let ns = make_ns(3, 0.5);
248 let scores = eval_batch(
249 &ns,
250 vec![vec![0.0], vec![1.0], vec![2.0], vec![3.0], vec![4.0]],
251 );
252
253 assert_eq!(scores, vec![0.5; 5]);
254 assert_eq!(archive_view_len(&ns), 5);
255 }
256
257 #[test]
258 fn identical_to_sole_archive_point_scores_zero() {
259 let ns = make_ns(3, 0.99);
260 seed(&ns, [vec![1.0, 0.0]]);
261
262 let score = eval(&ns, vec![1.0, 0.0]);
265 assert_eq!(score, 0.0);
266 }
267
268 #[test]
269 fn different_from_sole_archive_point_scores_half_neutral() {
270 let ns = make_ns(3, 0.99);
271 seed(&ns, [vec![0.0]]);
272
273 let score = eval(&ns, vec![5.0]);
275 assert_eq!(score, 0.5);
276 }
277
278 #[test]
279 fn bootstrap_admits_first_k_individuals_under_strict_threshold() {
280 let ns = make_ns(3, 0.99);
283
284 for _ in 0..3 {
285 eval(&ns, vec![1.0]);
286 }
287 assert_eq!(archive_view_len(&ns), 3);
288
289 eval(&ns, vec![1.0]);
291 assert_eq!(archive_view_len(&ns), 3);
292 }
293
294 #[test]
295 fn post_bootstrap_score_below_threshold_does_not_admit() {
296 let ns = make_ns(2, 0.2);
297 seed(&ns, [vec![0.0], vec![5.0], vec![10.0]]);
299
300 let score = eval(&ns, vec![2.0]);
303 assert!(
304 (score - 0.083_333).abs() < 1e-4,
305 "expected ≈0.0833, got {score}"
306 );
307 assert_eq!(archive_view_len(&ns), 3);
308 }
309
310 #[test]
311 fn novelty_score_matches_normalization_formula() {
312 let ns = make_ns(2, -1.0);
314 seed(&ns, [vec![0.0], vec![5.0], vec![10.0]]);
315
316 let score = eval(&ns, vec![12.0]);
319 assert!((score - 0.25).abs() < 1e-5, "expected 0.25, got {score}");
320 }
321
322 #[test]
323 fn novelty_score_with_k_equal_to_archive_size_uses_all_points() {
324 let ns = make_ns(3, -1.0);
325 seed(&ns, [vec![0.0], vec![5.0], vec![10.0]]);
326
327 let score = eval(&ns, vec![4.0]);
331 assert!(
332 (score - 8.0 / 15.0).abs() < 1e-5,
333 "expected 8/15 ≈ 0.5333, got {score}"
334 );
335 }
336
337 #[test]
338 fn novelty_score_always_in_unit_interval() {
339 let ns = make_ns(2, -1.0);
340 seed(&ns, [vec![0.0], vec![5.0], vec![10.0]]);
341
342 for x in [-100.0, -1.0, 0.0, 2.5, 5.0, 7.5, 10.0, 12.0, 100.0] {
343 let ns = make_ns(2, -1.0);
345 seed(&ns, [vec![0.0], vec![5.0], vec![10.0]]);
346
347 let score = eval(&ns, vec![x]);
348 assert!(
349 (0.0..=1.0).contains(&score),
350 "score {score} out of [0,1] for x={x}"
351 );
352 }
353 }
354
355 #[test]
356 fn score_above_threshold_admits_to_archive() {
357 let ns = make_ns(2, 0.2);
358 seed(&ns, [vec![0.0], vec![5.0], vec![10.0]]);
359 assert_eq!(archive_view_len(&ns), 3);
360
361 let score = eval(&ns, vec![12.0]);
363 assert!(score > 0.2, "expected > threshold, got {score}");
364 assert_eq!(archive_view_len(&ns), 4);
365 }
366
367 #[test]
368 fn archive_window_caps_at_configured_size() {
369 let ns = NoveltySearch::new(|v: &Vec<f32>| v.clone())
371 .k(1)
372 .threshold(-1.0)
373 .archive_size(5);
374
375 for i in 0..40 {
376 eval(&ns, vec![i as f32 * 100.0]);
377 }
378
379 let archive = ns.archive.read().unwrap();
381 assert!(
382 archive.values().len() <= 5,
383 "archive view exceeds window cap: {}",
384 archive.values().len()
385 );
386 }
387
388 #[test]
389 fn fitness_function_ref_variant_evaluates_and_admits() {
390 let ns = make_ns(1, 0.5);
391 let ind = vec![1.0, 0.0];
392 let score = eval(&ns, ind);
393 assert_eq!(score, 0.5);
394 assert_eq!(archive_view_len(&ns), 1);
395 }
396
397 #[test]
398 fn batch_eval_returns_one_score_per_individual() {
399 let ns = make_ns(3, 0.5);
400 let scores = eval_batch(&ns, vec![vec![0.0], vec![5.0], vec![10.0]]);
401 assert_eq!(scores.len(), 3);
402 for (i, &s) in scores.iter().enumerate() {
403 assert!((0.0..=1.0).contains(&s), "scores[{i}] = {s} out of [0,1]");
404 }
405 }
406
407 #[test]
408 fn batch_eval_admits_via_running_archive_size_for_bootstrap() {
409 let ns = make_ns(2, -1.0); seed(&ns, [vec![0.0], vec![10.0]]);
414 let initial = archive_view_len(&ns);
415
416 let scores = eval_batch(&ns, vec![vec![5.0], vec![20.0], vec![-5.0]]);
417 assert_eq!(scores.len(), 3);
418 assert_eq!(archive_view_len(&ns), initial + 3);
419 }
420
421 #[test]
422 fn batch_eval_does_not_score_against_intra_batch_additions() {
423 let ns = make_ns(2, -1.0);
429 seed(&ns, [vec![0.0], vec![10.0]]);
430
431 let scores = eval_batch(&ns, vec![vec![5.0], vec![5.0]]);
432 assert_eq!(scores.len(), 2);
433 assert!(
434 (scores[0] - scores[1]).abs() < 1e-6,
435 "duplicate batch members should score identically: {scores:?}"
436 );
437 }
438
439 #[test]
440 fn from_batch_fn_routes_batch_through_user_closure_and_falls_back_per_item() {
441 use std::sync::atomic::{AtomicUsize, Ordering};
442
443 let batch_calls = Arc::new(AtomicUsize::new(0));
444 let total_seen = Arc::new(AtomicUsize::new(0));
445
446 let ns: NoveltySearch<Vec<f32>> = {
447 let batch_calls = Arc::clone(&batch_calls);
448 let total_seen = Arc::clone(&total_seen);
449 NoveltySearch::from_batch_fn(move |members: &[Vec<f32>]| {
450 batch_calls.fetch_add(1, Ordering::Relaxed);
451 total_seen.fetch_add(members.len(), Ordering::Relaxed);
452 members.iter().map(|v| v.clone()).collect()
453 })
454 .k(3)
455 .threshold(0.5)
456 .archive_size(100)
457 };
458
459 let _ = eval(&ns, vec![1.0, 0.0]);
462 assert_eq!(batch_calls.load(Ordering::Relaxed), 1);
463 assert_eq!(total_seen.load(Ordering::Relaxed), 1);
464
465 let _ = eval_batch(&ns, vec![vec![1.0], vec![2.0], vec![3.0], vec![4.0]]);
467 assert_eq!(batch_calls.load(Ordering::Relaxed), 2);
468 assert_eq!(total_seen.load(Ordering::Relaxed), 5);
469 }
470
471 #[test]
472 fn clone_shares_archive_with_original() {
473 let ns = make_ns(3, 0.5);
474 let twin = ns.clone();
475
476 eval(&ns, vec![1.0]);
477 eval(&ns, vec![2.0]);
478
479 assert_eq!(archive_view_len(&twin), 2);
481 }
482
483 #[test]
484 fn cosine_distance_identical_direction_scores_zero_in_degenerate_case() {
485 let ns = NoveltySearch::new(|v: &Vec<f32>| v.clone())
486 .k(3)
487 .threshold(0.99)
488 .archive_size(100)
489 .cosine_distance();
490 seed(&ns, [vec![1.0, 0.0]]);
491
492 let score = eval(&ns, vec![100.0, 0.0]);
494 assert_eq!(score, 0.0);
495 }
496
497 #[test]
498 fn concurrent_evaluation_does_not_panic_or_deadlock() {
499 use std::thread;
500
501 let ns = Arc::new(
502 NoveltySearch::new(|v: &Vec<f32>| v.clone())
503 .k(5)
504 .threshold(0.3)
505 .archive_size(200),
506 );
507
508 let handles: Vec<_> = (0..8)
509 .map(|i| {
510 let ns = Arc::clone(&ns);
511 thread::spawn(move || {
512 for j in 0..50 {
513 let v = (i * 50 + j) as f32;
514 eval(&ns, vec![v, v * 0.5]);
515 }
516 })
517 })
518 .collect();
519
520 for h in handles {
521 h.join().expect("worker thread panicked");
522 }
523
524 let archive = ns.archive.read().unwrap();
526 assert!(archive.values().len() <= 200);
527 }
528}