Skip to main content

coreml_native/
model_lifecycle.rs

1//! Model lifecycle management — unload/reload for memory efficiency.
2//!
3//! When managing multiple large models on resource-constrained devices,
4//! unloading idle models reclaims GPU/ANE memory without losing the
5//! ability to quickly reload them.
6//!
7//! # Design
8//!
9//! [`ModelHandle`] is a move-based state machine wrapping [`Model`]. State
10//! transitions (`unload`, `reload`) consume `self` and return a new handle,
11//! so the Rust type system prevents use-after-unload at compile time.
12//!
13//! ```text
14//!   load()       unload()       reload()
15//!  --------> Loaded -------> Unloaded -------> Loaded
16//!                <-----------          <-----------
17//!                  reload()              unload()
18//! ```
19
20use crate::error::{Error, ErrorKind, Result};
21use crate::{ComputeUnits, Model};
22use std::path::PathBuf;
23
24/// A model handle that supports unloading and reloading.
25///
26/// Wraps a [`Model`] with lifecycle management. Unloading releases the
27/// model's GPU/ANE resources (by dropping the inner `MLModel`) while
28/// retaining the filesystem path and compute-unit configuration for
29/// efficient reloading.
30///
31/// State transitions consume `self`, so the compiler prevents calling
32/// `predict` on an unloaded model or double-unloading.
33///
34/// # Example
35///
36/// ```ignore
37/// use coreml_native::{ComputeUnits, ModelHandle};
38///
39/// let handle = ModelHandle::load("model.mlmodelc", ComputeUnits::All)?;
40/// let prediction = handle.predict(&[("input", &tensor)])?;
41///
42/// // Free GPU/ANE memory when the model is idle.
43/// let handle = handle.unload()?;
44/// assert!(!handle.is_loaded());
45///
46/// // Reload when needed again.
47/// let handle = handle.reload()?;
48/// let prediction = handle.predict(&[("input", &tensor)])?;
49/// ```
50pub enum ModelHandle {
51    /// Model is loaded and ready for inference.
52    Loaded {
53        /// The loaded model instance.
54        model: Model,
55        /// The compute units used to load this model (preserved for reload).
56        compute_units: ComputeUnits,
57    },
58    /// Model has been unloaded from memory. The path and configuration are
59    /// retained so the model can be reloaded without the caller needing to
60    /// remember them.
61    Unloaded {
62        /// Filesystem path to the compiled `.mlmodelc` bundle.
63        path: PathBuf,
64        /// Compute units to use when reloading.
65        compute_units: ComputeUnits,
66    },
67}
68
69impl ModelHandle {
70    /// Load a compiled CoreML model and wrap it in a lifecycle handle.
71    ///
72    /// This is equivalent to [`Model::load`] but returns a `ModelHandle`
73    /// that supports later unloading and reloading.
74    ///
75    /// # Errors
76    ///
77    /// Returns an error if the model cannot be loaded (invalid path,
78    /// corrupt model, or non-Apple platform).
79    pub fn load(
80        path: impl AsRef<std::path::Path>,
81        compute_units: ComputeUnits,
82    ) -> Result<Self> {
83        let model = Model::load(&path, compute_units)?;
84        Ok(Self::Loaded {
85            model,
86            compute_units,
87        })
88    }
89
90    /// Wrap an already-loaded [`Model`] in a lifecycle handle.
91    ///
92    /// Use this when you already have a `Model` instance (e.g. loaded via
93    /// [`Model::load_async`]) and want to add lifecycle management.
94    pub fn from_model(model: Model, compute_units: ComputeUnits) -> Self {
95        Self::Loaded {
96            model,
97            compute_units,
98        }
99    }
100
101    /// Returns `true` if the model is currently loaded and ready for
102    /// inference.
103    pub fn is_loaded(&self) -> bool {
104        matches!(self, Self::Loaded { .. })
105    }
106
107    /// Returns the filesystem path this model was (or will be) loaded from.
108    pub fn path(&self) -> &std::path::Path {
109        match self {
110            Self::Loaded { model, .. } => model.path(),
111            Self::Unloaded { path, .. } => path,
112        }
113    }
114
115    /// Returns the compute-unit configuration.
116    pub fn compute_units(&self) -> ComputeUnits {
117        match self {
118            Self::Loaded { compute_units, .. } | Self::Unloaded { compute_units, .. } => {
119                *compute_units
120            }
121        }
122    }
123
124    /// Get a reference to the loaded model.
125    ///
126    /// # Errors
127    ///
128    /// Returns an error if the model is currently unloaded.
129    pub fn model(&self) -> Result<&Model> {
130        match self {
131            Self::Loaded { model, .. } => Ok(model),
132            Self::Unloaded { .. } => Err(Error::new(
133                ErrorKind::ModelLoad,
134                "model is unloaded; call reload() first",
135            )),
136        }
137    }
138
139    /// Unload the model from memory, releasing GPU/ANE resources.
140    ///
141    /// The path and compute-unit configuration are preserved so the model
142    /// can be reloaded later via [`reload`](Self::reload).
143    ///
144    /// This method consumes `self` and returns a new `ModelHandle` in the
145    /// `Unloaded` state. The inner `MLModel` is dropped, which tells CoreML
146    /// to release its GPU and Neural Engine allocations.
147    ///
148    /// # Errors
149    ///
150    /// Returns an error if the model is already unloaded.
151    pub fn unload(self) -> Result<Self> {
152        match self {
153            Self::Loaded {
154                model,
155                compute_units,
156            } => {
157                let path = model.path().to_path_buf();
158                // Dropping the Model releases the Retained<MLModel> and its
159                // associated GPU/ANE resources.
160                drop(model);
161                Ok(Self::Unloaded {
162                    path,
163                    compute_units,
164                })
165            }
166            Self::Unloaded { .. } => Err(Error::new(
167                ErrorKind::ModelLoad,
168                "model is already unloaded",
169            )),
170        }
171    }
172
173    /// Reload a previously unloaded model from its original path.
174    ///
175    /// Uses the same compute-unit configuration that was active when the
176    /// model was first loaded.
177    ///
178    /// # Errors
179    ///
180    /// Returns an error if the model is already loaded, or if reloading
181    /// fails (e.g. the model file was deleted while unloaded).
182    pub fn reload(self) -> Result<Self> {
183        match self {
184            Self::Unloaded {
185                path,
186                compute_units,
187            } => {
188                let model = Model::load(&path, compute_units)?;
189                Ok(Self::Loaded {
190                    model,
191                    compute_units,
192                })
193            }
194            Self::Loaded { .. } => Err(Error::new(
195                ErrorKind::ModelLoad,
196                "model is already loaded",
197            )),
198        }
199    }
200
201    /// Run a prediction on the loaded model.
202    ///
203    /// This is a convenience method that delegates to [`Model::predict`].
204    ///
205    /// # Errors
206    ///
207    /// Returns an error if the model is unloaded, or if prediction fails.
208    pub fn predict(
209        &self,
210        inputs: &[(&str, &dyn crate::tensor::AsMultiArray)],
211    ) -> Result<crate::Prediction> {
212        self.model()?.predict(inputs)
213    }
214
215    /// Get descriptions of all model inputs.
216    ///
217    /// # Errors
218    ///
219    /// Returns an error if the model is unloaded.
220    pub fn inputs(&self) -> Result<Vec<crate::FeatureDescription>> {
221        Ok(self.model()?.inputs())
222    }
223
224    /// Get descriptions of all model outputs.
225    ///
226    /// # Errors
227    ///
228    /// Returns an error if the model is unloaded.
229    pub fn outputs(&self) -> Result<Vec<crate::FeatureDescription>> {
230        Ok(self.model()?.outputs())
231    }
232
233    /// Get model metadata (author, description, version, license).
234    ///
235    /// # Errors
236    ///
237    /// Returns an error if the model is unloaded.
238    pub fn metadata(&self) -> Result<crate::ModelMetadata> {
239        Ok(self.model()?.metadata())
240    }
241}
242
243impl std::fmt::Debug for ModelHandle {
244    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
245        match self {
246            Self::Loaded {
247                model,
248                compute_units,
249            } => f
250                .debug_struct("ModelHandle")
251                .field("state", &"Loaded")
252                .field("path", &model.path())
253                .field("compute_units", compute_units)
254                .finish(),
255            Self::Unloaded {
256                path,
257                compute_units,
258            } => f
259                .debug_struct("ModelHandle")
260                .field("state", &"Unloaded")
261                .field("path", path)
262                .field("compute_units", compute_units)
263                .finish(),
264        }
265    }
266}
267
268#[cfg(test)]
269mod tests {
270    use super::*;
271
272    #[test]
273    fn unloaded_handle_is_not_loaded() {
274        let handle = ModelHandle::Unloaded {
275            path: PathBuf::from("/test.mlmodelc"),
276            compute_units: ComputeUnits::All,
277        };
278        assert!(!handle.is_loaded());
279    }
280
281    #[test]
282    fn unloaded_handle_preserves_path() {
283        let handle = ModelHandle::Unloaded {
284            path: PathBuf::from("/models/my_model.mlmodelc"),
285            compute_units: ComputeUnits::CpuAndGpu,
286        };
287        assert_eq!(
288            handle.path(),
289            std::path::Path::new("/models/my_model.mlmodelc")
290        );
291    }
292
293    #[test]
294    fn unloaded_handle_preserves_compute_units() {
295        let handle = ModelHandle::Unloaded {
296            path: PathBuf::from("/test.mlmodelc"),
297            compute_units: ComputeUnits::CpuAndNeuralEngine,
298        };
299        assert_eq!(handle.compute_units(), ComputeUnits::CpuAndNeuralEngine);
300    }
301
302    #[test]
303    fn unloaded_handle_rejects_model_access() {
304        let handle = ModelHandle::Unloaded {
305            path: PathBuf::from("/test.mlmodelc"),
306            compute_units: ComputeUnits::All,
307        };
308        let err = handle.model().unwrap_err();
309        assert_eq!(err.kind(), &ErrorKind::ModelLoad);
310        assert!(err.message().contains("unloaded"));
311    }
312
313    #[test]
314    fn unloaded_handle_rejects_double_unload() {
315        let handle = ModelHandle::Unloaded {
316            path: PathBuf::from("/test.mlmodelc"),
317            compute_units: ComputeUnits::All,
318        };
319        let err = handle.unload().unwrap_err();
320        assert_eq!(err.kind(), &ErrorKind::ModelLoad);
321        assert!(err.message().contains("already unloaded"));
322    }
323
324    #[test]
325    fn load_nonexistent_model_fails() {
326        let result = ModelHandle::load("/nonexistent.mlmodelc", ComputeUnits::All);
327        assert!(result.is_err());
328    }
329
330    #[test]
331    fn debug_format_unloaded() {
332        let handle = ModelHandle::Unloaded {
333            path: PathBuf::from("/test.mlmodelc"),
334            compute_units: ComputeUnits::All,
335        };
336        let debug = format!("{:?}", handle);
337        assert!(debug.contains("Unloaded"));
338        assert!(debug.contains("/test.mlmodelc"));
339    }
340
341    #[test]
342    fn unloaded_handle_rejects_inputs() {
343        let handle = ModelHandle::Unloaded {
344            path: PathBuf::from("/test.mlmodelc"),
345            compute_units: ComputeUnits::All,
346        };
347        assert!(handle.inputs().is_err());
348    }
349
350    #[test]
351    fn unloaded_handle_rejects_outputs() {
352        let handle = ModelHandle::Unloaded {
353            path: PathBuf::from("/test.mlmodelc"),
354            compute_units: ComputeUnits::All,
355        };
356        assert!(handle.outputs().is_err());
357    }
358
359    #[test]
360    fn unloaded_handle_rejects_metadata() {
361        let handle = ModelHandle::Unloaded {
362            path: PathBuf::from("/test.mlmodelc"),
363            compute_units: ComputeUnits::All,
364        };
365        assert!(handle.metadata().is_err());
366    }
367
368    #[cfg(target_vendor = "apple")]
369    mod apple_tests {
370        use super::*;
371
372        #[test]
373        fn load_unload_reload_cycle() {
374            let model_path = std::path::Path::new(env!("CARGO_MANIFEST_DIR"))
375                .join("tests/fixtures/test_linear.mlmodelc");
376            if !model_path.exists() {
377                // Skip if fixture is not available.
378                return;
379            }
380
381            let handle = ModelHandle::load(&model_path, ComputeUnits::All).unwrap();
382            assert!(handle.is_loaded());
383            assert!(handle.model().is_ok());
384
385            // Unload releases GPU/ANE resources.
386            let handle = handle.unload().unwrap();
387            assert!(!handle.is_loaded());
388            assert!(handle.model().is_err());
389            assert_eq!(handle.path(), model_path);
390
391            // Reload brings the model back.
392            let handle = handle.reload().unwrap();
393            assert!(handle.is_loaded());
394            assert!(handle.model().is_ok());
395        }
396
397        #[test]
398        fn loaded_handle_rejects_double_reload() {
399            let model_path = std::path::Path::new(env!("CARGO_MANIFEST_DIR"))
400                .join("tests/fixtures/test_linear.mlmodelc");
401            if !model_path.exists() {
402                return;
403            }
404
405            let handle = ModelHandle::load(&model_path, ComputeUnits::All).unwrap();
406            let err = handle.reload().unwrap_err();
407            assert_eq!(err.kind(), &ErrorKind::ModelLoad);
408            assert!(err.message().contains("already loaded"));
409        }
410
411        #[test]
412        fn from_model_wraps_existing() {
413            let model_path = std::path::Path::new(env!("CARGO_MANIFEST_DIR"))
414                .join("tests/fixtures/test_linear.mlmodelc");
415            if !model_path.exists() {
416                return;
417            }
418
419            let model = Model::load(&model_path, ComputeUnits::All).unwrap();
420            let handle = ModelHandle::from_model(model, ComputeUnits::All);
421            assert!(handle.is_loaded());
422            assert_eq!(handle.compute_units(), ComputeUnits::All);
423        }
424
425        #[test]
426        fn debug_format_loaded() {
427            let model_path = std::path::Path::new(env!("CARGO_MANIFEST_DIR"))
428                .join("tests/fixtures/test_linear.mlmodelc");
429            if !model_path.exists() {
430                return;
431            }
432
433            let handle = ModelHandle::load(&model_path, ComputeUnits::All).unwrap();
434            let debug = format!("{:?}", handle);
435            assert!(debug.contains("Loaded"));
436        }
437    }
438}