use core::{num::NonZeroUsize, ops::Deref};
use crate::{
context::{anymap::Map, ContextOperator, Value},
Operator,
};
use super::Layer;
#[derive(Debug, Clone, Copy)]
pub struct Cache {
length: NonZeroUsize,
}
impl Cache {
pub fn new() -> Self {
Self {
length: NonZeroUsize::new(1).unwrap(),
}
}
pub fn with_length(length: NonZeroUsize) -> Self {
Self { length }
}
}
impl Default for Cache {
fn default() -> Self {
Self::new()
}
}
impl<T, P> Layer<T, P> for Cache
where
P: ContextOperator<T>,
{
type Operator = CacheOperator<P>;
type Out = P::Out;
fn layer(&self, inner: P) -> Self::Operator {
CacheOperator {
inner,
previous: Previous::default(),
limit: self.length.get() - 1,
}
}
}
#[derive(Debug, Default)]
pub struct Previous(Map);
impl Previous {
fn take(&mut self) -> Self {
core::mem::take(self)
}
pub fn backward<F>(&self, mut f: F)
where
F: FnMut(&Map),
{
f(&self.0);
if let Some(prev) = self.0.get::<Previous>() {
prev.backward(f);
}
}
fn backward_mut<F>(&mut self, mut f: F)
where
F: FnMut(&mut Map),
{
f(&mut self.0);
if let Some(prev) = self.0.get_mut::<Previous>() {
prev.backward_mut(f);
}
}
}
impl Deref for Previous {
type Target = Map;
fn deref(&self) -> &Self::Target {
&self.0
}
}
#[derive(Debug)]
pub struct CacheOperator<P> {
inner: P,
previous: Previous,
limit: usize,
}
impl<T, P> Operator<Value<T>> for CacheOperator<P>
where
P: ContextOperator<T>,
{
type Output = Value<P::Out>;
fn next(&mut self, mut input: Value<T>) -> Self::Output {
input.context_mut().env_mut().insert(self.previous.take());
let mut output = self.inner.next(input);
if self.limit > 0 {
let limit = self.limit;
let mut count = 1;
output
.context_mut()
.env_mut()
.get_mut::<Previous>()
.unwrap()
.backward_mut(|ctx| {
if count >= limit {
ctx.remove::<Previous>();
}
count += 1;
});
} else {
output.context_mut().env_mut().remove::<Previous>();
}
self.previous
.0
.extend(core::mem::take(output.context_mut().env_mut()));
output
}
}
#[cfg(test)]
mod tests {
use crate::{
context::{input, layer_fn, ContextOperatorExt},
IndicatorIteratorExt, OperatorExt,
};
use super::*;
#[test]
fn square_cache() {
struct Square<P>(P);
impl<P> Operator<Value<i32>> for Square<P>
where
P: ContextOperator<i32>,
{
type Output = Value<P::Out>;
fn next(&mut self, mut input: Value<i32>) -> Self::Output {
input.apply(|v, ctx| {
let prev = ctx
.env()
.get::<Previous>()
.and_then(|prev| prev.get::<i32>().copied())
.unwrap_or(0);
ctx.env_mut().insert(prev.pow(2) + *v);
});
self.0.next(input)
}
}
let op = input()
.map(|input| {
let previous = input.context().env().get::<Previous>().unwrap();
let mut count = 0;
previous.backward(|ctx| {
if let Some(v) = ctx.get::<i32>() {
println!("{count}: {v}");
}
count += 1;
});
input.map(|_, ctx| ctx.env().get::<i32>().copied().unwrap())
})
.with(layer_fn(Square))
.with(Cache::with_length(2.try_into().unwrap()))
.finish();
let data = [1, 2, 3, 4, 5];
data.into_iter().indicator(op).for_each(|v| {
println!("current: {v}");
});
}
}