use tensorlogic_ir::EinsumGraph;
use super::config::LowRankConfig;
use super::error::LowRankError;
use super::svd::{SvdResult, TruncatedSvd};
pub struct LowRankApproximation {
config: LowRankConfig,
svd: TruncatedSvd,
}
impl LowRankApproximation {
pub fn new(config: LowRankConfig) -> Self {
let svd = TruncatedSvd::new(config.clone());
LowRankApproximation { config, svd }
}
pub fn approximate_matrix(
&self,
data: &[f64],
rows: usize,
cols: usize,
) -> Result<SvdResult, LowRankError> {
self.svd.decompose(data, rows, cols)
}
pub fn approximate_matmul(
&self,
a: &[f64],
a_rows: usize,
a_cols: usize,
b: &[f64],
b_rows: usize,
b_cols: usize,
) -> Result<Vec<f64>, LowRankError> {
if a_cols != b_rows {
return Err(LowRankError::InvalidInput(format!(
"inner dimensions mismatch: a_cols={} != b_rows={}",
a_cols, b_rows
)));
}
let svd_result = self.svd.decompose(a, a_rows, a_cols)?;
let rank = svd_result.rank_used;
let mut m = vec![0.0_f64; rank * b_cols];
for k in 0..rank {
for j in 0..b_cols {
let mut val = 0.0_f64;
for l in 0..b_rows {
val += svd_result.vt[k * svd_result.vt_cols + l] * b[l * b_cols + j];
}
m[k * b_cols + j] = val;
}
}
let mut c = vec![0.0_f64; a_rows * b_cols];
for i in 0..a_rows {
for j in 0..b_cols {
let mut val = 0.0_f64;
for k in 0..rank {
let u_ik = svd_result.u[i * svd_result.u_cols + k];
val += u_ik * svd_result.singular_values[k] * m[k * b_cols + j];
}
c[i * b_cols + j] = val;
}
}
Ok(c)
}
pub fn is_candidate(&self, rows: usize, cols: usize) -> bool {
rows >= self.config.min_dimension && cols >= self.config.min_dimension
}
pub fn optimal_rank(singular_values: &[f64], energy_threshold: f64) -> usize {
if singular_values.is_empty() {
return 0;
}
let total: f64 = singular_values.iter().map(|s| s * s).sum();
if total == 0.0 {
return 1;
}
let mut cumulative = 0.0_f64;
for (k, &sv) in singular_values.iter().enumerate() {
cumulative += sv * sv;
if cumulative / total >= energy_threshold {
return k + 1;
}
}
singular_values.len()
}
}
#[derive(Debug, Clone)]
pub struct LowRankCandidate {
pub node_index: usize,
pub reason: String,
pub estimated_savings_ratio: f64,
}
#[derive(Debug, Clone, Default)]
pub struct LowRankPassStats {
pub candidates_found: usize,
pub nodes_inspected: usize,
pub estimated_total_flop_reduction: f64,
}
#[derive(Debug)]
pub struct LowRankInferencePass {
config: LowRankConfig,
}
impl LowRankInferencePass {
pub fn new(config: LowRankConfig) -> Self {
LowRankInferencePass { config }
}
pub fn find_candidates(&self, graph: &EinsumGraph) -> Vec<LowRankCandidate> {
let mut candidates = Vec::new();
for (idx, node) in graph.nodes.iter().enumerate() {
if let tensorlogic_ir::OpType::Einsum { spec } = &node.op {
if node.inputs.len() >= 2 && self.is_matmul_like(spec) {
let savings = self.estimate_savings(spec);
candidates.push(LowRankCandidate {
node_index: idx,
reason: format!(
"Einsum '{}' has {} inputs and matmul-like contraction",
spec,
node.inputs.len()
),
estimated_savings_ratio: savings,
});
}
}
}
candidates
}
pub fn apply_annotations(&self, graph: &EinsumGraph) -> LowRankPassStats {
let candidates = self.find_candidates(graph);
let estimated_total_flop_reduction: f64 =
candidates.iter().map(|c| c.estimated_savings_ratio).sum();
LowRankPassStats {
candidates_found: candidates.len(),
nodes_inspected: graph.nodes.len(),
estimated_total_flop_reduction,
}
}
fn is_matmul_like(&self, spec: &str) -> bool {
if let Some(arrow_pos) = spec.find("->") {
let inputs_part = &spec[..arrow_pos];
let operands: Vec<&str> = inputs_part.split(',').collect();
if operands.len() < 2 {
return false;
}
let a_chars: std::collections::HashSet<char> =
operands[0].chars().filter(|c| c.is_alphabetic()).collect();
let b_chars: std::collections::HashSet<char> =
operands[1].chars().filter(|c| c.is_alphabetic()).collect();
let output_chars: std::collections::HashSet<char> = spec[arrow_pos + 2..]
.chars()
.filter(|c| c.is_alphabetic())
.collect();
let contracted: std::collections::HashSet<char> = a_chars
.intersection(&b_chars)
.copied()
.filter(|c| !output_chars.contains(c))
.collect();
return contracted.len() >= 1
&& self.config.rank < self.min_contracted_dim_estimate(spec);
}
false
}
fn min_contracted_dim_estimate(&self, spec: &str) -> usize {
let contracted = spec.chars().filter(|c| c.is_alphabetic()).count();
contracted.max(1)
}
fn estimate_savings(&self, spec: &str) -> f64 {
let contracted_dims = self.min_contracted_dim_estimate(spec).max(1) as f64;
let rank = self.config.rank as f64;
(1.0 - (2.0 * rank) / contracted_dims).clamp(0.0, 1.0)
}
}
#[cfg(test)]
mod tests {
use super::*;
fn make_config(rank: usize) -> LowRankConfig {
LowRankConfig::new(rank)
.with_tolerance(1e-8)
.with_max_iterations(300)
.with_min_dimension(8)
}
#[test]
fn test_approximation_4x4_matrix() {
let m: Vec<f64> = (1..=16).map(|x| x as f64).collect();
let cfg = make_config(2);
let approx = LowRankApproximation::new(cfg);
let svd = approx
.approximate_matrix(&m, 4, 4)
.expect("approximation should succeed for a valid 4x4 matrix");
assert!(svd.rank_used >= 1);
assert!(svd.frobenius_error.is_finite());
}
#[test]
fn test_is_candidate_small_matrix() {
let cfg = LowRankConfig::new(2).with_min_dimension(32);
let approx = LowRankApproximation::new(cfg);
assert!(!approx.is_candidate(4, 4));
}
#[test]
fn test_is_candidate_large_matrix() {
let cfg = LowRankConfig::new(4).with_min_dimension(32);
let approx = LowRankApproximation::new(cfg);
assert!(approx.is_candidate(64, 64));
}
#[test]
fn test_optimal_rank_energy_threshold() {
let svs = vec![10.0_f64, 5.0, 2.0, 1.0];
let r = LowRankApproximation::optimal_rank(&svs, 0.90);
assert_eq!(r, 2, "optimal rank for 90% energy should be 2, got {r}");
let r2 = LowRankApproximation::optimal_rank(&svs, 0.99);
assert_eq!(r2, 3, "optimal rank for 99% energy should be 3, got {r2}");
}
#[test]
fn test_inference_pass_empty_graph() {
let graph = EinsumGraph::new();
let pass = LowRankInferencePass::new(LowRankConfig::default());
let candidates = pass.find_candidates(&graph);
assert!(
candidates.is_empty(),
"empty graph should yield no candidates"
);
}
#[test]
fn test_inference_pass_stats() {
let mut graph = EinsumGraph::new();
let t0 = graph.add_tensor("A");
let t1 = graph.add_tensor("B");
let t2 = graph.add_tensor("C");
let node = tensorlogic_ir::EinsumNode::einsum("ij,jk->ik", vec![t0, t1], vec![t2]);
graph.add_node(node).expect("add_node ok");
let pass = LowRankInferencePass::new(LowRankConfig::new(2));
let stats = pass.apply_annotations(&graph);
assert_eq!(stats.nodes_inspected, 1);
assert!(stats.nodes_inspected >= 1);
assert!(stats.estimated_total_flop_reduction >= 0.0);
}
}