1use std::{
2 collections::{BTreeMap, BTreeSet},
3 fs,
4 path::{Path, PathBuf},
5 str::FromStr,
6};
7
8use serde::Deserialize;
9use thiserror::Error;
10
11use crate::language::{Language, LanguageError};
12
13#[derive(Debug, Clone)]
14pub struct Config {
15 pub root: PathBuf,
16 pub database: PathBuf,
17 pub targets: Vec<ResolvedTarget>,
18 pub local_ai: LocalAiConfig,
19 pub watch: WatchConfig,
20}
21
22#[derive(Debug, Clone, PartialEq, Eq)]
23pub struct WatchConfig {
24 pub enabled: bool,
27 pub debounce_ms: u64,
29 pub max_latency_ms: u64,
32 pub periodic_sweep_secs: u64,
36}
37
38impl Default for WatchConfig {
39 fn default() -> Self {
40 Self { enabled: true, debounce_ms: 400, max_latency_ms: 2500, periodic_sweep_secs: 300 }
41 }
42}
43
44#[derive(Debug, Clone, Default, PartialEq, Eq)]
45pub struct LocalAiConfig {
46 pub embedding: EmbeddingConfig,
47}
48
49#[derive(Debug, Clone, Default, PartialEq, Eq)]
50pub struct EmbeddingConfig {
51 pub backend: EmbeddingBackend,
54 pub runtime: EmbeddingRuntimeConfig,
55}
56
57#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
67pub enum EmbeddingBackend {
68 #[default]
69 FastEmbed,
70 Model2Vec,
71 None,
72}
73
74impl EmbeddingBackend {
75 pub fn as_str(self) -> &'static str {
76 match self {
77 Self::FastEmbed => "minilm",
78 Self::Model2Vec => "model2vec",
79 Self::None => "none",
80 }
81 }
82
83 pub fn model_id(self) -> Option<&'static str> {
87 match self {
88 Self::FastEmbed => Some("fastembed-all-minilm-l6-v2"),
89 Self::Model2Vec => Some("model2vec-potion-retrieval-32m"),
90 Self::None => None,
91 }
92 }
93}
94
95impl FromStr for EmbeddingBackend {
96 type Err = ConfigError;
97
98 fn from_str(value: &str) -> Result<Self, Self::Err> {
99 match value.trim().to_ascii_lowercase().as_str() {
100 "minilm" | "fastembed" | "minilm-l6" => Ok(Self::FastEmbed),
101 "model2vec" | "potion" | "static" => Ok(Self::Model2Vec),
102 "none" | "off" | "bm25" => Ok(Self::None),
103 other => Err(ConfigError::UnknownEmbeddingBackend(other.to_string())),
104 }
105 }
106}
107
108#[derive(Debug, Clone, PartialEq, Eq)]
109pub struct EmbeddingRuntimeConfig {
110 pub batch_size: u32,
111 pub ort_threads: Option<u32>,
112 pub omp_threads: Option<u32>,
113 pub max_embedding_chars: usize,
114}
115
116impl Default for EmbeddingRuntimeConfig {
117 fn default() -> Self {
118 Self {
119 batch_size: 64,
120 ort_threads: Some(4),
121 omp_threads: Some(1),
122 max_embedding_chars: 4000,
123 }
124 }
125}
126
127#[derive(Debug, Clone, PartialEq, Eq)]
128pub struct ResolvedTarget {
129 pub name: String,
130 pub language: Language,
131 pub directories: Vec<PathBuf>,
132 pub include: Vec<String>,
133 pub exclude: Vec<String>,
134 pub kind: TargetKind,
135}
136
137#[derive(Debug, Clone, Copy, PartialEq, Eq)]
138pub enum TargetKind {
139 Source,
140 Generated,
141 Docs,
142 Tests,
143}
144
145impl TargetKind {
146 pub fn as_str(self) -> &'static str {
147 match self {
148 Self::Source => "source",
149 Self::Generated => "generated",
150 Self::Docs => "docs",
151 Self::Tests => "tests",
152 }
153 }
154}
155
156impl FromStr for TargetKind {
157 type Err = ConfigError;
158
159 fn from_str(value: &str) -> Result<Self, Self::Err> {
160 match value.trim().to_ascii_lowercase().as_str() {
161 "source" => Ok(Self::Source),
162 "generated" => Ok(Self::Generated),
163 "docs" => Ok(Self::Docs),
164 "tests" | "test" => Ok(Self::Tests),
165 other => Err(ConfigError::UnknownTargetKind(other.to_string())),
166 }
167 }
168}
169
170impl Config {
171 pub fn load(path: impl AsRef<Path>) -> Result<Self, ConfigError> {
172 let path = path.as_ref();
173 let text = fs::read_to_string(path)?;
174 let raw: RawConfig = toml::from_str(&text)?;
175 let config_dir = path.parent().unwrap_or_else(|| Path::new("."));
176 let root = config_dir.join(raw.index.root.unwrap_or_else(|| ".".to_string()));
177 let root = normalize_existing_dir(&root)?;
178 let database = match raw.index.database {
183 Some(db) if Path::new(&db).is_absolute() => PathBuf::from(db),
184 other => {
185 let relative = other.unwrap_or_else(|| ".rag-rat/index.sqlite".to_string());
186 shared_db_base(&root).join(relative)
187 },
188 };
189 let targets = resolve_targets(&root, raw.target_bindings, raw.target)?;
190 let local_ai = LocalAiConfig::try_from(raw.local_ai)?;
191 let watch = raw.watch.into();
192
193 Ok(Self { root, database, targets, local_ai, watch })
194 }
195}
196
197fn resolve_targets(
198 root: &Path,
199 simple: BTreeMap<String, Vec<String>>,
200 expanded: Vec<RawTarget>,
201) -> Result<Vec<ResolvedTarget>, ConfigError> {
202 let mut names = BTreeSet::new();
203 let mut targets = Vec::new();
204
205 for (language_name, directories) in simple {
206 let language = Language::from_str(&language_name)?;
207 let kind =
208 if language == Language::Markdown { TargetKind::Docs } else { TargetKind::Source };
209 let name = language.as_str().to_string();
210 push_target(
211 root,
212 &mut names,
213 &mut targets,
214 ResolvedTarget {
215 include: language
216 .simple_extensions()
217 .iter()
218 .map(|ext| format!("**/*.{ext}"))
219 .collect(),
220 exclude: Vec::new(),
221 name,
222 language,
223 directories: directories.into_iter().map(PathBuf::from).collect(),
224 kind,
225 },
226 )?;
227 }
228
229 for target in expanded {
230 let language = Language::from_str(&target.language)?;
231 let kind = target
232 .kind
233 .as_deref()
234 .map(TargetKind::from_str)
235 .transpose()?
236 .unwrap_or(TargetKind::Source);
237 push_target(
238 root,
239 &mut names,
240 &mut targets,
241 ResolvedTarget {
242 name: target.name,
243 language,
244 directories: target.directories.into_iter().map(PathBuf::from).collect(),
245 include: target.include.unwrap_or_else(|| {
246 language.simple_extensions().iter().map(|ext| format!("**/*.{ext}")).collect()
247 }),
248 exclude: target.exclude.unwrap_or_default(),
249 kind,
250 },
251 )?;
252 }
253
254 Ok(targets)
255}
256
257fn push_target(
258 root: &Path,
259 names: &mut BTreeSet<String>,
260 targets: &mut Vec<ResolvedTarget>,
261 target: ResolvedTarget,
262) -> Result<(), ConfigError> {
263 if !names.insert(target.name.clone()) {
264 return Err(ConfigError::DuplicateTarget(target.name));
265 }
266 for directory in &target.directories {
267 let full_path = root.join(directory);
268 if !full_path.is_dir() {
269 return Err(ConfigError::MissingDirectory(directory.clone()));
270 }
271 }
272 targets.push(target);
273 Ok(())
274}
275
276fn shared_db_base(root: &Path) -> PathBuf {
280 match main_worktree_root(root) {
281 Some(main_root) if main_root != root => main_root,
282 _ => root.to_path_buf(),
283 }
284}
285
286fn main_worktree_root(root: &Path) -> Option<PathBuf> {
290 let output = std::process::Command::new("git")
291 .arg("-C")
292 .arg(root)
293 .args(["rev-parse", "--git-common-dir"])
294 .output()
295 .ok()?;
296 if !output.status.success() {
297 return None;
298 }
299 let common_dir = String::from_utf8(output.stdout).ok()?.trim().to_string();
300 if common_dir.is_empty() {
301 return None;
302 }
303 let common_dir = root.join(common_dir).canonicalize().ok()?;
304 if common_dir.file_name()?.to_str()? != ".git" {
306 return None;
307 }
308 let main_root = common_dir.parent()?.to_path_buf();
309 main_root.is_dir().then_some(main_root)
310}
311
312fn normalize_existing_dir(path: &Path) -> Result<PathBuf, ConfigError> {
313 let absolute =
314 if path.is_absolute() { path.to_path_buf() } else { std::env::current_dir()?.join(path) };
315 let canonical = absolute.canonicalize()?;
316 if !canonical.is_dir() {
317 return Err(ConfigError::MissingDirectory(canonical));
318 }
319 Ok(canonical)
320}
321
322#[derive(Debug, Deserialize)]
323struct RawConfig {
324 #[serde(default)]
325 index: RawIndex,
326 #[serde(default)]
327 local_ai: RawLocalAi,
328 #[serde(default)]
329 watch: RawWatch,
330 #[serde(default)]
331 target_bindings: BTreeMap<String, Vec<String>>,
332 #[serde(default, rename = "target")]
333 target: Vec<RawTarget>,
334}
335
336#[derive(Debug, Default, Deserialize)]
337struct RawWatch {
338 enabled: Option<bool>,
339 debounce_ms: Option<u64>,
340 max_latency_ms: Option<u64>,
341 periodic_sweep_secs: Option<u64>,
342}
343
344impl From<RawWatch> for WatchConfig {
345 fn from(raw: RawWatch) -> Self {
346 let default = WatchConfig::default();
347 Self {
348 enabled: raw.enabled.unwrap_or(default.enabled),
349 debounce_ms: raw.debounce_ms.unwrap_or(default.debounce_ms),
350 max_latency_ms: raw.max_latency_ms.unwrap_or(default.max_latency_ms),
351 periodic_sweep_secs: raw.periodic_sweep_secs.unwrap_or(default.periodic_sweep_secs),
352 }
353 }
354}
355
356#[derive(Debug, Default, Deserialize)]
357struct RawIndex {
358 root: Option<String>,
359 database: Option<String>,
360}
361
362#[derive(Debug, Default, Deserialize)]
363struct RawLocalAi {
364 #[serde(default)]
365 embedding: RawEmbedding,
366}
367
368impl TryFrom<RawLocalAi> for LocalAiConfig {
369 type Error = ConfigError;
370
371 fn try_from(raw: RawLocalAi) -> Result<Self, Self::Error> {
372 Ok(Self { embedding: EmbeddingConfig::try_from(raw.embedding)? })
373 }
374}
375
376#[derive(Debug, Default, Deserialize)]
377struct RawEmbedding {
378 model: Option<String>,
380 #[serde(default)]
381 runtime: RawEmbeddingRuntime,
382}
383
384impl TryFrom<RawEmbedding> for EmbeddingConfig {
385 type Error = ConfigError;
386
387 fn try_from(raw: RawEmbedding) -> Result<Self, Self::Error> {
388 let backend = match raw.model.as_deref() {
389 Some(value) => value.parse()?,
390 None => EmbeddingBackend::default(),
391 };
392 Ok(Self { backend, runtime: raw.runtime.into() })
393 }
394}
395
396#[derive(Debug, Default, Deserialize)]
397struct RawEmbeddingRuntime {
398 batch_size: Option<u32>,
399 ort_threads: Option<u32>,
400 omp_threads: Option<u32>,
401 max_embedding_chars: Option<usize>,
402}
403
404impl From<RawEmbeddingRuntime> for EmbeddingRuntimeConfig {
405 fn from(raw: RawEmbeddingRuntime) -> Self {
406 let default = EmbeddingRuntimeConfig::default();
407 Self {
408 batch_size: raw.batch_size.unwrap_or(default.batch_size),
409 ort_threads: raw.ort_threads.or(default.ort_threads),
410 omp_threads: raw.omp_threads.or(default.omp_threads),
411 max_embedding_chars: raw.max_embedding_chars.unwrap_or(default.max_embedding_chars),
412 }
413 }
414}
415
416#[derive(Debug, Deserialize)]
417struct RawTarget {
418 name: String,
419 language: String,
420 directories: Vec<String>,
421 kind: Option<String>,
422 include: Option<Vec<String>>,
423 exclude: Option<Vec<String>>,
424}
425
426#[derive(Debug, Error)]
427pub enum ConfigError {
428 #[error("failed to read config: {0}")]
429 Io(#[from] std::io::Error),
430 #[error("failed to parse config TOML: {0}")]
431 Toml(#[from] toml::de::Error),
432 #[error("{0}")]
433 Language(#[from] LanguageError),
434 #[error("unknown target kind `{0}`")]
435 UnknownTargetKind(String),
436 #[error("unknown embedding backend `{0}` (expected `minilm`, `model2vec`, or `none`)")]
437 UnknownEmbeddingBackend(String),
438 #[error("duplicate target name `{0}`")]
439 DuplicateTarget(String),
440 #[error("configured directory does not exist: {0}")]
441 MissingDirectory(PathBuf),
442}
443
444#[cfg(test)]
445mod tests {
446 use std::sync::atomic::{AtomicU64, Ordering};
447
448 use super::*;
449
450 static CFG_TEMP: AtomicU64 = AtomicU64::new(0);
451
452 #[test]
453 fn config_load_resolves_main_and_linked_worktrees_to_one_database() {
454 let git = |dir: &Path, args: &[&str]| {
457 std::process::Command::new("git").arg("-C").arg(dir).args(args).output().unwrap()
458 };
459 let id = CFG_TEMP.fetch_add(1, Ordering::Relaxed);
460 let tmp = std::env::temp_dir().join(format!("ragrat-cfgload-{}-{id}", std::process::id()));
461 let main = tmp.join("main");
462 std::fs::create_dir_all(main.join("src")).unwrap();
463 std::fs::write(main.join("src/lib.rs"), "pub fn a() {}\n").unwrap();
464 std::fs::write(
465 main.join("rag-rat.toml"),
466 "[index]\nroot = \".\"\n[target_bindings]\nrust = [\"src\"]\n",
467 )
468 .unwrap();
469 git(&main, &["init", "-q"]);
470 git(&main, &["config", "user.email", "t@example.com"]);
471 git(&main, &["config", "user.name", "t"]);
472 git(&main, &["add", "-A"]);
473 git(&main, &["commit", "-qm", "seed"]);
474 let linked = tmp.join("wt");
475 git(&main, &["worktree", "add", "--detach", "-q", linked.to_str().unwrap()]);
476
477 let from_main = Config::load(main.join("rag-rat.toml")).unwrap();
478 let from_linked = Config::load(linked.join("rag-rat.toml")).unwrap();
479 assert_eq!(
480 from_main.database, from_linked.database,
481 "main and linked worktrees must share one index database",
482 );
483 assert_eq!(from_main.database, main.canonicalize().unwrap().join(".rag-rat/index.sqlite"));
484
485 let _ = std::fs::remove_dir_all(&tmp);
486 }
487
488 #[test]
489 fn shared_db_base_shares_one_db_across_worktrees() {
490 let git = |dir: &Path, args: &[&str]| {
491 std::process::Command::new("git").arg("-C").arg(dir).args(args).output().unwrap()
492 };
493 let id = CFG_TEMP.fetch_add(1, Ordering::Relaxed);
494 let tmp = std::env::temp_dir().join(format!("ragrat-cfg-{}-{id}", std::process::id()));
495 let main = tmp.join("main");
496 std::fs::create_dir_all(&main).unwrap();
497 git(&main, &["init", "-q"]);
498 git(&main, &["config", "user.email", "t@example.com"]);
499 git(&main, &["config", "user.name", "t"]);
500 std::fs::write(main.join("seed.txt"), "x").unwrap();
501 git(&main, &["add", "-A"]);
502 git(&main, &["commit", "-qm", "seed"]);
503 let linked = tmp.join("wt");
504 git(&main, &["worktree", "add", "--detach", "-q", linked.to_str().unwrap()]);
505
506 let main_c = main.canonicalize().unwrap();
507 let linked_c = linked.canonicalize().unwrap();
508
509 assert_eq!(shared_db_base(&main_c), main_c);
511 assert_eq!(shared_db_base(&linked_c), main_c);
513
514 let plain = tmp.join("plain");
516 std::fs::create_dir_all(&plain).unwrap();
517 let plain_c = plain.canonicalize().unwrap();
518 assert_eq!(shared_db_base(&plain_c), plain_c);
519
520 let _ = std::fs::remove_dir_all(&tmp);
521 }
522
523 #[test]
524 fn parses_simple_and_expanded_targets() {
525 let root = std::env::current_dir().unwrap();
526 let simple = BTreeMap::from([("rust".to_string(), vec![".".to_string()])]);
527 let expanded = vec![RawTarget {
528 name: "generated-ts".to_string(),
529 language: "typescript".to_string(),
530 directories: vec![".".to_string()],
531 kind: Some("generated".to_string()),
532 include: Some(vec!["**/*.ts".to_string()]),
533 exclude: Some(vec!["**/*.map".to_string()]),
534 }];
535
536 let targets = resolve_targets(&root, simple, expanded).unwrap();
537
538 assert_eq!(targets.len(), 2);
539 assert_eq!(targets[0].language, Language::Rust);
540 assert_eq!(targets[1].kind, TargetKind::Generated);
541 }
542
543 #[test]
544 fn embedding_runtime_defaults_match_local_profile() {
545 let runtime = EmbeddingRuntimeConfig::default();
546
547 assert_eq!(runtime.batch_size, 64);
548 assert_eq!(runtime.ort_threads, Some(4));
549 assert_eq!(runtime.omp_threads, Some(1));
550 assert_eq!(runtime.max_embedding_chars, 4000);
551 }
552
553 #[test]
554 fn parses_embedding_runtime_overrides() {
555 let raw: RawConfig = toml::from_str(
556 r#"
557 [index]
558 root = "."
559 database = ".rag-rat/index.sqlite"
560
561 [local_ai.embedding.runtime]
562 batch_size = 128
563 ort_threads = 2
564 omp_threads = 1
565 max_embedding_chars = 5000
566 "#,
567 )
568 .unwrap();
569
570 let local_ai = LocalAiConfig::try_from(raw.local_ai).unwrap();
571
572 assert_eq!(
573 local_ai.embedding.runtime,
574 EmbeddingRuntimeConfig {
575 batch_size: 128,
576 ort_threads: Some(2),
577 omp_threads: Some(1),
578 max_embedding_chars: 5000,
579 }
580 );
581 }
582
583 #[test]
584 fn watch_config_defaults_on_and_parses_overrides() {
585 let default: WatchConfig = RawWatch::default().into();
586 assert!(default.enabled, "watcher is on by default");
587 assert_eq!(default.debounce_ms, 400);
588 assert_eq!(default.max_latency_ms, 2500);
589 assert_eq!(default.periodic_sweep_secs, 300);
590
591 let raw: RawConfig = toml::from_str(
592 r#"
593 [index]
594 root = "."
595
596 [watch]
597 enabled = false
598 debounce_ms = 750
599 max_latency_ms = 4000
600 periodic_sweep_secs = 0
601 "#,
602 )
603 .unwrap();
604 let watch: WatchConfig = raw.watch.into();
605 assert_eq!(
606 watch,
607 WatchConfig {
608 enabled: false,
609 debounce_ms: 750,
610 max_latency_ms: 4000,
611 periodic_sweep_secs: 0,
612 }
613 );
614 }
615
616 #[test]
617 fn rejects_unknown_language() {
618 let root = std::env::current_dir().unwrap();
619 let simple = BTreeMap::from([("python".to_string(), vec![".".to_string()])]);
620
621 let err = resolve_targets(&root, simple, Vec::new()).unwrap_err();
622
623 assert!(err.to_string().contains("unknown language"));
624 }
625}