use std::collections::HashMap;
use super::OdMatrix;
use crate::gmns::types::ZoneID;
#[derive(Debug, Clone)]
pub struct DenseOdMatrix {
zone_ids: Vec<ZoneID>,
zone_index: HashMap<ZoneID, usize>,
data: Vec<f64>,
}
impl DenseOdMatrix {
pub fn new(zone_ids: Vec<ZoneID>) -> Self {
let n = zone_ids.len();
let zone_index: HashMap<ZoneID, usize> =
zone_ids.iter().enumerate().map(|(i, &z)| (z, i)).collect();
DenseOdMatrix {
zone_ids,
zone_index,
data: vec![0.0; n * n],
}
}
pub fn from_data(zone_ids: Vec<ZoneID>, data: Vec<f64>) -> Self {
let n = zone_ids.len();
assert_eq!(data.len(), n * n, "data length must be n*n");
let zone_index: HashMap<ZoneID, usize> =
zone_ids.iter().enumerate().map(|(i, &z)| (z, i)).collect();
DenseOdMatrix {
zone_ids,
zone_index,
data,
}
}
#[inline]
pub fn get_by_index(&self, i: usize, j: usize) -> f64 {
self.data[i * self.n() + j]
}
#[inline]
pub fn set_by_index(&mut self, i: usize, j: usize, value: f64) {
let n = self.n();
self.data[i * n + j] = value;
}
pub fn data_mut(&mut self) -> &mut [f64] {
&mut self.data
}
pub fn data(&self) -> &[f64] {
&self.data
}
fn index_of(&self, zone_id: ZoneID) -> Option<usize> {
self.zone_index.get(&zone_id).copied()
}
fn n(&self) -> usize {
self.zone_ids.len()
}
}
impl OdMatrix for DenseOdMatrix {
fn get(&self, origin: ZoneID, destination: ZoneID) -> f64 {
let Some(i) = self.index_of(origin) else {
return 0.0;
};
let Some(j) = self.index_of(destination) else {
return 0.0;
};
self.data[i * self.n() + j]
}
fn set(&mut self, origin: ZoneID, destination: ZoneID, value: f64) {
let Some(i) = self.index_of(origin) else {
return;
};
let Some(j) = self.index_of(destination) else {
return;
};
let n = self.n();
self.data[i * n + j] = value;
}
fn add(&mut self, origin: ZoneID, destination: ZoneID, delta: f64) {
let Some(i) = self.index_of(origin) else {
return;
};
let Some(j) = self.index_of(destination) else {
return;
};
let n = self.n();
self.data[i * n + j] += delta;
}
fn zone_ids(&self) -> &[ZoneID] {
&self.zone_ids
}
fn row_sum(&self, origin: ZoneID) -> f64 {
let Some(i) = self.index_of(origin) else {
return 0.0;
};
let n = self.n();
self.data[i * n..(i + 1) * n].iter().sum()
}
fn col_sum(&self, destination: ZoneID) -> f64 {
let Some(j) = self.index_of(destination) else {
return 0.0;
};
let n = self.n();
(0..n).map(|i| self.data[i * n + j]).sum()
}
fn total(&self) -> f64 {
self.data.iter().sum()
}
fn iter(&self) -> Vec<(ZoneID, ZoneID, f64)> {
let n = self.n();
let mut result = Vec::with_capacity(n * n);
for i in 0..n {
for j in 0..n {
let val = self.data[i * n + j];
result.push((self.zone_ids[i], self.zone_ids[j], val));
}
}
result
}
#[inline]
fn get_by_index(&self, i: usize, j: usize) -> f64 {
self.data[i * self.n() + j]
}
}
#[cfg(test)]
mod tests {
use super::*;
const EPS: f64 = 1e-10;
#[test]
fn new_matrix_is_zero() {
let od = DenseOdMatrix::new(vec![1, 2, 3]);
assert_eq!(od.total(), 0.0);
assert_eq!(od.get(1, 2), 0.0);
}
#[test]
fn set_and_get() {
let mut od = DenseOdMatrix::new(vec![10, 20]);
od.set(10, 20, 500.0);
assert_eq!(od.get(10, 20), 500.0);
assert_eq!(od.get(20, 10), 0.0);
}
#[test]
fn add_accumulates() {
let mut od = DenseOdMatrix::new(vec![1, 2]);
od.add(1, 2, 30.0);
od.add(1, 2, 70.0);
assert_eq!(od.get(1, 2), 100.0);
}
#[test]
fn unknown_zone_returns_zero() {
let od = DenseOdMatrix::new(vec![1, 2]);
assert_eq!(od.get(1, 999), 0.0);
assert_eq!(od.get(999, 1), 0.0);
}
#[test]
fn set_unknown_zone_is_noop() {
let mut od = DenseOdMatrix::new(vec![1, 2]);
od.set(999, 1, 100.0);
assert_eq!(od.total(), 0.0);
}
#[test]
fn from_data() {
let od = DenseOdMatrix::from_data(vec![1, 2], vec![0.0, 100.0, 200.0, 0.0]);
assert_eq!(od.get(1, 2), 100.0);
assert_eq!(od.get(2, 1), 200.0);
assert_eq!(od.total(), 300.0);
}
#[test]
#[should_panic]
fn from_data_wrong_length_panics() {
let _ = DenseOdMatrix::from_data(vec![1, 2], vec![1.0, 2.0]);
}
#[test]
fn row_sum() {
let mut od = DenseOdMatrix::new(vec![1, 2, 3]);
od.set(1, 2, 50.0);
od.set(1, 3, 30.0);
od.set(2, 3, 100.0);
assert!((od.row_sum(1) - 80.0).abs() < EPS);
assert!((od.row_sum(2) - 100.0).abs() < EPS);
assert!((od.row_sum(3) - 0.0).abs() < EPS);
}
#[test]
fn col_sum() {
let mut od = DenseOdMatrix::new(vec![1, 2, 3]);
od.set(1, 3, 30.0);
od.set(2, 3, 70.0);
assert!((od.col_sum(3) - 100.0).abs() < EPS);
assert!((od.col_sum(1) - 0.0).abs() < EPS);
}
#[test]
fn zone_ids_and_count() {
let od = DenseOdMatrix::new(vec![10, 20, 30]);
assert_eq!(od.zone_ids(), &[10, 20, 30]);
assert_eq!(od.zone_count(), 3);
}
#[test]
fn data_and_data_mut() {
let mut od = DenseOdMatrix::new(vec![1, 2]);
od.set(1, 2, 42.0);
assert_eq!(od.data(), &[0.0, 42.0, 0.0, 0.0]);
od.data_mut()[0] = 10.0;
assert_eq!(od.get(1, 1), 10.0);
}
#[test]
fn iter_returns_all_cells() {
let mut od = DenseOdMatrix::new(vec![1, 2]);
od.set(1, 2, 50.0);
let entries = od.iter();
assert_eq!(entries.len(), 4);
assert!(entries.contains(&(1, 2, 50.0)));
assert!(entries.contains(&(1, 1, 0.0)));
}
}