1use anyhow::Result;
2use serde::{Deserialize, Serialize};
3use std::collections::HashMap;
4use tracing::info;
5
6use crate::searcher::{LeannSearcher, SearcherOptions};
7use crate::settings;
8
9pub trait LlmProvider: Send + Sync {
11 fn ask(&self, prompt: &str, params: &LlmParams) -> Result<String>;
13}
14
15#[derive(Debug, Clone, Default)]
17pub struct LlmParams {
18 pub temperature: Option<f64>,
19 pub max_tokens: Option<usize>,
20 pub top_p: Option<f64>,
21 pub extra: HashMap<String, serde_json::Value>,
22}
23
24#[derive(Debug, Clone, Serialize, Deserialize)]
26pub struct LlmConfig {
27 #[serde(rename = "type")]
28 pub llm_type: String,
29 #[serde(default)]
30 pub model: Option<String>,
31 #[serde(default)]
32 pub api_key: Option<String>,
33 #[serde(default)]
34 pub base_url: Option<String>,
35 #[serde(default)]
36 pub host: Option<String>,
37}
38
39impl Default for LlmConfig {
40 fn default() -> Self {
41 Self {
42 llm_type: "openai".to_string(),
43 model: Some("gpt-4o".to_string()),
44 api_key: None,
45 base_url: None,
46 host: None,
47 }
48 }
49}
50
51pub struct OllamaChat {
53 model: String,
54 host: String,
55 client: reqwest::blocking::Client,
56}
57
58impl OllamaChat {
59 pub fn new(model: &str, host: Option<&str>) -> Self {
60 Self {
61 model: model.to_string(),
62 host: settings::resolve_ollama_host(host),
63 client: reqwest::blocking::Client::new(),
64 }
65 }
66}
67
68impl LlmProvider for OllamaChat {
69 fn ask(&self, prompt: &str, _params: &LlmParams) -> Result<String> {
70 let payload = serde_json::json!({
71 "model": self.model,
72 "prompt": prompt,
73 "stream": false,
74 });
75
76 let response = self
77 .client
78 .post(format!("{}/api/generate", self.host))
79 .json(&payload)
80 .send()?;
81
82 let body: serde_json::Value = response.json()?;
83 Ok(body["response"].as_str().unwrap_or("").to_string())
84 }
85}
86
87pub struct OpenAiChat {
89 model: String,
90 api_key: String,
91 base_url: String,
92 client: reqwest::blocking::Client,
93}
94
95impl OpenAiChat {
96 pub fn new(model: &str, api_key: Option<&str>, base_url: Option<&str>) -> Result<Self> {
97 let api_key = settings::resolve_openai_api_key(api_key)
98 .ok_or_else(|| anyhow::anyhow!("OpenAI API key required"))?;
99 let base_url = settings::resolve_openai_base_url(base_url);
100
101 Ok(Self {
102 model: model.to_string(),
103 api_key,
104 base_url,
105 client: reqwest::blocking::Client::new(),
106 })
107 }
108}
109
110impl LlmProvider for OpenAiChat {
111 fn ask(&self, prompt: &str, params: &LlmParams) -> Result<String> {
112 let mut payload = serde_json::json!({
113 "model": self.model,
114 "messages": [{"role": "user", "content": prompt}],
115 "temperature": params.temperature.unwrap_or(0.7),
116 });
117
118 if let Some(max_tokens) = params.max_tokens {
119 payload["max_tokens"] = serde_json::json!(max_tokens);
120 }
121
122 let response = self
123 .client
124 .post(format!("{}/chat/completions", self.base_url))
125 .header("Authorization", format!("Bearer {}", self.api_key))
126 .json(&payload)
127 .send()?;
128
129 let body: serde_json::Value = response.json()?;
130 Ok(body["choices"][0]["message"]["content"]
131 .as_str()
132 .unwrap_or("")
133 .trim()
134 .to_string())
135 }
136}
137
138pub struct AnthropicChat {
140 model: String,
141 api_key: String,
142 base_url: String,
143 client: reqwest::blocking::Client,
144}
145
146impl AnthropicChat {
147 pub fn new(model: &str, api_key: Option<&str>, base_url: Option<&str>) -> Result<Self> {
148 let api_key = settings::resolve_anthropic_api_key(api_key)
149 .ok_or_else(|| anyhow::anyhow!("Anthropic API key required"))?;
150 let base_url = settings::resolve_anthropic_base_url(base_url);
151
152 Ok(Self {
153 model: model.to_string(),
154 api_key,
155 base_url,
156 client: reqwest::blocking::Client::new(),
157 })
158 }
159}
160
161impl LlmProvider for AnthropicChat {
162 fn ask(&self, prompt: &str, params: &LlmParams) -> Result<String> {
163 let mut payload = serde_json::json!({
164 "model": self.model,
165 "max_tokens": params.max_tokens.unwrap_or(1000),
166 "messages": [{"role": "user", "content": prompt}],
167 });
168
169 if let Some(temp) = params.temperature {
170 payload["temperature"] = serde_json::json!(temp);
171 }
172
173 let response = self
174 .client
175 .post(format!("{}/v1/messages", self.base_url))
176 .header("x-api-key", &self.api_key)
177 .header("anthropic-version", "2023-06-01")
178 .header("content-type", "application/json")
179 .json(&payload)
180 .send()?;
181
182 let body: serde_json::Value = response.json()?;
183 Ok(body["content"][0]["text"]
184 .as_str()
185 .unwrap_or("")
186 .trim()
187 .to_string())
188 }
189}
190
191pub struct GeminiChat {
193 model: String,
194 api_key: String,
195 client: reqwest::blocking::Client,
196}
197
198impl GeminiChat {
199 pub fn new(model: &str, api_key: Option<&str>) -> Result<Self> {
200 let api_key = settings::resolve_gemini_api_key(api_key)
201 .ok_or_else(|| anyhow::anyhow!("Gemini API key required. Set GEMINI_API_KEY environment variable or pass api_key parameter."))?;
202
203 Ok(Self {
204 model: model.to_string(),
205 api_key,
206 client: reqwest::blocking::Client::new(),
207 })
208 }
209}
210
211impl LlmProvider for GeminiChat {
212 fn ask(&self, prompt: &str, params: &LlmParams) -> Result<String> {
213 let mut generation_config = serde_json::json!({
214 "temperature": params.temperature.unwrap_or(0.7),
215 "maxOutputTokens": params.max_tokens.unwrap_or(1000),
216 });
217
218 if let Some(top_p) = params.top_p {
219 generation_config["topP"] = serde_json::json!(top_p);
220 }
221
222 let payload = serde_json::json!({
223 "contents": [{"parts": [{"text": prompt}]}],
224 "generationConfig": generation_config,
225 });
226
227 let url = format!(
228 "https://generativelanguage.googleapis.com/v1beta/models/{}:generateContent?key={}",
229 self.model, self.api_key
230 );
231
232 let response = self.client.post(&url).json(&payload).send()?;
233
234 let body: serde_json::Value = response.json()?;
235 Ok(body["candidates"][0]["content"]["parts"][0]["text"]
236 .as_str()
237 .unwrap_or("")
238 .trim()
239 .to_string())
240 }
241}
242
243pub struct SimulatedChat;
245
246impl LlmProvider for SimulatedChat {
247 fn ask(&self, _prompt: &str, _params: &LlmParams) -> Result<String> {
248 Ok("This is a simulated answer from the LLM based on the retrieved context.".to_string())
249 }
250}
251
252pub fn get_llm(config: &LlmConfig) -> Result<Box<dyn LlmProvider>> {
254 match config.llm_type.as_str() {
255 "ollama" => Ok(Box::new(OllamaChat::new(
256 config.model.as_deref().unwrap_or("llama3:8b"),
257 config.host.as_deref(),
258 ))),
259 "openai" => Ok(Box::new(OpenAiChat::new(
260 config.model.as_deref().unwrap_or("gpt-4o"),
261 config.api_key.as_deref(),
262 config.base_url.as_deref(),
263 )?)),
264 "anthropic" => Ok(Box::new(AnthropicChat::new(
265 config
266 .model
267 .as_deref()
268 .unwrap_or("claude-3-5-sonnet-20241022"),
269 config.api_key.as_deref(),
270 config.base_url.as_deref(),
271 )?)),
272 "gemini" => Ok(Box::new(GeminiChat::new(
273 config.model.as_deref().unwrap_or("gemini-2.5-flash"),
274 config.api_key.as_deref(),
275 )?)),
276 "simulated" => Ok(Box::new(SimulatedChat)),
277 other => anyhow::bail!("Unknown LLM type: {}", other),
278 }
279}
280
281pub struct LeannChat {
283 searcher: LeannSearcher,
284 llm: Box<dyn LlmProvider>,
285 #[allow(dead_code)]
286 owns_searcher: bool,
287}
288
289impl LeannChat {
290 pub fn new(searcher: LeannSearcher, llm_config: Option<&LlmConfig>) -> Result<Self> {
291 let config = llm_config.cloned().unwrap_or_default();
292 let llm = get_llm(&config)?;
293
294 Ok(Self {
295 searcher,
296 llm,
297 owns_searcher: true,
298 })
299 }
300
301 pub fn new_with_options(
303 index_path: &std::path::Path,
304 llm_config: Option<&LlmConfig>,
305 searcher_options: &SearcherOptions,
306 ) -> Result<Self> {
307 let searcher = LeannSearcher::open_with_options(index_path, searcher_options)?;
308 Self::new(searcher, llm_config)
309 }
310
311 pub fn ask(&self, question: &str, top_k: usize) -> Result<String> {
313 self.ask_with_params(
314 question,
315 top_k,
316 &crate::searcher::SearchConfig::default(),
317 &LlmParams::default(),
318 )
319 }
320
321 pub fn ask_with_params(
323 &self,
324 question: &str,
325 top_k: usize,
326 config: &crate::searcher::SearchConfig,
327 llm_params: &LlmParams,
328 ) -> Result<String> {
329 let results = self.searcher.search_with_params(question, top_k, config)?;
330
331 let context: String = results
332 .iter()
333 .map(|r| r.text.as_str())
334 .collect::<Vec<_>>()
335 .join("\n\n");
336
337 let prompt = format!(
338 "Here is some retrieved context that might help answer your question:\n\n\
339 {}\n\n\
340 Question: {}\n\n\
341 Please provide the best answer you can based on this context and your knowledge.",
342 context, question
343 );
344
345 info!(
346 "Sending RAG prompt to LLM ({} context results)",
347 results.len()
348 );
349 let answer = self.llm.ask(&prompt, llm_params)?;
350 Ok(answer)
351 }
352
353 pub fn cleanup(&mut self) {
354 self.searcher.cleanup();
355 }
356}
357
358impl Drop for LeannChat {
359 fn drop(&mut self) {
360 self.cleanup();
361 }
362}