1use crate::error::{Error, ErrorKind, Result};
21use crate::{ComputeUnits, Model};
22use std::path::PathBuf;
23
24pub enum ModelHandle {
51 Loaded {
53 model: Model,
55 compute_units: ComputeUnits,
57 },
58 Unloaded {
62 path: PathBuf,
64 compute_units: ComputeUnits,
66 },
67}
68
69impl ModelHandle {
70 pub fn load(
80 path: impl AsRef<std::path::Path>,
81 compute_units: ComputeUnits,
82 ) -> Result<Self> {
83 let model = Model::load(&path, compute_units)?;
84 Ok(Self::Loaded {
85 model,
86 compute_units,
87 })
88 }
89
90 pub fn from_model(model: Model, compute_units: ComputeUnits) -> Self {
95 Self::Loaded {
96 model,
97 compute_units,
98 }
99 }
100
101 pub fn is_loaded(&self) -> bool {
104 matches!(self, Self::Loaded { .. })
105 }
106
107 pub fn path(&self) -> &std::path::Path {
109 match self {
110 Self::Loaded { model, .. } => model.path(),
111 Self::Unloaded { path, .. } => path,
112 }
113 }
114
115 pub fn compute_units(&self) -> ComputeUnits {
117 match self {
118 Self::Loaded { compute_units, .. } | Self::Unloaded { compute_units, .. } => {
119 *compute_units
120 }
121 }
122 }
123
124 pub fn model(&self) -> Result<&Model> {
130 match self {
131 Self::Loaded { model, .. } => Ok(model),
132 Self::Unloaded { .. } => Err(Error::new(
133 ErrorKind::ModelLoad,
134 "model is unloaded; call reload() first",
135 )),
136 }
137 }
138
139 pub fn unload(self) -> Result<Self> {
152 match self {
153 Self::Loaded {
154 model,
155 compute_units,
156 } => {
157 let path = model.path().to_path_buf();
158 drop(model);
161 Ok(Self::Unloaded {
162 path,
163 compute_units,
164 })
165 }
166 Self::Unloaded { .. } => Err(Error::new(
167 ErrorKind::ModelLoad,
168 "model is already unloaded",
169 )),
170 }
171 }
172
173 pub fn reload(self) -> Result<Self> {
183 match self {
184 Self::Unloaded {
185 path,
186 compute_units,
187 } => {
188 let model = Model::load(&path, compute_units)?;
189 Ok(Self::Loaded {
190 model,
191 compute_units,
192 })
193 }
194 Self::Loaded { .. } => Err(Error::new(
195 ErrorKind::ModelLoad,
196 "model is already loaded",
197 )),
198 }
199 }
200
201 pub fn predict(
209 &self,
210 inputs: &[(&str, &dyn crate::tensor::AsMultiArray)],
211 ) -> Result<crate::Prediction> {
212 self.model()?.predict(inputs)
213 }
214
215 pub fn inputs(&self) -> Result<Vec<crate::FeatureDescription>> {
221 Ok(self.model()?.inputs())
222 }
223
224 pub fn outputs(&self) -> Result<Vec<crate::FeatureDescription>> {
230 Ok(self.model()?.outputs())
231 }
232
233 pub fn metadata(&self) -> Result<crate::ModelMetadata> {
239 Ok(self.model()?.metadata())
240 }
241}
242
243impl std::fmt::Debug for ModelHandle {
244 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
245 match self {
246 Self::Loaded {
247 model,
248 compute_units,
249 } => f
250 .debug_struct("ModelHandle")
251 .field("state", &"Loaded")
252 .field("path", &model.path())
253 .field("compute_units", compute_units)
254 .finish(),
255 Self::Unloaded {
256 path,
257 compute_units,
258 } => f
259 .debug_struct("ModelHandle")
260 .field("state", &"Unloaded")
261 .field("path", path)
262 .field("compute_units", compute_units)
263 .finish(),
264 }
265 }
266}
267
268#[cfg(test)]
269mod tests {
270 use super::*;
271
272 #[test]
273 fn unloaded_handle_is_not_loaded() {
274 let handle = ModelHandle::Unloaded {
275 path: PathBuf::from("/test.mlmodelc"),
276 compute_units: ComputeUnits::All,
277 };
278 assert!(!handle.is_loaded());
279 }
280
281 #[test]
282 fn unloaded_handle_preserves_path() {
283 let handle = ModelHandle::Unloaded {
284 path: PathBuf::from("/models/my_model.mlmodelc"),
285 compute_units: ComputeUnits::CpuAndGpu,
286 };
287 assert_eq!(
288 handle.path(),
289 std::path::Path::new("/models/my_model.mlmodelc")
290 );
291 }
292
293 #[test]
294 fn unloaded_handle_preserves_compute_units() {
295 let handle = ModelHandle::Unloaded {
296 path: PathBuf::from("/test.mlmodelc"),
297 compute_units: ComputeUnits::CpuAndNeuralEngine,
298 };
299 assert_eq!(handle.compute_units(), ComputeUnits::CpuAndNeuralEngine);
300 }
301
302 #[test]
303 fn unloaded_handle_rejects_model_access() {
304 let handle = ModelHandle::Unloaded {
305 path: PathBuf::from("/test.mlmodelc"),
306 compute_units: ComputeUnits::All,
307 };
308 let err = handle.model().unwrap_err();
309 assert_eq!(err.kind(), &ErrorKind::ModelLoad);
310 assert!(err.message().contains("unloaded"));
311 }
312
313 #[test]
314 fn unloaded_handle_rejects_double_unload() {
315 let handle = ModelHandle::Unloaded {
316 path: PathBuf::from("/test.mlmodelc"),
317 compute_units: ComputeUnits::All,
318 };
319 let err = handle.unload().unwrap_err();
320 assert_eq!(err.kind(), &ErrorKind::ModelLoad);
321 assert!(err.message().contains("already unloaded"));
322 }
323
324 #[test]
325 fn load_nonexistent_model_fails() {
326 let result = ModelHandle::load("/nonexistent.mlmodelc", ComputeUnits::All);
327 assert!(result.is_err());
328 }
329
330 #[test]
331 fn debug_format_unloaded() {
332 let handle = ModelHandle::Unloaded {
333 path: PathBuf::from("/test.mlmodelc"),
334 compute_units: ComputeUnits::All,
335 };
336 let debug = format!("{:?}", handle);
337 assert!(debug.contains("Unloaded"));
338 assert!(debug.contains("/test.mlmodelc"));
339 }
340
341 #[test]
342 fn unloaded_handle_rejects_inputs() {
343 let handle = ModelHandle::Unloaded {
344 path: PathBuf::from("/test.mlmodelc"),
345 compute_units: ComputeUnits::All,
346 };
347 assert!(handle.inputs().is_err());
348 }
349
350 #[test]
351 fn unloaded_handle_rejects_outputs() {
352 let handle = ModelHandle::Unloaded {
353 path: PathBuf::from("/test.mlmodelc"),
354 compute_units: ComputeUnits::All,
355 };
356 assert!(handle.outputs().is_err());
357 }
358
359 #[test]
360 fn unloaded_handle_rejects_metadata() {
361 let handle = ModelHandle::Unloaded {
362 path: PathBuf::from("/test.mlmodelc"),
363 compute_units: ComputeUnits::All,
364 };
365 assert!(handle.metadata().is_err());
366 }
367
368 #[cfg(target_vendor = "apple")]
369 mod apple_tests {
370 use super::*;
371
372 #[test]
373 fn load_unload_reload_cycle() {
374 let model_path = std::path::Path::new(env!("CARGO_MANIFEST_DIR"))
375 .join("tests/fixtures/test_linear.mlmodelc");
376 if !model_path.exists() {
377 return;
379 }
380
381 let handle = ModelHandle::load(&model_path, ComputeUnits::All).unwrap();
382 assert!(handle.is_loaded());
383 assert!(handle.model().is_ok());
384
385 let handle = handle.unload().unwrap();
387 assert!(!handle.is_loaded());
388 assert!(handle.model().is_err());
389 assert_eq!(handle.path(), model_path);
390
391 let handle = handle.reload().unwrap();
393 assert!(handle.is_loaded());
394 assert!(handle.model().is_ok());
395 }
396
397 #[test]
398 fn loaded_handle_rejects_double_reload() {
399 let model_path = std::path::Path::new(env!("CARGO_MANIFEST_DIR"))
400 .join("tests/fixtures/test_linear.mlmodelc");
401 if !model_path.exists() {
402 return;
403 }
404
405 let handle = ModelHandle::load(&model_path, ComputeUnits::All).unwrap();
406 let err = handle.reload().unwrap_err();
407 assert_eq!(err.kind(), &ErrorKind::ModelLoad);
408 assert!(err.message().contains("already loaded"));
409 }
410
411 #[test]
412 fn from_model_wraps_existing() {
413 let model_path = std::path::Path::new(env!("CARGO_MANIFEST_DIR"))
414 .join("tests/fixtures/test_linear.mlmodelc");
415 if !model_path.exists() {
416 return;
417 }
418
419 let model = Model::load(&model_path, ComputeUnits::All).unwrap();
420 let handle = ModelHandle::from_model(model, ComputeUnits::All);
421 assert!(handle.is_loaded());
422 assert_eq!(handle.compute_units(), ComputeUnits::All);
423 }
424
425 #[test]
426 fn debug_format_loaded() {
427 let model_path = std::path::Path::new(env!("CARGO_MANIFEST_DIR"))
428 .join("tests/fixtures/test_linear.mlmodelc");
429 if !model_path.exists() {
430 return;
431 }
432
433 let handle = ModelHandle::load(&model_path, ComputeUnits::All).unwrap();
434 let debug = format!("{:?}", handle);
435 assert!(debug.contains("Loaded"));
436 }
437 }
438}