use ort::session::builder::GraphOptimizationLevel;
use crate::{Result, Session};
#[derive(Debug, Default)]
pub struct Builder {
inter_threads: Option<usize>,
intra_threads: Option<usize>,
optimization_level: Option<GraphOptimizationLevel>,
parallel_execution: Option<bool>,
}
impl Builder {
pub fn with_inter_threads(mut self, num_threads: usize) -> Self {
self.inter_threads = Some(num_threads);
self
}
pub fn with_intra_threads(mut self, num_threads: usize) -> Self {
self.intra_threads = Some(num_threads);
self
}
pub fn with_optimization_level(mut self, opt_level: GraphOptimizationLevel) -> Self {
self.optimization_level = Some(opt_level);
self
}
pub fn with_parallel_execution(mut self, parallel_execution: bool) -> Self {
self.parallel_execution = Some(parallel_execution);
self
}
pub fn build(self) -> Result<Session> {
let mut session = ort::session::Session::builder()?;
let Builder { inter_threads, intra_threads, optimization_level, parallel_execution } = self;
if let Some(num_threads) = inter_threads {
session = session.with_inter_threads(num_threads)?;
}
if let Some(num_threads) = intra_threads {
session = session.with_intra_threads(num_threads)?;
}
if let Some(opt_level) = optimization_level {
session = session.with_optimization_level(opt_level)?;
}
if let Some(parallel_execution) = parallel_execution {
session = session.with_parallel_execution(parallel_execution)?;
}
let session = session.commit_from_memory(include_bytes!("model.onnx"))?;
Ok(Session { session })
}
}