use std::fmt;
use std::sync::Arc;
use tract_data::internal::*;
use super::BlockQuant;
use super::BlockQuantFact;
#[derive(Clone, PartialEq, Eq)]
pub struct BlockQuantStorage {
format: Box<dyn BlockQuant>,
data: Arc<Blob>,
}
impl BlockQuantStorage {
fn expected_bytes(format: &dyn BlockQuant, m: usize, k: usize) -> usize {
m * k / format.block_len() * format.block_bytes()
}
pub fn new(
format: Box<dyn BlockQuant>,
m: usize,
k: usize,
data: Arc<Blob>,
) -> TractResult<Self> {
let expected = Self::expected_bytes(&*format, m, k);
ensure!(
data.len() == expected,
"BlockQuantStorage::new: blob length {} does not match expected {} (m={}, k={}, format={})",
data.len(),
expected,
m,
k,
format,
);
Ok(Self { format, data })
}
pub fn format(&self) -> &dyn BlockQuant {
&*self.format
}
pub fn value(&self) -> &Arc<Blob> {
&self.data
}
pub fn into_tensor_with_shape(self, dt: DatumType, shape: &[usize]) -> Tensor {
Tensor::from_storage(dt, shape, self)
}
}
pub fn block_quant_slice<'a>(
data: &'a [u8],
format: &dyn BlockQuant,
m_per_group: usize,
k: usize,
g: usize,
) -> &'a [u8] {
let row_bytes = k / format.block_len() * format.block_bytes();
let group_bytes = m_per_group * row_bytes;
let start = g * group_bytes;
&data[start..start + group_bytes]
}
impl fmt::Debug for BlockQuantStorage {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "BlockQuantStorage({}, bytes={})", self.format, self.data.len())
}
}
impl fmt::Display for BlockQuantStorage {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "BlockQuantStorage({}, bytes={})", self.format, self.data.len())
}
}
impl TensorStorage for BlockQuantStorage {
fn byte_len(&self) -> usize {
self.data.len()
}
fn is_empty(&self) -> bool {
self.data.is_empty()
}
fn deep_clone(&self) -> Box<dyn TensorStorage> {
Box::new(self.clone())
}
fn as_plain(&self) -> Option<&PlainStorage> {
None
}
fn as_plain_mut(&mut self) -> Option<&mut PlainStorage> {
None
}
fn into_plain(self: Box<Self>) -> Option<PlainStorage> {
None
}
fn dyn_hash(&self, state: &mut dyn std::hash::Hasher) {
state.write_u8(1);
self.format.dyn_hash(state);
state.write(self.data.as_bytes());
}
fn exotic_fact(&self, shape: &[usize]) -> TractResult<Option<Box<dyn ExoticFact>>> {
Ok(Some(Box::new(BlockQuantFact::new(dyn_clone::clone_box(&*self.format), shape.into()))))
}
}