use std::marker::PhantomData;
use parking_lot::RwLock;
use std::cell::RefCell;
use std::sync::Arc;
use anyhow::anyhow;
use rand::distr::{Distribution, Uniform};
use rand_xoshiro::Xoshiro256PlusPlus;
use rand_xoshiro::rand_core::SeedableRng;
use anndists::dist::*;
use crate::facility::*;
#[derive(Clone)]
pub struct BmorState<DataId, T: Send + Sync + Clone, Dist: Distance<T>> {
oneplogn: usize,
phase: usize,
li: f64,
phase_cost_upper: f64,
facility_bound: usize,
centers: Facilities<DataId, T, Dist>,
absolute_weight: f64,
total_cost: f64,
nb_inserted: usize,
rng: Xoshiro256PlusPlus,
unif: Uniform<f64>,
}
impl<
DataId: std::fmt::Debug + Clone + Send + Sync,
T: Send + Sync + Clone,
Dist: Distance<T> + Clone + Sync + Send,
> BmorState<DataId, T, Dist>
{
pub(crate) fn new(
k: usize,
nbdata: usize,
phase: usize,
alloc_size: usize,
upper_cost: f64,
facility_bound: usize,
distance: Dist,
) -> Self {
let centers = Facilities::<DataId, T, Dist>::new(alloc_size, distance);
let unif = Uniform::<f64>::new(0., 1.).unwrap();
let rng = Xoshiro256PlusPlus::seed_from_u64(1454691);
let oneplogn = (1 + nbdata.ilog2()) as usize * k;
let li = 1.0f64;
log::info!("BmorState creation : facility bound : {:?}", facility_bound);
BmorState {
oneplogn,
phase,
li,
phase_cost_upper: upper_cost,
facility_bound,
centers,
absolute_weight: 0.,
total_cost: 0.,
nb_inserted: 0,
rng,
unif,
}
}
pub fn get_facilities(&self) -> &Facilities<DataId, T, Dist> {
&self.centers
}
pub fn get_mut_facilities(&mut self) -> &mut Facilities<DataId, T, Dist> {
&mut self.centers
}
pub fn get_phase(&self) -> usize {
self.phase
}
#[allow(unused)]
pub(crate) fn get_li(&self) -> f64 {
self.li
}
pub(crate) fn get_nb_inserted(&self) -> usize {
self.nb_inserted
}
pub(crate) fn get_unif_sample(&mut self) -> f64 {
self.unif.sample(&mut self.rng)
}
pub(crate) fn get_phase_cost_bound(&self) -> f64 {
self.phase_cost_upper
}
#[allow(unused)]
pub(crate) fn get_facility_upper_bound(&self) -> usize {
self.facility_bound
}
pub(crate) fn get_weight(&self) -> f64 {
self.absolute_weight
}
pub(crate) fn get_cost(&self) -> f64 {
self.total_cost
}
#[allow(clippy::type_complexity)]
pub fn get_nearest_center(
&self,
point: &[T],
) -> Option<(&Arc<RwLock<Facility<DataId, T>>>, usize, f32)>
where
T: Send + Sync,
Dist: Sync,
{
let nb_facility = self.centers.len();
if nb_facility == 0 {
return None;
}
let (rank, dist) = self.centers.get_nearest_facility(point, false).unwrap();
Some((self.centers.get_facility(rank).unwrap(), rank, dist))
}
fn update(&mut self, rank_id: DataId, point: &[T], weight: f64) -> bool {
log::trace!("in BmorState::update rank_id: {:?}", rank_id);
let dist_to_nearest: f32;
let nearest_facility: Arc<RwLock<Facility<DataId, T>>>;
{
let nearest_facility_res = self.get_nearest_center(point);
if nearest_facility_res.is_none() {
log::error!("internal error, update did not find nearest facility");
std::process::exit(1);
}
let nearest_center = nearest_facility_res.unwrap();
dist_to_nearest = nearest_center.2;
nearest_facility = nearest_center.0.clone();
}
if self.get_unif_sample()
< (weight * dist_to_nearest as f64 * self.oneplogn as f64 / self.li)
{
let mut new_f = Facility::<DataId, T>::new(rank_id, point);
new_f.insert(weight, 0.);
self.centers.insert(new_f);
} else {
nearest_facility.write().insert(weight, dist_to_nearest);
self.total_cost += weight.abs() * dist_to_nearest as f64;
}
self.absolute_weight += weight.abs();
self.nb_inserted += 1;
if self.total_cost > self.phase_cost_upper || self.centers.len() > self.facility_bound {
if log::log_enabled!(log::Level::Debug) {
log::debug!("constraint violation");
self.log();
}
false
} else {
true
}
}
pub(crate) fn reinit(&mut self, beta: f64) {
self.phase += 1;
self.phase_cost_upper *= beta;
self.li *= beta;
self.centers.clear();
self.absolute_weight = 0.;
self.total_cost = 0.;
}
pub(crate) fn log(&self) {
log::debug!("\n\n BmorState::log_state");
log::debug!("\n nb facilities : {:?}", self.centers.len());
log::debug!(
"\n weight : {:.3e} cost {:.3e}",
self.get_weight(),
self.get_cost()
);
log::debug!(
"\n nb facility max : {:?}, upper cost bound : {:.3e}",
self.facility_bound,
self.get_phase_cost_bound()
);
log::info!(
"\n nb total insertion : {:?} nb_phases: {:?}",
self.get_nb_inserted(),
self.phase + 1
);
}
}
#[cfg_attr(doc, katexit::katexit)]
pub struct Bmor<DataId, T: Send + Sync + Clone, Dist: Distance<T>> {
k: usize,
nbdata_expected: usize,
beta: f64,
gamma: f64,
distance: Dist,
state: RefCell<BmorState<DataId, T, Dist>>,
_t: PhantomData<T>,
}
impl<DataId, T: Send + Sync + Clone, Dist> Bmor<DataId, T, Dist>
where
Dist: Distance<T> + Clone + Sync + Send,
DataId: std::fmt::Debug + Clone + Send + Sync,
{
#[allow(clippy::doc_lazy_continuation)]
pub fn new(
k_arg: usize,
nbdata_expected: usize,
beta: f64,
gamma: f64,
distance: Dist,
) -> Self {
let k = if k_arg > (nbdata_expected as f64).sqrt().trunc() as usize {
let kmax = k_arg.min((1. + nbdata_expected as f64).sqrt() as usize);
log::info!("resetting number of centers to : {}", kmax);
kmax
} else {
k_arg
};
let nb_centers_bound =
((gamma - 1.) * (1. + nbdata_expected.ilog2() as f64) * k as f64).trunc() as usize;
let upper_cost = gamma;
let state = BmorState::<DataId, T, Dist>::new(
k,
nbdata_expected,
0,
nb_centers_bound as usize,
upper_cost,
nb_centers_bound,
distance.clone(),
);
Bmor {
k,
nbdata_expected,
beta,
gamma,
distance,
state: RefCell::new(state),
_t: PhantomData::<T>,
}
}
pub fn get_k(&self) -> usize {
self.k
}
pub fn get_beta(&self) -> f64 {
self.beta
}
pub fn get_gamma(&self) -> f64 {
self.gamma
}
pub fn process_data(&mut self, data: &[Vec<T>], id: &[DataId]) -> anyhow::Result<usize> {
let weighted_data: Vec<(f64, &Vec<T>, DataId)> = (0..data.len())
.map(|i| (1., &data[i], id[i].clone()))
.collect();
self.process_weighted_block(&weighted_data);
let state = self.state.borrow();
state.log();
if log::log_enabled!(log::Level::Debug) {
state.get_facilities().log(1);
}
Ok(state.get_facilities().len())
}
#[allow(clippy::let_and_return)]
pub fn end_data(&self, contraction: bool) -> Facilities<DataId, T, Dist> {
let facilities = match contraction {
false => {
let facilities_ret = self.state.borrow().get_facilities().clone();
facilities_ret.log(0);
facilities_ret
}
true => {
log::info!("\n\n bmor doing final bmor pass ...");
let res = self.bmor_contraction();
if res.is_err() {
std::panic!("bmor_contraction failed");
}
let state_2 = res.unwrap();
state_2.log();
let facilities = state_2.get_facilities();
facilities.clone()
}
};
facilities
}
pub fn process_weighted_data(
&self,
weighted_data: &[(f64, &Vec<T>, DataId)],
) -> anyhow::Result<usize> {
self.process_weighted_block(weighted_data);
let state = self.state.borrow();
state.log();
if log::log_enabled!(log::Level::Debug) {
state.get_facilities().log(1);
}
Ok(state.get_facilities().len())
}
pub(crate) fn bmor_contraction(&self) -> anyhow::Result<BmorState<DataId, T, Dist>> {
log::info!("\n bmor recurring");
let facility_data = self.state.borrow().get_facilities().into_weighted_data();
log::info!(
"bmor_recur , nb facilities received : {:?}",
facility_data.len()
);
let weighted_data: Vec<(f64, &Vec<T>, DataId)> = (0..facility_data.len())
.map(|i| {
(
facility_data[i].0,
&facility_data[i].1,
facility_data[i].2.clone(),
)
})
.collect();
let _bound_2 = self.nbdata_expected.ilog2() as usize;
let nb_expected_data = weighted_data.len();
if self.state.borrow().get_nb_inserted() > self.k * (1 + nb_expected_data.ilog2() as usize)
{
log::debug!(
"reducing number of facilities: setting expected nb data : {:?}",
nb_expected_data
);
let bmor_algo_2: Bmor<DataId, T, Dist> = Bmor::new(
self.get_k(),
nb_expected_data,
self.get_beta(),
self.get_gamma(),
self.distance.clone(),
);
let res = bmor_algo_2.process_weighted_data(&weighted_data);
if res.is_err() {
return Err(anyhow!("constraction failed"));
}
let state_2 = bmor_algo_2.state.borrow();
state_2.get_facilities().log(0);
Ok(state_2.clone())
} else {
let state = self.state.borrow();
state.log();
state.get_facilities().log(0);
Ok(state.clone())
}
}
fn process_weighted_block(&self, data: &[(f64, &Vec<T>, DataId)]) {
log::debug!(
"entering process_weighted_block, phase : {:?}, nb data : {}",
self.state.borrow().get_phase(),
data.len()
);
for d in data {
log::trace!("treating rank_id : {:?}, weight : {:.4e}", d.2, d.0);
let add_res = self.add_data(d.2.clone(), d.1, d.0);
if !add_res {
log::debug!(
"recycling facilities, incrementing upper bound for cost, nb_facilities : {:?}",
self.state.borrow().get_facilities().len()
);
let weighted_data: Vec<(f64, Vec<T>, DataId)> = self
.state
.borrow()
.centers
.get_vec()
.iter()
.map(|f| {
(
f.read().get_weight(),
f.read().get_position().clone(),
f.read().get_dataid(),
)
})
.collect();
assert!(!weighted_data.is_empty());
let weighted_ref_data: Vec<(f64, &Vec<T>, DataId)> = weighted_data
.iter()
.map(|wd| (wd.0, &wd.1, wd.2.clone()))
.collect();
assert!(!weighted_ref_data.is_empty());
self.state.borrow_mut().reinit(self.beta);
self.process_weighted_block(&weighted_ref_data);
}
}
}
pub(crate) fn add_data(&self, rank_id: DataId, data: &[T], weight: f64) -> bool {
let mut state = self.state.borrow_mut();
let facilities = state.get_mut_facilities();
if facilities.is_empty() {
log::debug!(
"Bmor::add_data creating facility rank_id : {:?} with weight : {:.3e}",
rank_id,
weight
);
let mut new_f = Facility::<DataId, T>::new(rank_id, data);
new_f.insert(weight, 0.);
facilities.insert(new_f);
state.nb_inserted += 1;
state.absolute_weight += weight;
return true;
}
state.update(rank_id, data, weight)
}
pub fn log(&self) {
self.state.borrow().log();
}
}