1use serde::Deserialize;
7use tokio::sync::mpsc;
8
9use crate::error::{HooshError, Result};
10use crate::inference::{InferenceRequest, InferenceResponse, ModelInfo, Role, TokenUsage};
11
12#[derive(Debug, Clone)]
14pub struct HooshClient {
15 base_url: String,
16 client: reqwest::Client,
17}
18
19fn to_chat_body(request: &InferenceRequest) -> serde_json::Value {
21 let messages: Vec<serde_json::Value> = if request.messages.is_empty() {
22 let mut msgs = Vec::new();
23 if let Some(sys) = &request.system {
24 msgs.push(serde_json::json!({"role": "system", "content": sys}));
25 }
26 msgs.push(serde_json::json!({"role": "user", "content": request.prompt}));
27 msgs
28 } else {
29 request
30 .messages
31 .iter()
32 .map(|m| {
33 let role = match m.role {
34 Role::System => "system",
35 Role::User => "user",
36 Role::Assistant => "assistant",
37 Role::Tool => "tool",
38 };
39 serde_json::json!({"role": role, "content": m.content})
40 })
41 .collect()
42 };
43
44 let mut body = serde_json::json!({
45 "model": request.model,
46 "messages": messages,
47 "stream": request.stream,
48 });
49 if let Some(max) = request.max_tokens {
50 body["max_tokens"] = serde_json::json!(max);
51 }
52 if let Some(temp) = request.temperature {
53 body["temperature"] = serde_json::json!(temp);
54 }
55 if let Some(tp) = request.top_p {
56 body["top_p"] = serde_json::json!(tp);
57 }
58 body
59}
60
61#[derive(Deserialize)]
62struct ChatCompletionResp {
63 model: Option<String>,
64 choices: Vec<ChatChoice>,
65 usage: Option<ChatUsageResp>,
66}
67
68#[derive(Deserialize)]
69struct ChatChoice {
70 message: ChatMsg,
71}
72
73#[derive(Deserialize)]
74struct ChatMsg {
75 content: Option<String>,
76}
77
78#[derive(Deserialize)]
79struct ChatUsageResp {
80 prompt_tokens: Option<u32>,
81 completion_tokens: Option<u32>,
82 total_tokens: Option<u32>,
83}
84
85#[derive(Deserialize)]
86struct StreamChunk {
87 choices: Vec<StreamChoice>,
88}
89
90#[derive(Deserialize)]
91struct StreamChoice {
92 delta: StreamDelta,
93}
94
95#[derive(Deserialize)]
96struct StreamDelta {
97 content: Option<String>,
98}
99
100#[derive(Deserialize)]
101struct ModelsResp {
102 data: Vec<ModelObj>,
103}
104
105#[derive(Deserialize)]
106struct ModelObj {
107 id: String,
108 owned_by: Option<String>,
109}
110
111impl HooshClient {
112 pub fn new(base_url: impl Into<String>) -> Self {
119 Self {
120 base_url: base_url.into().trim_end_matches('/').to_string(),
121 client: reqwest::Client::builder()
122 .tcp_nodelay(true)
123 .tcp_keepalive(std::time::Duration::from_secs(60))
124 .pool_idle_timeout(std::time::Duration::from_secs(600))
125 .pool_max_idle_per_host(32)
126 .http2_adaptive_window(true)
127 .connect_timeout(std::time::Duration::from_secs(10))
128 .build()
129 .unwrap_or_default(),
130 }
131 }
132
133 pub async fn infer(&self, request: &InferenceRequest) -> Result<InferenceResponse> {
135 let url = format!("{}/v1/chat/completions", self.base_url);
136 let body = to_chat_body(&InferenceRequest {
137 stream: false,
138 ..request.clone()
139 });
140
141 let resp = self
142 .client
143 .post(&url)
144 .json(&body)
145 .send()
146 .await?
147 .error_for_status()
148 .map_err(|e| HooshError::Provider(e.to_string()))?;
149
150 let parsed: ChatCompletionResp = resp
151 .json()
152 .await
153 .map_err(|e| HooshError::Provider(e.to_string()))?;
154
155 let text = parsed
156 .choices
157 .first()
158 .and_then(|c| c.message.content.clone())
159 .unwrap_or_default();
160
161 let usage = parsed.usage.as_ref();
162 Ok(InferenceResponse {
163 text,
164 model: parsed.model.unwrap_or_else(|| request.model.clone()),
165 usage: TokenUsage {
166 prompt_tokens: usage.and_then(|u| u.prompt_tokens).unwrap_or(0),
167 completion_tokens: usage.and_then(|u| u.completion_tokens).unwrap_or(0),
168 total_tokens: usage.and_then(|u| u.total_tokens).unwrap_or(0),
169 },
170 provider: "hoosh".into(),
171 latency_ms: 0,
172 tool_calls: Vec::new(),
173 })
174 }
175
176 pub async fn infer_stream(
178 &self,
179 request: &InferenceRequest,
180 ) -> Result<mpsc::Receiver<std::result::Result<String, HooshError>>> {
181 let url = format!("{}/v1/chat/completions", self.base_url);
182 let body = to_chat_body(&InferenceRequest {
183 stream: true,
184 ..request.clone()
185 });
186
187 let resp = self
188 .client
189 .post(&url)
190 .json(&body)
191 .send()
192 .await?
193 .error_for_status()
194 .map_err(|e| HooshError::Provider(e.to_string()))?;
195
196 if let Some(ct) = resp.headers().get("content-type") {
197 let ct_str = ct.to_str().unwrap_or("");
198 if !ct_str.contains("text/event-stream") && !ct_str.contains("application/json") {
199 return Err(HooshError::Provider(format!(
200 "expected SSE stream, got Content-Type: {ct_str}"
201 )));
202 }
203 }
204
205 let (tx, rx) = mpsc::channel(256);
206
207 tokio::spawn(async move {
208 use futures::StreamExt;
209 let mut stream = resp.bytes_stream();
210 let mut buf = String::new();
211
212 while let Some(chunk) = stream.next().await {
213 let chunk = match chunk {
214 Ok(c) => c,
215 Err(e) => {
216 let _ = tx.send(Err(HooshError::Provider(e.to_string()))).await;
217 return;
218 }
219 };
220 if buf.len() + chunk.len() > 1024 * 1024 {
221 let _ = tx
222 .send(Err(HooshError::Provider(
223 "SSE line exceeded 1MB limit".into(),
224 )))
225 .await;
226 return;
227 }
228 buf.push_str(&String::from_utf8_lossy(&chunk));
229
230 while let Some(pos) = buf.find('\n') {
231 let line = buf[..pos].trim().to_string();
232 buf = buf[pos + 1..].to_string();
233
234 if line.is_empty() || line.starts_with(':') {
235 continue;
236 }
237 let data = if let Some(d) = line.strip_prefix("data: ") {
238 d.trim()
239 } else if let Some(d) = line.strip_prefix("data:") {
240 d.trim()
241 } else {
242 continue;
243 };
244 if data == "[DONE]" {
245 return;
246 }
247 if let Ok(chunk) = serde_json::from_str::<StreamChunk>(data) {
248 for choice in &chunk.choices {
249 if let Some(content) = &choice.delta.content
250 && !content.is_empty()
251 && tx.send(Ok(content.clone())).await.is_err()
252 {
253 return;
254 }
255 }
256 }
257 }
258 }
259 });
260
261 Ok(rx)
262 }
263
264 pub async fn list_models(&self) -> Result<Vec<ModelInfo>> {
266 let url = format!("{}/v1/models", self.base_url);
267 let resp = self
268 .client
269 .get(&url)
270 .send()
271 .await?
272 .error_for_status()
273 .map_err(|e| HooshError::Provider(e.to_string()))?;
274
275 let parsed: ModelsResp = resp
276 .json()
277 .await
278 .map_err(|e| HooshError::Provider(e.to_string()))?;
279
280 Ok(parsed
281 .data
282 .into_iter()
283 .map(|m| ModelInfo {
284 id: m.id.clone(),
285 name: m.id,
286 provider: m.owned_by.unwrap_or_else(|| "hoosh".into()),
287 parameters: None,
288 context_length: None,
289 available: true,
290 })
291 .collect())
292 }
293
294 pub async fn health(&self) -> Result<bool> {
296 let url = format!("{}/v1/health", self.base_url);
297 match self.client.get(&url).send().await {
298 Ok(resp) => Ok(resp.status().is_success()),
299 Err(_) => Ok(false),
300 }
301 }
302
303 pub fn base_url(&self) -> &str {
305 &self.base_url
306 }
307}
308
309#[cfg(test)]
310mod tests {
311 use super::*;
312
313 #[test]
314 fn client_creation() {
315 let client = HooshClient::new("http://localhost:8088");
316 assert_eq!(client.base_url(), "http://localhost:8088");
317 }
318
319 #[test]
320 fn client_strips_trailing_slash() {
321 let client = HooshClient::new("http://localhost:8088/");
322 assert_eq!(client.base_url(), "http://localhost:8088");
323 }
324
325 #[test]
326 fn client_strips_multiple_trailing_slashes() {
327 let client = HooshClient::new("http://localhost:8088///");
328 assert_eq!(client.base_url(), "http://localhost:8088");
330 }
331
332 #[test]
333 fn to_chat_body_with_messages() {
334 let request = InferenceRequest {
335 model: "llama3".into(),
336 messages: vec![
337 crate::inference::Message::new(Role::System, "You are a helper."),
338 crate::inference::Message::new(Role::User, "Hello"),
339 crate::inference::Message::new(Role::Assistant, "Hi there!"),
340 crate::inference::Message::new(Role::Tool, "tool result"),
341 ],
342 stream: false,
343 ..Default::default()
344 };
345 let body = to_chat_body(&request);
346 let messages = body["messages"].as_array().unwrap();
347 assert_eq!(messages.len(), 4);
348 assert_eq!(messages[0]["role"], "system");
349 assert_eq!(messages[1]["role"], "user");
350 assert_eq!(messages[2]["role"], "assistant");
351 assert_eq!(messages[3]["role"], "tool");
352 }
353
354 #[test]
355 fn to_chat_body_no_messages_uses_prompt() {
356 let request = InferenceRequest {
357 model: "llama3".into(),
358 prompt: "Hello world".into(),
359 system: None,
360 messages: vec![],
361 stream: false,
362 ..Default::default()
363 };
364 let body = to_chat_body(&request);
365 let messages = body["messages"].as_array().unwrap();
366 assert_eq!(messages.len(), 1);
367 assert_eq!(messages[0]["role"], "user");
368 assert_eq!(messages[0]["content"], "Hello world");
369 }
370
371 #[test]
372 fn to_chat_body_no_messages_with_system() {
373 let request = InferenceRequest {
374 model: "llama3".into(),
375 prompt: "Hello".into(),
376 system: Some("You are helpful.".into()),
377 messages: vec![],
378 stream: false,
379 ..Default::default()
380 };
381 let body = to_chat_body(&request);
382 let messages = body["messages"].as_array().unwrap();
383 assert_eq!(messages.len(), 2);
384 assert_eq!(messages[0]["role"], "system");
385 assert_eq!(messages[0]["content"], "You are helpful.");
386 assert_eq!(messages[1]["role"], "user");
387 }
388
389 #[test]
390 fn to_chat_body_with_optional_params() {
391 let request = InferenceRequest {
392 model: "gpt-4o".into(),
393 prompt: "test".into(),
394 max_tokens: Some(500),
395 temperature: Some(0.7),
396 top_p: Some(0.9),
397 stream: true,
398 ..Default::default()
399 };
400 let body = to_chat_body(&request);
401 assert_eq!(body["max_tokens"], 500);
402 assert_eq!(body["temperature"], 0.7);
403 assert_eq!(body["top_p"], 0.9);
404 assert_eq!(body["stream"], true);
405 }
406
407 #[test]
408 fn to_chat_body_without_optional_params() {
409 let request = InferenceRequest {
410 model: "gpt-4o".into(),
411 prompt: "test".into(),
412 ..Default::default()
413 };
414 let body = to_chat_body(&request);
415 assert!(body.get("max_tokens").is_none());
416 assert!(body.get("temperature").is_none());
417 assert!(body.get("top_p").is_none());
418 }
419
420 #[tokio::test]
421 async fn health_unreachable_server() {
422 let client = HooshClient::new("http://127.0.0.1:1");
423 let result = client.health().await.unwrap();
424 assert!(!result);
425 }
426
427 #[tokio::test]
428 async fn infer_connection_refused() {
429 let client = HooshClient::new("http://127.0.0.1:1");
430 let request = InferenceRequest {
431 model: "test".into(),
432 prompt: "hello".into(),
433 ..Default::default()
434 };
435 let result = client.infer(&request).await;
436 assert!(result.is_err());
437 }
438
439 #[tokio::test]
440 async fn list_models_connection_refused() {
441 let client = HooshClient::new("http://127.0.0.1:1");
442 let result = client.list_models().await;
443 assert!(result.is_err());
444 }
445
446 #[tokio::test]
447 async fn infer_stream_connection_refused() {
448 let client = HooshClient::new("http://127.0.0.1:1");
449 let request = InferenceRequest {
450 model: "test".into(),
451 prompt: "hello".into(),
452 ..Default::default()
453 };
454 let result = client.infer_stream(&request).await;
455 assert!(result.is_err());
456 }
457}