Skip to main content

offline_intelligence/model_runtime/
safetensors_runtime.rs

1//! Safetensors Runtime Adapter
2//!
3//! Adapter for Safetensors format models using Candle framework.
4
5use async_trait::async_trait;
6use super::runtime_trait::*;
7use std::process::{Child, Command, Stdio};
8use std::time::Duration;
9use tracing::{info, warn};
10use tokio::time::sleep;
11
12pub struct SafetensorsRuntime {
13    config: Option<RuntimeConfig>,
14    server_process: Option<Child>,
15    http_client: reqwest::Client,
16    base_url: String,
17}
18
19impl SafetensorsRuntime {
20    pub fn new() -> Self {
21        Self {
22            config: None,
23            server_process: None,
24            http_client: reqwest::Client::builder()
25                .timeout(Duration::from_secs(600))
26                .build()
27                .unwrap_or_default(),
28            base_url: String::new(),
29        }
30    }
31
32    async fn start_server(&mut self, config: &RuntimeConfig) -> anyhow::Result<()> {
33        let binary_path = config.runtime_binary.as_ref()
34            .ok_or_else(|| anyhow::anyhow!("Safetensors runtime requires runtime_binary path (e.g., candle-server)"))?;
35
36        if !binary_path.exists() {
37            return Err(anyhow::anyhow!(
38                "Safetensors server binary not found at: {}. Use Candle or similar framework.",
39                binary_path.display()
40            ));
41        }
42
43        info!("Starting Safetensors server for model: {}", config.model_path.display());
44        
45        let mut cmd = Command::new(binary_path);
46        cmd.arg("--model").arg(&config.model_path)
47            .arg("--host").arg(&config.host)
48            .arg("--port").arg(config.port.to_string())
49            .stdout(Stdio::piped())
50            .stderr(Stdio::piped());
51
52        let child = cmd.spawn()
53            .map_err(|e| anyhow::anyhow!("Failed to spawn Safetensors server: {}", e))?;
54
55        self.server_process = Some(child);
56        self.base_url = format!("http://{}:{}", config.host, config.port);
57
58        for attempt in 1..=15 {
59            sleep(Duration::from_secs(2)).await;
60            if self.is_ready().await {
61                info!("✅ Safetensors runtime ready after {} seconds", attempt * 2);
62                return Ok(());
63            }
64        }
65
66        Err(anyhow::anyhow!("Safetensors server failed to start within 30 seconds"))
67    }
68}
69
70impl Default for SafetensorsRuntime {
71    fn default() -> Self {
72        Self::new()
73    }
74}
75
76#[async_trait]
77impl ModelRuntime for SafetensorsRuntime {
78    fn supported_format(&self) -> ModelFormat {
79        ModelFormat::Safetensors
80    }
81
82    async fn initialize(&mut self, config: RuntimeConfig) -> anyhow::Result<()> {
83        info!("Initializing Safetensors runtime");
84        
85        if config.format != ModelFormat::Safetensors {
86            return Err(anyhow::anyhow!("Safetensors runtime received wrong format: {:?}", config.format));
87        }
88
89        self.config = Some(config.clone());
90        self.start_server(&config).await?;
91        Ok(())
92    }
93
94    async fn is_ready(&self) -> bool {
95        if self.base_url.is_empty() {
96            return false;
97        }
98
99        let health_url = format!("{}/health", self.base_url);
100        match self.http_client.get(&health_url).send().await {
101            Ok(resp) => resp.status().is_success(),
102            Err(_) => false,
103        }
104    }
105
106    async fn health_check(&self) -> anyhow::Result<String> {
107        if self.base_url.is_empty() {
108            return Err(anyhow::anyhow!("Runtime not initialized"));
109        }
110
111        let health_url = format!("{}/health", self.base_url);
112        let resp = self.http_client.get(&health_url).send().await
113            .map_err(|e| anyhow::anyhow!("Health check failed: {}", e))?;
114
115        if resp.status().is_success() {
116            Ok("healthy".to_string())
117        } else {
118            Err(anyhow::anyhow!("Health check returned: {}", resp.status()))
119        }
120    }
121
122    fn base_url(&self) -> String {
123        self.base_url.clone()
124    }
125
126    async fn generate(&self, request: InferenceRequest) -> anyhow::Result<InferenceResponse> {
127        let url = self.completions_url();
128        
129        let payload = serde_json::json!({
130            "model": "safetensors-llm",
131            "messages": request.messages,
132            "max_tokens": request.max_tokens,
133            "temperature": request.temperature,
134            "stream": false,
135        });
136
137        let resp = self.http_client.post(&url).json(&payload).send().await
138            .map_err(|e| anyhow::anyhow!("Inference request failed: {}", e))?;
139
140        if !resp.status().is_success() {
141            let status = resp.status();
142            let body = resp.text().await.unwrap_or_default();
143            return Err(anyhow::anyhow!("Inference failed ({}): {}", status, body));
144        }
145
146        let response: serde_json::Value = resp.json().await
147            .map_err(|e| anyhow::anyhow!("Failed to parse response: {}", e))?;
148
149        let content = response["choices"][0]["message"]["content"].as_str().unwrap_or("").to_string();
150        let finish_reason = response["choices"][0]["finish_reason"].as_str().map(|s| s.to_string());
151
152        Ok(InferenceResponse { content, finish_reason })
153    }
154
155    async fn generate_stream(
156        &self,
157        request: InferenceRequest,
158    ) -> anyhow::Result<Box<dyn futures_util::Stream<Item = Result<String, anyhow::Error>> + Send + Unpin>> {
159        use futures_util::StreamExt;
160        
161        let url = self.completions_url();
162        let payload = serde_json::json!({
163            "model": "safetensors-llm",
164            "messages": request.messages,
165            "max_tokens": request.max_tokens,
166            "temperature": request.temperature,
167            "stream": true,
168        });
169
170        let resp = self.http_client.post(&url).json(&payload).send().await
171            .map_err(|e| anyhow::anyhow!("Stream request failed: {}", e))?;
172
173        if !resp.status().is_success() {
174            let status = resp.status();
175            let body = resp.text().await.unwrap_or_default();
176            return Err(anyhow::anyhow!("Stream failed ({}): {}", status, body));
177        }
178
179        let byte_stream = resp.bytes_stream();
180        let sse_stream = async_stream::try_stream! {
181            let mut buffer = String::new();
182            futures_util::pin_mut!(byte_stream);
183
184            while let Some(chunk_result) = byte_stream.next().await {
185                let chunk = chunk_result.map_err(|e| anyhow::anyhow!("Stream read error: {}", e))?;
186                buffer.push_str(&String::from_utf8_lossy(&chunk));
187
188                while let Some(newline_pos) = buffer.find('\n') {
189                    let line = buffer[..newline_pos].trim().to_string();
190                    buffer = buffer[newline_pos + 1..].to_string();
191
192                    if line.is_empty() || !line.starts_with("data: ") {
193                        continue;
194                    }
195
196                    let data = &line[6..];
197                    if data == "[DONE]" {
198                        return;
199                    }
200
201                    yield format!("data: {}\n\n", data);
202                }
203            }
204        };
205
206        Ok(Box::new(Box::pin(sse_stream)))
207    }
208
209    async fn shutdown(&mut self) -> anyhow::Result<()> {
210        info!("Shutting down Safetensors runtime");
211        
212        if let Some(mut child) = self.server_process.take() {
213            match child.kill() {
214                Ok(_) => {
215                    info!("Safetensors server process killed successfully");
216                    let _ = child.wait();
217                }
218                Err(e) => {
219                    warn!("Failed to kill Safetensors server: {}", e);
220                }
221            }
222        }
223
224        self.config = None;
225        self.base_url.clear();
226        Ok(())
227    }
228
229    fn metadata(&self) -> RuntimeMetadata {
230        RuntimeMetadata {
231            format: ModelFormat::Safetensors,
232            runtime_name: "Safetensors (Candle)".to_string(),
233            version: "latest".to_string(),
234            supports_gpu: true,
235            supports_streaming: true,
236        }
237    }
238}
239
240impl Drop for SafetensorsRuntime {
241    fn drop(&mut self) {
242        if let Some(mut child) = self.server_process.take() {
243            let _ = child.kill();
244            let _ = child.wait();
245        }
246    }
247}