augurs_core/
distance.rs

1use std::{fmt, ops::Index};
2
3/// An error that can occur when creating a `DistanceMatrix`.
4#[derive(Debug)]
5pub enum DistanceMatrixError {
6    /// The input matrix is not square.
7    InvalidDistanceMatrix,
8}
9
10impl fmt::Display for DistanceMatrixError {
11    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
12        f.write_str("invalid distance matrix")
13    }
14}
15
16impl std::error::Error for DistanceMatrixError {}
17
18/// A matrix representing the distances between pairs of items.
19#[derive(Debug, Clone)]
20#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
21pub struct DistanceMatrix {
22    matrix: Vec<Vec<f64>>,
23}
24
25impl DistanceMatrix {
26    /// Create a new `DistanceMatrix` from a square matrix.
27    ///
28    /// # Errors
29    ///
30    /// Returns an error if the input matrix is not square.
31    pub fn try_from_square(matrix: Vec<Vec<f64>>) -> Result<Self, DistanceMatrixError> {
32        if matrix.iter().all(|x| x.len() == matrix.len()) {
33            Ok(Self { matrix })
34        } else {
35            Err(DistanceMatrixError::InvalidDistanceMatrix)
36        }
37    }
38
39    /// Consumes the `DistanceMatrix` and returns the inner matrix.
40    pub fn into_inner(self) -> Vec<Vec<f64>> {
41        self.matrix
42    }
43
44    /// Returns an iterator over the rows of the matrix.
45    pub fn iter(&self) -> DistanceMatrixIter<'_> {
46        DistanceMatrixIter {
47            iter: self.matrix.iter(),
48        }
49    }
50
51    /// Returns the shape of the matrix.
52    ///
53    /// The first element is the number of rows and the second element
54    /// is the number of columns.
55    ///
56    /// The matrix is square, so the number of rows is equal to the number of columns
57    /// and the number of input series.
58    pub fn shape(&self) -> (usize, usize) {
59        (self.matrix.len(), self.matrix.len())
60    }
61}
62
63impl Index<usize> for DistanceMatrix {
64    type Output = [f64];
65    fn index(&self, index: usize) -> &Self::Output {
66        &self.matrix[index]
67    }
68}
69
70impl Index<(usize, usize)> for DistanceMatrix {
71    type Output = f64;
72    fn index(&self, (i, j): (usize, usize)) -> &Self::Output {
73        &self.matrix[i][j]
74    }
75}
76
77impl IntoIterator for DistanceMatrix {
78    type Item = Vec<f64>;
79    type IntoIter = std::vec::IntoIter<Self::Item>;
80    fn into_iter(self) -> Self::IntoIter {
81        self.matrix.into_iter()
82    }
83}
84
85/// An iterator over the rows of a `DistanceMatrix`.
86#[derive(Debug)]
87pub struct DistanceMatrixIter<'a> {
88    iter: std::slice::Iter<'a, Vec<f64>>,
89}
90
91impl<'a> Iterator for DistanceMatrixIter<'a> {
92    type Item = &'a Vec<f64>;
93    fn next(&mut self) -> Option<Self::Item> {
94        self.iter.next()
95    }
96}