Skip to main content

mc_minder/ai/
mod.rs

1use anyhow::{Context, Result, bail};
2use log::{debug, warn};
3use reqwest::Client;
4use serde::{Serialize, Deserialize};
5use std::collections::HashMap;
6use std::sync::Arc;
7use std::time::{Duration, Instant};
8use tokio::sync::Semaphore;
9use tokio::time::timeout;
10use parking_lot::RwLock;
11use crate::config::{AiConfig, OllamaConfig};
12
13const MAX_CONCURRENT_REQUESTS: usize = 3;
14const PLAYER_COOLDOWN_SECS: u64 = 2;
15
16#[derive(Debug, Clone)]
17pub struct AiClient {
18    client: Client,
19    config: AiConfig,
20    ollama_config: Option<OllamaConfig>,
21    semaphore: Arc<Semaphore>,
22    last_request: Arc<RwLock<HashMap<String, Instant>>>,
23}
24
25#[derive(Debug, Serialize)]
26struct ChatRequest {
27    model: String,
28    messages: Vec<Message>,
29    max_tokens: u32,
30    temperature: f32,
31}
32
33#[derive(Debug, Serialize, Deserialize, Clone)]
34pub struct Message {
35    pub role: String,
36    pub content: String,
37}
38
39#[derive(Debug, Deserialize)]
40struct ChatResponse {
41    choices: Option<Vec<Choice>>,
42    error: Option<ApiError>,
43}
44
45#[derive(Debug, Deserialize)]
46struct ApiError {
47    message: String,
48}
49
50#[derive(Debug, Deserialize)]
51struct Choice {
52    message: Message,
53}
54
55// Ollama /api/chat 端点请求格式
56#[derive(Debug, Serialize)]
57struct OllamaChatRequest {
58    model: String,
59    messages: Vec<OllamaMessage>,
60    stream: bool,
61}
62
63#[derive(Debug, Serialize, Clone)]
64struct OllamaMessage {
65    role: String,
66    content: String,
67}
68
69#[derive(Debug, Deserialize)]
70struct OllamaChatResponse {
71    message: OllamaResponseMessage,
72}
73
74#[derive(Debug, Deserialize)]
75struct OllamaResponseMessage {
76    content: String,
77}
78
79pub enum ChatResult {
80    Success(String),
81    RateLimited(String),
82}
83
84impl AiClient {
85    pub fn new(config: AiConfig, ollama_config: Option<OllamaConfig>) -> Result<Self> {
86        let client = Client::builder()
87            .timeout(Duration::from_secs(30))
88            .build()
89            .context("Failed to create HTTP client")?;
90
91        Ok(Self {
92            client,
93            config,
94            ollama_config,
95            semaphore: Arc::new(Semaphore::new(MAX_CONCURRENT_REQUESTS)),
96            last_request: Arc::new(RwLock::new(HashMap::new())),
97        })
98    }
99
100    pub async fn chat(&self, messages: Vec<Message>, player: &str) -> Result<ChatResult> {
101        {
102            let last_requests = self.last_request.read();
103            if let Some(last_time) = last_requests.get(player) {
104                let elapsed = last_time.elapsed();
105                if elapsed < Duration::from_secs(PLAYER_COOLDOWN_SECS) {
106                    let wait_secs = PLAYER_COOLDOWN_SECS - elapsed.as_secs();
107                    return Ok(ChatResult::RateLimited(format!(
108                        "Please wait {} seconds before asking again.",
109                        wait_secs
110                    )));
111                }
112            }
113        }
114
115        let _permit = timeout(Duration::from_secs(10), self.semaphore.acquire())
116            .await
117            .context("AI request timeout: too many concurrent requests")?;
118        
119        {
120            let mut last_requests = self.last_request.write();
121            last_requests.insert(player.to_string(), Instant::now());
122        }
123
124        let result = if let Some(ref ollama) = self.ollama_config {
125            if ollama.enabled {
126                self.chat_ollama(ollama, messages).await
127            } else {
128                self.chat_openai(messages).await
129            }
130        } else {
131            self.chat_openai(messages).await
132        };
133
134        match result {
135            Ok(response) => Ok(ChatResult::Success(response)),
136            Err(e) => {
137                warn!("AI chat error: {}", e);
138                Err(e)
139            }
140        }
141    }
142
143    async fn chat_openai(&self, messages: Vec<Message>) -> Result<String> {
144        let request = ChatRequest {
145            model: self.config.model.clone(),
146            messages,
147            max_tokens: self.config.max_tokens,
148            temperature: self.config.temperature,
149        };
150
151        debug!("Sending request to OpenAI API");
152
153        let response = self.client
154            .post(&self.config.api_url)
155            .header("Authorization", format!("Bearer {}", self.config.api_key))
156            .json(&request)
157            .send()
158            .await
159            .context("Failed to send request to OpenAI API. Please check your network connection and api_url in config.toml")?;
160
161        if !response.status().is_success() {
162            let status = response.status();
163            let body = match response.text().await {
164                Ok(text) => text,
165                Err(e) => format!("<failed to read error response: {}>", e),
166            };
167            
168            if status.as_u16() == 401 {
169                bail!(
170                    "OpenAI API authentication failed. \n\
171                     Please check that api_key in config.toml is correct.\n\
172                     Response: {}", body
173                );
174            } else if status.as_u16() == 429 {
175                bail!(
176                    "OpenAI API rate limit exceeded. Please try again later.\n\
177                     Response: {}", body
178                );
179            }
180            
181            warn!("OpenAI API error: {} - {}", status, body);
182            bail!("OpenAI API returned error: {} - {}", status, body);
183        }
184
185        let chat_response: ChatResponse = response
186            .json()
187            .await
188            .context("Failed to parse OpenAI response")?;
189
190        // 检查 API 错误
191        if let Some(error) = chat_response.error {
192            bail!("OpenAI API error: {}", error.message);
193        }
194
195        // 获取响应内容
196        chat_response
197            .choices
198            .and_then(|c| c.first().map(|c| c.message.content.clone()))
199            .ok_or_else(|| anyhow::anyhow!("No response from AI API"))
200    }
201
202    async fn chat_ollama(&self, ollama: &OllamaConfig, messages: Vec<Message>) -> Result<String> {
203        // 转换为 Ollama 格式(使用 /api/chat 端点)
204        let ollama_messages: Vec<OllamaMessage> = messages
205            .into_iter()
206            .map(|m| OllamaMessage {
207                role: m.role,
208                content: m.content,
209            })
210            .collect();
211
212        // 构造 /api/chat 请求
213        let chat_url = ollama.url.replace("/api/generate", "/api/chat");
214        let request = OllamaChatRequest {
215            model: ollama.model.clone(),
216            messages: ollama_messages,
217            stream: false,
218        };
219
220        debug!("Sending request to Ollama /api/chat: {}", chat_url);
221
222        let response = self.client
223            .post(&chat_url)
224            .json(&request)
225            .send()
226            .await
227            .context("Failed to send request to Ollama API. Please ensure Ollama is running and the URL in config.toml is correct")?;
228
229        if !response.status().is_success() {
230            let status = response.status();
231            let body = match response.text().await {
232                Ok(text) => text,
233                Err(e) => format!("<failed to read error response: {}>", e),
234            };
235            warn!("Ollama API error: {} - {}", status, body);
236            bail!("Ollama API returned error: {} - {}", status, body);
237        }
238
239        let ollama_response: OllamaChatResponse = response
240            .json()
241            .await
242            .context("Failed to parse Ollama response")?;
243
244        Ok(ollama_response.message.content)
245    }
246
247    pub fn get_trigger(&self) -> &str {
248        &self.config.trigger
249    }
250}