use std::collections::HashMap;
use crate::opcode::{OpCode, UNUSED};
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct SparsityPattern {
pub dim: usize,
pub rows: Vec<u32>,
pub cols: Vec<u32>,
}
impl SparsityPattern {
#[must_use]
pub fn nnz(&self) -> usize {
self.rows.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.rows.is_empty()
}
#[must_use]
pub fn contains(&self, i: usize, j: usize) -> bool {
let (r, c) = if i >= j { (i, j) } else { (j, i) };
self.rows
.iter()
.zip(self.cols.iter())
.any(|(&row, &col)| row as usize == r && col as usize == c)
}
}
pub(crate) fn detect_sparsity_impl(
opcodes: &[OpCode],
arg_indices: &[[u32; 2]],
custom_second_args: &HashMap<u32, u32>,
num_inputs: usize,
num_vars: usize,
) -> SparsityPattern {
let num_words = num_inputs.div_ceil(64);
let mut deps: Vec<Vec<u64>> = vec![vec![0u64; num_words]; num_vars];
let mut interactions: Vec<(u32, u32)> = Vec::new();
let mut input_idx = 0u32;
for i in 0..opcodes.len() {
match opcodes[i] {
OpCode::Input => {
let word = input_idx as usize / 64;
let bit = input_idx as usize % 64;
deps[i][word] |= 1u64 << bit;
input_idx += 1;
}
OpCode::Const => {
}
op => {
let [a_idx, b_idx] = arg_indices[i];
let a = a_idx as usize;
match classify_op(op) {
OpClass::Linear => {
union_into(&mut deps, i, a);
if is_binary_op(op) && b_idx != UNUSED {
union_into(&mut deps, i, b_idx as usize);
}
}
OpClass::UnaryNonlinear => {
union_into(&mut deps, i, a);
let bits = extract_bits(&deps[i], num_inputs);
for ii in 0..bits.len() {
for jj in 0..=ii {
let (r, c) = if bits[ii] >= bits[jj] {
(bits[ii], bits[jj])
} else {
(bits[jj], bits[ii])
};
interactions.push((r, c));
}
}
}
OpClass::BinaryNonlinear => {
let real_b = if op == OpCode::Custom {
custom_second_args.get(&(i as u32)).map(|&v| v as usize)
} else {
Some(b_idx as usize)
};
if let Some(b) = real_b {
let bits_a = extract_bits(&deps[a], num_inputs);
let bits_b = extract_bits(&deps[b], num_inputs);
union_into(&mut deps, i, a);
union_into(&mut deps, i, b);
for &va in &bits_a {
for &vb in &bits_b {
let (r, c) = if va >= vb { (va, vb) } else { (vb, va) };
interactions.push((r, c));
}
}
if op != OpCode::Mul {
for ii in 0..bits_a.len() {
for jj in 0..=ii {
let (r, c) = if bits_a[ii] >= bits_a[jj] {
(bits_a[ii], bits_a[jj])
} else {
(bits_a[jj], bits_a[ii])
};
interactions.push((r, c));
}
}
for ii in 0..bits_b.len() {
for jj in 0..=ii {
let (r, c) = if bits_b[ii] >= bits_b[jj] {
(bits_b[ii], bits_b[jj])
} else {
(bits_b[jj], bits_b[ii])
};
interactions.push((r, c));
}
}
}
} else {
union_into(&mut deps, i, a);
let bits = extract_bits(&deps[i], num_inputs);
for ii in 0..bits.len() {
for jj in 0..=ii {
let (r, c) = if bits[ii] >= bits[jj] {
(bits[ii], bits[jj])
} else {
(bits[jj], bits[ii])
};
interactions.push((r, c));
}
}
}
}
OpClass::ZeroDerivative => {
union_into(&mut deps, i, a);
if is_binary_op(op) && b_idx != UNUSED {
union_into(&mut deps, i, b_idx as usize);
}
}
}
}
}
}
interactions.sort_unstable();
interactions.dedup();
let entries = interactions;
let rows: Vec<u32> = entries.iter().map(|&(r, _)| r).collect();
let cols: Vec<u32> = entries.iter().map(|&(_, c)| c).collect();
SparsityPattern {
dim: num_inputs,
rows,
cols,
}
}
#[derive(Debug, Clone, Copy)]
enum OpClass {
Linear,
UnaryNonlinear,
BinaryNonlinear,
ZeroDerivative,
}
fn classify_op(op: OpCode) -> OpClass {
match op {
OpCode::Add | OpCode::Sub | OpCode::Neg | OpCode::Fract => OpClass::Linear,
OpCode::Recip
| OpCode::Sqrt
| OpCode::Cbrt
| OpCode::Powi
| OpCode::Exp
| OpCode::Exp2
| OpCode::ExpM1
| OpCode::Ln
| OpCode::Log2
| OpCode::Log10
| OpCode::Ln1p
| OpCode::Sin
| OpCode::Cos
| OpCode::Tan
| OpCode::Asin
| OpCode::Acos
| OpCode::Atan
| OpCode::Sinh
| OpCode::Cosh
| OpCode::Tanh
| OpCode::Asinh
| OpCode::Acosh
| OpCode::Atanh => OpClass::UnaryNonlinear,
OpCode::Mul | OpCode::Div | OpCode::Powf | OpCode::Atan2 | OpCode::Hypot => {
OpClass::BinaryNonlinear
}
OpCode::Abs
| OpCode::Signum
| OpCode::Floor
| OpCode::Ceil
| OpCode::Round
| OpCode::Trunc
| OpCode::Max
| OpCode::Min
| OpCode::Rem => OpClass::ZeroDerivative,
OpCode::Custom => OpClass::BinaryNonlinear,
OpCode::Input | OpCode::Const => unreachable!(),
}
}
pub(crate) fn is_binary_op(op: OpCode) -> bool {
matches!(
op,
OpCode::Add
| OpCode::Sub
| OpCode::Mul
| OpCode::Div
| OpCode::Rem
| OpCode::Powf
| OpCode::Atan2
| OpCode::Hypot
| OpCode::Max
| OpCode::Min
| OpCode::Custom
)
}
pub(crate) fn union_into(deps: &mut [Vec<u64>], dst: usize, src: usize) {
if dst == src {
return;
}
let (a, b) = if dst < src {
let (left, right) = deps.split_at_mut(src);
(&mut left[dst], &right[0] as &[u64])
} else {
let (left, right) = deps.split_at_mut(dst);
(&mut right[0], &left[src] as &[u64])
};
for w in 0..a.len() {
a[w] |= b[w];
}
}
pub(crate) fn extract_bits(bitset: &[u64], max_bits: usize) -> Vec<u32> {
let mut result = Vec::new();
for (word_idx, &word) in bitset.iter().enumerate() {
if word == 0 {
continue;
}
let mut w = word;
while w != 0 {
let bit = w.trailing_zeros();
let pos = word_idx * 64 + bit as usize;
if pos < max_bits {
result.push(pos as u32);
}
w &= w - 1; }
}
result
}
#[must_use]
pub fn greedy_coloring(pattern: &SparsityPattern) -> (Vec<u32>, u32) {
let n = pattern.dim;
if n == 0 {
return (Vec::new(), 0);
}
let mut adj: Vec<Vec<u32>> = vec![Vec::new(); n];
for (&r, &c) in pattern.rows.iter().zip(pattern.cols.iter()) {
let r = r as usize;
let c = c as usize;
if r != c {
adj[r].push(c as u32);
adj[c].push(r as u32);
}
}
let mut adj2: Vec<Vec<u32>> = vec![Vec::new(); n];
for v in 0..n {
for &u in &adj[v] {
adj2[v].push(u);
}
for &u in &adj[v] {
for &w in &adj[u as usize] {
if w as usize != v {
adj2[v].push(w);
}
}
}
adj2[v].sort_unstable();
adj2[v].dedup();
}
greedy_distance1_coloring(&adj2, n)
}
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct CsrPattern {
pub dim: usize,
pub row_ptr: Vec<u32>,
pub col_ind: Vec<u32>,
}
impl CsrPattern {
#[must_use]
pub fn nnz(&self) -> usize {
self.col_ind.len()
}
pub fn reorder_values<F: Copy>(&self, coo: &SparsityPattern, coo_vals: &[F]) -> Vec<F> {
assert_eq!(coo_vals.len(), coo.nnz());
assert_eq!(self.nnz(), coo.nnz());
let mut result = Vec::with_capacity(self.nnz());
for row in 0..self.dim {
let start = self.row_ptr[row] as usize;
let end = self.row_ptr[row + 1] as usize;
for csr_idx in start..end {
let col = self.col_ind[csr_idx];
let coo_idx = coo
.rows
.iter()
.zip(coo.cols.iter())
.position(|(&r, &c)| r == row as u32 && c == col)
.expect("CSR entry not found in COO pattern");
result.push(coo_vals[coo_idx]);
}
}
result
}
}
impl SparsityPattern {
#[must_use]
pub fn to_csr_lower(&self) -> CsrPattern {
let n = self.dim;
let mut row_ptr = vec![0u32; n + 1];
for &r in &self.rows {
row_ptr[r as usize + 1] += 1;
}
for i in 1..=n {
row_ptr[i] += row_ptr[i - 1];
}
let nnz = self.nnz();
let mut col_ind = vec![0u32; nnz];
let mut pos = vec![0u32; n]; for k in 0..nnz {
let r = self.rows[k] as usize;
let offset = row_ptr[r] + pos[r];
col_ind[offset as usize] = self.cols[k];
pos[r] += 1;
}
CsrPattern {
dim: n,
row_ptr,
col_ind,
}
}
#[must_use]
pub fn to_csr(&self) -> CsrPattern {
let n = self.dim;
let mut row_ptr = vec![0u32; n + 1];
for (&r, &c) in self.rows.iter().zip(self.cols.iter()) {
row_ptr[r as usize + 1] += 1;
if r != c {
row_ptr[c as usize + 1] += 1;
}
}
for i in 1..=n {
row_ptr[i] += row_ptr[i - 1];
}
let nnz = row_ptr[n] as usize;
let mut col_ind = vec![0u32; nnz];
let mut pos = vec![0u32; n];
for (&r, &c) in self.rows.iter().zip(self.cols.iter()) {
let ri = r as usize;
let offset = row_ptr[ri] + pos[ri];
col_ind[offset as usize] = c;
pos[ri] += 1;
if r != c {
let ci = c as usize;
let offset = row_ptr[ci] + pos[ci];
col_ind[offset as usize] = r;
pos[ci] += 1;
}
}
for i in 0..n {
let start = row_ptr[i] as usize;
let end = row_ptr[i + 1] as usize;
col_ind[start..end].sort_unstable();
}
CsrPattern {
dim: n,
row_ptr,
col_ind,
}
}
}
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct JacobianSparsityPattern {
pub num_outputs: usize,
pub num_inputs: usize,
pub rows: Vec<u32>,
pub cols: Vec<u32>,
}
impl JacobianSparsityPattern {
#[must_use]
pub fn nnz(&self) -> usize {
self.rows.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.rows.is_empty()
}
#[must_use]
pub fn contains(&self, output_idx: usize, input_idx: usize) -> bool {
self.rows
.iter()
.zip(self.cols.iter())
.any(|(&r, &c)| r as usize == output_idx && c as usize == input_idx)
}
}
pub(crate) fn detect_jacobian_sparsity_impl(
opcodes: &[OpCode],
arg_indices: &[[u32; 2]],
custom_second_args: &HashMap<u32, u32>,
num_inputs: usize,
num_vars: usize,
output_indices: &[u32],
) -> JacobianSparsityPattern {
let num_words = num_inputs.div_ceil(64);
let mut deps: Vec<Vec<u64>> = vec![vec![0u64; num_words]; num_vars];
let mut input_idx = 0u32;
for i in 0..opcodes.len() {
match opcodes[i] {
OpCode::Input => {
let word = input_idx as usize / 64;
let bit = input_idx as usize % 64;
deps[i][word] |= 1u64 << bit;
input_idx += 1;
}
OpCode::Const => {
}
op => {
let [a_idx, b_idx] = arg_indices[i];
let a = a_idx as usize;
union_into(&mut deps, i, a);
if op == OpCode::Custom {
if let Some(&real_b) = custom_second_args.get(&(i as u32)) {
union_into(&mut deps, i, real_b as usize);
}
} else if is_binary_op(op) && b_idx != UNUSED {
union_into(&mut deps, i, b_idx as usize);
}
}
}
}
let mut rows = Vec::new();
let mut cols = Vec::new();
for (out_row, &out_idx) in output_indices.iter().enumerate() {
let bits = extract_bits(&deps[out_idx as usize], num_inputs);
for input_col in bits {
rows.push(out_row as u32);
cols.push(input_col);
}
}
JacobianSparsityPattern {
num_outputs: output_indices.len(),
num_inputs,
rows,
cols,
}
}
#[must_use]
pub fn column_coloring(pattern: &JacobianSparsityPattern) -> (Vec<u32>, u32) {
intersection_graph_coloring(
&pattern.rows,
&pattern.cols,
pattern.num_outputs,
pattern.num_inputs,
)
}
#[must_use]
pub fn row_coloring(pattern: &JacobianSparsityPattern) -> (Vec<u32>, u32) {
intersection_graph_coloring(
&pattern.cols,
&pattern.rows,
pattern.num_inputs,
pattern.num_outputs,
)
}
fn intersection_graph_coloring(
group_keys: &[u32],
color_keys: &[u32],
group_dim: usize,
color_dim: usize,
) -> (Vec<u32>, u32) {
if color_dim == 0 {
return (Vec::new(), 0);
}
let mut groups: Vec<Vec<u32>> = vec![Vec::new(); group_dim];
for (&g, &c) in group_keys.iter().zip(color_keys.iter()) {
groups[g as usize].push(c);
}
let mut adj: Vec<Vec<u32>> = vec![Vec::new(); color_dim];
for members in &groups {
for i in 0..members.len() {
for j in (i + 1)..members.len() {
let a = members[i] as usize;
let b = members[j] as usize;
adj[a].push(b as u32);
adj[b].push(a as u32);
}
}
}
for list in &mut adj {
list.sort_unstable();
list.dedup();
}
greedy_distance1_coloring(&adj, color_dim)
}
fn greedy_distance1_coloring(adj: &[Vec<u32>], n: usize) -> (Vec<u32>, u32) {
let mut order: Vec<usize> = (0..n).collect();
order.sort_by(|&a, &b| adj[b].len().cmp(&adj[a].len()));
let mut colors = vec![u32::MAX; n];
let mut num_colors = 0u32;
for &v in &order {
let mut used_bits: u64 = 0;
let mut needs_fallback = false;
for &neighbor in &adj[v] {
let c = colors[neighbor as usize];
if c != u32::MAX {
if c < 64 {
used_bits |= 1u64 << c;
} else {
needs_fallback = true;
}
}
}
let color = if !needs_fallback {
(!used_bits).trailing_zeros()
} else {
let mut used_vec: Vec<u32> = adj[v]
.iter()
.filter_map(|&neighbor| {
let c = colors[neighbor as usize];
if c != u32::MAX {
Some(c)
} else {
None
}
})
.collect();
used_vec.sort_unstable();
used_vec.dedup();
let mut c = 0u32;
for &u in &used_vec {
if u != c {
break;
}
c += 1;
}
c
};
colors[v] = color;
if color + 1 > num_colors {
num_colors = color + 1;
}
}
(colors, num_colors)
}