use std::fs::{self, File};
use std::io::{IsTerminal, Read, Write};
use std::path::{Path, PathBuf};
use std::sync::Arc;
use anyhow::{Context, Result, bail};
use indicatif::{ProgressBar, ProgressStyle};
use llama_cpp_2::context::params::LlamaContextParams;
use llama_cpp_2::llama_backend::LlamaBackend;
use llama_cpp_2::llama_batch::LlamaBatch;
use llama_cpp_2::model::params::LlamaModelParams;
use llama_cpp_2::model::{AddBos, LlamaModel};
use regex::Regex;
use crate::config;
pub const DEFAULT_EMBED_MODEL: &str = "embeddinggemma-300M-Q8_0.gguf";
pub const DEFAULT_RERANK_MODEL: &str = "qwen3-reranker-0.6b-q8_0.gguf";
pub const DEFAULT_GENERATE_MODEL: &str = "qmd-query-expansion-1.7B-q4_k_m.gguf";
pub const DEFAULT_EMBED_MODEL_URI: &str =
"hf:ggml-org/embeddinggemma-300M-GGUF/embeddinggemma-300M-Q8_0.gguf";
pub const DEFAULT_RERANK_MODEL_URI: &str =
"hf:ggml-org/Qwen3-Reranker-0.6B-Q8_0-GGUF/qwen3-reranker-0.6b-q8_0.gguf";
pub const DEFAULT_GENERATE_MODEL_URI: &str =
"hf:tobil/qmd-query-expansion-1.7B-gguf/qmd-query-expansion-1.7B-q4_k_m.gguf";
pub const CHUNK_SIZE_TOKENS: usize = 800;
pub const CHUNK_OVERLAP_TOKENS: usize = 120;
pub const CHUNK_SIZE_CHARS: usize = 3200;
pub const CHUNK_OVERLAP_CHARS: usize = 480;
#[derive(Debug, Default, Clone, Copy)]
pub struct Progress {
enabled: bool,
}
impl Progress {
#[must_use]
pub fn new() -> Self {
Self {
enabled: std::io::stderr().is_terminal(),
}
}
pub fn set(&self, percent: f64) {
if self.enabled {
eprint!("\x1b]9;4;1;{}\x07", percent.round() as u8);
}
}
pub fn clear(&self) {
if self.enabled {
eprint!("\x1b]9;4;0\x07");
}
}
pub fn indeterminate(&self) {
if self.enabled {
eprint!("\x1b]9;4;3\x07");
}
}
pub fn error(&self) {
if self.enabled {
eprint!("\x1b]9;4;2\x07");
}
}
}
#[derive(Debug, Clone, Copy)]
pub struct Cursor;
impl Cursor {
pub fn hide() {
if std::io::stderr().is_terminal() {
eprint!("\x1b[?25l");
}
}
pub fn show() {
if std::io::stderr().is_terminal() {
eprint!("\x1b[?25h");
}
}
}
#[must_use]
pub fn format_eta(seconds: f64) -> String {
if seconds < 60.0 {
format!("{}s", seconds.round() as u64)
} else if seconds < 3600.0 {
format!(
"{}m {}s",
(seconds / 60.0) as u64,
(seconds % 60.0).round() as u64
)
} else {
format!(
"{}h {}m",
(seconds / 3600.0) as u64,
((seconds % 3600.0) / 60.0).round() as u64
)
}
}
#[must_use]
pub fn render_progress_bar(percent: f64, width: usize) -> String {
let filled = ((percent / 100.0) * width as f64).round() as usize;
let empty = width.saturating_sub(filled);
format!("{}{}", "â–ˆ".repeat(filled), "â–‘".repeat(empty))
}
pub struct EmbeddingEngine {
backend: LlamaBackend,
model: Arc<LlamaModel>,
dimensions: Option<usize>,
}
impl std::fmt::Debug for EmbeddingEngine {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("EmbeddingEngine")
.field("dimensions", &self.dimensions)
.finish()
}
}
#[derive(Debug, Clone)]
pub struct EmbeddingResult {
pub embedding: Vec<f32>,
pub model: String,
}
#[derive(Debug, Clone)]
pub struct Chunk {
pub text: String,
pub pos: usize,
}
impl EmbeddingEngine {
pub fn new(model_path: &Path) -> Result<Self> {
let backend = LlamaBackend::init()?;
let model_params = LlamaModelParams::default();
let model = LlamaModel::load_from_file(&backend, model_path, &model_params)
.with_context(|| format!("Failed to load model from {}", model_path.display()))?;
Ok(Self {
backend,
model: Arc::new(model),
dimensions: None,
})
}
pub fn load_default() -> Result<Self> {
let model_path = get_model_path(DEFAULT_EMBED_MODEL)?;
Self::new(&model_path)
}
pub fn embed(&mut self, text: &str) -> Result<EmbeddingResult> {
let formatted = format_doc_for_embedding(text, None);
self.embed_raw(&formatted)
}
pub fn embed_document(&mut self, text: &str, title: Option<&str>) -> Result<EmbeddingResult> {
let formatted = format_doc_for_embedding(text, title);
self.embed_raw(&formatted)
}
pub fn embed_query(&mut self, query: &str) -> Result<EmbeddingResult> {
let formatted = format_query_for_embedding(query);
self.embed_raw(&formatted)
}
pub fn embed_batch(&mut self, texts: &[String]) -> Result<Vec<EmbeddingResult>> {
texts.iter().map(|text| self.embed(text)).collect()
}
fn embed_raw(&mut self, text: &str) -> Result<EmbeddingResult> {
let tokens = self
.model
.str_to_token(text, AddBos::Always)
.context("Failed to tokenize text")?;
if tokens.is_empty() {
bail!("Empty token sequence");
}
let n_ctx = std::cmp::max(tokens.len() + 64, 512);
let ctx_params = LlamaContextParams::default()
.with_embeddings(true)
.with_n_ctx(std::num::NonZero::new(n_ctx as u32))
.with_n_batch(n_ctx as u32)
.with_n_ubatch(n_ctx as u32);
let mut ctx = self
.model
.new_context(&self.backend, ctx_params)
.context("Failed to create context")?;
let mut batch = LlamaBatch::new(tokens.len(), 1);
for (i, token) in tokens.iter().enumerate() {
let is_last = i == tokens.len() - 1;
batch.add(*token, i as i32, &[0], is_last)?;
}
ctx.decode(&mut batch).context("Failed to decode batch")?;
let embeddings = ctx
.embeddings_seq_ith(0)
.context("Failed to get embeddings")?;
if self.dimensions.is_none() {
self.dimensions = Some(embeddings.len());
}
Ok(EmbeddingResult {
embedding: embeddings.to_vec(),
model: DEFAULT_EMBED_MODEL.to_string(),
})
}
#[must_use]
pub const fn dimensions(&self) -> Option<usize> {
self.dimensions
}
pub fn count_tokens(&self, text: &str) -> Result<usize> {
let tokens = self
.model
.str_to_token(text, AddBos::Never)
.context("Failed to tokenize")?;
Ok(tokens.len())
}
pub fn tokenize(&self, text: &str) -> Result<Vec<i32>> {
let tokens = self
.model
.str_to_token(text, AddBos::Never)
.context("Failed to tokenize")?;
Ok(tokens.iter().map(|t| t.0).collect())
}
pub fn embed_batch_with_progress<F>(
&mut self,
items: &[(String, Option<String>)], mut on_progress: F,
) -> Vec<Result<EmbeddingResult>>
where
F: FnMut(usize, usize),
{
let total = items.len();
items
.iter()
.enumerate()
.map(|(i, (text, title))| {
on_progress(i, total);
self.embed_document(text, title.as_deref())
})
.collect()
}
}
#[must_use]
pub fn format_doc_for_embedding(text: &str, title: Option<&str>) -> String {
let title_str = title.unwrap_or("none");
format!("title: {title_str} | text: {text}")
}
#[must_use]
pub fn format_query_for_embedding(query: &str) -> String {
format!("task: search result | query: {query}")
}
pub fn get_model_path(model_name: &str) -> Result<PathBuf> {
let cache_dir = config::get_model_cache_dir();
let model_path = cache_dir.join(model_name);
if !model_path.exists() {
bail!(
"Model not found: {}. Run 'qmd models pull' to download models.",
model_path.display()
);
}
Ok(model_path)
}
#[must_use]
pub fn model_exists(model_name: &str) -> bool {
let cache_dir = config::get_model_cache_dir();
cache_dir.join(model_name).exists()
}
#[must_use]
pub fn list_cached_models() -> Vec<String> {
let cache_dir = config::get_model_cache_dir();
if !cache_dir.exists() {
return Vec::new();
}
fs::read_dir(&cache_dir)
.map(|entries| {
entries
.filter_map(Result::ok)
.filter(|e| e.path().extension().is_some_and(|ext| ext == "gguf"))
.filter_map(|e| e.file_name().into_string().ok())
.collect()
})
.unwrap_or_default()
}
#[derive(Debug, Clone)]
pub struct TokenChunk {
pub text: String,
pub pos: usize,
pub tokens: usize,
pub bytes: usize,
}
pub fn chunk_document_by_tokens(
engine: &EmbeddingEngine,
content: &str,
max_tokens: usize,
overlap_tokens: usize,
) -> Result<Vec<TokenChunk>> {
let total_tokens = engine.count_tokens(content)?;
if total_tokens <= max_tokens {
return Ok(vec![TokenChunk {
text: content.to_string(),
pos: 0,
tokens: total_tokens,
bytes: content.len(),
}]);
}
let mut chunks = Vec::new();
let paragraphs: Vec<&str> = content.split("\n\n").collect();
let mut current_chunk = String::new();
let mut current_tokens = 0usize;
let mut chunk_start_pos = 0usize;
let mut char_pos = 0usize;
for (para_idx, para) in paragraphs.iter().enumerate() {
let para_tokens = engine.count_tokens(para)?;
let para_with_sep = if para_idx > 0 {
format!("\n\n{para}")
} else {
(*para).to_string()
};
let sep_tokens = if para_idx > 0 { 2 } else { 0 };
if current_tokens + para_tokens + sep_tokens > max_tokens && !current_chunk.is_empty() {
let chunk_bytes = current_chunk.len();
chunks.push(TokenChunk {
text: current_chunk.clone(),
pos: chunk_start_pos,
tokens: current_tokens,
bytes: chunk_bytes,
});
let overlap_text = get_overlap_text(¤t_chunk, overlap_tokens, engine)?;
current_chunk = overlap_text;
current_tokens = engine.count_tokens(¤t_chunk)?;
chunk_start_pos = char_pos.saturating_sub(current_chunk.len());
}
if !current_chunk.is_empty() {
current_chunk.push_str("\n\n");
}
current_chunk.push_str(para);
current_tokens += para_tokens + sep_tokens;
char_pos += para_with_sep.len();
}
if !current_chunk.is_empty() {
chunks.push(TokenChunk {
text: current_chunk.clone(),
pos: chunk_start_pos,
tokens: current_tokens,
bytes: current_chunk.len(),
});
}
Ok(chunks)
}
fn get_overlap_text(text: &str, target_tokens: usize, engine: &EmbeddingEngine) -> Result<String> {
let start_frac = text.len() * 4 / 5;
let candidate = &text[start_frac..];
if let Some(pos) = candidate.find("\n\n") {
let overlap = &candidate[pos + 2..];
let tokens = engine.count_tokens(overlap)?;
if tokens <= target_tokens * 2 {
return Ok(overlap.to_string());
}
}
let words: Vec<&str> = candidate.split_whitespace().collect();
let mut result = String::new();
for word in words.iter().rev().take(target_tokens / 2) {
if !result.is_empty() {
result = format!("{word} {result}");
} else {
result = (*word).to_string();
}
}
Ok(result)
}
#[must_use]
pub fn chunk_document(content: &str, max_chars: usize, overlap_chars: usize) -> Vec<Chunk> {
if content.len() <= max_chars {
return vec![Chunk {
text: content.to_string(),
pos: 0,
}];
}
let mut chunks = Vec::new();
let mut char_pos = 0;
while char_pos < content.len() {
let end_pos = (char_pos + max_chars).min(content.len());
let actual_end = if end_pos < content.len() {
find_break_point(content, char_pos, end_pos)
} else {
end_pos
};
let actual_end = if actual_end <= char_pos {
(char_pos + max_chars).min(content.len())
} else {
actual_end
};
chunks.push(Chunk {
text: content[char_pos..actual_end].to_string(),
pos: char_pos,
});
if actual_end >= content.len() {
break;
}
char_pos = actual_end.saturating_sub(overlap_chars);
if let Some(last) = chunks.last()
&& char_pos <= last.pos
{
char_pos = actual_end;
}
}
chunks
}
fn find_break_point(content: &str, start: usize, end: usize) -> usize {
let slice = &content[start..end];
let search_start = slice.len() * 7 / 10; let search_slice = &slice[search_start..];
if let Some(pos) = search_slice.rfind("\n\n") {
return start + search_start + pos + 2;
}
for pattern in &[". ", ".\n", "? ", "?\n", "! ", "!\n"] {
if let Some(pos) = search_slice.rfind(pattern) {
return start + search_start + pos + 2;
}
}
if let Some(pos) = search_slice.rfind('\n') {
return start + search_start + pos + 1;
}
if let Some(pos) = search_slice.rfind(' ') {
return start + search_start + pos + 1;
}
end
}
#[must_use]
pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
if a.len() != b.len() || a.is_empty() {
return 0.0;
}
let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm_a == 0.0 || norm_b == 0.0 {
return 0.0;
}
dot / (norm_a * norm_b)
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum QueryType {
Lex,
Vec,
Hyde,
}
#[derive(Debug, Clone)]
pub struct Queryable {
pub query_type: QueryType,
pub text: String,
}
impl Queryable {
#[must_use]
pub fn new(query_type: QueryType, text: impl Into<String>) -> Self {
Self {
query_type,
text: text.into(),
}
}
#[must_use]
pub fn lex(text: impl Into<String>) -> Self {
Self::new(QueryType::Lex, text)
}
#[must_use]
pub fn vec(text: impl Into<String>) -> Self {
Self::new(QueryType::Vec, text)
}
#[must_use]
pub fn hyde(text: impl Into<String>) -> Self {
Self::new(QueryType::Hyde, text)
}
}
#[derive(Debug, Clone)]
pub struct RerankDocument {
pub file: String,
pub text: String,
pub title: Option<String>,
}
#[derive(Debug, Clone)]
pub struct RerankResult {
pub file: String,
pub score: f32,
pub index: usize,
}
#[derive(Debug, Clone)]
struct HfRef {
repo: String,
file: String,
}
fn parse_hf_uri(uri: &str) -> Option<HfRef> {
if !uri.starts_with("hf:") {
return None;
}
let without_prefix = &uri[3..];
let parts: Vec<&str> = without_prefix.splitn(3, '/').collect();
if parts.len() < 3 {
return None;
}
Some(HfRef {
repo: format!("{}/{}", parts[0], parts[1]),
file: parts[2].to_string(),
})
}
#[derive(Debug, Clone)]
pub struct PullResult {
pub model: String,
pub path: PathBuf,
pub size_bytes: u64,
pub refreshed: bool,
}
pub fn pull_model(model_uri: &str, refresh: bool) -> Result<PullResult> {
let cache_dir = config::get_model_cache_dir();
fs::create_dir_all(&cache_dir)?;
let hf_ref = parse_hf_uri(model_uri);
let filename = if let Some(ref hf) = hf_ref {
hf.file.clone()
} else {
model_uri.to_string()
};
let local_path = cache_dir.join(&filename);
let etag_path = cache_dir.join(format!("{filename}.etag"));
let should_download = if refresh {
true
} else if !local_path.exists() {
true
} else if let Some(ref hf) = hf_ref {
let remote_etag = get_remote_etag(hf);
let local_etag = fs::read_to_string(&etag_path).ok();
remote_etag.is_some() && remote_etag != local_etag
} else {
false
};
if should_download {
if let Some(ref hf) = hf_ref {
download_from_hf(hf, &local_path, &etag_path)?;
} else {
bail!("Model not found and no HuggingFace URI provided: {model_uri}");
}
}
let size_bytes = fs::metadata(&local_path).map_or(0, |m| m.len());
Ok(PullResult {
model: model_uri.to_string(),
path: local_path,
size_bytes,
refreshed: should_download,
})
}
fn get_remote_etag(hf: &HfRef) -> Option<String> {
let url = format!(
"https://huggingface.co/{}/resolve/main/{}",
hf.repo, hf.file
);
let client = reqwest::blocking::Client::builder()
.timeout(std::time::Duration::from_secs(10))
.build()
.ok()?;
let resp = client.head(&url).send().ok()?;
if !resp.status().is_success() {
return None;
}
resp.headers()
.get("etag")
.and_then(|v| v.to_str().ok())
.map(|s| s.trim_matches('"').to_string())
}
fn download_from_hf(hf: &HfRef, local_path: &Path, etag_path: &Path) -> Result<()> {
let url = format!(
"https://huggingface.co/{}/resolve/main/{}",
hf.repo, hf.file
);
let client = reqwest::blocking::Client::builder()
.timeout(std::time::Duration::from_hours(1))
.build()?;
let mut resp = client.get(&url).send()?;
if !resp.status().is_success() {
bail!("Failed to download {}: HTTP {}", url, resp.status());
}
let total_size = resp.content_length().unwrap_or(0);
let pb = ProgressBar::new(total_size);
pb.set_style(
ProgressStyle::default_bar()
.template("{spinner:.green} [{elapsed_precise}] [{bar:40.cyan/blue}] {bytes}/{total_bytes} ({eta})")
.expect("valid template")
.progress_chars("#>-"),
);
pb.set_message(format!("Downloading {}", hf.file));
let mut file = File::create(local_path)?;
let mut downloaded: u64 = 0;
let mut buffer = [0u8; 8192];
loop {
let bytes_read = resp.read(&mut buffer)?;
if bytes_read == 0 {
break;
}
file.write_all(&buffer[..bytes_read])?;
downloaded += bytes_read as u64;
pb.set_position(downloaded);
}
pb.finish_with_message(format!("Downloaded {}", hf.file));
if let Some(etag) = resp.headers().get("etag")
&& let Ok(etag_str) = etag.to_str()
{
let _ = fs::write(etag_path, etag_str.trim_matches('"'));
}
Ok(())
}
pub fn pull_models(models: &[&str], refresh: bool) -> Result<Vec<PullResult>> {
models.iter().map(|m| pull_model(m, refresh)).collect()
}
pub fn resolve_model(model_uri: &str) -> Result<PathBuf> {
let result = pull_model(model_uri, false)?;
Ok(result.path)
}
#[must_use]
pub fn expand_query_simple(query: &str) -> Vec<Queryable> {
let mut queries = Vec::new();
queries.push(Queryable::lex(query));
queries.push(Queryable::vec(query));
let hyde_text = format!("Information about {query}");
queries.push(Queryable::hyde(hyde_text));
queries
}
#[must_use]
pub fn parse_query_expansion(output: &str, original_query: &str) -> Vec<Queryable> {
let mut queries = Vec::new();
let query_lower = original_query.to_lowercase();
let line_re = Regex::new(r"^(lex|vec|hyde):\s*(.+)$").ok();
for line in output.lines() {
let line = line.trim();
if line.is_empty() {
continue;
}
if let Some(ref re) = line_re
&& let Some(caps) = re.captures(line)
{
let query_type = match &caps[1] {
"lex" => QueryType::Lex,
"vec" => QueryType::Vec,
"hyde" => QueryType::Hyde,
_ => continue,
};
let text = caps[2].trim();
let text_lower = text.to_lowercase();
let has_query_term = query_lower
.split_whitespace()
.any(|term| term.len() >= 3 && text_lower.contains(term));
if has_query_term || query_lower.len() < 3 {
queries.push(Queryable::new(query_type, text));
}
}
}
if queries.is_empty() {
return expand_query_simple(original_query);
}
queries
}
#[derive(Debug, Clone)]
pub struct RrfResult {
pub file: String,
pub display_path: String,
pub title: String,
pub body: String,
pub score: f64,
pub best_rank: usize,
}
#[must_use]
pub fn reciprocal_rank_fusion(
result_lists: &[Vec<(String, String, String, String)>],
weights: Option<&[f64]>,
k: usize,
) -> Vec<RrfResult> {
use std::collections::HashMap;
let mut scores: HashMap<String, (f64, String, String, String, usize)> = HashMap::new();
for (list_idx, results) in result_lists.iter().enumerate() {
let weight = weights
.and_then(|w| w.get(list_idx))
.copied()
.unwrap_or(1.0);
for (rank, (file, display_path, title, body)) in results.iter().enumerate() {
let rrf_score = weight / (k + rank + 1) as f64;
scores
.entry(file.clone())
.and_modify(|(score, _, _, _, best_rank)| {
*score += rrf_score;
*best_rank = (*best_rank).min(rank);
})
.or_insert((
rrf_score,
display_path.clone(),
title.clone(),
body.clone(),
rank,
));
}
}
let mut results: Vec<RrfResult> = scores
.into_iter()
.map(|(file, (score, display_path, title, body, best_rank))| {
let bonus = match best_rank {
0..=2 => 0.08, 3..=9 => 0.04, 10..=19 => 0.01, _ => 0.0,
};
RrfResult {
file,
display_path,
title,
body,
score: score + bonus,
best_rank,
}
})
.collect();
results.sort_by(|a, b| {
b.score
.partial_cmp(&a.score)
.unwrap_or(std::cmp::Ordering::Equal)
});
results
}
#[derive(Debug, Clone)]
pub struct SnippetResult {
pub snippet: String,
pub line: usize,
}
#[must_use]
pub fn extract_snippet(
body: &str,
query: &str,
max_chars: usize,
chunk_pos: Option<usize>,
) -> SnippetResult {
if body.len() <= max_chars {
return SnippetResult {
snippet: body.to_string(),
line: 1,
};
}
let terms: Vec<&str> = query.split_whitespace().filter(|t| t.len() >= 3).collect();
let body_lower = body.to_lowercase();
let start_pos = if let Some(pos) = chunk_pos {
pos.min(body.len().saturating_sub(max_chars))
} else {
let mut best_pos = 0;
for term in &terms {
if let Some(pos) = body_lower.find(&term.to_lowercase()) {
best_pos = pos.saturating_sub(50); break;
}
}
best_pos
};
let line_start = body[..start_pos].rfind('\n').map_or(0, |p| p + 1);
let end_pos = (line_start + max_chars).min(body.len());
let line_end = body[end_pos..]
.find('\n')
.map_or(body.len(), |p| end_pos + p);
let line = body[..line_start].matches('\n').count() + 1;
let snippet = body[line_start..line_end].to_string();
SnippetResult { snippet, line }
}
#[derive(Debug, Clone, Copy)]
pub struct IndexHealth {
pub needs_embedding: usize,
pub total_docs: usize,
pub days_stale: Option<u64>,
}
impl IndexHealth {
#[must_use]
pub fn is_healthy(&self) -> bool {
let embedding_ok = self.needs_embedding == 0
|| (self.needs_embedding as f64 / self.total_docs.max(1) as f64) < 0.1;
let freshness_ok = self.days_stale.is_none() || self.days_stale < Some(14);
embedding_ok && freshness_ok
}
#[must_use]
pub fn warning_message(&self) -> Option<String> {
let mut messages = Vec::new();
if self.needs_embedding > 0 {
let pct =
(self.needs_embedding as f64 / self.total_docs.max(1) as f64 * 100.0) as usize;
if pct >= 10 {
messages.push(format!(
"{} documents ({}%) need embeddings. Run 'qmd embed' for better results.",
self.needs_embedding, pct
));
}
}
if let Some(days) = self.days_stale
&& days >= 14
{
messages.push(format!(
"Index last updated {days} days ago. Run 'qmd update' to refresh."
));
}
if messages.is_empty() {
None
} else {
Some(messages.join("\n"))
}
}
}
pub struct GenerationEngine {
backend: LlamaBackend,
model: Arc<LlamaModel>,
}
impl std::fmt::Debug for GenerationEngine {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("GenerationEngine").finish()
}
}
#[derive(Debug, Clone)]
pub struct GenerationResult {
pub text: String,
pub model: String,
pub done: bool,
}
impl GenerationEngine {
pub fn new(model_path: &Path) -> Result<Self> {
let backend = LlamaBackend::init()?;
let model_params = LlamaModelParams::default();
let model = LlamaModel::load_from_file(&backend, model_path, &model_params)
.with_context(|| format!("Failed to load model from {}", model_path.display()))?;
Ok(Self {
backend,
model: Arc::new(model),
})
}
pub fn load_default() -> Result<Self> {
let model_path = get_model_path(DEFAULT_GENERATE_MODEL)?;
Self::new(&model_path)
}
pub fn is_available() -> bool {
model_exists(DEFAULT_GENERATE_MODEL)
}
pub fn generate(&self, prompt: &str, max_tokens: usize) -> Result<GenerationResult> {
use llama_cpp_2::sampling::LlamaSampler;
let ctx_params = LlamaContextParams::default().with_n_ctx(std::num::NonZero::new(4096));
let mut ctx = self
.model
.new_context(&self.backend, ctx_params)
.context("Failed to create context")?;
let tokens = self
.model
.str_to_token(prompt, AddBos::Always)
.context("Failed to tokenize prompt")?;
let mut batch = LlamaBatch::new(tokens.len().max(512), 1);
for (i, token) in tokens.iter().enumerate() {
batch.add(*token, i as i32, &[0], i == tokens.len() - 1)?;
}
ctx.decode(&mut batch).context("Failed to decode prompt")?;
let mut sampler = LlamaSampler::chain_simple([
LlamaSampler::temp(0.7),
LlamaSampler::top_k(40),
LlamaSampler::top_p(0.9, 1),
LlamaSampler::dist(42),
]);
let mut output_text = String::new();
let mut n_cur = tokens.len();
for _ in 0..max_tokens {
let new_token = sampler.sample(&ctx, batch.n_tokens() - 1);
if self.model.is_eog_token(new_token) {
break;
}
if let Ok(piece) = self
.model
.token_to_str(new_token, llama_cpp_2::model::Special::Tokenize)
{
output_text.push_str(&piece);
}
batch.clear();
batch.add(new_token, n_cur as i32, &[0], true)?;
n_cur += 1;
ctx.decode(&mut batch)?;
}
Ok(GenerationResult {
text: output_text,
model: DEFAULT_GENERATE_MODEL.to_string(),
done: true,
})
}
pub fn expand_query(&self, query: &str, include_lexical: bool) -> Result<Vec<Queryable>> {
let prompt = format!(
r#"/no_think Expand this search query into different forms for retrieval.
Output format (one per line):
lex: keyword terms for BM25 search
vec: semantic query for vector search
hyde: hypothetical document that would answer the query
Query: {query}
"#
);
let result = self.generate(&prompt, 300)?;
let mut queries = parse_query_expansion(&result.text, query);
if !include_lexical {
queries.retain(|q| q.query_type != QueryType::Lex);
}
Ok(queries)
}
}
pub struct RerankEngine {
backend: LlamaBackend,
model: Arc<LlamaModel>,
}
impl std::fmt::Debug for RerankEngine {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("RerankEngine").finish()
}
}
#[derive(Debug, Clone)]
pub struct BatchRerankResult {
pub results: Vec<RerankResult>,
pub model: String,
}
impl RerankEngine {
pub fn new(model_path: &Path) -> Result<Self> {
let backend = LlamaBackend::init()?;
let model_params = LlamaModelParams::default();
let model =
LlamaModel::load_from_file(&backend, model_path, &model_params).with_context(|| {
format!("Failed to load rerank model from {}", model_path.display())
})?;
Ok(Self {
backend,
model: Arc::new(model),
})
}
pub fn load_default() -> Result<Self> {
let model_path = get_model_path(DEFAULT_RERANK_MODEL)?;
Self::new(&model_path)
}
pub fn is_available() -> bool {
model_exists(DEFAULT_RERANK_MODEL)
}
pub fn rerank(
&mut self,
query: &str,
documents: &[RerankDocument],
) -> Result<BatchRerankResult> {
if documents.is_empty() {
return Ok(BatchRerankResult {
results: Vec::new(),
model: DEFAULT_RERANK_MODEL.to_string(),
});
}
let ctx_params = LlamaContextParams::default().with_embeddings(true);
let mut results: Vec<RerankResult> = Vec::new();
let query_input = format_query_for_embedding(query);
let query_embedding = self.get_embedding(&query_input, &ctx_params)?;
for (index, doc) in documents.iter().enumerate() {
let doc_input = format_doc_for_embedding(&doc.text, doc.title.as_deref());
match self.get_embedding(&doc_input, &ctx_params) {
Ok(doc_embedding) => {
let score = cosine_similarity(&query_embedding, &doc_embedding);
results.push(RerankResult {
file: doc.file.clone(),
score,
index,
});
}
Err(_) => {
results.push(RerankResult {
file: doc.file.clone(),
score: 0.0,
index,
});
}
}
}
results.sort_by(|a, b| {
b.score
.partial_cmp(&a.score)
.unwrap_or(std::cmp::Ordering::Equal)
});
Ok(BatchRerankResult {
results,
model: DEFAULT_RERANK_MODEL.to_string(),
})
}
fn get_embedding(&self, text: &str, ctx_params: &LlamaContextParams) -> Result<Vec<f32>> {
let mut ctx = self
.model
.new_context(&self.backend, ctx_params.clone())
.context("Failed to create context")?;
let tokens = self
.model
.str_to_token(text, AddBos::Always)
.context("Failed to tokenize")?;
if tokens.is_empty() {
bail!("Empty token sequence");
}
let mut batch = LlamaBatch::new(tokens.len(), 1);
for (i, token) in tokens.iter().enumerate() {
batch.add(*token, i as i32, &[0], i == tokens.len() - 1)?;
}
ctx.decode(&mut batch)?;
let embeddings = ctx
.embeddings_seq_ith(0)
.context("Failed to get embeddings")?;
Ok(embeddings.to_vec())
}
}
pub fn hybrid_search_rrf(
fts_results: Vec<(String, String, String, String)>,
vec_results: Vec<(String, String, String, String)>,
k: usize,
) -> Vec<RrfResult> {
reciprocal_rank_fusion(&[fts_results, vec_results], Some(&[1.0, 1.0]), k)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_format_doc_for_embedding() {
let result = format_doc_for_embedding("hello world", Some("Test Title"));
assert_eq!(result, "title: Test Title | text: hello world");
let result = format_doc_for_embedding("hello world", None);
assert_eq!(result, "title: none | text: hello world");
}
#[test]
fn test_format_query_for_embedding() {
let result = format_query_for_embedding("test query");
assert_eq!(result, "task: search result | query: test query");
}
#[test]
fn test_chunk_document_small() {
let content = "Small content";
let chunks = chunk_document(content, 100, 10);
assert_eq!(chunks.len(), 1);
assert_eq!(chunks[0].text, content);
assert_eq!(chunks[0].pos, 0);
}
#[test]
fn test_chunk_document_large() {
let content = "a".repeat(500);
let chunks = chunk_document(&content, 100, 10);
assert!(chunks.len() > 1);
}
#[test]
fn test_cosine_similarity() {
let a = vec![1.0, 0.0, 0.0];
let b = vec![1.0, 0.0, 0.0];
assert!((cosine_similarity(&a, &b) - 1.0).abs() < 0.001);
let c = vec![0.0, 1.0, 0.0];
assert!(cosine_similarity(&a, &c).abs() < 0.001);
}
}