use scirs2_core::ndarray::{Array1, Array2, ArrayView1};
use scirs2_core::numeric::{Float, SparseElement};
use std::fmt::{self, Debug};
use std::ops::{Add, Div, Mul, Sub};
use crate::coo_array::CooArray;
use crate::csr_array::CsrArray;
use crate::error::{SparseError, SparseResult};
use crate::sparray::{SparseArray, SparseSum};
#[derive(Clone)]
pub struct LilArray<T>
where
T: SparseElement + Div<Output = T> + Float + 'static,
{
data: Vec<Vec<T>>,
indices: Vec<Vec<usize>>,
shape: (usize, usize),
}
impl<T> LilArray<T>
where
T: SparseElement + Div<Output = T> + Float + 'static,
{
pub fn new(shape: (usize, usize)) -> Self {
let (rows, cols) = shape;
let data = vec![Vec::new(); rows];
let indices = vec![Vec::new(); rows];
Self {
data,
indices,
shape,
}
}
pub fn from_lists(
data: Vec<Vec<T>>,
indices: Vec<Vec<usize>>,
shape: (usize, usize),
) -> SparseResult<Self> {
if data.len() != shape.0 || indices.len() != shape.0 {
return Err(SparseError::InconsistentData {
reason: "Number of rows in data and indices must match the shape".to_string(),
});
}
for (i, (row_data, row_indices)) in data.iter().zip(indices.iter()).enumerate() {
if row_data.len() != row_indices.len() {
return Err(SparseError::InconsistentData {
reason: format!("Row {i}: data and indices have different lengths"),
});
}
if let Some(&max_col) = row_indices.iter().max() {
if max_col >= shape.1 {
return Err(SparseError::IndexOutOfBounds {
index: (i, max_col),
shape,
});
}
}
}
let mut new_data = Vec::with_capacity(shape.0);
let mut new_indices = Vec::with_capacity(shape.0);
for (row_data, row_indices) in data.iter().zip(indices.iter()) {
let mut pairs: Vec<(usize, T)> = row_indices
.iter()
.copied()
.zip(row_data.iter().copied())
.collect();
pairs.sort_by_key(|&(idx_, _)| idx_);
let mut sorted_data = Vec::with_capacity(row_data.len());
let mut sorted_indices = Vec::with_capacity(row_indices.len());
for (idx, val) in pairs {
sorted_indices.push(idx);
sorted_data.push(val);
}
new_data.push(sorted_data);
new_indices.push(sorted_indices);
}
Ok(Self {
data: new_data,
indices: new_indices,
shape,
})
}
pub fn from_triplets(
rows: &[usize],
cols: &[usize],
data: &[T],
shape: (usize, usize),
) -> SparseResult<Self> {
if rows.len() != cols.len() || rows.len() != data.len() {
return Err(SparseError::InconsistentData {
reason: "rows, cols, and data must have the same length".to_string(),
});
}
let (num_rows, num_cols) = shape;
if let Some(&max_row) = rows.iter().max() {
if max_row >= num_rows {
return Err(SparseError::IndexOutOfBounds {
index: (max_row, 0),
shape,
});
}
}
if let Some(&max_col) = cols.iter().max() {
if max_col >= num_cols {
return Err(SparseError::IndexOutOfBounds {
index: (0, max_col),
shape,
});
}
}
let mut lil = LilArray::new(shape);
for (&row, &col, &value) in rows
.iter()
.zip(cols.iter())
.zip(data.iter())
.map(|((r, c), d)| (r, c, d))
{
lil.set(row, col, value)?;
}
Ok(lil)
}
pub fn get_data(&self) -> &Vec<Vec<T>> {
&self.data
}
pub fn get_indices(&self) -> &Vec<Vec<usize>> {
&self.indices
}
pub fn sort_indices(&mut self) {
for row in 0..self.shape.0 {
if !self.data[row].is_empty() {
let mut pairs: Vec<(usize, T)> = self.indices[row]
.iter()
.copied()
.zip(self.data[row].iter().copied())
.collect();
pairs.sort_by_key(|&(idx_, _)| idx_);
self.indices[row].clear();
self.data[row].clear();
for (idx, val) in pairs {
self.indices[row].push(idx);
self.data[row].push(val);
}
}
}
}
}
impl<T> SparseArray<T> for LilArray<T>
where
T: SparseElement + Div<Output = T> + Float + 'static,
{
fn shape(&self) -> (usize, usize) {
self.shape
}
fn nnz(&self) -> usize {
self.indices.iter().map(|row| row.len()).sum()
}
fn dtype(&self) -> &str {
"float" }
fn to_array(&self) -> Array2<T> {
let (rows, cols) = self.shape;
let mut result = Array2::zeros((rows, cols));
for row in 0..rows {
for (idx, &col) in self.indices[row].iter().enumerate() {
result[[row, col]] = self.data[row][idx];
}
}
result
}
fn toarray(&self) -> Array2<T> {
self.to_array()
}
fn to_coo(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
let nnz = self.nnz();
let mut row_indices = Vec::with_capacity(nnz);
let mut col_indices = Vec::with_capacity(nnz);
let mut values = Vec::with_capacity(nnz);
for row in 0..self.shape.0 {
for (idx, &col) in self.indices[row].iter().enumerate() {
row_indices.push(row);
col_indices.push(col);
values.push(self.data[row][idx]);
}
}
CooArray::from_triplets(&row_indices, &col_indices, &values, self.shape, false)
.map(|array| Box::new(array) as Box<dyn SparseArray<T>>)
}
fn to_csr(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
let (rows, cols) = self.shape;
let nnz = self.nnz();
let mut data = Vec::with_capacity(nnz);
let mut indices = Vec::with_capacity(nnz);
let mut indptr = Vec::with_capacity(rows + 1);
indptr.push(0);
for row in 0..rows {
for (idx, &col) in self.indices[row].iter().enumerate() {
indices.push(col);
data.push(self.data[row][idx]);
}
indptr.push(indptr.last().expect("Operation failed") + self.indices[row].len());
}
CsrArray::new(
Array1::from_vec(data),
Array1::from_vec(indices),
Array1::from_vec(indptr),
self.shape,
)
.map(|array| Box::new(array) as Box<dyn SparseArray<T>>)
}
fn to_csc(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
self.to_coo()?.to_csc()
}
fn to_dok(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
self.to_coo()?.to_dok()
}
fn to_lil(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
Ok(Box::new(self.clone()))
}
fn to_dia(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
self.to_coo()?.to_dia()
}
fn to_bsr(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
self.to_coo()?.to_bsr()
}
fn add(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
let self_csr = self.to_csr()?;
self_csr.add(other)
}
fn sub(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
let self_csr = self.to_csr()?;
self_csr.sub(other)
}
fn mul(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
let self_csr = self.to_csr()?;
self_csr.mul(other)
}
fn div(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
let self_csr = self.to_csr()?;
self_csr.div(other)
}
fn dot(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
let self_csr = self.to_csr()?;
self_csr.dot(other)
}
fn dot_vector(&self, other: &ArrayView1<T>) -> SparseResult<Array1<T>> {
let (rows, cols) = self.shape;
if cols != other.len() {
return Err(SparseError::DimensionMismatch {
expected: cols,
found: other.len(),
});
}
let mut result = Array1::zeros(rows);
for row in 0..rows {
for (idx, &col) in self.indices[row].iter().enumerate() {
result[row] = result[row] + self.data[row][idx] * other[col];
}
}
Ok(result)
}
fn transpose(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
self.to_coo()?.transpose()
}
fn copy(&self) -> Box<dyn SparseArray<T>> {
Box::new(self.clone())
}
fn get(&self, i: usize, j: usize) -> T {
if i >= self.shape.0 || j >= self.shape.1 {
return T::sparse_zero();
}
match self.indices[i].binary_search(&j) {
Ok(pos) => self.data[i][pos],
Err(_) => T::sparse_zero(),
}
}
fn set(&mut self, i: usize, j: usize, value: T) -> SparseResult<()> {
if i >= self.shape.0 || j >= self.shape.1 {
return Err(SparseError::IndexOutOfBounds {
index: (i, j),
shape: self.shape,
});
}
match self.indices[i].binary_search(&j) {
Ok(pos) => {
if SparseElement::is_zero(&value) {
self.data[i].remove(pos);
self.indices[i].remove(pos);
} else {
self.data[i][pos] = value;
}
}
Err(pos) => {
if !SparseElement::is_zero(&value) {
self.data[i].insert(pos, value);
self.indices[i].insert(pos, j);
}
}
}
Ok(())
}
fn eliminate_zeros(&mut self) {
for row in 0..self.shape.0 {
let mut new_data = Vec::new();
let mut new_indices = Vec::new();
for (idx, &value) in self.data[row].iter().enumerate() {
if !SparseElement::is_zero(&value) {
new_data.push(value);
new_indices.push(self.indices[row][idx]);
}
}
self.data[row] = new_data;
self.indices[row] = new_indices;
}
}
fn sort_indices(&mut self) {
LilArray::sort_indices(self);
}
fn sorted_indices(&self) -> Box<dyn SparseArray<T>> {
let mut sorted = self.clone();
sorted.sort_indices();
Box::new(sorted)
}
fn has_sorted_indices(&self) -> bool {
for row in 0..self.shape.0 {
if self.indices[row].len() > 1 {
for i in 1..self.indices[row].len() {
if self.indices[row][i - 1] >= self.indices[row][i] {
return false;
}
}
}
}
true
}
fn sum(&self, axis: Option<usize>) -> SparseResult<SparseSum<T>> {
match axis {
None => {
let mut sum = T::sparse_zero();
for row in 0..self.shape.0 {
for &val in self.data[row].iter() {
sum = sum + val;
}
}
Ok(SparseSum::Scalar(sum))
}
Some(0) => {
let (_, cols) = self.shape;
let mut result = Array1::<T>::zeros(cols);
for row in 0..self.shape.0 {
for (idx, &col) in self.indices[row].iter().enumerate() {
result[col] = result[col] + self.data[row][idx];
}
}
let mut lil = LilArray::new((1, cols));
for (col, &val) in result.iter().enumerate() {
if !SparseElement::is_zero(&val) {
lil.set(0, col, val)?;
}
}
Ok(SparseSum::SparseArray(Box::new(lil)))
}
Some(1) => {
let (rows, _) = self.shape;
let mut result = Array1::<T>::zeros(rows);
for row in 0..rows {
for &val in self.data[row].iter() {
result[row] = result[row] + val;
}
}
let mut lil = LilArray::new((rows, 1));
for (row, &val) in result.iter().enumerate() {
if !SparseElement::is_zero(&val) {
lil.set(row, 0, val)?;
}
}
Ok(SparseSum::SparseArray(Box::new(lil)))
}
_ => Err(SparseError::InvalidAxis),
}
}
fn max(&self) -> T {
if self.nnz() == 0 {
return T::sparse_zero();
}
let mut max_val = T::sparse_zero();
let mut found_value = false;
for row in 0..self.shape.0 {
for &val in self.data[row].iter() {
if !found_value {
max_val = val;
found_value = true;
} else if val > max_val {
max_val = val;
}
}
}
let zero = T::sparse_zero();
if max_val < zero && self.nnz() < self.shape.0 * self.shape.1 {
zero
} else {
max_val
}
}
fn min(&self) -> T {
if self.nnz() == 0 {
return T::sparse_zero();
}
let mut min_val = T::sparse_zero();
for row in 0..self.shape.0 {
for &val in self.data[row].iter() {
if val < min_val {
min_val = val;
}
}
}
if min_val > T::sparse_zero() && self.nnz() < self.shape.0 * self.shape.1 {
T::sparse_zero()
} else {
min_val
}
}
fn find(&self) -> (Array1<usize>, Array1<usize>, Array1<T>) {
let nnz = self.nnz();
let mut rows = Vec::with_capacity(nnz);
let mut cols = Vec::with_capacity(nnz);
let mut values = Vec::with_capacity(nnz);
for row in 0..self.shape.0 {
for (idx, &col) in self.indices[row].iter().enumerate() {
rows.push(row);
cols.push(col);
values.push(self.data[row][idx]);
}
}
(
Array1::from_vec(rows),
Array1::from_vec(cols),
Array1::from_vec(values),
)
}
fn slice(
&self,
row_range: (usize, usize),
col_range: (usize, usize),
) -> SparseResult<Box<dyn SparseArray<T>>> {
let (start_row, end_row) = row_range;
let (start_col, end_col) = col_range;
if start_row >= self.shape.0
|| end_row > self.shape.0
|| start_col >= self.shape.1
|| end_col > self.shape.1
|| start_row >= end_row
|| start_col >= end_col
{
return Err(SparseError::InvalidSliceRange);
}
let mut new_data = vec![Vec::new(); end_row - start_row];
let mut new_indices = vec![Vec::new(); end_row - start_row];
for row in start_row..end_row {
for (idx, &col) in self.indices[row].iter().enumerate() {
if col >= start_col && col < end_col {
new_data[row - start_row].push(self.data[row][idx]);
new_indices[row - start_row].push(col - start_col);
}
}
}
Ok(Box::new(LilArray {
data: new_data,
indices: new_indices,
shape: (end_row - start_row, end_col - start_col),
}))
}
fn as_any(&self) -> &dyn std::any::Any {
self
}
}
impl<T> fmt::Debug for LilArray<T>
where
T: SparseElement + Div<Output = T> + Float + 'static,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"LilArray<{}x{}, nnz={}>",
self.shape.0,
self.shape.1,
self.nnz()
)
}
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_relative_eq;
#[test]
fn test_lil_array_construction() {
let shape = (3, 3);
let lil = LilArray::<f64>::new(shape);
assert_eq!(lil.shape(), (3, 3));
assert_eq!(lil.nnz(), 0);
let data = vec![vec![1.0, 2.0], vec![3.0], vec![4.0, 5.0]];
let indices = vec![vec![0, 2], vec![1], vec![0, 1]];
let lil = LilArray::from_lists(data, indices, shape).expect("Operation failed");
assert_eq!(lil.shape(), (3, 3));
assert_eq!(lil.nnz(), 5);
assert_eq!(lil.get(0, 0), 1.0);
assert_eq!(lil.get(0, 2), 2.0);
assert_eq!(lil.get(1, 1), 3.0);
assert_eq!(lil.get(2, 0), 4.0);
assert_eq!(lil.get(2, 1), 5.0);
assert_eq!(lil.get(0, 1), 0.0);
}
#[test]
fn test_lil_from_triplets() {
let rows = vec![0, 0, 1, 2, 2];
let cols = vec![0, 2, 1, 0, 1];
let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
let shape = (3, 3);
let lil = LilArray::from_triplets(&rows, &cols, &data, shape).expect("Operation failed");
assert_eq!(lil.shape(), (3, 3));
assert_eq!(lil.nnz(), 5);
assert_eq!(lil.get(0, 0), 1.0);
assert_eq!(lil.get(0, 2), 2.0);
assert_eq!(lil.get(1, 1), 3.0);
assert_eq!(lil.get(2, 0), 4.0);
assert_eq!(lil.get(2, 1), 5.0);
assert_eq!(lil.get(0, 1), 0.0);
}
#[test]
fn test_lil_array_to_array() {
let rows = vec![0, 0, 1, 2, 2];
let cols = vec![0, 2, 1, 0, 1];
let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
let shape = (3, 3);
let lil = LilArray::from_triplets(&rows, &cols, &data, shape).expect("Operation failed");
let dense = lil.to_array();
assert_eq!(dense.shape(), &[3, 3]);
assert_eq!(dense[[0, 0]], 1.0);
assert_eq!(dense[[0, 1]], 0.0);
assert_eq!(dense[[0, 2]], 2.0);
assert_eq!(dense[[1, 0]], 0.0);
assert_eq!(dense[[1, 1]], 3.0);
assert_eq!(dense[[1, 2]], 0.0);
assert_eq!(dense[[2, 0]], 4.0);
assert_eq!(dense[[2, 1]], 5.0);
assert_eq!(dense[[2, 2]], 0.0);
}
#[test]
fn test_lil_set_get() {
let mut lil = LilArray::<f64>::new((3, 3));
lil.set(0, 0, 1.0).expect("Operation failed");
lil.set(0, 2, 2.0).expect("Operation failed");
lil.set(1, 1, 3.0).expect("Operation failed");
lil.set(2, 0, 4.0).expect("Operation failed");
lil.set(2, 1, 5.0).expect("Operation failed");
assert_eq!(lil.get(0, 0), 1.0);
assert_eq!(lil.get(0, 2), 2.0);
assert_eq!(lil.get(1, 1), 3.0);
assert_eq!(lil.get(2, 0), 4.0);
assert_eq!(lil.get(2, 1), 5.0);
assert_eq!(lil.get(0, 1), 0.0);
lil.set(0, 0, 6.0).expect("Operation failed");
assert_eq!(lil.get(0, 0), 6.0);
lil.set(0, 0, 0.0).expect("Operation failed");
assert_eq!(lil.get(0, 0), 0.0);
assert_eq!(lil.nnz(), 4);
}
#[test]
fn test_lil_to_csr() {
let rows = vec![0, 0, 1, 2, 2];
let cols = vec![0, 2, 1, 0, 1];
let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
let shape = (3, 3);
let lil = LilArray::from_triplets(&rows, &cols, &data, shape).expect("Operation failed");
let csr = lil.to_csr().expect("Operation failed");
let dense = csr.to_array();
assert_eq!(dense[[0, 0]], 1.0);
assert_eq!(dense[[0, 1]], 0.0);
assert_eq!(dense[[0, 2]], 2.0);
assert_eq!(dense[[1, 0]], 0.0);
assert_eq!(dense[[1, 1]], 3.0);
assert_eq!(dense[[1, 2]], 0.0);
assert_eq!(dense[[2, 0]], 4.0);
assert_eq!(dense[[2, 1]], 5.0);
assert_eq!(dense[[2, 2]], 0.0);
}
#[test]
fn test_lil_to_coo() {
let rows = vec![0, 0, 1, 2, 2];
let cols = vec![0, 2, 1, 0, 1];
let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
let shape = (3, 3);
let lil = LilArray::from_triplets(&rows, &cols, &data, shape).expect("Operation failed");
let coo = lil.to_coo().expect("Operation failed");
let dense = coo.to_array();
assert_eq!(dense[[0, 0]], 1.0);
assert_eq!(dense[[0, 1]], 0.0);
assert_eq!(dense[[0, 2]], 2.0);
assert_eq!(dense[[1, 0]], 0.0);
assert_eq!(dense[[1, 1]], 3.0);
assert_eq!(dense[[1, 2]], 0.0);
assert_eq!(dense[[2, 0]], 4.0);
assert_eq!(dense[[2, 1]], 5.0);
assert_eq!(dense[[2, 2]], 0.0);
}
#[test]
fn test_lil_dot_vector() {
let rows = vec![0, 0, 1, 2, 2];
let cols = vec![0, 2, 1, 0, 1];
let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
let shape = (3, 3);
let lil = LilArray::from_triplets(&rows, &cols, &data, shape).expect("Operation failed");
let vector = Array1::from_vec(vec![1.0, 2.0, 3.0]);
let result = lil.dot_vector(&vector.view()).expect("Operation failed");
assert_eq!(result.len(), 3);
assert_relative_eq!(result[0], 7.0, epsilon = 1e-10);
assert_relative_eq!(result[1], 6.0, epsilon = 1e-10);
assert_relative_eq!(result[2], 14.0, epsilon = 1e-10);
}
#[test]
fn test_lil_eliminate_zeros() {
let mut lil = LilArray::<f64>::new((2, 2));
lil.set(0, 0, 1.0).expect("Operation failed");
lil.set(0, 1, 0.0).expect("Operation failed"); lil.set(1, 0, 2.0).expect("Operation failed");
lil.set(1, 1, 3.0).expect("Operation failed");
lil.set(0, 1, 4.0).expect("Operation failed");
lil.set(0, 1, 0.0).expect("Operation failed");
lil.data[1][0] = 0.0;
assert_eq!(lil.nnz(), 3);
lil.eliminate_zeros();
assert_eq!(lil.nnz(), 2); assert_eq!(lil.get(0, 0), 1.0);
assert_eq!(lil.get(1, 1), 3.0);
}
#[test]
fn test_lil_sort_indices() {
let mut lil = LilArray::<f64>::new((2, 4));
lil.set(0, 3, 1.0).expect("Operation failed");
lil.set(0, 1, 2.0).expect("Operation failed");
lil.set(1, 2, 3.0).expect("Operation failed");
lil.set(1, 0, 4.0).expect("Operation failed");
if lil.data[0].len() >= 2 {
lil.data[0].swap(0, 1);
lil.indices[0].swap(0, 1);
}
assert!(!lil.has_sorted_indices());
lil.sort_indices();
assert!(lil.has_sorted_indices());
assert_eq!(lil.get(0, 1), 2.0);
assert_eq!(lil.get(0, 3), 1.0);
assert_eq!(lil.get(1, 0), 4.0);
assert_eq!(lil.get(1, 2), 3.0);
assert_eq!(lil.indices[0][0], 1);
assert_eq!(lil.indices[0][1], 3);
assert_eq!(lil.data[0][0], 2.0);
assert_eq!(lil.data[0][1], 1.0);
assert_eq!(lil.indices[1][0], 0);
assert_eq!(lil.indices[1][1], 2);
assert_eq!(lil.data[1][0], 4.0);
assert_eq!(lil.data[1][1], 3.0);
}
#[test]
fn test_lil_slice() {
let rows = vec![0, 0, 1, 2, 2];
let cols = vec![0, 2, 1, 0, 1];
let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
let shape = (3, 3);
let lil = LilArray::from_triplets(&rows, &cols, &data, shape).expect("Operation failed");
let slice = lil.slice((1, 3), (0, 2)).expect("Operation failed");
assert_eq!(slice.shape(), (2, 2));
assert_eq!(slice.get(0, 1), 3.0);
assert_eq!(slice.get(1, 0), 4.0);
assert_eq!(slice.get(1, 1), 5.0);
assert_eq!(slice.get(0, 0), 0.0);
}
}