1mod embed;
2mod index;
3mod ingest;
4
5use anyhow::{Context, Result};
6use clap::{Parser, Subcommand};
7use std::path::PathBuf;
8use std::time::Instant;
9
10use std::collections::BTreeMap;
11
12use crate::embed::{select_device, EmbeddingEngine, DEFAULT_MODEL};
13use crate::index::{search_top_k, ChunkRecord, Index, IndexMeta};
14use crate::ingest::{chunk_file, discover_files, hash_files};
15
16const DEFAULT_CHUNK_SIZE: usize = 512;
17const DEFAULT_CHUNK_OVERLAP: usize = 64;
18const DEFAULT_TOP_K: usize = 5;
19
20#[derive(Parser)]
21#[command(name = "rag")]
22#[command(about = "Local RAG — index and semantic search your files using local embeddings")]
23#[command(version)]
24struct Cli {
25 #[arg(long, global = true, env = "RAG_CACHE_DIR")]
28 cache_dir: Option<PathBuf>,
29
30 #[command(subcommand)]
31 command: Commands,
32}
33
34#[derive(Subcommand)]
35enum Commands {
36 Index {
38 path: PathBuf,
40
41 #[arg(short, long)]
43 output: Option<PathBuf>,
44
45 #[arg(short, long, default_value = DEFAULT_MODEL)]
47 model: String,
48
49 #[arg(long, default_value_t = DEFAULT_CHUNK_SIZE)]
51 chunk_size: usize,
52
53 #[arg(long, default_value_t = DEFAULT_CHUNK_OVERLAP)]
55 chunk_overlap: usize,
56 },
57
58 Search {
60 query: String,
62
63 #[arg(short, long)]
65 index: Option<PathBuf>,
66
67 #[arg(short = 'k', long, default_value_t = DEFAULT_TOP_K)]
69 top_k: usize,
70
71 #[arg(short, long)]
73 model: Option<String>,
74
75 #[arg(long)]
77 full: bool,
78
79 #[arg(long)]
81 json: bool,
82 },
83
84 Info {
86 #[arg(short, long)]
88 index: Option<PathBuf>,
89 },
90}
91
92pub fn run() -> Result<()> {
94 let cli = Cli::parse();
95 let cache_dir = cli.cache_dir.as_deref();
96
97 match cli.command {
98 Commands::Index {
99 path,
100 output,
101 model,
102 chunk_size,
103 chunk_overlap,
104 } => cmd_index(
105 &path,
106 output.as_deref(),
107 &model,
108 chunk_size,
109 chunk_overlap,
110 cache_dir,
111 ),
112 Commands::Search {
113 query,
114 index,
115 top_k,
116 model,
117 full,
118 json,
119 } => cmd_search(
120 &query,
121 index.as_deref(),
122 top_k,
123 model.as_deref(),
124 full,
125 json,
126 cache_dir,
127 ),
128 Commands::Info { index } => cmd_info(index.as_deref()),
129 }
130}
131
132fn cmd_index(
133 path: &PathBuf,
134 output: Option<&std::path::Path>,
135 model_id: &str,
136 chunk_size: usize,
137 chunk_overlap: usize,
138 cache_dir: Option<&std::path::Path>,
139) -> Result<()> {
140 let start = Instant::now();
141
142 let root = path
143 .canonicalize()
144 .with_context(|| format!("Directory not found: {}", path.display()))?;
145
146 if !root.is_dir() {
147 anyhow::bail!("{} is not a directory", root.display());
148 }
149
150 let index_dir = output.map(PathBuf::from).unwrap_or_else(Index::default_dir);
151
152 eprintln!("Indexing: {}", root.display());
154 let files = discover_files(&root)?;
155 eprintln!("Found {} text files", files.len());
156
157 if files.is_empty() {
158 anyhow::bail!("No text files found in {}", root.display());
159 }
160
161 let current_hashes = hash_files(&files, &root)?;
162
163 let prev_index = Index::load(&index_dir).ok();
165 let can_reuse = prev_index.as_ref().is_some_and(|prev| {
166 prev.meta.model_id == model_id
167 && prev.meta.chunk_size == chunk_size
168 && prev.meta.chunk_overlap == chunk_overlap
169 });
170
171 let (chunks, file_hashes, hidden_size) = if can_reuse {
172 let prev = prev_index.unwrap();
173 incremental_index(
174 &root,
175 &files,
176 ¤t_hashes,
177 &prev,
178 model_id,
179 chunk_size,
180 chunk_overlap,
181 cache_dir,
182 )?
183 } else {
184 if prev_index.is_some() {
185 eprintln!("Settings changed, performing full re-index");
186 }
187 full_index(
188 &root,
189 &files,
190 ¤t_hashes,
191 model_id,
192 chunk_size,
193 chunk_overlap,
194 cache_dir,
195 )?
196 };
197
198 if chunks.is_empty() {
199 anyhow::bail!("No text chunks produced. Check the directory contents.");
200 }
201
202 let meta = IndexMeta {
204 model_id: model_id.to_string(),
205 hidden_size,
206 num_chunks: chunks.len(),
207 root_dir: root.to_string_lossy().to_string(),
208 created_at: chrono_now(),
209 chunk_size,
210 chunk_overlap,
211 file_hashes,
212 };
213
214 let index = Index::new(meta, chunks);
215 index.save(&index_dir)?;
216
217 let elapsed = start.elapsed();
218 eprintln!(
219 "Index saved to {} ({} chunks, {:.1}s)",
220 index_dir.display(),
221 index.meta.num_chunks,
222 elapsed.as_secs_f64()
223 );
224
225 Ok(())
226}
227
228fn full_index(
230 root: &std::path::Path,
231 files: &[PathBuf],
232 current_hashes: &BTreeMap<String, String>,
233 model_id: &str,
234 chunk_size: usize,
235 chunk_overlap: usize,
236 cache_dir: Option<&std::path::Path>,
237) -> Result<(Vec<ChunkRecord>, BTreeMap<String, String>, usize)> {
238 let mut all_chunks = Vec::new();
239 for file in files {
240 match chunk_file(file, root, chunk_size, chunk_overlap) {
241 Ok(chunks) => all_chunks.extend(chunks),
242 Err(e) => eprintln!(" Skipping {}: {e}", file.display()),
243 }
244 }
245
246 eprintln!("Embedding {} chunks...", all_chunks.len());
247 let device = select_device()?;
248 let engine = EmbeddingEngine::load(Some(model_id), &device, cache_dir)?;
249 let hidden_size = engine.hidden_size();
250
251 let texts: Vec<String> = all_chunks.iter().map(|c| c.text.clone()).collect();
252 let embeddings = engine.embed_batch_progress(&texts)?;
253
254 let chunks: Vec<ChunkRecord> = all_chunks
255 .into_iter()
256 .zip(embeddings)
257 .map(|(tc, emb)| ChunkRecord {
258 source: tc.source,
259 byte_offset: tc.byte_offset,
260 text: tc.text,
261 embedding: emb,
262 })
263 .collect();
264
265 Ok((chunks, current_hashes.clone(), hidden_size))
266}
267
268fn incremental_index(
270 root: &std::path::Path,
271 files: &[PathBuf],
272 current_hashes: &BTreeMap<String, String>,
273 prev: &Index,
274 model_id: &str,
275 chunk_size: usize,
276 chunk_overlap: usize,
277 cache_dir: Option<&std::path::Path>,
278) -> Result<(Vec<ChunkRecord>, BTreeMap<String, String>, usize)> {
279 let mut unchanged: Vec<&str> = Vec::new();
281 let mut dirty_files: Vec<&PathBuf> = Vec::new();
282
283 for file in files {
284 let relative = file
285 .strip_prefix(root)
286 .unwrap_or(file)
287 .to_string_lossy()
288 .to_string();
289
290 let cur_hash = current_hashes.get(&relative);
291 let prev_hash = prev.meta.file_hashes.get(&relative);
292
293 if cur_hash.is_some() && cur_hash == prev_hash {
294 unchanged.push(
295 prev.meta
296 .file_hashes
297 .get_key_value(&relative)
298 .unwrap()
299 .0
300 .as_str(),
301 );
302 } else {
303 dirty_files.push(file);
304 }
305 }
306
307 let deleted: Vec<&str> = prev
308 .meta
309 .file_hashes
310 .keys()
311 .filter(|k| !current_hashes.contains_key(k.as_str()))
312 .map(|k| k.as_str())
313 .collect();
314
315 eprintln!(
316 "Incremental: {} unchanged, {} changed/new, {} deleted",
317 unchanged.len(),
318 dirty_files.len(),
319 deleted.len(),
320 );
321
322 let mut chunks: Vec<ChunkRecord> = prev
324 .chunks
325 .iter()
326 .filter(|c| unchanged.contains(&c.source.as_str()))
327 .cloned()
328 .collect();
329
330 let hidden_size = if !dirty_files.is_empty() {
332 let mut new_text_chunks = Vec::new();
333 for file in &dirty_files {
334 match chunk_file(file, root, chunk_size, chunk_overlap) {
335 Ok(cs) => new_text_chunks.extend(cs),
336 Err(e) => eprintln!(" Skipping {}: {e}", file.display()),
337 }
338 }
339
340 if !new_text_chunks.is_empty() {
341 eprintln!("Embedding {} new/changed chunks...", new_text_chunks.len());
342 let device = select_device()?;
343 let engine = EmbeddingEngine::load(Some(model_id), &device, cache_dir)?;
344 let hs = engine.hidden_size();
345
346 let texts: Vec<String> = new_text_chunks.iter().map(|c| c.text.clone()).collect();
347 let embeddings = engine.embed_batch_progress(&texts)?;
348
349 let new_chunks: Vec<ChunkRecord> = new_text_chunks
350 .into_iter()
351 .zip(embeddings)
352 .map(|(tc, emb)| ChunkRecord {
353 source: tc.source,
354 byte_offset: tc.byte_offset,
355 text: tc.text,
356 embedding: emb,
357 })
358 .collect();
359
360 chunks.extend(new_chunks);
361 hs
362 } else {
363 prev.meta.hidden_size
364 }
365 } else {
366 eprintln!("Everything up to date, nothing to embed");
367 prev.meta.hidden_size
368 };
369
370 Ok((chunks, current_hashes.clone(), hidden_size))
371}
372
373#[derive(serde::Serialize)]
375struct JsonResult {
376 source: String,
377 score: f32,
378 byte_offset: usize,
379 text: String,
380}
381
382fn cmd_search(
383 query: &str,
384 index_dir: Option<&std::path::Path>,
385 top_k: usize,
386 model_override: Option<&str>,
387 full: bool,
388 json: bool,
389 cache_dir: Option<&std::path::Path>,
390) -> Result<()> {
391 let start = Instant::now();
392
393 let index_dir = index_dir
394 .map(PathBuf::from)
395 .unwrap_or_else(Index::default_dir);
396
397 let index = Index::load(&index_dir).with_context(|| {
398 format!(
399 "No index found at {}. Run `rag index <path>` first.",
400 index_dir.display()
401 )
402 })?;
403
404 let model_id = model_override.unwrap_or(&index.meta.model_id);
405
406 let device = select_device()?;
408 let engine = EmbeddingEngine::load(Some(model_id), &device, cache_dir)?;
409 let query_embedding = engine.embed_one(query)?;
410
411 let embed_time = start.elapsed();
412
413 let results = search_top_k(&query_embedding, &index.chunks, top_k);
415
416 let search_time = start.elapsed();
417
418 if json {
419 let json_results: Vec<JsonResult> = results
420 .iter()
421 .map(|r| JsonResult {
422 source: r.chunk.source.clone(),
423 score: r.score,
424 byte_offset: r.chunk.byte_offset,
425 text: r.chunk.text.clone(),
426 })
427 .collect();
428 println!("{}", serde_json::to_string(&json_results)?);
429 } else {
430 println!();
432 println!("Query: {query}");
433 println!("─────────────────────────────────────────");
434
435 if results.is_empty() {
436 println!("No results found.");
437 } else {
438 for (i, result) in results.iter().enumerate() {
439 let preview = if full {
440 result.chunk.text.clone()
441 } else {
442 truncate_text(&result.chunk.text, 200)
443 };
444
445 println!();
446 println!(
447 " [{rank}] {source} (score: {score:.4})",
448 rank = i + 1,
449 source = result.chunk.source,
450 score = result.score
451 );
452 println!(" offset: {} bytes", result.chunk.byte_offset);
453 println!();
454 for line in preview.lines() {
455 println!(" {line}");
456 }
457 }
458 }
459
460 println!();
461 println!("─────────────────────────────────────────");
462 println!(
463 " {} results in {:.1}ms (embed: {:.1}ms)",
464 results.len(),
465 search_time.as_secs_f64() * 1000.0,
466 embed_time.as_secs_f64() * 1000.0,
467 );
468 }
469
470 Ok(())
471}
472
473fn cmd_info(index_dir: Option<&std::path::Path>) -> Result<()> {
474 let index_dir = index_dir
475 .map(PathBuf::from)
476 .unwrap_or_else(Index::default_dir);
477
478 let index = Index::load(&index_dir).with_context(|| {
479 format!(
480 "No index found at {}. Run `rag index <path>` first.",
481 index_dir.display()
482 )
483 })?;
484
485 let m = &index.meta;
486
487 let mut sources: Vec<&str> = index.chunks.iter().map(|c| c.source.as_str()).collect();
488 sources.sort();
489 sources.dedup();
490
491 let index_path = index_dir.join("index.bin");
492 let size = std::fs::metadata(&index_path).map(|m| m.len()).unwrap_or(0);
493
494 println!("RAG Index Info");
495 println!("─────────────────────────────────────────");
496 println!(" Index path: {}", index_dir.display());
497 println!(" Root dir: {}", m.root_dir);
498 println!(" Model: {}", m.model_id);
499 println!(" Hidden size: {}", m.hidden_size);
500 println!(" Chunks: {}", m.num_chunks);
501 println!(" Source files: {}", sources.len());
502 println!(" Chunk size: {} chars", m.chunk_size);
503 println!(" Chunk overlap: {} chars", m.chunk_overlap);
504 println!(" Created: {}", m.created_at);
505 println!(" Index size: {}", format_bytes(size));
506
507 Ok(())
508}
509
510fn truncate_text(text: &str, max_chars: usize) -> String {
511 if text.len() <= max_chars {
512 text.to_string()
513 } else {
514 let mut end = max_chars;
515 while end < text.len() && !text.is_char_boundary(end) {
516 end += 1;
517 }
518 format!("{}...", &text[..end.min(text.len())])
519 }
520}
521
522fn format_bytes(bytes: u64) -> String {
523 if bytes < 1024 {
524 format!("{bytes} B")
525 } else if bytes < 1024 * 1024 {
526 format!("{:.1} KB", bytes as f64 / 1024.0)
527 } else if bytes < 1024 * 1024 * 1024 {
528 format!("{:.1} MB", bytes as f64 / (1024.0 * 1024.0))
529 } else {
530 format!("{:.2} GB", bytes as f64 / (1024.0 * 1024.0 * 1024.0))
531 }
532}
533
534fn chrono_now() -> String {
535 use std::process::Command;
536 Command::new("date")
537 .arg("+%Y-%m-%dT%H:%M:%S%z")
538 .output()
539 .ok()
540 .and_then(|o| String::from_utf8(o.stdout).ok())
541 .map(|s| s.trim().to_string())
542 .unwrap_or_else(|| "unknown".to_string())
543}