1use async_trait::async_trait;
2use reqwest::Client;
3use serde::{Deserialize, Serialize};
4use crate::chat::ChatMessage;
5use crate::config::LLMConfig;
6use crate::error::{HeliosError, Result};
7use crate::tools::ToolDefinition;
8
9#[derive(Debug, Clone, Serialize, Deserialize)]
10pub struct LLMRequest {
11 pub model: String,
12 pub messages: Vec<ChatMessage>,
13 #[serde(skip_serializing_if = "Option::is_none")]
14 pub temperature: Option<f32>,
15 #[serde(skip_serializing_if = "Option::is_none")]
16 pub max_tokens: Option<u32>,
17 #[serde(skip_serializing_if = "Option::is_none")]
18 pub tools: Option<Vec<ToolDefinition>>,
19 #[serde(skip_serializing_if = "Option::is_none")]
20 pub tool_choice: Option<String>,
21}
22
23#[derive(Debug, Clone, Serialize, Deserialize)]
24pub struct LLMResponse {
25 pub id: String,
26 pub object: String,
27 pub created: u64,
28 pub model: String,
29 pub choices: Vec<Choice>,
30 pub usage: Usage,
31}
32
33#[derive(Debug, Clone, Serialize, Deserialize)]
34pub struct Choice {
35 pub index: u32,
36 pub message: ChatMessage,
37 pub finish_reason: Option<String>,
38}
39
40#[derive(Debug, Clone, Serialize, Deserialize)]
41pub struct Usage {
42 pub prompt_tokens: u32,
43 pub completion_tokens: u32,
44 pub total_tokens: u32,
45}
46
47#[async_trait]
48pub trait LLMProvider: Send + Sync {
49 async fn generate(&self, request: LLMRequest) -> Result<LLMResponse>;
50}
51
52pub struct LLMClient {
53 config: LLMConfig,
54 client: Client,
55}
56
57impl LLMClient {
58 pub fn new(config: LLMConfig) -> Self {
59 Self {
60 config,
61 client: Client::new(),
62 }
63 }
64
65 pub fn config(&self) -> &LLMConfig {
66 &self.config
67 }
68}
69
70#[async_trait]
71impl LLMProvider for LLMClient {
72 async fn generate(&self, request: LLMRequest) -> Result<LLMResponse> {
73 let url = format!("{}/chat/completions", self.config.base_url);
74
75 let response = self
76 .client
77 .post(&url)
78 .header("Authorization", format!("Bearer {}", self.config.api_key))
79 .header("Content-Type", "application/json")
80 .json(&request)
81 .send()
82 .await?;
83
84 if !response.status().is_success() {
85 let status = response.status();
86 let error_text = response.text().await.unwrap_or_else(|_| "Unknown error".to_string());
87 return Err(HeliosError::LLMError(format!(
88 "LLM API request failed with status {}: {}",
89 status, error_text
90 )));
91 }
92
93 let llm_response: LLMResponse = response.json().await?;
94 Ok(llm_response)
95 }
96}
97
98impl LLMClient {
99 pub async fn chat(
100 &self,
101 messages: Vec<ChatMessage>,
102 tools: Option<Vec<ToolDefinition>>,
103 ) -> Result<ChatMessage> {
104 let request = LLMRequest {
105 model: self.config.model_name.clone(),
106 messages,
107 temperature: Some(self.config.temperature),
108 max_tokens: Some(self.config.max_tokens),
109 tools,
110 tool_choice: None,
111 };
112
113 let response = self.generate(request).await?;
114
115 response
116 .choices
117 .into_iter()
118 .next()
119 .map(|choice| choice.message)
120 .ok_or_else(|| HeliosError::LLMError("No response from LLM".to_string()))
121 }
122
123 pub async fn chat_stream(
124 &self,
125 messages: Vec<ChatMessage>,
126 ) -> Result<ChatMessage> {
127 self.chat(messages, None).await
130 }
131}