offline_intelligence/model_runtime/
safetensors_runtime.rs1use async_trait::async_trait;
6use super::runtime_trait::*;
7use std::process::{Child, Command, Stdio};
8use std::time::Duration;
9use tracing::{info, warn};
10use tokio::time::sleep;
11
12pub struct SafetensorsRuntime {
13 config: Option<RuntimeConfig>,
14 server_process: Option<Child>,
15 http_client: reqwest::Client,
16 base_url: String,
17}
18
19impl SafetensorsRuntime {
20 pub fn new() -> Self {
21 Self {
22 config: None,
23 server_process: None,
24 http_client: reqwest::Client::builder()
25 .timeout(Duration::from_secs(600))
26 .build()
27 .unwrap_or_default(),
28 base_url: String::new(),
29 }
30 }
31
32 async fn start_server(&mut self, config: &RuntimeConfig) -> anyhow::Result<()> {
33 let binary_path = config.runtime_binary.as_ref()
34 .ok_or_else(|| anyhow::anyhow!("Safetensors runtime requires runtime_binary path (e.g., candle-server)"))?;
35
36 if !binary_path.exists() {
37 return Err(anyhow::anyhow!(
38 "Safetensors server binary not found at: {}. Use Candle or similar framework.",
39 binary_path.display()
40 ));
41 }
42
43 info!("Starting Safetensors server for model: {}", config.model_path.display());
44
45 let mut cmd = Command::new(binary_path);
46 cmd.arg("--model").arg(&config.model_path)
47 .arg("--host").arg(&config.host)
48 .arg("--port").arg(config.port.to_string())
49 .stdout(Stdio::piped())
50 .stderr(Stdio::piped());
51
52 let child = cmd.spawn()
53 .map_err(|e| anyhow::anyhow!("Failed to spawn Safetensors server: {}", e))?;
54
55 self.server_process = Some(child);
56 self.base_url = format!("http://{}:{}", config.host, config.port);
57
58 let _start = std::time::Instant::now();
60 let mut delay_ms: u64 = 100;
61 let mut last_log_secs: u64 = 0;
62 loop {
63 sleep(Duration::from_millis(delay_ms)).await;
64 if self.is_ready().await {
65 info!("✅ Safetensors runtime ready after {:.1}s", _start.elapsed().as_secs_f64());
66 return Ok(());
67 }
68 let elapsed_secs = _start.elapsed().as_secs();
69 if elapsed_secs >= 120 {
70 break;
71 }
72 if elapsed_secs >= last_log_secs + 10 {
73 info!("Still waiting for Safetensors server... ({}/120s)", elapsed_secs);
74 last_log_secs = elapsed_secs;
75 }
76 delay_ms = (delay_ms * 2).min(2_000);
77 }
78
79 Err(anyhow::anyhow!("Safetensors server failed to start within 120 seconds"))
80 }
81}
82
83impl Default for SafetensorsRuntime {
84 fn default() -> Self {
85 Self::new()
86 }
87}
88
89#[async_trait]
90impl ModelRuntime for SafetensorsRuntime {
91 fn supported_format(&self) -> ModelFormat {
92 ModelFormat::Safetensors
93 }
94
95 async fn initialize(&mut self, config: RuntimeConfig) -> anyhow::Result<()> {
96 info!("Initializing Safetensors runtime");
97
98 if config.format != ModelFormat::Safetensors {
99 return Err(anyhow::anyhow!("Safetensors runtime received wrong format: {:?}", config.format));
100 }
101
102 self.config = Some(config.clone());
103 self.start_server(&config).await?;
104 Ok(())
105 }
106
107 async fn is_ready(&self) -> bool {
108 if self.base_url.is_empty() {
109 return false;
110 }
111
112 let health_url = format!("{}/health", self.base_url);
113 match self.http_client.get(&health_url).send().await {
114 Ok(resp) => resp.status().is_success(),
115 Err(_) => false,
116 }
117 }
118
119 async fn health_check(&self) -> anyhow::Result<String> {
120 if self.base_url.is_empty() {
121 return Err(anyhow::anyhow!("Runtime not initialized"));
122 }
123
124 let health_url = format!("{}/health", self.base_url);
125 let resp = self.http_client.get(&health_url).send().await
126 .map_err(|e| anyhow::anyhow!("Health check failed: {}", e))?;
127
128 if resp.status().is_success() {
129 Ok("healthy".to_string())
130 } else {
131 Err(anyhow::anyhow!("Health check returned: {}", resp.status()))
132 }
133 }
134
135 fn base_url(&self) -> String {
136 self.base_url.clone()
137 }
138
139 async fn generate(&self, request: InferenceRequest) -> anyhow::Result<InferenceResponse> {
140 let url = self.completions_url();
141
142 let payload = serde_json::json!({
143 "model": "safetensors-llm",
144 "messages": request.messages,
145 "max_tokens": request.max_tokens,
146 "temperature": request.temperature,
147 "stream": false,
148 });
149
150 let resp = self.http_client.post(&url).json(&payload).send().await
151 .map_err(|e| anyhow::anyhow!("Inference request failed: {}", e))?;
152
153 if !resp.status().is_success() {
154 let status = resp.status();
155 let body = resp.text().await.unwrap_or_default();
156 return Err(anyhow::anyhow!("Inference failed ({}): {}", status, body));
157 }
158
159 let response: serde_json::Value = resp.json().await
160 .map_err(|e| anyhow::anyhow!("Failed to parse response: {}", e))?;
161
162 let content = response["choices"][0]["message"]["content"].as_str().unwrap_or("").to_string();
163 let finish_reason = response["choices"][0]["finish_reason"].as_str().map(|s| s.to_string());
164
165 Ok(InferenceResponse { content, finish_reason })
166 }
167
168 async fn generate_stream(
169 &self,
170 request: InferenceRequest,
171 ) -> anyhow::Result<Box<dyn futures_util::Stream<Item = Result<String, anyhow::Error>> + Send + Unpin>> {
172 use futures_util::StreamExt;
173
174 let url = self.completions_url();
175 let payload = serde_json::json!({
176 "model": "safetensors-llm",
177 "messages": request.messages,
178 "max_tokens": request.max_tokens,
179 "temperature": request.temperature,
180 "stream": true,
181 });
182
183 let resp = self.http_client.post(&url).json(&payload).send().await
184 .map_err(|e| anyhow::anyhow!("Stream request failed: {}", e))?;
185
186 if !resp.status().is_success() {
187 let status = resp.status();
188 let body = resp.text().await.unwrap_or_default();
189 return Err(anyhow::anyhow!("Stream failed ({}): {}", status, body));
190 }
191
192 let byte_stream = resp.bytes_stream();
193 let sse_stream = async_stream::try_stream! {
194 let mut buffer = String::new();
195 futures_util::pin_mut!(byte_stream);
196
197 while let Some(chunk_result) = byte_stream.next().await {
198 let chunk = chunk_result.map_err(|e| anyhow::anyhow!("Stream read error: {}", e))?;
199 buffer.push_str(&String::from_utf8_lossy(&chunk));
200
201 while let Some(newline_pos) = buffer.find('\n') {
202 let line = buffer[..newline_pos].trim().to_string();
203 buffer = buffer[newline_pos + 1..].to_string();
204
205 if line.is_empty() || !line.starts_with("data: ") {
206 continue;
207 }
208
209 let data = &line[6..];
210 if data == "[DONE]" {
211 return;
212 }
213
214 yield format!("data: {}\n\n", data);
215 }
216 }
217 };
218
219 Ok(Box::new(Box::pin(sse_stream)))
220 }
221
222 async fn shutdown(&mut self) -> anyhow::Result<()> {
223 info!("Shutting down Safetensors runtime");
224
225 if let Some(mut child) = self.server_process.take() {
226 match child.kill() {
227 Ok(_) => {
228 info!("Safetensors server process killed successfully");
229 let _ = child.wait();
230 }
231 Err(e) => {
232 warn!("Failed to kill Safetensors server: {}", e);
233 }
234 }
235 }
236
237 self.config = None;
238 self.base_url.clear();
239 Ok(())
240 }
241
242 fn metadata(&self) -> RuntimeMetadata {
243 RuntimeMetadata {
244 format: ModelFormat::Safetensors,
245 runtime_name: "Safetensors (Candle)".to_string(),
246 version: "latest".to_string(),
247 supports_gpu: true,
248 supports_streaming: true,
249 }
250 }
251}
252
253impl Drop for SafetensorsRuntime {
254 fn drop(&mut self) {
255 if let Some(mut child) = self.server_process.take() {
256 let _ = child.kill();
257 let _ = child.wait();
258 }
259 }
260}