1use super::error::LlmError;
2use super::stream::SseStream;
3use super::types::{ChatRequest, ChatResponse};
4
5#[derive(Debug)]
7pub struct LlmClient {
8 http_client: reqwest::Client,
9 base_url: String,
10 api_key: String,
11}
12
13impl LlmClient {
14 pub fn new(api_base: &str, api_key: &str) -> Self {
16 Self {
17 http_client: reqwest::Client::new(),
18 base_url: api_base.trim_end_matches('/').to_string(),
19 api_key: api_key.to_string(),
20 }
21 }
22
23 const CHAT_COMPLETIONS_PATH: &'static str = "/chat/completions";
24
25 fn endpoint(&self) -> String {
26 format!("{}{}", self.base_url, Self::CHAT_COMPLETIONS_PATH)
27 }
28
29 pub async fn chat_completion(&self, request: &ChatRequest) -> Result<ChatResponse, LlmError> {
31 let resp = self
32 .http_client
33 .post(self.endpoint())
34 .bearer_auth(&self.api_key)
35 .json(request)
36 .send()
37 .await?;
38
39 let status = resp.status();
40 if !status.is_success() {
41 let body = resp.text().await.unwrap_or_default();
42 return Err(LlmError::Api {
43 status: status.as_u16(),
44 body,
45 });
46 }
47
48 let body = resp.text().await?;
49 serde_json::from_str::<ChatResponse>(&body).map_err(|e| {
50 LlmError::Deserialize(format!("Failed to parse response: {} | body: {}", e, body))
51 })
52 }
53
54 pub async fn chat_completion_stream(
56 &self,
57 request: &ChatRequest,
58 ) -> Result<SseStream, LlmError> {
59 let mut body = serde_json::to_value(request)
61 .map_err(|e| LlmError::RequestBuild(format!("Failed to serialize request: {}", e)))?;
62 body.as_object_mut()
65 .expect("ChatRequest must serialize to a JSON object")
66 .insert("stream".to_string(), serde_json::Value::Bool(true));
67
68 let resp = self
69 .http_client
70 .post(self.endpoint())
71 .bearer_auth(&self.api_key)
72 .header("Content-Type", "application/json")
73 .body(body.to_string())
74 .send()
75 .await?;
76
77 let status = resp.status();
78 if !status.is_success() {
79 let resp_body = resp.text().await.unwrap_or_default();
80 return Err(LlmError::Api {
81 status: status.as_u16(),
82 body: resp_body,
83 });
84 }
85
86 Ok(SseStream::new(resp))
87 }
88}