Skip to main content

offline_intelligence/model_runtime/
safetensors_runtime.rs

1//! Safetensors Runtime Adapter
2//!
3//! Adapter for Safetensors format models using Candle framework.
4
5use async_trait::async_trait;
6use super::runtime_trait::*;
7use std::process::{Child, Command, Stdio};
8use std::time::Duration;
9use tracing::{info, warn};
10use tokio::time::sleep;
11
12pub struct SafetensorsRuntime {
13    config: Option<RuntimeConfig>,
14    server_process: Option<Child>,
15    http_client: reqwest::Client,
16    base_url: String,
17}
18
19impl SafetensorsRuntime {
20    pub fn new() -> Self {
21        Self {
22            config: None,
23            server_process: None,
24            http_client: reqwest::Client::builder()
25                .timeout(Duration::from_secs(600))
26                .build()
27                .unwrap_or_default(),
28            base_url: String::new(),
29        }
30    }
31
32    async fn start_server(&mut self, config: &RuntimeConfig) -> anyhow::Result<()> {
33        let binary_path = config.runtime_binary.as_ref()
34            .ok_or_else(|| anyhow::anyhow!("Safetensors runtime requires runtime_binary path (e.g., candle-server)"))?;
35
36        if !binary_path.exists() {
37            return Err(anyhow::anyhow!(
38                "Safetensors server binary not found at: {}. Use Candle or similar framework.",
39                binary_path.display()
40            ));
41        }
42
43        info!("Starting Safetensors server for model: {}", config.model_path.display());
44        
45        let mut cmd = Command::new(binary_path);
46        cmd.arg("--model").arg(&config.model_path)
47            .arg("--host").arg(&config.host)
48            .arg("--port").arg(config.port.to_string())
49            .stdout(Stdio::piped())
50            .stderr(Stdio::piped());
51
52        let child = cmd.spawn()
53            .map_err(|e| anyhow::anyhow!("Failed to spawn Safetensors server: {}", e))?;
54
55        self.server_process = Some(child);
56        self.base_url = format!("http://{}:{}", config.host, config.port);
57
58        // Wait for server to be ready (up to 120 seconds) with exponential backoff.
59        let _start = std::time::Instant::now();
60        let mut delay_ms: u64 = 100;
61        let mut last_log_secs: u64 = 0;
62        loop {
63            sleep(Duration::from_millis(delay_ms)).await;
64            if self.is_ready().await {
65                info!("✅ Safetensors runtime ready after {:.1}s", _start.elapsed().as_secs_f64());
66                return Ok(());
67            }
68            let elapsed_secs = _start.elapsed().as_secs();
69            if elapsed_secs >= 120 {
70                break;
71            }
72            if elapsed_secs >= last_log_secs + 10 {
73                info!("Still waiting for Safetensors server... ({}/120s)", elapsed_secs);
74                last_log_secs = elapsed_secs;
75            }
76            delay_ms = (delay_ms * 2).min(2_000);
77        }
78
79        Err(anyhow::anyhow!("Safetensors server failed to start within 120 seconds"))
80    }
81}
82
83impl Default for SafetensorsRuntime {
84    fn default() -> Self {
85        Self::new()
86    }
87}
88
89#[async_trait]
90impl ModelRuntime for SafetensorsRuntime {
91    fn supported_format(&self) -> ModelFormat {
92        ModelFormat::Safetensors
93    }
94
95    async fn initialize(&mut self, config: RuntimeConfig) -> anyhow::Result<()> {
96        info!("Initializing Safetensors runtime");
97        
98        if config.format != ModelFormat::Safetensors {
99            return Err(anyhow::anyhow!("Safetensors runtime received wrong format: {:?}", config.format));
100        }
101
102        self.config = Some(config.clone());
103        self.start_server(&config).await?;
104        Ok(())
105    }
106
107    async fn is_ready(&self) -> bool {
108        if self.base_url.is_empty() {
109            return false;
110        }
111
112        let health_url = format!("{}/health", self.base_url);
113        match self.http_client.get(&health_url).send().await {
114            Ok(resp) => resp.status().is_success(),
115            Err(_) => false,
116        }
117    }
118
119    async fn health_check(&self) -> anyhow::Result<String> {
120        if self.base_url.is_empty() {
121            return Err(anyhow::anyhow!("Runtime not initialized"));
122        }
123
124        let health_url = format!("{}/health", self.base_url);
125        let resp = self.http_client.get(&health_url).send().await
126            .map_err(|e| anyhow::anyhow!("Health check failed: {}", e))?;
127
128        if resp.status().is_success() {
129            Ok("healthy".to_string())
130        } else {
131            Err(anyhow::anyhow!("Health check returned: {}", resp.status()))
132        }
133    }
134
135    fn base_url(&self) -> String {
136        self.base_url.clone()
137    }
138
139    async fn generate(&self, request: InferenceRequest) -> anyhow::Result<InferenceResponse> {
140        let url = self.completions_url();
141        
142        let payload = serde_json::json!({
143            "model": "safetensors-llm",
144            "messages": request.messages,
145            "max_tokens": request.max_tokens,
146            "temperature": request.temperature,
147            "stream": false,
148        });
149
150        let resp = self.http_client.post(&url).json(&payload).send().await
151            .map_err(|e| anyhow::anyhow!("Inference request failed: {}", e))?;
152
153        if !resp.status().is_success() {
154            let status = resp.status();
155            let body = resp.text().await.unwrap_or_default();
156            return Err(anyhow::anyhow!("Inference failed ({}): {}", status, body));
157        }
158
159        let response: serde_json::Value = resp.json().await
160            .map_err(|e| anyhow::anyhow!("Failed to parse response: {}", e))?;
161
162        let content = response["choices"][0]["message"]["content"].as_str().unwrap_or("").to_string();
163        let finish_reason = response["choices"][0]["finish_reason"].as_str().map(|s| s.to_string());
164
165        Ok(InferenceResponse { content, finish_reason })
166    }
167
168    async fn generate_stream(
169        &self,
170        request: InferenceRequest,
171    ) -> anyhow::Result<Box<dyn futures_util::Stream<Item = Result<String, anyhow::Error>> + Send + Unpin>> {
172        use futures_util::StreamExt;
173        
174        let url = self.completions_url();
175        let payload = serde_json::json!({
176            "model": "safetensors-llm",
177            "messages": request.messages,
178            "max_tokens": request.max_tokens,
179            "temperature": request.temperature,
180            "stream": true,
181        });
182
183        let resp = self.http_client.post(&url).json(&payload).send().await
184            .map_err(|e| anyhow::anyhow!("Stream request failed: {}", e))?;
185
186        if !resp.status().is_success() {
187            let status = resp.status();
188            let body = resp.text().await.unwrap_or_default();
189            return Err(anyhow::anyhow!("Stream failed ({}): {}", status, body));
190        }
191
192        let byte_stream = resp.bytes_stream();
193        let sse_stream = async_stream::try_stream! {
194            let mut buffer = String::new();
195            futures_util::pin_mut!(byte_stream);
196
197            while let Some(chunk_result) = byte_stream.next().await {
198                let chunk = chunk_result.map_err(|e| anyhow::anyhow!("Stream read error: {}", e))?;
199                buffer.push_str(&String::from_utf8_lossy(&chunk));
200
201                while let Some(newline_pos) = buffer.find('\n') {
202                    let line = buffer[..newline_pos].trim().to_string();
203                    buffer = buffer[newline_pos + 1..].to_string();
204
205                    if line.is_empty() || !line.starts_with("data: ") {
206                        continue;
207                    }
208
209                    let data = &line[6..];
210                    if data == "[DONE]" {
211                        return;
212                    }
213
214                    yield format!("data: {}\n\n", data);
215                }
216            }
217        };
218
219        Ok(Box::new(Box::pin(sse_stream)))
220    }
221
222    async fn shutdown(&mut self) -> anyhow::Result<()> {
223        info!("Shutting down Safetensors runtime");
224        
225        if let Some(mut child) = self.server_process.take() {
226            match child.kill() {
227                Ok(_) => {
228                    info!("Safetensors server process killed successfully");
229                    let _ = child.wait();
230                }
231                Err(e) => {
232                    warn!("Failed to kill Safetensors server: {}", e);
233                }
234            }
235        }
236
237        self.config = None;
238        self.base_url.clear();
239        Ok(())
240    }
241
242    fn metadata(&self) -> RuntimeMetadata {
243        RuntimeMetadata {
244            format: ModelFormat::Safetensors,
245            runtime_name: "Safetensors (Candle)".to_string(),
246            version: "latest".to_string(),
247            supports_gpu: true,
248            supports_streaming: true,
249        }
250    }
251}
252
253impl Drop for SafetensorsRuntime {
254    fn drop(&mut self) {
255        if let Some(mut child) = self.server_process.take() {
256            let _ = child.kill();
257            let _ = child.wait();
258        }
259    }
260}