use super::super::CpuRuntime;
use crate::dtype::Element;
use crate::error::{Error, Result};
use crate::tensor::Tensor;
pub(crate) use crate::runtime::common::sparse_utils::zero_tolerance;
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
pub(crate) enum MergeStrategy {
Union,
Intersection,
}
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
pub(crate) enum OperationSemantics {
Add,
Subtract,
Multiply,
Divide,
}
fn handle_empty_compressed<T: Element>(
a_nnz: usize,
b_nnz: usize,
shape: [usize; 2],
device: &<CpuRuntime as crate::runtime::Runtime>::Device,
a_ptrs: &Tensor<CpuRuntime>,
a_indices: &Tensor<CpuRuntime>,
a_values: &Tensor<CpuRuntime>,
b_ptrs: &Tensor<CpuRuntime>,
b_indices: &Tensor<CpuRuntime>,
b_values: &Tensor<CpuRuntime>,
semantics: OperationSemantics,
format_is_csr: bool,
) -> Option<Result<(Tensor<CpuRuntime>, Tensor<CpuRuntime>, Tensor<CpuRuntime>)>> {
let ptr_dim = if format_is_csr { shape[0] } else { shape[1] };
let empty_result = || {
let empty_ptrs = Tensor::from_slice(&vec![0i64; ptr_dim + 1], &[ptr_dim + 1], device);
let empty_indices = Tensor::from_slice(&Vec::<i64>::new(), &[0], device);
let empty_vals = Tensor::from_slice(&Vec::<T>::new(), &[0], device);
(empty_ptrs, empty_indices, empty_vals)
};
match (a_nnz, b_nnz, semantics) {
(0, 0, _) => Some(Ok(empty_result())),
(0, _, OperationSemantics::Add) => {
Some(Ok((b_ptrs.clone(), b_indices.clone(), b_values.clone())))
}
(0, _, OperationSemantics::Subtract) => {
let b_vals: Vec<T> = b_values.to_vec();
let negated_vals: Vec<T> = b_vals.iter().map(|&v| T::from_f64(-v.to_f64())).collect();
let out_vals = Tensor::from_slice(&negated_vals, &[negated_vals.len()], device);
Some(Ok((b_ptrs.clone(), b_indices.clone(), out_vals)))
}
(0, _, OperationSemantics::Multiply) => {
Some(Ok(empty_result()))
}
(0, _, OperationSemantics::Divide) => {
Some(Ok(empty_result()))
}
(_, 0, OperationSemantics::Add) => {
Some(Ok((a_ptrs.clone(), a_indices.clone(), a_values.clone())))
}
(_, 0, OperationSemantics::Subtract) => {
Some(Ok((a_ptrs.clone(), a_indices.clone(), a_values.clone())))
}
(_, 0, OperationSemantics::Multiply) => {
Some(Ok(empty_result()))
}
(_, 0, OperationSemantics::Divide) => {
Some(Err(Error::Internal(
"Division by zero - B matrix is empty".to_string(),
)))
}
(_, _, _) => None,
}
}
pub(crate) fn merge_csr_impl<T: Element, F, FA, FB>(
a_row_ptrs: &Tensor<CpuRuntime>,
a_col_indices: &Tensor<CpuRuntime>,
a_values: &Tensor<CpuRuntime>,
b_row_ptrs: &Tensor<CpuRuntime>,
b_col_indices: &Tensor<CpuRuntime>,
b_values: &Tensor<CpuRuntime>,
shape: [usize; 2],
strategy: MergeStrategy,
semantics: OperationSemantics,
op: F,
only_a_op: FA,
only_b_op: FB,
) -> Result<(Tensor<CpuRuntime>, Tensor<CpuRuntime>, Tensor<CpuRuntime>)>
where
F: Fn(T, T) -> T,
FA: Fn(T) -> T,
FB: Fn(T) -> T,
{
let [nrows, _ncols] = shape;
let device = a_values.device();
if let Some(result) = handle_empty_compressed::<T>(
a_values.numel(),
b_values.numel(),
shape,
device,
a_row_ptrs,
a_col_indices,
a_values,
b_row_ptrs,
b_col_indices,
b_values,
semantics,
true, ) {
return result;
}
let a_row_ptrs_data: Vec<i64> = a_row_ptrs.to_vec();
let a_col_indices_data: Vec<i64> = a_col_indices.to_vec();
let a_values_data: Vec<T> = a_values.to_vec();
let b_row_ptrs_data: Vec<i64> = b_row_ptrs.to_vec();
let b_col_indices_data: Vec<i64> = b_col_indices.to_vec();
let b_values_data: Vec<T> = b_values.to_vec();
let mut out_row_ptrs: Vec<i64> = Vec::with_capacity(nrows + 1);
let mut out_col_indices: Vec<i64> = Vec::new();
let mut out_values: Vec<T> = Vec::new();
out_row_ptrs.push(0);
for row in 0..nrows {
let a_start = a_row_ptrs_data[row] as usize;
let a_end = a_row_ptrs_data[row + 1] as usize;
let b_start = b_row_ptrs_data[row] as usize;
let b_end = b_row_ptrs_data[row + 1] as usize;
let mut i = a_start;
let mut j = b_start;
match strategy {
MergeStrategy::Union => {
while i < a_end || j < b_end {
let a_col = if i < a_end {
a_col_indices_data[i]
} else {
i64::MAX
};
let b_col = if j < b_end {
b_col_indices_data[j]
} else {
i64::MAX
};
if a_col < b_col {
let result = only_a_op(a_values_data[i]);
if result.to_f64().abs() > zero_tolerance::<T>() {
out_col_indices.push(a_col);
out_values.push(result);
}
i += 1;
} else if a_col > b_col {
let result = only_b_op(b_values_data[j]);
if result.to_f64().abs() > zero_tolerance::<T>() {
out_col_indices.push(b_col);
out_values.push(result);
}
j += 1;
} else {
let result = op(a_values_data[i], b_values_data[j]);
if result.to_f64().abs() > zero_tolerance::<T>() {
out_col_indices.push(a_col);
out_values.push(result);
}
i += 1;
j += 1;
}
}
}
MergeStrategy::Intersection => {
while i < a_end && j < b_end {
let a_col = a_col_indices_data[i];
let b_col = b_col_indices_data[j];
if a_col < b_col {
i += 1;
} else if a_col > b_col {
j += 1;
} else {
let result = op(a_values_data[i], b_values_data[j]);
if result.to_f64().abs() > zero_tolerance::<T>() {
out_col_indices.push(a_col);
out_values.push(result);
}
i += 1;
j += 1;
}
}
}
}
out_row_ptrs.push(out_col_indices.len() as i64);
}
let result_row_ptrs = Tensor::from_slice(&out_row_ptrs, &[nrows + 1], device);
let result_col_indices = Tensor::from_slice(&out_col_indices, &[out_col_indices.len()], device);
let result_values = Tensor::from_slice(&out_values, &[out_values.len()], device);
Ok((result_row_ptrs, result_col_indices, result_values))
}
pub(crate) fn merge_csc_impl<T: Element, F, FA, FB>(
a_col_ptrs: &Tensor<CpuRuntime>,
a_row_indices: &Tensor<CpuRuntime>,
a_values: &Tensor<CpuRuntime>,
b_col_ptrs: &Tensor<CpuRuntime>,
b_row_indices: &Tensor<CpuRuntime>,
b_values: &Tensor<CpuRuntime>,
shape: [usize; 2],
strategy: MergeStrategy,
semantics: OperationSemantics,
op: F,
only_a_op: FA,
only_b_op: FB,
) -> Result<(Tensor<CpuRuntime>, Tensor<CpuRuntime>, Tensor<CpuRuntime>)>
where
F: Fn(T, T) -> T,
FA: Fn(T) -> T,
FB: Fn(T) -> T,
{
let [_nrows, ncols] = shape;
let device = a_values.device();
if let Some(result) = handle_empty_compressed::<T>(
a_values.numel(),
b_values.numel(),
shape,
device,
a_col_ptrs,
a_row_indices,
a_values,
b_col_ptrs,
b_row_indices,
b_values,
semantics,
false, ) {
return result;
}
let a_col_ptrs_data: Vec<i64> = a_col_ptrs.to_vec();
let a_row_indices_data: Vec<i64> = a_row_indices.to_vec();
let a_values_data: Vec<T> = a_values.to_vec();
let b_col_ptrs_data: Vec<i64> = b_col_ptrs.to_vec();
let b_row_indices_data: Vec<i64> = b_row_indices.to_vec();
let b_values_data: Vec<T> = b_values.to_vec();
let mut out_col_ptrs: Vec<i64> = Vec::with_capacity(ncols + 1);
let mut out_row_indices: Vec<i64> = Vec::new();
let mut out_values: Vec<T> = Vec::new();
out_col_ptrs.push(0);
for col in 0..ncols {
let a_start = a_col_ptrs_data[col] as usize;
let a_end = a_col_ptrs_data[col + 1] as usize;
let b_start = b_col_ptrs_data[col] as usize;
let b_end = b_col_ptrs_data[col + 1] as usize;
let mut i = a_start;
let mut j = b_start;
match strategy {
MergeStrategy::Union => {
while i < a_end || j < b_end {
let a_row = if i < a_end {
a_row_indices_data[i]
} else {
i64::MAX
};
let b_row = if j < b_end {
b_row_indices_data[j]
} else {
i64::MAX
};
if a_row < b_row {
let result = only_a_op(a_values_data[i]);
if result.to_f64().abs() > zero_tolerance::<T>() {
out_row_indices.push(a_row);
out_values.push(result);
}
i += 1;
} else if a_row > b_row {
let result = only_b_op(b_values_data[j]);
if result.to_f64().abs() > zero_tolerance::<T>() {
out_row_indices.push(b_row);
out_values.push(result);
}
j += 1;
} else {
let result = op(a_values_data[i], b_values_data[j]);
if result.to_f64().abs() > zero_tolerance::<T>() {
out_row_indices.push(a_row);
out_values.push(result);
}
i += 1;
j += 1;
}
}
}
MergeStrategy::Intersection => {
while i < a_end && j < b_end {
let a_row = a_row_indices_data[i];
let b_row = b_row_indices_data[j];
if a_row < b_row {
i += 1;
} else if a_row > b_row {
j += 1;
} else {
let result = op(a_values_data[i], b_values_data[j]);
if result.to_f64().abs() > zero_tolerance::<T>() {
out_row_indices.push(a_row);
out_values.push(result);
}
i += 1;
j += 1;
}
}
}
}
out_col_ptrs.push(out_row_indices.len() as i64);
}
let result_col_ptrs = Tensor::from_slice(&out_col_ptrs, &[ncols + 1], device);
let result_row_indices = Tensor::from_slice(&out_row_indices, &[out_row_indices.len()], device);
let result_values = Tensor::from_slice(&out_values, &[out_values.len()], device);
Ok((result_col_ptrs, result_row_indices, result_values))
}
pub(crate) fn merge_coo_impl<T: Element, F, FA, FB>(
a_row_indices: &Tensor<CpuRuntime>,
a_col_indices: &Tensor<CpuRuntime>,
a_values: &Tensor<CpuRuntime>,
b_row_indices: &Tensor<CpuRuntime>,
b_col_indices: &Tensor<CpuRuntime>,
b_values: &Tensor<CpuRuntime>,
_semantics: OperationSemantics,
op: F,
only_a_op: FA,
only_b_op: FB,
) -> Result<(Tensor<CpuRuntime>, Tensor<CpuRuntime>, Tensor<CpuRuntime>)>
where
F: Fn(T, T) -> T,
FA: Fn(T) -> T,
FB: Fn(T) -> T,
{
let device = a_values.device();
let a_nnz = a_values.numel();
let b_nnz = b_values.numel();
if a_nnz == 0 && b_nnz == 0 {
let empty_rows = Tensor::from_slice(&Vec::<i64>::new(), &[0], device);
let empty_cols = Tensor::from_slice(&Vec::<i64>::new(), &[0], device);
let empty_vals = Tensor::from_slice(&Vec::<T>::new(), &[0], device);
return Ok((empty_rows, empty_cols, empty_vals));
} else if a_nnz == 0 {
let b_vals: Vec<T> = b_values.to_vec();
let transformed_vals: Vec<T> = b_vals.iter().map(|&v| only_b_op(v)).collect();
let out_vals = Tensor::from_slice(&transformed_vals, &[transformed_vals.len()], device);
return Ok((b_row_indices.clone(), b_col_indices.clone(), out_vals));
} else if b_nnz == 0 {
let a_vals: Vec<T> = a_values.to_vec();
let transformed_vals: Vec<T> = a_vals.iter().map(|&v| only_a_op(v)).collect();
let out_vals = Tensor::from_slice(&transformed_vals, &[transformed_vals.len()], device);
return Ok((a_row_indices.clone(), a_col_indices.clone(), out_vals));
}
let a_rows: Vec<i64> = a_row_indices.to_vec();
let a_cols: Vec<i64> = a_col_indices.to_vec();
let a_vals: Vec<T> = a_values.to_vec();
let b_rows: Vec<i64> = b_row_indices.to_vec();
let b_cols: Vec<i64> = b_col_indices.to_vec();
let b_vals: Vec<T> = b_values.to_vec();
let mut triplets: Vec<(i64, i64, T, bool)> = Vec::new();
for i in 0..a_rows.len() {
triplets.push((a_rows[i], a_cols[i], a_vals[i], true));
}
for i in 0..b_rows.len() {
triplets.push((b_rows[i], b_cols[i], b_vals[i], false));
}
triplets.sort_by_key(|&(r, c, _, _)| (r, c));
let mut result_rows: Vec<i64> = Vec::new();
let mut result_cols: Vec<i64> = Vec::new();
let mut result_vals: Vec<T> = Vec::new();
if triplets.is_empty() {
let empty_rows = Tensor::from_slice(&result_rows, &[0], device);
let empty_cols = Tensor::from_slice(&result_cols, &[0], device);
let empty_vals = Tensor::from_slice(&result_vals, &[0], device);
return Ok((empty_rows, empty_cols, empty_vals));
}
let mut current_row = triplets[0].0;
let mut current_col = triplets[0].1;
let mut current_val = triplets[0].2;
let mut current_from_a = triplets[0].3;
let mut current_merged = false;
for i in 1..triplets.len() {
let (row, col, val, from_a) = triplets[i];
if row == current_row && col == current_col {
current_val = op(current_val, val);
current_merged = true; } else {
let final_val = if current_merged {
current_val
} else {
if current_from_a {
only_a_op(current_val)
} else {
only_b_op(current_val)
}
};
if final_val.to_f64().abs() > zero_tolerance::<T>() {
result_rows.push(current_row);
result_cols.push(current_col);
result_vals.push(final_val);
}
current_row = row;
current_col = col;
current_val = val;
current_from_a = from_a;
current_merged = false;
}
}
let final_val = if current_merged {
current_val
} else {
if current_from_a {
only_a_op(current_val)
} else {
only_b_op(current_val)
}
};
if final_val.to_f64().abs() > zero_tolerance::<T>() {
result_rows.push(current_row);
result_cols.push(current_col);
result_vals.push(final_val);
}
let out_rows = Tensor::from_slice(&result_rows, &[result_rows.len()], device);
let out_cols = Tensor::from_slice(&result_cols, &[result_cols.len()], device);
let out_vals = Tensor::from_slice(&result_vals, &[result_vals.len()], device);
Ok((out_rows, out_cols, out_vals))
}
pub(crate) fn intersect_coo_impl<T: Element, F>(
a_row_indices: &Tensor<CpuRuntime>,
a_col_indices: &Tensor<CpuRuntime>,
a_values: &Tensor<CpuRuntime>,
b_row_indices: &Tensor<CpuRuntime>,
b_col_indices: &Tensor<CpuRuntime>,
b_values: &Tensor<CpuRuntime>,
op: F,
) -> Result<(Tensor<CpuRuntime>, Tensor<CpuRuntime>, Tensor<CpuRuntime>)>
where
F: Fn(T, T) -> T,
{
let device = a_values.device();
let a_nnz = a_values.numel();
let b_nnz = b_values.numel();
if a_nnz == 0 || b_nnz == 0 {
let empty_rows = Tensor::from_slice(&Vec::<i64>::new(), &[0], device);
let empty_cols = Tensor::from_slice(&Vec::<i64>::new(), &[0], device);
let empty_vals = Tensor::from_slice(&Vec::<T>::new(), &[0], device);
return Ok((empty_rows, empty_cols, empty_vals));
}
let a_rows: Vec<i64> = a_row_indices.to_vec();
let a_cols: Vec<i64> = a_col_indices.to_vec();
let a_vals: Vec<T> = a_values.to_vec();
let b_rows: Vec<i64> = b_row_indices.to_vec();
let b_cols: Vec<i64> = b_col_indices.to_vec();
let b_vals: Vec<T> = b_values.to_vec();
let mut a_triplets: Vec<(i64, i64, T)> = a_rows
.iter()
.zip(a_cols.iter())
.zip(a_vals.iter())
.map(|((&r, &c), &v)| (r, c, v))
.collect();
a_triplets.sort_by_key(|&(r, c, _)| (r, c));
let mut b_triplets: Vec<(i64, i64, T)> = b_rows
.iter()
.zip(b_cols.iter())
.zip(b_vals.iter())
.map(|((&r, &c), &v)| (r, c, v))
.collect();
b_triplets.sort_by_key(|&(r, c, _)| (r, c));
let mut result_rows: Vec<i64> = Vec::new();
let mut result_cols: Vec<i64> = Vec::new();
let mut result_vals: Vec<T> = Vec::new();
let mut i = 0;
let mut j = 0;
while i < a_triplets.len() && j < b_triplets.len() {
let (a_row, a_col, a_val) = a_triplets[i];
let (b_row, b_col, b_val) = b_triplets[j];
match (a_row.cmp(&b_row), a_col.cmp(&b_col)) {
(std::cmp::Ordering::Less, _)
| (std::cmp::Ordering::Equal, std::cmp::Ordering::Less) => {
i += 1;
}
(std::cmp::Ordering::Greater, _)
| (std::cmp::Ordering::Equal, std::cmp::Ordering::Greater) => {
j += 1;
}
(std::cmp::Ordering::Equal, std::cmp::Ordering::Equal) => {
let result = op(a_val, b_val);
if result.to_f64().abs() > zero_tolerance::<T>() {
result_rows.push(a_row);
result_cols.push(a_col);
result_vals.push(result);
}
i += 1;
j += 1;
}
}
}
let out_rows = Tensor::from_slice(&result_rows, &[result_rows.len()], device);
let out_cols = Tensor::from_slice(&result_cols, &[result_cols.len()], device);
let out_vals = Tensor::from_slice(&result_vals, &[result_vals.len()], device);
Ok((out_rows, out_cols, out_vals))
}