pub mod ops;
mod rewriter;
use std::collections::HashSet;
use rewriter::*;
use tract_nnef::internal::*;
register_simple_model_transform!("detect_apply_rope", ApplyRopeTransform);
register_simple_model_transform!("detect_scaled_masked_softmax", ScaledMaskedSoftmaxTransform);
register_simple_model_transform!("detect_kv_cache", KeyValueCacheTransform);
register_simple_model_transform!(
"detect_sdpa_kv_cache_broadcast",
SdpaFuseKvCacheBroadcastTransform
);
register_simple_model_transform!("unfold_kv_cache", UnfoldKeyValueCacheTransform);
register_simple_model_transform!("transformers_detect_all", TransformersTransform);
pub fn register(registry: &mut Registry) {
ops::apply_rope::register(registry);
ops::scaled_masked_softmax::register(registry);
ops::sdpa::register(registry);
ops::dyn_kv_cache::register(registry);
}
pub trait WithTractTransformers {
fn enable_tract_transformers(&mut self);
fn with_tract_transformers(self) -> Self;
}
impl WithTractTransformers for tract_nnef::framework::Nnef {
fn enable_tract_transformers(&mut self) {
self.enable_tract_core();
self.registries.push(tract_transformers_registry());
}
fn with_tract_transformers(mut self) -> Self {
self.enable_tract_transformers();
self
}
}
pub fn tract_transformers_registry() -> Registry {
let mut reg = Registry::new("tract_transformers")
.with_doc("Extension `tract_transformers` extends NNEF with operators")
.with_doc("for transformer networks.")
.with_doc("")
.with_doc("Add `extension tract_transformers` to `graph.nnef`");
register(&mut reg);
reg
}
pub fn figure_out_causal_llm_b_s_p(
model: &TypedModel,
) -> TractResult<(Option<Symbol>, Option<Symbol>, Option<Symbol>)> {
let token_input = model
.inputs
.iter()
.position(|i| model.outlet_fact(*i).unwrap().datum_type.is_integer())
.context("No token input found")?;
let tokens_symbols = model.input_fact(token_input)?.shape.volume().symbols();
let kv_symbols = if let Some(kv_input) =
model.inputs.iter().position(|i| model.outlet_fact(*i).unwrap().datum_type.is_float())
{
model.input_fact(kv_input)?.shape.volume().symbols()
} else {
let dummy_session_state = TurnState::default();
let mut symbols = HashSet::new();
for node in &model.nodes {
if let Some((_, fact)) =
node.op.state(&dummy_session_state, 0)?.and_then(|state| state.init_tensor_fact())
{
symbols = fact.shape.volume().symbols();
break;
}
}
symbols
};
let b = tokens_symbols.intersection(&kv_symbols).cloned().collect::<HashSet<_>>();
let s = tokens_symbols.difference(&b).cloned().collect::<HashSet<_>>();
let p = kv_symbols.difference(&b).cloned().collect::<HashSet<_>>();
Ok((b.into_iter().next(), s.into_iter().next(), p.into_iter().next()))
}
pub fn memory_arena_hints_for_causal_llm(model: &TypedModel) -> TractResult<SymbolValues> {
let (b, s, p) = figure_out_causal_llm_b_s_p(model)?;
let mut values = SymbolValues::default()
.with(&s.context("Could not determine sequence_len (S)")?, 1024)
.with(&p.context("Could not determine past_sequence_len (P)")?, 0);
if let Some(b) = b {
values = values.with(&b, 1);
}
Ok(values)
}