Skip to main content

diskann_utils/sampling/
medoid.rs

1/*
2 * Copyright (c) Microsoft Corporation.
3 * Licensed under the MIT license.
4 */
5
6use crate::views::MatrixView;
7use diskann_vector::{conversion::CastFromSlice, distance::SquaredL2, PureDistanceFunction};
8use half::f16;
9
10/// Return the row in `data` that is closest to the medoid of all rows.
11pub trait ComputeMedoid: Sized {
12    fn compute_medoid(data: MatrixView<Self>) -> Vec<Self>;
13}
14
15impl ComputeMedoid for f32 {
16    fn compute_medoid(data: MatrixView<Self>) -> Vec<Self> {
17        if data.ncols() == 0 {
18            return vec![];
19        }
20
21        let mut sum = vec![0.0f64; data.ncols()];
22        data.row_iter().for_each(|r| {
23            std::iter::zip(sum.iter_mut(), r.iter()).for_each(|(o, i)| {
24                let i: f64 = (*i).into();
25                *o += i;
26            });
27        });
28
29        let m: Vec<f32> = sum
30            .iter()
31            .map(|s| (s / data.nrows() as f64) as f32)
32            .collect();
33
34        let mut min_dist: f32 = f32::MAX;
35        let mut medoid = None;
36        data.row_iter().for_each(|r| {
37            let d = SquaredL2::evaluate(m.as_slice(), r);
38            if d < min_dist {
39                min_dist = d;
40                medoid = Some(r);
41            }
42        });
43
44        medoid
45            .map(|x| x.into())
46            .unwrap_or(vec![0.0f32; data.ncols()])
47    }
48}
49
50impl ComputeMedoid for f16 {
51    fn compute_medoid(data: MatrixView<Self>) -> Vec<Self> {
52        if data.ncols() == 0 {
53            return vec![];
54        }
55
56        let mut sum = vec![0.0f64; data.ncols()];
57        let mut buffer = vec![0.0f32; data.ncols()];
58        data.row_iter().for_each(|r| {
59            buffer.cast_from_slice(r);
60            std::iter::zip(sum.iter_mut(), buffer.iter()).for_each(|(o, i)| {
61                let i: f64 = (*i).into();
62                *o += i;
63            });
64        });
65
66        std::iter::zip(buffer.iter_mut(), sum.iter()).for_each(|(o, i)| {
67            *o = (*i / data.nrows() as f64) as f32;
68        });
69
70        let mut min_dist: f32 = f32::MAX;
71        let mut medoid = None;
72        data.row_iter().for_each(|r| {
73            let d = SquaredL2::evaluate(buffer.as_slice(), r);
74            if d < min_dist {
75                min_dist = d;
76                medoid = Some(r);
77            }
78        });
79
80        medoid
81            .map(|x| x.into())
82            .unwrap_or(vec![f16::default(); data.ncols()])
83    }
84}
85
86impl ComputeMedoid for u8 {
87    fn compute_medoid(data: MatrixView<Self>) -> Vec<Self> {
88        if data.ncols() == 0 {
89            return vec![];
90        }
91
92        let mut sum = vec![0.0f64; data.ncols()];
93        data.row_iter().for_each(|r| {
94            std::iter::zip(sum.iter_mut(), r.iter()).for_each(|(o, i)| {
95                let i: f64 = (*i).into();
96                *o += i;
97            });
98        });
99
100        let m: Vec<f32> = sum
101            .iter()
102            .map(|s| (s / data.nrows() as f64) as f32)
103            .collect();
104
105        let mut min_dist: f32 = f32::MAX;
106        let mut medoid = None;
107        let mut as_float = vec![0.0f32; data.ncols()];
108        data.row_iter().for_each(|r| {
109            std::iter::zip(as_float.iter_mut(), r.iter())
110                .for_each(|(dst, src)| *dst = (*src).into());
111            let d = SquaredL2::evaluate(m.as_slice(), &*as_float);
112            if d < min_dist {
113                min_dist = d;
114                medoid = Some(r);
115            }
116        });
117
118        medoid.map(|x| x.into()).unwrap_or(vec![0u8; data.ncols()])
119    }
120}
121
122impl ComputeMedoid for i8 {
123    fn compute_medoid(data: MatrixView<Self>) -> Vec<Self> {
124        if data.ncols() == 0 {
125            return vec![];
126        }
127
128        let mut sum = vec![0.0f64; data.ncols()];
129        data.row_iter().for_each(|r| {
130            std::iter::zip(sum.iter_mut(), r.iter()).for_each(|(o, i)| {
131                let i: f64 = (*i).into();
132                *o += i;
133            });
134        });
135
136        let m: Vec<f32> = sum
137            .iter()
138            .map(|s| (s / data.nrows() as f64) as f32)
139            .collect();
140
141        let mut min_dist: f32 = f32::MAX;
142        let mut medoid = None;
143        let mut as_float = vec![0.0f32; data.ncols()];
144        data.row_iter().for_each(|r| {
145            std::iter::zip(as_float.iter_mut(), r.iter())
146                .for_each(|(dst, src)| *dst = (*src).into());
147            let d = SquaredL2::evaluate(m.as_slice(), &*as_float);
148            if d < min_dist {
149                min_dist = d;
150                medoid = Some(r);
151            }
152        });
153
154        medoid.map(|x| x.into()).unwrap_or(vec![0i8; data.ncols()])
155    }
156}
157
158///////////
159// Tests //
160///////////
161
162#[cfg(test)]
163mod tests {
164    use crate::views::{Init, Matrix};
165    use rand::{
166        distr::{Distribution, StandardUniform},
167        rngs::StdRng,
168        SeedableRng,
169    };
170
171    use super::*;
172
173    fn example_dataset() -> (Matrix<f32>, Vec<f32>) {
174        let data: Vec<f32> = vec![
175            // row 0
176            0.203688,
177            0.841956,
178            0.855665,
179            0.801917,
180            0.754536,
181            // row 1
182            0.312881,
183            0.217382,
184            0.0644115,
185            0.348708,
186            0.999495,
187            // row 2
188            0.657741,
189            0.914681,
190            0.555228,
191            0.13253,
192            0.118615,
193            // row 3
194            0.356464,
195            0.207449,
196            0.452471,
197            0.925219,
198            0.508498,
199            // row 4
200            0.749786,
201            0.90786,
202            0.129618,
203            0.597719,
204            0.000622153,
205            // row 5 -- this is the medoid
206            0.569517,
207            0.435447,
208            0.558136,
209            0.480974,
210            0.711425,
211            // row 6
212            0.896353,
213            0.275053,
214            0.0427179,
215            0.660916,
216            0.464851,
217            // row 7
218            0.558689,
219            0.596543,
220            0.740983,
221            0.122136,
222            0.453822,
223            // row 8
224            0.526895,
225            0.492643,
226            0.0951115,
227            0.495487,
228            0.446127,
229            // row 9
230            0.454093,
231            0.160239,
232            0.924585,
233            0.901708,
234            0.329328,
235        ];
236
237        let data = Matrix::<f32>::try_from(data.into(), 10, 5).unwrap();
238        let expected: Vec<f32> = data.row(5).into();
239        (data, expected)
240    }
241
242    #[test]
243    fn test_f32() {
244        // No Rows
245        let x = Matrix::<f32>::new(0.0f32, 0, 10);
246        assert_eq!(f32::compute_medoid(x.as_view()), vec![0.0; x.ncols()]);
247
248        // No Cols
249        let x = Matrix::<f32>::new(0.0f32, 10, 0);
250        assert_eq!(f32::compute_medoid(x.as_view()), Vec::<f32>::new());
251
252        let mut rng = StdRng::seed_from_u64(0xaf2f5fa0b5161acf);
253
254        // One row
255        let dist = StandardUniform;
256        for dim in 1..20 {
257            let x = Matrix::<f32>::new(Init(|| dist.sample(&mut rng)), 1, dim);
258            assert_eq!(&*f32::compute_medoid(x.as_view()), x.row(0));
259        }
260
261        // Example dataset
262        let (data, expected) = example_dataset();
263        let m = f32::compute_medoid(data.as_view());
264        assert_eq!(m, expected);
265    }
266
267    #[test]
268    fn test_f16() {
269        // No Rows
270        let x = Matrix::<f16>::new(f16::default(), 0, 10);
271        assert_eq!(
272            f16::compute_medoid(x.as_view()),
273            vec![f16::default(); x.ncols()]
274        );
275
276        // No Cols
277        let x = Matrix::<f16>::new(f16::default(), 10, 0);
278        assert_eq!(f16::compute_medoid(x.as_view()), Vec::<f16>::new());
279
280        let mut rng = StdRng::seed_from_u64(0x88e2f7096fc9b90e);
281
282        // One row
283        let dist = StandardUniform;
284        for dim in 1..20 {
285            let x = Matrix::<f16>::new(Init(|| f16::from_f32(dist.sample(&mut rng))), 1, dim);
286            assert_eq!(&*f16::compute_medoid(x.as_view()), x.row(0));
287        }
288
289        // Example dataset
290        let (data, expected) = example_dataset();
291        let mut data_f16 = Matrix::<f16>::new(f16::default(), data.nrows(), data.ncols());
292        data_f16.as_mut_slice().cast_from_slice(data.as_slice());
293
294        let mut expected_f16 = vec![f16::default(); expected.len()];
295        expected_f16.cast_from_slice(expected.as_slice());
296
297        let m = f16::compute_medoid(data_f16.as_view());
298        assert_eq!(m, expected_f16);
299    }
300
301    fn example_dataset_u8() -> (Matrix<u8>, Vec<u8>) {
302        let data: Vec<u8> = vec![
303            52, 215, 218, 204, 192, // row 0
304            79, 55, 16, 89, 255, // row 1
305            167, 233, 141, 33, 30, // row 2
306            91, 53, 115, 236, 130, // row 3
307            191, 232, 33, 152, 1, // row 4
308            145, 111, 142, 122, 181, // row 5 -- this is the medoid
309        ];
310
311        let data = Matrix::<u8>::try_from(data.into(), 6, 5).unwrap();
312        let expected: Vec<u8> = data.row(5).into();
313        (data, expected)
314    }
315
316    #[test]
317    fn test_u8() {
318        // No Rows
319        let x = Matrix::<u8>::new(0u8, 0, 10);
320        assert_eq!(u8::compute_medoid(x.as_view()), vec![0u8; x.ncols()]);
321
322        // No Cols
323        let x = Matrix::<u8>::new(0u8, 10, 0);
324        assert_eq!(u8::compute_medoid(x.as_view()), Vec::<u8>::new());
325        let mut rng = StdRng::seed_from_u64(0x8f2f5fa0b5161acf);
326
327        // One row
328        let dist = StandardUniform;
329        for dim in 1..20 {
330            let x = Matrix::<u8>::new(Init(|| dist.sample(&mut rng)), 1, dim);
331            assert_eq!(&*u8::compute_medoid(x.as_view()), x.row(0));
332        }
333
334        // Example dataset
335        let (data, expected) = example_dataset_u8();
336        let m = u8::compute_medoid(data.as_view());
337        assert_eq!(m, expected);
338    }
339
340    // This is a test for the i8 medoid function. Each entry is between -128 and 127.
341    fn example_dataset_i8() -> (Matrix<i8>, Vec<i8>) {
342        let data: Vec<i8> = vec![
343            -76, 87, 90, 76, 64, // row 0
344            -49, -73, -112, -39, 127, // row 1
345            39, 105, 13, -95, -98, // row 2
346            -37, -75, -13, 108, 2, // row 3
347            -37, -75, -13, 108, 2, // row 4
348            17, -17, 14, -6, 53, // row 5 -- this is the medoid
349        ];
350
351        let data = Matrix::<i8>::try_from(data.into(), 6, 5).unwrap();
352        let expected: Vec<i8> = data.row(5).into();
353        (data, expected)
354    }
355
356    #[test]
357    fn test_i8() {
358        // No Rows
359        let x = Matrix::<i8>::new(0i8, 0, 10);
360        assert_eq!(i8::compute_medoid(x.as_view()), vec![0i8; x.ncols()]);
361
362        // No Cols
363        let x = Matrix::<i8>::new(0i8, 10, 0);
364        assert_eq!(i8::compute_medoid(x.as_view()), Vec::<i8>::new());
365
366        let mut rng = StdRng::seed_from_u64(0x8f2f5fa0b5161acf);
367
368        // One row
369        let dist = StandardUniform;
370        for dim in 1..20 {
371            let x = Matrix::<i8>::new(Init(|| dist.sample(&mut rng)), 1, dim);
372            assert_eq!(&*i8::compute_medoid(x.as_view()), x.row(0));
373        }
374
375        // Example dataset
376        let (data, expected) = example_dataset_i8();
377        let m = i8::compute_medoid(data.as_view());
378        assert_eq!(m, expected);
379    }
380}