Skip to main content

offline_intelligence/model_runtime/
coreml_runtime.rs

1//! CoreML Runtime Adapter (macOS only)
2//! For Apple Silicon optimized models
3
4#[allow(unused_imports)]
5use async_trait::async_trait;
6use super::runtime_trait::*;
7use std::process::Child;
8#[cfg(target_os = "macos")]
9use std::process::{Command, Stdio};
10use std::time::Duration;
11use tracing::{info, warn};
12#[cfg(target_os = "macos")]
13use tokio::time::sleep;
14
15pub struct CoreMLRuntime {
16    config: Option<RuntimeConfig>,
17    server_process: Option<Child>,
18    http_client: reqwest::Client,
19    base_url: String,
20}
21
22impl CoreMLRuntime {
23    pub fn new() -> Self {
24        Self {
25            config: None,
26            server_process: None,
27            http_client: reqwest::Client::builder()
28                .timeout(Duration::from_secs(600))
29                .build()
30                .unwrap_or_default(),
31            base_url: String::new(),
32        }
33    }
34
35    #[allow(unused_variables)]
36    async fn start_server(&mut self, config: &RuntimeConfig) -> anyhow::Result<()> {
37        #[cfg(not(target_os = "macos"))]
38        {
39            return Err(anyhow::anyhow!("CoreML runtime is only supported on macOS"));
40        }
41
42        #[cfg(target_os = "macos")]
43        {
44            let binary_path = config.runtime_binary.as_ref()
45                .ok_or_else(|| anyhow::anyhow!("CoreML runtime requires runtime_binary path"))?;
46
47            if !binary_path.exists() {
48                return Err(anyhow::anyhow!(
49                    "CoreML server binary not found at: {}",
50                    binary_path.display()
51                ));
52            }
53
54            info!("Starting CoreML server for model: {}", config.model_path.display());
55            
56            let mut cmd = Command::new(binary_path);
57            cmd.arg("--model").arg(&config.model_path)
58                .arg("--host").arg(&config.host)
59                .arg("--port").arg(config.port.to_string())
60                .stdout(Stdio::piped())
61                .stderr(Stdio::piped());
62
63            let child = cmd.spawn()
64                .map_err(|e| anyhow::anyhow!("Failed to spawn CoreML server: {}", e))?;
65
66            self.server_process = Some(child);
67            self.base_url = format!("http://{}:{}", config.host, config.port);
68
69            // Wait for server to be ready (up to 120 seconds) with exponential backoff.
70            let _start = std::time::Instant::now();
71            let mut delay_ms: u64 = 100;
72            let mut last_log_secs: u64 = 0;
73            loop {
74                sleep(Duration::from_millis(delay_ms)).await;
75                if self.is_ready().await {
76                    info!("✅ CoreML runtime ready after {:.1}s", _start.elapsed().as_secs_f64());
77                    return Ok(());
78                }
79                let elapsed_secs = _start.elapsed().as_secs();
80                if elapsed_secs >= 120 {
81                    break;
82                }
83                if elapsed_secs >= last_log_secs + 10 {
84                    info!("Still waiting for CoreML server... ({}/120s)", elapsed_secs);
85                    last_log_secs = elapsed_secs;
86                }
87                delay_ms = (delay_ms * 2).min(2_000);
88            }
89
90            Err(anyhow::anyhow!("CoreML server failed to start within 120 seconds"))
91        }
92    }
93}
94
95impl Default for CoreMLRuntime {
96    fn default() -> Self {
97        Self::new()
98    }
99}
100
101#[async_trait]
102impl ModelRuntime for CoreMLRuntime {
103    fn supported_format(&self) -> ModelFormat {
104        ModelFormat::CoreML
105    }
106
107    async fn initialize(&mut self, config: RuntimeConfig) -> anyhow::Result<()> {
108        info!("Initializing CoreML runtime");
109        
110        if config.format != ModelFormat::CoreML {
111            return Err(anyhow::anyhow!("CoreML runtime received wrong format: {:?}", config.format));
112        }
113
114        self.config = Some(config.clone());
115        self.start_server(&config).await?;
116        Ok(())
117    }
118
119    async fn is_ready(&self) -> bool {
120        if self.base_url.is_empty() {
121            return false;
122        }
123
124        let health_url = format!("{}/health", self.base_url);
125        match self.http_client.get(&health_url).send().await {
126            Ok(resp) => resp.status().is_success(),
127            Err(_) => false,
128        }
129    }
130
131    async fn health_check(&self) -> anyhow::Result<String> {
132        if self.base_url.is_empty() {
133            return Err(anyhow::anyhow!("Runtime not initialized"));
134        }
135
136        let health_url = format!("{}/health", self.base_url);
137        let resp = self.http_client.get(&health_url).send().await
138            .map_err(|e| anyhow::anyhow!("Health check failed: {}", e))?;
139
140        if resp.status().is_success() {
141            Ok("healthy".to_string())
142        } else {
143            Err(anyhow::anyhow!("Health check returned: {}", resp.status()))
144        }
145    }
146
147    fn base_url(&self) -> String {
148        self.base_url.clone()
149    }
150
151    async fn generate(&self, request: InferenceRequest) -> anyhow::Result<InferenceResponse> {
152        let url = self.completions_url();
153        
154        let payload = serde_json::json!({
155            "model": "coreml-llm",
156            "messages": request.messages,
157            "max_tokens": request.max_tokens,
158            "temperature": request.temperature,
159            "stream": false,
160        });
161
162        let resp = self.http_client.post(&url).json(&payload).send().await
163            .map_err(|e| anyhow::anyhow!("Inference request failed: {}", e))?;
164
165        if !resp.status().is_success() {
166            let status = resp.status();
167            let body = resp.text().await.unwrap_or_default();
168            return Err(anyhow::anyhow!("Inference failed ({}): {}", status, body));
169        }
170
171        let response: serde_json::Value = resp.json().await
172            .map_err(|e| anyhow::anyhow!("Failed to parse response: {}", e))?;
173
174        let content = response["choices"][0]["message"]["content"].as_str().unwrap_or("").to_string();
175        let finish_reason = response["choices"][0]["finish_reason"].as_str().map(|s| s.to_string());
176
177        Ok(InferenceResponse { content, finish_reason })
178    }
179
180    async fn generate_stream(
181        &self,
182        request: InferenceRequest,
183    ) -> anyhow::Result<Box<dyn futures_util::Stream<Item = Result<String, anyhow::Error>> + Send + Unpin>> {
184        use futures_util::StreamExt;
185        
186        let url = self.completions_url();
187        let payload = serde_json::json!({
188            "model": "coreml-llm",
189            "messages": request.messages,
190            "max_tokens": request.max_tokens,
191            "temperature": request.temperature,
192            "stream": true,
193        });
194
195        let resp = self.http_client.post(&url).json(&payload).send().await
196            .map_err(|e| anyhow::anyhow!("Stream request failed: {}", e))?;
197
198        if !resp.status().is_success() {
199            let status = resp.status();
200            let body = resp.text().await.unwrap_or_default();
201            return Err(anyhow::anyhow!("Stream failed ({}): {}", status, body));
202        }
203
204        let byte_stream = resp.bytes_stream();
205        let sse_stream = async_stream::try_stream! {
206            let mut buffer = String::new();
207            futures_util::pin_mut!(byte_stream);
208
209            while let Some(chunk_result) = byte_stream.next().await {
210                let chunk = chunk_result.map_err(|e| anyhow::anyhow!("Stream read error: {}", e))?;
211                buffer.push_str(&String::from_utf8_lossy(&chunk));
212
213                while let Some(newline_pos) = buffer.find('\n') {
214                    let line = buffer[..newline_pos].trim().to_string();
215                    buffer = buffer[newline_pos + 1..].to_string();
216
217                    if line.is_empty() || !line.starts_with("data: ") {
218                        continue;
219                    }
220
221                    let data = &line[6..];
222                    if data == "[DONE]" {
223                        return;
224                    }
225
226                    yield format!("data: {}\n\n", data);
227                }
228            }
229        };
230
231        Ok(Box::new(Box::pin(sse_stream)))
232    }
233
234    async fn shutdown(&mut self) -> anyhow::Result<()> {
235        info!("Shutting down CoreML runtime");
236        
237        if let Some(mut child) = self.server_process.take() {
238            match child.kill() {
239                Ok(_) => {
240                    info!("CoreML server process killed successfully");
241                    let _ = child.wait();
242                }
243                Err(e) => {
244                    warn!("Failed to kill CoreML server: {}", e);
245                }
246            }
247        }
248
249        self.config = None;
250        self.base_url.clear();
251        Ok(())
252    }
253
254    fn metadata(&self) -> RuntimeMetadata {
255        RuntimeMetadata {
256            format: ModelFormat::CoreML,
257            runtime_name: "CoreML (Apple)".to_string(),
258            version: "latest".to_string(),
259            supports_gpu: true,
260            supports_streaming: true,
261        }
262    }
263}
264
265impl Drop for CoreMLRuntime {
266    fn drop(&mut self) {
267        if let Some(mut child) = self.server_process.take() {
268            let _ = child.kill();
269            let _ = child.wait();
270        }
271    }
272}