1use async_trait::async_trait;
2use reqwest::Client;
3use serde::{Deserialize, Serialize};
4use futures::stream::StreamExt;
5use crate::chat::ChatMessage;
6use crate::config::LLMConfig;
7use crate::error::{HeliosError, Result};
8use crate::tools::ToolDefinition;
9
10#[derive(Debug, Clone, Serialize, Deserialize)]
11pub struct LLMRequest {
12 pub model: String,
13 pub messages: Vec<ChatMessage>,
14 #[serde(skip_serializing_if = "Option::is_none")]
15 pub temperature: Option<f32>,
16 #[serde(skip_serializing_if = "Option::is_none")]
17 pub max_tokens: Option<u32>,
18 #[serde(skip_serializing_if = "Option::is_none")]
19 pub tools: Option<Vec<ToolDefinition>>,
20 #[serde(skip_serializing_if = "Option::is_none")]
21 pub tool_choice: Option<String>,
22 #[serde(skip_serializing_if = "Option::is_none")]
23 pub stream: Option<bool>,
24}
25
26#[derive(Debug, Clone, Serialize, Deserialize)]
27pub struct StreamChunk {
28 pub id: String,
29 pub object: String,
30 pub created: u64,
31 pub model: String,
32 pub choices: Vec<StreamChoice>,
33}
34
35#[derive(Debug, Clone, Serialize, Deserialize)]
36pub struct StreamChoice {
37 pub index: u32,
38 pub delta: Delta,
39 pub finish_reason: Option<String>,
40}
41
42#[derive(Debug, Clone, Serialize, Deserialize)]
43pub struct Delta {
44 #[serde(skip_serializing_if = "Option::is_none")]
45 pub role: Option<String>,
46 #[serde(skip_serializing_if = "Option::is_none")]
47 pub content: Option<String>,
48}
49
50#[derive(Debug, Clone, Serialize, Deserialize)]
51pub struct LLMResponse {
52 pub id: String,
53 pub object: String,
54 pub created: u64,
55 pub model: String,
56 pub choices: Vec<Choice>,
57 pub usage: Usage,
58}
59
60#[derive(Debug, Clone, Serialize, Deserialize)]
61pub struct Choice {
62 pub index: u32,
63 pub message: ChatMessage,
64 pub finish_reason: Option<String>,
65}
66
67#[derive(Debug, Clone, Serialize, Deserialize)]
68pub struct Usage {
69 pub prompt_tokens: u32,
70 pub completion_tokens: u32,
71 pub total_tokens: u32,
72}
73
74#[async_trait]
75pub trait LLMProvider: Send + Sync {
76 async fn generate(&self, request: LLMRequest) -> Result<LLMResponse>;
77}
78
79pub struct LLMClient {
80 config: LLMConfig,
81 client: Client,
82}
83
84impl LLMClient {
85 pub fn new(config: LLMConfig) -> Self {
86 Self {
87 config,
88 client: Client::new(),
89 }
90 }
91
92 pub fn config(&self) -> &LLMConfig {
93 &self.config
94 }
95}
96
97#[async_trait]
98impl LLMProvider for LLMClient {
99 async fn generate(&self, request: LLMRequest) -> Result<LLMResponse> {
100 let url = format!("{}/chat/completions", self.config.base_url);
101
102 let response = self
103 .client
104 .post(&url)
105 .header("Authorization", format!("Bearer {}", self.config.api_key))
106 .header("Content-Type", "application/json")
107 .json(&request)
108 .send()
109 .await?;
110
111 if !response.status().is_success() {
112 let status = response.status();
113 let error_text = response.text().await.unwrap_or_else(|_| "Unknown error".to_string());
114 return Err(HeliosError::LLMError(format!(
115 "LLM API request failed with status {}: {}",
116 status, error_text
117 )));
118 }
119
120 let llm_response: LLMResponse = response.json().await?;
121 Ok(llm_response)
122 }
123}
124
125impl LLMClient {
126 pub async fn chat(
127 &self,
128 messages: Vec<ChatMessage>,
129 tools: Option<Vec<ToolDefinition>>,
130 ) -> Result<ChatMessage> {
131 let request = LLMRequest {
132 model: self.config.model_name.clone(),
133 messages,
134 temperature: Some(self.config.temperature),
135 max_tokens: Some(self.config.max_tokens),
136 tools,
137 tool_choice: None,
138 stream: None,
139 };
140
141 let response = self.generate(request).await?;
142
143 response
144 .choices
145 .into_iter()
146 .next()
147 .map(|choice| choice.message)
148 .ok_or_else(|| HeliosError::LLMError("No response from LLM".to_string()))
149 }
150
151 pub async fn chat_stream<F>(
152 &self,
153 messages: Vec<ChatMessage>,
154 tools: Option<Vec<ToolDefinition>>,
155 mut on_chunk: F,
156 ) -> Result<ChatMessage>
157 where
158 F: FnMut(&str) + Send,
159 {
160 let request = LLMRequest {
161 model: self.config.model_name.clone(),
162 messages,
163 temperature: Some(self.config.temperature),
164 max_tokens: Some(self.config.max_tokens),
165 tools,
166 tool_choice: None,
167 stream: Some(true),
168 };
169
170 let url = format!("{}/chat/completions", self.config.base_url);
171
172 let response = self
173 .client
174 .post(&url)
175 .header("Authorization", format!("Bearer {}", self.config.api_key))
176 .header("Content-Type", "application/json")
177 .json(&request)
178 .send()
179 .await?;
180
181 if !response.status().is_success() {
182 let status = response.status();
183 let error_text = response.text().await.unwrap_or_else(|_| "Unknown error".to_string());
184 return Err(HeliosError::LLMError(format!(
185 "LLM API request failed with status {}: {}",
186 status, error_text
187 )));
188 }
189
190 let mut stream = response.bytes_stream();
191 let mut full_content = String::new();
192 let mut role = None;
193 let mut buffer = String::new();
194
195 while let Some(chunk_result) = stream.next().await {
196 let chunk = chunk_result?;
197 let chunk_str = String::from_utf8_lossy(&chunk);
198 buffer.push_str(&chunk_str);
199
200 while let Some(line_end) = buffer.find('\n') {
202 let line = buffer[..line_end].trim().to_string();
203 buffer = buffer[line_end + 1..].to_string();
204
205 if line.is_empty() || line == "data: [DONE]" {
206 continue;
207 }
208
209 if let Some(data) = line.strip_prefix("data: ") {
210 match serde_json::from_str::<StreamChunk>(data) {
211 Ok(stream_chunk) => {
212 if let Some(choice) = stream_chunk.choices.first() {
213 if let Some(r) = &choice.delta.role {
214 role = Some(r.clone());
215 }
216 if let Some(content) = &choice.delta.content {
217 full_content.push_str(content);
218 on_chunk(content);
219 }
220 }
221 }
222 Err(e) => {
223 tracing::debug!("Failed to parse stream chunk: {} - Data: {}", e, data);
224 }
225 }
226 }
227 }
228 }
229
230 Ok(ChatMessage {
231 role: crate::chat::Role::from(role.as_deref().unwrap_or("assistant")),
232 content: full_content,
233 name: None,
234 tool_calls: None,
235 tool_call_id: None,
236 })
237 }
238}