use crate::fixed_point::imperative::{FixedPoint, FixedMatrix};
use crate::fixed_point::universal::fasc::stack_evaluator::{StackValue, StackEvaluator, evaluate};
use crate::fixed_point::universal::fasc::lazy_expr::LazyExpr;
use crate::fixed_point::universal::ugod::DomainType;
use crate::fixed_point::core_types::errors::OverflowDetected;
use crate::deployment_profiles::DeploymentProfile;
use core::cell::RefCell;
#[derive(Debug, Clone)]
pub struct DomainMatrix {
rows: usize,
cols: usize,
data: Vec<StackValue>,
}
impl DomainMatrix {
pub fn from_values(rows: usize, cols: usize, data: Vec<StackValue>) -> Self {
assert_eq!(data.len(), rows * cols, "DomainMatrix: data length mismatch");
Self { rows, cols, data }
}
pub fn from_strings(rows: usize, cols: usize, values: &[&'static str]) -> Result<Self, OverflowDetected> {
assert_eq!(values.len(), rows * cols, "DomainMatrix: values length mismatch");
let data: Result<Vec<StackValue>, OverflowDetected> = values.iter()
.map(|s| evaluate(&LazyExpr::Literal(s)))
.collect();
Ok(Self { rows, cols, data: data? })
}
pub fn from_fixed_matrix(m: &FixedMatrix) -> Self {
let mut data = Vec::with_capacity(m.rows() * m.cols());
for r in 0..m.rows() {
for c in 0..m.cols() {
data.push(m.get(r, c).to_stack_value());
}
}
Self { rows: m.rows(), cols: m.cols(), data }
}
pub fn to_fixed_matrix(&self) -> Result<FixedMatrix, OverflowDetected> {
with_evaluator(|eval| {
let mut data = Vec::with_capacity(self.rows * self.cols);
for sv in &self.data {
match sv.as_binary_storage() {
Some(raw) => data.push(FixedPoint::from_raw(raw)),
None => {
let binary_sv = eval.to_binary_value(sv)?;
let materialized = eval.materialize_compute(binary_sv)?;
match materialized.as_binary_storage() {
Some(raw) => data.push(FixedPoint::from_raw(raw)),
None => return Err(OverflowDetected::InvalidInput),
}
}
}
}
Ok(FixedMatrix::from_fn(self.rows, self.cols, |r, c| data[r * self.cols + c]))
})
}
pub fn rows(&self) -> usize { self.rows }
pub fn cols(&self) -> usize { self.cols }
pub fn get(&self, row: usize, col: usize) -> &StackValue {
&self.data[row * self.cols + col]
}
pub fn set(&mut self, row: usize, col: usize, val: StackValue) {
self.data[row * self.cols + col] = val;
}
pub fn is_uniform_domain(&self) -> bool {
if self.data.is_empty() { return true; }
let first = self.data[0].domain_type();
self.data.iter().all(|sv| sv.domain_type() == first)
}
pub fn dominant_domain(&self) -> Option<DomainType> {
if self.data.is_empty() { return None; }
let mut counts = [0u32; 4]; for sv in &self.data {
match sv.domain_type() {
Some(DomainType::Binary) => counts[0] += 1,
Some(DomainType::Decimal) => counts[1] += 1,
Some(DomainType::Ternary) => counts[2] += 1,
Some(DomainType::Symbolic) => counts[3] += 1,
_ => {}
}
}
let max_idx = counts.iter().enumerate().max_by_key(|(_, &c)| c).map(|(i, _)| i)?;
match max_idx {
0 => Some(DomainType::Binary),
1 => Some(DomainType::Decimal),
2 => Some(DomainType::Ternary),
3 => Some(DomainType::Symbolic),
_ => None,
}
}
pub fn identity_binary(n: usize) -> Self {
Self::from_fixed_matrix(&FixedMatrix::identity(n))
}
pub fn transpose(&self) -> Self {
let mut data = Vec::with_capacity(self.rows * self.cols);
for c in 0..self.cols {
for r in 0..self.rows {
data.push(self.data[r * self.cols + c].clone());
}
}
Self { rows: self.cols, cols: self.rows, data }
}
pub fn add(&self, other: &DomainMatrix) -> Result<DomainMatrix, OverflowDetected> {
assert_eq!(self.rows, other.rows, "DomainMatrix add: row mismatch");
assert_eq!(self.cols, other.cols, "DomainMatrix add: col mismatch");
with_evaluator(|eval| {
let mut data = Vec::with_capacity(self.rows * self.cols);
for i in 0..self.data.len() {
let result = eval.add_values(self.data[i].clone(), other.data[i].clone())?;
data.push(result);
}
Ok(DomainMatrix { rows: self.rows, cols: self.cols, data })
})
}
pub fn sub(&self, other: &DomainMatrix) -> Result<DomainMatrix, OverflowDetected> {
assert_eq!(self.rows, other.rows, "DomainMatrix sub: row mismatch");
assert_eq!(self.cols, other.cols, "DomainMatrix sub: col mismatch");
with_evaluator(|eval| {
let mut data = Vec::with_capacity(self.rows * self.cols);
for i in 0..self.data.len() {
let result = eval.subtract_values(self.data[i].clone(), other.data[i].clone())?;
data.push(result);
}
Ok(DomainMatrix { rows: self.rows, cols: self.cols, data })
})
}
pub fn mat_mul(&self, other: &DomainMatrix) -> Result<DomainMatrix, OverflowDetected> {
assert_eq!(self.cols, other.rows, "DomainMatrix matmul: dimension mismatch");
let k = self.cols;
with_evaluator(|eval| {
let mut data = Vec::with_capacity(self.rows * other.cols);
for r in 0..self.rows {
for c in 0..other.cols {
let first_prod = eval.multiply_values(
self.data[r * k].clone(),
other.data[c].clone(),
)?;
let mut acc = first_prod;
for m in 1..k {
let prod = eval.multiply_values(
self.data[r * k + m].clone(),
other.data[m * other.cols + c].clone(),
)?;
acc = eval.add_values(acc, prod)?;
}
data.push(acc);
}
}
Ok(DomainMatrix { rows: self.rows, cols: other.cols, data })
})
}
pub fn neg(&self) -> Result<DomainMatrix, OverflowDetected> {
with_evaluator(|eval| {
let mut data = Vec::with_capacity(self.data.len());
for sv in &self.data {
data.push(eval.negate_value(sv.clone())?);
}
Ok(DomainMatrix { rows: self.rows, cols: self.cols, data })
})
}
pub fn scalar_mul(&self, s: &StackValue) -> Result<DomainMatrix, OverflowDetected> {
with_evaluator(|eval| {
let mut data = Vec::with_capacity(self.data.len());
for sv in &self.data {
data.push(eval.multiply_values(s.clone(), sv.clone())?);
}
Ok(DomainMatrix { rows: self.rows, cols: self.cols, data })
})
}
pub fn trace(&self) -> Result<StackValue, OverflowDetected> {
assert_eq!(self.rows, self.cols, "DomainMatrix trace: not square");
with_evaluator(|eval| {
let mut acc = self.data[0].clone();
for i in 1..self.rows {
acc = eval.add_values(acc, self.data[i * self.cols + i].clone())?;
}
Ok(acc)
})
}
}
thread_local! {
static DOMAIN_EVALUATOR: RefCell<StackEvaluator> = RefCell::new(
StackEvaluator::new(compile_time_profile())
);
}
const fn compile_time_profile() -> DeploymentProfile {
#[cfg(table_format = "q256_256")]
{ DeploymentProfile::Scientific }
#[cfg(table_format = "q128_128")]
{ DeploymentProfile::Balanced }
#[cfg(table_format = "q64_64")]
{ DeploymentProfile::Embedded }
#[cfg(table_format = "q32_32")]
{ DeploymentProfile::Compact }
#[cfg(table_format = "q16_16")]
{ DeploymentProfile::Realtime }
}
fn with_evaluator<T>(f: impl FnOnce(&mut StackEvaluator) -> Result<T, OverflowDetected>) -> Result<T, OverflowDetected> {
DOMAIN_EVALUATOR.with(|eval| {
let mut evaluator = eval.borrow_mut();
evaluator.reset();
f(&mut evaluator)
})
}
impl core::fmt::Display for DomainMatrix {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
write!(f, "DomainMatrix({}x{}, {:?})", self.rows, self.cols, self.dominant_domain())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_domain_matrix_from_strings_binary() {
let m = DomainMatrix::from_strings(2, 2, &["1", "2", "3", "4"]).unwrap();
assert_eq!(m.rows(), 2);
assert_eq!(m.cols(), 2);
assert!(m.is_uniform_domain());
}
#[test]
fn test_domain_matrix_from_strings_decimal() {
let m = DomainMatrix::from_strings(2, 2, &["0.10", "0.20", "0.30", "0.40"]).unwrap();
assert_eq!(m.rows(), 2);
assert!(m.is_uniform_domain());
assert_eq!(m.dominant_domain(), Some(DomainType::Decimal));
}
#[test]
fn test_domain_matrix_transpose() {
let m = DomainMatrix::from_strings(2, 3, &["1", "2", "3", "4", "5", "6"]).unwrap();
let mt = m.transpose();
assert_eq!(mt.rows(), 3);
assert_eq!(mt.cols(), 2);
}
#[test]
fn test_domain_matrix_identity() {
let id = DomainMatrix::identity_binary(3);
assert_eq!(id.rows(), 3);
assert_eq!(id.cols(), 3);
assert!(id.is_uniform_domain());
}
#[test]
fn test_domain_matrix_add() {
let a = DomainMatrix::from_strings(2, 2, &["1", "2", "3", "4"]).unwrap();
let b = DomainMatrix::from_strings(2, 2, &["10", "20", "30", "40"]).unwrap();
let c = a.add(&b).unwrap();
assert_eq!(c.rows(), 2);
assert_eq!(c.cols(), 2);
}
#[test]
fn test_domain_matrix_decimal_matmul() {
let id = DomainMatrix::from_strings(2, 2, &["1.00", "0.00", "0.00", "1.00"]).unwrap();
let v = DomainMatrix::from_strings(2, 1, &["0.10", "0.20"]).unwrap();
let result = id.mat_mul(&v).unwrap();
assert_eq!(result.rows(), 2);
assert_eq!(result.cols(), 1);
}
}