Skip to main content

mc_minder/ai/
mod.rs

1use anyhow::{Context, Result, bail};
2use log::{debug, warn};
3use reqwest::Client;
4use serde::{Serialize, Deserialize};
5use std::collections::HashMap;
6use std::sync::Arc;
7use std::time::{Duration, Instant};
8use tokio::sync::Semaphore;
9use parking_lot::RwLock;
10use crate::config::{AiConfig, OllamaConfig};
11
12const MAX_CONCURRENT_REQUESTS: usize = 3;
13const PLAYER_COOLDOWN_SECS: u64 = 2;
14
15#[derive(Debug, Clone)]
16pub struct AiClient {
17    client: Client,
18    config: AiConfig,
19    ollama_config: Option<OllamaConfig>,
20    semaphore: Arc<Semaphore>,
21    last_request: Arc<RwLock<HashMap<String, Instant>>>,
22}
23
24#[derive(Debug, Serialize)]
25struct ChatRequest {
26    model: String,
27    messages: Vec<Message>,
28    max_tokens: u32,
29    temperature: f32,
30}
31
32#[derive(Debug, Serialize, Deserialize, Clone)]
33pub struct Message {
34    pub role: String,
35    pub content: String,
36}
37
38#[derive(Debug, Deserialize)]
39struct ChatResponse {
40    choices: Vec<Choice>,
41}
42
43#[derive(Debug, Deserialize)]
44struct Choice {
45    message: Message,
46}
47
48#[derive(Debug, Serialize)]
49struct OllamaRequest {
50    model: String,
51    prompt: String,
52    stream: bool,
53}
54
55#[derive(Debug, Deserialize)]
56struct OllamaResponse {
57    response: String,
58}
59
60pub enum ChatResult {
61    Success(String),
62    RateLimited(String),
63}
64
65impl AiClient {
66    pub fn new(config: AiConfig, ollama_config: Option<OllamaConfig>) -> Result<Self> {
67        let client = Client::builder()
68            .timeout(Duration::from_secs(30))
69            .build()
70            .context("Failed to create HTTP client")?;
71
72        Ok(Self {
73            client,
74            config,
75            ollama_config,
76            semaphore: Arc::new(Semaphore::new(MAX_CONCURRENT_REQUESTS)),
77            last_request: Arc::new(RwLock::new(HashMap::new())),
78        })
79    }
80
81    pub async fn chat(&self, messages: Vec<Message>, player: &str) -> Result<ChatResult> {
82        {
83            let last_requests = self.last_request.read();
84            if let Some(last_time) = last_requests.get(player) {
85                let elapsed = last_time.elapsed();
86                if elapsed < Duration::from_secs(PLAYER_COOLDOWN_SECS) {
87                    let wait_secs = PLAYER_COOLDOWN_SECS - elapsed.as_secs();
88                    return Ok(ChatResult::RateLimited(format!(
89                        "Please wait {} seconds before asking again.",
90                        wait_secs
91                    )));
92                }
93            }
94        }
95
96        let _permit = self.semaphore.acquire().await;
97        
98        {
99            let mut last_requests = self.last_request.write();
100            last_requests.insert(player.to_string(), Instant::now());
101        }
102
103        let result = if let Some(ref ollama) = self.ollama_config {
104            if ollama.enabled {
105                self.chat_ollama(ollama, messages).await
106            } else {
107                self.chat_openai(messages).await
108            }
109        } else {
110            self.chat_openai(messages).await
111        };
112
113        match result {
114            Ok(response) => Ok(ChatResult::Success(response)),
115            Err(e) => {
116                warn!("AI chat error: {}", e);
117                Err(e)
118            }
119        }
120    }
121
122    async fn chat_openai(&self, messages: Vec<Message>) -> Result<String> {
123        let request = ChatRequest {
124            model: self.config.model.clone(),
125            messages,
126            max_tokens: self.config.max_tokens,
127            temperature: self.config.temperature,
128        };
129
130        debug!("Sending request to OpenAI API");
131
132        let response = self.client
133            .post(&self.config.api_url)
134            .header("Authorization", format!("Bearer {}", self.config.api_key))
135            .json(&request)
136            .send()
137            .await
138            .context("Failed to send request to OpenAI API. Please check your network connection and api_url in config.toml")?;
139
140        if !response.status().is_success() {
141            let status = response.status();
142            let body = response.text().await.unwrap_or_default();
143            
144            if status.as_u16() == 401 {
145                bail!(
146                    "OpenAI API authentication failed. \n\
147                     Please check that api_key in config.toml is correct."
148                );
149            } else if status.as_u16() == 429 {
150                bail!(
151                    "OpenAI API rate limit exceeded. Please try again later."
152                );
153            }
154            
155            warn!("OpenAI API error: {} - {}", status, body);
156            bail!("OpenAI API returned error: {}", status);
157        }
158
159        let chat_response: ChatResponse = response
160            .json()
161            .await
162            .context("Failed to parse OpenAI response")?;
163
164        chat_response
165            .choices
166            .first()
167            .map(|c| c.message.content.clone())
168            .ok_or_else(|| anyhow::anyhow!("No response from OpenAI"))
169    }
170
171    async fn chat_ollama(&self, ollama: &OllamaConfig, messages: Vec<Message>) -> Result<String> {
172        let prompt = messages
173            .iter()
174            .map(|m| format!("{}: {}", m.role, m.content))
175            .collect::<Vec<_>>()
176            .join("\n");
177
178        let request = OllamaRequest {
179            model: ollama.model.clone(),
180            prompt,
181            stream: false,
182        };
183
184        debug!("Sending request to Ollama API");
185
186        let response = self.client
187            .post(&ollama.url)
188            .json(&request)
189            .send()
190            .await
191            .context("Failed to send request to Ollama API. Please ensure Ollama is running and the URL in config.toml is correct")?;
192
193        if !response.status().is_success() {
194            let status = response.status();
195            let body = response.text().await.unwrap_or_default();
196            warn!("Ollama API error: {} - {}", status, body);
197            bail!("Ollama API returned error: {}", status);
198        }
199
200        let ollama_response: OllamaResponse = response
201            .json()
202            .await
203            .context("Failed to parse Ollama response")?;
204
205        Ok(ollama_response.response)
206    }
207
208    pub fn get_trigger(&self) -> &str {
209        &self.config.trigger
210    }
211}