1use anyhow::{Context, Result, bail};
2use log::{debug, warn};
3use reqwest::Client;
4use serde::{Serialize, Deserialize};
5use std::collections::HashMap;
6use std::sync::Arc;
7use std::time::{Duration, Instant};
8use tokio::sync::Semaphore;
9use parking_lot::RwLock;
10use crate::config::{AiConfig, OllamaConfig};
11
12const MAX_CONCURRENT_REQUESTS: usize = 3;
13const PLAYER_COOLDOWN_SECS: u64 = 2;
14
15#[derive(Debug, Clone)]
16pub struct AiClient {
17 client: Client,
18 config: AiConfig,
19 ollama_config: Option<OllamaConfig>,
20 semaphore: Arc<Semaphore>,
21 last_request: Arc<RwLock<HashMap<String, Instant>>>,
22}
23
24#[derive(Debug, Serialize)]
25struct ChatRequest {
26 model: String,
27 messages: Vec<Message>,
28 max_tokens: u32,
29 temperature: f32,
30}
31
32#[derive(Debug, Serialize, Deserialize, Clone)]
33pub struct Message {
34 pub role: String,
35 pub content: String,
36}
37
38#[derive(Debug, Deserialize)]
39struct ChatResponse {
40 choices: Vec<Choice>,
41}
42
43#[derive(Debug, Deserialize)]
44struct Choice {
45 message: Message,
46}
47
48#[derive(Debug, Serialize)]
49struct OllamaRequest {
50 model: String,
51 prompt: String,
52 stream: bool,
53}
54
55#[derive(Debug, Deserialize)]
56struct OllamaResponse {
57 response: String,
58}
59
60pub enum ChatResult {
61 Success(String),
62 RateLimited(String),
63}
64
65impl AiClient {
66 pub fn new(config: AiConfig, ollama_config: Option<OllamaConfig>) -> Result<Self> {
67 let client = Client::builder()
68 .timeout(Duration::from_secs(30))
69 .build()
70 .context("Failed to create HTTP client")?;
71
72 Ok(Self {
73 client,
74 config,
75 ollama_config,
76 semaphore: Arc::new(Semaphore::new(MAX_CONCURRENT_REQUESTS)),
77 last_request: Arc::new(RwLock::new(HashMap::new())),
78 })
79 }
80
81 pub async fn chat(&self, messages: Vec<Message>, player: &str) -> Result<ChatResult> {
82 {
83 let last_requests = self.last_request.read();
84 if let Some(last_time) = last_requests.get(player) {
85 let elapsed = last_time.elapsed();
86 if elapsed < Duration::from_secs(PLAYER_COOLDOWN_SECS) {
87 let wait_secs = PLAYER_COOLDOWN_SECS - elapsed.as_secs();
88 return Ok(ChatResult::RateLimited(format!(
89 "Please wait {} seconds before asking again.",
90 wait_secs
91 )));
92 }
93 }
94 }
95
96 let _permit = self.semaphore.acquire().await;
97
98 {
99 let mut last_requests = self.last_request.write();
100 last_requests.insert(player.to_string(), Instant::now());
101 }
102
103 let result = if let Some(ref ollama) = self.ollama_config {
104 if ollama.enabled {
105 self.chat_ollama(ollama, messages).await
106 } else {
107 self.chat_openai(messages).await
108 }
109 } else {
110 self.chat_openai(messages).await
111 };
112
113 match result {
114 Ok(response) => Ok(ChatResult::Success(response)),
115 Err(e) => {
116 warn!("AI chat error: {}", e);
117 Err(e)
118 }
119 }
120 }
121
122 async fn chat_openai(&self, messages: Vec<Message>) -> Result<String> {
123 let request = ChatRequest {
124 model: self.config.model.clone(),
125 messages,
126 max_tokens: self.config.max_tokens,
127 temperature: self.config.temperature,
128 };
129
130 debug!("Sending request to OpenAI API");
131
132 let response = self.client
133 .post(&self.config.api_url)
134 .header("Authorization", format!("Bearer {}", self.config.api_key))
135 .json(&request)
136 .send()
137 .await
138 .context("Failed to send request to OpenAI API. Please check your network connection and api_url in config.toml")?;
139
140 if !response.status().is_success() {
141 let status = response.status();
142 let body = response.text().await.unwrap_or_default();
143
144 if status.as_u16() == 401 {
145 bail!(
146 "OpenAI API authentication failed. \n\
147 Please check that api_key in config.toml is correct."
148 );
149 } else if status.as_u16() == 429 {
150 bail!(
151 "OpenAI API rate limit exceeded. Please try again later."
152 );
153 }
154
155 warn!("OpenAI API error: {} - {}", status, body);
156 bail!("OpenAI API returned error: {}", status);
157 }
158
159 let chat_response: ChatResponse = response
160 .json()
161 .await
162 .context("Failed to parse OpenAI response")?;
163
164 chat_response
165 .choices
166 .first()
167 .map(|c| c.message.content.clone())
168 .ok_or_else(|| anyhow::anyhow!("No response from OpenAI"))
169 }
170
171 async fn chat_ollama(&self, ollama: &OllamaConfig, messages: Vec<Message>) -> Result<String> {
172 let prompt = messages
173 .iter()
174 .map(|m| format!("{}: {}", m.role, m.content))
175 .collect::<Vec<_>>()
176 .join("\n");
177
178 let request = OllamaRequest {
179 model: ollama.model.clone(),
180 prompt,
181 stream: false,
182 };
183
184 debug!("Sending request to Ollama API");
185
186 let response = self.client
187 .post(&ollama.url)
188 .json(&request)
189 .send()
190 .await
191 .context("Failed to send request to Ollama API. Please ensure Ollama is running and the URL in config.toml is correct")?;
192
193 if !response.status().is_success() {
194 let status = response.status();
195 let body = response.text().await.unwrap_or_default();
196 warn!("Ollama API error: {} - {}", status, body);
197 bail!("Ollama API returned error: {}", status);
198 }
199
200 let ollama_response: OllamaResponse = response
201 .json()
202 .await
203 .context("Failed to parse Ollama response")?;
204
205 Ok(ollama_response.response)
206 }
207
208 pub fn get_trigger(&self) -> &str {
209 &self.config.trigger
210 }
211}