coreset/
bmor.rs

1//! This module implement building blocks used a black box in coreset constructions.
2//! The algorithm computes an (alfa, beta) k-median approximation and is used as input
3//! to coreset computations.
4//!
5//! Adaptation of Streaming k-means on well clustered data.  
6//! Braverman Meyerson Ostrovski Roytman ACM-SIAM 2011 [braverman-2](https://dl.acm.org/doi/10.5555/2133036.2133039)
7//!
8//! **This algorithm can run in a streaming context**.  
9//!
10//! We do not constrain the clustering output to be exactly some value k but let the number of clusters be
11//! the result of the main algorithms.   
12//!   
13//! **Bmor algorithm dispatch points on the fly so it computes an upper bound of the cost**.  
14//! **But it is possible to [dispatch_data](crate::facility::Facilities::dispatch_data()) explicitly**
15//!
16//! This algorithm can process mnist fashion data in 1 second on a i9 laptop (without requiring heavy multithreading)
17//!
18//!
19
20use std::marker::PhantomData;
21
22use parking_lot::RwLock;
23use std::cell::RefCell;
24use std::sync::Arc;
25
26use anyhow::anyhow;
27
28use rand::distr::{Distribution, Uniform};
29use rand_xoshiro::Xoshiro256PlusPlus;
30use rand_xoshiro::rand_core::SeedableRng;
31
32use anndists::dist::*;
33
34use crate::facility::*;
35
36/// This structure stores the state of Bmor algorithm through iterations.
37/// In particular it stores allocated facilities.
38#[derive(Clone)]
39pub struct BmorState<DataId, T: Send + Sync + Clone, Dist: Distance<T>> {
40    // (1+logn)k
41    oneplogn: usize,
42    // nb iterations (phases)
43    phase: usize,
44    // initial cost factor
45    li: f64,
46    // at each phase we have an upper bound for cost.
47    phase_cost_upper: f64,
48    // upper bound on number of facilities
49    facility_bound: usize,
50    // current centers, associated to rank in stream (or in data) and weight (nb points in facility)
51    centers: Facilities<DataId, T, Dist>,
52    // sum of absolute value (some algos have <0 weights) of inserted weight
53    absolute_weight: f64,
54    // total cost
55    total_cost: f64,
56    //
57    nb_inserted: usize,
58    //
59    rng: Xoshiro256PlusPlus,
60    //
61    unif: Uniform<f64>,
62} // end of
63
64impl<
65    DataId: std::fmt::Debug + Clone + Send + Sync,
66    T: Send + Sync + Clone,
67    Dist: Distance<T> + Clone + Sync + Send,
68> BmorState<DataId, T, Dist>
69{
70    pub(crate) fn new(
71        k: usize,
72        nbdata: usize,
73        phase: usize,
74        alloc_size: usize,
75        upper_cost: f64,
76        facility_bound: usize,
77        distance: Dist,
78    ) -> Self {
79        let centers = Facilities::<DataId, T, Dist>::new(alloc_size, distance);
80        let unif = Uniform::<f64>::new(0., 1.).unwrap();
81        let rng = Xoshiro256PlusPlus::seed_from_u64(1454691);
82        let oneplogn = (1 + nbdata.ilog2()) as usize * k;
83        let li = 1.0f64;
84        //
85        log::info!("BmorState creation : facility bound : {:?}", facility_bound);
86        //
87        BmorState {
88            oneplogn,
89            phase,
90            li,
91            phase_cost_upper: upper_cost,
92            facility_bound,
93            centers,
94            absolute_weight: 0.,
95            total_cost: 0.,
96            nb_inserted: 0,
97            rng,
98            unif,
99        }
100    }
101
102    /// returns facilities as computed by the algorithm
103    pub fn get_facilities(&self) -> &Facilities<DataId, T, Dist> {
104        &self.centers
105    }
106
107    /// returns a mutable reference to facilities (useful for calling [dispatch_labesl](crate::facility::Facilities::dispatch_data())).
108    pub fn get_mut_facilities(&mut self) -> &mut Facilities<DataId, T, Dist> {
109        &mut self.centers
110    }
111
112    // get current phase num of processing
113    pub fn get_phase(&self) -> usize {
114        self.phase
115    }
116
117    #[allow(unused)]
118    pub(crate) fn get_li(&self) -> f64 {
119        self.li
120    }
121
122    pub(crate) fn get_nb_inserted(&self) -> usize {
123        self.nb_inserted
124    }
125
126    pub(crate) fn get_unif_sample(&mut self) -> f64 {
127        self.unif.sample(&mut self.rng)
128    }
129
130    pub(crate) fn get_phase_cost_bound(&self) -> f64 {
131        self.phase_cost_upper
132    }
133
134    /// get upper bound  for number of facilities
135    #[allow(unused)]
136    pub(crate) fn get_facility_upper_bound(&self) -> usize {
137        self.facility_bound
138    }
139
140    /// get sum of absolute value of weights inserted
141    pub(crate) fn get_weight(&self) -> f64 {
142        self.absolute_weight
143    }
144
145    /// get sum of absolute value of weights inserted
146    pub(crate) fn get_cost(&self) -> f64 {
147        self.total_cost
148    }
149
150    /// get nearest center/facility of a point, its rank and distance to facility
151    #[allow(clippy::type_complexity)]
152    pub fn get_nearest_center(
153        &self,
154        point: &[T],
155    ) -> Option<(&Arc<RwLock<Facility<DataId, T>>>, usize, f32)>
156    where
157        T: Send + Sync,
158        Dist: Sync,
159    {
160        //
161        let nb_facility = self.centers.len();
162        //
163        if nb_facility == 0 {
164            return None;
165        }
166        // get nearest facilty
167        let (rank, dist) = self.centers.get_nearest_facility(point, false).unwrap();
168        //
169        Some((self.centers.get_facility(rank).unwrap(), rank, dist))
170    } // end of get_nearest_center
171
172    /// insert into an already existing facility
173    /// return true if all is OK, false if costs or number of facilities got too large
174    fn update(&mut self, rank_id: DataId, point: &[T], weight: f64) -> bool {
175        //
176        log::trace!("in BmorState::update rank_id: {:?}", rank_id);
177        //
178        let dist_to_nearest: f32;
179        let nearest_facility: Arc<RwLock<Facility<DataId, T>>>;
180        {
181            let nearest_facility_res = self.get_nearest_center(point);
182            if nearest_facility_res.is_none() {
183                log::error!("internal error, update did not find nearest facility");
184                std::process::exit(1);
185            }
186            let nearest_center = nearest_facility_res.unwrap();
187            dist_to_nearest = nearest_center.2;
188            nearest_facility = nearest_center.0.clone();
189        }
190        // take into account f factor
191        if self.get_unif_sample()
192            < (weight * dist_to_nearest as f64 * self.oneplogn as f64 / self.li)
193        {
194            // we create a new facility. No cost increment
195            let mut new_f = Facility::<DataId, T>::new(rank_id, point);
196            new_f.insert(weight, 0.);
197            self.centers.insert(new_f);
198            // log::debug!("in BmorState::update  creating new facility around {}, nb_facilities : {}", rank_id, self.centers.len());
199        } else {
200            // log::debug!("in BmorState::update rank_id: {:?}, inserting in old facility dist : {:.3e}", rank_id, dist_to_nearest);
201            nearest_facility.write().insert(weight, dist_to_nearest);
202            self.total_cost += weight.abs() * dist_to_nearest as f64;
203        }
204        // we increments weight monitoring and number of insertions
205        self.absolute_weight += weight.abs();
206        self.nb_inserted += 1;
207        // check if we are above constraints
208        if self.total_cost > self.phase_cost_upper || self.centers.len() > self.facility_bound {
209            if log::log_enabled!(log::Level::Debug) {
210                log::debug!("constraint violation");
211                self.log();
212            }
213            false
214        } else {
215            true
216        }
217    } // end of update
218
219    // reinitialization. (upper cost rescaling)
220    pub(crate) fn reinit(&mut self, beta: f64) {
221        self.phase += 1;
222        self.phase_cost_upper *= beta;
223        self.li *= beta;
224        self.centers.clear();
225        self.absolute_weight = 0.;
226        self.total_cost = 0.;
227    }
228
229    pub(crate) fn log(&self) {
230        log::debug!("\n\n BmorState::log_state");
231        log::debug!("\n nb facilities : {:?}", self.centers.len());
232        log::debug!(
233            "\n weight : {:.3e}   cost {:.3e}",
234            self.get_weight(),
235            self.get_cost()
236        );
237        log::debug!(
238            "\n nb facility max : {:?}, upper cost bound : {:.3e}",
239            self.facility_bound,
240            self.get_phase_cost_bound()
241        );
242        log::info!(
243            "\n nb total insertion : {:?}  nb_phases: {:?}",
244            self.get_nb_inserted(),
245            self.phase + 1
246        );
247    }
248} // end of impl block BmorState
249
250#[cfg_attr(doc, katexit::katexit)]
251/// This structure gathers all parameters defining Bmor algorithm.  
252/// The algorithm do iterations with at each step an acceptable upper bound cost and upper bound on number
253/// facilities. The upper bounds are increased if iteration constraints are not satisfied, exisiting facilities are recycled as
254/// old points and the algortitm can go on with new incoming points in a streaming way.
255///
256/// These upper bounds are defined using 2 parameters : $ \beta $ and $ \gamma $.  
257///
258/// let $k$ be the number of expected facilities (centers),  the upper bound on number facilities is
259/// defined by : $ (\gamma −1) \space k \space (1+ \log_2 n)$.  
260/// At each iteration $i$ the upper bound of cost $C_{i}$ is defined  by $ \beta * C_{i-1} $ and the allocation of a facility
261/// is relaxed in a coherent way.  
262/// As for large n the resulting number of allocated facilities can be larger than k it is possible to ask for an end step [end_step](Self::end_data()) that
263/// will reduce the number of facilities to less than $ (\gamma −1) \space k \space (1+ \log_2 nbfacility)$
264///
265/// The data are affected to a facility on the fly (useful in streaming).
266/// But it is possible for a point to be nearer to a facility opened later with data arriving after it.  
267/// So the dispatching cost can be optimized a posteriori (in a second pass on the data) with method [dispatch_data](crate::facility::Facilities::dispatch_data())
268///   
269///
270/// $\beta$ and $\gamma$ can be initialized by 2.
271pub struct Bmor<DataId, T: Send + Sync + Clone, Dist: Distance<T>> {
272    // base number of centers expected
273    k: usize,
274    //
275    nbdata_expected: usize,
276    // cost multiplicative factor for upper bound of accepted cost at each phase.
277    beta: f64,
278    //  slackness parameters for cost and number of centers accepted
279    gamma: f64,
280    //
281    distance: Dist,
282    // store computation state
283    state: RefCell<BmorState<DataId, T, Dist>>,
284    //
285    _t: PhantomData<T>,
286} // end of struct Bmor
287
288impl<DataId, T: Send + Sync + Clone, Dist> Bmor<DataId, T, Dist>
289where
290    Dist: Distance<T> + Clone + Sync + Send,
291    DataId: std::fmt::Debug + Clone + Send + Sync,
292{
293    #[allow(clippy::doc_lazy_continuation)]
294    /// Args are:  
295    /// - k: number of centers.  
296    /// - nbdata : nb data expected.     
297    ///  As this algorithm can be used in streaming (successive calls to methods [process_data](Self::process_data()) or [process_weighted_data](Self::process_weighted_data()) the exact number of data can be larger than the length or arguments passed to these methods.
298    ///
299    /// - beta : upper cost multiplicative factor.  
300    /// - gamma : slackness factor for number facilities upper bound.  
301    /// - end_step : if true a second step is done to further reduce the number of facilities.
302    ///         
303    pub fn new(
304        k_arg: usize,
305        nbdata_expected: usize,
306        beta: f64,
307        gamma: f64,
308        distance: Dist,
309    ) -> Self {
310        // We restrict k to be adjusted to nbdata_expected to avoid k too large compared to nb_data !
311        let k = if k_arg > (nbdata_expected as f64).sqrt().trunc() as usize {
312            let kmax = k_arg.min((1. + nbdata_expected as f64).sqrt() as usize);
313            log::info!("resetting number of centers to : {}", kmax);
314            kmax
315        } else {
316            k_arg
317        };
318        // This is orginal formula of the paper
319        let nb_centers_bound =
320            ((gamma - 1.) * (1. + nbdata_expected.ilog2() as f64) * k as f64).trunc() as usize;
321        let upper_cost = gamma;
322        let state = BmorState::<DataId, T, Dist>::new(
323            k,
324            nbdata_expected,
325            0,
326            nb_centers_bound as usize,
327            upper_cost,
328            nb_centers_bound,
329            distance.clone(),
330        );
331        //
332        Bmor {
333            k,
334            nbdata_expected,
335            beta,
336            gamma,
337            distance,
338            state: RefCell::new(state),
339            _t: PhantomData::<T>,
340        }
341    }
342
343    /// return expected number of facilities (clusters)
344    pub fn get_k(&self) -> usize {
345        self.k
346    }
347
348    /// get_beta
349    pub fn get_beta(&self) -> f64 {
350        self.beta
351    }
352
353    /// get gamma
354    pub fn get_gamma(&self) -> f64 {
355        self.gamma
356    }
357
358    /// treat unweighted data.
359    /// **This method can be called many times in case of data streaming, passing data by blocks**.  
360    /// It returns the number of facilities created up to this call.
361    /// id are data id (anything identifying data point)
362    pub fn process_data(&mut self, data: &[Vec<T>], id: &[DataId]) -> anyhow::Result<usize> {
363        //
364        let weighted_data: Vec<(f64, &Vec<T>, DataId)> = (0..data.len())
365            .map(|i| (1., &data[i], id[i].clone()))
366            .collect();
367        self.process_weighted_block(&weighted_data);
368        //
369        let state = self.state.borrow();
370        state.log();
371        if log::log_enabled!(log::Level::Debug) {
372            state.get_facilities().log(1);
373        }
374        //
375        Ok(state.get_facilities().len())
376    } // end of process_data
377
378    //
379    #[allow(clippy::let_and_return)]
380    /// declare end of streaming data.
381    /// This method returns the facilities created.
382    /// if contraction flag is set to true, a final pass of the bmor algorithm will be used to try to reduce the
383    /// number of facilities created by previous call to [process_data](Self::process_data()) or [process_weighted_data](Self::process_weighted_data())
384    pub fn end_data(&self, contraction: bool) -> Facilities<DataId, T, Dist> {
385        let facilities = match contraction {
386            false => {
387                let facilities_ret = self.state.borrow().get_facilities().clone();
388                facilities_ret.log(0);
389                facilities_ret
390            }
391            true => {
392                log::info!("\n\n bmor doing final bmor pass ...");
393                // note that state_2 is not saved anywhere, but this last step is easy to do by hans as the caller has the facilties.
394                let res = self.bmor_contraction();
395                if res.is_err() {
396                    std::panic!("bmor_contraction failed");
397                }
398                //
399                let state_2 = res.unwrap();
400                state_2.log();
401                //
402                let facilities = state_2.get_facilities();
403                facilities.clone()
404            }
405        };
406        facilities
407    } // end of end_data
408
409    /// treat data with weights attached.
410    /// **This method can be called many times in case of data streaming, passing data by blocks**.  
411    /// It returns the number of facilities created up to this call.
412    /// a data trplet consists in a weight , data vector and data id  (anything identifying data point)
413    pub fn process_weighted_data(
414        &self,
415        weighted_data: &[(f64, &Vec<T>, DataId)],
416    ) -> anyhow::Result<usize> {
417        //
418        self.process_weighted_block(weighted_data);
419        //
420        let state = self.state.borrow();
421        //
422        state.log();
423        if log::log_enabled!(log::Level::Debug) {
424            state.get_facilities().log(1);
425        }
426        //
427        Ok(state.get_facilities().len())
428    } // end of process_weighted_data
429
430    // We recur (once) to reduce number of facilities. To go from $1 + k * logn$ to $1 + k * log(log(n))$
431    // (We tried to reduce with imp algo but not better)
432    pub(crate) fn bmor_contraction(&self) -> anyhow::Result<BmorState<DataId, T, Dist>> {
433        //
434        log::info!("\n bmor recurring");
435        // extract weighted data
436        let facility_data = self.state.borrow().get_facilities().into_weighted_data();
437        //
438        // allocate another Bmor state. TODO: change some parameters gamma ?
439        //
440        log::info!(
441            "bmor_recur , nb facilities received : {:?}",
442            facility_data.len()
443        );
444        //
445        let weighted_data: Vec<(f64, &Vec<T>, DataId)> = (0..facility_data.len())
446            .map(|i| {
447                (
448                    facility_data[i].0,
449                    &facility_data[i].1,
450                    facility_data[i].2.clone(),
451                )
452            })
453            .collect();
454        let _bound_2 = self.nbdata_expected.ilog2() as usize;
455        // we could to adapt to number of facilities and we could impose a log reduction in input size for each step.
456        // by bounding nb_expected_data with min(_bound_2)
457        // let nb_expected_data = weighted_data.len().min(_bound_2);
458        let nb_expected_data = weighted_data.len();
459        if self.state.borrow().get_nb_inserted() > self.k * (1 + nb_expected_data.ilog2() as usize)
460        {
461            log::debug!(
462                "reducing number of facilities: setting expected nb data : {:?}",
463                nb_expected_data
464            );
465            let bmor_algo_2: Bmor<DataId, T, Dist> = Bmor::new(
466                self.get_k(),
467                nb_expected_data,
468                self.get_beta(),
469                self.get_gamma(),
470                self.distance.clone(),
471            );
472            //
473            let res = bmor_algo_2.process_weighted_data(&weighted_data);
474            if res.is_err() {
475                return Err(anyhow!("constraction failed"));
476            }
477            let state_2 = bmor_algo_2.state.borrow();
478            state_2.get_facilities().log(0);
479            Ok(state_2.clone())
480        } else {
481            let state = self.state.borrow();
482            state.log();
483            state.get_facilities().log(0);
484            Ok(state.clone())
485        }
486    } // end of bmor_recur
487
488    // This method is the real working method.
489    // It inserts data, update state, and drive recurrence
490    // args is a vecotr of triplets (weight, data, data_id)
491    fn process_weighted_block(&self, data: &[(f64, &Vec<T>, DataId)]) {
492        //
493        log::debug!(
494            "entering process_weighted_block, phase : {:?}, nb data : {}",
495            self.state.borrow().get_phase(),
496            data.len()
497        );
498        //
499        for d in data {
500            // TODO: now we use rank as rank_id (sufficicent for ordered ids)
501            log::trace!("treating rank_id : {:?}, weight : {:.4e}", d.2, d.0);
502            let add_res = self.add_data(d.2.clone(), d.1, d.0);
503            if !add_res {
504                // allocate new state
505                log::debug!(
506                    "recycling facilities, incrementing upper bound for cost, nb_facilities : {:?}",
507                    self.state.borrow().get_facilities().len()
508                );
509                // recycle facilitites in process adding them
510                let weighted_data: Vec<(f64, Vec<T>, DataId)> = self
511                    .state
512                    .borrow()
513                    .centers
514                    .get_vec()
515                    .iter()
516                    .map(|f| {
517                        (
518                            f.read().get_weight(),
519                            f.read().get_position().clone(),
520                            f.read().get_dataid(),
521                        )
522                    })
523                    .collect();
524                assert!(!weighted_data.is_empty());
525                let weighted_ref_data: Vec<(f64, &Vec<T>, DataId)> = weighted_data
526                    .iter()
527                    .map(|wd| (wd.0, &wd.1, wd.2.clone()))
528                    .collect();
529                assert!(!weighted_ref_data.is_empty());
530                self.state.borrow_mut().reinit(self.beta);
531                self.process_weighted_block(&weighted_ref_data);
532            }
533        }
534    } // end of process_weighted_block
535
536    // This function return true except if we got beyond bound for cost or number of facilities
537    // The data added can be a facility extracted during a preceding phase
538    pub(crate) fn add_data(&self, rank_id: DataId, data: &[T], weight: f64) -> bool {
539        //
540        let mut state = self.state.borrow_mut();
541        let facilities = state.get_mut_facilities();
542        // get nearest facility or open facility
543        if facilities.is_empty() {
544            log::debug!(
545                "Bmor::add_data creating facility rank_id : {:?} with weight : {:.3e}",
546                rank_id,
547                weight
548            );
549            let mut new_f = Facility::<DataId, T>::new(rank_id, data);
550            new_f.insert(weight, 0.);
551            facilities.insert(new_f);
552            // we update global state here in facility creation case
553            state.nb_inserted += 1;
554            state.absolute_weight += weight;
555            return true;
556        }
557        // we already have a facility we update state
558        state.update(rank_id, data, weight)
559    } // end of add_data
560
561    pub fn log(&self) {
562        self.state.borrow().log();
563    }
564} // end of impl block Bmor