use ndarray::ArrayD;
use std::collections::HashMap;
use std::path::Path;
use super::types::ExecutorResult;
use crate::runtime_adapter::AdapterError;
pub trait InferenceSession: Send + Sync {
fn run(
&self,
inputs: HashMap<String, ArrayD<f32>>,
) -> ExecutorResult<HashMap<String, ArrayD<f32>>>;
fn output_names(&self) -> &[String];
fn input_names(&self) -> &[String];
}
pub trait SessionFactory: Send + Sync {
type Session: InferenceSession;
fn create(&self, model_path: &Path) -> ExecutorResult<Self::Session>;
}
use crate::runtime_adapter::onnx::{ExecutionProviderKind, ONNXSession, SessionOptions};
pub struct OnnxInferenceSession {
session: ONNXSession,
}
impl InferenceSession for OnnxInferenceSession {
fn run(
&self,
inputs: HashMap<String, ArrayD<f32>>,
) -> ExecutorResult<HashMap<String, ArrayD<f32>>> {
self.session.run(inputs)
}
fn output_names(&self) -> &[String] {
self.session.output_names()
}
fn input_names(&self) -> &[String] {
self.session.input_names()
}
}
#[derive(Default)]
pub struct OnnxSessionFactory;
impl OnnxSessionFactory {
pub fn create_session(
model_path: &Path,
execution_provider: ExecutionProviderKind,
options: SessionOptions,
) -> ExecutorResult<ONNXSession> {
let path_str = model_path
.to_str()
.ok_or_else(|| AdapterError::InvalidInput("Invalid model path encoding".to_string()))?;
let session = ONNXSession::build(path_str, execution_provider, options)?;
Ok(session)
}
}
impl SessionFactory for OnnxSessionFactory {
type Session = OnnxInferenceSession;
fn create(&self, model_path: &Path) -> ExecutorResult<Self::Session> {
let session = Self::create_session(
model_path,
ExecutionProviderKind::Cpu,
SessionOptions::default(),
)?;
Ok(OnnxInferenceSession { session })
}
}
#[cfg(test)]
mod tests {
use super::*;
use ndarray::Array1;
pub struct MockSession {
pub input_names: Vec<String>,
pub output_names: Vec<String>,
pub outputs: HashMap<String, ArrayD<f32>>,
}
impl MockSession {
pub fn new() -> Self {
Self {
input_names: vec!["input".to_string()],
output_names: vec!["output".to_string()],
outputs: HashMap::new(),
}
}
pub fn with_output(mut self, name: &str, tensor: ArrayD<f32>) -> Self {
self.outputs.insert(name.to_string(), tensor);
self
}
pub fn with_input_names(mut self, names: Vec<String>) -> Self {
self.input_names = names;
self
}
pub fn with_output_names(mut self, names: Vec<String>) -> Self {
self.output_names = names;
self
}
}
impl InferenceSession for MockSession {
fn run(
&self,
_inputs: HashMap<String, ArrayD<f32>>,
) -> ExecutorResult<HashMap<String, ArrayD<f32>>> {
Ok(self.outputs.clone())
}
fn output_names(&self) -> &[String] {
&self.output_names
}
fn input_names(&self) -> &[String] {
&self.input_names
}
}
pub struct MockSessionFactory {
pub session_template: MockSession,
}
impl MockSessionFactory {
pub fn new() -> Self {
Self {
session_template: MockSession::new(),
}
}
pub fn with_output(mut self, name: &str, tensor: ArrayD<f32>) -> Self {
self.session_template
.outputs
.insert(name.to_string(), tensor);
self
}
}
impl SessionFactory for MockSessionFactory {
type Session = MockSession;
fn create(&self, _model_path: &Path) -> ExecutorResult<Self::Session> {
Ok(MockSession {
input_names: self.session_template.input_names.clone(),
output_names: self.session_template.output_names.clone(),
outputs: self.session_template.outputs.clone(),
})
}
}
#[test]
fn test_mock_session_run() {
let output_tensor = Array1::from_vec(vec![1.0, 2.0, 3.0]).into_dyn();
let session = MockSession::new().with_output("output", output_tensor.clone());
let inputs = HashMap::new();
let result = session.run(inputs).unwrap();
assert_eq!(result.get("output").unwrap(), &output_tensor);
}
#[test]
fn test_mock_session_names() {
let session = MockSession::new()
.with_input_names(vec!["tokens".to_string(), "mask".to_string()])
.with_output_names(vec!["logits".to_string()]);
assert_eq!(session.input_names(), &["tokens", "mask"]);
assert_eq!(session.output_names(), &["logits"]);
}
#[test]
fn test_mock_factory_creates_session() {
let output_tensor = Array1::from_vec(vec![4.0, 5.0, 6.0]).into_dyn();
let factory = MockSessionFactory::new().with_output("result", output_tensor.clone());
let session = factory.create(Path::new("/fake/model.onnx")).unwrap();
let result = session.run(HashMap::new()).unwrap();
assert_eq!(result.get("result").unwrap(), &output_tensor);
}
}