1use crate::error::{Error, ErrorKind, Result};
6use std::path::{Path, PathBuf};
7
8#[cfg(target_vendor = "apple")]
14#[allow(deprecated)] pub fn compile_model(source: impl AsRef<Path>) -> Result<PathBuf> {
16 use objc2_core_ml::MLModel;
17
18 let source = source.as_ref();
19 let source_str = source.to_str().ok_or_else(|| {
20 Error::new(ErrorKind::ModelLoad, "source path contains non-UTF8 characters")
21 })?;
22
23 let url = objc2_foundation::NSURL::fileURLWithPath(
24 &crate::ffi::str_to_nsstring(source_str),
25 );
26
27 let compiled_url = unsafe { MLModel::compileModelAtURL_error(&url) }
28 .map_err(|e| Error::from_nserror(ErrorKind::ModelLoad, &e))?;
29
30 let compiled_path = compiled_url.path()
31 .ok_or_else(|| Error::new(ErrorKind::ModelLoad, "compiled URL has no path"))?;
32
33 Ok(PathBuf::from(compiled_path.to_string()))
34}
35
36#[cfg(not(target_vendor = "apple"))]
37pub fn compile_model(_source: impl AsRef<Path>) -> Result<PathBuf> {
38 Err(Error::new(
39 ErrorKind::UnsupportedPlatform,
40 "CoreML requires Apple platform",
41 ))
42}
43
44#[cfg(target_vendor = "apple")]
54pub fn compile_model_async(
55 source: impl AsRef<Path>,
56) -> Result<crate::async_bridge::CompletionFuture<PathBuf>> {
57 use objc2_core_ml::MLModel;
58
59 let source = source.as_ref();
60 let source_str = source.to_str().ok_or_else(|| {
61 Error::new(ErrorKind::ModelLoad, "source path contains non-UTF8 characters")
62 })?;
63
64 let url = objc2_foundation::NSURL::fileURLWithPath(
65 &crate::ffi::str_to_nsstring(source_str),
66 );
67
68 let (sender, future) = crate::async_bridge::completion_channel();
69 let sender_cell = std::cell::Cell::new(Some(sender));
70
71 let block = block2::RcBlock::new(
72 move |compiled_url: *mut objc2_foundation::NSURL,
73 error: *mut objc2_foundation::NSError| {
74 let sender = sender_cell
75 .take()
76 .expect("completion handler called more than once");
77 if compiled_url.is_null() {
78 if error.is_null() {
79 sender.send(Err(Error::new(
80 ErrorKind::ModelLoad,
81 "compile returned null with no error",
82 )));
83 } else {
84 let err = unsafe { &*error };
85 sender.send(Err(Error::from_nserror(ErrorKind::ModelLoad, err)));
86 }
87 } else {
88 let url = unsafe { &*compiled_url };
89 match url.path() {
90 Some(p) => sender.send(Ok(PathBuf::from(p.to_string()))),
91 None => sender.send(Err(Error::new(
92 ErrorKind::ModelLoad,
93 "compiled URL has no path",
94 ))),
95 }
96 }
97 },
98 );
99
100 unsafe {
101 MLModel::compileModelAtURL_completionHandler(&url, &block);
102 }
103
104 Ok(future)
105}
106
107#[cfg(not(target_vendor = "apple"))]
109pub fn compile_model_async(
110 _source: impl AsRef<Path>,
111) -> Result<crate::async_bridge::CompletionFuture<PathBuf>> {
112 Err(Error::new(
113 ErrorKind::UnsupportedPlatform,
114 "CoreML requires Apple platform",
115 ))
116}
117
118#[cfg(test)]
119mod tests {
120 #[cfg(not(target_vendor = "apple"))]
121 #[test]
122 fn compile_fails_on_non_apple() {
123 let err = super::compile_model("/tmp/model.mlmodel").unwrap_err();
124 assert_eq!(err.kind(), &crate::ErrorKind::UnsupportedPlatform);
125 }
126
127 #[cfg(not(target_vendor = "apple"))]
128 #[test]
129 fn compile_async_fails_on_non_apple() {
130 let err = super::compile_model_async("/tmp/model.mlmodel").unwrap_err();
131 assert_eq!(err.kind(), &crate::ErrorKind::UnsupportedPlatform);
132 }
133}