1use crate::errors::{Result, VisionError};
7use std::path::Path;
8use std::sync::Arc;
9
10#[derive(Debug, Clone, Copy, PartialEq, Eq)]
12pub enum LoadingStrategy {
13 Standard,
15 MemoryMapped,
17 Lazy,
19}
20
21#[derive(Debug, Clone)]
23pub struct ModelLoadingConfig {
24 pub strategy: LoadingStrategy,
26 pub enable_sharing: bool,
28 pub prefetch: bool,
30 pub use_huge_pages: bool,
32}
33
34impl Default for ModelLoadingConfig {
35 fn default() -> Self {
36 Self {
37 strategy: LoadingStrategy::MemoryMapped,
38 enable_sharing: true,
39 prefetch: true,
40 use_huge_pages: false,
41 }
42 }
43}
44
45pub struct ModelLoader {
47 config: ModelLoadingConfig,
48 #[allow(dead_code)]
49 cache: Arc<ModelCache>,
50}
51
52impl ModelLoader {
53 pub fn new() -> Self {
55 Self {
56 config: ModelLoadingConfig::default(),
57 cache: Arc::new(ModelCache::new()),
58 }
59 }
60
61 pub fn with_config(config: ModelLoadingConfig) -> Self {
63 Self {
64 config,
65 cache: Arc::new(ModelCache::new()),
66 }
67 }
68
69 pub fn load_model(&self, model_path: &Path) -> Result<ModelHandle> {
75 if !model_path.exists() {
76 return Err(VisionError::config(format!(
77 "Model file not found: {}",
78 model_path.display()
79 )));
80 }
81
82 match self.config.strategy {
83 LoadingStrategy::Standard => self.load_standard(model_path),
84 LoadingStrategy::MemoryMapped => self.load_memory_mapped(model_path),
85 LoadingStrategy::Lazy => self.load_lazy(model_path),
86 }
87 }
88
89 fn load_standard(&self, model_path: &Path) -> Result<ModelHandle> {
91 let file_size = std::fs::metadata(model_path)
92 .map_err(|e| VisionError::config(format!("Failed to read model metadata: {}", e)))?
93 .len();
94
95 Ok(ModelHandle {
96 path: model_path.to_path_buf(),
97 size_bytes: file_size,
98 strategy: LoadingStrategy::Standard,
99 is_loaded: true,
100 })
101 }
102
103 fn load_memory_mapped(&self, model_path: &Path) -> Result<ModelHandle> {
105 let file_size = std::fs::metadata(model_path)
106 .map_err(|e| VisionError::config(format!("Failed to read model metadata: {}", e)))?
107 .len();
108
109 Ok(ModelHandle {
116 path: model_path.to_path_buf(),
117 size_bytes: file_size,
118 strategy: LoadingStrategy::MemoryMapped,
119 is_loaded: true,
120 })
121 }
122
123 fn load_lazy(&self, model_path: &Path) -> Result<ModelHandle> {
125 let file_size = std::fs::metadata(model_path)
126 .map_err(|e| VisionError::config(format!("Failed to read model metadata: {}", e)))?
127 .len();
128
129 Ok(ModelHandle {
130 path: model_path.to_path_buf(),
131 size_bytes: file_size,
132 strategy: LoadingStrategy::Lazy,
133 is_loaded: false,
134 })
135 }
136
137 pub fn memory_stats(&self) -> MemoryStats {
139 MemoryStats {
140 total_mapped_bytes: 0,
141 active_models: 0,
142 cache_hits: 0,
143 cache_misses: 0,
144 }
145 }
146}
147
148impl Default for ModelLoader {
149 fn default() -> Self {
150 Self::new()
151 }
152}
153
154#[derive(Debug, Clone)]
156pub struct ModelHandle {
157 pub path: std::path::PathBuf,
159 pub size_bytes: u64,
161 pub strategy: LoadingStrategy,
163 pub is_loaded: bool,
165}
166
167impl ModelHandle {
168 pub fn size_mb(&self) -> f64 {
170 self.size_bytes as f64 / (1024.0 * 1024.0)
171 }
172
173 pub fn unload(&mut self) -> Result<()> {
175 self.is_loaded = false;
176 Ok(())
177 }
178
179 pub fn reload(&mut self) -> Result<()> {
181 self.is_loaded = true;
182 Ok(())
183 }
184}
185
186#[derive(Debug, Clone)]
188pub struct MemoryStats {
189 pub total_mapped_bytes: u64,
191 pub active_models: usize,
193 pub cache_hits: u64,
195 pub cache_misses: u64,
197}
198
199impl MemoryStats {
200 pub fn total_mapped_mb(&self) -> f64 {
202 self.total_mapped_bytes as f64 / (1024.0 * 1024.0)
203 }
204
205 pub fn cache_hit_rate(&self) -> f64 {
207 let total = self.cache_hits + self.cache_misses;
208 if total > 0 {
209 self.cache_hits as f64 / total as f64
210 } else {
211 0.0
212 }
213 }
214}
215
216struct ModelCache {
218 }
221
222impl ModelCache {
223 fn new() -> Self {
224 Self {}
225 }
226}
227
228#[cfg(unix)]
230mod platform {
231 use super::*;
232
233 #[allow(dead_code)]
235 pub fn create_mmap(_path: &Path, _size: u64) -> Result<()> {
236 Ok(())
238 }
239
240 #[allow(dead_code)]
242 pub fn prefetch_pages(_addr: *const u8, _size: usize) -> Result<()> {
243 Ok(())
245 }
246
247 #[allow(dead_code)]
249 pub fn enable_huge_pages(_addr: *const u8, _size: usize) -> Result<()> {
250 Ok(())
252 }
253}
254
255#[cfg(windows)]
256mod platform {
257 use super::*;
258
259 #[allow(dead_code)]
261 pub fn create_mmap(_path: &Path, _size: u64) -> Result<()> {
262 Ok(())
264 }
265
266 #[allow(dead_code)]
268 pub fn prefetch_pages(_addr: *const u8, _size: usize) -> Result<()> {
269 Ok(())
271 }
272
273 #[allow(dead_code)]
275 pub fn enable_large_pages(_addr: *const u8, _size: usize) -> Result<()> {
276 Ok(())
278 }
279}
280
281#[cfg(test)]
282mod tests {
283 use super::*;
284 use std::path::PathBuf;
285
286 #[test]
287 fn test_loading_config_default() {
288 let config = ModelLoadingConfig::default();
289 assert_eq!(config.strategy, LoadingStrategy::MemoryMapped);
290 assert!(config.enable_sharing);
291 assert!(config.prefetch);
292 }
293
294 #[test]
295 fn test_model_handle_size_mb() {
296 let handle = ModelHandle {
297 path: PathBuf::from("/test/model.onnx"),
298 size_bytes: 100 * 1024 * 1024, strategy: LoadingStrategy::MemoryMapped,
300 is_loaded: true,
301 };
302
303 assert_eq!(handle.size_mb(), 100.0);
304 }
305
306 #[test]
307 fn test_memory_stats_hit_rate() {
308 let stats = MemoryStats {
309 total_mapped_bytes: 1024 * 1024 * 1024,
310 active_models: 2,
311 cache_hits: 80,
312 cache_misses: 20,
313 };
314
315 assert_eq!(stats.cache_hit_rate(), 0.8);
316 assert_eq!(stats.total_mapped_mb(), 1024.0);
317 }
318
319 #[test]
320 fn test_model_handle_unload_reload() {
321 let mut handle = ModelHandle {
322 path: PathBuf::from("/test/model.onnx"),
323 size_bytes: 1024,
324 strategy: LoadingStrategy::Standard,
325 is_loaded: true,
326 };
327
328 assert!(handle.is_loaded);
329
330 handle.unload().unwrap();
331 assert!(!handle.is_loaded);
332
333 handle.reload().unwrap();
334 assert!(handle.is_loaded);
335 }
336
337 #[test]
338 fn test_model_loader_creation() {
339 let loader = ModelLoader::new();
340 assert_eq!(loader.config.strategy, LoadingStrategy::MemoryMapped);
341
342 let custom_config = ModelLoadingConfig {
343 strategy: LoadingStrategy::Standard,
344 enable_sharing: false,
345 prefetch: false,
346 use_huge_pages: false,
347 };
348
349 let custom_loader = ModelLoader::with_config(custom_config);
350 assert_eq!(custom_loader.config.strategy, LoadingStrategy::Standard);
351 assert!(!custom_loader.config.enable_sharing);
352 }
353}