use std::ops::MulAssign;
use crate::{IncompatibleStages, Row, RowBuf, Stage};
#[derive(Debug, Clone)]
pub struct RowAccumulator {
total: RowBuf,
temp_row: RowBuf,
}
impl RowAccumulator {
#[inline]
pub fn rounds(stage: Stage) -> Self {
Self::new(RowBuf::rounds(stage))
}
#[inline]
pub fn new(total: RowBuf) -> Self {
Self {
temp_row: RowBuf::rounds(total.stage()),
total,
}
}
#[inline]
pub fn post_accumulate(&mut self, row: &Row) -> Result<(), IncompatibleStages> {
self.total.mul_into(row, &mut self.temp_row)?; std::mem::swap(&mut self.total, &mut self.temp_row); Ok(())
}
#[inline]
pub unsafe fn post_accumulate_unchecked(&mut self, row: &Row) {
self.total.mul_into_unchecked(row, &mut self.temp_row); std::mem::swap(&mut self.total, &mut self.temp_row); }
#[inline]
pub fn pre_accumulate(&mut self, row: &Row) -> Result<(), IncompatibleStages> {
row.mul_into(&self.total, &mut self.temp_row)?; std::mem::swap(&mut self.total, &mut self.temp_row); Ok(())
}
#[inline]
pub unsafe fn pre_accumulate_unchecked(&mut self, row: &Row) {
row.mul_into_unchecked(&self.total, &mut self.temp_row); std::mem::swap(&mut self.total, &mut self.temp_row); }
#[inline]
pub fn set(&mut self, row: &Row) {
row.copy_into(&mut self.total);
}
#[inline]
pub fn total(&self) -> &Row {
&self.total
}
#[inline]
pub fn into_total(self) -> RowBuf {
self.total
}
}
impl MulAssign<&Row> for RowAccumulator {
#[inline]
fn mul_assign(&mut self, rhs: &Row) {
self.post_accumulate(rhs).unwrap()
}
}
impl MulAssign<&RowBuf> for RowAccumulator {
#[inline]
fn mul_assign(&mut self, rhs: &RowBuf) {
self.post_accumulate(rhs).unwrap()
}
}