1use anyhow::{Context, Result, bail};
2use log::{debug, warn};
3use reqwest::Client;
4use serde::{Serialize, Deserialize};
5use std::collections::HashMap;
6use std::sync::Arc;
7use std::time::{Duration, Instant};
8use tokio::sync::Semaphore;
9use tokio::time::timeout;
10use parking_lot::RwLock;
11use crate::config::{AiConfig, OllamaConfig};
12
13const MAX_CONCURRENT_REQUESTS: usize = 3;
14const PLAYER_COOLDOWN_SECS: u64 = 2;
15
16#[derive(Debug, Clone)]
17pub struct AiClient {
18 client: Client,
19 config: AiConfig,
20 ollama_config: Option<OllamaConfig>,
21 semaphore: Arc<Semaphore>,
22 last_request: Arc<RwLock<HashMap<String, Instant>>>,
23}
24
25#[derive(Debug, Serialize)]
26struct ChatRequest {
27 model: String,
28 messages: Vec<Message>,
29 max_tokens: u32,
30 temperature: f32,
31}
32
33#[derive(Debug, Serialize, Deserialize, Clone)]
34pub struct Message {
35 pub role: String,
36 pub content: String,
37}
38
39#[derive(Debug, Deserialize)]
40struct ChatResponse {
41 choices: Option<Vec<Choice>>,
42 error: Option<ApiError>,
43}
44
45#[derive(Debug, Deserialize)]
46struct ApiError {
47 message: String,
48}
49
50#[derive(Debug, Deserialize)]
51struct Choice {
52 message: Message,
53}
54
55#[derive(Debug, Serialize)]
57struct OllamaChatRequest {
58 model: String,
59 messages: Vec<OllamaMessage>,
60 stream: bool,
61}
62
63#[derive(Debug, Serialize, Clone)]
64struct OllamaMessage {
65 role: String,
66 content: String,
67}
68
69#[derive(Debug, Deserialize)]
70struct OllamaChatResponse {
71 message: OllamaResponseMessage,
72}
73
74#[derive(Debug, Deserialize)]
75struct OllamaResponseMessage {
76 content: String,
77}
78
79pub enum ChatResult {
80 Success(String),
81 RateLimited(String),
82}
83
84impl AiClient {
85 pub fn new(config: AiConfig, ollama_config: Option<OllamaConfig>) -> Result<Self> {
86 let client = Client::builder()
87 .timeout(Duration::from_secs(30))
88 .build()
89 .context("Failed to create HTTP client")?;
90
91 Ok(Self {
92 client,
93 config,
94 ollama_config,
95 semaphore: Arc::new(Semaphore::new(MAX_CONCURRENT_REQUESTS)),
96 last_request: Arc::new(RwLock::new(HashMap::new())),
97 })
98 }
99
100 pub async fn chat(&self, messages: Vec<Message>, player: &str) -> Result<ChatResult> {
101 debug!("[AI] Chat request: player='{}', messages_count={}", player, messages.len());
102
103 {
104 let last_requests = self.last_request.read();
105 if let Some(last_time) = last_requests.get(player) {
106 let elapsed = last_time.elapsed();
107 if elapsed < Duration::from_secs(PLAYER_COOLDOWN_SECS) {
108 let wait_secs = PLAYER_COOLDOWN_SECS - elapsed.as_secs();
109 debug!("[AI] Player '{}' rate limited, elapsed={}ms, wait={}s", player, elapsed.as_millis(), wait_secs);
110 return Ok(ChatResult::RateLimited(format!(
111 "Please wait {} seconds before asking again.",
112 wait_secs
113 )));
114 }
115 }
116 }
117
118 debug!("[AI] Acquiring semaphore permit (max concurrent: {})", MAX_CONCURRENT_REQUESTS);
119 let _permit = timeout(Duration::from_secs(10), self.semaphore.acquire())
120 .await
121 .context("AI request timeout: too many concurrent requests")?;
122
123 {
124 let mut last_requests = self.last_request.write();
125 last_requests.insert(player.to_string(), Instant::now());
126 }
127
128 debug!("[AI] Routing to backend: ollama_enabled={}",
129 self.ollama_config.as_ref().is_some_and(|o| o.enabled));
130
131 let result = if let Some(ref ollama) = self.ollama_config {
132 if ollama.enabled {
133 debug!("[AI] Using Ollama backend");
134 self.chat_ollama(ollama, messages).await
135 } else {
136 debug!("[AI] Using OpenAI-compatible backend");
137 self.chat_openai(messages).await
138 }
139 } else {
140 debug!("[AI] No Ollama config, using OpenAI-compatible backend");
141 self.chat_openai(messages).await
142 };
143
144 match result {
145 Ok(response) => {
146 debug!("[AI] Chat successful, response_length={}", response.len());
147 Ok(ChatResult::Success(response))
148 }
149 Err(e) => {
150 warn!("[AI] Chat error: {}", e);
151 Err(e)
152 }
153 }
154 }
155
156 async fn chat_openai(&self, messages: Vec<Message>) -> Result<String> {
157 let request = ChatRequest {
158 model: self.config.model.clone(),
159 messages,
160 max_tokens: self.config.max_tokens,
161 temperature: self.config.temperature,
162 };
163 debug!("[AI] OpenAI request: model={}, max_tokens={}, temperature={}, api_url={}",
164 self.config.model, self.config.max_tokens, self.config.temperature, self.config.api_url);
165
166 debug!("[AI] Sending request to OpenAI API...");
167
168 let response = self.client
169 .post(&self.config.api_url)
170 .header("Authorization", format!("Bearer {}", self.config.api_key))
171 .json(&request)
172 .send()
173 .await
174 .context("Failed to send request to OpenAI API. Please check your network connection and api_url in config.toml")?;
175
176 let status = response.status();
177 debug!("[AI] OpenAI API response status: {}", status);
178
179 if !response.status().is_success() {
180 let body = match response.text().await {
181 Ok(text) => {
182 debug!("[AI] OpenAI error response body: {}", text);
183 text
184 },
185 Err(e) => format!("<failed to read error response: {}>", e),
186 };
187
188 if status.as_u16() == 401 {
189 bail!(
190 "OpenAI API authentication failed. \n\
191 Please check that api_key in config.toml is correct.\n\
192 Response: {}", body
193 );
194 } else if status.as_u16() == 429 {
195 bail!(
196 "OpenAI API rate limit exceeded. Please try again later.\n\
197 Response: {}", body
198 );
199 }
200
201 warn!("[AI] OpenAI API error: {} - {}", status, body);
202 bail!("OpenAI API returned error: {} - {}", status, body);
203 }
204
205 debug!("[AI] Parsing OpenAI response...");
206 let chat_response: ChatResponse = response
207 .json()
208 .await
209 .context("Failed to parse OpenAI response")?;
210
211 if let Some(error) = chat_response.error {
213 debug!("[AI] OpenAI API returned error in response: {}", error.message);
214 bail!("OpenAI API error: {}", error.message);
215 }
216
217 let content = chat_response
219 .choices
220 .and_then(|c| c.first().map(|c| c.message.content.clone()))
221 .ok_or_else(|| anyhow::anyhow!("No response from AI API"))?;
222
223 debug!("[AI] OpenAI response parsed successfully, content_length={}", content.len());
224 Ok(content)
225 }
226
227 async fn chat_ollama(&self, ollama: &OllamaConfig, messages: Vec<Message>) -> Result<String> {
228 debug!("[AI] Preparing Ollama request: model={}, messages_count={}", ollama.model, messages.len());
229
230 let ollama_messages: Vec<OllamaMessage> = messages
232 .into_iter()
233 .map(|m| OllamaMessage {
234 role: m.role,
235 content: m.content,
236 })
237 .collect();
238
239 let base_url = ollama.url
241 .trim_end_matches("/api/generate")
242 .trim_end_matches("/api/chat")
243 .trim_end_matches('/');
244 let chat_url = format!("{}/api/chat", base_url);
245
246 let request = OllamaChatRequest {
247 model: ollama.model.clone(),
248 messages: ollama_messages,
249 stream: false,
250 };
251
252 debug!("[AI] Sending request to Ollama: {}", chat_url);
253 debug!("[AI] Ollama request: model={}, url={}", ollama.model, chat_url);
254
255 let response = self.client
256 .post(&chat_url)
257 .json(&request)
258 .send()
259 .await
260 .context("Failed to send request to Ollama API. Please ensure Ollama is running and the URL in config.toml is correct")?;
261
262 let status = response.status();
263 debug!("[AI] Ollama API response status: {}", status);
264
265 if !response.status().is_success() {
266 let body = match response.text().await {
267 Ok(text) => {
268 debug!("[AI] Ollama error response body: {}", text);
269 text
270 },
271 Err(e) => format!("<failed to read error response: {}>", e),
272 };
273 warn!("[AI] Ollama API error: {} - {}", status, body);
274 bail!("Ollama API returned error: {} - {}", status, body);
275 }
276
277 debug!("[AI] Parsing Ollama response...");
278 let ollama_response: OllamaChatResponse = response
279 .json()
280 .await
281 .context("Failed to parse Ollama response")?;
282
283 debug!("[AI] Ollama response parsed successfully, content_length={}", ollama_response.message.content.len());
284 Ok(ollama_response.message.content)
285 }
286
287 pub fn get_trigger(&self) -> &str {
288 &self.config.trigger
289 }
290}