use crate::pattern::SparsityPattern;
use std::iter;
pub fn spadd_pattern(a: &SparsityPattern, b: &SparsityPattern) -> SparsityPattern {
assert_eq!(
a.major_dim(),
b.major_dim(),
"Patterns must have identical major dimensions."
);
assert_eq!(
a.minor_dim(),
b.minor_dim(),
"Patterns must have identical minor dimensions."
);
let mut offsets = Vec::new();
let mut indices = Vec::new();
offsets.reserve(a.major_dim() + 1);
indices.clear();
offsets.push(0);
for lane_idx in 0..a.major_dim() {
let lane_a = a.lane(lane_idx);
let lane_b = b.lane(lane_idx);
indices.extend(iterate_union(lane_a, lane_b));
offsets.push(indices.len());
}
SparsityPattern::try_from_offsets_and_indices(a.major_dim(), a.minor_dim(), offsets, indices)
.expect("Internal error: Pattern must be valid by definition")
}
pub fn spmm_csc_pattern(a: &SparsityPattern, b: &SparsityPattern) -> SparsityPattern {
spmm_csr_pattern(b, a)
}
pub fn spmm_csr_pattern(a: &SparsityPattern, b: &SparsityPattern) -> SparsityPattern {
assert_eq!(
a.minor_dim(),
b.major_dim(),
"a and b must have compatible dimensions"
);
let mut offsets = Vec::new();
let mut indices = Vec::new();
offsets.push(0);
let mut visited = vec![false; b.minor_dim()];
for i in 0..a.major_dim() {
let a_lane_i = a.lane(i);
let c_lane_i_offset = *offsets.last().unwrap();
for &k in a_lane_i {
let b_lane_k = b.lane(k);
for &j in b_lane_k {
let have_visited_j = &mut visited[j];
if !*have_visited_j {
indices.push(j);
*have_visited_j = true;
}
}
}
let c_lane_i = &mut indices[c_lane_i_offset..];
c_lane_i.sort_unstable();
for j in c_lane_i {
visited[*j] = false;
}
offsets.push(indices.len());
}
SparsityPattern::try_from_offsets_and_indices(a.major_dim(), b.minor_dim(), offsets, indices)
.expect("Internal error: Invalid pattern during matrix multiplication pattern construction")
}
fn iterate_union<'a>(
mut sorted_a: &'a [usize],
mut sorted_b: &'a [usize],
) -> impl Iterator<Item = usize> + 'a {
iter::from_fn(move || {
if let (Some(a_item), Some(b_item)) = (sorted_a.first(), sorted_b.first()) {
let item = match a_item.cmp(b_item) {
std::cmp::Ordering::Less => {
sorted_a = &sorted_a[1..];
a_item
}
std::cmp::Ordering::Greater => {
sorted_b = &sorted_b[1..];
b_item
}
std::cmp::Ordering::Equal => {
sorted_a = &sorted_a[1..];
sorted_b = &sorted_b[1..];
a_item
}
};
Some(*item)
} else if let Some(a_item) = sorted_a.first() {
sorted_a = &sorted_a[1..];
Some(*a_item)
} else if let Some(b_item) = sorted_b.first() {
sorted_b = &sorted_b[1..];
Some(*b_item)
} else {
None
}
})
}