offline_intelligence/model_runtime/
safetensors_runtime.rs1use async_trait::async_trait;
6use super::runtime_trait::*;
7use std::process::{Child, Command, Stdio};
8use std::time::Duration;
9use tracing::{info, warn};
10use tokio::time::sleep;
11
12pub struct SafetensorsRuntime {
13 config: Option<RuntimeConfig>,
14 server_process: Option<Child>,
15 http_client: reqwest::Client,
16 base_url: String,
17}
18
19impl SafetensorsRuntime {
20 pub fn new() -> Self {
21 Self {
22 config: None,
23 server_process: None,
24 http_client: reqwest::Client::builder()
25 .timeout(Duration::from_secs(600))
26 .build()
27 .unwrap_or_default(),
28 base_url: String::new(),
29 }
30 }
31
32 async fn start_server(&mut self, config: &RuntimeConfig) -> anyhow::Result<()> {
33 let binary_path = config.runtime_binary.as_ref()
34 .ok_or_else(|| anyhow::anyhow!("Safetensors runtime requires runtime_binary path (e.g., candle-server)"))?;
35
36 if !binary_path.exists() {
37 return Err(anyhow::anyhow!(
38 "Safetensors server binary not found at: {}. Use Candle or similar framework.",
39 binary_path.display()
40 ));
41 }
42
43 info!("Starting Safetensors server for model: {}", config.model_path.display());
44
45 let mut cmd = Command::new(binary_path);
46 cmd.arg("--model").arg(&config.model_path)
47 .arg("--host").arg(&config.host)
48 .arg("--port").arg(config.port.to_string())
49 .stdout(Stdio::piped())
50 .stderr(Stdio::piped());
51
52 let child = cmd.spawn()
53 .map_err(|e| anyhow::anyhow!("Failed to spawn Safetensors server: {}", e))?;
54
55 self.server_process = Some(child);
56 self.base_url = format!("http://{}:{}", config.host, config.port);
57
58 for attempt in 1..=15 {
59 sleep(Duration::from_secs(2)).await;
60 if self.is_ready().await {
61 info!("✅ Safetensors runtime ready after {} seconds", attempt * 2);
62 return Ok(());
63 }
64 }
65
66 Err(anyhow::anyhow!("Safetensors server failed to start within 30 seconds"))
67 }
68}
69
70impl Default for SafetensorsRuntime {
71 fn default() -> Self {
72 Self::new()
73 }
74}
75
76#[async_trait]
77impl ModelRuntime for SafetensorsRuntime {
78 fn supported_format(&self) -> ModelFormat {
79 ModelFormat::Safetensors
80 }
81
82 async fn initialize(&mut self, config: RuntimeConfig) -> anyhow::Result<()> {
83 info!("Initializing Safetensors runtime");
84
85 if config.format != ModelFormat::Safetensors {
86 return Err(anyhow::anyhow!("Safetensors runtime received wrong format: {:?}", config.format));
87 }
88
89 self.config = Some(config.clone());
90 self.start_server(&config).await?;
91 Ok(())
92 }
93
94 async fn is_ready(&self) -> bool {
95 if self.base_url.is_empty() {
96 return false;
97 }
98
99 let health_url = format!("{}/health", self.base_url);
100 match self.http_client.get(&health_url).send().await {
101 Ok(resp) => resp.status().is_success(),
102 Err(_) => false,
103 }
104 }
105
106 async fn health_check(&self) -> anyhow::Result<String> {
107 if self.base_url.is_empty() {
108 return Err(anyhow::anyhow!("Runtime not initialized"));
109 }
110
111 let health_url = format!("{}/health", self.base_url);
112 let resp = self.http_client.get(&health_url).send().await
113 .map_err(|e| anyhow::anyhow!("Health check failed: {}", e))?;
114
115 if resp.status().is_success() {
116 Ok("healthy".to_string())
117 } else {
118 Err(anyhow::anyhow!("Health check returned: {}", resp.status()))
119 }
120 }
121
122 fn base_url(&self) -> String {
123 self.base_url.clone()
124 }
125
126 async fn generate(&self, request: InferenceRequest) -> anyhow::Result<InferenceResponse> {
127 let url = self.completions_url();
128
129 let payload = serde_json::json!({
130 "model": "safetensors-llm",
131 "messages": request.messages,
132 "max_tokens": request.max_tokens,
133 "temperature": request.temperature,
134 "stream": false,
135 });
136
137 let resp = self.http_client.post(&url).json(&payload).send().await
138 .map_err(|e| anyhow::anyhow!("Inference request failed: {}", e))?;
139
140 if !resp.status().is_success() {
141 let status = resp.status();
142 let body = resp.text().await.unwrap_or_default();
143 return Err(anyhow::anyhow!("Inference failed ({}): {}", status, body));
144 }
145
146 let response: serde_json::Value = resp.json().await
147 .map_err(|e| anyhow::anyhow!("Failed to parse response: {}", e))?;
148
149 let content = response["choices"][0]["message"]["content"].as_str().unwrap_or("").to_string();
150 let finish_reason = response["choices"][0]["finish_reason"].as_str().map(|s| s.to_string());
151
152 Ok(InferenceResponse { content, finish_reason })
153 }
154
155 async fn generate_stream(
156 &self,
157 request: InferenceRequest,
158 ) -> anyhow::Result<Box<dyn futures_util::Stream<Item = Result<String, anyhow::Error>> + Send + Unpin>> {
159 use futures_util::StreamExt;
160
161 let url = self.completions_url();
162 let payload = serde_json::json!({
163 "model": "safetensors-llm",
164 "messages": request.messages,
165 "max_tokens": request.max_tokens,
166 "temperature": request.temperature,
167 "stream": true,
168 });
169
170 let resp = self.http_client.post(&url).json(&payload).send().await
171 .map_err(|e| anyhow::anyhow!("Stream request failed: {}", e))?;
172
173 if !resp.status().is_success() {
174 let status = resp.status();
175 let body = resp.text().await.unwrap_or_default();
176 return Err(anyhow::anyhow!("Stream failed ({}): {}", status, body));
177 }
178
179 let byte_stream = resp.bytes_stream();
180 let sse_stream = async_stream::try_stream! {
181 let mut buffer = String::new();
182 futures_util::pin_mut!(byte_stream);
183
184 while let Some(chunk_result) = byte_stream.next().await {
185 let chunk = chunk_result.map_err(|e| anyhow::anyhow!("Stream read error: {}", e))?;
186 buffer.push_str(&String::from_utf8_lossy(&chunk));
187
188 while let Some(newline_pos) = buffer.find('\n') {
189 let line = buffer[..newline_pos].trim().to_string();
190 buffer = buffer[newline_pos + 1..].to_string();
191
192 if line.is_empty() || !line.starts_with("data: ") {
193 continue;
194 }
195
196 let data = &line[6..];
197 if data == "[DONE]" {
198 return;
199 }
200
201 yield format!("data: {}\n\n", data);
202 }
203 }
204 };
205
206 Ok(Box::new(Box::pin(sse_stream)))
207 }
208
209 async fn shutdown(&mut self) -> anyhow::Result<()> {
210 info!("Shutting down Safetensors runtime");
211
212 if let Some(mut child) = self.server_process.take() {
213 match child.kill() {
214 Ok(_) => {
215 info!("Safetensors server process killed successfully");
216 let _ = child.wait();
217 }
218 Err(e) => {
219 warn!("Failed to kill Safetensors server: {}", e);
220 }
221 }
222 }
223
224 self.config = None;
225 self.base_url.clear();
226 Ok(())
227 }
228
229 fn metadata(&self) -> RuntimeMetadata {
230 RuntimeMetadata {
231 format: ModelFormat::Safetensors,
232 runtime_name: "Safetensors (Candle)".to_string(),
233 version: "latest".to_string(),
234 supports_gpu: true,
235 supports_streaming: true,
236 }
237 }
238}
239
240impl Drop for SafetensorsRuntime {
241 fn drop(&mut self) {
242 if let Some(mut child) = self.server_process.take() {
243 let _ = child.kill();
244 let _ = child.wait();
245 }
246 }
247}