use std::ops::{Add, Sub, Neg, Mul, Index, IndexMut};
use super::FixedPoint;
use super::linalg::{compute_tier_dot, compute_tier_dot_raw};
use crate::fixed_point::universal::fasc::stack_evaluator::BinaryStorage;
#[derive(Clone, Debug, PartialEq)]
pub struct FixedVector {
data: Vec<FixedPoint>,
}
impl FixedVector {
pub fn new(dim: usize) -> Self {
Self {
data: vec![FixedPoint::ZERO; dim],
}
}
pub fn from_f32_slice(values: &[f32]) -> Self {
Self {
data: values.iter().map(|&v| FixedPoint::from_f32(v)).collect(),
}
}
pub fn from_slice(values: &[FixedPoint]) -> Self {
Self {
data: values.to_vec(),
}
}
#[inline]
pub fn len(&self) -> usize {
self.data.len()
}
#[inline]
pub fn dimension(&self) -> usize {
self.data.len()
}
#[inline]
pub fn is_empty(&self) -> bool {
self.data.is_empty()
}
pub fn dot(&self, other: &FixedVector) -> FixedPoint {
assert_eq!(self.len(), other.len(), "FixedVector::dot: dimension mismatch");
compute_tier_dot(&self.data, &other.data)
}
pub fn length_squared(&self) -> FixedPoint {
self.dot(self)
}
pub fn length(&self) -> FixedPoint {
self.length_squared().sqrt()
}
pub fn length_fused(&self) -> FixedPoint {
super::fused::sqrt_sum_sq(&self.data)
}
pub fn distance_to(&self, other: &FixedVector) -> FixedPoint {
assert_eq!(self.len(), other.len(), "FixedVector::distance_to: dimension mismatch");
super::fused::euclidean_distance(&self.data, &other.data)
}
pub fn normalize(&mut self) {
let len = self.length();
for v in &mut self.data {
*v = *v / len;
}
}
pub fn normalized(&self) -> Self {
let mut result = self.clone();
result.normalize();
result
}
pub fn map(&self, f: impl Fn(FixedPoint) -> FixedPoint) -> Self {
Self {
data: self.data.iter().map(|&v| f(v)).collect(),
}
}
pub fn iter(&self) -> std::slice::Iter<'_, FixedPoint> {
self.data.iter()
}
pub fn iter_mut(&mut self) -> std::slice::IterMut<'_, FixedPoint> {
self.data.iter_mut()
}
pub fn metric_distance_safe(&self, other: &FixedVector) -> FixedPoint {
assert_eq!(self.len(), other.len(), "FixedVector::metric_distance_safe: dimension mismatch");
let diff_raw: Vec<BinaryStorage> = (0..self.len())
.map(|i| (self.data[i] - other.data[i]).raw())
.collect();
let sum_sq = FixedPoint::from_raw(compute_tier_dot_raw(&diff_raw, &diff_raw));
sum_sq.sqrt()
}
}
impl Index<usize> for FixedVector {
type Output = FixedPoint;
#[inline]
fn index(&self, idx: usize) -> &FixedPoint {
&self.data[idx]
}
}
impl IndexMut<usize> for FixedVector {
#[inline]
fn index_mut(&mut self, idx: usize) -> &mut FixedPoint {
&mut self.data[idx]
}
}
impl Default for FixedVector {
fn default() -> Self {
Self { data: Vec::new() }
}
}
impl Add for FixedVector {
type Output = Self;
fn add(self, rhs: Self) -> Self {
assert_eq!(self.len(), rhs.len(), "FixedVector::add: dimension mismatch");
Self {
data: self.data.iter().zip(rhs.data.iter())
.map(|(&a, &b)| a + b).collect(),
}
}
}
impl<'a, 'b> Add<&'b FixedVector> for &'a FixedVector {
type Output = FixedVector;
fn add(self, rhs: &'b FixedVector) -> FixedVector {
assert_eq!(self.len(), rhs.len(), "FixedVector::add: dimension mismatch");
FixedVector {
data: self.data.iter().zip(rhs.data.iter())
.map(|(&a, &b)| a + b).collect(),
}
}
}
impl Sub for FixedVector {
type Output = Self;
fn sub(self, rhs: Self) -> Self {
assert_eq!(self.len(), rhs.len(), "FixedVector::sub: dimension mismatch");
Self {
data: self.data.iter().zip(rhs.data.iter())
.map(|(&a, &b)| a - b).collect(),
}
}
}
impl<'a, 'b> Sub<&'b FixedVector> for &'a FixedVector {
type Output = FixedVector;
fn sub(self, rhs: &'b FixedVector) -> FixedVector {
assert_eq!(self.len(), rhs.len(), "FixedVector::sub: dimension mismatch");
FixedVector {
data: self.data.iter().zip(rhs.data.iter())
.map(|(&a, &b)| a - b).collect(),
}
}
}
impl Neg for FixedVector {
type Output = Self;
fn neg(self) -> Self {
Self {
data: self.data.iter().map(|&v| -v).collect(),
}
}
}
impl Neg for &FixedVector {
type Output = FixedVector;
fn neg(self) -> FixedVector {
FixedVector {
data: self.data.iter().map(|&v| -v).collect(),
}
}
}
impl Mul<FixedVector> for FixedPoint {
type Output = FixedVector;
fn mul(self, rhs: FixedVector) -> FixedVector {
FixedVector {
data: rhs.data.iter().map(|&v| self * v).collect(),
}
}
}
impl Mul<FixedPoint> for FixedVector {
type Output = Self;
fn mul(self, rhs: FixedPoint) -> Self {
Self {
data: self.data.iter().map(|&v| v * rhs).collect(),
}
}
}
impl Mul<FixedPoint> for &FixedVector {
type Output = FixedVector;
fn mul(self, rhs: FixedPoint) -> FixedVector {
FixedVector {
data: self.data.iter().map(|&v| v * rhs).collect(),
}
}
}
impl FixedVector {
pub fn dot_precise(&self, other: &FixedVector) -> FixedPoint {
assert_eq!(self.len(), other.len(), "FixedVector::dot_precise: dimension mismatch");
compute_tier_dot(&self.data, &other.data)
}
pub fn cross(&self, other: &FixedVector) -> FixedVector {
assert_eq!(self.len(), 3, "FixedVector::cross: self must be 3D");
assert_eq!(other.len(), 3, "FixedVector::cross: other must be 3D");
FixedVector::from_slice(&[
self.data[1] * other.data[2] - self.data[2] * other.data[1],
self.data[2] * other.data[0] - self.data[0] * other.data[2],
self.data[0] * other.data[1] - self.data[1] * other.data[0],
])
}
pub fn outer_product(&self, other: &FixedVector) -> super::FixedMatrix {
let mut m = super::FixedMatrix::new(self.len(), other.len());
for i in 0..self.len() {
for j in 0..other.len() {
m.set(i, j, self.data[i] * other.data[j]);
}
}
m
}
#[inline]
pub(crate) fn as_slice(&self) -> &[FixedPoint] {
&self.data
}
}