use std::cmp::Ordering;
#[cfg(feature = "serde")]
use serde::{Serialize, Deserialize};
use ndarray::{Array1, Array2};
use crate::error::{DigiFiError, ErrorTitle};
use crate::utilities::{maths_utils::euclidean_distance, data_transformations::log_return_transformation};
use crate::statistics::pearson_correlation;
#[derive(Clone, Copy, Debug)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub enum MSTDistance {
EuclideanDistance,
MantegnaDistance,
}
impl MSTDistance {
fn mantegna_distance(v_1: &Array1<f64>, v_2: &Array1<f64>) -> Result<f64, DigiFiError> {
let returns_1: Array1<f64> = log_return_transformation(v_1);
let returns_2: Array1<f64> = log_return_transformation(v_2);
pearson_correlation(&returns_1, &returns_2, 0)
}
pub fn distance(&self, v_1: &Array1<f64>, v_2: &Array1<f64>) -> Result<f64, DigiFiError> {
match self {
Self::EuclideanDistance => euclidean_distance(v_1.iter(), v_2.iter()),
Self::MantegnaDistance => Self::mantegna_distance(v_1, v_2),
}
}
}
#[derive(Clone, Debug)]
#[cfg_attr(feature = "serde", derive(Serialize))]
pub struct MSTNode<'x> {
pub name: String,
pub index: usize,
pub coordinate: &'x Array1<f64>,
}
#[derive(Clone, Debug)]
#[cfg_attr(feature = "serde", derive(Serialize))]
pub struct MSTEdge<'a, 'b, 'x> {
pub node_1: &'a MSTNode<'x>,
pub node_2: &'b MSTNode<'x>,
pub weight: f64,
}
#[derive(Clone, Debug)]
#[cfg_attr(feature = "serde", derive(Serialize))]
pub struct MST<'a, 'b, 'x> {
pub n_nodes: usize,
pub graph: Vec<MSTEdge<'a, 'b, 'x>>,
pub result: Vec<MSTEdge<'a, 'b, 'x>>
}
impl<'a, 'b, 'x> MST<'a, 'b, 'x> {
pub fn new(n_nodes: usize, edges: Vec<MSTEdge<'a, 'b, 'x>>) -> Self {
Self { n_nodes, graph: edges, result: Vec::<MSTEdge>::new(), }
}
pub fn add_edge(&mut self, edge: MSTEdge<'a, 'b, 'x>) -> () {
self.graph.push(edge);
}
pub fn add_edges(&mut self, edges: Vec<MSTEdge<'a, 'b, 'x>>) -> () {
for edge in edges {
self.add_edge(edge);
}
}
pub fn compute_edge_weights(&mut self, distance_type: &MSTDistance) -> Result<(), DigiFiError> {
for edge in &mut self.graph {
edge.weight = distance_type.distance(&edge.node_1.coordinate, &edge.node_2.coordinate)?;
}
Ok(())
}
fn find(&self, parent: &mut Vec<usize>, i: usize) -> usize {
if parent[i] != i {
parent[i] = self.find(parent, parent[i])
}
parent[i]
}
fn union(&self, parent: &mut Vec<usize>, rank: &mut Vec<usize>, x: usize, y: usize) -> () {
if rank[x] < rank[y] {
parent[x] = y;
} else if rank[x] > rank[y] {
parent[y] = x;
} else {
parent[y] = x;
rank[x] += 1;
}
}
pub fn kruskal_mst(&mut self) -> () {
let mut result: Vec<MSTEdge> = Vec::<MSTEdge>::new();
let mut i: usize = 0;
let mut e: usize = 0;
self.graph.sort_by(|a, b| {
if a.weight < b.weight {
Ordering::Less
} else if a.weight > b.weight {
Ordering::Greater
} else {
Ordering::Equal
}
} );
let mut parent: Vec<usize> = Vec::<usize>::new();
let mut rank: Vec<usize> = Vec::<usize>::new();
for node in 0..self.n_nodes {
parent.push(node);
rank.push(0);
}
while e < (self.n_nodes - 1) {
let edge: &MSTEdge<'a, 'b, 'x> = &self.graph[i];
i += 1;
let x: usize = self.find(&mut parent, edge.node_1.index);
let y: usize = self.find(&mut parent, edge.node_2.index);
if x != y {
e += 1;
result.push(edge.clone());
self.union(&mut parent, &mut rank, x, y);
}
}
self.result = result;
}
pub fn minimum_cost(&self) -> f64 {
let mut minimum_cost: f64 = 0.0;
for edge in &self.result {
minimum_cost += edge.weight;
}
minimum_cost
}
pub fn distance_matrix(&self, nodes: &Vec<MSTNode>) -> Result<(Vec<String>, Array2<Option<f64>>), DigiFiError> {
let mut node_names: Vec<String> = Vec::<String>::new();
for node in nodes {
node_names.push(node.name.clone());
}
let distances: Vec<Option<f64>> = vec![None; self.n_nodes.pow(2)];
let mut distances: Array2<Option<f64>> = Array2::from_shape_vec((self.n_nodes, self.n_nodes), distances)?;
for edge in &self.result {
let i: usize = node_names.iter().position(|v| v == &edge.node_1.name )
.ok_or(DigiFiError::NotFound { title: Self::error_title(), data: "matching edge name".to_owned(), })?;
let j: usize = node_names.iter().position(|v| v == &edge.node_2.name )
.ok_or(DigiFiError::NotFound { title: Self::error_title(), data: "matching edge name".to_owned(), })?;
distances[[i, j]] = Some(edge.weight);
distances[[j, i]] = Some(edge.weight);
}
Ok((node_names, distances))
}
}
impl ErrorTitle for MST<'_, '_, '_> {
fn error_title() -> String {
String::from("Minimal-Spanning Tree")
}
}
#[cfg(test)]
mod tests {
use ndarray::{arr1, arr2, Array2};
#[test]
fn test_mst() -> () {
use crate::utilities::minimal_spanning_tree::{MSTNode, MSTEdge, MST};
let node_0: MSTNode = MSTNode { name: "First".to_owned(), index: 0, coordinate: &arr1(&[1.0, 2.0, 3.0]), };
let node_1: MSTNode = MSTNode { name: "Second".to_owned(), index: 1, coordinate: &arr1(&[1.0, 2.0, 3.0]), };
let node_2: MSTNode = MSTNode { name: "Third".to_owned(), index: 2, coordinate: &arr1(&[1.0, 2.0, 3.0]), };
let node_3: MSTNode = MSTNode { name: "Fourth".to_owned(), index: 3, coordinate: &arr1(&[1.0, 2.0, 3.0]), };
let nodes: Vec<MSTNode> = vec![node_0, node_1, node_2, node_3];
let edge_0: MSTEdge = MSTEdge { node_1: &nodes[0], node_2: &nodes[1], weight: 10.0 };
let edge_1: MSTEdge = MSTEdge { node_1: &nodes[0], node_2: &nodes[2], weight: 6.0 };
let edge_2: MSTEdge = MSTEdge { node_1: &nodes[0], node_2: &nodes[3], weight: 5.0 };
let edge_3: MSTEdge = MSTEdge { node_1: &nodes[1], node_2: &nodes[3], weight: 15.0 };
let edge_4: MSTEdge = MSTEdge { node_1: &nodes[2], node_2: &nodes[3], weight: 4.0 };
let mut mst: MST = MST::new(nodes.len(), vec![edge_0, edge_1, edge_2, edge_3, edge_4]);
mst.kruskal_mst();
println!("Edges in the constructed MST:");
for edge in &mst.result {
println!("{} -- {} == {}", edge.node_1.name, edge.node_2.name, edge.weight);
}
assert_eq!(mst.minimum_cost(), 19.0);
let (node_names, distances) = mst.distance_matrix(&nodes).unwrap();
assert_eq!(node_names, vec!["First".to_owned(), "Second".to_owned(), "Third".to_owned(), "Fourth".to_owned()]);
let distances_: Array2<Option<f64>> = arr2(&[
[None, Some(10.0), None, Some(5.0)],
[Some(10.0), None, None, None],
[None, None, None, Some(4.0)],
[Some(5.0), None, Some(4.0), None]
]);
assert_eq!(distances, distances_);
}
}