Skip to main content

offline_intelligence/model_runtime/
coreml_runtime.rs

1//! CoreML Runtime Adapter (macOS only)
2//! For Apple Silicon optimized models
3
4use async_trait::async_trait;
5use super::runtime_trait::*;
6use std::process::{Child, Command, Stdio};
7use std::time::Duration;
8use tracing::{info, warn};
9use tokio::time::sleep;
10
11pub struct CoreMLRuntime {
12    config: Option<RuntimeConfig>,
13    server_process: Option<Child>,
14    http_client: reqwest::Client,
15    base_url: String,
16}
17
18impl CoreMLRuntime {
19    pub fn new() -> Self {
20        Self {
21            config: None,
22            server_process: None,
23            http_client: reqwest::Client::builder()
24                .timeout(Duration::from_secs(600))
25                .build()
26                .unwrap_or_default(),
27            base_url: String::new(),
28        }
29    }
30
31    async fn start_server(&mut self, config: &RuntimeConfig) -> anyhow::Result<()> {
32        #[cfg(not(target_os = "macos"))]
33        {
34            return Err(anyhow::anyhow!("CoreML runtime is only supported on macOS"));
35        }
36
37        #[cfg(target_os = "macos")]
38        {
39            let binary_path = config.runtime_binary.as_ref()
40                .ok_or_else(|| anyhow::anyhow!("CoreML runtime requires runtime_binary path"))?;
41
42            if !binary_path.exists() {
43                return Err(anyhow::anyhow!(
44                    "CoreML server binary not found at: {}",
45                    binary_path.display()
46                ));
47            }
48
49            info!("Starting CoreML server for model: {}", config.model_path.display());
50            
51            let mut cmd = Command::new(binary_path);
52            cmd.arg("--model").arg(&config.model_path)
53                .arg("--host").arg(&config.host)
54                .arg("--port").arg(config.port.to_string())
55                .stdout(Stdio::piped())
56                .stderr(Stdio::piped());
57
58            let child = cmd.spawn()
59                .map_err(|e| anyhow::anyhow!("Failed to spawn CoreML server: {}", e))?;
60
61            self.server_process = Some(child);
62            self.base_url = format!("http://{}:{}", config.host, config.port);
63
64            for attempt in 1..=15 {
65                sleep(Duration::from_secs(2)).await;
66                if self.is_ready().await {
67                    info!("✅ CoreML runtime ready after {} seconds", attempt * 2);
68                    return Ok(());
69                }
70            }
71
72            Err(anyhow::anyhow!("CoreML server failed to start within 30 seconds"))
73        }
74    }
75}
76
77impl Default for CoreMLRuntime {
78    fn default() -> Self {
79        Self::new()
80    }
81}
82
83#[async_trait]
84impl ModelRuntime for CoreMLRuntime {
85    fn supported_format(&self) -> ModelFormat {
86        ModelFormat::CoreML
87    }
88
89    async fn initialize(&mut self, config: RuntimeConfig) -> anyhow::Result<()> {
90        info!("Initializing CoreML runtime");
91        
92        if config.format != ModelFormat::CoreML {
93            return Err(anyhow::anyhow!("CoreML runtime received wrong format: {:?}", config.format));
94        }
95
96        self.config = Some(config.clone());
97        self.start_server(&config).await?;
98        Ok(())
99    }
100
101    async fn is_ready(&self) -> bool {
102        if self.base_url.is_empty() {
103            return false;
104        }
105
106        let health_url = format!("{}/health", self.base_url);
107        match self.http_client.get(&health_url).send().await {
108            Ok(resp) => resp.status().is_success(),
109            Err(_) => false,
110        }
111    }
112
113    async fn health_check(&self) -> anyhow::Result<String> {
114        if self.base_url.is_empty() {
115            return Err(anyhow::anyhow!("Runtime not initialized"));
116        }
117
118        let health_url = format!("{}/health", self.base_url);
119        let resp = self.http_client.get(&health_url).send().await
120            .map_err(|e| anyhow::anyhow!("Health check failed: {}", e))?;
121
122        if resp.status().is_success() {
123            Ok("healthy".to_string())
124        } else {
125            Err(anyhow::anyhow!("Health check returned: {}", resp.status()))
126        }
127    }
128
129    fn base_url(&self) -> String {
130        self.base_url.clone()
131    }
132
133    async fn generate(&self, request: InferenceRequest) -> anyhow::Result<InferenceResponse> {
134        let url = self.completions_url();
135        
136        let payload = serde_json::json!({
137            "model": "coreml-llm",
138            "messages": request.messages,
139            "max_tokens": request.max_tokens,
140            "temperature": request.temperature,
141            "stream": false,
142        });
143
144        let resp = self.http_client.post(&url).json(&payload).send().await
145            .map_err(|e| anyhow::anyhow!("Inference request failed: {}", e))?;
146
147        if !resp.status().is_success() {
148            let status = resp.status();
149            let body = resp.text().await.unwrap_or_default();
150            return Err(anyhow::anyhow!("Inference failed ({}): {}", status, body));
151        }
152
153        let response: serde_json::Value = resp.json().await
154            .map_err(|e| anyhow::anyhow!("Failed to parse response: {}", e))?;
155
156        let content = response["choices"][0]["message"]["content"].as_str().unwrap_or("").to_string();
157        let finish_reason = response["choices"][0]["finish_reason"].as_str().map(|s| s.to_string());
158
159        Ok(InferenceResponse { content, finish_reason })
160    }
161
162    async fn generate_stream(
163        &self,
164        request: InferenceRequest,
165    ) -> anyhow::Result<Box<dyn futures_util::Stream<Item = Result<String, anyhow::Error>> + Send + Unpin>> {
166        use futures_util::StreamExt;
167        
168        let url = self.completions_url();
169        let payload = serde_json::json!({
170            "model": "coreml-llm",
171            "messages": request.messages,
172            "max_tokens": request.max_tokens,
173            "temperature": request.temperature,
174            "stream": true,
175        });
176
177        let resp = self.http_client.post(&url).json(&payload).send().await
178            .map_err(|e| anyhow::anyhow!("Stream request failed: {}", e))?;
179
180        if !resp.status().is_success() {
181            let status = resp.status();
182            let body = resp.text().await.unwrap_or_default();
183            return Err(anyhow::anyhow!("Stream failed ({}): {}", status, body));
184        }
185
186        let byte_stream = resp.bytes_stream();
187        let sse_stream = async_stream::try_stream! {
188            let mut buffer = String::new();
189            futures_util::pin_mut!(byte_stream);
190
191            while let Some(chunk_result) = byte_stream.next().await {
192                let chunk = chunk_result.map_err(|e| anyhow::anyhow!("Stream read error: {}", e))?;
193                buffer.push_str(&String::from_utf8_lossy(&chunk));
194
195                while let Some(newline_pos) = buffer.find('\n') {
196                    let line = buffer[..newline_pos].trim().to_string();
197                    buffer = buffer[newline_pos + 1..].to_string();
198
199                    if line.is_empty() || !line.starts_with("data: ") {
200                        continue;
201                    }
202
203                    let data = &line[6..];
204                    if data == "[DONE]" {
205                        return;
206                    }
207
208                    yield format!("data: {}\n\n", data);
209                }
210            }
211        };
212
213        Ok(Box::new(Box::pin(sse_stream)))
214    }
215
216    async fn shutdown(&mut self) -> anyhow::Result<()> {
217        info!("Shutting down CoreML runtime");
218        
219        if let Some(mut child) = self.server_process.take() {
220            match child.kill() {
221                Ok(_) => {
222                    info!("CoreML server process killed successfully");
223                    let _ = child.wait();
224                }
225                Err(e) => {
226                    warn!("Failed to kill CoreML server: {}", e);
227                }
228            }
229        }
230
231        self.config = None;
232        self.base_url.clear();
233        Ok(())
234    }
235
236    fn metadata(&self) -> RuntimeMetadata {
237        RuntimeMetadata {
238            format: ModelFormat::CoreML,
239            runtime_name: "CoreML (Apple)".to_string(),
240            version: "latest".to_string(),
241            supports_gpu: true,
242            supports_streaming: true,
243        }
244    }
245}
246
247impl Drop for CoreMLRuntime {
248    fn drop(&mut self) {
249        if let Some(mut child) = self.server_process.take() {
250            let _ = child.kill();
251            let _ = child.wait();
252        }
253    }
254}