Skip to main content

coreml_native/
compile.rs

1//! Model compilation (.mlmodel/.mlpackage -> .mlmodelc).
2//!
3//! Covers FR-7.1, FR-7.2.
4
5use crate::error::{Error, ErrorKind, Result};
6use std::path::{Path, PathBuf};
7
8/// Compile a `.mlmodel` or `.mlpackage` to a `.mlmodelc` directory.
9///
10/// Returns the path to the compiled model directory.
11/// The compiled model is placed in a temporary directory by CoreML;
12/// you should copy it to a permanent location.
13#[cfg(target_vendor = "apple")]
14#[allow(deprecated)] // sync API is deprecated but async requires run loop
15pub fn compile_model(source: impl AsRef<Path>) -> Result<PathBuf> {
16    use objc2_core_ml::MLModel;
17
18    let source = source.as_ref();
19    let source_str = source.to_str().ok_or_else(|| {
20        Error::new(ErrorKind::ModelLoad, "source path contains non-UTF8 characters")
21    })?;
22
23    let url = objc2_foundation::NSURL::fileURLWithPath(
24        &crate::ffi::str_to_nsstring(source_str),
25    );
26
27    let compiled_url = unsafe { MLModel::compileModelAtURL_error(&url) }
28        .map_err(|e| Error::from_nserror(ErrorKind::ModelLoad, &e))?;
29
30    let compiled_path = compiled_url.path()
31        .ok_or_else(|| Error::new(ErrorKind::ModelLoad, "compiled URL has no path"))?;
32
33    Ok(PathBuf::from(compiled_path.to_string()))
34}
35
36#[cfg(not(target_vendor = "apple"))]
37pub fn compile_model(_source: impl AsRef<Path>) -> Result<PathBuf> {
38    Err(Error::new(
39        ErrorKind::UnsupportedPlatform,
40        "CoreML requires Apple platform",
41    ))
42}
43
44/// Compile a `.mlmodel` or `.mlpackage` asynchronously.
45///
46/// Returns a [`CompletionFuture`](crate::async_bridge::CompletionFuture) that
47/// resolves to the path of the compiled `.mlmodelc` directory.
48///
49/// The compiled model is placed in a temporary directory by CoreML;
50/// you should copy it to a permanent location.
51///
52/// Requires macOS 14.4+ / iOS 17.4+.
53#[cfg(target_vendor = "apple")]
54pub fn compile_model_async(
55    source: impl AsRef<Path>,
56) -> Result<crate::async_bridge::CompletionFuture<PathBuf>> {
57    use objc2_core_ml::MLModel;
58
59    let source = source.as_ref();
60    let source_str = source.to_str().ok_or_else(|| {
61        Error::new(ErrorKind::ModelLoad, "source path contains non-UTF8 characters")
62    })?;
63
64    let url = objc2_foundation::NSURL::fileURLWithPath(
65        &crate::ffi::str_to_nsstring(source_str),
66    );
67
68    let (sender, future) = crate::async_bridge::completion_channel();
69    let sender_cell = std::cell::Cell::new(Some(sender));
70
71    let block = block2::RcBlock::new(
72        move |compiled_url: *mut objc2_foundation::NSURL,
73              error: *mut objc2_foundation::NSError| {
74            let sender = sender_cell
75                .take()
76                .expect("completion handler called more than once");
77            if compiled_url.is_null() {
78                if error.is_null() {
79                    sender.send(Err(Error::new(
80                        ErrorKind::ModelLoad,
81                        "compile returned null with no error",
82                    )));
83                } else {
84                    let err = unsafe { &*error };
85                    sender.send(Err(Error::from_nserror(ErrorKind::ModelLoad, err)));
86                }
87            } else {
88                let url = unsafe { &*compiled_url };
89                match url.path() {
90                    Some(p) => sender.send(Ok(PathBuf::from(p.to_string()))),
91                    None => sender.send(Err(Error::new(
92                        ErrorKind::ModelLoad,
93                        "compiled URL has no path",
94                    ))),
95                }
96            }
97        },
98    );
99
100    unsafe {
101        MLModel::compileModelAtURL_completionHandler(&url, &block);
102    }
103
104    Ok(future)
105}
106
107/// Compile a `.mlmodel` or `.mlpackage` asynchronously (stub for non-Apple platforms).
108#[cfg(not(target_vendor = "apple"))]
109pub fn compile_model_async(
110    _source: impl AsRef<Path>,
111) -> Result<crate::async_bridge::CompletionFuture<PathBuf>> {
112    Err(Error::new(
113        ErrorKind::UnsupportedPlatform,
114        "CoreML requires Apple platform",
115    ))
116}
117
118#[cfg(test)]
119mod tests {
120    #[cfg(not(target_vendor = "apple"))]
121    #[test]
122    fn compile_fails_on_non_apple() {
123        let err = super::compile_model("/tmp/model.mlmodel").unwrap_err();
124        assert_eq!(err.kind(), &crate::ErrorKind::UnsupportedPlatform);
125    }
126
127    #[cfg(not(target_vendor = "apple"))]
128    #[test]
129    fn compile_async_fails_on_non_apple() {
130        let err = super::compile_model_async("/tmp/model.mlmodel").unwrap_err();
131        assert_eq!(err.kind(), &crate::ErrorKind::UnsupportedPlatform);
132    }
133}