#[cfg(feature = "std")]
use std::vec::Vec;
#[cfg(not(feature = "std"))]
use alloc::vec::Vec;
use crate::octet::Octet;
use crate::octets::{add_assign, fused_addassign_mul_scalar, mulassign_scalar};
use crate::symbol::Symbol;
#[cfg(feature = "serde_support")]
use serde::{Deserialize, Serialize};
#[derive(Clone, Debug, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde_support", derive(Serialize, Deserialize))]
pub struct SymbolSlab {
data: Vec<u8>,
count: usize,
symbol_size: usize,
mapping: Option<Vec<usize>>,
}
impl SymbolSlab {
pub fn with_zeros(count: usize, symbol_size: usize) -> Self {
SymbolSlab {
data: vec![0u8; count * symbol_size],
count,
symbol_size,
mapping: None,
}
}
#[allow(dead_code)]
pub fn from_symbols(symbols: Vec<Symbol>, symbol_size: usize) -> Self {
let count = symbols.len();
let mut data = Vec::with_capacity(count * symbol_size);
for symbol in symbols {
let bytes = symbol.into_bytes();
assert_eq!(
bytes.len(),
symbol_size,
"symbol length mismatch in SymbolSlab::from_symbols"
);
data.extend_from_slice(&bytes);
}
SymbolSlab {
data,
count,
symbol_size,
mapping: None,
}
}
#[allow(dead_code)]
pub fn into_symbols(self) -> Vec<Symbol> {
if let Some(ref mapping) = self.mapping {
mapping
.iter()
.map(|&phys| {
let start = phys * self.symbol_size;
Symbol::new(self.data[start..start + self.symbol_size].to_vec())
})
.collect()
} else {
self.data
.chunks_exact(self.symbol_size)
.map(|chunk| Symbol::new(chunk.to_vec()))
.collect()
}
}
#[inline]
pub fn len(&self) -> usize {
self.count
}
#[inline]
pub fn symbol_size(&self) -> usize {
self.symbol_size
}
#[inline(always)]
fn physical_index(&self, i: usize) -> usize {
self.mapping.as_ref().map_or(i, |m| m[i])
}
#[inline]
pub fn get(&self, i: usize) -> &[u8] {
let i = self.physical_index(i);
let start = i * self.symbol_size;
&self.data[start..start + self.symbol_size]
}
#[inline]
pub fn get_mut(&mut self, i: usize) -> &mut [u8] {
let i = self.physical_index(i);
let start = i * self.symbol_size;
&mut self.data[start..start + self.symbol_size]
}
#[inline]
pub fn get_pair_mut(&mut self, dest: usize, src: usize) -> (&mut [u8], &[u8]) {
let dest = self.physical_index(dest);
let src = self.physical_index(src);
assert_ne!(dest, src, "dest and src must differ");
assert!(dest < self.count, "dest out of range");
assert!(src < self.count, "src out of range");
let ss = self.symbol_size;
let dest_start = dest * ss;
let src_start = src * ss;
unsafe {
let ptr = self.data.as_mut_ptr();
let dest_slice = core::slice::from_raw_parts_mut(ptr.add(dest_start), ss);
let src_slice = core::slice::from_raw_parts(ptr.add(src_start), ss);
(dest_slice, src_slice)
}
}
#[inline]
pub fn add_assign(&mut self, dest: usize, src: usize) {
let (d, s) = self.get_pair_mut(dest, src);
add_assign(d, s);
}
#[inline]
pub fn mulassign_scalar(&mut self, dest: usize, scalar: &Octet) {
let d = self.get_mut(dest);
mulassign_scalar(d, scalar);
}
#[inline]
pub fn fma(&mut self, dest: usize, src: usize, scalar: &Octet) {
let (d, s) = self.get_pair_mut(dest, src);
fused_addassign_mul_scalar(d, s, scalar);
}
pub fn set_reorder(&mut self, order: Vec<usize>) {
self.mapping = Some(order);
}
#[allow(dead_code)]
pub fn copy_block_from(&mut self, dest_symbol_start: usize, source: &[u8]) {
debug_assert!(
self.mapping.is_none(),
"copy_block_from called with active mapping"
);
debug_assert_eq!(source.len() % self.symbol_size, 0);
let start = dest_symbol_start * self.symbol_size;
self.data[start..start + source.len()].copy_from_slice(source);
}
#[allow(dead_code)]
pub fn gather(&self, indices: &[usize]) -> Self {
debug_assert!(self.mapping.is_none(), "gather called with active mapping");
let ss = self.symbol_size;
let new_count = indices.len();
let mut data = vec![0u8; new_count * ss];
for (new_pos, &old_pos) in indices.iter().enumerate() {
data[new_pos * ss..(new_pos + 1) * ss]
.copy_from_slice(&self.data[old_pos * ss..(old_pos + 1) * ss]);
}
SymbolSlab {
data,
count: new_count,
symbol_size: ss,
mapping: None,
}
}
}
#[cfg(feature = "std")]
#[cfg(test)]
mod tests {
use super::*;
use crate::octet::Octet;
#[test]
fn roundtrip_from_into_symbols() {
let symbols = vec![
Symbol::new(vec![1, 2, 3, 4]),
Symbol::new(vec![5, 6, 7, 8]),
Symbol::new(vec![9, 10, 11, 12]),
];
let slab = SymbolSlab::from_symbols(symbols.clone(), 4);
assert_eq!(slab.len(), 3);
assert_eq!(slab.symbol_size(), 4);
assert_eq!(slab.get(0), &[1, 2, 3, 4]);
assert_eq!(slab.get(1), &[5, 6, 7, 8]);
assert_eq!(slab.get(2), &[9, 10, 11, 12]);
let back = slab.into_symbols();
assert_eq!(back, symbols);
}
#[test]
fn add_assign_xor() {
let mut slab = SymbolSlab::from_symbols(
vec![Symbol::new(vec![0xFF, 0x00]), Symbol::new(vec![0x0F, 0xF0])],
2,
);
slab.add_assign(0, 1);
assert_eq!(slab.get(0), &[0xF0, 0xF0]);
assert_eq!(slab.get(1), &[0x0F, 0xF0]); }
#[test]
fn reorder_symbols() {
let mut slab = SymbolSlab::from_symbols(
vec![
Symbol::new(vec![0]),
Symbol::new(vec![1]),
Symbol::new(vec![2]),
],
1,
);
slab.set_reorder(vec![2, 0, 1]);
assert_eq!(slab.get(0), &[2]);
assert_eq!(slab.get(1), &[0]);
assert_eq!(slab.get(2), &[1]);
}
#[test]
fn fma_operation() {
let mut slab =
SymbolSlab::from_symbols(vec![Symbol::new(vec![0x01]), Symbol::new(vec![0x02])], 1);
let scalar = Octet::new(3);
slab.fma(0, 1, &scalar);
assert_eq!(slab.get(0), &[0x07]);
}
}