Skip to main content

rag_cli/
lib.rs

1mod embed;
2mod index;
3mod ingest;
4
5use anyhow::{Context, Result};
6use clap::{Parser, Subcommand};
7use std::path::PathBuf;
8use std::time::Instant;
9
10use std::collections::BTreeMap;
11
12use crate::embed::{select_device, EmbeddingEngine, DEFAULT_MODEL};
13use crate::index::{search_top_k, ChunkRecord, Index, IndexMeta};
14use crate::ingest::{chunk_file, discover_files, hash_files};
15
16const DEFAULT_CHUNK_SIZE: usize = 512;
17const DEFAULT_CHUNK_OVERLAP: usize = 64;
18const DEFAULT_TOP_K: usize = 5;
19
20#[derive(Parser)]
21#[command(name = "rag")]
22#[command(about = "Local RAG — index and semantic search your files using local embeddings")]
23#[command(version)]
24struct Cli {
25    /// Override HuggingFace model cache directory.
26    /// Default: $HF_HOME/hub or ~/.cache/huggingface/hub
27    #[arg(long, global = true, env = "RAG_CACHE_DIR")]
28    cache_dir: Option<PathBuf>,
29
30    #[command(subcommand)]
31    command: Commands,
32}
33
34#[derive(Subcommand)]
35enum Commands {
36    /// Index a directory of text files for semantic search.
37    Index {
38        /// Directory to index.
39        path: PathBuf,
40
41        /// Where to store the index (default: .rag in current directory).
42        #[arg(short, long)]
43        output: Option<PathBuf>,
44
45        /// HuggingFace model ID for embeddings.
46        #[arg(short, long, default_value = DEFAULT_MODEL)]
47        model: String,
48
49        /// Chunk size in characters.
50        #[arg(long, default_value_t = DEFAULT_CHUNK_SIZE)]
51        chunk_size: usize,
52
53        /// Chunk overlap in characters.
54        #[arg(long, default_value_t = DEFAULT_CHUNK_OVERLAP)]
55        chunk_overlap: usize,
56    },
57
58    /// Search the index with a natural language query.
59    Search {
60        /// The search query.
61        query: String,
62
63        /// Index directory (default: .rag).
64        #[arg(short, long)]
65        index: Option<PathBuf>,
66
67        /// Number of results to return.
68        #[arg(short = 'k', long, default_value_t = DEFAULT_TOP_K)]
69        top_k: usize,
70
71        /// HuggingFace model ID (must match the one used for indexing).
72        #[arg(short, long)]
73        model: Option<String>,
74
75        /// Show full chunk text instead of truncated preview.
76        #[arg(long)]
77        full: bool,
78
79        /// Output results as compact JSON (for piping to LLMs or other tools).
80        #[arg(long)]
81        json: bool,
82    },
83
84    /// Show index metadata and statistics.
85    Info {
86        /// Index directory (default: .rag).
87        #[arg(short, long)]
88        index: Option<PathBuf>,
89    },
90}
91
92/// Entry point for the CLI. Call this from `main()`.
93pub fn run() -> Result<()> {
94    let cli = Cli::parse();
95    let cache_dir = cli.cache_dir.as_deref();
96
97    match cli.command {
98        Commands::Index {
99            path,
100            output,
101            model,
102            chunk_size,
103            chunk_overlap,
104        } => cmd_index(
105            &path,
106            output.as_deref(),
107            &model,
108            chunk_size,
109            chunk_overlap,
110            cache_dir,
111        ),
112        Commands::Search {
113            query,
114            index,
115            top_k,
116            model,
117            full,
118            json,
119        } => cmd_search(
120            &query,
121            index.as_deref(),
122            top_k,
123            model.as_deref(),
124            full,
125            json,
126            cache_dir,
127        ),
128        Commands::Info { index } => cmd_info(index.as_deref()),
129    }
130}
131
132fn cmd_index(
133    path: &PathBuf,
134    output: Option<&std::path::Path>,
135    model_id: &str,
136    chunk_size: usize,
137    chunk_overlap: usize,
138    cache_dir: Option<&std::path::Path>,
139) -> Result<()> {
140    let start = Instant::now();
141
142    let root = path
143        .canonicalize()
144        .with_context(|| format!("Directory not found: {}", path.display()))?;
145
146    if !root.is_dir() {
147        anyhow::bail!("{} is not a directory", root.display());
148    }
149
150    let index_dir = output.map(PathBuf::from).unwrap_or_else(Index::default_dir);
151
152    // 1. Discover files and hash them
153    eprintln!("Indexing: {}", root.display());
154    let files = discover_files(&root)?;
155    eprintln!("Found {} text files", files.len());
156
157    if files.is_empty() {
158        anyhow::bail!("No text files found in {}", root.display());
159    }
160
161    let current_hashes = hash_files(&files, &root)?;
162
163    // 2. Try incremental indexing against an existing index
164    let prev_index = Index::load(&index_dir).ok();
165    let can_reuse = prev_index.as_ref().is_some_and(|prev| {
166        prev.meta.model_id == model_id
167            && prev.meta.chunk_size == chunk_size
168            && prev.meta.chunk_overlap == chunk_overlap
169    });
170
171    let (chunks, file_hashes, hidden_size) = if can_reuse {
172        let prev = prev_index.unwrap();
173        incremental_index(
174            &root,
175            &files,
176            &current_hashes,
177            &prev,
178            model_id,
179            chunk_size,
180            chunk_overlap,
181            cache_dir,
182        )?
183    } else {
184        if prev_index.is_some() {
185            eprintln!("Settings changed, performing full re-index");
186        }
187        full_index(
188            &root,
189            &files,
190            &current_hashes,
191            model_id,
192            chunk_size,
193            chunk_overlap,
194            cache_dir,
195        )?
196    };
197
198    if chunks.is_empty() {
199        anyhow::bail!("No text chunks produced. Check the directory contents.");
200    }
201
202    // 3. Save index
203    let meta = IndexMeta {
204        model_id: model_id.to_string(),
205        hidden_size,
206        num_chunks: chunks.len(),
207        root_dir: root.to_string_lossy().to_string(),
208        created_at: chrono_now(),
209        chunk_size,
210        chunk_overlap,
211        file_hashes,
212    };
213
214    let index = Index::new(meta, chunks);
215    index.save(&index_dir)?;
216
217    let elapsed = start.elapsed();
218    eprintln!(
219        "Index saved to {} ({} chunks, {:.1}s)",
220        index_dir.display(),
221        index.meta.num_chunks,
222        elapsed.as_secs_f64()
223    );
224
225    Ok(())
226}
227
228/// Full re-index: chunk and embed every file from scratch.
229fn full_index(
230    root: &std::path::Path,
231    files: &[PathBuf],
232    current_hashes: &BTreeMap<String, String>,
233    model_id: &str,
234    chunk_size: usize,
235    chunk_overlap: usize,
236    cache_dir: Option<&std::path::Path>,
237) -> Result<(Vec<ChunkRecord>, BTreeMap<String, String>, usize)> {
238    let mut all_chunks = Vec::new();
239    for file in files {
240        match chunk_file(file, root, chunk_size, chunk_overlap) {
241            Ok(chunks) => all_chunks.extend(chunks),
242            Err(e) => eprintln!("  Skipping {}: {e}", file.display()),
243        }
244    }
245
246    eprintln!("Embedding {} chunks...", all_chunks.len());
247    let device = select_device()?;
248    let engine = EmbeddingEngine::load(Some(model_id), &device, cache_dir)?;
249    let hidden_size = engine.hidden_size();
250
251    let texts: Vec<String> = all_chunks.iter().map(|c| c.text.clone()).collect();
252    let embeddings = engine.embed_batch_progress(&texts)?;
253
254    let chunks: Vec<ChunkRecord> = all_chunks
255        .into_iter()
256        .zip(embeddings)
257        .map(|(tc, emb)| ChunkRecord {
258            source: tc.source,
259            byte_offset: tc.byte_offset,
260            text: tc.text,
261            embedding: emb,
262        })
263        .collect();
264
265    Ok((chunks, current_hashes.clone(), hidden_size))
266}
267
268/// Incremental re-index: only re-chunk and re-embed changed/new files.
269fn incremental_index(
270    root: &std::path::Path,
271    files: &[PathBuf],
272    current_hashes: &BTreeMap<String, String>,
273    prev: &Index,
274    model_id: &str,
275    chunk_size: usize,
276    chunk_overlap: usize,
277    cache_dir: Option<&std::path::Path>,
278) -> Result<(Vec<ChunkRecord>, BTreeMap<String, String>, usize)> {
279    // Partition files into unchanged vs. changed/new
280    let mut unchanged: Vec<&str> = Vec::new();
281    let mut dirty_files: Vec<&PathBuf> = Vec::new();
282
283    for file in files {
284        let relative = file
285            .strip_prefix(root)
286            .unwrap_or(file)
287            .to_string_lossy()
288            .to_string();
289
290        let cur_hash = current_hashes.get(&relative);
291        let prev_hash = prev.meta.file_hashes.get(&relative);
292
293        if cur_hash.is_some() && cur_hash == prev_hash {
294            unchanged.push(
295                prev.meta
296                    .file_hashes
297                    .get_key_value(&relative)
298                    .unwrap()
299                    .0
300                    .as_str(),
301            );
302        } else {
303            dirty_files.push(file);
304        }
305    }
306
307    let deleted: Vec<&str> = prev
308        .meta
309        .file_hashes
310        .keys()
311        .filter(|k| !current_hashes.contains_key(k.as_str()))
312        .map(|k| k.as_str())
313        .collect();
314
315    eprintln!(
316        "Incremental: {} unchanged, {} changed/new, {} deleted",
317        unchanged.len(),
318        dirty_files.len(),
319        deleted.len(),
320    );
321
322    // Keep chunks from unchanged files
323    let mut chunks: Vec<ChunkRecord> = prev
324        .chunks
325        .iter()
326        .filter(|c| unchanged.contains(&c.source.as_str()))
327        .cloned()
328        .collect();
329
330    // Chunk and embed dirty files
331    let hidden_size = if !dirty_files.is_empty() {
332        let mut new_text_chunks = Vec::new();
333        for file in &dirty_files {
334            match chunk_file(file, root, chunk_size, chunk_overlap) {
335                Ok(cs) => new_text_chunks.extend(cs),
336                Err(e) => eprintln!("  Skipping {}: {e}", file.display()),
337            }
338        }
339
340        if !new_text_chunks.is_empty() {
341            eprintln!("Embedding {} new/changed chunks...", new_text_chunks.len());
342            let device = select_device()?;
343            let engine = EmbeddingEngine::load(Some(model_id), &device, cache_dir)?;
344            let hs = engine.hidden_size();
345
346            let texts: Vec<String> = new_text_chunks.iter().map(|c| c.text.clone()).collect();
347            let embeddings = engine.embed_batch_progress(&texts)?;
348
349            let new_chunks: Vec<ChunkRecord> = new_text_chunks
350                .into_iter()
351                .zip(embeddings)
352                .map(|(tc, emb)| ChunkRecord {
353                    source: tc.source,
354                    byte_offset: tc.byte_offset,
355                    text: tc.text,
356                    embedding: emb,
357                })
358                .collect();
359
360            chunks.extend(new_chunks);
361            hs
362        } else {
363            prev.meta.hidden_size
364        }
365    } else {
366        eprintln!("Everything up to date, nothing to embed");
367        prev.meta.hidden_size
368    };
369
370    Ok((chunks, current_hashes.clone(), hidden_size))
371}
372
373/// A single search result for JSON output.
374#[derive(serde::Serialize)]
375struct JsonResult {
376    source: String,
377    score: f32,
378    byte_offset: usize,
379    text: String,
380}
381
382fn cmd_search(
383    query: &str,
384    index_dir: Option<&std::path::Path>,
385    top_k: usize,
386    model_override: Option<&str>,
387    full: bool,
388    json: bool,
389    cache_dir: Option<&std::path::Path>,
390) -> Result<()> {
391    let start = Instant::now();
392
393    let index_dir = index_dir
394        .map(PathBuf::from)
395        .unwrap_or_else(Index::default_dir);
396
397    let index = Index::load(&index_dir).with_context(|| {
398        format!(
399            "No index found at {}. Run `rag index <path>` first.",
400            index_dir.display()
401        )
402    })?;
403
404    let model_id = model_override.unwrap_or(&index.meta.model_id);
405
406    // Load model and embed the query
407    let device = select_device()?;
408    let engine = EmbeddingEngine::load(Some(model_id), &device, cache_dir)?;
409    let query_embedding = engine.embed_one(query)?;
410
411    let embed_time = start.elapsed();
412
413    // Search
414    let results = search_top_k(&query_embedding, &index.chunks, top_k);
415
416    let search_time = start.elapsed();
417
418    if json {
419        let json_results: Vec<JsonResult> = results
420            .iter()
421            .map(|r| JsonResult {
422                source: r.chunk.source.clone(),
423                score: r.score,
424                byte_offset: r.chunk.byte_offset,
425                text: r.chunk.text.clone(),
426            })
427            .collect();
428        println!("{}", serde_json::to_string(&json_results)?);
429    } else {
430        // Print results
431        println!();
432        println!("Query: {query}");
433        println!("─────────────────────────────────────────");
434
435        if results.is_empty() {
436            println!("No results found.");
437        } else {
438            for (i, result) in results.iter().enumerate() {
439                let preview = if full {
440                    result.chunk.text.clone()
441                } else {
442                    truncate_text(&result.chunk.text, 200)
443                };
444
445                println!();
446                println!(
447                    "  [{rank}] {source} (score: {score:.4})",
448                    rank = i + 1,
449                    source = result.chunk.source,
450                    score = result.score
451                );
452                println!("      offset: {} bytes", result.chunk.byte_offset);
453                println!();
454                for line in preview.lines() {
455                    println!("      {line}");
456                }
457            }
458        }
459
460        println!();
461        println!("─────────────────────────────────────────");
462        println!(
463            "  {} results in {:.1}ms (embed: {:.1}ms)",
464            results.len(),
465            search_time.as_secs_f64() * 1000.0,
466            embed_time.as_secs_f64() * 1000.0,
467        );
468    }
469
470    Ok(())
471}
472
473fn cmd_info(index_dir: Option<&std::path::Path>) -> Result<()> {
474    let index_dir = index_dir
475        .map(PathBuf::from)
476        .unwrap_or_else(Index::default_dir);
477
478    let index = Index::load(&index_dir).with_context(|| {
479        format!(
480            "No index found at {}. Run `rag index <path>` first.",
481            index_dir.display()
482        )
483    })?;
484
485    let m = &index.meta;
486
487    let mut sources: Vec<&str> = index.chunks.iter().map(|c| c.source.as_str()).collect();
488    sources.sort();
489    sources.dedup();
490
491    let index_path = index_dir.join("index.bin");
492    let size = std::fs::metadata(&index_path).map(|m| m.len()).unwrap_or(0);
493
494    println!("RAG Index Info");
495    println!("─────────────────────────────────────────");
496    println!("  Index path:    {}", index_dir.display());
497    println!("  Root dir:      {}", m.root_dir);
498    println!("  Model:         {}", m.model_id);
499    println!("  Hidden size:   {}", m.hidden_size);
500    println!("  Chunks:        {}", m.num_chunks);
501    println!("  Source files:  {}", sources.len());
502    println!("  Chunk size:    {} chars", m.chunk_size);
503    println!("  Chunk overlap: {} chars", m.chunk_overlap);
504    println!("  Created:       {}", m.created_at);
505    println!("  Index size:    {}", format_bytes(size));
506
507    Ok(())
508}
509
510fn truncate_text(text: &str, max_chars: usize) -> String {
511    if text.len() <= max_chars {
512        text.to_string()
513    } else {
514        let mut end = max_chars;
515        while end < text.len() && !text.is_char_boundary(end) {
516            end += 1;
517        }
518        format!("{}...", &text[..end.min(text.len())])
519    }
520}
521
522fn format_bytes(bytes: u64) -> String {
523    if bytes < 1024 {
524        format!("{bytes} B")
525    } else if bytes < 1024 * 1024 {
526        format!("{:.1} KB", bytes as f64 / 1024.0)
527    } else if bytes < 1024 * 1024 * 1024 {
528        format!("{:.1} MB", bytes as f64 / (1024.0 * 1024.0))
529    } else {
530        format!("{:.2} GB", bytes as f64 / (1024.0 * 1024.0 * 1024.0))
531    }
532}
533
534fn chrono_now() -> String {
535    use std::process::Command;
536    Command::new("date")
537        .arg("+%Y-%m-%dT%H:%M:%S%z")
538        .output()
539        .ok()
540        .and_then(|o| String::from_utf8(o.stdout).ok())
541        .map(|s| s.trim().to_string())
542        .unwrap_or_else(|| "unknown".to_string())
543}