mod ops;
#[cfg(target_arch = "x86_64")]
pub(crate) mod simd;
use crate::fixed_point::universal::fasc::stack_evaluator::BinaryStorage;
use crate::fixed_point::imperative::FixedPoint;
use rayon::prelude::*;
pub const SCALE: i32 = 19_683;
pub const MAX_RAW: i16 = 29_524;
pub const MIN_RAW: i16 = -29_524;
pub const TRIT_DECODE_TABLE: [[i8; 5]; 256] = generate_trit_decode_table();
const fn generate_trit_decode_table() -> [[i8; 5]; 256] {
let mut table = [[0i8; 5]; 256];
let mut byte_val: u16 = 0;
while byte_val < 256 {
let mut v = byte_val as u8;
let d4 = (v % 3) as i8 - 1; v /= 3;
let d3 = (v % 3) as i8 - 1; v /= 3;
let d2 = (v % 3) as i8 - 1; v /= 3;
let d1 = (v % 3) as i8 - 1; v /= 3;
let d0 = (v % 3) as i8 - 1;
table[byte_val as usize] = [d0, d1, d2, d3, d4];
byte_val += 1;
}
table
}
#[derive(Debug, Clone)]
pub struct TQ19Matrix {
rows: usize,
cols: usize,
data: Vec<i16>,
}
impl TQ19Matrix {
pub fn new(rows: usize, cols: usize, data: Vec<i16>) -> Self {
assert_eq!(data.len(), rows * cols, "TQ19Matrix: data.len() must equal rows × cols");
Self { rows, cols, data }
}
pub fn from_fn(rows: usize, cols: usize, f: impl Fn(usize, usize) -> i16) -> Self {
let mut data = Vec::with_capacity(rows * cols);
for r in 0..rows {
for c in 0..cols {
data.push(f(r, c));
}
}
Self { rows, cols, data }
}
#[inline]
pub fn rows(&self) -> usize { self.rows }
#[inline]
pub fn cols(&self) -> usize { self.cols }
#[inline]
pub fn data(&self) -> &[i16] { &self.data }
#[inline]
pub fn row_slice(&self, row: usize) -> &[i16] {
let start = row * self.cols;
&self.data[start..start + self.cols]
}
#[inline]
pub fn get(&self, row: usize, col: usize) -> i16 {
self.data[row * self.cols + col]
}
pub fn matvec(&self, activations: &[BinaryStorage]) -> Vec<BinaryStorage> {
assert_eq!(activations.len(), self.cols, "TQ19Matrix::matvec: activation length mismatch");
ops::tq19_matvec(&self.data, self.rows, self.cols, activations)
}
pub fn matvec_batch(&self, batch: &[&[BinaryStorage]]) -> Vec<Vec<BinaryStorage>> {
for (i, v) in batch.iter().enumerate() {
assert_eq!(v.len(), self.cols, "TQ19Matrix::matvec_batch: activation[{i}] length mismatch");
}
ops::tq19_matvec_batch(&self.data, self.rows, self.cols, batch)
}
pub fn matvec_fp(&self, activations: &[BinaryStorage]) -> Vec<FixedPoint> {
self.matvec(activations).into_iter().map(FixedPoint::from_raw).collect()
}
pub fn matvec_par(&self, activations: &[BinaryStorage]) -> Vec<BinaryStorage> {
assert_eq!(activations.len(), self.cols, "TQ19Matrix::matvec_par: activation length mismatch");
ops::tq19_matvec_par(&self.data, self.rows, self.cols, activations)
}
pub fn matvec_batch_par(&self, batch: &[&[BinaryStorage]]) -> Vec<Vec<BinaryStorage>> {
for (i, v) in batch.iter().enumerate() {
assert_eq!(v.len(), self.cols, "TQ19Matrix::matvec_batch_par: activation[{i}] length mismatch");
}
ops::tq19_matvec_batch_par(&self.data, self.rows, self.cols, batch)
}
}
#[inline]
pub fn tq19_dot(weights: &[i16], activations: &[BinaryStorage]) -> BinaryStorage {
ops::tq19_dot(weights, activations)
}
#[inline]
pub fn trit_dot(trits: &[i8], activations: &[BinaryStorage]) -> BinaryStorage {
ops::trit_dot(trits, activations)
}
pub fn packed_trit_dot(
packed: &[u8],
count: usize,
activations: &[BinaryStorage],
scale: BinaryStorage,
) -> BinaryStorage {
ops::packed_trit_dot(packed, count, activations, scale)
}
pub fn packed_trit_matvec(
packed_trits: &[u8],
rows: usize,
cols: usize,
activations: &[BinaryStorage],
scales: &[BinaryStorage],
) -> Vec<BinaryStorage> {
ops::packed_trit_matvec(packed_trits, rows, cols, activations, scales)
}
pub fn packed_trit_matvec_par(
packed_trits: &[u8],
rows: usize,
cols: usize,
activations: &[BinaryStorage],
scales: &[BinaryStorage],
) -> Vec<BinaryStorage> {
ops::packed_trit_matvec_par(packed_trits, rows, cols, activations, scales)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn trit_decode_table_spot_check() {
assert_eq!(TRIT_DECODE_TABLE[0], [-1, -1, -1, -1, -1]);
assert_eq!(TRIT_DECODE_TABLE[121], [0, 0, 0, 0, 0]);
assert_eq!(TRIT_DECODE_TABLE[242], [1, 1, 1, 1, 1]);
assert_eq!(TRIT_DECODE_TABLE[1], [-1, -1, -1, -1, 0]);
assert_eq!(TRIT_DECODE_TABLE[2], [-1, -1, -1, -1, 1]);
assert_eq!(TRIT_DECODE_TABLE[3], [-1, -1, -1, 0, -1]);
}
#[test]
fn trit_decode_roundtrip() {
for byte in 0u8..=242 {
let trits = TRIT_DECODE_TABLE[byte as usize];
let re_encoded = ((trits[0] + 1) as u8) * 81
+ ((trits[1] + 1) as u8) * 27
+ ((trits[2] + 1) as u8) * 9
+ ((trits[3] + 1) as u8) * 3
+ ((trits[4] + 1) as u8);
assert_eq!(re_encoded, byte, "roundtrip failed for byte {byte}");
}
}
#[test]
fn tq19_matrix_construction() {
let m = TQ19Matrix::new(2, 3, vec![1, 2, 3, 4, 5, 6]);
assert_eq!(m.rows(), 2);
assert_eq!(m.cols(), 3);
assert_eq!(m.get(0, 0), 1);
assert_eq!(m.get(1, 2), 6);
assert_eq!(m.row_slice(0), &[1, 2, 3]);
assert_eq!(m.row_slice(1), &[4, 5, 6]);
}
#[test]
fn tq19_matrix_from_fn() {
let m = TQ19Matrix::from_fn(3, 3, |r, c| if r == c { SCALE as i16 } else { 0 });
assert_eq!(m.get(0, 0), SCALE as i16);
assert_eq!(m.get(0, 1), 0);
assert_eq!(m.get(1, 1), SCALE as i16);
}
#[test]
#[should_panic(expected = "data.len() must equal rows × cols")]
fn tq19_matrix_size_mismatch() {
TQ19Matrix::new(2, 3, vec![0; 5]);
}
}