#![forbid(unsafe_code)]
#![allow(deprecated)]
#![deny(missing_docs)]
use sim_kernel::{
AbiVersion, DefaultFactory, Dependency, Export, Factory, Lib, LibManifest, LibTarget, Linker,
Result, Symbol, Value, Version,
};
use sim_lib_numbers_tensor::{
SpecTensor, SpecTensorDescriptor, Tensor, domains, element_count, spec_tensor_descriptor_value,
spec_tensor_symbol,
};
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct BitTensor {
shape: Vec<usize>,
len: usize,
words: Vec<u64>,
}
impl BitTensor {
pub fn from_bools(shape: Vec<usize>, bits: &[bool]) -> Option<Self> {
let len = element_count(&shape);
if len != bits.len() {
return None;
}
let mut words = vec![0u64; len.div_ceil(64)];
for (index, bit) in bits.iter().enumerate() {
if *bit {
words[index / 64] |= 1u64 << (index % 64);
}
}
Some(Self { shape, len, words })
}
pub fn to_bools(&self) -> Vec<bool> {
(0..self.len)
.map(|index| ((self.words[index / 64] >> (index % 64)) & 1) == 1)
.collect()
}
pub fn bit_or(&self, other: &Self) -> Option<Self> {
map_words(self, other, |left, right| left | right)
}
pub fn bit_xor(&self, other: &Self) -> Option<Self> {
map_words(self, other, |left, right| left ^ right)
}
pub fn bit_and(&self, other: &Self) -> Option<Self> {
map_words(self, other, |left, right| left & right)
}
}
impl SpecTensor for BitTensor {
fn shape(&self) -> &[usize] {
&self.shape
}
fn dtype(&self) -> Symbol {
domains::bool()
}
fn to_uniform(&self) -> Tensor {
Tensor {
shape: self.shape.clone(),
dtype: self.dtype(),
data: self
.to_bools()
.into_iter()
.map(bool_value)
.collect::<Option<Vec<_>>>()
.expect("bool tensor values should always encode"),
}
}
fn from_uniform(tensor: &Tensor) -> Option<Self> {
let bits = tensor
.data
.iter()
.map(parse_bool_cell)
.collect::<Option<Vec<_>>>()?;
Self::from_bools(tensor.shape.clone(), &bits)
}
}
pub struct BitTensorLib;
impl BitTensorLib {
pub fn new() -> Self {
Self
}
}
impl Default for BitTensorLib {
fn default() -> Self {
Self::new()
}
}
impl Lib for BitTensorLib {
fn manifest(&self) -> LibManifest {
LibManifest {
id: tensor_lib_symbol(),
version: Version(env!("CARGO_PKG_VERSION").to_owned()),
abi: AbiVersion { major: 0, minor: 1 },
target: LibTarget::HostRegistered,
requires: Vec::<Dependency>::new(),
capabilities: Vec::new(),
exports: vec![Export::Value {
symbol: tensor_spec_symbol(),
}],
}
}
fn load(&self, _cx: &mut sim_kernel::LoadCx, linker: &mut Linker<'_>) -> Result<()> {
linker.value(
tensor_spec_symbol(),
spec_tensor_descriptor_value(
&DefaultFactory,
SpecTensorDescriptor {
symbol: tensor_spec_symbol(),
dtype: domains::bool(),
implementation: "BitTensor",
storage: "bit-packed u64 words",
},
)?,
)
}
}
pub fn tensor_lib_symbol() -> Symbol {
domains::domain("tensor-bit")
}
pub fn tensor_spec_symbol() -> Symbol {
spec_tensor_symbol("bit")
}
fn bool_value(value: bool) -> Option<Value> {
DefaultFactory
.number_literal(domains::bool(), value.to_string())
.ok()
}
fn parse_bool_cell(value: &Value) -> Option<bool> {
let mut cx = sim_kernel::Cx::new(
std::sync::Arc::new(sim_kernel::NoopEvalPolicy),
std::sync::Arc::new(DefaultFactory),
);
let literal = value
.object()
.as_number_value()?
.number_literal(&mut cx)
.ok()??;
(literal.domain == domains::bool())
.then(|| literal.canonical.parse::<bool>().ok())
.flatten()
}
fn map_words(
left: &BitTensor,
right: &BitTensor,
f: impl Fn(u64, u64) -> u64,
) -> Option<BitTensor> {
(left.shape == right.shape).then(|| BitTensor {
shape: left.shape.clone(),
len: left.len,
words: left
.words
.iter()
.zip(right.words.iter())
.map(|(left, right)| f(*left, *right))
.collect(),
})
}
#[cfg(test)]
mod tests {
use sim_kernel::Lib;
use super::{BitTensor, BitTensorLib, SpecTensor, tensor_spec_symbol};
#[test]
fn bit_tensor_and_matches_bool_and() {
let left = BitTensor::from_bools(vec![4], &[true, false, true, true]).unwrap();
let right = BitTensor::from_bools(vec![4], &[true, true, false, true]).unwrap();
let out = left.bit_and(&right).unwrap();
assert_eq!(out.to_bools(), vec![true, false, false, true]);
let uniform = out.to_uniform();
assert_eq!(uniform.shape, vec![4]);
}
#[test]
fn lib_exports_spec_tensor_descriptor() {
assert_eq!(
BitTensorLib::new().manifest().exports[0].symbol(),
&tensor_spec_symbol()
);
}
}