1use anyhow::{Context, Result, bail};
2use log::{debug, warn};
3use reqwest::Client;
4use serde::{Serialize, Deserialize};
5use std::collections::HashMap;
6use std::sync::Arc;
7use std::time::{Duration, Instant};
8use tokio::sync::Semaphore;
9use tokio::time::timeout;
10use parking_lot::RwLock;
11use crate::config::{AiConfig, OllamaConfig};
12
13const MAX_CONCURRENT_REQUESTS: usize = 3;
14const PLAYER_COOLDOWN_SECS: u64 = 2;
15
16#[derive(Debug, Clone)]
17pub struct AiClient {
18 client: Client,
19 config: AiConfig,
20 ollama_config: Option<OllamaConfig>,
21 semaphore: Arc<Semaphore>,
22 last_request: Arc<RwLock<HashMap<String, Instant>>>,
23}
24
25#[derive(Debug, Serialize)]
26struct ChatRequest {
27 model: String,
28 messages: Vec<Message>,
29 max_tokens: u32,
30 temperature: f32,
31}
32
33#[derive(Debug, Serialize, Deserialize, Clone)]
34pub struct Message {
35 pub role: String,
36 pub content: String,
37}
38
39#[derive(Debug, Deserialize)]
40struct ChatResponse {
41 choices: Option<Vec<Choice>>,
42 error: Option<ApiError>,
43}
44
45#[derive(Debug, Deserialize)]
46struct ApiError {
47 message: String,
48}
49
50#[derive(Debug, Deserialize)]
51struct Choice {
52 message: Message,
53}
54
55#[derive(Debug, Serialize)]
57struct OllamaChatRequest {
58 model: String,
59 messages: Vec<OllamaMessage>,
60 stream: bool,
61}
62
63#[derive(Debug, Serialize, Clone)]
64struct OllamaMessage {
65 role: String,
66 content: String,
67}
68
69#[derive(Debug, Deserialize)]
70struct OllamaChatResponse {
71 message: OllamaResponseMessage,
72}
73
74#[derive(Debug, Deserialize)]
75struct OllamaResponseMessage {
76 content: String,
77}
78
79pub enum ChatResult {
80 Success(String),
81 RateLimited(String),
82}
83
84impl AiClient {
85 pub fn new(config: AiConfig, ollama_config: Option<OllamaConfig>) -> Result<Self> {
86 let client = Client::builder()
87 .timeout(Duration::from_secs(30))
88 .build()
89 .context("Failed to create HTTP client")?;
90
91 Ok(Self {
92 client,
93 config,
94 ollama_config,
95 semaphore: Arc::new(Semaphore::new(MAX_CONCURRENT_REQUESTS)),
96 last_request: Arc::new(RwLock::new(HashMap::new())),
97 })
98 }
99
100 pub async fn chat(&self, messages: Vec<Message>, player: &str) -> Result<ChatResult> {
101 {
102 let last_requests = self.last_request.read();
103 if let Some(last_time) = last_requests.get(player) {
104 let elapsed = last_time.elapsed();
105 if elapsed < Duration::from_secs(PLAYER_COOLDOWN_SECS) {
106 let wait_secs = PLAYER_COOLDOWN_SECS - elapsed.as_secs();
107 return Ok(ChatResult::RateLimited(format!(
108 "Please wait {} seconds before asking again.",
109 wait_secs
110 )));
111 }
112 }
113 }
114
115 let _permit = timeout(Duration::from_secs(10), self.semaphore.acquire())
116 .await
117 .context("AI request timeout: too many concurrent requests")?;
118
119 {
120 let mut last_requests = self.last_request.write();
121 last_requests.insert(player.to_string(), Instant::now());
122 }
123
124 let result = if let Some(ref ollama) = self.ollama_config {
125 if ollama.enabled {
126 self.chat_ollama(ollama, messages).await
127 } else {
128 self.chat_openai(messages).await
129 }
130 } else {
131 self.chat_openai(messages).await
132 };
133
134 match result {
135 Ok(response) => Ok(ChatResult::Success(response)),
136 Err(e) => {
137 warn!("AI chat error: {}", e);
138 Err(e)
139 }
140 }
141 }
142
143 async fn chat_openai(&self, messages: Vec<Message>) -> Result<String> {
144 let request = ChatRequest {
145 model: self.config.model.clone(),
146 messages,
147 max_tokens: self.config.max_tokens,
148 temperature: self.config.temperature,
149 };
150
151 debug!("Sending request to OpenAI API");
152
153 let response = self.client
154 .post(&self.config.api_url)
155 .header("Authorization", format!("Bearer {}", self.config.api_key))
156 .json(&request)
157 .send()
158 .await
159 .context("Failed to send request to OpenAI API. Please check your network connection and api_url in config.toml")?;
160
161 if !response.status().is_success() {
162 let status = response.status();
163 let body = match response.text().await {
164 Ok(text) => text,
165 Err(e) => format!("<failed to read error response: {}>", e),
166 };
167
168 if status.as_u16() == 401 {
169 bail!(
170 "OpenAI API authentication failed. \n\
171 Please check that api_key in config.toml is correct.\n\
172 Response: {}", body
173 );
174 } else if status.as_u16() == 429 {
175 bail!(
176 "OpenAI API rate limit exceeded. Please try again later.\n\
177 Response: {}", body
178 );
179 }
180
181 warn!("OpenAI API error: {} - {}", status, body);
182 bail!("OpenAI API returned error: {} - {}", status, body);
183 }
184
185 let chat_response: ChatResponse = response
186 .json()
187 .await
188 .context("Failed to parse OpenAI response")?;
189
190 if let Some(error) = chat_response.error {
192 bail!("OpenAI API error: {}", error.message);
193 }
194
195 chat_response
197 .choices
198 .and_then(|c| c.first().map(|c| c.message.content.clone()))
199 .ok_or_else(|| anyhow::anyhow!("No response from AI API"))
200 }
201
202 async fn chat_ollama(&self, ollama: &OllamaConfig, messages: Vec<Message>) -> Result<String> {
203 let ollama_messages: Vec<OllamaMessage> = messages
205 .into_iter()
206 .map(|m| OllamaMessage {
207 role: m.role,
208 content: m.content,
209 })
210 .collect();
211
212 let chat_url = ollama.url.replace("/api/generate", "/api/chat");
214 let request = OllamaChatRequest {
215 model: ollama.model.clone(),
216 messages: ollama_messages,
217 stream: false,
218 };
219
220 debug!("Sending request to Ollama /api/chat: {}", chat_url);
221
222 let response = self.client
223 .post(&chat_url)
224 .json(&request)
225 .send()
226 .await
227 .context("Failed to send request to Ollama API. Please ensure Ollama is running and the URL in config.toml is correct")?;
228
229 if !response.status().is_success() {
230 let status = response.status();
231 let body = match response.text().await {
232 Ok(text) => text,
233 Err(e) => format!("<failed to read error response: {}>", e),
234 };
235 warn!("Ollama API error: {} - {}", status, body);
236 bail!("Ollama API returned error: {} - {}", status, body);
237 }
238
239 let ollama_response: OllamaChatResponse = response
240 .json()
241 .await
242 .context("Failed to parse Ollama response")?;
243
244 Ok(ollama_response.message.content)
245 }
246
247 pub fn get_trigger(&self) -> &str {
248 &self.config.trigger
249 }
250}