Skip to main content

hoosh/
client.rs

1//! HTTP client for talking to a hoosh server.
2//!
3//! This is what downstream crates (tarang, daimon, consumer apps) use to
4//! call hoosh over the network. OpenAI-compatible API.
5
6use serde::Deserialize;
7use tokio::sync::mpsc;
8
9use crate::error::{HooshError, Result};
10use crate::inference::{InferenceRequest, InferenceResponse, ModelInfo, Role, TokenUsage};
11
12/// HTTP client for the hoosh inference gateway.
13#[derive(Debug, Clone)]
14pub struct HooshClient {
15    base_url: String,
16    client: reqwest::Client,
17}
18
19/// Build an OpenAI-compatible chat body from an InferenceRequest.
20fn to_chat_body(request: &InferenceRequest) -> serde_json::Value {
21    let messages: Vec<serde_json::Value> = if request.messages.is_empty() {
22        let mut msgs = Vec::new();
23        if let Some(sys) = &request.system {
24            msgs.push(serde_json::json!({"role": "system", "content": sys}));
25        }
26        msgs.push(serde_json::json!({"role": "user", "content": request.prompt}));
27        msgs
28    } else {
29        request
30            .messages
31            .iter()
32            .map(|m| {
33                let role = match m.role {
34                    Role::System => "system",
35                    Role::User => "user",
36                    Role::Assistant => "assistant",
37                    Role::Tool => "tool",
38                };
39                serde_json::json!({"role": role, "content": m.content})
40            })
41            .collect()
42    };
43
44    let mut body = serde_json::json!({
45        "model": request.model,
46        "messages": messages,
47        "stream": request.stream,
48    });
49    if let Some(max) = request.max_tokens {
50        body["max_tokens"] = serde_json::json!(max);
51    }
52    if let Some(temp) = request.temperature {
53        body["temperature"] = serde_json::json!(temp);
54    }
55    if let Some(tp) = request.top_p {
56        body["top_p"] = serde_json::json!(tp);
57    }
58    body
59}
60
61#[derive(Deserialize)]
62struct ChatCompletionResp {
63    model: Option<String>,
64    choices: Vec<ChatChoice>,
65    usage: Option<ChatUsageResp>,
66}
67
68#[derive(Deserialize)]
69struct ChatChoice {
70    message: ChatMsg,
71}
72
73#[derive(Deserialize)]
74struct ChatMsg {
75    content: Option<String>,
76}
77
78#[derive(Deserialize)]
79struct ChatUsageResp {
80    prompt_tokens: Option<u32>,
81    completion_tokens: Option<u32>,
82    total_tokens: Option<u32>,
83}
84
85#[derive(Deserialize)]
86struct StreamChunk {
87    choices: Vec<StreamChoice>,
88}
89
90#[derive(Deserialize)]
91struct StreamChoice {
92    delta: StreamDelta,
93}
94
95#[derive(Deserialize)]
96struct StreamDelta {
97    content: Option<String>,
98}
99
100#[derive(Deserialize)]
101struct ModelsResp {
102    data: Vec<ModelObj>,
103}
104
105#[derive(Deserialize)]
106struct ModelObj {
107    id: String,
108    owned_by: Option<String>,
109}
110
111impl HooshClient {
112    /// Create a new client pointing at the given hoosh server.
113    ///
114    /// The client is tuned for low-latency local connections:
115    /// - TCP_NODELAY disables Nagle's algorithm (avoids 40ms batching delay)
116    /// - Connection pooling keeps TCP connections alive across requests
117    /// - HTTP/2 adaptive window for multiplexed requests
118    pub fn new(base_url: impl Into<String>) -> Self {
119        Self {
120            base_url: base_url.into().trim_end_matches('/').to_string(),
121            client: reqwest::Client::builder()
122                .tcp_nodelay(true)
123                .tcp_keepalive(std::time::Duration::from_secs(60))
124                .pool_idle_timeout(std::time::Duration::from_secs(600))
125                .pool_max_idle_per_host(32)
126                .http2_adaptive_window(true)
127                .connect_timeout(std::time::Duration::from_secs(10))
128                .build()
129                .unwrap_or_default(),
130        }
131    }
132
133    /// Run inference via the hoosh server.
134    pub async fn infer(&self, request: &InferenceRequest) -> Result<InferenceResponse> {
135        let url = format!("{}/v1/chat/completions", self.base_url);
136        let body = to_chat_body(&InferenceRequest {
137            stream: false,
138            ..request.clone()
139        });
140
141        let resp = self
142            .client
143            .post(&url)
144            .json(&body)
145            .send()
146            .await?
147            .error_for_status()
148            .map_err(|e| HooshError::Provider(e.to_string()))?;
149
150        let parsed: ChatCompletionResp = resp
151            .json()
152            .await
153            .map_err(|e| HooshError::Provider(e.to_string()))?;
154
155        let text = parsed
156            .choices
157            .first()
158            .and_then(|c| c.message.content.clone())
159            .unwrap_or_default();
160
161        let usage = parsed.usage.as_ref();
162        Ok(InferenceResponse {
163            text,
164            model: parsed.model.unwrap_or_else(|| request.model.clone()),
165            usage: TokenUsage {
166                prompt_tokens: usage.and_then(|u| u.prompt_tokens).unwrap_or(0),
167                completion_tokens: usage.and_then(|u| u.completion_tokens).unwrap_or(0),
168                total_tokens: usage.and_then(|u| u.total_tokens).unwrap_or(0),
169            },
170            provider: "hoosh".into(),
171            latency_ms: 0,
172            tool_calls: Vec::new(),
173        })
174    }
175
176    /// Stream inference results token by token.
177    pub async fn infer_stream(
178        &self,
179        request: &InferenceRequest,
180    ) -> Result<mpsc::Receiver<std::result::Result<String, HooshError>>> {
181        let url = format!("{}/v1/chat/completions", self.base_url);
182        let body = to_chat_body(&InferenceRequest {
183            stream: true,
184            ..request.clone()
185        });
186
187        let resp = self
188            .client
189            .post(&url)
190            .json(&body)
191            .send()
192            .await?
193            .error_for_status()
194            .map_err(|e| HooshError::Provider(e.to_string()))?;
195
196        if let Some(ct) = resp.headers().get("content-type") {
197            let ct_str = ct.to_str().unwrap_or("");
198            if !ct_str.contains("text/event-stream") && !ct_str.contains("application/json") {
199                return Err(HooshError::Provider(format!(
200                    "expected SSE stream, got Content-Type: {ct_str}"
201                )));
202            }
203        }
204
205        let (tx, rx) = mpsc::channel(256);
206
207        tokio::spawn(async move {
208            use futures::StreamExt;
209            let mut stream = resp.bytes_stream();
210            let mut buf = String::new();
211
212            while let Some(chunk) = stream.next().await {
213                let chunk = match chunk {
214                    Ok(c) => c,
215                    Err(e) => {
216                        let _ = tx.send(Err(HooshError::Provider(e.to_string()))).await;
217                        return;
218                    }
219                };
220                if buf.len() + chunk.len() > 1024 * 1024 {
221                    let _ = tx
222                        .send(Err(HooshError::Provider(
223                            "SSE line exceeded 1MB limit".into(),
224                        )))
225                        .await;
226                    return;
227                }
228                buf.push_str(&String::from_utf8_lossy(&chunk));
229
230                while let Some(pos) = buf.find('\n') {
231                    let line = buf[..pos].trim().to_string();
232                    buf = buf[pos + 1..].to_string();
233
234                    if line.is_empty() || line.starts_with(':') {
235                        continue;
236                    }
237                    let data = if let Some(d) = line.strip_prefix("data: ") {
238                        d.trim()
239                    } else if let Some(d) = line.strip_prefix("data:") {
240                        d.trim()
241                    } else {
242                        continue;
243                    };
244                    if data == "[DONE]" {
245                        return;
246                    }
247                    if let Ok(chunk) = serde_json::from_str::<StreamChunk>(data) {
248                        for choice in &chunk.choices {
249                            if let Some(content) = &choice.delta.content
250                                && !content.is_empty()
251                                && tx.send(Ok(content.clone())).await.is_err()
252                            {
253                                return;
254                            }
255                        }
256                    }
257                }
258            }
259        });
260
261        Ok(rx)
262    }
263
264    /// List available models.
265    pub async fn list_models(&self) -> Result<Vec<ModelInfo>> {
266        let url = format!("{}/v1/models", self.base_url);
267        let resp = self
268            .client
269            .get(&url)
270            .send()
271            .await?
272            .error_for_status()
273            .map_err(|e| HooshError::Provider(e.to_string()))?;
274
275        let parsed: ModelsResp = resp
276            .json()
277            .await
278            .map_err(|e| HooshError::Provider(e.to_string()))?;
279
280        Ok(parsed
281            .data
282            .into_iter()
283            .map(|m| ModelInfo {
284                id: m.id.clone(),
285                name: m.id,
286                provider: m.owned_by.unwrap_or_else(|| "hoosh".into()),
287                parameters: None,
288                context_length: None,
289                available: true,
290            })
291            .collect())
292    }
293
294    /// Health check.
295    pub async fn health(&self) -> Result<bool> {
296        let url = format!("{}/v1/health", self.base_url);
297        match self.client.get(&url).send().await {
298            Ok(resp) => Ok(resp.status().is_success()),
299            Err(_) => Ok(false),
300        }
301    }
302
303    /// Base URL of the hoosh server.
304    pub fn base_url(&self) -> &str {
305        &self.base_url
306    }
307}
308
309#[cfg(test)]
310mod tests {
311    use super::*;
312
313    #[test]
314    fn client_creation() {
315        let client = HooshClient::new("http://localhost:8088");
316        assert_eq!(client.base_url(), "http://localhost:8088");
317    }
318
319    #[test]
320    fn client_strips_trailing_slash() {
321        let client = HooshClient::new("http://localhost:8088/");
322        assert_eq!(client.base_url(), "http://localhost:8088");
323    }
324
325    #[test]
326    fn client_strips_multiple_trailing_slashes() {
327        let client = HooshClient::new("http://localhost:8088///");
328        // trim_end_matches removes all trailing slashes
329        assert_eq!(client.base_url(), "http://localhost:8088");
330    }
331
332    #[test]
333    fn to_chat_body_with_messages() {
334        let request = InferenceRequest {
335            model: "llama3".into(),
336            messages: vec![
337                crate::inference::Message::new(Role::System, "You are a helper."),
338                crate::inference::Message::new(Role::User, "Hello"),
339                crate::inference::Message::new(Role::Assistant, "Hi there!"),
340                crate::inference::Message::new(Role::Tool, "tool result"),
341            ],
342            stream: false,
343            ..Default::default()
344        };
345        let body = to_chat_body(&request);
346        let messages = body["messages"].as_array().unwrap();
347        assert_eq!(messages.len(), 4);
348        assert_eq!(messages[0]["role"], "system");
349        assert_eq!(messages[1]["role"], "user");
350        assert_eq!(messages[2]["role"], "assistant");
351        assert_eq!(messages[3]["role"], "tool");
352    }
353
354    #[test]
355    fn to_chat_body_no_messages_uses_prompt() {
356        let request = InferenceRequest {
357            model: "llama3".into(),
358            prompt: "Hello world".into(),
359            system: None,
360            messages: vec![],
361            stream: false,
362            ..Default::default()
363        };
364        let body = to_chat_body(&request);
365        let messages = body["messages"].as_array().unwrap();
366        assert_eq!(messages.len(), 1);
367        assert_eq!(messages[0]["role"], "user");
368        assert_eq!(messages[0]["content"], "Hello world");
369    }
370
371    #[test]
372    fn to_chat_body_no_messages_with_system() {
373        let request = InferenceRequest {
374            model: "llama3".into(),
375            prompt: "Hello".into(),
376            system: Some("You are helpful.".into()),
377            messages: vec![],
378            stream: false,
379            ..Default::default()
380        };
381        let body = to_chat_body(&request);
382        let messages = body["messages"].as_array().unwrap();
383        assert_eq!(messages.len(), 2);
384        assert_eq!(messages[0]["role"], "system");
385        assert_eq!(messages[0]["content"], "You are helpful.");
386        assert_eq!(messages[1]["role"], "user");
387    }
388
389    #[test]
390    fn to_chat_body_with_optional_params() {
391        let request = InferenceRequest {
392            model: "gpt-4o".into(),
393            prompt: "test".into(),
394            max_tokens: Some(500),
395            temperature: Some(0.7),
396            top_p: Some(0.9),
397            stream: true,
398            ..Default::default()
399        };
400        let body = to_chat_body(&request);
401        assert_eq!(body["max_tokens"], 500);
402        assert_eq!(body["temperature"], 0.7);
403        assert_eq!(body["top_p"], 0.9);
404        assert_eq!(body["stream"], true);
405    }
406
407    #[test]
408    fn to_chat_body_without_optional_params() {
409        let request = InferenceRequest {
410            model: "gpt-4o".into(),
411            prompt: "test".into(),
412            ..Default::default()
413        };
414        let body = to_chat_body(&request);
415        assert!(body.get("max_tokens").is_none());
416        assert!(body.get("temperature").is_none());
417        assert!(body.get("top_p").is_none());
418    }
419
420    #[tokio::test]
421    async fn health_unreachable_server() {
422        let client = HooshClient::new("http://127.0.0.1:1");
423        let result = client.health().await.unwrap();
424        assert!(!result);
425    }
426
427    #[tokio::test]
428    async fn infer_connection_refused() {
429        let client = HooshClient::new("http://127.0.0.1:1");
430        let request = InferenceRequest {
431            model: "test".into(),
432            prompt: "hello".into(),
433            ..Default::default()
434        };
435        let result = client.infer(&request).await;
436        assert!(result.is_err());
437    }
438
439    #[tokio::test]
440    async fn list_models_connection_refused() {
441        let client = HooshClient::new("http://127.0.0.1:1");
442        let result = client.list_models().await;
443        assert!(result.is_err());
444    }
445
446    #[tokio::test]
447    async fn infer_stream_connection_refused() {
448        let client = HooshClient::new("http://127.0.0.1:1");
449        let request = InferenceRequest {
450            model: "test".into(),
451            prompt: "hello".into(),
452            ..Default::default()
453        };
454        let result = client.infer_stream(&request).await;
455        assert!(result.is_err());
456    }
457}