radiate_core/fitness/
novelty.rs1use crate::{BatchFitnessFunction, FitnessFunction};
2use std::{
3 collections::VecDeque,
4 sync::{Arc, RwLock},
5};
6
7pub trait Novelty<T> {
8 fn description(&self, member: &T) -> Vec<f32>;
9}
10
11impl<T, F> Novelty<T> for F
12where
13 F: Fn(&T) -> Vec<f32>,
14{
15 fn description(&self, member: &T) -> Vec<f32> {
16 self(member)
17 }
18}
19
20#[derive(Clone)]
21pub struct NoveltySearch<T> {
22 pub behavior: Arc<dyn Novelty<T> + Send + Sync>,
23 pub archive: Arc<RwLock<VecDeque<Vec<f32>>>>,
24 pub k: usize,
25 pub threshold: f32,
26 pub max_archive_size: usize,
27 pub distance_fn: Arc<dyn Fn(&[f32], &[f32]) -> f32 + Send + Sync>,
28}
29
30impl<T> NoveltySearch<T> {
31 pub fn new<N>(behavior: N, k: usize, threshold: f32) -> Self
32 where
33 N: Novelty<T> + Send + Sync + 'static,
34 {
35 NoveltySearch {
36 behavior: Arc::new(behavior),
37 archive: Arc::new(RwLock::new(VecDeque::new())),
38 k,
39 threshold,
40 max_archive_size: 1000,
41 distance_fn: Arc::new(|a, b| {
42 if a.len() != b.len() {
43 return f32::INFINITY;
44 }
45 a.iter()
46 .zip(b.iter())
47 .map(|(x, y)| (x - y).powi(2))
48 .sum::<f32>()
49 .sqrt()
50 }),
51 }
52 }
53
54 pub fn with_max_archive_size(mut self, size: usize) -> Self {
55 self.max_archive_size = size;
56 self
57 }
58
59 pub fn cosine_distance(mut self) -> Self {
60 self.distance_fn = Arc::new(|a, b| {
61 let dot_product = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum::<f32>();
62 let norm_a = a.iter().map(|x| x * x).sum::<f32>().sqrt();
63 let norm_b = b.iter().map(|x| x * x).sum::<f32>().sqrt();
64 1.0 - (dot_product / (norm_a * norm_b))
65 });
66 self
67 }
68
69 pub fn euclidean_distance(mut self) -> Self {
70 self.distance_fn = Arc::new(|a, b| {
71 if a.len() != b.len() {
72 return f32::INFINITY;
73 }
74 a.iter()
75 .zip(b.iter())
76 .map(|(x, y)| (x - y).powi(2))
77 .sum::<f32>()
78 .sqrt()
79 });
80 self
81 }
82
83 pub fn hamming_distance(mut self) -> Self {
84 self.distance_fn = Arc::new(|a, b| {
85 if a.len() != b.len() {
86 return f32::INFINITY;
87 }
88 a.iter().zip(b.iter()).filter(|(x, y)| x != y).count() as f32
89 });
90 self
91 }
92
93 fn normalized_novelty_score(&self, descriptor: &Vec<f32>, archive: &VecDeque<Vec<f32>>) -> f32 {
94 if archive.is_empty() {
95 return 0.5;
96 }
97
98 let mut min_distance = f32::INFINITY;
99 let mut max_distance = f32::NEG_INFINITY;
100 let mut distances = archive
101 .iter()
102 .map(|archived| (self.distance_fn)(&descriptor, archived))
103 .inspect(|&d| {
104 max_distance = max_distance.max(d);
105 min_distance = min_distance.min(d);
106 })
107 .collect::<Vec<f32>>();
108
109 if max_distance == min_distance {
110 if min_distance == 0.0 {
111 return 0.0;
112 }
113
114 if min_distance > 0.0 {
115 return 0.5;
116 }
117
118 return 0.0;
119 }
120
121 distances.sort_by(|a, b| a.partial_cmp(b).unwrap());
122 let k = std::cmp::min(self.k, distances.len());
123 let k_nearest_distances = &distances[..k];
124 let avg_distance = k_nearest_distances.iter().sum::<f32>() / k as f32;
125
126 (avg_distance - min_distance) / (max_distance - min_distance)
127 }
128
129 fn evaluate_internal(&self, individual: &T) -> f32 {
130 let description = self.behavior.description(individual);
131
132 let is_empty = {
133 let archive = self.archive.read().unwrap();
134 archive.is_empty()
135 };
136
137 if is_empty {
138 let mut writer = self.archive.write().unwrap();
139 writer.push_back(description);
140 return 0.5;
141 }
142
143 let (novelty, should_add) = {
144 let archive = self.archive.read().unwrap();
145 let result = self.normalized_novelty_score(&description, &archive);
146 let should_add = result > self.threshold || archive.len() < self.k;
147
148 (result, should_add)
149 };
150
151 let mut writer = self.archive.write().unwrap();
152
153 if should_add {
154 writer.push_back(description);
155 while writer.len() > self.max_archive_size {
156 writer.pop_front();
157 }
158 }
159
160 novelty
161 }
162}
163
164impl<T> FitnessFunction<T, f32> for NoveltySearch<T>
165where
166 T: Send + Sync,
167{
168 fn evaluate(&self, individual: T) -> f32 {
169 self.evaluate_internal(&individual)
170 }
171}
172
173impl<T> FitnessFunction<&T, f32> for NoveltySearch<T>
174where
175 T: Send + Sync,
176{
177 fn evaluate(&self, individual: &T) -> f32 {
178 self.evaluate_internal(individual)
179 }
180}
181
182impl<T> BatchFitnessFunction<T, f32> for NoveltySearch<T>
183where
184 T: Send + Sync,
185{
186 fn evaluate(&self, individuals: &[T]) -> Vec<f32> {
187 individuals
188 .into_iter()
189 .map(|ind| self.evaluate_internal(ind))
190 .collect()
191 }
192}
193
194impl<T> BatchFitnessFunction<&T, f32> for NoveltySearch<T>
195where
196 T: Send + Sync,
197{
198 fn evaluate(&self, individuals: &[&T]) -> Vec<f32> {
199 individuals
200 .into_iter()
201 .map(|ind| self.evaluate_internal(ind))
202 .collect()
203 }
204}