use super::*;
#[derive(Debug, Clone)]
pub(crate) struct BetaEdge {
pub(crate) a: usize,
pub(crate) b: usize,
}
#[derive(Debug, Clone)]
pub(crate) struct BetaCouplingGraph {
pub(crate) num_blocks: usize,
pub(crate) edges: Vec<BetaEdge>,
pub(crate) adj_start: Vec<usize>,
pub(crate) adj_targets: Vec<usize>,
}
impl BetaCouplingGraph {
pub(crate) fn build(block_offsets: &[Range<usize>], htbeta_rows: &[Array2<f64>]) -> Self {
let num_blocks = block_offsets.len();
if num_blocks == 0 {
return Self {
num_blocks: 0,
edges: Vec::new(),
adj_start: vec![0],
adj_targets: Vec::new(),
};
}
let mut edge_set = Vec::<(usize, usize)>::new();
for row in htbeta_rows {
let mut active = Vec::<usize>::new();
for (block, range) in block_offsets.iter().enumerate() {
if range
.clone()
.any(|col| (0..row.nrows()).any(|axis| row[[axis, col]] != 0.0))
{
active.push(block);
}
}
for i in 0..active.len() {
for j in (i + 1)..active.len() {
edge_set.push((active[i].min(active[j]), active[i].max(active[j])));
}
}
}
edge_set.sort_unstable();
edge_set.dedup();
let edges: Vec<_> = edge_set.iter().map(|&(a, b)| BetaEdge { a, b }).collect();
let mut degree = vec![0usize; num_blocks];
for &BetaEdge { a, b } in &edges {
degree[a] += 1;
degree[b] += 1;
}
let mut adj_start = vec![0usize; num_blocks + 1];
for block in 0..num_blocks {
adj_start[block + 1] = adj_start[block] + degree[block];
}
let mut adj_targets = vec![0usize; adj_start[num_blocks]];
let mut cursor = adj_start[..num_blocks].to_vec();
for &BetaEdge { a, b } in &edges {
adj_targets[cursor[a]] = b;
cursor[a] += 1;
adj_targets[cursor[b]] = a;
cursor[b] += 1;
}
Self {
num_blocks,
edges,
adj_start,
adj_targets,
}
}
pub(crate) fn neighbours(&self, node: usize) -> &[usize] {
&self.adj_targets[self.adj_start[node]..self.adj_start[node + 1]]
}
pub(crate) fn component_partition(&self) -> Vec<Vec<usize>> {
let mut parent: Vec<usize> = (0..self.num_blocks).collect();
let mut rank = vec![0u8; self.num_blocks];
fn find(parent: &mut [usize], mut x: usize) -> usize {
while parent[x] != x {
parent[x] = parent[parent[x]];
x = parent[x];
}
x
}
for &BetaEdge { a, b } in &self.edges {
let lhs = find(&mut parent, a);
let rhs = find(&mut parent, b);
if lhs != rhs {
if rank[lhs] < rank[rhs] {
parent[lhs] = rhs;
} else if rank[lhs] > rank[rhs] {
parent[rhs] = lhs;
} else {
parent[rhs] = lhs;
rank[lhs] += 1;
}
}
}
let mut label_map = vec![usize::MAX; self.num_blocks];
let mut parts = Vec::<Vec<usize>>::new();
for block in 0..self.num_blocks {
let root = find(&mut parent, block);
let label = if label_map[root] == usize::MAX {
label_map[root] = parts.len();
parts.push(Vec::new());
label_map[root]
} else {
label_map[root]
};
parts[label].push(block);
}
parts
}
pub(crate) fn expand_one_hop(&self, seed: &[usize]) -> Vec<usize> {
let mut expanded = seed.to_vec();
for &block in seed {
expanded.extend_from_slice(self.neighbours(block));
}
expanded.sort_unstable();
expanded.dedup();
expanded
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct BetaBlockId(pub usize);
pub trait BetaPenaltyOp: Send + Sync {
fn dim(&self) -> usize;
fn matvec(&self, x: &[f64], y: &mut [f64]);
fn gradient(&self, beta: &[f64], out: &mut [f64]);
fn diagonal(&self, diag: &mut [f64]);
fn block(&self, id: BetaBlockId, offsets: &[Range<usize>], out: &mut Array2<f64>);
fn to_dense(&self) -> Array2<f64>;
fn row_abs_sums(&self) -> Array1<f64> {
let dense = self.to_dense();
let k = dense.nrows();
let mut out = Array1::<f64>::zeros(k);
for r in 0..k {
let mut s = 0.0_f64;
for c in 0..dense.ncols() {
s += dense[[r, c]].abs();
}
out[r] = s;
}
out
}
fn fingerprint(&self, hasher: &mut Fingerprinter);
}
pub struct DensePenaltyOp(pub Array2<f64>);
impl BetaPenaltyOp for DensePenaltyOp {
fn dim(&self) -> usize {
self.0.nrows()
}
fn matvec(&self, x: &[f64], y: &mut [f64]) {
let k = self.0.nrows();
for a in 0..k {
let mut acc = 0.0_f64;
for b in 0..k {
acc += self.0[[a, b]] * x[b];
}
y[a] += acc;
}
}
fn gradient(&self, beta: &[f64], out: &mut [f64]) {
let k = self.0.nrows();
for a in 0..k {
let mut acc = 0.0_f64;
for b in 0..k {
acc += self.0[[a, b]] * beta[b];
}
out[a] += acc;
}
}
fn diagonal(&self, diag: &mut [f64]) {
let k = self.0.nrows().min(diag.len());
for j in 0..k {
diag[j] += self.0[[j, j]];
}
}
fn block(&self, id: BetaBlockId, offsets: &[Range<usize>], out: &mut Array2<f64>) {
let range = &offsets[id.0];
let b = range.end - range.start;
for bi in 0..b {
for bj in 0..b {
out[[bi, bj]] += self.0[[range.start + bi, range.start + bj]];
}
}
}
fn to_dense(&self) -> Array2<f64> {
self.0.clone()
}
fn fingerprint(&self, hasher: &mut Fingerprinter) {
hasher.write_str("dense-penalty-op-v1");
hasher.write_f64_array2(&self.0);
}
}
pub struct BlockPenaltyOp {
pub k: usize,
pub blocks: Vec<(usize, Array2<f64>)>,
}
impl BetaPenaltyOp for BlockPenaltyOp {
fn dim(&self) -> usize {
self.k
}
fn matvec(&self, x: &[f64], y: &mut [f64]) {
for (off, local) in &self.blocks {
let b = local.nrows();
for i in 0..b {
let gi = off + i;
let mut acc = 0.0_f64;
for j in 0..b {
acc += local[[i, j]] * x[off + j];
}
y[gi] += acc;
}
}
}
fn gradient(&self, beta: &[f64], out: &mut [f64]) {
for (off, local) in &self.blocks {
let b = local.nrows();
for i in 0..b {
let gi = off + i;
let mut acc = 0.0_f64;
for j in 0..b {
acc += local[[i, j]] * beta[off + j];
}
out[gi] += acc;
}
}
}
fn diagonal(&self, diag: &mut [f64]) {
for (off, local) in &self.blocks {
let b = local.nrows();
for j in 0..b {
diag[off + j] += local[[j, j]];
}
}
}
fn block(&self, id: BetaBlockId, offsets: &[Range<usize>], out: &mut Array2<f64>) {
let range = &offsets[id.0];
let b_out = range.end - range.start;
for (off, local) in &self.blocks {
let b = local.nrows();
let block_end = off + b;
if block_end <= range.start || *off >= range.end {
continue;
}
for bi in 0..b_out {
let gi = range.start + bi;
if gi < *off || gi >= block_end {
continue;
}
let li = gi - off;
for bj in 0..b_out {
let gj = range.start + bj;
if gj < *off || gj >= block_end {
continue;
}
let lj = gj - off;
out[[bi, bj]] += local[[li, lj]];
}
}
}
}
fn to_dense(&self) -> Array2<f64> {
let mut out = Array2::<f64>::zeros((self.k, self.k));
for (off, local) in &self.blocks {
let b = local.nrows();
for i in 0..b {
for j in 0..b {
out[[off + i, off + j]] += local[[i, j]];
}
}
}
out
}
fn fingerprint(&self, hasher: &mut Fingerprinter) {
hasher.write_str("block-penalty-op-v1");
hasher.write_usize(self.k);
hasher.write_usize(self.blocks.len());
for (off, local) in &self.blocks {
hasher.write_usize(*off);
hasher.write_f64_array2(local);
}
}
}
pub struct KroneckerPenaltyOp {
pub factor_a: Array2<f64>,
pub factor_b: Array2<f64>,
pub global_offset: usize,
pub k: usize,
}
impl BetaPenaltyOp for KroneckerPenaltyOp {
fn dim(&self) -> usize {
self.k
}
fn matvec(&self, x: &[f64], y: &mut [f64]) {
let p_a = self.factor_a.nrows();
let p_b = self.factor_b.nrows();
let off = self.global_offset;
for i_a in 0..p_a {
for i_b in 0..p_b {
let gi = off + i_a * p_b + i_b;
let mut acc = 0.0_f64;
for j_a in 0..p_a {
let a_ij = self.factor_a[[i_a, j_a]];
if a_ij == 0.0 {
continue;
}
for j_b in 0..p_b {
acc += a_ij * self.factor_b[[i_b, j_b]] * x[off + j_a * p_b + j_b];
}
}
y[gi] += acc;
}
}
}
fn gradient(&self, beta: &[f64], out: &mut [f64]) {
let p_a = self.factor_a.nrows();
let p_b = self.factor_b.nrows();
let off = self.global_offset;
for i_a in 0..p_a {
for i_b in 0..p_b {
let gi = off + i_a * p_b + i_b;
let mut acc = 0.0_f64;
for j_a in 0..p_a {
let a_ij = self.factor_a[[i_a, j_a]];
if a_ij == 0.0 {
continue;
}
for j_b in 0..p_b {
acc += a_ij * self.factor_b[[i_b, j_b]] * beta[off + j_a * p_b + j_b];
}
}
out[gi] += acc;
}
}
}
fn diagonal(&self, diag: &mut [f64]) {
let p_a = self.factor_a.nrows();
let p_b = self.factor_b.nrows();
let off = self.global_offset;
for i_a in 0..p_a {
for i_b in 0..p_b {
diag[off + i_a * p_b + i_b] +=
self.factor_a[[i_a, i_a]] * self.factor_b[[i_b, i_b]];
}
}
}
fn block(&self, id: BetaBlockId, offsets: &[Range<usize>], out: &mut Array2<f64>) {
let range = &offsets[id.0];
let b = range.end - range.start;
let p_a = self.factor_a.nrows();
let p_b = self.factor_b.nrows();
let off = self.global_offset;
let block_end = off + p_a * p_b;
if block_end <= range.start || off >= range.end {
return;
}
for bi in 0..b {
let gi = range.start + bi;
if gi < off || gi >= block_end {
continue;
}
let li = gi - off;
let i_a = li / p_b;
let i_b = li % p_b;
for bj in 0..b {
let gj = range.start + bj;
if gj < off || gj >= block_end {
continue;
}
let lj = gj - off;
let j_a = lj / p_b;
let j_b = lj % p_b;
out[[bi, bj]] += self.factor_a[[i_a, j_a]] * self.factor_b[[i_b, j_b]];
}
}
}
fn to_dense(&self) -> Array2<f64> {
let p_a = self.factor_a.nrows();
let p_b = self.factor_b.nrows();
let off = self.global_offset;
let mut out = Array2::<f64>::zeros((self.k, self.k));
for i_a in 0..p_a {
for i_b in 0..p_b {
let gi = off + i_a * p_b + i_b;
for j_a in 0..p_a {
let a_ij = self.factor_a[[i_a, j_a]];
if a_ij == 0.0 {
continue;
}
for j_b in 0..p_b {
let gj = off + j_a * p_b + j_b;
out[[gi, gj]] += a_ij * self.factor_b[[i_b, j_b]];
}
}
}
}
out
}
fn fingerprint(&self, hasher: &mut Fingerprinter) {
hasher.write_str("kronecker-penalty-op-v1");
hasher.write_usize(self.global_offset);
hasher.write_usize(self.k);
hasher.write_f64_array2(&self.factor_a);
hasher.write_f64_array2(&self.factor_b);
}
}
pub struct IdentityRightKroneckerPenaltyOp {
pub factor_a: Array2<f64>,
pub p: usize,
pub global_offset: usize,
pub k: usize,
}
impl BetaPenaltyOp for IdentityRightKroneckerPenaltyOp {
fn dim(&self) -> usize {
self.k
}
fn matvec(&self, x: &[f64], y: &mut [f64]) {
let p_a = self.factor_a.nrows();
let p = self.p;
let off = self.global_offset;
for i_a in 0..p_a {
for i_b in 0..p {
let gi = off + i_a * p + i_b;
let mut acc = 0.0_f64;
for j_a in 0..p_a {
let a_ij = self.factor_a[[i_a, j_a]];
if a_ij == 0.0 {
continue;
}
acc += a_ij * x[off + j_a * p + i_b];
}
y[gi] += acc;
}
}
}
fn gradient(&self, beta: &[f64], out: &mut [f64]) {
self.matvec(beta, out);
}
fn diagonal(&self, diag: &mut [f64]) {
let p_a = self.factor_a.nrows();
let p = self.p;
let off = self.global_offset;
for i_a in 0..p_a {
let a_ii = self.factor_a[[i_a, i_a]];
for i_b in 0..p {
diag[off + i_a * p + i_b] += a_ii;
}
}
}
fn block(&self, id: BetaBlockId, offsets: &[Range<usize>], out: &mut Array2<f64>) {
let range = &offsets[id.0];
let b = range.end - range.start;
let p_a = self.factor_a.nrows();
let p = self.p;
let off = self.global_offset;
let block_end = off + p_a * p;
if block_end <= range.start || off >= range.end {
return;
}
for bi in 0..b {
let gi = range.start + bi;
if gi < off || gi >= block_end {
continue;
}
let li = gi - off;
let i_a = li / p;
let i_b = li % p;
for bj in 0..b {
let gj = range.start + bj;
if gj < off || gj >= block_end {
continue;
}
let lj = gj - off;
let j_a = lj / p;
let j_b = lj % p;
if i_b == j_b {
out[[bi, bj]] += self.factor_a[[i_a, j_a]];
}
}
}
}
fn to_dense(&self) -> Array2<f64> {
let p_a = self.factor_a.nrows();
let p = self.p;
let off = self.global_offset;
let mut out = Array2::<f64>::zeros((self.k, self.k));
for i_a in 0..p_a {
for j_a in 0..p_a {
let a_ij = self.factor_a[[i_a, j_a]];
if a_ij == 0.0 {
continue;
}
for i_b in 0..p {
let gi = off + i_a * p + i_b;
let gj = off + j_a * p + i_b;
out[[gi, gj]] += a_ij;
}
}
}
out
}
fn fingerprint(&self, hasher: &mut Fingerprinter) {
hasher.write_str("identity-right-kronecker-penalty-op-v1");
hasher.write_usize(self.global_offset);
hasher.write_usize(self.k);
hasher.write_usize(self.p);
hasher.write_f64_array2(&self.factor_a);
}
}
#[derive(Debug, Clone)]
pub struct SparseGBlock {
pub row_off: usize,
pub col_off: usize,
pub data: Array2<f64>,
}
pub struct SparseBlockKroneckerPenaltyOp {
pub p: usize,
pub dim_a: usize,
pub k: usize,
pub blocks: Vec<SparseGBlock>,
}
#[derive(Debug, Clone)]
pub struct DeviceSaeSmoothBlock {
pub global_offset: usize,
pub factor_a: Array2<f64>,
}
#[derive(Debug, Clone)]
pub struct DeviceSaePcgData {
pub p: usize,
pub beta_dim: usize,
pub a_phi: Vec<Vec<(usize, f64)>>,
pub local_jac: Vec<Vec<f64>>,
pub smooth_blocks: Vec<DeviceSaeSmoothBlock>,
pub sparse_g_blocks: Vec<SparseGBlock>,
}
impl DeviceSaePcgData {
pub(crate) fn a_phi_shared(&self) -> Arc<[Vec<(usize, f64)>]> {
Arc::from(self.a_phi.clone().into_boxed_slice())
}
}
impl BetaPenaltyOp for SparseBlockKroneckerPenaltyOp {
fn dim(&self) -> usize {
self.k
}
fn matvec(&self, x: &[f64], y: &mut [f64]) {
let p = self.p;
for blk in &self.blocks {
let (m_i, m_j) = blk.data.dim();
for li in 0..m_i {
let gi_base = (blk.row_off + li) * p;
for lj in 0..m_j {
let a_ij = blk.data[[li, lj]];
if a_ij == 0.0 {
continue;
}
let gj_base = (blk.col_off + lj) * p;
for oc in 0..p {
y[gi_base + oc] += a_ij * x[gj_base + oc];
}
}
}
}
}
fn gradient(&self, beta: &[f64], out: &mut [f64]) {
self.matvec(beta, out);
}
fn diagonal(&self, diag: &mut [f64]) {
let p = self.p;
for blk in &self.blocks {
if blk.row_off != blk.col_off {
continue;
}
let (m_i, m_j) = blk.data.dim();
let m = m_i.min(m_j);
for li in 0..m {
let a_ii = blk.data[[li, li]];
let gi_base = (blk.row_off + li) * p;
for oc in 0..p {
diag[gi_base + oc] += a_ii;
}
}
}
}
fn block(&self, id: BetaBlockId, offsets: &[Range<usize>], out: &mut Array2<f64>) {
let range = &offsets[id.0];
let b = range.end - range.start;
let p = self.p;
for blk in &self.blocks {
let (m_i, m_j) = blk.data.dim();
let row_start = blk.row_off * p;
let row_end = (blk.row_off + m_i) * p;
let col_start = blk.col_off * p;
let col_end = (blk.col_off + m_j) * p;
if row_end <= range.start
|| row_start >= range.end
|| col_end <= range.start
|| col_start >= range.end
{
continue;
}
for bi in 0..b {
let gi = range.start + bi;
if gi < row_start || gi >= row_end {
continue;
}
let li = (gi - row_start) / p;
let oc_i = (gi - row_start) % p;
for bj in 0..b {
let gj = range.start + bj;
if gj < col_start || gj >= col_end {
continue;
}
let oc_j = (gj - col_start) % p;
if oc_i != oc_j {
continue;
}
let lj = (gj - col_start) / p;
out[[bi, bj]] += blk.data[[li, lj]];
}
}
}
}
fn to_dense(&self) -> Array2<f64> {
let p = self.p;
let mut out = Array2::<f64>::zeros((self.k, self.k));
for blk in &self.blocks {
let (m_i, m_j) = blk.data.dim();
for li in 0..m_i {
let gi_base = (blk.row_off + li) * p;
for lj in 0..m_j {
let a_ij = blk.data[[li, lj]];
if a_ij == 0.0 {
continue;
}
let gj_base = (blk.col_off + lj) * p;
for oc in 0..p {
out[[gi_base + oc, gj_base + oc]] += a_ij;
}
}
}
}
out
}
fn row_abs_sums(&self) -> Array1<f64> {
let p = self.p;
let mut out = Array1::<f64>::zeros(self.k);
for blk in &self.blocks {
let (m_i, m_j) = blk.data.dim();
for li in 0..m_i {
let gi_base = (blk.row_off + li) * p;
let mut row_abs = 0.0_f64;
for lj in 0..m_j {
row_abs += blk.data[[li, lj]].abs();
}
for oc in 0..p {
out[gi_base + oc] += row_abs;
}
}
}
out
}
fn fingerprint(&self, hasher: &mut Fingerprinter) {
hasher.write_str("sparse-block-kronecker-penalty-op-v1");
hasher.write_usize(self.p);
hasher.write_usize(self.dim_a);
hasher.write_usize(self.k);
hasher.write_usize(self.blocks.len());
for blk in &self.blocks {
hasher.write_usize(blk.row_off);
hasher.write_usize(blk.col_off);
hasher.write_f64_array2(&blk.data);
}
}
}
#[derive(Debug, Clone)]
pub struct FactoredFrameGBlock {
pub atom_i: usize,
pub atom_j: usize,
pub g: Array2<f64>,
pub w: Array2<f64>,
}
pub struct FactoredFrameKroneckerOp {
pub ranks: Vec<usize>,
pub basis_sizes: Vec<usize>,
pub offsets: Vec<usize>,
pub dim: usize,
pub blocks: Vec<FactoredFrameGBlock>,
}
pub fn frame_output_gram(u_i: ArrayView2<f64>, u_j: ArrayView2<f64>) -> Array2<f64> {
let (p_i, r_i) = u_i.dim();
let (p_j, r_j) = u_j.dim();
assert_eq!(
p_i, p_j,
"frame_output_gram: frames live in different ambient dims ({p_i} vs {p_j})"
);
let mut w = Array2::<f64>::zeros((r_i, r_j));
for a in 0..r_i {
for b in 0..r_j {
let mut acc = 0.0;
for c in 0..p_i {
acc += u_i[[c, a]] * u_j[[c, b]];
}
w[[a, b]] = acc;
}
}
w
}
impl FactoredFrameKroneckerOp {
pub fn new(
ranks: Vec<usize>,
basis_sizes: Vec<usize>,
blocks: Vec<FactoredFrameGBlock>,
) -> Result<Self, String> {
if ranks.len() != basis_sizes.len() {
return Err(format!(
"FactoredFrameKroneckerOp: {} ranks but {} basis sizes",
ranks.len(),
basis_sizes.len()
));
}
let n_atoms = ranks.len();
let mut offsets = Vec::with_capacity(n_atoms + 1);
let mut acc = 0usize;
for k in 0..n_atoms {
offsets.push(acc);
acc += basis_sizes[k] * ranks[k];
}
offsets.push(acc);
let dim = acc;
for blk in &blocks {
if blk.atom_i >= n_atoms || blk.atom_j >= n_atoms {
return Err(format!(
"FactoredFrameKroneckerOp: block atom indices ({}, {}) out of range (n_atoms = {n_atoms})",
blk.atom_i, blk.atom_j
));
}
if blk.g.dim() != (basis_sizes[blk.atom_i], basis_sizes[blk.atom_j]) {
return Err(format!(
"FactoredFrameKroneckerOp: block ({}, {}) g has shape {:?} but expected ({}, {})",
blk.atom_i,
blk.atom_j,
blk.g.dim(),
basis_sizes[blk.atom_i],
basis_sizes[blk.atom_j]
));
}
if blk.w.dim() != (ranks[blk.atom_i], ranks[blk.atom_j]) {
return Err(format!(
"FactoredFrameKroneckerOp: block ({}, {}) w has shape {:?} but expected ({}, {})",
blk.atom_i,
blk.atom_j,
blk.w.dim(),
ranks[blk.atom_i],
ranks[blk.atom_j]
));
}
}
Ok(Self {
ranks,
basis_sizes,
offsets,
dim,
blocks,
})
}
pub fn from_frames_and_blocks(
frames: &[Option<Array2<f64>>],
basis_sizes: &[usize],
p: usize,
g_blocks: &std::collections::BTreeMap<(usize, usize), Array2<f64>>,
) -> Result<Self, String> {
if frames.len() != basis_sizes.len() {
return Err(format!(
"FactoredFrameKroneckerOp::from_frames_and_blocks: {} frames but {} basis sizes",
frames.len(),
basis_sizes.len()
));
}
let n_atoms = frames.len();
let mut ranks = Vec::with_capacity(n_atoms);
for (k, frame) in frames.iter().enumerate() {
match frame {
Some(u) => {
let (pr, r) = u.dim();
if pr != p {
return Err(format!(
"FactoredFrameKroneckerOp::from_frames_and_blocks: frame {k} has {pr} rows but ambient dim is {p}"
));
}
if r > p {
return Err(format!(
"FactoredFrameKroneckerOp::from_frames_and_blocks: frame {k} has rank {r} > ambient dim {p}"
));
}
ranks.push(r);
}
None => ranks.push(p),
}
}
let identity = Array2::<f64>::eye(p);
let frame_or_ident = |k: usize| -> ArrayView2<f64> {
match &frames[k] {
Some(u) => u.view(),
None => identity.view(),
}
};
let mut blocks = Vec::with_capacity(g_blocks.len());
for (&(atom_i, atom_j), g) in g_blocks {
if atom_i >= n_atoms || atom_j >= n_atoms {
return Err(format!(
"FactoredFrameKroneckerOp::from_frames_and_blocks: block atom indices ({atom_i}, {atom_j}) out of range (n_atoms = {n_atoms})"
));
}
let w = frame_output_gram(frame_or_ident(atom_i), frame_or_ident(atom_j));
blocks.push(FactoredFrameGBlock {
atom_i,
atom_j,
g: g.clone(),
w,
});
}
Self::new(ranks, basis_sizes.to_vec(), blocks)
}
}
impl BetaPenaltyOp for FactoredFrameKroneckerOp {
fn dim(&self) -> usize {
self.dim
}
fn matvec(&self, x: &[f64], y: &mut [f64]) {
for blk in &self.blocks {
let r_i = self.ranks[blk.atom_i];
let r_j = self.ranks[blk.atom_j];
let off_i = self.offsets[blk.atom_i];
let off_j = self.offsets[blk.atom_j];
let (m_i, m_j) = blk.g.dim();
for li in 0..m_i {
let yi_base = off_i + li * r_i;
for lj in 0..m_j {
let g = blk.g[[li, lj]];
if g == 0.0 {
continue;
}
let xj_base = off_j + lj * r_j;
for a in 0..r_i {
let mut acc = 0.0;
for b in 0..r_j {
acc += blk.w[[a, b]] * x[xj_base + b];
}
y[yi_base + a] += g * acc;
}
}
}
}
}
fn gradient(&self, beta: &[f64], out: &mut [f64]) {
self.matvec(beta, out);
}
fn diagonal(&self, diag: &mut [f64]) {
for blk in &self.blocks {
if blk.atom_i != blk.atom_j {
continue;
}
let r = self.ranks[blk.atom_i];
let off = self.offsets[blk.atom_i];
let (m_i, m_j) = blk.g.dim();
let m = m_i.min(m_j);
for li in 0..m {
let gii = blk.g[[li, li]];
let base = off + li * r;
for a in 0..r {
diag[base + a] += gii * blk.w[[a, a]];
}
}
}
}
fn block(&self, id: BetaBlockId, offsets: &[Range<usize>], out: &mut Array2<f64>) {
let range = &offsets[id.0];
let b_dim = range.end - range.start;
for blk in &self.blocks {
let r_i = self.ranks[blk.atom_i];
let r_j = self.ranks[blk.atom_j];
let off_i = self.offsets[blk.atom_i];
let off_j = self.offsets[blk.atom_j];
let (m_i, m_j) = blk.g.dim();
for li in 0..m_i {
for a in 0..r_i {
let gi = off_i + li * r_i + a;
if gi < range.start || gi >= range.end {
continue;
}
let bi = gi - range.start;
for lj in 0..m_j {
let g = blk.g[[li, lj]];
if g == 0.0 {
continue;
}
for b in 0..r_j {
let gj = off_j + lj * r_j + b;
if gj < range.start || gj >= range.end {
continue;
}
let bj = gj - range.start;
if bi < b_dim && bj < b_dim {
out[[bi, bj]] += g * blk.w[[a, b]];
}
}
}
}
}
}
}
fn to_dense(&self) -> Array2<f64> {
let mut out = Array2::<f64>::zeros((self.dim, self.dim));
for blk in &self.blocks {
let r_i = self.ranks[blk.atom_i];
let r_j = self.ranks[blk.atom_j];
let off_i = self.offsets[blk.atom_i];
let off_j = self.offsets[blk.atom_j];
let (m_i, m_j) = blk.g.dim();
for li in 0..m_i {
for lj in 0..m_j {
let g = blk.g[[li, lj]];
if g == 0.0 {
continue;
}
for a in 0..r_i {
let gi = off_i + li * r_i + a;
for b in 0..r_j {
let gj = off_j + lj * r_j + b;
out[[gi, gj]] += g * blk.w[[a, b]];
}
}
}
}
}
out
}
fn fingerprint(&self, hasher: &mut Fingerprinter) {
hasher.write_str("factored-frame-kronecker-op-v1");
hasher.write_usize(self.dim);
for &r in &self.ranks {
hasher.write_usize(r);
}
for &m in &self.basis_sizes {
hasher.write_usize(m);
}
hasher.write_usize(self.blocks.len());
for blk in &self.blocks {
hasher.write_usize(blk.atom_i);
hasher.write_usize(blk.atom_j);
hasher.write_f64_array2(&blk.g);
hasher.write_f64_array2(&blk.w);
}
}
}
pub struct CompositePenaltyOp {
pub k: usize,
pub ops: Vec<Arc<dyn BetaPenaltyOp>>,
}
impl BetaPenaltyOp for CompositePenaltyOp {
fn dim(&self) -> usize {
self.k
}
fn matvec(&self, x: &[f64], y: &mut [f64]) {
for op in &self.ops {
op.matvec(x, y);
}
}
fn gradient(&self, beta: &[f64], out: &mut [f64]) {
for op in &self.ops {
op.gradient(beta, out);
}
}
fn diagonal(&self, diag: &mut [f64]) {
for op in &self.ops {
op.diagonal(diag);
}
}
fn block(&self, id: BetaBlockId, offsets: &[Range<usize>], out: &mut Array2<f64>) {
for op in &self.ops {
op.block(id, offsets, out);
}
}
fn to_dense(&self) -> Array2<f64> {
let mut out = Array2::<f64>::zeros((self.k, self.k));
for op in &self.ops {
let dense = op.to_dense();
out += &dense;
}
out
}
fn fingerprint(&self, hasher: &mut Fingerprinter) {
hasher.write_str("composite-penalty-op-v1");
hasher.write_usize(self.k);
hasher.write_usize(self.ops.len());
for op in &self.ops {
op.fingerprint(hasher);
}
}
}
pub struct MatvecDiagPenaltyOp {
pub(crate) k: usize,
pub(crate) matvec: SharedBetaMatvec,
pub(crate) diagonal_vec: Array1<f64>,
}
impl MatvecDiagPenaltyOp {
pub fn new(k: usize, matvec: SharedBetaMatvec, diagonal_vec: Array1<f64>) -> Self {
assert_eq!(diagonal_vec.len(), k);
Self {
k,
matvec,
diagonal_vec,
}
}
}
impl BetaPenaltyOp for MatvecDiagPenaltyOp {
fn dim(&self) -> usize {
self.k
}
fn matvec(&self, x: &[f64], y: &mut [f64]) {
let x_arr = Array1::from_iter(x.iter().copied());
let mut out = Array1::<f64>::zeros(self.k);
(self.matvec)(x_arr.view(), &mut out);
for a in 0..self.k {
y[a] += out[a];
}
}
fn gradient(&self, beta: &[f64], out: &mut [f64]) {
let beta_arr = Array1::from_iter(beta.iter().copied());
let mut hb = Array1::<f64>::zeros(self.k);
(self.matvec)(beta_arr.view(), &mut hb);
for a in 0..self.k {
out[a] += hb[a];
}
}
fn diagonal(&self, diag: &mut [f64]) {
for j in 0..self.k.min(diag.len()) {
diag[j] += self.diagonal_vec[j];
}
}
fn block(&self, id: BetaBlockId, offsets: &[Range<usize>], out: &mut Array2<f64>) {
let range = &offsets[id.0];
let b = range.end - range.start;
let mut probe = Array1::<f64>::zeros(self.k);
for bj in 0..b {
probe.fill(0.0);
probe[range.start + bj] = 1.0;
let mut col = Array1::<f64>::zeros(self.k);
(self.matvec)(probe.view(), &mut col);
for bi in 0..b {
out[[bi, bj]] += col[range.start + bi];
}
}
}
fn to_dense(&self) -> Array2<f64> {
let k = self.k;
let mut out = Array2::<f64>::zeros((k, k));
let mut probe = Array1::<f64>::zeros(k);
for j in 0..k {
probe.fill(0.0);
probe[j] = 1.0;
let mut col = Array1::<f64>::zeros(k);
(self.matvec)(probe.view(), &mut col);
for i in 0..k {
out[[i, j]] = col[i];
}
}
out
}
fn fingerprint(&self, hasher: &mut Fingerprinter) {
hasher.write_str("matvec-diag-penalty-op-v1");
hasher.write_usize(self.k);
for &value in self.diagonal_vec.iter() {
hasher.write_f64(value);
}
}
}