use std::{collections::HashMap, path::Path};
use ort::{
memory::Allocator,
session::Session,
value::{Tensor, TensorRef},
};
use smol_str::SmolStr;
use crate::{
error::{Error, Result},
options::Options,
runtime::session::{build_session, collect_cache_inputs, validate_decoder_session},
};
#[allow(dead_code)]
pub(crate) struct Decoder {
session: Session,
template: KvCacheTemplate,
}
#[allow(dead_code)]
#[derive(Clone)]
pub(crate) struct KvCacheTemplate {
conv: Vec<(SmolStr, Vec<i64>)>,
attn: Vec<(SmolStr, Vec<i64>)>,
present_to_past: HashMap<SmolStr, SmolStr>,
}
#[allow(dead_code)]
pub(crate) struct KvCache {
conv: HashMap<SmolStr, Tensor<f32>>,
attn: HashMap<SmolStr, Tensor<f32>>,
pub(crate) past_len: usize,
}
impl Decoder {
#[allow(dead_code)]
pub(crate) fn from_path(path: &Path, opts: &Options) -> Result<Self> {
let session = build_session(path, opts)?;
validate_decoder_session(&session)?;
let template = build_template(&session)?;
Ok(Self { session, template })
}
#[allow(dead_code)]
pub(crate) fn from_session(session: Session) -> Result<Self> {
validate_decoder_session(&session)?;
let template = build_template(&session)?;
Ok(Self { session, template })
}
#[allow(dead_code)]
pub(crate) fn new_cache(&self) -> Result<KvCache> {
let alloc = Allocator::default();
let mut conv = HashMap::with_capacity(self.template.conv.len());
for (name, shape_i64) in &self.template.conv {
let shape: Vec<usize> = shape_i64.iter().map(|&d| d as usize).collect();
let total: usize = shape.iter().product();
let tensor = Tensor::from_array((shape.as_slice(), vec![0f32; total])).map_err(Error::Ort)?;
conv.insert(name.clone(), tensor);
}
let mut attn = HashMap::with_capacity(self.template.attn.len());
for (name, shape_i64) in &self.template.attn {
let shape: Vec<usize> = shape_i64
.iter()
.map(|&d| if d < 0 { 0 } else { d as usize })
.collect();
let tensor = Tensor::<f32>::new(&alloc, shape.as_slice()).map_err(Error::Ort)?;
attn.insert(name.clone(), tensor);
}
Ok(KvCache {
conv,
attn,
past_len: 0,
})
}
#[allow(dead_code)]
pub(crate) fn step(
&mut self,
cache: &mut KvCache,
inputs_embeds: &[f32],
seq_len: usize,
) -> Result<Vec<f32>> {
let total_len = cache.past_len + seq_len;
let attn_mask: Vec<i64> = vec![1i64; total_len];
let inputs_shape = [1usize, seq_len, 1024usize];
let mask_shape = [1usize, total_len];
let embeds_ref =
TensorRef::from_array_view((inputs_shape, inputs_embeds)).map_err(Error::Ort)?;
let mask_ref =
TensorRef::from_array_view((mask_shape, attn_mask.as_slice())).map_err(Error::Ort)?;
let mut my_inputs = ort::inputs![
"inputs_embeds" => embeds_ref,
"attention_mask" => mask_ref,
];
for (name, tensor) in cache.conv.iter().chain(cache.attn.iter()) {
my_inputs.push((name.as_str().into(), tensor.into()));
}
let outputs = self.session.run(my_inputs).map_err(Error::Ort)?;
let logits_out = outputs.get("logits").ok_or(Error::SessionShapeMismatch {
input: "logits",
expected: "output present in session run",
got: vec![],
})?;
let (shape, data) = logits_out.try_extract_tensor::<f32>().map_err(Error::Ort)?;
if shape.len() != 3 || shape[0] < 1 || shape[1] < 1 || shape[2] != 65536 {
return Err(Error::SessionShapeMismatch {
input: "logits",
expected: "[batch>=1, seq>=1, 65536]",
got: shape.to_vec(),
});
}
let last_pos = (shape[1] - 1) as usize;
let vocab = shape[2] as usize;
let logits = data[last_pos * vocab..(last_pos + 1) * vocab].to_vec();
if logits.iter().any(|v| !v.is_finite()) {
return Err(Error::SessionNonFiniteOutput { stage: "decoder" });
}
advance_cache(cache, &outputs, &self.template.present_to_past)?;
cache.past_len = total_len;
Ok(logits)
}
}
fn build_template(session: &Session) -> Result<KvCacheTemplate> {
let inputs = collect_cache_inputs(session.inputs())?;
let mut conv = Vec::with_capacity(inputs.conv.len());
for name in inputs.conv {
conv.push((SmolStr::from(name), vec![1i64, 1024, 3]));
}
let mut attn = Vec::with_capacity(inputs.attn.len());
for name in inputs.attn {
attn.push((SmolStr::from(name), vec![1i64, 8, -1, 64]));
}
let present_to_past = build_present_to_past(
&session
.outputs()
.iter()
.map(|o| SmolStr::from(o.name()))
.collect::<Vec<_>>(),
);
let mapped_pasts: std::collections::HashSet<&SmolStr> = present_to_past.values().collect();
let mut missing: Vec<i64> = Vec::new();
for (past_name, _) in conv.iter().chain(attn.iter()) {
if !mapped_pasts.contains(past_name) {
missing.push(past_name.len() as i64);
}
}
if !missing.is_empty() {
return Err(Error::SessionShapeMismatch {
input: "present_*",
expected: "one present_* output per past_* cache input",
got: missing,
});
}
Ok(KvCacheTemplate {
conv,
attn,
present_to_past,
})
}
fn build_present_to_past(present_names: &[SmolStr]) -> HashMap<SmolStr, SmolStr> {
let mut map = HashMap::new();
for n in present_names {
if let Some(rest) = n.strip_prefix("present_conv.") {
map.insert(n.clone(), SmolStr::from(format!("past_conv.{rest}")));
} else if let Some(rest) = n.strip_prefix("present.") {
map.insert(n.clone(), SmolStr::from(format!("past_key_values.{rest}")));
}
}
map
}
fn advance_cache(
cache: &mut KvCache,
outputs: &ort::session::SessionOutputs<'_>,
present_to_past: &HashMap<SmolStr, SmolStr>,
) -> Result<()> {
for (present_name, past_name) in present_to_past {
let Some(out) = outputs.get(present_name.as_str()) else {
return Err(Error::SessionShapeMismatch {
input: "present_*",
expected: "every mapped present_* output present in session.run() result",
got: vec![present_name.len() as i64],
});
};
let (shape, data) = out.try_extract_tensor::<f32>().map_err(Error::Ort)?;
let shape_usize: Vec<usize> = shape.iter().map(|&v| v as usize).collect();
let new_tensor =
Tensor::from_array((shape_usize.as_slice(), data.to_vec())).map_err(Error::Ort)?;
if past_name.starts_with("past_conv.") {
cache.conv.insert(past_name.clone(), new_tensor);
} else {
cache.attn.insert(past_name.clone(), new_tensor);
}
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn build_present_to_past_maps_conv_and_attn() {
let names: Vec<SmolStr> = vec![
"present_conv.0".into(),
"present_conv.13".into(),
"present.2.key".into(),
"present.14.value".into(),
"logits".into(), ];
let map = build_present_to_past(&names);
assert_eq!(
map
.get(&SmolStr::from("present_conv.0"))
.map(SmolStr::as_str),
Some("past_conv.0")
);
assert_eq!(
map
.get(&SmolStr::from("present_conv.13"))
.map(SmolStr::as_str),
Some("past_conv.13")
);
assert_eq!(
map
.get(&SmolStr::from("present.2.key"))
.map(SmolStr::as_str),
Some("past_key_values.2.key")
);
assert_eq!(
map
.get(&SmolStr::from("present.14.value"))
.map(SmolStr::as_str),
Some("past_key_values.14.value")
);
assert_eq!(map.get(&SmolStr::from("logits")), None);
}
}