Skip to main content

faiss_next/
clustering.rs

1use std::ptr;
2
3use faiss_next_sys::{self, FaissClustering, FaissClusteringParameters};
4
5use crate::error::{check_return_code, Result};
6use crate::index::Index;
7
8pub struct Clustering {
9    inner: *mut FaissClustering,
10}
11
12impl Clustering {
13    pub fn new(d: u32, k: u32) -> Result<Self> {
14        unsafe {
15            let mut inner = ptr::null_mut();
16            check_return_code(faiss_next_sys::faiss_Clustering_new(
17                &mut inner, d as i32, k as i32,
18            ))?;
19            Ok(Self { inner })
20        }
21    }
22
23    pub fn new_with_params(d: u32, k: u32, params: &ClusteringParameters) -> Result<Self> {
24        unsafe {
25            let mut inner = ptr::null_mut();
26            check_return_code(faiss_next_sys::faiss_Clustering_new_with_params(
27                &mut inner,
28                d as i32,
29                k as i32,
30                params.inner,
31            ))?;
32            Ok(Self { inner })
33        }
34    }
35
36    pub fn train(&mut self, n: u64, x: &[f32], index: &mut impl Index) -> Result<()> {
37        check_return_code(unsafe {
38            faiss_next_sys::faiss_Clustering_train(
39                self.inner,
40                n as i64,
41                x.as_ptr(),
42                index.inner_ptr(),
43            )
44        })
45    }
46
47    pub fn niter(&self) -> i32 {
48        unsafe { faiss_next_sys::faiss_Clustering_niter(self.inner) }
49    }
50
51    pub fn k(&self) -> usize {
52        unsafe { faiss_next_sys::faiss_Clustering_k(self.inner) }
53    }
54
55    pub fn d(&self) -> usize {
56        unsafe { faiss_next_sys::faiss_Clustering_d(self.inner) }
57    }
58
59    pub fn centroids(&self) -> Vec<f32> {
60        unsafe {
61            let mut ptr = ptr::null_mut();
62            let mut size = 0usize;
63            faiss_next_sys::faiss_Clustering_centroids(self.inner, &mut ptr, &mut size);
64            if ptr.is_null() || size == 0 {
65                Vec::new()
66            } else {
67                std::slice::from_raw_parts(ptr, size).to_vec()
68            }
69        }
70    }
71
72    pub fn verbose(&self) -> bool {
73        unsafe { faiss_next_sys::faiss_Clustering_verbose(self.inner) != 0 }
74    }
75
76    pub fn seed(&self) -> i32 {
77        unsafe { faiss_next_sys::faiss_Clustering_seed(self.inner) }
78    }
79}
80
81impl Drop for Clustering {
82    fn drop(&mut self) {
83        if !self.inner.is_null() {
84            unsafe {
85                faiss_next_sys::faiss_Clustering_free(self.inner);
86            }
87        }
88    }
89}
90
91pub struct ClusteringParameters {
92    inner: *mut FaissClusteringParameters,
93}
94
95impl ClusteringParameters {
96    pub fn new() -> Result<Self> {
97        unsafe {
98            let mut inner = Box::new(std::mem::zeroed::<FaissClusteringParameters>());
99            faiss_next_sys::faiss_ClusteringParameters_init(inner.as_mut() as *mut _);
100            Ok(Self {
101                inner: Box::into_raw(inner),
102            })
103        }
104    }
105
106    pub fn niter(&mut self, niter: i32) -> &mut Self {
107        unsafe { (*self.inner).niter = niter }
108        self
109    }
110
111    pub fn verbose(&mut self, verbose: bool) -> &mut Self {
112        unsafe { (*self.inner).verbose = verbose as i32 }
113        self
114    }
115
116    pub fn spherical(&mut self, spherical: bool) -> &mut Self {
117        unsafe { (*self.inner).spherical = spherical as i32 }
118        self
119    }
120
121    pub fn min_points_per_centroid(&mut self, n: i32) -> &mut Self {
122        unsafe { (*self.inner).min_points_per_centroid = n }
123        self
124    }
125
126    pub fn max_points_per_centroid(&mut self, n: i32) -> &mut Self {
127        unsafe { (*self.inner).max_points_per_centroid = n }
128        self
129    }
130
131    pub fn seed(&mut self, seed: i32) -> &mut Self {
132        unsafe { (*self.inner).seed = seed }
133        self
134    }
135
136    pub fn nredo(&mut self, nredo: i32) -> &mut Self {
137        unsafe { (*self.inner).nredo = nredo }
138        self
139    }
140}
141
142impl Default for ClusteringParameters {
143    fn default() -> Self {
144        Self::new().expect("failed to create ClusteringParameters")
145    }
146}
147
148impl Drop for ClusteringParameters {
149    fn drop(&mut self) {
150        if !self.inner.is_null() {
151            unsafe {
152                let _ = Box::from_raw(self.inner);
153            }
154        }
155    }
156}