radiate_core/fitness/
novelty.rs

1use crate::{BatchFitnessFunction, FitnessFunction};
2use std::{
3    collections::VecDeque,
4    sync::{Arc, RwLock},
5};
6
7pub trait Novelty<T> {
8    fn description(&self, member: &T) -> Vec<f32>;
9}
10
11impl<T, F> Novelty<T> for F
12where
13    F: Fn(&T) -> Vec<f32>,
14{
15    fn description(&self, member: &T) -> Vec<f32> {
16        self(member)
17    }
18}
19
20#[derive(Clone)]
21pub struct NoveltySearch<T> {
22    pub behavior: Arc<dyn Novelty<T> + Send + Sync>,
23    pub archive: Arc<RwLock<VecDeque<Vec<f32>>>>,
24    pub k: usize,
25    pub threshold: f32,
26    pub max_archive_size: usize,
27    pub distance_fn: Arc<dyn Fn(&[f32], &[f32]) -> f32 + Send + Sync>,
28}
29
30impl<T> NoveltySearch<T> {
31    pub fn new<N>(behavior: N, k: usize, threshold: f32) -> Self
32    where
33        N: Novelty<T> + Send + Sync + 'static,
34    {
35        NoveltySearch {
36            behavior: Arc::new(behavior),
37            archive: Arc::new(RwLock::new(VecDeque::new())),
38            k,
39            threshold,
40            max_archive_size: 1000,
41            distance_fn: Arc::new(|a, b| {
42                if a.len() != b.len() {
43                    return f32::INFINITY;
44                }
45                a.iter()
46                    .zip(b.iter())
47                    .map(|(x, y)| (x - y).powi(2))
48                    .sum::<f32>()
49                    .sqrt()
50            }),
51        }
52    }
53
54    pub fn with_max_archive_size(mut self, size: usize) -> Self {
55        self.max_archive_size = size;
56        self
57    }
58
59    pub fn cosine_distance(mut self) -> Self {
60        self.distance_fn = Arc::new(|a, b| {
61            let dot_product = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum::<f32>();
62            let norm_a = a.iter().map(|x| x * x).sum::<f32>().sqrt();
63            let norm_b = b.iter().map(|x| x * x).sum::<f32>().sqrt();
64            1.0 - (dot_product / (norm_a * norm_b))
65        });
66        self
67    }
68
69    pub fn euclidean_distance(mut self) -> Self {
70        self.distance_fn = Arc::new(|a, b| {
71            if a.len() != b.len() {
72                return f32::INFINITY;
73            }
74            a.iter()
75                .zip(b.iter())
76                .map(|(x, y)| (x - y).powi(2))
77                .sum::<f32>()
78                .sqrt()
79        });
80        self
81    }
82
83    pub fn hamming_distance(mut self) -> Self {
84        self.distance_fn = Arc::new(|a, b| {
85            if a.len() != b.len() {
86                return f32::INFINITY;
87            }
88            a.iter().zip(b.iter()).filter(|(x, y)| x != y).count() as f32
89        });
90        self
91    }
92
93    fn normalized_novelty_score(&self, descriptor: &Vec<f32>, archive: &VecDeque<Vec<f32>>) -> f32 {
94        if archive.is_empty() {
95            return 0.5;
96        }
97
98        let mut min_distance = f32::INFINITY;
99        let mut max_distance = f32::NEG_INFINITY;
100        let mut distances = archive
101            .iter()
102            .map(|archived| (self.distance_fn)(&descriptor, archived))
103            .inspect(|&d| {
104                max_distance = max_distance.max(d);
105                min_distance = min_distance.min(d);
106            })
107            .collect::<Vec<f32>>();
108
109        if max_distance == min_distance {
110            if min_distance == 0.0 {
111                return 0.0;
112            }
113
114            if min_distance > 0.0 {
115                return 0.5;
116            }
117
118            return 0.0;
119        }
120
121        distances.sort_by(|a, b| a.partial_cmp(b).unwrap());
122        let k = std::cmp::min(self.k, distances.len());
123        let k_nearest_distances = &distances[..k];
124        let avg_distance = k_nearest_distances.iter().sum::<f32>() / k as f32;
125
126        (avg_distance - min_distance) / (max_distance - min_distance)
127    }
128
129    fn evaluate_internal(&self, individual: &T) -> f32 {
130        let description = self.behavior.description(individual);
131
132        let is_empty = {
133            let archive = self.archive.read().unwrap();
134            archive.is_empty()
135        };
136
137        if is_empty {
138            let mut writer = self.archive.write().unwrap();
139            writer.push_back(description);
140            return 0.5;
141        }
142
143        let (novelty, should_add) = {
144            let archive = self.archive.read().unwrap();
145            let result = self.normalized_novelty_score(&description, &archive);
146            let should_add = result > self.threshold || archive.len() < self.k;
147
148            (result, should_add)
149        };
150
151        let mut writer = self.archive.write().unwrap();
152
153        if should_add {
154            writer.push_back(description);
155            while writer.len() > self.max_archive_size {
156                writer.pop_front();
157            }
158        }
159
160        novelty
161    }
162}
163
164impl<T> FitnessFunction<T, f32> for NoveltySearch<T>
165where
166    T: Send + Sync,
167{
168    fn evaluate(&self, individual: T) -> f32 {
169        self.evaluate_internal(&individual)
170    }
171}
172
173impl<T> FitnessFunction<&T, f32> for NoveltySearch<T>
174where
175    T: Send + Sync,
176{
177    fn evaluate(&self, individual: &T) -> f32 {
178        self.evaluate_internal(individual)
179    }
180}
181
182impl<T> BatchFitnessFunction<T, f32> for NoveltySearch<T>
183where
184    T: Send + Sync,
185{
186    fn evaluate(&self, individuals: &[T]) -> Vec<f32> {
187        individuals
188            .into_iter()
189            .map(|ind| self.evaluate_internal(ind))
190            .collect()
191    }
192}
193
194impl<T> BatchFitnessFunction<&T, f32> for NoveltySearch<T>
195where
196    T: Send + Sync,
197{
198    fn evaluate(&self, individuals: &[&T]) -> Vec<f32> {
199        individuals
200            .into_iter()
201            .map(|ind| self.evaluate_internal(ind))
202            .collect()
203    }
204}