Skip to main content

leann_core/
chat.rs

1use anyhow::Result;
2use serde::{Deserialize, Serialize};
3use std::collections::HashMap;
4use tracing::info;
5
6use crate::searcher::{LeannSearcher, SearcherOptions};
7use crate::settings;
8
9/// Trait for LLM chat backends.
10pub trait LlmProvider: Send + Sync {
11    /// Send a prompt and get a response.
12    fn ask(&self, prompt: &str, params: &LlmParams) -> Result<String>;
13}
14
15/// Parameters for LLM generation.
16#[derive(Debug, Clone, Default)]
17pub struct LlmParams {
18    pub temperature: Option<f64>,
19    pub max_tokens: Option<usize>,
20    pub top_p: Option<f64>,
21    pub extra: HashMap<String, serde_json::Value>,
22}
23
24/// Configuration for creating an LLM provider.
25#[derive(Debug, Clone, Serialize, Deserialize)]
26pub struct LlmConfig {
27    #[serde(rename = "type")]
28    pub llm_type: String,
29    #[serde(default)]
30    pub model: Option<String>,
31    #[serde(default)]
32    pub api_key: Option<String>,
33    #[serde(default)]
34    pub base_url: Option<String>,
35    #[serde(default)]
36    pub host: Option<String>,
37}
38
39impl Default for LlmConfig {
40    fn default() -> Self {
41        Self {
42            llm_type: "openai".to_string(),
43            model: Some("gpt-4o".to_string()),
44            api_key: None,
45            base_url: None,
46            host: None,
47        }
48    }
49}
50
51/// Ollama LLM chat backend.
52pub struct OllamaChat {
53    model: String,
54    host: String,
55    client: reqwest::blocking::Client,
56}
57
58impl OllamaChat {
59    pub fn new(model: &str, host: Option<&str>) -> Self {
60        Self {
61            model: model.to_string(),
62            host: settings::resolve_ollama_host(host),
63            client: reqwest::blocking::Client::new(),
64        }
65    }
66}
67
68impl LlmProvider for OllamaChat {
69    fn ask(&self, prompt: &str, _params: &LlmParams) -> Result<String> {
70        let payload = serde_json::json!({
71            "model": self.model,
72            "prompt": prompt,
73            "stream": false,
74        });
75
76        let response = self
77            .client
78            .post(format!("{}/api/generate", self.host))
79            .json(&payload)
80            .send()?;
81
82        let body: serde_json::Value = response.json()?;
83        Ok(body["response"].as_str().unwrap_or("").to_string())
84    }
85}
86
87/// OpenAI LLM chat backend.
88pub struct OpenAiChat {
89    model: String,
90    api_key: String,
91    base_url: String,
92    client: reqwest::blocking::Client,
93}
94
95impl OpenAiChat {
96    pub fn new(model: &str, api_key: Option<&str>, base_url: Option<&str>) -> Result<Self> {
97        let api_key = settings::resolve_openai_api_key(api_key)
98            .ok_or_else(|| anyhow::anyhow!("OpenAI API key required"))?;
99        let base_url = settings::resolve_openai_base_url(base_url);
100
101        Ok(Self {
102            model: model.to_string(),
103            api_key,
104            base_url,
105            client: reqwest::blocking::Client::new(),
106        })
107    }
108}
109
110impl LlmProvider for OpenAiChat {
111    fn ask(&self, prompt: &str, params: &LlmParams) -> Result<String> {
112        let mut payload = serde_json::json!({
113            "model": self.model,
114            "messages": [{"role": "user", "content": prompt}],
115            "temperature": params.temperature.unwrap_or(0.7),
116        });
117
118        if let Some(max_tokens) = params.max_tokens {
119            payload["max_tokens"] = serde_json::json!(max_tokens);
120        }
121
122        let response = self
123            .client
124            .post(format!("{}/chat/completions", self.base_url))
125            .header("Authorization", format!("Bearer {}", self.api_key))
126            .json(&payload)
127            .send()?;
128
129        let body: serde_json::Value = response.json()?;
130        Ok(body["choices"][0]["message"]["content"]
131            .as_str()
132            .unwrap_or("")
133            .trim()
134            .to_string())
135    }
136}
137
138/// Anthropic LLM chat backend.
139pub struct AnthropicChat {
140    model: String,
141    api_key: String,
142    base_url: String,
143    client: reqwest::blocking::Client,
144}
145
146impl AnthropicChat {
147    pub fn new(model: &str, api_key: Option<&str>, base_url: Option<&str>) -> Result<Self> {
148        let api_key = settings::resolve_anthropic_api_key(api_key)
149            .ok_or_else(|| anyhow::anyhow!("Anthropic API key required"))?;
150        let base_url = settings::resolve_anthropic_base_url(base_url);
151
152        Ok(Self {
153            model: model.to_string(),
154            api_key,
155            base_url,
156            client: reqwest::blocking::Client::new(),
157        })
158    }
159}
160
161impl LlmProvider for AnthropicChat {
162    fn ask(&self, prompt: &str, params: &LlmParams) -> Result<String> {
163        let mut payload = serde_json::json!({
164            "model": self.model,
165            "max_tokens": params.max_tokens.unwrap_or(1000),
166            "messages": [{"role": "user", "content": prompt}],
167        });
168
169        if let Some(temp) = params.temperature {
170            payload["temperature"] = serde_json::json!(temp);
171        }
172
173        let response = self
174            .client
175            .post(format!("{}/v1/messages", self.base_url))
176            .header("x-api-key", &self.api_key)
177            .header("anthropic-version", "2023-06-01")
178            .header("content-type", "application/json")
179            .json(&payload)
180            .send()?;
181
182        let body: serde_json::Value = response.json()?;
183        Ok(body["content"][0]["text"]
184            .as_str()
185            .unwrap_or("")
186            .trim()
187            .to_string())
188    }
189}
190
191/// Google Gemini LLM chat backend.
192pub struct GeminiChat {
193    model: String,
194    api_key: String,
195    client: reqwest::blocking::Client,
196}
197
198impl GeminiChat {
199    pub fn new(model: &str, api_key: Option<&str>) -> Result<Self> {
200        let api_key = settings::resolve_gemini_api_key(api_key)
201            .ok_or_else(|| anyhow::anyhow!("Gemini API key required. Set GEMINI_API_KEY environment variable or pass api_key parameter."))?;
202
203        Ok(Self {
204            model: model.to_string(),
205            api_key,
206            client: reqwest::blocking::Client::new(),
207        })
208    }
209}
210
211impl LlmProvider for GeminiChat {
212    fn ask(&self, prompt: &str, params: &LlmParams) -> Result<String> {
213        let mut generation_config = serde_json::json!({
214            "temperature": params.temperature.unwrap_or(0.7),
215            "maxOutputTokens": params.max_tokens.unwrap_or(1000),
216        });
217
218        if let Some(top_p) = params.top_p {
219            generation_config["topP"] = serde_json::json!(top_p);
220        }
221
222        let payload = serde_json::json!({
223            "contents": [{"parts": [{"text": prompt}]}],
224            "generationConfig": generation_config,
225        });
226
227        let url = format!(
228            "https://generativelanguage.googleapis.com/v1beta/models/{}:generateContent?key={}",
229            self.model, self.api_key
230        );
231
232        let response = self.client.post(&url).json(&payload).send()?;
233
234        let body: serde_json::Value = response.json()?;
235        Ok(body["candidates"][0]["content"]["parts"][0]["text"]
236            .as_str()
237            .unwrap_or("")
238            .trim()
239            .to_string())
240    }
241}
242
243/// Simulated LLM for testing.
244pub struct SimulatedChat;
245
246impl LlmProvider for SimulatedChat {
247    fn ask(&self, _prompt: &str, _params: &LlmParams) -> Result<String> {
248        Ok("This is a simulated answer from the LLM based on the retrieved context.".to_string())
249    }
250}
251
252/// Factory function to create an LLM provider from configuration.
253pub fn get_llm(config: &LlmConfig) -> Result<Box<dyn LlmProvider>> {
254    match config.llm_type.as_str() {
255        "ollama" => Ok(Box::new(OllamaChat::new(
256            config.model.as_deref().unwrap_or("llama3:8b"),
257            config.host.as_deref(),
258        ))),
259        "openai" => Ok(Box::new(OpenAiChat::new(
260            config.model.as_deref().unwrap_or("gpt-4o"),
261            config.api_key.as_deref(),
262            config.base_url.as_deref(),
263        )?)),
264        "anthropic" => Ok(Box::new(AnthropicChat::new(
265            config
266                .model
267                .as_deref()
268                .unwrap_or("claude-3-5-sonnet-20241022"),
269            config.api_key.as_deref(),
270            config.base_url.as_deref(),
271        )?)),
272        "gemini" => Ok(Box::new(GeminiChat::new(
273            config.model.as_deref().unwrap_or("gemini-2.5-flash"),
274            config.api_key.as_deref(),
275        )?)),
276        "simulated" => Ok(Box::new(SimulatedChat)),
277        other => anyhow::bail!("Unknown LLM type: {}", other),
278    }
279}
280
281/// High-level RAG chat interface combining search + LLM.
282pub struct LeannChat {
283    searcher: LeannSearcher,
284    llm: Box<dyn LlmProvider>,
285    #[allow(dead_code)]
286    owns_searcher: bool,
287}
288
289impl LeannChat {
290    pub fn new(searcher: LeannSearcher, llm_config: Option<&LlmConfig>) -> Result<Self> {
291        let config = llm_config.cloned().unwrap_or_default();
292        let llm = get_llm(&config)?;
293
294        Ok(Self {
295            searcher,
296            llm,
297            owns_searcher: true,
298        })
299    }
300
301    /// Create a new chat from an index path with custom searcher options.
302    pub fn new_with_options(
303        index_path: &std::path::Path,
304        llm_config: Option<&LlmConfig>,
305        searcher_options: &SearcherOptions,
306    ) -> Result<Self> {
307        let searcher = LeannSearcher::open_with_options(index_path, searcher_options)?;
308        Self::new(searcher, llm_config)
309    }
310
311    /// Ask a question using RAG (retrieve context, then generate answer).
312    pub fn ask(&self, question: &str, top_k: usize) -> Result<String> {
313        self.ask_with_params(
314            question,
315            top_k,
316            &crate::searcher::SearchConfig::default(),
317            &LlmParams::default(),
318        )
319    }
320
321    /// Ask a question using RAG with full search and LLM configuration.
322    pub fn ask_with_params(
323        &self,
324        question: &str,
325        top_k: usize,
326        config: &crate::searcher::SearchConfig,
327        llm_params: &LlmParams,
328    ) -> Result<String> {
329        let results = self.searcher.search_with_params(question, top_k, config)?;
330
331        let context: String = results
332            .iter()
333            .map(|r| r.text.as_str())
334            .collect::<Vec<_>>()
335            .join("\n\n");
336
337        let prompt = format!(
338            "Here is some retrieved context that might help answer your question:\n\n\
339             {}\n\n\
340             Question: {}\n\n\
341             Please provide the best answer you can based on this context and your knowledge.",
342            context, question
343        );
344
345        info!(
346            "Sending RAG prompt to LLM ({} context results)",
347            results.len()
348        );
349        let answer = self.llm.ask(&prompt, llm_params)?;
350        Ok(answer)
351    }
352
353    pub fn cleanup(&mut self) {
354        self.searcher.cleanup();
355    }
356}
357
358impl Drop for LeannChat {
359    fn drop(&mut self) {
360        self.cleanup();
361    }
362}