1use crate::views::MatrixView;
7use diskann_vector::{conversion::CastFromSlice, distance::SquaredL2, PureDistanceFunction};
8use half::f16;
9
10pub trait ComputeMedoid: Sized {
12 fn compute_medoid(data: MatrixView<Self>) -> Vec<Self>;
13}
14
15impl ComputeMedoid for f32 {
16 fn compute_medoid(data: MatrixView<Self>) -> Vec<Self> {
17 if data.ncols() == 0 {
18 return vec![];
19 }
20
21 let mut sum = vec![0.0f64; data.ncols()];
22 data.row_iter().for_each(|r| {
23 std::iter::zip(sum.iter_mut(), r.iter()).for_each(|(o, i)| {
24 let i: f64 = (*i).into();
25 *o += i;
26 });
27 });
28
29 let m: Vec<f32> = sum
30 .iter()
31 .map(|s| (s / data.nrows() as f64) as f32)
32 .collect();
33
34 let mut min_dist: f32 = f32::MAX;
35 let mut medoid = None;
36 data.row_iter().for_each(|r| {
37 let d = SquaredL2::evaluate(m.as_slice(), r);
38 if d < min_dist {
39 min_dist = d;
40 medoid = Some(r);
41 }
42 });
43
44 medoid
45 .map(|x| x.into())
46 .unwrap_or(vec![0.0f32; data.ncols()])
47 }
48}
49
50impl ComputeMedoid for f16 {
51 fn compute_medoid(data: MatrixView<Self>) -> Vec<Self> {
52 if data.ncols() == 0 {
53 return vec![];
54 }
55
56 let mut sum = vec![0.0f64; data.ncols()];
57 let mut buffer = vec![0.0f32; data.ncols()];
58 data.row_iter().for_each(|r| {
59 buffer.cast_from_slice(r);
60 std::iter::zip(sum.iter_mut(), buffer.iter()).for_each(|(o, i)| {
61 let i: f64 = (*i).into();
62 *o += i;
63 });
64 });
65
66 std::iter::zip(buffer.iter_mut(), sum.iter()).for_each(|(o, i)| {
67 *o = (*i / data.nrows() as f64) as f32;
68 });
69
70 let mut min_dist: f32 = f32::MAX;
71 let mut medoid = None;
72 data.row_iter().for_each(|r| {
73 let d = SquaredL2::evaluate(buffer.as_slice(), r);
74 if d < min_dist {
75 min_dist = d;
76 medoid = Some(r);
77 }
78 });
79
80 medoid
81 .map(|x| x.into())
82 .unwrap_or(vec![f16::default(); data.ncols()])
83 }
84}
85
86impl ComputeMedoid for u8 {
87 fn compute_medoid(data: MatrixView<Self>) -> Vec<Self> {
88 if data.ncols() == 0 {
89 return vec![];
90 }
91
92 let mut sum = vec![0.0f64; data.ncols()];
93 data.row_iter().for_each(|r| {
94 std::iter::zip(sum.iter_mut(), r.iter()).for_each(|(o, i)| {
95 let i: f64 = (*i).into();
96 *o += i;
97 });
98 });
99
100 let m: Vec<f32> = sum
101 .iter()
102 .map(|s| (s / data.nrows() as f64) as f32)
103 .collect();
104
105 let mut min_dist: f32 = f32::MAX;
106 let mut medoid = None;
107 let mut as_float = vec![0.0f32; data.ncols()];
108 data.row_iter().for_each(|r| {
109 std::iter::zip(as_float.iter_mut(), r.iter())
110 .for_each(|(dst, src)| *dst = (*src).into());
111 let d = SquaredL2::evaluate(m.as_slice(), &*as_float);
112 if d < min_dist {
113 min_dist = d;
114 medoid = Some(r);
115 }
116 });
117
118 medoid.map(|x| x.into()).unwrap_or(vec![0u8; data.ncols()])
119 }
120}
121
122impl ComputeMedoid for i8 {
123 fn compute_medoid(data: MatrixView<Self>) -> Vec<Self> {
124 if data.ncols() == 0 {
125 return vec![];
126 }
127
128 let mut sum = vec![0.0f64; data.ncols()];
129 data.row_iter().for_each(|r| {
130 std::iter::zip(sum.iter_mut(), r.iter()).for_each(|(o, i)| {
131 let i: f64 = (*i).into();
132 *o += i;
133 });
134 });
135
136 let m: Vec<f32> = sum
137 .iter()
138 .map(|s| (s / data.nrows() as f64) as f32)
139 .collect();
140
141 let mut min_dist: f32 = f32::MAX;
142 let mut medoid = None;
143 let mut as_float = vec![0.0f32; data.ncols()];
144 data.row_iter().for_each(|r| {
145 std::iter::zip(as_float.iter_mut(), r.iter())
146 .for_each(|(dst, src)| *dst = (*src).into());
147 let d = SquaredL2::evaluate(m.as_slice(), &*as_float);
148 if d < min_dist {
149 min_dist = d;
150 medoid = Some(r);
151 }
152 });
153
154 medoid.map(|x| x.into()).unwrap_or(vec![0i8; data.ncols()])
155 }
156}
157
158#[cfg(test)]
163mod tests {
164 use crate::views::{Init, Matrix};
165 use rand::{
166 distr::{Distribution, StandardUniform},
167 rngs::StdRng,
168 SeedableRng,
169 };
170
171 use super::*;
172
173 fn example_dataset() -> (Matrix<f32>, Vec<f32>) {
174 let data: Vec<f32> = vec![
175 0.203688,
177 0.841956,
178 0.855665,
179 0.801917,
180 0.754536,
181 0.312881,
183 0.217382,
184 0.0644115,
185 0.348708,
186 0.999495,
187 0.657741,
189 0.914681,
190 0.555228,
191 0.13253,
192 0.118615,
193 0.356464,
195 0.207449,
196 0.452471,
197 0.925219,
198 0.508498,
199 0.749786,
201 0.90786,
202 0.129618,
203 0.597719,
204 0.000622153,
205 0.569517,
207 0.435447,
208 0.558136,
209 0.480974,
210 0.711425,
211 0.896353,
213 0.275053,
214 0.0427179,
215 0.660916,
216 0.464851,
217 0.558689,
219 0.596543,
220 0.740983,
221 0.122136,
222 0.453822,
223 0.526895,
225 0.492643,
226 0.0951115,
227 0.495487,
228 0.446127,
229 0.454093,
231 0.160239,
232 0.924585,
233 0.901708,
234 0.329328,
235 ];
236
237 let data = Matrix::<f32>::try_from(data.into(), 10, 5).unwrap();
238 let expected: Vec<f32> = data.row(5).into();
239 (data, expected)
240 }
241
242 #[test]
243 fn test_f32() {
244 let x = Matrix::<f32>::new(0.0f32, 0, 10);
246 assert_eq!(f32::compute_medoid(x.as_view()), vec![0.0; x.ncols()]);
247
248 let x = Matrix::<f32>::new(0.0f32, 10, 0);
250 assert_eq!(f32::compute_medoid(x.as_view()), Vec::<f32>::new());
251
252 let mut rng = StdRng::seed_from_u64(0xaf2f5fa0b5161acf);
253
254 let dist = StandardUniform;
256 for dim in 1..20 {
257 let x = Matrix::<f32>::new(Init(|| dist.sample(&mut rng)), 1, dim);
258 assert_eq!(&*f32::compute_medoid(x.as_view()), x.row(0));
259 }
260
261 let (data, expected) = example_dataset();
263 let m = f32::compute_medoid(data.as_view());
264 assert_eq!(m, expected);
265 }
266
267 #[test]
268 fn test_f16() {
269 let x = Matrix::<f16>::new(f16::default(), 0, 10);
271 assert_eq!(
272 f16::compute_medoid(x.as_view()),
273 vec![f16::default(); x.ncols()]
274 );
275
276 let x = Matrix::<f16>::new(f16::default(), 10, 0);
278 assert_eq!(f16::compute_medoid(x.as_view()), Vec::<f16>::new());
279
280 let mut rng = StdRng::seed_from_u64(0x88e2f7096fc9b90e);
281
282 let dist = StandardUniform;
284 for dim in 1..20 {
285 let x = Matrix::<f16>::new(Init(|| f16::from_f32(dist.sample(&mut rng))), 1, dim);
286 assert_eq!(&*f16::compute_medoid(x.as_view()), x.row(0));
287 }
288
289 let (data, expected) = example_dataset();
291 let mut data_f16 = Matrix::<f16>::new(f16::default(), data.nrows(), data.ncols());
292 data_f16.as_mut_slice().cast_from_slice(data.as_slice());
293
294 let mut expected_f16 = vec![f16::default(); expected.len()];
295 expected_f16.cast_from_slice(expected.as_slice());
296
297 let m = f16::compute_medoid(data_f16.as_view());
298 assert_eq!(m, expected_f16);
299 }
300
301 fn example_dataset_u8() -> (Matrix<u8>, Vec<u8>) {
302 let data: Vec<u8> = vec![
303 52, 215, 218, 204, 192, 79, 55, 16, 89, 255, 167, 233, 141, 33, 30, 91, 53, 115, 236, 130, 191, 232, 33, 152, 1, 145, 111, 142, 122, 181, ];
310
311 let data = Matrix::<u8>::try_from(data.into(), 6, 5).unwrap();
312 let expected: Vec<u8> = data.row(5).into();
313 (data, expected)
314 }
315
316 #[test]
317 fn test_u8() {
318 let x = Matrix::<u8>::new(0u8, 0, 10);
320 assert_eq!(u8::compute_medoid(x.as_view()), vec![0u8; x.ncols()]);
321
322 let x = Matrix::<u8>::new(0u8, 10, 0);
324 assert_eq!(u8::compute_medoid(x.as_view()), Vec::<u8>::new());
325 let mut rng = StdRng::seed_from_u64(0x8f2f5fa0b5161acf);
326
327 let dist = StandardUniform;
329 for dim in 1..20 {
330 let x = Matrix::<u8>::new(Init(|| dist.sample(&mut rng)), 1, dim);
331 assert_eq!(&*u8::compute_medoid(x.as_view()), x.row(0));
332 }
333
334 let (data, expected) = example_dataset_u8();
336 let m = u8::compute_medoid(data.as_view());
337 assert_eq!(m, expected);
338 }
339
340 fn example_dataset_i8() -> (Matrix<i8>, Vec<i8>) {
342 let data: Vec<i8> = vec![
343 -76, 87, 90, 76, 64, -49, -73, -112, -39, 127, 39, 105, 13, -95, -98, -37, -75, -13, 108, 2, -37, -75, -13, 108, 2, 17, -17, 14, -6, 53, ];
350
351 let data = Matrix::<i8>::try_from(data.into(), 6, 5).unwrap();
352 let expected: Vec<i8> = data.row(5).into();
353 (data, expected)
354 }
355
356 #[test]
357 fn test_i8() {
358 let x = Matrix::<i8>::new(0i8, 0, 10);
360 assert_eq!(i8::compute_medoid(x.as_view()), vec![0i8; x.ncols()]);
361
362 let x = Matrix::<i8>::new(0i8, 10, 0);
364 assert_eq!(i8::compute_medoid(x.as_view()), Vec::<i8>::new());
365
366 let mut rng = StdRng::seed_from_u64(0x8f2f5fa0b5161acf);
367
368 let dist = StandardUniform;
370 for dim in 1..20 {
371 let x = Matrix::<i8>::new(Init(|| dist.sample(&mut rng)), 1, dim);
372 assert_eq!(&*i8::compute_medoid(x.as_view()), x.row(0));
373 }
374
375 let (data, expected) = example_dataset_i8();
377 let m = i8::compute_medoid(data.as_view());
378 assert_eq!(m, expected);
379 }
380}