1use ort::session::builder::{GraphOptimizationLevel, SessionBuilder};
2use ort::session::Session;
3use ort::Result as OrtResult;
4use std::sync::Once;
5
6static INIT: Once = Once::new();
7
8#[derive(Debug)]
9pub struct RuntimeConfig {
10 pub inter_threads: usize,
15
16 pub intra_threads: usize,
21
22 pub optimization_level: GraphOptimizationLevel,
27}
28
29impl Default for RuntimeConfig {
30 fn default() -> Self {
31 Self {
32 inter_threads: 0, intra_threads: 0, optimization_level: GraphOptimizationLevel::Level3,
35 }
36 }
37}
38
39impl Clone for RuntimeConfig {
40 fn clone(&self) -> Self {
41 Self {
42 inter_threads: self.inter_threads,
43 intra_threads: self.intra_threads,
44 optimization_level: match self.optimization_level {
45 GraphOptimizationLevel::Level1 => GraphOptimizationLevel::Level1,
46 GraphOptimizationLevel::Level2 => GraphOptimizationLevel::Level2,
47 GraphOptimizationLevel::Level3 => GraphOptimizationLevel::Level3,
48 GraphOptimizationLevel::Disable => GraphOptimizationLevel::Disable,
49 },
50 }
51 }
52}
53
54fn init_onnx_environment() -> OrtResult<()> {
55 ort::init()
56 .with_name("prefrontal")
57 .commit()?;
58 Ok(())
59}
60
61pub fn ensure_initialized() -> OrtResult<()> {
62 INIT.call_once(|| {
63 init_onnx_environment().expect("Failed to initialize ONNX Runtime environment");
64 });
65 Ok(())
66}
67
68pub fn create_session_builder(config: &RuntimeConfig) -> OrtResult<SessionBuilder> {
69 ensure_initialized()?;
70 let mut builder = Session::builder()?;
71
72 if config.inter_threads > 0 {
74 builder = builder.with_inter_threads(config.inter_threads)?;
75 }
76 if config.intra_threads > 0 {
77 builder = builder.with_intra_threads(config.intra_threads)?;
78 }
79
80 let opt_level = match config.optimization_level {
82 GraphOptimizationLevel::Level1 => GraphOptimizationLevel::Level1,
83 GraphOptimizationLevel::Level2 => GraphOptimizationLevel::Level2,
84 GraphOptimizationLevel::Level3 => GraphOptimizationLevel::Level3,
85 GraphOptimizationLevel::Disable => GraphOptimizationLevel::Disable,
86 };
87 builder = builder.with_optimization_level(opt_level)?;
88
89 Ok(builder)
90}
91
92#[cfg(test)]
93mod tests {
94 use super::*;
95
96 #[test]
97 fn test_environment_initialization() {
98 assert!(ensure_initialized().is_ok());
99 assert!(ensure_initialized().is_ok()); }
101
102 #[test]
103 fn test_session_builder_config() {
104 let config = RuntimeConfig {
105 inter_threads: 2,
106 intra_threads: 2,
107 optimization_level: GraphOptimizationLevel::Level1,
108 };
109 let builder = create_session_builder(&config);
110 assert!(builder.is_ok());
111 }
112
113 #[test]
114 fn test_default_config() {
115 let config = RuntimeConfig::default();
116 assert_eq!(config.inter_threads, 0);
117 assert_eq!(config.intra_threads, 0);
118 match config.optimization_level {
119 GraphOptimizationLevel::Level3 => (),
120 _ => panic!("Expected GraphOptimizationLevel::Level3"),
121 }
122 }
123}