Skip to main content

mc_minder/ai/
mod.rs

1use anyhow::{Context, Result, bail};
2use log::{debug, warn};
3use reqwest::Client;
4use serde::{Serialize, Deserialize};
5use std::collections::HashMap;
6use std::sync::Arc;
7use std::time::{Duration, Instant};
8use tokio::sync::Semaphore;
9use tokio::time::timeout;
10use parking_lot::RwLock;
11use crate::config::{AiConfig, OllamaConfig};
12
13const MAX_CONCURRENT_REQUESTS: usize = 3;
14const PLAYER_COOLDOWN_SECS: u64 = 2;
15
16#[derive(Debug, Clone)]
17pub struct AiClient {
18    client: Client,
19    config: AiConfig,
20    ollama_config: Option<OllamaConfig>,
21    semaphore: Arc<Semaphore>,
22    last_request: Arc<RwLock<HashMap<String, Instant>>>,
23}
24
25#[derive(Debug, Serialize)]
26struct ChatRequest {
27    model: String,
28    messages: Vec<Message>,
29    max_tokens: u32,
30    temperature: f32,
31}
32
33#[derive(Debug, Serialize, Deserialize, Clone)]
34pub struct Message {
35    pub role: String,
36    pub content: String,
37}
38
39#[derive(Debug, Deserialize)]
40struct ChatResponse {
41    choices: Option<Vec<Choice>>,
42    error: Option<ApiError>,
43}
44
45#[derive(Debug, Deserialize)]
46struct ApiError {
47    message: String,
48}
49
50#[derive(Debug, Deserialize)]
51struct Choice {
52    message: Message,
53}
54
55// Ollama /api/chat 端点请求格式
56#[derive(Debug, Serialize)]
57struct OllamaChatRequest {
58    model: String,
59    messages: Vec<OllamaMessage>,
60    stream: bool,
61}
62
63#[derive(Debug, Serialize, Clone)]
64struct OllamaMessage {
65    role: String,
66    content: String,
67}
68
69#[derive(Debug, Deserialize)]
70struct OllamaChatResponse {
71    message: OllamaResponseMessage,
72}
73
74#[derive(Debug, Deserialize)]
75struct OllamaResponseMessage {
76    content: String,
77}
78
79pub enum ChatResult {
80    Success(String),
81    RateLimited(String),
82}
83
84impl AiClient {
85    pub fn new(config: AiConfig, ollama_config: Option<OllamaConfig>) -> Result<Self> {
86        let client = Client::builder()
87            .timeout(Duration::from_secs(30))
88            .build()
89            .context("Failed to create HTTP client")?;
90
91        Ok(Self {
92            client,
93            config,
94            ollama_config,
95            semaphore: Arc::new(Semaphore::new(MAX_CONCURRENT_REQUESTS)),
96            last_request: Arc::new(RwLock::new(HashMap::new())),
97        })
98    }
99
100    pub async fn chat(&self, messages: Vec<Message>, player: &str) -> Result<ChatResult> {
101        debug!("[AI] Chat request: player='{}', messages_count={}", player, messages.len());
102        
103        {
104            let last_requests = self.last_request.read();
105            if let Some(last_time) = last_requests.get(player) {
106                let elapsed = last_time.elapsed();
107                if elapsed < Duration::from_secs(PLAYER_COOLDOWN_SECS) {
108                    let wait_secs = PLAYER_COOLDOWN_SECS - elapsed.as_secs();
109                    debug!("[AI] Player '{}' rate limited, elapsed={}ms, wait={}s", player, elapsed.as_millis(), wait_secs);
110                    return Ok(ChatResult::RateLimited(format!(
111                        "Please wait {} seconds before asking again.",
112                        wait_secs
113                    )));
114                }
115            }
116        }
117
118        debug!("[AI] Acquiring semaphore permit (max concurrent: {})", MAX_CONCURRENT_REQUESTS);
119        let _permit = timeout(Duration::from_secs(10), self.semaphore.acquire())
120            .await
121            .context("AI request timeout: too many concurrent requests")?;
122        
123        {
124            let mut last_requests = self.last_request.write();
125            last_requests.insert(player.to_string(), Instant::now());
126        }
127
128        debug!("[AI] Routing to backend: ollama_enabled={}", 
129            self.ollama_config.as_ref().is_some_and(|o| o.enabled));
130        
131        let result = if let Some(ref ollama) = self.ollama_config {
132            if ollama.enabled {
133                debug!("[AI] Using Ollama backend");
134                self.chat_ollama(ollama, messages).await
135            } else {
136                debug!("[AI] Using OpenAI-compatible backend");
137                self.chat_openai(messages).await
138            }
139        } else {
140            debug!("[AI] No Ollama config, using OpenAI-compatible backend");
141            self.chat_openai(messages).await
142        };
143
144        match result {
145            Ok(response) => {
146                debug!("[AI] Chat successful, response_length={}", response.len());
147                Ok(ChatResult::Success(response))
148            }
149            Err(e) => {
150                warn!("[AI] Chat error: {}", e);
151                Err(e)
152            }
153        }
154    }
155
156    async fn chat_openai(&self, messages: Vec<Message>) -> Result<String> {
157        let request = ChatRequest {
158            model: self.config.model.clone(),
159            messages,
160            max_tokens: self.config.max_tokens,
161            temperature: self.config.temperature,
162        };
163        debug!("[AI] OpenAI request: model={}, max_tokens={}, temperature={}, api_url={}", 
164            self.config.model, self.config.max_tokens, self.config.temperature, self.config.api_url);
165
166        debug!("[AI] Sending request to OpenAI API...");
167
168        let response = self.client
169            .post(&self.config.api_url)
170            .header("Authorization", format!("Bearer {}", self.config.api_key))
171            .json(&request)
172            .send()
173            .await
174            .context("Failed to send request to OpenAI API. Please check your network connection and api_url in config.toml")?;
175
176        let status = response.status();
177        debug!("[AI] OpenAI API response status: {}", status);
178
179        if !response.status().is_success() {
180            let body = match response.text().await {
181                Ok(text) => {
182                    debug!("[AI] OpenAI error response body: {}", text);
183                    text
184                },
185                Err(e) => format!("<failed to read error response: {}>", e),
186            };
187            
188            if status.as_u16() == 401 {
189                bail!(
190                    "OpenAI API authentication failed. \n\
191                     Please check that api_key in config.toml is correct.\n\
192                     Response: {}", body
193                );
194            } else if status.as_u16() == 429 {
195                bail!(
196                    "OpenAI API rate limit exceeded. Please try again later.\n\
197                     Response: {}", body
198                );
199            }
200            
201            warn!("[AI] OpenAI API error: {} - {}", status, body);
202            bail!("OpenAI API returned error: {} - {}", status, body);
203        }
204
205        debug!("[AI] Parsing OpenAI response...");
206        let chat_response: ChatResponse = response
207            .json()
208            .await
209            .context("Failed to parse OpenAI response")?;
210
211        // 检查 API 错误
212        if let Some(error) = chat_response.error {
213            debug!("[AI] OpenAI API returned error in response: {}", error.message);
214            bail!("OpenAI API error: {}", error.message);
215        }
216
217        // 获取响应内容
218        let content = chat_response
219            .choices
220            .and_then(|c| c.first().map(|c| c.message.content.clone()))
221            .ok_or_else(|| anyhow::anyhow!("No response from AI API"))?;
222        
223        debug!("[AI] OpenAI response parsed successfully, content_length={}", content.len());
224        Ok(content)
225    }
226
227    async fn chat_ollama(&self, ollama: &OllamaConfig, messages: Vec<Message>) -> Result<String> {
228        debug!("[AI] Preparing Ollama request: model={}, messages_count={}", ollama.model, messages.len());
229        
230        // 转换为 Ollama 格式(使用 /api/chat 端点)
231        let ollama_messages: Vec<OllamaMessage> = messages
232            .into_iter()
233            .map(|m| OllamaMessage {
234                role: m.role,
235                content: m.content,
236            })
237            .collect();
238
239        // 构造 /api/chat 请求
240        let chat_url = ollama.url.replace("/api/generate", "/api/chat");
241        let request = OllamaChatRequest {
242            model: ollama.model.clone(),
243            messages: ollama_messages,
244            stream: false,
245        };
246
247        debug!("[AI] Sending request to Ollama /api/chat: {}", chat_url);
248        debug!("[AI] Ollama request: model={}, url={}", ollama.model, chat_url);
249
250        let response = self.client
251            .post(&chat_url)
252            .json(&request)
253            .send()
254            .await
255            .context("Failed to send request to Ollama API. Please ensure Ollama is running and the URL in config.toml is correct")?;
256
257        let status = response.status();
258        debug!("[AI] Ollama API response status: {}", status);
259
260        if !response.status().is_success() {
261            let body = match response.text().await {
262                Ok(text) => {
263                    debug!("[AI] Ollama error response body: {}", text);
264                    text
265                },
266                Err(e) => format!("<failed to read error response: {}>", e),
267            };
268            warn!("[AI] Ollama API error: {} - {}", status, body);
269            bail!("Ollama API returned error: {} - {}", status, body);
270        }
271
272        debug!("[AI] Parsing Ollama response...");
273        let ollama_response: OllamaChatResponse = response
274            .json()
275            .await
276            .context("Failed to parse Ollama response")?;
277
278        debug!("[AI] Ollama response parsed successfully, content_length={}", ollama_response.message.content.len());
279        Ok(ollama_response.message.content)
280    }
281
282    pub fn get_trigger(&self) -> &str {
283        &self.config.trigger
284    }
285}