#![allow(missing_docs)]
use crate::backends::gliner2_fastino::errors::Error;
use crate::backends::hf_loader;
use std::path::Path;
use std::sync::{Arc, Mutex};
pub struct Sessions {
pub encoder: SessionSlot,
pub token_gather: SessionSlot,
pub span_rep: SessionSlot,
pub schema_gather: SessionSlot,
pub count_pred_argmax: SessionSlot,
pub count_lstm_fixed: SessionSlot,
pub scorer: SessionSlot,
pub classifier: SessionSlot,
}
#[derive(Debug)]
pub struct SessionSlot {
inner: Arc<Mutex<ort::session::Session>>,
}
impl SessionSlot {
pub fn from_path_with_cfg(
model_path: &Path,
cfg: hf_loader::OnnxSessionConfig,
) -> Result<Self, Error> {
let session = hf_loader::create_onnx_session(model_path, cfg)
.map_err(|e| Error::Tokenizer(format!("session {}: {e}", model_path.display())))?;
Ok(Self {
inner: Arc::new(Mutex::new(session)),
})
}
#[allow(dead_code)] pub fn from_path(model_path: &Path) -> Result<Self, Error> {
Self::from_path_with_cfg(model_path, hf_loader::OnnxSessionConfig::default())
}
pub fn with_session<F, R>(&self, f: F) -> R
where
F: FnOnce(&mut ort::session::Session) -> R,
{
let mut guard = self.inner.lock().unwrap_or_else(|e| e.into_inner());
f(&mut guard)
}
}
impl std::fmt::Debug for Sessions {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Sessions").field("count", &8usize).finish()
}
}
impl Sessions {
#[allow(dead_code)] pub fn from_dir_with_cfg(
model_dir: &Path,
cfg: hf_loader::OnnxSessionConfig,
) -> Result<(Self, std::path::PathBuf), Error> {
Self::from_dir_with_cfg_mode(model_dir, cfg, super::ExecutionMode::Standard)
}
pub fn from_dir_with_cfg_mode(
model_dir: &Path,
cfg: hf_loader::OnnxSessionConfig,
mode: super::ExecutionMode,
) -> Result<(Self, std::path::PathBuf), Error> {
for (subdir, suffix) in [
("fp32_v2", "_fp32.onnx"),
("fp16_v2", "_fp16.onnx"),
("fp32", "_fp32.onnx"), ("fp16", "_fp16.onnx"), ] {
let try_dir = model_dir.join(subdir);
if !try_dir.is_dir() {
continue;
}
let resolve = |name: &str| -> std::path::PathBuf {
if matches!(mode, super::ExecutionMode::IoBinding) {
let io = try_dir.join(format!("{name}_iobinding{suffix}"));
if io.exists() {
return io;
}
}
try_dir.join(format!("{name}{suffix}"))
};
let all_present = [
"encoder",
"token_gather",
"span_rep",
"schema_gather",
"count_pred_argmax",
"count_lstm_fixed",
"scorer",
"classifier",
]
.iter()
.all(|n| resolve(n).exists());
if !all_present {
continue;
}
return Ok((
Self {
encoder: SessionSlot::from_path_with_cfg(&resolve("encoder"), cfg.clone())?,
token_gather: SessionSlot::from_path_with_cfg(
&resolve("token_gather"),
cfg.clone(),
)?,
span_rep: SessionSlot::from_path_with_cfg(&resolve("span_rep"), cfg.clone())?,
schema_gather: SessionSlot::from_path_with_cfg(
&resolve("schema_gather"),
cfg.clone(),
)?,
count_pred_argmax: SessionSlot::from_path_with_cfg(
&resolve("count_pred_argmax"),
cfg.clone(),
)?,
count_lstm_fixed: SessionSlot::from_path_with_cfg(
&resolve("count_lstm_fixed"),
cfg.clone(),
)?,
scorer: SessionSlot::from_path_with_cfg(&resolve("scorer"), cfg.clone())?,
classifier: SessionSlot::from_path_with_cfg(
&resolve("classifier"),
cfg.clone(),
)?,
},
try_dir,
));
}
Err(Error::Tokenizer(format!(
"no complete v2 session set found under {} (looked in fp32_v2/, fp16_v2/, fp32/, fp16/)",
model_dir.display()
)))
}
#[allow(dead_code)] pub fn from_dir(model_dir: &Path) -> Result<(Self, std::path::PathBuf), Error> {
Self::from_dir_with_cfg(model_dir, hf_loader::OnnxSessionConfig::default())
}
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::tempdir;
#[test]
fn from_dir_fails_clearly_on_empty_dir() {
let dir = tempdir().unwrap();
let err = Sessions::from_dir(dir.path()).unwrap_err();
let msg = err.to_string();
assert!(msg.contains("no complete v2 session set"), "got: {msg}");
assert!(msg.contains("fp32") || msg.contains("fp16"), "got: {msg}");
}
#[test]
fn from_dir_fails_clearly_on_partial_layout() {
let dir = tempdir().unwrap();
std::fs::create_dir_all(dir.path().join("fp32_v2")).unwrap();
std::fs::write(dir.path().join("fp32_v2/encoder_fp32.onnx"), b"").unwrap();
let err = Sessions::from_dir(dir.path()).unwrap_err();
assert!(err.to_string().contains("no complete v2 session set"));
}
#[test]
fn iobinding_mode_prefers_iobinding_variant_when_present() {
let dir = tempdir().unwrap();
let v2 = dir.path().join("fp32_v2");
std::fs::create_dir_all(&v2).unwrap();
for n in [
"encoder",
"token_gather",
"span_rep",
"schema_gather",
"count_pred_argmax",
"count_lstm_fixed",
"scorer",
"classifier",
] {
std::fs::write(v2.join(format!("{n}_iobinding_fp32.onnx")), b"").unwrap();
}
let err_std = Sessions::from_dir_with_cfg_mode(
dir.path(),
hf_loader::OnnxSessionConfig::default(),
super::super::ExecutionMode::Standard,
)
.unwrap_err();
assert!(
err_std.to_string().contains("no complete v2 session set"),
"Standard mode should not see iobinding-only variants. Got: {err_std}"
);
let err_io = Sessions::from_dir_with_cfg_mode(
dir.path(),
hf_loader::OnnxSessionConfig::default(),
super::super::ExecutionMode::IoBinding,
)
.unwrap_err();
let msg = err_io.to_string();
assert!(
!msg.contains("no complete v2 session set"),
"IoBinding mode should resolve iobinding variants and advance past 'no complete' check. Got: {msg}"
);
}
#[test]
fn iobinding_mode_falls_back_to_standard_when_iobinding_missing() {
let dir = tempdir().unwrap();
let v2 = dir.path().join("fp32_v2");
std::fs::create_dir_all(&v2).unwrap();
for n in [
"encoder",
"token_gather",
"span_rep",
"schema_gather",
"count_pred_argmax",
"count_lstm_fixed",
"scorer",
"classifier",
] {
std::fs::write(v2.join(format!("{n}_fp32.onnx")), b"").unwrap();
}
let err_io = Sessions::from_dir_with_cfg_mode(
dir.path(),
hf_loader::OnnxSessionConfig::default(),
super::super::ExecutionMode::IoBinding,
)
.unwrap_err();
let msg = err_io.to_string();
assert!(
!msg.contains("no complete v2 session set"),
"IoBinding mode should fall back to standard variants. Got: {msg}"
);
}
}