use anyhow::*;
use serde::{Deserialize, Serialize};
use ndarray::Array2;
use quantiles::ckms::CKMS;
use rayon::prelude::*;
use parking_lot::RwLock;
use std::sync::Arc;
use std::collections::HashMap;
use anndists::dist::*;
#[derive(Clone, Serialize, Deserialize)]
pub struct Facility<DataId, T: Send + Sync + Clone> {
d_rank: DataId,
center: Vec<T>,
weight: f64,
cost: f64,
}
impl<DataId: std::fmt::Debug + Clone, T: Send + Sync + Clone> Facility<DataId, T> {
pub fn new(d_rank: DataId, center: &[T]) -> Self {
Facility {
d_rank,
center: center.to_vec(),
weight: 0.,
cost: 0.,
}
}
pub fn get_position(&self) -> &Vec<T> {
&self.center
}
pub fn get_dataid(&self) -> DataId {
self.d_rank.clone()
}
pub fn get_weight(&self) -> f64 {
self.weight
}
#[cfg_attr(doc, katexit::katexit)]
pub fn get_cost(&self) -> f64 {
self.cost
}
pub(crate) fn insert(&mut self, weight: f64, dist: f32) {
self.weight += weight;
self.cost += dist as f64 * weight;
}
pub(crate) fn empty(&mut self) {
self.weight = 0.;
self.cost = 0.;
}
pub fn log(&self) {
log::info!(
"facility , dataid : {:?} weight : {:.4e}, cost : {:.3e} cost/weight : {:.3e}",
self.d_rank,
self.weight,
self.cost,
self.cost / self.weight
);
}
}
pub type FacilityId = usize;
#[derive(Copy, Clone, Debug)]
pub struct PointMap {
facility: usize,
dist_to_f: f32,
weight: f32,
}
impl PointMap {
pub fn new(facility: usize, dist_to_f: f32, weight: f32) -> Self {
PointMap {
facility,
dist_to_f,
weight,
}
}
pub fn get_facility(&self) -> usize {
self.facility
}
pub fn get_dist(&self) -> f32 {
self.dist_to_f
}
pub fn get_weight(&self) -> f32 {
self.weight
}
}
#[cfg_attr(doc, katexit::katexit)]
#[derive(Clone)]
pub struct Facilities<DataId, T: Send + Sync + Clone, Dist: Distance<T>> {
centers: Vec<Arc<RwLock<Facility<DataId, T>>>>,
distance: Dist,
weight: f64,
cost: f64,
}
impl<
DataId: std::fmt::Debug + Clone + Send + Sync,
T: Send + Sync + Clone,
Dist: Distance<T> + Send + Sync,
> Facilities<DataId, T, Dist>
{
pub fn new(size: usize, distance: Dist) -> Self {
let centers = Vec::<Arc<RwLock<Facility<DataId, T>>>>::with_capacity(size);
Facilities {
centers,
distance,
weight: 0.,
cost: 0.,
}
}
pub fn len(&self) -> usize {
self.centers.len()
}
pub fn is_empty(&self) -> bool {
matches!(self.centers.len(), 0)
}
pub fn get_weight(&self) -> f64 {
self.centers.iter().map(|f| f.read().get_weight()).sum()
}
pub fn get_distance(&self) -> &Dist {
&self.distance
}
pub fn get_cost(&self) -> f64 {
self.centers.iter().map(|f| f.read().get_cost()).sum()
}
pub(crate) fn clear(&mut self) {
log::debug!("clearing facilities");
self.centers.clear();
self.weight = 0.;
self.cost = 0.;
}
pub(crate) fn empty(&mut self) {
log::debug!("emptying facilities");
for f in &self.centers {
f.write().empty();
}
self.weight = 0.;
self.cost = 0.;
}
pub(crate) fn get_vec(&self) -> &Vec<Arc<RwLock<Facility<DataId, T>>>> {
&self.centers
}
pub fn match_point(&self, point: &[T], dmax: f32, distance: &Dist) -> bool {
for f in &self.centers {
if distance.eval(f.read().get_position(), point) <= dmax {
return true;
}
}
false
}
pub(crate) fn insert(&mut self, facility: Facility<DataId, T>) {
self.centers.push(Arc::new(RwLock::new(facility)));
log::trace!(
"Facilities: facility insertion nb facilities : {}",
self.centers.len()
);
}
pub fn get_facility(&self, rank: usize) -> Option<&Arc<RwLock<Facility<DataId, T>>>> {
if rank >= self.centers.len() {
None
} else {
Some(&self.centers[rank])
}
}
pub fn get_cloned_facility(&self, rank: usize) -> Option<Facility<DataId, T>> {
if rank >= self.centers.len() {
None
} else {
return Some(self.centers[rank].read().clone());
}
}
pub fn get_facility_weight(&self, rank: usize) -> Result<f64> {
if rank <= self.centers.len() {
return Ok(self.centers[rank].read().get_weight());
} else {
Err(anyhow!("not so many facilities , rank is {}", rank))
}
}
pub fn get_nearest_facility(&self, data: &[T], parallel: bool) -> anyhow::Result<(usize, f32)> {
let mut dist = f32::INFINITY;
let mut rank_f: usize = usize::MAX;
if self.centers.is_empty() {
return Err(anyhow!("Empty facility"));
}
let dist_to_f = |i| -> (usize, f32) {
let f_i = self.get_facility(i).unwrap().read();
let center_i = f_i.get_position();
let d_i = self.distance.eval(center_i, data);
(i, d_i)
};
let dist_slot: Vec<(usize, f32)> = match parallel {
true => (0..self.centers.len())
.into_par_iter()
.map(dist_to_f)
.collect(),
false => (0..self.centers.len()).map(dist_to_f).collect(),
};
for (f, d) in dist_slot {
if d < dist {
dist = d;
rank_f = f;
}
}
assert!(rank_f < usize::MAX);
Ok((rank_f, dist))
}
pub(crate) fn insert_point(&self, facility: usize, dist: f32, weight: f32) {
let mut f = self.centers[facility].write();
f.weight += weight as f64;
f.cost += dist as f64 * weight as f64;
}
pub fn compute_weight_cost(&mut self) -> (f64, f64) {
if self.weight <= 0. {
let mut total_weight = 0.;
let mut total_cost = 0.;
for f in &self.centers {
let f_access = f.read();
total_cost += f_access.get_cost();
total_weight += f_access.get_weight();
}
self.cost = total_cost;
self.weight = total_weight;
}
(self.weight, self.cost)
}
pub fn log(&self, level: usize) {
let mut total_weight = 0.;
let mut total_cost = 0.;
for f in &self.centers {
let f_access = f.read();
if level == 1 {
f_access.log();
}
total_cost += f_access.get_cost();
total_weight += f_access.get_weight();
}
log::info!(
"\n\n nb facilities : {} sum of facilities weight : {:.3e}, total cost : {:.3e}",
self.centers.len(),
total_weight,
total_cost
);
}
pub fn dispatch_data(
&mut self,
data: &[&Vec<T>],
_ids: &[usize],
weights: Option<&Vec<f32>>,
) -> f64 {
log::info!("in facilities::dispatch_data");
if weights.is_some() {
assert_eq!(data.len(), weights.unwrap().len());
}
self.empty();
let dispatch_i = |item: usize| {
let (facility, dist) = self.get_nearest_facility(data[item], false).unwrap();
let weight = if let Some(w_values) = weights {
w_values[item]
} else {
1.
};
self.insert_point(facility, dist, weight);
};
(0..data.len()).into_par_iter().for_each(dispatch_i);
let mut global_cost = 0_f64;
let mut total_weight = 0.;
for i in 0..self.centers.len() {
global_cost += self.centers[i].read().cost;
total_weight += self.centers[i].read().weight;
}
println!(
"\n\n total weight collected in facilities : {:.3e}, total cost : {:.3e}",
total_weight, global_cost
);
println!("\n **************************************************************************");
global_cost
}
pub fn dispatch_labels<L: PartialEq + Eq + Copy + std::hash::Hash + Sync + Send>(
&mut self,
data: &[Vec<T>],
labels: &[L],
weights: Option<&Vec<f32>>,
) -> (Vec<f64>, Vec<HashMap<L, u32>>) {
log::info!("dispatch_labels");
type SafeHashMap<L> = Arc<RwLock<HashMap<L, u32>>>;
assert_eq!(data.len(), labels.len());
let nb_facility = self.centers.len();
let mut label_distribution = Vec::<SafeHashMap<L>>::with_capacity(nb_facility);
for i in 0..nb_facility {
self.centers[i].write().cost = 0.;
self.centers[i].write().weight = 0.;
let newmap = HashMap::<L, u32>::with_capacity(data.len() / (2 * nb_facility));
label_distribution.push(Arc::new(RwLock::new(newmap)));
}
let dispatch_i = |i: usize| {
let (itemf, dist) = self.get_nearest_facility(&data[i], false).unwrap();
let weight = if let Some(w_values) = weights {
w_values[itemf] as f64
} else {
1.
};
let cost_incr = dist as f64 * weight;
{
let mut facility = self.centers[itemf].write();
facility.weight += weight;
facility.cost += cost_incr;
}
{
let mut distribution = label_distribution[itemf].write();
if let Some(count) = distribution.get_mut(&labels[i]) {
*count += 1;
} else {
distribution.insert(labels[i], 1);
}
}
};
log::info!("computing global cost and weights");
(0..data.len()).into_par_iter().for_each(dispatch_i);
let mut global_cost = 0_f64;
let mut total_weight = 0.;
for i in 0..nb_facility {
global_cost += self.centers[i].read().cost;
total_weight += self.centers[i].read().weight;
}
println!(
"\n\n total weight collected in facilities : {:.3e}, total cost : {:.3e}",
total_weight, global_cost
);
println!("\n **************************************************************************");
log::info!("computing label distribution entropy");
let mut entropies = Vec::<f64>::with_capacity(nb_facility);
for (i, ld) in label_distribution.iter().enumerate() {
let distribution = ld.read();
let mut mass = 0.0f64;
let nb_label = distribution.len();
let mut weights = Vec::<f64>::with_capacity(nb_label);
let mut entropy = 0.;
for item in distribution.iter() {
assert!(*item.1 > 0);
weights.push(*item.1 as f64);
mass += *item.1 as f64;
entropy -= (*item.1 as f64) * (*item.1 as f64).ln();
}
entropy = entropy / mass + mass.ln();
if entropy < -f64::EPSILON * 10. {
log::error!("facility {:?} entropy {:.3e}", i, entropy);
std::panic!("negative entropy");
} else {
entropy = entropy.max(0.);
}
entropies.push(entropy);
}
let mut global_entropy = 0.;
let mut total_weight = 0.;
for (i, f) in self.centers.iter().enumerate() {
let facility = f.read();
let weight = facility.get_weight();
total_weight += weight;
global_entropy += weight * entropies[i];
}
global_entropy /= total_weight;
println!(
"\n\n mean of entropies : {:.3e}, total weight : {:.3e}",
global_entropy, total_weight
);
println!("\n **************************************************************************");
let mut simple_label_distribution = Vec::<HashMap<L, u32>>::with_capacity(nb_facility);
for ld in &label_distribution {
simple_label_distribution.push(ld.read().clone());
}
(entropies, simple_label_distribution)
}
pub fn into_weighted_data(&self) -> Vec<(f64, Vec<T>, DataId)> {
log::info!("facility::into_weighted_data");
let nb_facility = self.len();
let mut weighted_data = Vec::<(f64, Vec<T>, DataId)>::with_capacity(nb_facility);
for i in 0..nb_facility {
let facility = self.get_facility(i).unwrap().read();
let weight = facility.get_weight();
let pos = facility.get_position();
let id: DataId = facility.get_dataid();
weighted_data.push((weight, pos.clone(), id.clone()));
}
weighted_data
}
pub fn cross_distances(&self) {
let nb_facility = self.centers.len();
let mut distances = Array2::<f32>::zeros((nb_facility, nb_facility));
let mut q_dist = CKMS::<f32>::new(0.01);
if nb_facility <= 1 {
log::error!("facility::cross_distances, only one facility");
return;
}
for i in 0..nb_facility {
let f_i = self.get_facility(i).unwrap().read();
let center_i = f_i.get_position();
for j in 0..nb_facility {
if i != j {
let f_j = self.get_facility(j).unwrap().read();
distances[[i, j]] = self.distance.eval(center_i, f_j.get_position());
q_dist.insert(distances[[i, j]]);
}
}
}
println!("\n inter facility distances quantiles : ");
println!("\n distance quantiles at 0.01 : {:.2e}, 0.05 : {:.2e}, 0.1 : {:.2e} , 0.5 : {:.2e}, 0.75 : {:.2e} ",
q_dist.query(0.01).unwrap().1, q_dist.query(0.05).unwrap().1, q_dist.query(0.1).unwrap().1, q_dist.query(0.5).unwrap().1, q_dist.query(0.75).unwrap().1);
log::debug!("\n cross distances : {:.3e}", distances);
} }