helios_engine/
llm.rs

1use async_trait::async_trait;
2use reqwest::Client;
3use serde::{Deserialize, Serialize};
4use crate::chat::ChatMessage;
5use crate::config::LLMConfig;
6use crate::error::{HeliosError, Result};
7use crate::tools::ToolDefinition;
8
9#[derive(Debug, Clone, Serialize, Deserialize)]
10pub struct LLMRequest {
11    pub model: String,
12    pub messages: Vec<ChatMessage>,
13    #[serde(skip_serializing_if = "Option::is_none")]
14    pub temperature: Option<f32>,
15    #[serde(skip_serializing_if = "Option::is_none")]
16    pub max_tokens: Option<u32>,
17    #[serde(skip_serializing_if = "Option::is_none")]
18    pub tools: Option<Vec<ToolDefinition>>,
19    #[serde(skip_serializing_if = "Option::is_none")]
20    pub tool_choice: Option<String>,
21}
22
23#[derive(Debug, Clone, Serialize, Deserialize)]
24pub struct LLMResponse {
25    pub id: String,
26    pub object: String,
27    pub created: u64,
28    pub model: String,
29    pub choices: Vec<Choice>,
30    pub usage: Usage,
31}
32
33#[derive(Debug, Clone, Serialize, Deserialize)]
34pub struct Choice {
35    pub index: u32,
36    pub message: ChatMessage,
37    pub finish_reason: Option<String>,
38}
39
40#[derive(Debug, Clone, Serialize, Deserialize)]
41pub struct Usage {
42    pub prompt_tokens: u32,
43    pub completion_tokens: u32,
44    pub total_tokens: u32,
45}
46
47#[async_trait]
48pub trait LLMProvider: Send + Sync {
49    async fn generate(&self, request: LLMRequest) -> Result<LLMResponse>;
50}
51
52pub struct LLMClient {
53    config: LLMConfig,
54    client: Client,
55}
56
57impl LLMClient {
58    pub fn new(config: LLMConfig) -> Self {
59        Self {
60            config,
61            client: Client::new(),
62        }
63    }
64
65    pub fn config(&self) -> &LLMConfig {
66        &self.config
67    }
68}
69
70#[async_trait]
71impl LLMProvider for LLMClient {
72    async fn generate(&self, request: LLMRequest) -> Result<LLMResponse> {
73        let url = format!("{}/chat/completions", self.config.base_url);
74        
75        let response = self
76            .client
77            .post(&url)
78            .header("Authorization", format!("Bearer {}", self.config.api_key))
79            .header("Content-Type", "application/json")
80            .json(&request)
81            .send()
82            .await?;
83
84        if !response.status().is_success() {
85            let status = response.status();
86            let error_text = response.text().await.unwrap_or_else(|_| "Unknown error".to_string());
87            return Err(HeliosError::LLMError(format!(
88                "LLM API request failed with status {}: {}",
89                status, error_text
90            )));
91        }
92
93        let llm_response: LLMResponse = response.json().await?;
94        Ok(llm_response)
95    }
96}
97
98impl LLMClient {
99    pub async fn chat(
100        &self,
101        messages: Vec<ChatMessage>,
102        tools: Option<Vec<ToolDefinition>>,
103    ) -> Result<ChatMessage> {
104        let request = LLMRequest {
105            model: self.config.model_name.clone(),
106            messages,
107            temperature: Some(self.config.temperature),
108            max_tokens: Some(self.config.max_tokens),
109            tools,
110            tool_choice: None,
111        };
112
113        let response = self.generate(request).await?;
114        
115        response
116            .choices
117            .into_iter()
118            .next()
119            .map(|choice| choice.message)
120            .ok_or_else(|| HeliosError::LLMError("No response from LLM".to_string()))
121    }
122
123    pub async fn chat_stream(
124        &self,
125        messages: Vec<ChatMessage>,
126    ) -> Result<ChatMessage> {
127        // For simplicity, using non-streaming version
128        // Streaming can be implemented with server-sent events
129        self.chat(messages, None).await
130    }
131}