Skip to main content

rag_rat_core/
config.rs

1use std::{
2    collections::{BTreeMap, BTreeSet},
3    fs,
4    path::{Path, PathBuf},
5    str::FromStr,
6};
7
8use serde::Deserialize;
9use thiserror::Error;
10
11use crate::language::{Language, LanguageError};
12
13#[derive(Debug, Clone)]
14pub struct Config {
15    pub root: PathBuf,
16    pub database: PathBuf,
17    pub targets: Vec<ResolvedTarget>,
18    pub local_ai: LocalAiConfig,
19    pub watch: WatchConfig,
20}
21
22#[derive(Debug, Clone, PartialEq, Eq)]
23pub struct WatchConfig {
24    /// Run the background file watcher (default true). `RAG_RAT_NO_WATCH` overrides this off at
25    /// the call site.
26    pub enabled: bool,
27    /// Quiet window (ms) before a debounced reindex pass.
28    pub debounce_ms: u64,
29    /// Hard cap (ms): force a pass after this much continuous activity, so sustained writes never
30    /// starve the quiet-window debounce.
31    pub max_latency_ms: u64,
32    /// Periodic backstop: run a pass at least this often even with no events (0 disables). Covers
33    /// event-blind filesystems (NFS, WSL2 `/mnt`) and a watcher that missed events, and bounds how
34    /// long a wedged peer can leave the index stale.
35    pub periodic_sweep_secs: u64,
36}
37
38impl Default for WatchConfig {
39    fn default() -> Self {
40        Self { enabled: true, debounce_ms: 400, max_latency_ms: 2500, periodic_sweep_secs: 300 }
41    }
42}
43
44#[derive(Debug, Clone, Default, PartialEq, Eq)]
45pub struct LocalAiConfig {
46    pub embedding: EmbeddingConfig,
47}
48
49#[derive(Debug, Clone, Default, PartialEq, Eq)]
50pub struct EmbeddingConfig {
51    /// Which embedding backend to use for semantic (vector) recall. `init` picks a default based
52    /// on repo size; see [`EmbeddingBackend`].
53    pub backend: EmbeddingBackend,
54    pub runtime: EmbeddingRuntimeConfig,
55}
56
57/// The embedding backend selector (`[local_ai.embedding] model = "..."`).
58///
59/// - `FastEmbed` (default): MiniLM transformer — best quality, but the cold backfill is CPU-bound
60///   (~10-100 chunks/sec), so impractical for very large repos.
61/// - `Model2Vec`: static token-vector lookup + mean-pool — ~100-500× faster on CPU at some
62///   retrieval-quality cost (no context/word-order). The choice for huge repos that still want
63///   vectors.
64/// - `None`: structural + BM25 only; no dense vectors. `semantic_search` degrades to BM25. The
65///   cheapest option for enormous codebases where any embedding backfill is too slow.
66#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
67pub enum EmbeddingBackend {
68    #[default]
69    FastEmbed,
70    Model2Vec,
71    None,
72}
73
74impl EmbeddingBackend {
75    pub fn as_str(self) -> &'static str {
76        match self {
77            Self::FastEmbed => "minilm",
78            Self::Model2Vec => "model2vec",
79            Self::None => "none",
80        }
81    }
82
83    /// The persisted embedding-model id this backend installs/activates, or `None` for the
84    /// embeddings-off choice. Kept as a string so `rag-rat-core::index::ai` model ids stay the
85    /// single source of truth without `config` depending on `index`.
86    pub fn model_id(self) -> Option<&'static str> {
87        match self {
88            Self::FastEmbed => Some("fastembed-all-minilm-l6-v2"),
89            Self::Model2Vec => Some("model2vec-potion-retrieval-32m"),
90            Self::None => None,
91        }
92    }
93}
94
95impl FromStr for EmbeddingBackend {
96    type Err = ConfigError;
97
98    fn from_str(value: &str) -> Result<Self, Self::Err> {
99        match value.trim().to_ascii_lowercase().as_str() {
100            "minilm" | "fastembed" | "minilm-l6" => Ok(Self::FastEmbed),
101            "model2vec" | "potion" | "static" => Ok(Self::Model2Vec),
102            "none" | "off" | "bm25" => Ok(Self::None),
103            other => Err(ConfigError::UnknownEmbeddingBackend(other.to_string())),
104        }
105    }
106}
107
108#[derive(Debug, Clone, PartialEq, Eq)]
109pub struct EmbeddingRuntimeConfig {
110    pub batch_size: u32,
111    pub ort_threads: Option<u32>,
112    pub omp_threads: Option<u32>,
113    pub max_embedding_chars: usize,
114}
115
116impl Default for EmbeddingRuntimeConfig {
117    fn default() -> Self {
118        Self {
119            batch_size: 64,
120            ort_threads: Some(4),
121            omp_threads: Some(1),
122            max_embedding_chars: 4000,
123        }
124    }
125}
126
127#[derive(Debug, Clone, PartialEq, Eq)]
128pub struct ResolvedTarget {
129    pub name: String,
130    pub language: Language,
131    pub directories: Vec<PathBuf>,
132    pub include: Vec<String>,
133    pub exclude: Vec<String>,
134    pub kind: TargetKind,
135}
136
137#[derive(Debug, Clone, Copy, PartialEq, Eq)]
138pub enum TargetKind {
139    Source,
140    Generated,
141    Docs,
142    Tests,
143}
144
145impl TargetKind {
146    pub fn as_str(self) -> &'static str {
147        match self {
148            Self::Source => "source",
149            Self::Generated => "generated",
150            Self::Docs => "docs",
151            Self::Tests => "tests",
152        }
153    }
154}
155
156impl FromStr for TargetKind {
157    type Err = ConfigError;
158
159    fn from_str(value: &str) -> Result<Self, Self::Err> {
160        match value.trim().to_ascii_lowercase().as_str() {
161            "source" => Ok(Self::Source),
162            "generated" => Ok(Self::Generated),
163            "docs" => Ok(Self::Docs),
164            "tests" | "test" => Ok(Self::Tests),
165            other => Err(ConfigError::UnknownTargetKind(other.to_string())),
166        }
167    }
168}
169
170impl Config {
171    pub fn load(path: impl AsRef<Path>) -> Result<Self, ConfigError> {
172        let path = path.as_ref();
173        let text = fs::read_to_string(path)?;
174        let raw: RawConfig = toml::from_str(&text)?;
175        let config_dir = path.parent().unwrap_or_else(|| Path::new("."));
176        let root = config_dir.join(raw.index.root.unwrap_or_else(|| ".".to_string()));
177        let root = normalize_existing_dir(&root)?;
178        // One database per repo: resolve a relative database path against the *main* worktree so
179        // all linked worktrees of a repo share one index (the commit/worktree overlay is built for
180        // exactly this). An absolute path is honored as-is. The main worktree resolves against its
181        // own root (unchanged), so single-worktree users see no change.
182        let database = match raw.index.database {
183            Some(db) if Path::new(&db).is_absolute() => PathBuf::from(db),
184            other => {
185                let relative = other.unwrap_or_else(|| ".rag-rat/index.sqlite".to_string());
186                shared_db_base(&root).join(relative)
187            },
188        };
189        let targets = resolve_targets(&root, raw.target_bindings, raw.target)?;
190        let local_ai = LocalAiConfig::try_from(raw.local_ai)?;
191        let watch = raw.watch.into();
192
193        Ok(Self { root, database, targets, local_ai, watch })
194    }
195}
196
197fn resolve_targets(
198    root: &Path,
199    simple: BTreeMap<String, Vec<String>>,
200    expanded: Vec<RawTarget>,
201) -> Result<Vec<ResolvedTarget>, ConfigError> {
202    let mut names = BTreeSet::new();
203    let mut targets = Vec::new();
204
205    for (language_name, directories) in simple {
206        let language = Language::from_str(&language_name)?;
207        let kind =
208            if language == Language::Markdown { TargetKind::Docs } else { TargetKind::Source };
209        let name = language.as_str().to_string();
210        push_target(
211            root,
212            &mut names,
213            &mut targets,
214            ResolvedTarget {
215                include: language
216                    .simple_extensions()
217                    .iter()
218                    .map(|ext| format!("**/*.{ext}"))
219                    .collect(),
220                exclude: Vec::new(),
221                name,
222                language,
223                directories: directories.into_iter().map(PathBuf::from).collect(),
224                kind,
225            },
226        )?;
227    }
228
229    for target in expanded {
230        let language = Language::from_str(&target.language)?;
231        let kind = target
232            .kind
233            .as_deref()
234            .map(TargetKind::from_str)
235            .transpose()?
236            .unwrap_or(TargetKind::Source);
237        push_target(
238            root,
239            &mut names,
240            &mut targets,
241            ResolvedTarget {
242                name: target.name,
243                language,
244                directories: target.directories.into_iter().map(PathBuf::from).collect(),
245                include: target.include.unwrap_or_else(|| {
246                    language.simple_extensions().iter().map(|ext| format!("**/*.{ext}")).collect()
247                }),
248                exclude: target.exclude.unwrap_or_default(),
249                kind,
250            },
251        )?;
252    }
253
254    Ok(targets)
255}
256
257fn push_target(
258    root: &Path,
259    names: &mut BTreeSet<String>,
260    targets: &mut Vec<ResolvedTarget>,
261    target: ResolvedTarget,
262) -> Result<(), ConfigError> {
263    if !names.insert(target.name.clone()) {
264        return Err(ConfigError::DuplicateTarget(target.name));
265    }
266    for directory in &target.directories {
267        let full_path = root.join(directory);
268        if !full_path.is_dir() {
269            return Err(ConfigError::MissingDirectory(directory.clone()));
270        }
271    }
272    targets.push(target);
273    Ok(())
274}
275
276/// Base directory a relative `database` path resolves against. For a **linked** git worktree this
277/// is the **main** worktree root (so all worktrees share one index DB); for the main worktree or a
278/// non-git dir it is `root` unchanged — single-worktree setups keep their existing DB location.
279fn shared_db_base(root: &Path) -> PathBuf {
280    match main_worktree_root(root) {
281        Some(main_root) if main_root != root => main_root,
282        _ => root.to_path_buf(),
283    }
284}
285
286/// The main worktree root, derived from the git common dir (`<main>/.git`). Returns `None` outside
287/// a standard git repo (bare repo, custom `GIT_DIR`, git unavailable) so resolution falls back to
288/// `root` — never guess.
289fn main_worktree_root(root: &Path) -> Option<PathBuf> {
290    let output = std::process::Command::new("git")
291        .arg("-C")
292        .arg(root)
293        .args(["rev-parse", "--git-common-dir"])
294        .output()
295        .ok()?;
296    if !output.status.success() {
297        return None;
298    }
299    let common_dir = String::from_utf8(output.stdout).ok()?.trim().to_string();
300    if common_dir.is_empty() {
301        return None;
302    }
303    let common_dir = root.join(common_dir).canonicalize().ok()?;
304    // Only the standard `<main>/.git` layout maps cleanly to a main worktree root.
305    if common_dir.file_name()?.to_str()? != ".git" {
306        return None;
307    }
308    let main_root = common_dir.parent()?.to_path_buf();
309    main_root.is_dir().then_some(main_root)
310}
311
312fn normalize_existing_dir(path: &Path) -> Result<PathBuf, ConfigError> {
313    let absolute =
314        if path.is_absolute() { path.to_path_buf() } else { std::env::current_dir()?.join(path) };
315    let canonical = absolute.canonicalize()?;
316    if !canonical.is_dir() {
317        return Err(ConfigError::MissingDirectory(canonical));
318    }
319    Ok(canonical)
320}
321
322#[derive(Debug, Deserialize)]
323struct RawConfig {
324    #[serde(default)]
325    index: RawIndex,
326    #[serde(default)]
327    local_ai: RawLocalAi,
328    #[serde(default)]
329    watch: RawWatch,
330    #[serde(default)]
331    target_bindings: BTreeMap<String, Vec<String>>,
332    #[serde(default, rename = "target")]
333    target: Vec<RawTarget>,
334}
335
336#[derive(Debug, Default, Deserialize)]
337struct RawWatch {
338    enabled: Option<bool>,
339    debounce_ms: Option<u64>,
340    max_latency_ms: Option<u64>,
341    periodic_sweep_secs: Option<u64>,
342}
343
344impl From<RawWatch> for WatchConfig {
345    fn from(raw: RawWatch) -> Self {
346        let default = WatchConfig::default();
347        Self {
348            enabled: raw.enabled.unwrap_or(default.enabled),
349            debounce_ms: raw.debounce_ms.unwrap_or(default.debounce_ms),
350            max_latency_ms: raw.max_latency_ms.unwrap_or(default.max_latency_ms),
351            periodic_sweep_secs: raw.periodic_sweep_secs.unwrap_or(default.periodic_sweep_secs),
352        }
353    }
354}
355
356#[derive(Debug, Default, Deserialize)]
357struct RawIndex {
358    root: Option<String>,
359    database: Option<String>,
360}
361
362#[derive(Debug, Default, Deserialize)]
363struct RawLocalAi {
364    #[serde(default)]
365    embedding: RawEmbedding,
366}
367
368impl TryFrom<RawLocalAi> for LocalAiConfig {
369    type Error = ConfigError;
370
371    fn try_from(raw: RawLocalAi) -> Result<Self, Self::Error> {
372        Ok(Self { embedding: EmbeddingConfig::try_from(raw.embedding)? })
373    }
374}
375
376#[derive(Debug, Default, Deserialize)]
377struct RawEmbedding {
378    /// `model = "minilm" | "model2vec" | "none"` — the embedding backend selector.
379    model: Option<String>,
380    #[serde(default)]
381    runtime: RawEmbeddingRuntime,
382}
383
384impl TryFrom<RawEmbedding> for EmbeddingConfig {
385    type Error = ConfigError;
386
387    fn try_from(raw: RawEmbedding) -> Result<Self, Self::Error> {
388        let backend = match raw.model.as_deref() {
389            Some(value) => value.parse()?,
390            None => EmbeddingBackend::default(),
391        };
392        Ok(Self { backend, runtime: raw.runtime.into() })
393    }
394}
395
396#[derive(Debug, Default, Deserialize)]
397struct RawEmbeddingRuntime {
398    batch_size: Option<u32>,
399    ort_threads: Option<u32>,
400    omp_threads: Option<u32>,
401    max_embedding_chars: Option<usize>,
402}
403
404impl From<RawEmbeddingRuntime> for EmbeddingRuntimeConfig {
405    fn from(raw: RawEmbeddingRuntime) -> Self {
406        let default = EmbeddingRuntimeConfig::default();
407        Self {
408            batch_size: raw.batch_size.unwrap_or(default.batch_size),
409            ort_threads: raw.ort_threads.or(default.ort_threads),
410            omp_threads: raw.omp_threads.or(default.omp_threads),
411            max_embedding_chars: raw.max_embedding_chars.unwrap_or(default.max_embedding_chars),
412        }
413    }
414}
415
416#[derive(Debug, Deserialize)]
417struct RawTarget {
418    name: String,
419    language: String,
420    directories: Vec<String>,
421    kind: Option<String>,
422    include: Option<Vec<String>>,
423    exclude: Option<Vec<String>>,
424}
425
426#[derive(Debug, Error)]
427pub enum ConfigError {
428    #[error("failed to read config: {0}")]
429    Io(#[from] std::io::Error),
430    #[error("failed to parse config TOML: {0}")]
431    Toml(#[from] toml::de::Error),
432    #[error("{0}")]
433    Language(#[from] LanguageError),
434    #[error("unknown target kind `{0}`")]
435    UnknownTargetKind(String),
436    #[error("unknown embedding backend `{0}` (expected `minilm`, `model2vec`, or `none`)")]
437    UnknownEmbeddingBackend(String),
438    #[error("duplicate target name `{0}`")]
439    DuplicateTarget(String),
440    #[error("configured directory does not exist: {0}")]
441    MissingDirectory(PathBuf),
442}
443
444#[cfg(test)]
445mod tests {
446    use std::sync::atomic::{AtomicU64, Ordering};
447
448    use super::*;
449
450    static CFG_TEMP: AtomicU64 = AtomicU64::new(0);
451
452    #[test]
453    fn config_load_resolves_main_and_linked_worktrees_to_one_database() {
454        // The actual guarantee (review item 1): Config::load from the main worktree and from a
455        // linked worktree of the same repo produce the *same* database path — not two DBs.
456        let git = |dir: &Path, args: &[&str]| {
457            std::process::Command::new("git").arg("-C").arg(dir).args(args).output().unwrap()
458        };
459        let id = CFG_TEMP.fetch_add(1, Ordering::Relaxed);
460        let tmp = std::env::temp_dir().join(format!("ragrat-cfgload-{}-{id}", std::process::id()));
461        let main = tmp.join("main");
462        std::fs::create_dir_all(main.join("src")).unwrap();
463        std::fs::write(main.join("src/lib.rs"), "pub fn a() {}\n").unwrap();
464        std::fs::write(
465            main.join("rag-rat.toml"),
466            "[index]\nroot = \".\"\n[target_bindings]\nrust = [\"src\"]\n",
467        )
468        .unwrap();
469        git(&main, &["init", "-q"]);
470        git(&main, &["config", "user.email", "t@example.com"]);
471        git(&main, &["config", "user.name", "t"]);
472        git(&main, &["add", "-A"]);
473        git(&main, &["commit", "-qm", "seed"]);
474        let linked = tmp.join("wt");
475        git(&main, &["worktree", "add", "--detach", "-q", linked.to_str().unwrap()]);
476
477        let from_main = Config::load(main.join("rag-rat.toml")).unwrap();
478        let from_linked = Config::load(linked.join("rag-rat.toml")).unwrap();
479        assert_eq!(
480            from_main.database, from_linked.database,
481            "main and linked worktrees must share one index database",
482        );
483        assert_eq!(from_main.database, main.canonicalize().unwrap().join(".rag-rat/index.sqlite"));
484
485        let _ = std::fs::remove_dir_all(&tmp);
486    }
487
488    #[test]
489    fn shared_db_base_shares_one_db_across_worktrees() {
490        let git = |dir: &Path, args: &[&str]| {
491            std::process::Command::new("git").arg("-C").arg(dir).args(args).output().unwrap()
492        };
493        let id = CFG_TEMP.fetch_add(1, Ordering::Relaxed);
494        let tmp = std::env::temp_dir().join(format!("ragrat-cfg-{}-{id}", std::process::id()));
495        let main = tmp.join("main");
496        std::fs::create_dir_all(&main).unwrap();
497        git(&main, &["init", "-q"]);
498        git(&main, &["config", "user.email", "t@example.com"]);
499        git(&main, &["config", "user.name", "t"]);
500        std::fs::write(main.join("seed.txt"), "x").unwrap();
501        git(&main, &["add", "-A"]);
502        git(&main, &["commit", "-qm", "seed"]);
503        let linked = tmp.join("wt");
504        git(&main, &["worktree", "add", "--detach", "-q", linked.to_str().unwrap()]);
505
506        let main_c = main.canonicalize().unwrap();
507        let linked_c = linked.canonicalize().unwrap();
508
509        // Main worktree resolves to itself (no redirect → existing DB location preserved).
510        assert_eq!(shared_db_base(&main_c), main_c);
511        // Linked worktree redirects to the main worktree → one shared DB.
512        assert_eq!(shared_db_base(&linked_c), main_c);
513
514        // A non-git directory falls back to itself.
515        let plain = tmp.join("plain");
516        std::fs::create_dir_all(&plain).unwrap();
517        let plain_c = plain.canonicalize().unwrap();
518        assert_eq!(shared_db_base(&plain_c), plain_c);
519
520        let _ = std::fs::remove_dir_all(&tmp);
521    }
522
523    #[test]
524    fn parses_simple_and_expanded_targets() {
525        let root = std::env::current_dir().unwrap();
526        let simple = BTreeMap::from([("rust".to_string(), vec![".".to_string()])]);
527        let expanded = vec![RawTarget {
528            name: "generated-ts".to_string(),
529            language: "typescript".to_string(),
530            directories: vec![".".to_string()],
531            kind: Some("generated".to_string()),
532            include: Some(vec!["**/*.ts".to_string()]),
533            exclude: Some(vec!["**/*.map".to_string()]),
534        }];
535
536        let targets = resolve_targets(&root, simple, expanded).unwrap();
537
538        assert_eq!(targets.len(), 2);
539        assert_eq!(targets[0].language, Language::Rust);
540        assert_eq!(targets[1].kind, TargetKind::Generated);
541    }
542
543    #[test]
544    fn embedding_runtime_defaults_match_local_profile() {
545        let runtime = EmbeddingRuntimeConfig::default();
546
547        assert_eq!(runtime.batch_size, 64);
548        assert_eq!(runtime.ort_threads, Some(4));
549        assert_eq!(runtime.omp_threads, Some(1));
550        assert_eq!(runtime.max_embedding_chars, 4000);
551    }
552
553    #[test]
554    fn parses_embedding_runtime_overrides() {
555        let raw: RawConfig = toml::from_str(
556            r#"
557            [index]
558            root = "."
559            database = ".rag-rat/index.sqlite"
560
561            [local_ai.embedding.runtime]
562            batch_size = 128
563            ort_threads = 2
564            omp_threads = 1
565            max_embedding_chars = 5000
566            "#,
567        )
568        .unwrap();
569
570        let local_ai = LocalAiConfig::try_from(raw.local_ai).unwrap();
571
572        assert_eq!(
573            local_ai.embedding.runtime,
574            EmbeddingRuntimeConfig {
575                batch_size: 128,
576                ort_threads: Some(2),
577                omp_threads: Some(1),
578                max_embedding_chars: 5000,
579            }
580        );
581    }
582
583    #[test]
584    fn watch_config_defaults_on_and_parses_overrides() {
585        let default: WatchConfig = RawWatch::default().into();
586        assert!(default.enabled, "watcher is on by default");
587        assert_eq!(default.debounce_ms, 400);
588        assert_eq!(default.max_latency_ms, 2500);
589        assert_eq!(default.periodic_sweep_secs, 300);
590
591        let raw: RawConfig = toml::from_str(
592            r#"
593            [index]
594            root = "."
595
596            [watch]
597            enabled = false
598            debounce_ms = 750
599            max_latency_ms = 4000
600            periodic_sweep_secs = 0
601            "#,
602        )
603        .unwrap();
604        let watch: WatchConfig = raw.watch.into();
605        assert_eq!(
606            watch,
607            WatchConfig {
608                enabled: false,
609                debounce_ms: 750,
610                max_latency_ms: 4000,
611                periodic_sweep_secs: 0,
612            }
613        );
614    }
615
616    #[test]
617    fn rejects_unknown_language() {
618        let root = std::env::current_dir().unwrap();
619        let simple = BTreeMap::from([("python".to_string(), vec![".".to_string()])]);
620
621        let err = resolve_targets(&root, simple, Vec::new()).unwrap_err();
622
623        assert!(err.to_string().contains("unknown language"));
624    }
625}