use std::collections::HashMap;
use super::OdMatrix;
use crate::gmns::types::ZoneID;
#[derive(Debug, Clone)]
pub struct SparseOdMatrix {
zone_ids: Vec<ZoneID>,
data: HashMap<(ZoneID, ZoneID), f64>,
}
impl SparseOdMatrix {
pub fn new(zone_ids: Vec<ZoneID>) -> Self {
SparseOdMatrix {
zone_ids,
data: HashMap::new(),
}
}
pub fn nnz(&self) -> usize {
self.data.len()
}
}
impl OdMatrix for SparseOdMatrix {
fn get(&self, origin: ZoneID, destination: ZoneID) -> f64 {
self.data
.get(&(origin, destination))
.copied()
.unwrap_or(0.0)
}
fn set(&mut self, origin: ZoneID, destination: ZoneID, value: f64) {
if value == 0.0 {
self.data.remove(&(origin, destination));
} else {
self.data.insert((origin, destination), value);
}
}
fn add(&mut self, origin: ZoneID, destination: ZoneID, delta: f64) {
let entry = self.data.entry((origin, destination)).or_insert(0.0);
*entry += delta;
if *entry == 0.0 {
self.data.remove(&(origin, destination));
}
}
fn zone_ids(&self) -> &[ZoneID] {
&self.zone_ids
}
fn row_sum(&self, origin: ZoneID) -> f64 {
self.data
.iter()
.filter(|&(&(o, _), _)| o == origin)
.map(|(_, &v)| v)
.sum()
}
fn col_sum(&self, destination: ZoneID) -> f64 {
self.data
.iter()
.filter(|&(&(_, d), _)| d == destination)
.map(|(_, &v)| v)
.sum()
}
fn total(&self) -> f64 {
self.data.values().sum()
}
fn iter(&self) -> Vec<(ZoneID, ZoneID, f64)> {
let mut result: Vec<(ZoneID, ZoneID, f64)> =
self.data.iter().map(|(&(o, d), &v)| (o, d, v)).collect();
result.sort_by_key(|&(o, d, _)| (o, d));
result
}
}
#[cfg(test)]
mod tests {
use super::*;
const EPS: f64 = 1e-10;
#[test]
fn new_matrix_is_empty() {
let od = SparseOdMatrix::new(vec![1, 2, 3]);
assert_eq!(od.nnz(), 0);
assert_eq!(od.total(), 0.0);
assert_eq!(od.get(1, 2), 0.0);
}
#[test]
fn set_and_get() {
let mut od = SparseOdMatrix::new(vec![1, 2]);
od.set(1, 2, 500.0);
assert_eq!(od.get(1, 2), 500.0);
assert_eq!(od.nnz(), 1);
}
#[test]
fn set_zero_removes_entry() {
let mut od = SparseOdMatrix::new(vec![1, 2]);
od.set(1, 2, 100.0);
assert_eq!(od.nnz(), 1);
od.set(1, 2, 0.0);
assert_eq!(od.nnz(), 0);
assert_eq!(od.get(1, 2), 0.0);
}
#[test]
fn add_accumulates() {
let mut od = SparseOdMatrix::new(vec![1, 2]);
od.add(1, 2, 60.0);
od.add(1, 2, 40.0);
assert_eq!(od.get(1, 2), 100.0);
assert_eq!(od.nnz(), 1);
}
#[test]
fn add_to_zero_removes_entry() {
let mut od = SparseOdMatrix::new(vec![1, 2]);
od.add(1, 2, 50.0);
od.add(1, 2, -50.0);
assert_eq!(od.nnz(), 0);
}
#[test]
fn row_sum() {
let mut od = SparseOdMatrix::new(vec![1, 2, 3]);
od.set(1, 2, 30.0);
od.set(1, 3, 70.0);
od.set(2, 3, 50.0);
assert!((od.row_sum(1) - 100.0).abs() < EPS);
assert!((od.row_sum(2) - 50.0).abs() < EPS);
assert!((od.row_sum(3) - 0.0).abs() < EPS);
}
#[test]
fn col_sum() {
let mut od = SparseOdMatrix::new(vec![1, 2, 3]);
od.set(1, 3, 30.0);
od.set(2, 3, 70.0);
assert!((od.col_sum(3) - 100.0).abs() < EPS);
assert!((od.col_sum(1) - 0.0).abs() < EPS);
}
#[test]
fn total() {
let mut od = SparseOdMatrix::new(vec![1, 2]);
od.set(1, 2, 100.0);
od.set(2, 1, 200.0);
assert_eq!(od.total(), 300.0);
}
#[test]
fn zone_ids_and_count() {
let od = SparseOdMatrix::new(vec![10, 20, 30]);
assert_eq!(od.zone_ids(), &[10, 20, 30]);
assert_eq!(od.zone_count(), 3);
}
#[test]
fn iter_returns_only_nonzero() {
let mut od = SparseOdMatrix::new(vec![1, 2, 3]);
od.set(1, 2, 50.0);
od.set(3, 1, 25.0);
let entries = od.iter();
assert_eq!(entries.len(), 2);
assert!(entries.contains(&(1, 2, 50.0)));
assert!(entries.contains(&(3, 1, 25.0)));
}
}