use crate::PlcConfig;
use crate::ans::ProbModel;
use crate::mixer::ContextMixer;
use crate::predictor::TransformerPredictor;
pub mod paged;
use crate::loom::paged::PagedLoom;
#[cfg(feature = "std")]
use alloc::boxed::Box;
pub trait LoomPredictor {
fn predict(&self, history: &[u8]) -> ProbModel;
fn predict_batch(&self, histories: &[&[u8]]) -> Vec<ProbModel> {
histories.iter().map(|&h| self.predict(h)).collect()
}
}
pub trait LoomWeaver {
fn weave(&mut self, history: &[u8], next: u8);
fn weave_batch(&mut self, ops: Vec<(Vec<u8>, u8)>) {
for (ctx, next) in ops {
self.weave(&ctx, next);
}
}
}
pub trait LoomPruner {
fn prune(&mut self);
fn prune_batch(&mut self) {
self.prune();
}
}
pub struct StandardLoom {
predictor: TransformerPredictor,
mixer: ContextMixer,
config: PlcConfig,
}
impl StandardLoom {
pub fn new(config: PlcConfig) -> Self {
let mixer = ContextMixer::new(config.mixer_orders.clone())
.with_node_limit(config.max_nodes)
.with_prune_divisor(config.prune_aggression.prune_divisor());
let predictor = TransformerPredictor::new(config.clone());
Self {
predictor,
mixer,
config,
}
}
}
impl LoomPredictor for StandardLoom {
fn predict(&self, history: &[u8]) -> ProbModel {
self.predictor.predict(history, self.mixer.predict(history))
}
}
impl LoomWeaver for StandardLoom {
fn weave(&mut self, history: &[u8], next: u8) {
if self.config.use_mixer {
self.mixer.update(history, next);
}
}
}
impl LoomPruner for StandardLoom {
fn prune(&mut self) {
}
}
pub enum LoomEnum {
Standard(StandardLoom),
Paged(PagedLoom),
}
impl LoomEnum {
pub fn new(config: PlcConfig) -> Result<Self, &'static str> {
use crate::LoomMode;
match config.loom_mode {
LoomMode::Paged => Ok(LoomEnum::Paged(PagedLoom::new(
config.clone(),
config.paged_loom_max_ram,
))),
LoomMode::Standard => Ok(LoomEnum::Standard(StandardLoom::new(config))),
LoomMode::Off => Err("Loom is disabled"),
}
}
}
impl LoomPredictor for LoomEnum {
fn predict(&self, history: &[u8]) -> ProbModel {
match self {
LoomEnum::Standard(l) => l.predict(history),
LoomEnum::Paged(l) => l.predict(history),
}
}
}
impl LoomWeaver for LoomEnum {
fn weave(&mut self, history: &[u8], next: u8) {
match self {
LoomEnum::Standard(l) => l.weave(history, next),
LoomEnum::Paged(l) => l.weave(history, next),
}
}
}
impl LoomPruner for LoomEnum {
fn prune(&mut self) {
match self {
LoomEnum::Standard(l) => l.prune(),
LoomEnum::Paged(l) => l.prune(),
}
}
}
#[cfg(feature = "std")]
use svalinn::vault::SvalinnAead;
pub struct EncryptedLoom<L> {
inner: L,
_key: [u8; 16],
#[cfg(feature = "std")]
_cipher: Option<Box<SvalinnAead>>,
}
impl<L> EncryptedLoom<L> {
pub fn new(inner: L, key: [u8; 16]) -> Self {
#[cfg(feature = "std")]
let mut k32 = [0u8; 32];
k32[..16].copy_from_slice(&key);
k32[16..].copy_from_slice(&key);
let cipher_result = SvalinnAead::new_chacha(&k32);
let cipher = cipher_result.ok().map(Box::new);
#[cfg(not(feature = "std"))]
let cipher = None;
Self {
inner,
_key: key,
#[cfg(feature = "std")]
_cipher: cipher,
}
}
}
impl<L: LoomPredictor> LoomPredictor for EncryptedLoom<L> {
fn predict(&self, history: &[u8]) -> ProbModel {
self.inner.predict(history)
}
}
impl<L: LoomWeaver> LoomWeaver for EncryptedLoom<L> {
fn weave(&mut self, history: &[u8], next: u8) {
self.inner.weave(history, next)
}
}
impl<L: LoomPruner> LoomPruner for EncryptedLoom<L> {
fn prune(&mut self) {
self.inner.prune()
}
}