use crate::error::{SparseError, SparseResult};
use scirs2_core::numeric::{SparseElement, Zero};
use std::fmt::Debug;
#[derive(Debug, Clone)]
pub struct CsfTensor<T> {
pub shape: Vec<usize>,
pub mode_order: Vec<usize>,
pub fib_ptr: Vec<Vec<usize>>,
pub fib_idx: Vec<Vec<usize>>,
pub values: Vec<T>,
}
#[derive(Debug, Clone)]
struct CooEntry<T: Copy> {
coords: Vec<usize>,
value: T,
}
impl<T> CsfTensor<T>
where
T: Clone + Copy + Zero + SparseElement + Debug,
{
pub fn from_coo(
indices: &[Vec<usize>],
values: &[T],
shape: &[usize],
mode_order: Option<&[usize]>,
) -> SparseResult<Self> {
let ndim = shape.len();
if indices.len() != ndim {
return Err(SparseError::ValueError(format!(
"indices length {} != ndim {}",
indices.len(),
ndim
)));
}
let nnz = values.len();
if ndim > 0 && indices[0].len() != nnz {
return Err(SparseError::ValueError(
"indices and values length mismatch".to_string(),
));
}
let order: Vec<usize> = match mode_order {
Some(o) => {
if o.len() != ndim {
return Err(SparseError::ValueError(
"mode_order length must match ndim".to_string(),
));
}
let mut sorted = o.to_vec();
sorted.sort_unstable();
for (i, &v) in sorted.iter().enumerate() {
if v != i {
return Err(SparseError::ValueError(
"mode_order must be a permutation of 0..ndim".to_string(),
));
}
}
o.to_vec()
}
None => (0..ndim).collect(),
};
if nnz == 0 {
let fib_ptr = if ndim > 1 {
(0..ndim - 1).map(|_| vec![0usize]).collect()
} else {
Vec::new()
};
let fib_idx = (0..ndim).map(|_| Vec::new()).collect();
return Ok(Self {
shape: shape.to_vec(),
mode_order: order,
fib_ptr,
fib_idx,
values: Vec::new(),
});
}
let mut entries: Vec<CooEntry<T>> = (0..nnz)
.map(|i| {
let coords: Vec<usize> = order.iter().map(|&m| indices[m][i]).collect();
CooEntry {
coords,
value: values[i],
}
})
.collect();
entries.sort_by(|a, b| a.coords.cmp(&b.coords));
let mut fib_ptr: Vec<Vec<usize>> = Vec::new();
let mut fib_idx: Vec<Vec<usize>> = Vec::new();
let mut leaf_values: Vec<T> = Vec::new();
for _ in 0..ndim {
fib_idx.push(Vec::new());
}
for _ in 0..ndim.saturating_sub(1) {
fib_ptr.push(Vec::new());
}
Self::build_levels(
&entries,
&mut fib_ptr,
&mut fib_idx,
&mut leaf_values,
0,
ndim,
);
for l in 0..ndim.saturating_sub(1) {
fib_ptr[l].push(fib_idx[l + 1].len());
}
Ok(Self {
shape: shape.to_vec(),
mode_order: order,
fib_ptr,
fib_idx,
values: leaf_values,
})
}
fn build_levels(
entries: &[CooEntry<T>],
fib_ptr: &mut Vec<Vec<usize>>,
fib_idx: &mut Vec<Vec<usize>>,
values: &mut Vec<T>,
level: usize,
ndim: usize,
) {
if entries.is_empty() {
return;
}
if level == ndim - 1 {
for entry in entries {
fib_idx[level].push(entry.coords[level]);
values.push(entry.value);
}
return;
}
let mut group_start = 0usize;
while group_start < entries.len() {
let coord = entries[group_start].coords[level];
let mut group_end = group_start + 1;
while group_end < entries.len() && entries[group_end].coords[level] == coord {
group_end += 1;
}
fib_idx[level].push(coord);
fib_ptr[level].push(fib_idx[level + 1].len());
Self::build_levels(
&entries[group_start..group_end],
fib_ptr,
fib_idx,
values,
level + 1,
ndim,
);
group_start = group_end;
}
}
pub fn ndim(&self) -> usize {
self.shape.len()
}
pub fn nnz(&self) -> usize {
self.values.len()
}
pub fn get(&self, indices: &[usize]) -> Option<T> {
let ndim = self.ndim();
if indices.len() != ndim {
return None;
}
let ordered: Vec<usize> = self.mode_order.iter().map(|&m| indices[m]).collect();
self.search_tree(&ordered, 0, 0)
}
fn search_tree(&self, ordered: &[usize], level: usize, fiber_idx: usize) -> Option<T> {
let ndim = self.ndim();
let (start, end) = if level == 0 {
(0, self.fib_idx[0].len())
} else {
let s = self.fib_ptr[level - 1].get(fiber_idx).copied().unwrap_or(0);
let e = self.fib_ptr[level - 1]
.get(fiber_idx + 1)
.copied()
.unwrap_or(self.fib_idx[level].len());
(s, e)
};
let target = ordered[level];
let range = &self.fib_idx[level][start..end];
match range.binary_search(&target) {
Ok(pos) => {
let abs_pos = start + pos;
if level == ndim - 1 {
self.values.get(abs_pos).copied()
} else {
self.search_tree(ordered, level + 1, abs_pos)
}
}
Err(_) => Some(T::sparse_zero()),
}
}
pub fn fiber(
&self,
free_mode: usize,
fixed_indices: &[usize],
) -> SparseResult<Vec<(usize, T)>> {
let ndim = self.ndim();
if free_mode >= ndim {
return Err(SparseError::ValueError(format!(
"free_mode {} >= ndim {}",
free_mode, ndim
)));
}
if fixed_indices.len() != ndim - 1 {
return Err(SparseError::ValueError(format!(
"fixed_indices length {} != ndim-1 = {}",
fixed_indices.len(),
ndim - 1
)));
}
let mut result = Vec::new();
self.collect_fiber(0, 0, free_mode, fixed_indices, &mut Vec::new(), &mut result);
Ok(result)
}
fn collect_fiber(
&self,
level: usize,
fiber_idx: usize,
free_mode: usize,
fixed_indices: &[usize],
coord_stack: &mut Vec<usize>,
result: &mut Vec<(usize, T)>,
) {
let ndim = self.ndim();
let (start, end) = if level == 0 {
(0, self.fib_idx[0].len())
} else {
let s = self.fib_ptr[level - 1].get(fiber_idx).copied().unwrap_or(0);
let e = self.fib_ptr[level - 1]
.get(fiber_idx + 1)
.copied()
.unwrap_or(self.fib_idx[level].len());
(s, e)
};
let current_mode = self.mode_order[level];
for i in start..end {
if i >= self.fib_idx[level].len() {
break;
}
let coord = self.fib_idx[level][i];
if current_mode == free_mode {
coord_stack.push(coord);
if level == ndim - 1 {
if self.check_fixed_coords(coord_stack, free_mode, fixed_indices) {
if let Some(&val) = self.values.get(i) {
result.push((coord, val));
}
}
} else {
self.collect_fiber(level + 1, i, free_mode, fixed_indices, coord_stack, result);
}
coord_stack.pop();
} else {
let fixed_idx = self.fixed_index_for_mode(current_mode, free_mode);
if let Some(fidx) = fixed_idx {
if fidx < fixed_indices.len() && coord == fixed_indices[fidx] {
coord_stack.push(coord);
if level == ndim - 1 {
if self.check_fixed_coords(coord_stack, free_mode, fixed_indices) {
if let Some(&val) = self.values.get(i) {
let free_coord = self.find_free_coord(coord_stack, free_mode);
if let Some(fc) = free_coord {
result.push((fc, val));
}
}
}
} else {
self.collect_fiber(
level + 1,
i,
free_mode,
fixed_indices,
coord_stack,
result,
);
}
coord_stack.pop();
}
}
}
}
}
fn fixed_index_for_mode(&self, mode: usize, free_mode: usize) -> Option<usize> {
if mode == free_mode {
return None;
}
let mut idx = 0usize;
for m in 0..self.ndim() {
if m == free_mode {
continue;
}
if m == mode {
return Some(idx);
}
idx += 1;
}
None
}
fn check_fixed_coords(
&self,
coord_stack: &[usize],
free_mode: usize,
fixed_indices: &[usize],
) -> bool {
let mut fix_idx = 0usize;
for (level, &coord) in coord_stack.iter().enumerate() {
if level >= self.mode_order.len() {
break;
}
let mode = self.mode_order[level];
if mode == free_mode {
continue;
}
if fix_idx >= fixed_indices.len() || coord != fixed_indices[fix_idx] {
return false;
}
fix_idx += 1;
}
true
}
fn find_free_coord(&self, coord_stack: &[usize], free_mode: usize) -> Option<usize> {
for (level, &coord) in coord_stack.iter().enumerate() {
if level < self.mode_order.len() && self.mode_order[level] == free_mode {
return Some(coord);
}
}
None
}
pub fn matricize(&self, mode: usize) -> SparseResult<(Vec<usize>, Vec<usize>, Vec<T>)> {
let ndim = self.ndim();
if mode >= ndim {
return Err(SparseError::ValueError(format!(
"mode {} >= ndim {}",
mode, ndim
)));
}
let other_modes: Vec<usize> = (0..ndim).filter(|&m| m != mode).collect();
let mut col_strides: Vec<usize> = Vec::with_capacity(other_modes.len());
let mut stride = 1usize;
for &m in other_modes.iter().rev() {
col_strides.push(stride);
stride = stride.saturating_mul(self.shape[m]);
}
col_strides.reverse();
let mut rows = Vec::new();
let mut cols = Vec::new();
let mut vals = Vec::new();
let mut coord_stack: Vec<usize> = vec![0; ndim];
self.traverse_for_matricize(
0,
0,
&mut coord_stack,
mode,
&other_modes,
&col_strides,
&mut rows,
&mut cols,
&mut vals,
);
Ok((rows, cols, vals))
}
fn traverse_for_matricize(
&self,
level: usize,
fiber_idx: usize,
coord_stack: &mut Vec<usize>,
mode: usize,
other_modes: &[usize],
col_strides: &[usize],
rows: &mut Vec<usize>,
cols: &mut Vec<usize>,
vals: &mut Vec<T>,
) {
let ndim = self.ndim();
let (start, end) = if level == 0 {
(0, self.fib_idx[0].len())
} else {
let s = self.fib_ptr[level - 1].get(fiber_idx).copied().unwrap_or(0);
let e = self.fib_ptr[level - 1]
.get(fiber_idx + 1)
.copied()
.unwrap_or(self.fib_idx[level].len());
(s, e)
};
for i in start..end {
if i >= self.fib_idx[level].len() {
break;
}
coord_stack[level] = self.fib_idx[level][i];
if level == ndim - 1 {
if let Some(&val) = self.values.get(i) {
let mut orig_coords = vec![0usize; ndim];
for (l, &c) in coord_stack.iter().enumerate().take(ndim) {
orig_coords[self.mode_order[l]] = c;
}
let row = orig_coords[mode];
let mut col = 0usize;
for (idx, &m) in other_modes.iter().enumerate() {
col += orig_coords[m] * col_strides[idx];
}
rows.push(row);
cols.push(col);
vals.push(val);
}
} else {
self.traverse_for_matricize(
level + 1,
i,
coord_stack,
mode,
other_modes,
col_strides,
rows,
cols,
vals,
);
}
}
}
pub fn memory_usage(&self) -> usize {
let mut total = 0usize;
for fp in &self.fib_ptr {
total += fp.len() * std::mem::size_of::<usize>();
}
for fi in &self.fib_idx {
total += fi.len() * std::mem::size_of::<usize>();
}
total += self.values.len() * std::mem::size_of::<T>();
total += self.shape.len() * std::mem::size_of::<usize>();
total += self.mode_order.len() * std::mem::size_of::<usize>();
total
}
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_relative_eq;
#[test]
fn test_csf_3d_construction_and_access() {
let indices = vec![
vec![0, 0, 0, 1, 1], vec![0, 1, 2, 0, 2], vec![0, 1, 3, 2, 0], ];
let values = vec![1.0, 2.0, 3.0, 4.0, 5.0];
let shape = vec![2, 3, 4];
let csf = CsfTensor::from_coo(&indices, &values, &shape, None).expect("csf");
assert_eq!(csf.ndim(), 3);
assert_eq!(csf.nnz(), 5);
assert_relative_eq!(csf.get(&[0, 0, 0]).unwrap_or(0.0), 1.0, epsilon = 1e-12);
assert_relative_eq!(csf.get(&[0, 1, 1]).unwrap_or(0.0), 2.0, epsilon = 1e-12);
assert_relative_eq!(csf.get(&[0, 2, 3]).unwrap_or(0.0), 3.0, epsilon = 1e-12);
assert_relative_eq!(csf.get(&[1, 0, 2]).unwrap_or(0.0), 4.0, epsilon = 1e-12);
assert_relative_eq!(csf.get(&[1, 2, 0]).unwrap_or(0.0), 5.0, epsilon = 1e-12);
assert_relative_eq!(csf.get(&[0, 0, 1]).unwrap_or(0.0), 0.0, epsilon = 1e-12);
}
#[test]
fn test_csf_fiber_extraction() {
let indices = vec![
vec![0, 0, 1, 2, 2], vec![0, 2, 1, 0, 2], ];
let values = vec![1.0, 2.0, 3.0, 4.0, 5.0];
let shape = vec![3, 3];
let csf = CsfTensor::from_coo(&indices, &values, &shape, None).expect("csf");
let fiber = csf.fiber(1, &[0]).expect("fiber");
assert_eq!(fiber.len(), 2);
let fiber_map: std::collections::HashMap<usize, f64> = fiber.into_iter().collect();
assert_relative_eq!(*fiber_map.get(&0).unwrap_or(&0.0), 1.0, epsilon = 1e-12);
assert_relative_eq!(*fiber_map.get(&2).unwrap_or(&0.0), 2.0, epsilon = 1e-12);
let fiber = csf.fiber(0, &[2]).expect("fiber");
let fiber_map: std::collections::HashMap<usize, f64> = fiber.into_iter().collect();
assert_relative_eq!(*fiber_map.get(&0).unwrap_or(&0.0), 2.0, epsilon = 1e-12);
assert_relative_eq!(*fiber_map.get(&2).unwrap_or(&0.0), 5.0, epsilon = 1e-12);
}
#[test]
fn test_csf_matricize() {
let indices = vec![
vec![0, 0, 1, 1], vec![0, 1, 0, 2], vec![0, 1, 0, 1], ];
let values = vec![1.0, 2.0, 3.0, 4.0];
let shape = vec![2, 3, 2];
let csf = CsfTensor::from_coo(&indices, &values, &shape, None).expect("csf");
let (rows, cols, vals) = csf.matricize(0).expect("matricize");
assert_eq!(rows.len(), 4);
for ((&r, &c), &v) in rows.iter().zip(cols.iter()).zip(vals.iter()) {
assert!(r < 2);
assert!(c < 6);
assert!(v != 0.0);
}
}
#[test]
fn test_csf_empty() {
let indices: Vec<Vec<usize>> = vec![Vec::new(), Vec::new()];
let values: Vec<f64> = Vec::new();
let shape = vec![3, 4];
let csf = CsfTensor::from_coo(&indices, &values, &shape, None).expect("csf");
assert_eq!(csf.nnz(), 0);
assert_eq!(csf.ndim(), 2);
}
#[test]
fn test_csf_with_mode_order() {
let indices = vec![
vec![0, 0, 1], vec![0, 1, 0], ];
let values = vec![1.0, 2.0, 3.0];
let shape = vec![2, 2];
let csf = CsfTensor::from_coo(&indices, &values, &shape, Some(&[1, 0])).expect("csf");
assert_eq!(csf.nnz(), 3);
assert_relative_eq!(csf.get(&[0, 0]).unwrap_or(0.0), 1.0, epsilon = 1e-12);
assert_relative_eq!(csf.get(&[0, 1]).unwrap_or(0.0), 2.0, epsilon = 1e-12);
assert_relative_eq!(csf.get(&[1, 0]).unwrap_or(0.0), 3.0, epsilon = 1e-12);
}
#[test]
fn test_csf_memory_usage() {
let indices = vec![vec![0, 1], vec![0, 1]];
let values = vec![1.0, 2.0];
let shape = vec![2, 2];
let csf = CsfTensor::from_coo(&indices, &values, &shape, None).expect("csf");
assert!(csf.memory_usage() > 0);
}
#[test]
fn test_csf_invalid_mode_order() {
let indices = vec![vec![0], vec![0]];
let values = vec![1.0];
let shape = vec![2, 2];
assert!(CsfTensor::from_coo(&indices, &values, &shape, Some(&[0, 0])).is_err());
assert!(CsfTensor::from_coo(&indices, &values, &shape, Some(&[0])).is_err());
}
}