use crate::{Error, Result};
use candle_core::{Device, Tensor};
#[derive(Debug)]
pub struct RollbackCache {
keys: Option<Tensor>,
values: Option<Tensor>,
committed_len: usize,
total_len: usize,
device: Device,
}
#[derive(Debug, Clone, Copy)]
pub struct KvSnapshot {
committed_len: usize,
}
impl RollbackCache {
pub fn new(device: Device) -> Self {
Self {
keys: None,
values: None,
committed_len: 0,
total_len: 0,
device,
}
}
pub fn total_len(&self) -> usize {
self.total_len
}
pub fn committed_len(&self) -> usize {
self.committed_len
}
pub fn snapshot(&self) -> KvSnapshot {
KvSnapshot {
committed_len: self.committed_len,
}
}
pub fn append(&mut self, k: &Tensor, v: &Tensor) -> Result<()> {
let new_len = k.dim(2)?;
debug_assert_eq!(new_len, v.dim(2)?, "append: k seq_len must equal v seq_len");
match (&self.keys, &self.values) {
(None, None) => {
self.keys = Some(k.clone());
self.values = Some(v.clone());
}
(Some(prev_k), Some(prev_v)) => {
let trimmed_k = prev_k.narrow(2, 0, self.total_len)?;
let trimmed_v = prev_v.narrow(2, 0, self.total_len)?;
self.keys = Some(Tensor::cat(&[&trimmed_k, k], 2)?);
self.values = Some(Tensor::cat(&[&trimmed_v, v], 2)?);
}
_ => unreachable!("keys/values invariant: both Some or both None"),
}
self.total_len += new_len;
Ok(())
}
pub fn commit(&mut self) {
self.committed_len = self.total_len;
}
pub fn rollback(&mut self, snap: KvSnapshot) -> Result<()> {
if snap.committed_len > self.committed_len {
return Err(Error::CacheRollback(format!(
"snapshot points to length {}, but cache has only committed {}",
snap.committed_len, self.committed_len
)));
}
self.total_len = snap.committed_len;
self.committed_len = snap.committed_len;
Ok(())
}
pub fn current(&self) -> Result<Option<(Tensor, Tensor)>> {
match (&self.keys, &self.values) {
(Some(k), Some(v)) => {
let k = k.narrow(2, 0, self.total_len)?;
let v = v.narrow(2, 0, self.total_len)?;
Ok(Some((k, v)))
}
(None, None) => Ok(None),
_ => unreachable!(),
}
}
pub fn device(&self) -> &Device {
&self.device
}
}
#[cfg(test)]
mod tests {
use super::*;
use candle_core::{DType, Device};
fn fake_kv(seq: usize, device: &Device) -> (Tensor, Tensor) {
let shape = (1usize, 2usize, seq, 4usize);
let k = Tensor::ones(shape, DType::F32, device).unwrap();
let v = Tensor::ones(shape, DType::F32, device).unwrap();
(k, v)
}
#[test]
fn append_extends_total_len() {
let dev = Device::Cpu;
let mut c = RollbackCache::new(dev.clone());
let (k, v) = fake_kv(3, &dev);
c.append(&k, &v).unwrap();
assert_eq!(c.total_len(), 3);
assert_eq!(c.committed_len(), 0);
}
#[test]
fn commit_advances_committed_len() {
let dev = Device::Cpu;
let mut c = RollbackCache::new(dev.clone());
let (k, v) = fake_kv(3, &dev);
c.append(&k, &v).unwrap();
c.commit();
assert_eq!(c.committed_len(), 3);
assert_eq!(c.total_len(), 3);
}
#[test]
fn rollback_truncates_to_snapshot() {
let dev = Device::Cpu;
let mut c = RollbackCache::new(dev.clone());
let (k1, v1) = fake_kv(3, &dev);
c.append(&k1, &v1).unwrap();
c.commit();
let snap = c.snapshot();
let (k2, v2) = fake_kv(5, &dev);
c.append(&k2, &v2).unwrap();
assert_eq!(c.total_len(), 8);
c.rollback(snap).unwrap();
assert_eq!(c.total_len(), 3);
assert_eq!(c.committed_len(), 3);
}
#[test]
fn rollback_to_future_snapshot_is_error() {
let dev = Device::Cpu;
let mut c = RollbackCache::new(dev.clone());
let bogus = KvSnapshot { committed_len: 99 };
let (k, v) = fake_kv(2, &dev);
c.append(&k, &v).unwrap();
c.commit();
assert!(c.rollback(bogus).is_err());
}
}