offline_intelligence/model_runtime/
coreml_runtime.rs1use async_trait::async_trait;
5use super::runtime_trait::*;
6use std::process::{Child, Command, Stdio};
7use std::time::Duration;
8use tracing::{info, warn};
9use tokio::time::sleep;
10
11pub struct CoreMLRuntime {
12 config: Option<RuntimeConfig>,
13 server_process: Option<Child>,
14 http_client: reqwest::Client,
15 base_url: String,
16}
17
18impl CoreMLRuntime {
19 pub fn new() -> Self {
20 Self {
21 config: None,
22 server_process: None,
23 http_client: reqwest::Client::builder()
24 .timeout(Duration::from_secs(600))
25 .build()
26 .unwrap_or_default(),
27 base_url: String::new(),
28 }
29 }
30
31 async fn start_server(&mut self, config: &RuntimeConfig) -> anyhow::Result<()> {
32 #[cfg(not(target_os = "macos"))]
33 {
34 return Err(anyhow::anyhow!("CoreML runtime is only supported on macOS"));
35 }
36
37 #[cfg(target_os = "macos")]
38 {
39 let binary_path = config.runtime_binary.as_ref()
40 .ok_or_else(|| anyhow::anyhow!("CoreML runtime requires runtime_binary path"))?;
41
42 if !binary_path.exists() {
43 return Err(anyhow::anyhow!(
44 "CoreML server binary not found at: {}",
45 binary_path.display()
46 ));
47 }
48
49 info!("Starting CoreML server for model: {}", config.model_path.display());
50
51 let mut cmd = Command::new(binary_path);
52 cmd.arg("--model").arg(&config.model_path)
53 .arg("--host").arg(&config.host)
54 .arg("--port").arg(config.port.to_string())
55 .stdout(Stdio::piped())
56 .stderr(Stdio::piped());
57
58 let child = cmd.spawn()
59 .map_err(|e| anyhow::anyhow!("Failed to spawn CoreML server: {}", e))?;
60
61 self.server_process = Some(child);
62 self.base_url = format!("http://{}:{}", config.host, config.port);
63
64 for attempt in 1..=15 {
65 sleep(Duration::from_secs(2)).await;
66 if self.is_ready().await {
67 info!("✅ CoreML runtime ready after {} seconds", attempt * 2);
68 return Ok(());
69 }
70 }
71
72 Err(anyhow::anyhow!("CoreML server failed to start within 30 seconds"))
73 }
74 }
75}
76
77impl Default for CoreMLRuntime {
78 fn default() -> Self {
79 Self::new()
80 }
81}
82
83#[async_trait]
84impl ModelRuntime for CoreMLRuntime {
85 fn supported_format(&self) -> ModelFormat {
86 ModelFormat::CoreML
87 }
88
89 async fn initialize(&mut self, config: RuntimeConfig) -> anyhow::Result<()> {
90 info!("Initializing CoreML runtime");
91
92 if config.format != ModelFormat::CoreML {
93 return Err(anyhow::anyhow!("CoreML runtime received wrong format: {:?}", config.format));
94 }
95
96 self.config = Some(config.clone());
97 self.start_server(&config).await?;
98 Ok(())
99 }
100
101 async fn is_ready(&self) -> bool {
102 if self.base_url.is_empty() {
103 return false;
104 }
105
106 let health_url = format!("{}/health", self.base_url);
107 match self.http_client.get(&health_url).send().await {
108 Ok(resp) => resp.status().is_success(),
109 Err(_) => false,
110 }
111 }
112
113 async fn health_check(&self) -> anyhow::Result<String> {
114 if self.base_url.is_empty() {
115 return Err(anyhow::anyhow!("Runtime not initialized"));
116 }
117
118 let health_url = format!("{}/health", self.base_url);
119 let resp = self.http_client.get(&health_url).send().await
120 .map_err(|e| anyhow::anyhow!("Health check failed: {}", e))?;
121
122 if resp.status().is_success() {
123 Ok("healthy".to_string())
124 } else {
125 Err(anyhow::anyhow!("Health check returned: {}", resp.status()))
126 }
127 }
128
129 fn base_url(&self) -> String {
130 self.base_url.clone()
131 }
132
133 async fn generate(&self, request: InferenceRequest) -> anyhow::Result<InferenceResponse> {
134 let url = self.completions_url();
135
136 let payload = serde_json::json!({
137 "model": "coreml-llm",
138 "messages": request.messages,
139 "max_tokens": request.max_tokens,
140 "temperature": request.temperature,
141 "stream": false,
142 });
143
144 let resp = self.http_client.post(&url).json(&payload).send().await
145 .map_err(|e| anyhow::anyhow!("Inference request failed: {}", e))?;
146
147 if !resp.status().is_success() {
148 let status = resp.status();
149 let body = resp.text().await.unwrap_or_default();
150 return Err(anyhow::anyhow!("Inference failed ({}): {}", status, body));
151 }
152
153 let response: serde_json::Value = resp.json().await
154 .map_err(|e| anyhow::anyhow!("Failed to parse response: {}", e))?;
155
156 let content = response["choices"][0]["message"]["content"].as_str().unwrap_or("").to_string();
157 let finish_reason = response["choices"][0]["finish_reason"].as_str().map(|s| s.to_string());
158
159 Ok(InferenceResponse { content, finish_reason })
160 }
161
162 async fn generate_stream(
163 &self,
164 request: InferenceRequest,
165 ) -> anyhow::Result<Box<dyn futures_util::Stream<Item = Result<String, anyhow::Error>> + Send + Unpin>> {
166 use futures_util::StreamExt;
167
168 let url = self.completions_url();
169 let payload = serde_json::json!({
170 "model": "coreml-llm",
171 "messages": request.messages,
172 "max_tokens": request.max_tokens,
173 "temperature": request.temperature,
174 "stream": true,
175 });
176
177 let resp = self.http_client.post(&url).json(&payload).send().await
178 .map_err(|e| anyhow::anyhow!("Stream request failed: {}", e))?;
179
180 if !resp.status().is_success() {
181 let status = resp.status();
182 let body = resp.text().await.unwrap_or_default();
183 return Err(anyhow::anyhow!("Stream failed ({}): {}", status, body));
184 }
185
186 let byte_stream = resp.bytes_stream();
187 let sse_stream = async_stream::try_stream! {
188 let mut buffer = String::new();
189 futures_util::pin_mut!(byte_stream);
190
191 while let Some(chunk_result) = byte_stream.next().await {
192 let chunk = chunk_result.map_err(|e| anyhow::anyhow!("Stream read error: {}", e))?;
193 buffer.push_str(&String::from_utf8_lossy(&chunk));
194
195 while let Some(newline_pos) = buffer.find('\n') {
196 let line = buffer[..newline_pos].trim().to_string();
197 buffer = buffer[newline_pos + 1..].to_string();
198
199 if line.is_empty() || !line.starts_with("data: ") {
200 continue;
201 }
202
203 let data = &line[6..];
204 if data == "[DONE]" {
205 return;
206 }
207
208 yield format!("data: {}\n\n", data);
209 }
210 }
211 };
212
213 Ok(Box::new(Box::pin(sse_stream)))
214 }
215
216 async fn shutdown(&mut self) -> anyhow::Result<()> {
217 info!("Shutting down CoreML runtime");
218
219 if let Some(mut child) = self.server_process.take() {
220 match child.kill() {
221 Ok(_) => {
222 info!("CoreML server process killed successfully");
223 let _ = child.wait();
224 }
225 Err(e) => {
226 warn!("Failed to kill CoreML server: {}", e);
227 }
228 }
229 }
230
231 self.config = None;
232 self.base_url.clear();
233 Ok(())
234 }
235
236 fn metadata(&self) -> RuntimeMetadata {
237 RuntimeMetadata {
238 format: ModelFormat::CoreML,
239 runtime_name: "CoreML (Apple)".to_string(),
240 version: "latest".to_string(),
241 supports_gpu: true,
242 supports_streaming: true,
243 }
244 }
245}
246
247impl Drop for CoreMLRuntime {
248 fn drop(&mut self) {
249 if let Some(mut child) = self.server_process.take() {
250 let _ = child.kill();
251 let _ = child.wait();
252 }
253 }
254}