use serde::{Deserialize, Serialize};
use num_traits::cast::FromPrimitive;
use num_traits::Float;
use std::cmp::Ordering;
use hnsw_rs::hnsw::Neighbour;
pub type NodeIdx = usize;
#[derive(Clone, Copy, Debug, Serialize, Deserialize)]
pub struct OutEdge<F> {
pub node: NodeIdx,
pub weight: F,
}
impl<F> OutEdge<F> {
pub fn new(node: NodeIdx, weight: F) -> Self {
OutEdge { node, weight }
}
}
impl<F> PartialEq for OutEdge<F>
where
F: Float,
{
fn eq(&self, other: &OutEdge<F>) -> bool {
self.weight == other.weight
} }
impl<F: Float> PartialOrd for OutEdge<F> {
fn partial_cmp(&self, other: &OutEdge<F>) -> Option<Ordering> {
self.weight.partial_cmp(&other.weight)
} }
impl<F> From<Neighbour> for OutEdge<F>
where
F: Float + FromPrimitive,
{
fn from(neighbour: Neighbour) -> OutEdge<F> {
OutEdge {
node: neighbour.d_id,
weight: F::from_f32(neighbour.distance).unwrap(),
}
} }
#[derive(Clone)]
pub struct NodeParam {
pub(crate) scale: f32,
pub(crate) edges: Vec<OutEdge<f32>>,
}
impl NodeParam {
pub fn new(scale: f32, edges: Vec<OutEdge<f32>>) -> Self {
NodeParam { scale, edges }
}
pub fn get_edge(&self, i: NodeIdx) -> Option<&OutEdge<f32>> {
self.edges.iter().find(|&&edge| edge.node == i)
}
pub fn get_perplexity(&self) -> f32 {
let h: f32 = self.edges.iter().map(|&x| -x.weight * x.weight.ln()).sum();
h.exp()
}
#[allow(unused)]
pub fn get_nb_edges(&self) -> usize {
self.edges.len()
}
}
impl Default for NodeParam {
fn default() -> Self {
NodeParam {
scale: 0f32,
edges: Vec::<OutEdge<f32>>::new(),
}
}
}
pub struct NodeParams {
pub params: Vec<NodeParam>,
pub max_nbng: usize,
}
impl NodeParams {
pub fn new(params: Vec<NodeParam>, max_nbng: usize) -> Self {
NodeParams { params, max_nbng }
}
pub fn get_node_param(&self, node: NodeIdx) -> &NodeParam {
&self.params[node]
}
pub fn get_nb_nodes(&self) -> usize {
self.params.len()
}
pub fn get_max_nbng(&self) -> usize {
self.max_nbng
}
}