use tract_nnef::internal::*;
use tract_nnef::prelude::tract_itertools::Itertools;
use tract_nnef::tract_core::ops::array::TypedConcat;
use tract_nnef::tract_core::ops::source::TypedSource;
use tract_nnef::tract_core::ops::OpStateFreeze;
use crate::rule_ensure;
use super::next_node;
#[derive(Debug, Clone)]
pub struct DynKeyValueCacheState {
name: String,
axis: usize,
past_sequence_fact: TypedFact,
kv_cache: Option<TValue>,
}
impl DynKeyValueCacheState {
pub fn resolve_symbols(
state: &mut SessionState,
fact: TypedFact,
concrete_shape: Option<&[usize]>,
) -> TractResult<()> {
let unresolved = fact
.shape
.iter()
.enumerate()
.filter_map(|(ax, symb)| match symb {
TDim::Sym(s) if state.resolved_symbols.get(s).is_none() => Some((ax, s)),
_ => None,
})
.collect_vec();
if unresolved.is_empty() {
return Ok(());
}
ensure!(unresolved.len() == 1);
let (ax, sym) = unresolved[0];
if let Some(shape) = concrete_shape {
ensure!(ax < shape.len());
state.resolved_symbols.set(sym, shape[ax] as i64);
} else {
state.resolved_symbols.set(sym, 0);
}
if state.scenario.is_none() {
state.scenario = sym.scope().unwrap().guess_scenario(&state.resolved_symbols)?;
}
Ok(())
}
pub fn truncate(&mut self, len: usize) -> TractResult<()> {
if let Some(t) = self.kv_cache.as_mut() {
*t = t.slice(self.axis, 0, len)?.into_tvalue();
} else {
bail!("Can not truncate a zero-len kv-cache value");
}
Ok(())
}
}
impl OpState for DynKeyValueCacheState {
fn load_from(&mut self, state: &mut SessionState, states: &mut Vec<TValue>) -> TractResult<()> {
let kv_cache_init = states.remove(0);
Self::resolve_symbols(state, self.past_sequence_fact.clone(), Some(kv_cache_init.shape()))?;
self.kv_cache = Some(kv_cache_init);
Ok(())
}
fn save_to(&self, states: &mut Vec<TValue>) -> TractResult<()> {
if let Some(kv_cache) = &self.kv_cache {
states.push(kv_cache.clone());
Ok(())
} else {
bail!("KV cache {} was never initialized", self.name)
}
}
fn init_tensor_fact(&self) -> Option<TypedFact> {
Some(self.past_sequence_fact.clone())
}
fn resolve_symbols(&mut self, state: &mut SessionState) -> TractResult<()> {
let shape = self.kv_cache.as_ref().map(|kv_cache| kv_cache.shape());
Self::resolve_symbols(state, self.past_sequence_fact.clone(), shape)
}
fn eval(
&mut self,
_state: &mut SessionState,
_op: &dyn Op,
inputs: TVec<TValue>,
) -> TractResult<TVec<TValue>> {
let input = args_1!(inputs);
let output = if let Some(curr) = self.kv_cache.take() {
TypedConcat { axis: self.axis }.eval(tvec![curr, input])?.remove(0)
} else {
input
};
self.kv_cache = Some(output.clone());
Ok(tvec!(output))
}
}
#[derive(Clone, Debug)]
pub struct DynKeyValueCache {
pub name: String,
pub axis: usize,
pub past_sequence_fact: TypedFact,
pub input_sequence_fact: TypedFact,
}
impl Op for DynKeyValueCache {
fn name(&self) -> StaticName {
"DynamicKeyValueCache".to_string().into()
}
op_as_typed_op!();
}
impl EvalOp for DynKeyValueCache {
fn is_stateless(&self) -> bool {
false
}
fn state(
&self,
_session: &mut SessionState,
_node_id: usize,
) -> TractResult<Option<Box<dyn OpState>>> {
Ok(Some(Box::new(DynKeyValueCacheState {
name: self.name.clone(),
axis: self.axis,
past_sequence_fact: self.past_sequence_fact.clone(),
kv_cache: None,
})))
}
}
impl TypedOp for DynKeyValueCache {
fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
ensure!(inputs.len() == 1);
let input = inputs[0];
let mut fact = input.without_value();
fact.shape.set(
self.axis,
self.past_sequence_fact.shape.dims()[self.axis].clone()
+ self.input_sequence_fact.shape.dims()[self.axis].clone(),
);
Ok(tvec!(fact))
}
fn cost(&self, _inputs: &[&TypedFact]) -> TractResult<TVec<(Cost, TDim)>> {
let token_volume = self
.past_sequence_fact
.shape
.iter()
.enumerate()
.filter(|(axis, _d)| *axis != self.axis)
.map(|(_axis, d)| d)
.product::<TDim>();
Ok(tvec!((Cost::Custom(false, "KVCacheValuesPerToken".to_string()), token_volume)))
}
as_op!();
}
#[derive(Debug, Clone)]
pub struct FrozenDynKeyValueCacheState {
name: String,
axis: usize,
past_sequence_fact: TypedFact,
kv_cache: Option<Tensor>,
}
impl OpStateFreeze for DynKeyValueCacheState {
fn freeze(&self) -> Box<dyn FrozenOpState> {
Box::new(FrozenDynKeyValueCacheState {
name: self.name.clone(),
axis: self.axis,
past_sequence_fact: self.past_sequence_fact.clone(),
kv_cache: self.kv_cache.clone().map(|t| t.into_tensor()),
})
}
}
impl FrozenOpState for FrozenDynKeyValueCacheState {
fn unfreeze(&self) -> Box<dyn OpState> {
Box::new(DynKeyValueCacheState {
axis: self.axis,
name: self.name.clone(),
past_sequence_fact: self.past_sequence_fact.clone(),
kv_cache: self.kv_cache.clone().map(|t| t.into_tvalue()),
})
}
}
pub fn replace_kv_cache(target: &mut TypedModel, source_node_id: usize) -> TractResult<Option<()>> {
assert!(target.node(source_node_id).op_is::<TypedSource>());
let (concat_node_id, non_source_input_id, axis, input_facts) = {
let concat_node = if let Some(n_node) = next_node(target, target.node(source_node_id)) {
n_node
} else {
return Ok(None);
};
rule_ensure!(
concat_node.op_is::<TypedConcat>()
&& concat_node.inputs.len() == 2
&& concat_node.outputs.len() == 1
&& target.outputs.contains(&concat_node.id.into())
);
let concat_in_facts = target.node_input_facts(concat_node.id)?;
let concat_in_shapes = [concat_in_facts[0].shape.dims(), concat_in_facts[1].shape.dims()];
let rank = concat_in_shapes[0].len();
let axes = (0..rank)
.filter(|ax| concat_in_shapes[0][*ax] != concat_in_shapes[1][*ax])
.collect_vec();
ensure!(axes.len() == 1);
let axis = axes[0];
rule_ensure!(
matches!(concat_in_shapes[0][axis], TDim::Sym(_))
&& matches!(concat_in_shapes[1][axis], TDim::Sym(_))
);
let mut facts = [concat_in_facts[0].clone(), concat_in_facts[1].clone()];
if concat_node.inputs[0].node == source_node_id {
(concat_node.id, concat_node.inputs[1].node, axis, facts)
} else if concat_node.inputs[1].node == source_node_id {
facts.swap(0, 1);
(concat_node.id, concat_node.inputs[0].node, axis, facts)
} else {
return Ok(None);
}
};
{
let name = target.node_names().collect_vec()[source_node_id].to_string();
let concat_node = target.node_mut(concat_node_id);
concat_node.op = Box::new(DynKeyValueCache {
name: name.clone(),
axis,
past_sequence_fact: input_facts[0].clone(),
input_sequence_fact: input_facts[1].clone(),
});
concat_node.name = name;
concat_node.inputs.retain(|input| input != &source_node_id.into());
}
{
let dummy_op = target.create_dummy();
let source_node = target.node_mut(source_node_id);
source_node.outputs[0].successors.clear();
source_node.op = dummy_op;
}
{
let non_source_input = target.node_mut(non_source_input_id);
non_source_input.outputs.iter_mut().for_each(|output| {
output.successors.iter_mut().for_each(|succ| {
if succ.node == concat_node_id {
succ.slot = 0
}
})
});
}
target.outputs.retain(|output| output.node != concat_node_id);
target.inputs.retain(|input| input.node != source_node_id);
target.outlet_labels.remove(&concat_node_id.into());
Ok(None)
}
#[cfg(test)]
mod tests {
use super::*;
use tract_num_traits::AsPrimitive;
use tract_num_traits::Zero;
fn run_test_case<F: Datum + Zero + Copy>(
input_shapes: &[Vec<usize>],
axis: usize,
) -> TractResult<()>
where
usize: AsPrimitive<F>,
{
let first_shape = &input_shapes[0];
ensure!(input_shapes.iter().all(|shape| (shape.len() == first_shape.len())
&& (shape[..axis] == first_shape[..axis])
&& (if axis != (shape.len() - 1) {
shape[(axis + 1)..] == first_shape[(axis + 1)..]
} else {
true
})));
let op_name = "test".to_string();
let dummy_model = TypedModel::default();
let make_shape =
|sym: &str| {
input_shapes[0]
.iter()
.enumerate()
.map(|(i, &dim)| {
if i == axis {
TDim::Sym(dummy_model.sym(sym))
} else {
TDim::Val(dim as _)
}
})
.collect::<TVec<TDim>>()
};
let past_shape = make_shape("P");
let input_shape = make_shape("S");
let op = DynKeyValueCache {
name: op_name.clone(),
past_sequence_fact: TypedFact::dt_shape(F::datum_type(), past_shape),
input_sequence_fact: TypedFact::dt_shape(F::datum_type(), input_shape),
axis,
};
let mut session_state = SessionState::default();
let mut state = op.state(&mut session_state, 0)?.unwrap();
let mut inputs = tvec![];
let shape = &input_shapes[0];
let len = shape.iter().product::<usize>();
let input = Tensor::from_shape(shape, &(0..len).map(|f| f.as_()).collect::<Vec<F>>())?;
inputs.push(input.clone().into_tvalue());
let mut state_initializers = vec![input.into()];
state.load_from(&mut session_state, &mut state_initializers)?;
for shape in input_shapes {
let len = shape.iter().product::<usize>();
let input = Tensor::from_shape(&shape, &(0..len).map(|f| f.as_()).collect::<Vec<F>>())?;
inputs.push(input.clone().into_tvalue());
state.eval(&mut session_state, &op, tvec!(input.clone().into()))?[0]
.clone()
.into_tensor();
}
let mut curr_states = vec![];
state.save_to(&mut curr_states)?;
let output = curr_states.remove(0);
let reference = &TypedConcat { axis }.eval(inputs)?[0];
output.close_enough(&reference.clone().into_tensor(), Approximation::Close)?;
Ok(())
}
#[test]
fn test_dyn_kv_cache() -> TractResult<()> {
run_test_case::<f32>(&[vec![2, 2]], 0)?;
run_test_case::<f32>(&[vec![2, 2], vec![4, 2]], 0)?;
run_test_case::<f32>(&[vec![2, 2], vec![2, 1], vec![2, 3]], 1)?;
Ok(())
}
}