use crate::types::{Precision, DimensionType, IndexType, NodeId};
use crate::error::{SolverError, Result};
use alloc::{vec::Vec, collections::BTreeMap};
use core::iter;
#[derive(Debug, Clone)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct CSRStorage {
pub values: Vec<Precision>,
pub col_indices: Vec<IndexType>,
pub row_ptr: Vec<IndexType>,
}
#[derive(Debug, Clone)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct CSCStorage {
pub values: Vec<Precision>,
pub row_indices: Vec<IndexType>,
pub col_ptr: Vec<IndexType>,
}
#[derive(Debug, Clone)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct COOStorage {
pub row_indices: Vec<IndexType>,
pub col_indices: Vec<IndexType>,
pub values: Vec<Precision>,
}
#[derive(Debug, Clone)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct GraphStorage {
pub out_edges: Vec<Vec<GraphEdge>>,
pub in_edges: Vec<Vec<GraphEdge>>,
pub degrees: Vec<Precision>,
}
#[derive(Debug, Clone, Copy, PartialEq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct GraphEdge {
pub target: NodeId,
pub weight: Precision,
}
impl CSRStorage {
pub fn from_coo(coo: &COOStorage, rows: DimensionType, cols: DimensionType) -> Result<Self> {
if coo.is_empty() {
return Ok(Self {
values: Vec::new(),
col_indices: Vec::new(),
row_ptr: vec![0; rows + 1],
});
}
let mut sorted_entries: Vec<_> = coo.row_indices.iter()
.zip(&coo.col_indices)
.zip(&coo.values)
.map(|((&r, &c), &v)| (r as usize, c, v))
.collect();
sorted_entries.sort_by(|a, b| a.0.cmp(&b.0).then(a.1.cmp(&b.1)));
let mut values = Vec::new();
let mut col_indices = Vec::new();
let mut row_ptr = vec![0; rows + 1];
let mut current_row = 0;
let mut nnz_count = 0;
for (row, col, value) in sorted_entries {
if value == 0.0 {
continue;
}
while current_row < row {
current_row += 1;
row_ptr[current_row] = nnz_count as IndexType;
}
values.push(value);
col_indices.push(col);
nnz_count += 1;
}
while current_row < rows {
current_row += 1;
row_ptr[current_row] = nnz_count as IndexType;
}
Ok(Self {
values,
col_indices,
row_ptr,
})
}
pub fn from_csc(csc: &CSCStorage, rows: DimensionType, cols: DimensionType) -> Result<Self> {
let triplets = csc.to_triplets()?;
let coo = COOStorage::from_triplets(triplets)?;
Self::from_coo(&coo, rows, cols)
}
pub fn get(&self, row: usize, col: usize) -> Option<Precision> {
if row >= self.row_ptr.len() - 1 {
return None;
}
let start = self.row_ptr[row] as usize;
let end = self.row_ptr[row + 1] as usize;
match self.col_indices[start..end].binary_search(&(col as IndexType)) {
Ok(pos) => Some(self.values[start + pos]),
Err(_) => None,
}
}
pub fn row_iter(&self, row: usize) -> CSRRowIter {
if row >= self.row_ptr.len() - 1 {
return CSRRowIter {
col_indices: &[],
values: &[],
pos: 0,
};
}
let start = self.row_ptr[row] as usize;
let end = self.row_ptr[row + 1] as usize;
CSRRowIter {
col_indices: &self.col_indices[start..end],
values: &self.values[start..end],
pos: 0,
}
}
pub fn col_iter(&self, col: usize) -> CSRColIter {
CSRColIter {
storage: self,
col: col as IndexType,
row: 0,
}
}
pub fn multiply_vector(&self, x: &[Precision], result: &mut [Precision]) {
result.fill(0.0);
self.multiply_vector_add(x, result);
}
pub fn multiply_vector_add(&self, x: &[Precision], result: &mut [Precision]) {
for (row, mut row_sum) in result.iter_mut().enumerate() {
let start = self.row_ptr[row] as usize;
let end = self.row_ptr[row + 1] as usize;
for i in start..end {
let col = self.col_indices[i] as usize;
*row_sum += self.values[i] * x[col];
}
}
}
pub fn nnz(&self) -> usize {
self.values.len()
}
pub fn to_triplets(&self) -> Result<Vec<(usize, usize, Precision)>> {
let mut triplets = Vec::new();
for row in 0..self.row_ptr.len() - 1 {
let start = self.row_ptr[row] as usize;
let end = self.row_ptr[row + 1] as usize;
for i in start..end {
let col = self.col_indices[i] as usize;
let value = self.values[i];
triplets.push((row, col, value));
}
}
Ok(triplets)
}
pub fn scale(&mut self, factor: Precision) {
for value in &mut self.values {
*value *= factor;
}
}
pub fn add_diagonal(&mut self, alpha: Precision) {
for row in 0..self.row_ptr.len() - 1 {
let start = self.row_ptr[row] as usize;
let end = self.row_ptr[row + 1] as usize;
if let Ok(pos) = self.col_indices[start..end].binary_search(&(row as IndexType)) {
self.values[start + pos] += alpha;
}
}
}
}
pub struct CSRRowIter<'a> {
col_indices: &'a [IndexType],
values: &'a [Precision],
pos: usize,
}
impl<'a> Iterator for CSRRowIter<'a> {
type Item = (IndexType, Precision);
fn next(&mut self) -> Option<Self::Item> {
if self.pos < self.col_indices.len() {
let col = self.col_indices[self.pos];
let val = self.values[self.pos];
self.pos += 1;
Some((col, val))
} else {
None
}
}
}
pub struct CSRColIter<'a> {
storage: &'a CSRStorage,
col: IndexType,
row: usize,
}
impl<'a> Iterator for CSRColIter<'a> {
type Item = (IndexType, Precision);
fn next(&mut self) -> Option<Self::Item> {
while self.row < self.storage.row_ptr.len() - 1 {
let start = self.storage.row_ptr[self.row] as usize;
let end = self.storage.row_ptr[self.row + 1] as usize;
if let Ok(pos) = self.storage.col_indices[start..end].binary_search(&self.col) {
let value = self.storage.values[start + pos];
let row = self.row as IndexType;
self.row += 1;
return Some((row, value));
}
self.row += 1;
}
None
}
}
impl CSCStorage {
pub fn from_coo(coo: &COOStorage, rows: DimensionType, cols: DimensionType) -> Result<Self> {
if coo.is_empty() {
return Ok(Self {
values: Vec::new(),
row_indices: Vec::new(),
col_ptr: vec![0; cols + 1],
});
}
let mut sorted_entries: Vec<_> = coo.row_indices.iter()
.zip(&coo.col_indices)
.zip(&coo.values)
.map(|((&r, &c), &v)| (r, c as usize, v))
.collect();
sorted_entries.sort_by(|a, b| a.1.cmp(&b.1).then(a.0.cmp(&b.0)));
let mut values = Vec::new();
let mut row_indices = Vec::new();
let mut col_ptr = vec![0; cols + 1];
let mut current_col = 0;
let mut nnz_count = 0;
for (row, col, value) in sorted_entries {
if value == 0.0 {
continue;
}
while current_col < col {
current_col += 1;
col_ptr[current_col] = nnz_count as IndexType;
}
values.push(value);
row_indices.push(row);
nnz_count += 1;
}
while current_col < cols {
current_col += 1;
col_ptr[current_col] = nnz_count as IndexType;
}
Ok(Self {
values,
row_indices,
col_ptr,
})
}
pub fn from_csr(csr: &CSRStorage, rows: DimensionType, cols: DimensionType) -> Result<Self> {
let triplets = csr.to_triplets()?;
let coo = COOStorage::from_triplets(triplets)?;
Self::from_coo(&coo, rows, cols)
}
pub fn get(&self, row: usize, col: usize) -> Option<Precision> {
if col >= self.col_ptr.len() - 1 {
return None;
}
let start = self.col_ptr[col] as usize;
let end = self.col_ptr[col + 1] as usize;
match self.row_indices[start..end].binary_search(&(row as IndexType)) {
Ok(pos) => Some(self.values[start + pos]),
Err(_) => None,
}
}
pub fn row_iter(&self, row: usize) -> CSCRowIter {
CSCRowIter {
storage: self,
row: row as IndexType,
col: 0,
}
}
pub fn col_iter(&self, col: usize) -> CSCColIter {
if col >= self.col_ptr.len() - 1 {
return CSCColIter {
row_indices: &[],
values: &[],
pos: 0,
};
}
let start = self.col_ptr[col] as usize;
let end = self.col_ptr[col + 1] as usize;
CSCColIter {
row_indices: &self.row_indices[start..end],
values: &self.values[start..end],
pos: 0,
}
}
pub fn multiply_vector(&self, x: &[Precision], result: &mut [Precision]) {
result.fill(0.0);
self.multiply_vector_add(x, result);
}
pub fn multiply_vector_add(&self, x: &[Precision], result: &mut [Precision]) {
for col in 0..self.col_ptr.len() - 1 {
let x_col = x[col];
if x_col == 0.0 {
continue;
}
let start = self.col_ptr[col] as usize;
let end = self.col_ptr[col + 1] as usize;
for i in start..end {
let row = self.row_indices[i] as usize;
result[row] += self.values[i] * x_col;
}
}
}
pub fn nnz(&self) -> usize {
self.values.len()
}
pub fn to_triplets(&self) -> Result<Vec<(usize, usize, Precision)>> {
let mut triplets = Vec::new();
for col in 0..self.col_ptr.len() - 1 {
let start = self.col_ptr[col] as usize;
let end = self.col_ptr[col + 1] as usize;
for i in start..end {
let row = self.row_indices[i] as usize;
let value = self.values[i];
triplets.push((row, col, value));
}
}
Ok(triplets)
}
pub fn scale(&mut self, factor: Precision) {
for value in &mut self.values {
*value *= factor;
}
}
pub fn add_diagonal(&mut self, alpha: Precision) {
for col in 0..self.col_ptr.len() - 1 {
let start = self.col_ptr[col] as usize;
let end = self.col_ptr[col + 1] as usize;
if let Ok(pos) = self.row_indices[start..end].binary_search(&(col as IndexType)) {
self.values[start + pos] += alpha;
}
}
}
}
pub struct CSCRowIter<'a> {
storage: &'a CSCStorage,
row: IndexType,
col: usize,
}
impl<'a> Iterator for CSCRowIter<'a> {
type Item = (IndexType, Precision);
fn next(&mut self) -> Option<Self::Item> {
while self.col < self.storage.col_ptr.len() - 1 {
let start = self.storage.col_ptr[self.col] as usize;
let end = self.storage.col_ptr[self.col + 1] as usize;
if let Ok(pos) = self.storage.row_indices[start..end].binary_search(&self.row) {
let value = self.storage.values[start + pos];
let col = self.col as IndexType;
self.col += 1;
return Some((col, value));
}
self.col += 1;
}
None
}
}
pub struct CSCColIter<'a> {
row_indices: &'a [IndexType],
values: &'a [Precision],
pos: usize,
}
impl<'a> Iterator for CSCColIter<'a> {
type Item = (IndexType, Precision);
fn next(&mut self) -> Option<Self::Item> {
if self.pos < self.row_indices.len() {
let row = self.row_indices[self.pos];
let val = self.values[self.pos];
self.pos += 1;
Some((row, val))
} else {
None
}
}
}
impl COOStorage {
pub fn from_triplets(triplets: Vec<(usize, usize, Precision)>) -> Result<Self> {
let mut row_indices = Vec::new();
let mut col_indices = Vec::new();
let mut values = Vec::new();
for (row, col, value) in triplets {
if value != 0.0 {
row_indices.push(row as IndexType);
col_indices.push(col as IndexType);
values.push(value);
}
}
Ok(Self {
row_indices,
col_indices,
values,
})
}
pub fn is_empty(&self) -> bool {
self.values.is_empty()
}
pub fn get(&self, row: usize, col: usize) -> Option<Precision> {
for i in 0..self.values.len() {
if self.row_indices[i] as usize == row && self.col_indices[i] as usize == col {
return Some(self.values[i]);
}
}
None
}
pub fn row_iter(&self, row: usize) -> COORowIter {
COORowIter {
storage: self,
target_row: row as IndexType,
pos: 0,
}
}
pub fn col_iter(&self, col: usize) -> COOColIter {
COOColIter {
storage: self,
target_col: col as IndexType,
pos: 0,
}
}
pub fn multiply_vector(&self, x: &[Precision], result: &mut [Precision]) {
result.fill(0.0);
self.multiply_vector_add(x, result);
}
pub fn multiply_vector_add(&self, x: &[Precision], result: &mut [Precision]) {
for i in 0..self.values.len() {
let row = self.row_indices[i] as usize;
let col = self.col_indices[i] as usize;
result[row] += self.values[i] * x[col];
}
}
pub fn nnz(&self) -> usize {
self.values.len()
}
pub fn to_triplets(&self) -> Vec<(usize, usize, Precision)> {
self.row_indices.iter()
.zip(&self.col_indices)
.zip(&self.values)
.map(|((&r, &c), &v)| (r as usize, c as usize, v))
.collect()
}
pub fn scale(&mut self, factor: Precision) {
for value in &mut self.values {
*value *= factor;
}
}
pub fn add_diagonal(&mut self, alpha: Precision, rows: DimensionType) {
for i in 0..self.values.len() {
if self.row_indices[i] == self.col_indices[i] {
self.values[i] += alpha;
}
}
}
}
pub struct COORowIter<'a> {
storage: &'a COOStorage,
target_row: IndexType,
pos: usize,
}
impl<'a> Iterator for COORowIter<'a> {
type Item = (IndexType, Precision);
fn next(&mut self) -> Option<Self::Item> {
while self.pos < self.storage.values.len() {
if self.storage.row_indices[self.pos] == self.target_row {
let col = self.storage.col_indices[self.pos];
let val = self.storage.values[self.pos];
self.pos += 1;
return Some((col, val));
}
self.pos += 1;
}
None
}
}
pub struct COOColIter<'a> {
storage: &'a COOStorage,
target_col: IndexType,
pos: usize,
}
impl<'a> Iterator for COOColIter<'a> {
type Item = (IndexType, Precision);
fn next(&mut self) -> Option<Self::Item> {
while self.pos < self.storage.values.len() {
if self.storage.col_indices[self.pos] == self.target_col {
let row = self.storage.row_indices[self.pos];
let val = self.storage.values[self.pos];
self.pos += 1;
return Some((row, val));
}
self.pos += 1;
}
None
}
}
impl GraphStorage {
pub fn from_triplets(triplets: Vec<(usize, usize, Precision)>, nodes: DimensionType) -> Result<Self> {
let mut out_edges = vec![Vec::new(); nodes];
let mut in_edges = vec![Vec::new(); nodes];
let mut degrees = vec![0.0; nodes];
for (row, col, weight) in triplets {
if weight != 0.0 && row < nodes && col < nodes {
out_edges[row].push(GraphEdge {
target: col as NodeId,
weight,
});
if row != col { in_edges[col].push(GraphEdge {
target: row as NodeId,
weight,
});
}
degrees[row] += weight.abs();
}
}
Ok(Self {
out_edges,
in_edges,
degrees,
})
}
pub fn get(&self, row: usize, col: usize) -> Option<Precision> {
if row >= self.out_edges.len() {
return None;
}
for edge in &self.out_edges[row] {
if edge.target as usize == col {
return Some(edge.weight);
}
}
None
}
pub fn row_iter(&self, row: usize) -> GraphRowIter {
if row >= self.out_edges.len() {
GraphRowIter {
edges: &[],
pos: 0,
}
} else {
GraphRowIter {
edges: &self.out_edges[row],
pos: 0,
}
}
}
pub fn col_iter(&self, col: usize) -> GraphColIter {
if col >= self.in_edges.len() {
GraphColIter {
edges: &[],
pos: 0,
}
} else {
GraphColIter {
edges: &self.in_edges[col],
pos: 0,
}
}
}
pub fn multiply_vector(&self, x: &[Precision], result: &mut [Precision]) {
result.fill(0.0);
self.multiply_vector_add(x, result);
}
pub fn multiply_vector_add(&self, x: &[Precision], result: &mut [Precision]) {
for (row, edges) in self.out_edges.iter().enumerate() {
for edge in edges {
let col = edge.target as usize;
if col < x.len() {
result[row] += edge.weight * x[col];
}
}
}
}
pub fn nnz(&self) -> usize {
self.out_edges.iter().map(|edges| edges.len()).sum()
}
pub fn to_triplets(&self) -> Result<Vec<(usize, usize, Precision)>> {
let mut triplets = Vec::new();
for (row, edges) in self.out_edges.iter().enumerate() {
for edge in edges {
triplets.push((row, edge.target as usize, edge.weight));
}
}
Ok(triplets)
}
pub fn scale(&mut self, factor: Precision) {
for edges in &mut self.out_edges {
for edge in edges {
edge.weight *= factor;
}
}
for edges in &mut self.in_edges {
for edge in edges {
edge.weight *= factor;
}
}
for degree in &mut self.degrees {
*degree *= factor.abs();
}
}
pub fn add_diagonal(&mut self, alpha: Precision) {
for (node, edges) in self.out_edges.iter_mut().enumerate() {
let mut found = false;
for edge in edges.iter_mut() {
if edge.target as usize == node {
edge.weight += alpha;
found = true;
break;
}
}
if !found && alpha != 0.0 {
edges.push(GraphEdge {
target: node as NodeId,
weight: alpha,
});
}
self.degrees[node] += alpha.abs();
}
}
pub fn out_neighbors(&self, node: usize) -> &[GraphEdge] {
if node < self.out_edges.len() {
&self.out_edges[node]
} else {
&[]
}
}
pub fn in_neighbors(&self, node: usize) -> &[GraphEdge] {
if node < self.in_edges.len() {
&self.in_edges[node]
} else {
&[]
}
}
pub fn degree(&self, node: usize) -> Precision {
if node < self.degrees.len() {
self.degrees[node]
} else {
0.0
}
}
}
pub struct GraphRowIter<'a> {
edges: &'a [GraphEdge],
pos: usize,
}
impl<'a> Iterator for GraphRowIter<'a> {
type Item = (IndexType, Precision);
fn next(&mut self) -> Option<Self::Item> {
if self.pos < self.edges.len() {
let edge = self.edges[self.pos];
self.pos += 1;
Some((edge.target, edge.weight))
} else {
None
}
}
}
pub struct GraphColIter<'a> {
edges: &'a [GraphEdge],
pos: usize,
}
impl<'a> Iterator for GraphColIter<'a> {
type Item = (IndexType, Precision);
fn next(&mut self) -> Option<Self::Item> {
if self.pos < self.edges.len() {
let edge = self.edges[self.pos];
self.pos += 1;
Some((edge.target, edge.weight))
} else {
None
}
}
}
#[cfg(all(test, feature = "std"))]
mod tests {
use super::*;
#[test]
fn test_csr_creation() {
let triplets = vec![(0, 0, 1.0), (0, 2, 2.0), (1, 1, 3.0), (2, 0, 4.0), (2, 2, 5.0)];
let coo = COOStorage::from_triplets(triplets).unwrap();
let csr = CSRStorage::from_coo(&coo, 3, 3).unwrap();
assert_eq!(csr.nnz(), 5);
assert_eq!(csr.get(0, 0), Some(1.0));
assert_eq!(csr.get(0, 2), Some(2.0));
assert_eq!(csr.get(1, 1), Some(3.0));
assert_eq!(csr.get(0, 1), None);
}
#[test]
fn test_csr_matrix_vector_multiply() {
let triplets = vec![(0, 0, 2.0), (0, 1, 1.0), (1, 0, 1.0), (1, 1, 3.0)];
let coo = COOStorage::from_triplets(triplets).unwrap();
let csr = CSRStorage::from_coo(&coo, 2, 2).unwrap();
let x = vec![1.0, 2.0];
let mut result = vec![0.0; 2];
csr.multiply_vector(&x, &mut result);
assert_eq!(result, vec![4.0, 7.0]); }
#[test]
fn test_graph_storage() {
let triplets = vec![(0, 1, 0.5), (1, 0, 0.3), (1, 2, 0.7), (2, 1, 0.2)];
let graph = GraphStorage::from_triplets(triplets, 3).unwrap();
assert_eq!(graph.nnz(), 4);
assert_eq!(graph.out_neighbors(1).len(), 2);
assert_eq!(graph.in_neighbors(1).len(), 2);
assert!(graph.degree(1) > 0.0);
}
#[test]
fn test_format_conversions() {
let triplets = vec![(0, 0, 1.0), (0, 2, 2.0), (1, 1, 3.0)];
let coo1 = COOStorage::from_triplets(triplets.clone()).unwrap();
let csr = CSRStorage::from_coo(&coo1, 2, 3).unwrap();
let csc = CSCStorage::from_csr(&csr, 2, 3).unwrap();
let triplets2 = csc.to_triplets().unwrap();
fn cmp_triplet(a: &(usize, usize, f64), b: &(usize, usize, f64)) -> std::cmp::Ordering {
a.0.cmp(&b.0)
.then_with(|| a.1.cmp(&b.1))
.then_with(|| a.2.total_cmp(&b.2))
}
let mut t1 = triplets.clone();
let mut t2 = triplets2;
t1.sort_by(cmp_triplet);
t2.sort_by(cmp_triplet);
assert_eq!(t1, t2);
}
}