use std::path::Path;
use ort::{
session::Session,
value::{Outlet, TensorElementType, ValueType},
};
use crate::{
error::{Error, Result},
options::Options,
};
#[allow(dead_code)]
pub(crate) fn build_session(graph: &Path, opts: &Options) -> Result<Session> {
if !graph.exists() {
return Err(Error::NotFound(graph.to_path_buf()));
}
let level = opts.optimization_level();
let mut builder = Session::builder()
.map_err(Error::Ort)?
.with_optimization_level(level)
.map_err(|e| Error::Ort(ort::Error::from(e)))?;
if let Some(t) = opts.thread().intra_threads() {
builder = builder
.with_intra_threads(t)
.map_err(|e| Error::Ort(ort::Error::from(e)))?;
}
if let Some(t) = opts.thread().inter_threads() {
builder = builder
.with_inter_threads(t)
.map_err(|e| Error::Ort(ort::Error::from(e)))?;
}
#[allow(unused_mut)]
let mut eps: Vec<ort::execution_providers::ExecutionProviderDispatch> = Vec::new();
#[cfg(feature = "cuda")]
{
eps.push(ort::execution_providers::CUDAExecutionProvider::default().build());
}
#[cfg(feature = "tensorrt")]
{
eps.push(ort::execution_providers::TensorRTExecutionProvider::default().build());
}
#[cfg(feature = "directml")]
{
eps.push(ort::execution_providers::DirectMLExecutionProvider::default().build());
}
#[cfg(feature = "rocm")]
{
eps.push(ort::execution_providers::ROCmExecutionProvider::default().build());
}
#[cfg(feature = "coreml")]
{
eps.push(ort::execution_providers::CoreMLExecutionProvider::default().build());
}
if !eps.is_empty() {
builder = builder
.with_execution_providers(eps)
.map_err(|e| Error::Ort(ort::Error::from(e)))?;
}
let session = builder.commit_from_file(graph).map_err(Error::Ort)?;
Ok(session)
}
#[allow(dead_code)]
pub(crate) fn check_outlet(
outlets: &[Outlet],
name: &'static str,
expected_dtype: TensorElementType,
expected_shape: &[i64],
) -> Result<()> {
let outlet = outlets
.iter()
.find(|o| o.name() == name)
.ok_or(Error::SessionShapeMismatch {
input: name,
expected: "outlet present in session",
got: vec![],
})?;
match outlet.dtype() {
ValueType::Tensor { ty, shape, .. } => {
if *ty != expected_dtype {
return Err(Error::SessionContractMismatch {
input: name,
expected: "matching tensor dtype",
got: *ty,
});
}
let actual: &[i64] = shape;
if actual.len() != expected_shape.len() {
return Err(Error::SessionShapeMismatch {
input: name,
expected: "matching tensor rank",
got: actual.to_vec(),
});
}
for (i, &want) in expected_shape.iter().enumerate() {
let got = actual[i];
if want == -1 {
if got != -1 {
return Err(Error::SessionShapeMismatch {
input: name,
expected: "dynamic axis required",
got: actual.to_vec(),
});
}
} else {
if got != -1 && got != want {
return Err(Error::SessionShapeMismatch {
input: name,
expected: "matching static dim",
got: actual.to_vec(),
});
}
}
}
Ok(())
}
_ => Err(Error::SessionShapeMismatch {
input: name,
expected: "tensor",
got: vec![],
}),
}
}
#[allow(dead_code)]
pub(crate) fn validate_vision_session(s: &Session) -> Result<()> {
check_outlet(
s.inputs(),
"pixel_values",
TensorElementType::Float32,
&[-1, -1, 768],
)?;
check_outlet(
s.inputs(),
"pixel_attention_mask",
TensorElementType::Int64,
&[-1, -1],
)?;
check_outlet(
s.inputs(),
"spatial_shapes",
TensorElementType::Int64,
&[-1, 2],
)?;
check_outlet(
s.outputs(),
"image_features",
TensorElementType::Float32,
&[-1, 1024],
)?;
Ok(())
}
#[allow(dead_code)]
pub(crate) fn validate_embed_session(s: &Session) -> Result<()> {
check_outlet(s.inputs(), "input_ids", TensorElementType::Int64, &[-1, -1])?;
check_outlet(
s.outputs(),
"inputs_embeds",
TensorElementType::Float32,
&[-1, -1, 1024],
)?;
Ok(())
}
#[allow(dead_code)]
pub(crate) fn validate_decoder_session(s: &Session) -> Result<()> {
check_outlet(
s.inputs(),
"inputs_embeds",
TensorElementType::Float32,
&[-1, -1, 1024],
)?;
check_outlet(
s.inputs(),
"attention_mask",
TensorElementType::Int64,
&[-1, -1],
)?;
if s.inputs().iter().any(|o| o.name() == "position_ids") {
return Err(Error::SessionShapeMismatch {
input: "position_ids",
expected: "must NOT be a required input (Decoder::step doesn't pass it)",
got: vec![],
});
}
let cache = collect_cache_inputs(s.inputs())?;
if cache.conv.len() != 10 || cache.attn.len() != 12 {
return Err(Error::DecoderCacheMismatch {
expected_conv: 10,
expected_attn: 12,
got_conv: cache.conv.len(),
got_attn: cache.attn.len(),
});
}
const EXPECTED_CONV: &[u32] = &[0, 1, 3, 4, 6, 7, 9, 11, 13, 15];
const EXPECTED_ATTN: &[u32] = &[2, 5, 8, 10, 12, 14];
let mut conv_indices: Vec<u32> = cache
.conv
.iter()
.filter_map(|n| parse_conv_index(n))
.collect();
conv_indices.sort_unstable();
if conv_indices != EXPECTED_CONV {
return Err(Error::SessionShapeMismatch {
input: "past_conv.*",
expected: "sparse indices [0,1,3,4,6,7,9,11,13,15]",
got: conv_indices.into_iter().map(i64::from).collect(),
});
}
let mut attn_indices: Vec<u32> = cache
.attn
.iter()
.filter_map(|n| parse_attn_index(n))
.collect();
attn_indices.sort_unstable();
attn_indices.dedup();
if attn_indices != EXPECTED_ATTN {
return Err(Error::SessionShapeMismatch {
input: "past_key_values.*.{key,value}",
expected: "sparse indices [2,5,8,10,12,14]",
got: attn_indices.into_iter().map(i64::from).collect(),
});
}
for name in &cache.conv {
let owned: &'static str = leak_static(name);
check_outlet(s.inputs(), owned, TensorElementType::Float32, &[1, 1024, 3])?;
let present = format!(
"present_conv.{}",
parse_conv_index(name).unwrap_or(u32::MAX)
);
let present_owned: &'static str = leak_static(&present);
check_outlet(
s.outputs(),
present_owned,
TensorElementType::Float32,
&[1, 1024, 3],
)?;
}
for name in &cache.attn {
let owned: &'static str = leak_static(name);
check_outlet(
s.inputs(),
owned,
TensorElementType::Float32,
&[1, 8, -1, 64],
)?;
if let Some(rest) = name.strip_prefix("past_key_values.") {
let present = format!("present.{rest}");
let present_owned: &'static str = leak_static(&present);
check_outlet(
s.outputs(),
present_owned,
TensorElementType::Float32,
&[1, 8, -1, 64],
)?;
}
}
check_outlet(
s.outputs(),
"logits",
TensorElementType::Float32,
&[-1, -1, 65536],
)?;
Ok(())
}
fn leak_static(s: &str) -> &'static str {
Box::leak(s.to_string().into_boxed_str())
}
#[allow(dead_code)]
pub(crate) struct CacheInputs {
pub(crate) conv: Vec<String>,
pub(crate) attn: Vec<String>,
}
#[allow(dead_code)]
pub(crate) fn collect_cache_inputs(outlets: &[Outlet]) -> Result<CacheInputs> {
let mut conv = Vec::new();
let mut attn = Vec::new();
for o in outlets {
let n = o.name();
if n.starts_with("past_conv.") {
conv.push(n.to_string());
} else if n.starts_with("past_key_values.") {
attn.push(n.to_string());
}
}
Ok(CacheInputs { conv, attn })
}
fn parse_conv_index(name: &str) -> Option<u32> {
name.strip_prefix("past_conv.")?.parse().ok()
}
#[allow(dead_code)]
fn parse_attn_index(name: &str) -> Option<u32> {
let rest = name.strip_prefix("past_key_values.")?;
let dot = rest.find('.')?;
rest[..dot].parse().ok()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn parse_conv_index_works() {
assert_eq!(parse_conv_index("past_conv.0"), Some(0));
assert_eq!(parse_conv_index("past_conv.15"), Some(15));
assert_eq!(parse_conv_index("past_kv.0"), None);
assert_eq!(parse_conv_index("past_conv."), None); assert_eq!(parse_conv_index("past_conv.foo"), None); }
#[test]
fn parse_attn_index_works() {
assert_eq!(parse_attn_index("past_key_values.2.key"), Some(2));
assert_eq!(parse_attn_index("past_key_values.14.value"), Some(14));
assert_eq!(parse_attn_index("past_conv.0"), None);
assert_eq!(parse_attn_index("past_key_values.2"), None); }
}