prefrontal/
runtime.rs

1use ort::session::builder::{GraphOptimizationLevel, SessionBuilder};
2use ort::session::Session;
3use ort::Result as OrtResult;
4use std::sync::Once;
5
6static INIT: Once = Once::new();
7
8#[derive(Debug)]
9pub struct RuntimeConfig {
10    /// Number of threads to use for parallel model execution (0 = let ONNX Runtime decide)
11    /// 
12    /// This controls inter-op parallelism, which is parallelism between independent nodes in the model graph.
13    /// Setting this to 0 allows ONNX Runtime to automatically choose based on system resources.
14    pub inter_threads: usize,
15
16    /// Number of threads to use for parallel computation within nodes (0 = let ONNX Runtime decide)
17    /// 
18    /// This controls intra-op parallelism, which is parallelism within individual operations like matrix multiplication.
19    /// Setting this to 0 allows ONNX Runtime to automatically choose based on system resources.
20    pub intra_threads: usize,
21
22    /// The level of graph optimization to apply
23    /// 
24    /// Higher levels perform more aggressive optimizations but may take longer to compile.
25    /// Level3 (maximum) is recommended for production use.
26    pub optimization_level: GraphOptimizationLevel,
27}
28
29impl Default for RuntimeConfig {
30    fn default() -> Self {
31        Self {
32            inter_threads: 0, // Let ONNX Runtime decide
33            intra_threads: 0, // Let ONNX Runtime decide
34            optimization_level: GraphOptimizationLevel::Level3,
35        }
36    }
37}
38
39impl Clone for RuntimeConfig {
40    fn clone(&self) -> Self {
41        Self {
42            inter_threads: self.inter_threads,
43            intra_threads: self.intra_threads,
44            optimization_level: match self.optimization_level {
45                GraphOptimizationLevel::Level1 => GraphOptimizationLevel::Level1,
46                GraphOptimizationLevel::Level2 => GraphOptimizationLevel::Level2,
47                GraphOptimizationLevel::Level3 => GraphOptimizationLevel::Level3,
48                GraphOptimizationLevel::Disable => GraphOptimizationLevel::Disable,
49            },
50        }
51    }
52}
53
54fn init_onnx_environment() -> OrtResult<()> {
55    ort::init()
56        .with_name("prefrontal")
57        .commit()?;
58    Ok(())
59}
60
61pub fn ensure_initialized() -> OrtResult<()> {
62    INIT.call_once(|| {
63        init_onnx_environment().expect("Failed to initialize ONNX Runtime environment");
64    });
65    Ok(())
66}
67
68pub fn create_session_builder(config: &RuntimeConfig) -> OrtResult<SessionBuilder> {
69    ensure_initialized()?;
70    let mut builder = Session::builder()?;
71
72    // Configure threading
73    if config.inter_threads > 0 {
74        builder = builder.with_inter_threads(config.inter_threads)?;
75    }
76    if config.intra_threads > 0 {
77        builder = builder.with_intra_threads(config.intra_threads)?;
78    }
79
80    // Set optimization level
81    let opt_level = match config.optimization_level {
82        GraphOptimizationLevel::Level1 => GraphOptimizationLevel::Level1,
83        GraphOptimizationLevel::Level2 => GraphOptimizationLevel::Level2,
84        GraphOptimizationLevel::Level3 => GraphOptimizationLevel::Level3,
85        GraphOptimizationLevel::Disable => GraphOptimizationLevel::Disable,
86    };
87    builder = builder.with_optimization_level(opt_level)?;
88
89    Ok(builder)
90}
91
92#[cfg(test)]
93mod tests {
94    use super::*;
95
96    #[test]
97    fn test_environment_initialization() {
98        assert!(ensure_initialized().is_ok());
99        assert!(ensure_initialized().is_ok()); // Second call should be fine
100    }
101
102    #[test]
103    fn test_session_builder_config() {
104        let config = RuntimeConfig {
105            inter_threads: 2,
106            intra_threads: 2,
107            optimization_level: GraphOptimizationLevel::Level1,
108        };
109        let builder = create_session_builder(&config);
110        assert!(builder.is_ok());
111    }
112
113    #[test]
114    fn test_default_config() {
115        let config = RuntimeConfig::default();
116        assert_eq!(config.inter_threads, 0);
117        assert_eq!(config.intra_threads, 0);
118        match config.optimization_level {
119            GraphOptimizationLevel::Level3 => (),
120            _ => panic!("Expected GraphOptimizationLevel::Level3"),
121        }
122    }
123}