use anyhow::Result;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use tracing::info;
use crate::searcher::LeannSearcher;
use crate::settings;
pub trait LlmProvider: Send + Sync {
fn ask(&self, prompt: &str, params: &LlmParams) -> Result<String>;
}
#[derive(Debug, Clone, Default)]
pub struct LlmParams {
pub temperature: Option<f64>,
pub max_tokens: Option<usize>,
pub top_p: Option<f64>,
pub extra: HashMap<String, serde_json::Value>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LlmConfig {
#[serde(rename = "type")]
pub llm_type: String,
#[serde(default)]
pub model: Option<String>,
#[serde(default)]
pub api_key: Option<String>,
#[serde(default)]
pub base_url: Option<String>,
#[serde(default)]
pub host: Option<String>,
}
impl Default for LlmConfig {
fn default() -> Self {
Self {
llm_type: "openai".to_string(),
model: Some("gpt-4o".to_string()),
api_key: None,
base_url: None,
host: None,
}
}
}
pub struct OllamaChat {
model: String,
host: String,
client: reqwest::blocking::Client,
}
impl OllamaChat {
pub fn new(model: &str, host: Option<&str>) -> Self {
Self {
model: model.to_string(),
host: settings::resolve_ollama_host(host),
client: reqwest::blocking::Client::new(),
}
}
}
impl LlmProvider for OllamaChat {
fn ask(&self, prompt: &str, _params: &LlmParams) -> Result<String> {
let payload = serde_json::json!({
"model": self.model,
"prompt": prompt,
"stream": false,
});
let response = self
.client
.post(format!("{}/api/generate", self.host))
.json(&payload)
.send()?;
let body: serde_json::Value = response.json()?;
Ok(body["response"].as_str().unwrap_or("").to_string())
}
}
pub struct OpenAiChat {
model: String,
api_key: String,
base_url: String,
client: reqwest::blocking::Client,
}
impl OpenAiChat {
pub fn new(model: &str, api_key: Option<&str>, base_url: Option<&str>) -> Result<Self> {
let api_key = settings::resolve_openai_api_key(api_key)
.ok_or_else(|| anyhow::anyhow!("OpenAI API key required"))?;
let base_url = settings::resolve_openai_base_url(base_url);
Ok(Self {
model: model.to_string(),
api_key,
base_url,
client: reqwest::blocking::Client::new(),
})
}
}
impl LlmProvider for OpenAiChat {
fn ask(&self, prompt: &str, params: &LlmParams) -> Result<String> {
let mut payload = serde_json::json!({
"model": self.model,
"messages": [{"role": "user", "content": prompt}],
"temperature": params.temperature.unwrap_or(0.7),
});
if let Some(max_tokens) = params.max_tokens {
payload["max_tokens"] = serde_json::json!(max_tokens);
}
let response = self
.client
.post(format!("{}/chat/completions", self.base_url))
.header("Authorization", format!("Bearer {}", self.api_key))
.json(&payload)
.send()?;
let body: serde_json::Value = response.json()?;
Ok(body["choices"][0]["message"]["content"]
.as_str()
.unwrap_or("")
.trim()
.to_string())
}
}
pub struct AnthropicChat {
model: String,
api_key: String,
base_url: String,
client: reqwest::blocking::Client,
}
impl AnthropicChat {
pub fn new(model: &str, api_key: Option<&str>, base_url: Option<&str>) -> Result<Self> {
let api_key = settings::resolve_anthropic_api_key(api_key)
.ok_or_else(|| anyhow::anyhow!("Anthropic API key required"))?;
let base_url = settings::resolve_anthropic_base_url(base_url);
Ok(Self {
model: model.to_string(),
api_key,
base_url,
client: reqwest::blocking::Client::new(),
})
}
}
impl LlmProvider for AnthropicChat {
fn ask(&self, prompt: &str, params: &LlmParams) -> Result<String> {
let mut payload = serde_json::json!({
"model": self.model,
"max_tokens": params.max_tokens.unwrap_or(1000),
"messages": [{"role": "user", "content": prompt}],
});
if let Some(temp) = params.temperature {
payload["temperature"] = serde_json::json!(temp);
}
let response = self
.client
.post(format!("{}/v1/messages", self.base_url))
.header("x-api-key", &self.api_key)
.header("anthropic-version", "2023-06-01")
.header("content-type", "application/json")
.json(&payload)
.send()?;
let body: serde_json::Value = response.json()?;
Ok(body["content"][0]["text"]
.as_str()
.unwrap_or("")
.trim()
.to_string())
}
}
pub struct SimulatedChat;
impl LlmProvider for SimulatedChat {
fn ask(&self, _prompt: &str, _params: &LlmParams) -> Result<String> {
Ok("This is a simulated answer from the LLM based on the retrieved context.".to_string())
}
}
pub fn get_llm(config: &LlmConfig) -> Result<Box<dyn LlmProvider>> {
match config.llm_type.as_str() {
"ollama" => Ok(Box::new(OllamaChat::new(
config.model.as_deref().unwrap_or("llama3:8b"),
config.host.as_deref(),
))),
"openai" => Ok(Box::new(OpenAiChat::new(
config.model.as_deref().unwrap_or("gpt-4o"),
config.api_key.as_deref(),
config.base_url.as_deref(),
)?)),
"anthropic" => Ok(Box::new(AnthropicChat::new(
config
.model
.as_deref()
.unwrap_or("claude-3-5-sonnet-20241022"),
config.api_key.as_deref(),
config.base_url.as_deref(),
)?)),
"simulated" => Ok(Box::new(SimulatedChat)),
other => anyhow::bail!("Unknown LLM type: {}", other),
}
}
pub struct LeannChat {
searcher: LeannSearcher,
llm: Box<dyn LlmProvider>,
#[allow(dead_code)]
owns_searcher: bool,
}
impl LeannChat {
pub fn new(searcher: LeannSearcher, llm_config: Option<&LlmConfig>) -> Result<Self> {
let config = llm_config.cloned().unwrap_or_default();
let llm = get_llm(&config)?;
Ok(Self {
searcher,
llm,
owns_searcher: true,
})
}
pub fn ask(&self, question: &str, top_k: usize) -> Result<String> {
let results = self.searcher.search(question, top_k)?;
let context: String = results
.iter()
.map(|r| r.text.as_str())
.collect::<Vec<_>>()
.join("\n\n");
let prompt = format!(
"Here is some retrieved context that might help answer your question:\n\n\
{}\n\n\
Question: {}\n\n\
Please provide the best answer you can based on this context and your knowledge.",
context, question
);
info!(
"Sending RAG prompt to LLM ({} context results)",
results.len()
);
let answer = self.llm.ask(&prompt, &LlmParams::default())?;
Ok(answer)
}
pub fn cleanup(&mut self) {
}
}