helios_engine/
llm.rs

1use async_trait::async_trait;
2use reqwest::Client;
3use serde::{Deserialize, Serialize};
4use futures::stream::StreamExt;
5use crate::chat::ChatMessage;
6use crate::config::LLMConfig;
7use crate::error::{HeliosError, Result};
8use crate::tools::ToolDefinition;
9
10#[derive(Debug, Clone, Serialize, Deserialize)]
11pub struct LLMRequest {
12    pub model: String,
13    pub messages: Vec<ChatMessage>,
14    #[serde(skip_serializing_if = "Option::is_none")]
15    pub temperature: Option<f32>,
16    #[serde(skip_serializing_if = "Option::is_none")]
17    pub max_tokens: Option<u32>,
18    #[serde(skip_serializing_if = "Option::is_none")]
19    pub tools: Option<Vec<ToolDefinition>>,
20    #[serde(skip_serializing_if = "Option::is_none")]
21    pub tool_choice: Option<String>,
22    #[serde(skip_serializing_if = "Option::is_none")]
23    pub stream: Option<bool>,
24}
25
26#[derive(Debug, Clone, Serialize, Deserialize)]
27pub struct StreamChunk {
28    pub id: String,
29    pub object: String,
30    pub created: u64,
31    pub model: String,
32    pub choices: Vec<StreamChoice>,
33}
34
35#[derive(Debug, Clone, Serialize, Deserialize)]
36pub struct StreamChoice {
37    pub index: u32,
38    pub delta: Delta,
39    pub finish_reason: Option<String>,
40}
41
42#[derive(Debug, Clone, Serialize, Deserialize)]
43pub struct Delta {
44    #[serde(skip_serializing_if = "Option::is_none")]
45    pub role: Option<String>,
46    #[serde(skip_serializing_if = "Option::is_none")]
47    pub content: Option<String>,
48}
49
50#[derive(Debug, Clone, Serialize, Deserialize)]
51pub struct LLMResponse {
52    pub id: String,
53    pub object: String,
54    pub created: u64,
55    pub model: String,
56    pub choices: Vec<Choice>,
57    pub usage: Usage,
58}
59
60#[derive(Debug, Clone, Serialize, Deserialize)]
61pub struct Choice {
62    pub index: u32,
63    pub message: ChatMessage,
64    pub finish_reason: Option<String>,
65}
66
67#[derive(Debug, Clone, Serialize, Deserialize)]
68pub struct Usage {
69    pub prompt_tokens: u32,
70    pub completion_tokens: u32,
71    pub total_tokens: u32,
72}
73
74#[async_trait]
75pub trait LLMProvider: Send + Sync {
76    async fn generate(&self, request: LLMRequest) -> Result<LLMResponse>;
77}
78
79pub struct LLMClient {
80    config: LLMConfig,
81    client: Client,
82}
83
84impl LLMClient {
85    pub fn new(config: LLMConfig) -> Self {
86        Self {
87            config,
88            client: Client::new(),
89        }
90    }
91
92    pub fn config(&self) -> &LLMConfig {
93        &self.config
94    }
95}
96
97#[async_trait]
98impl LLMProvider for LLMClient {
99    async fn generate(&self, request: LLMRequest) -> Result<LLMResponse> {
100        let url = format!("{}/chat/completions", self.config.base_url);
101        
102        let response = self
103            .client
104            .post(&url)
105            .header("Authorization", format!("Bearer {}", self.config.api_key))
106            .header("Content-Type", "application/json")
107            .json(&request)
108            .send()
109            .await?;
110
111        if !response.status().is_success() {
112            let status = response.status();
113            let error_text = response.text().await.unwrap_or_else(|_| "Unknown error".to_string());
114            return Err(HeliosError::LLMError(format!(
115                "LLM API request failed with status {}: {}",
116                status, error_text
117            )));
118        }
119
120        let llm_response: LLMResponse = response.json().await?;
121        Ok(llm_response)
122    }
123}
124
125impl LLMClient {
126    pub async fn chat(
127        &self,
128        messages: Vec<ChatMessage>,
129        tools: Option<Vec<ToolDefinition>>,
130    ) -> Result<ChatMessage> {
131        let request = LLMRequest {
132            model: self.config.model_name.clone(),
133            messages,
134            temperature: Some(self.config.temperature),
135            max_tokens: Some(self.config.max_tokens),
136            tools,
137            tool_choice: None,
138            stream: None,
139        };
140
141        let response = self.generate(request).await?;
142        
143        response
144            .choices
145            .into_iter()
146            .next()
147            .map(|choice| choice.message)
148            .ok_or_else(|| HeliosError::LLMError("No response from LLM".to_string()))
149    }
150
151    pub async fn chat_stream<F>(
152        &self,
153        messages: Vec<ChatMessage>,
154        tools: Option<Vec<ToolDefinition>>,
155        mut on_chunk: F,
156    ) -> Result<ChatMessage>
157    where
158        F: FnMut(&str) + Send,
159    {
160        let request = LLMRequest {
161            model: self.config.model_name.clone(),
162            messages,
163            temperature: Some(self.config.temperature),
164            max_tokens: Some(self.config.max_tokens),
165            tools,
166            tool_choice: None,
167            stream: Some(true),
168        };
169
170        let url = format!("{}/chat/completions", self.config.base_url);
171        
172        let response = self
173            .client
174            .post(&url)
175            .header("Authorization", format!("Bearer {}", self.config.api_key))
176            .header("Content-Type", "application/json")
177            .json(&request)
178            .send()
179            .await?;
180
181        if !response.status().is_success() {
182            let status = response.status();
183            let error_text = response.text().await.unwrap_or_else(|_| "Unknown error".to_string());
184            return Err(HeliosError::LLMError(format!(
185                "LLM API request failed with status {}: {}",
186                status, error_text
187            )));
188        }
189
190        let mut stream = response.bytes_stream();
191        let mut full_content = String::new();
192        let mut role = None;
193        let mut buffer = String::new();
194
195        while let Some(chunk_result) = stream.next().await {
196            let chunk = chunk_result?;
197            let chunk_str = String::from_utf8_lossy(&chunk);
198            buffer.push_str(&chunk_str);
199
200            // Process complete lines
201            while let Some(line_end) = buffer.find('\n') {
202                let line = buffer[..line_end].trim().to_string();
203                buffer = buffer[line_end + 1..].to_string();
204
205                if line.is_empty() || line == "data: [DONE]" {
206                    continue;
207                }
208
209                if let Some(data) = line.strip_prefix("data: ") {
210                    match serde_json::from_str::<StreamChunk>(data) {
211                        Ok(stream_chunk) => {
212                            if let Some(choice) = stream_chunk.choices.first() {
213                                if let Some(r) = &choice.delta.role {
214                                    role = Some(r.clone());
215                                }
216                                if let Some(content) = &choice.delta.content {
217                                    full_content.push_str(content);
218                                    on_chunk(content);
219                                }
220                            }
221                        }
222                        Err(e) => {
223                            tracing::debug!("Failed to parse stream chunk: {} - Data: {}", e, data);
224                        }
225                    }
226                }
227            }
228        }
229
230        Ok(ChatMessage {
231            role: crate::chat::Role::from(role.as_deref().unwrap_or("assistant")),
232            content: full_content,
233            name: None,
234            tool_calls: None,
235            tool_call_id: None,
236        })
237    }
238}