Skip to main content

rag_rat_core/
config.rs

1use std::{
2    collections::{BTreeMap, BTreeSet},
3    fs,
4    path::{Path, PathBuf},
5    str::FromStr,
6};
7
8use serde::Deserialize;
9use thiserror::Error;
10
11use crate::language::{Language, LanguageError};
12
13#[derive(Debug, Clone)]
14pub struct Config {
15    pub root: PathBuf,
16    pub database: PathBuf,
17    pub targets: Vec<ResolvedTarget>,
18    pub local_ai: LocalAiConfig,
19}
20
21#[derive(Debug, Clone, Default, PartialEq, Eq)]
22pub struct LocalAiConfig {
23    pub embedding: EmbeddingConfig,
24}
25
26#[derive(Debug, Clone, Default, PartialEq, Eq)]
27pub struct EmbeddingConfig {
28    pub runtime: EmbeddingRuntimeConfig,
29}
30
31#[derive(Debug, Clone, PartialEq, Eq)]
32pub struct EmbeddingRuntimeConfig {
33    pub batch_size: u32,
34    pub ort_threads: Option<u32>,
35    pub omp_threads: Option<u32>,
36    pub max_embedding_chars: usize,
37}
38
39impl Default for EmbeddingRuntimeConfig {
40    fn default() -> Self {
41        Self {
42            batch_size: 64,
43            ort_threads: Some(4),
44            omp_threads: Some(1),
45            max_embedding_chars: 4000,
46        }
47    }
48}
49
50#[derive(Debug, Clone, PartialEq, Eq)]
51pub struct ResolvedTarget {
52    pub name: String,
53    pub language: Language,
54    pub directories: Vec<PathBuf>,
55    pub include: Vec<String>,
56    pub exclude: Vec<String>,
57    pub kind: TargetKind,
58}
59
60#[derive(Debug, Clone, Copy, PartialEq, Eq)]
61pub enum TargetKind {
62    Source,
63    Generated,
64    Docs,
65    Tests,
66}
67
68impl TargetKind {
69    pub fn as_str(self) -> &'static str {
70        match self {
71            Self::Source => "source",
72            Self::Generated => "generated",
73            Self::Docs => "docs",
74            Self::Tests => "tests",
75        }
76    }
77}
78
79impl FromStr for TargetKind {
80    type Err = ConfigError;
81
82    fn from_str(value: &str) -> Result<Self, Self::Err> {
83        match value.trim().to_ascii_lowercase().as_str() {
84            "source" => Ok(Self::Source),
85            "generated" => Ok(Self::Generated),
86            "docs" => Ok(Self::Docs),
87            "tests" | "test" => Ok(Self::Tests),
88            other => Err(ConfigError::UnknownTargetKind(other.to_string())),
89        }
90    }
91}
92
93impl Config {
94    pub fn load(path: impl AsRef<Path>) -> Result<Self, ConfigError> {
95        let path = path.as_ref();
96        let text = fs::read_to_string(path)?;
97        let raw: RawConfig = toml::from_str(&text)?;
98        let config_dir = path.parent().unwrap_or_else(|| Path::new("."));
99        let root = config_dir.join(raw.index.root.unwrap_or_else(|| ".".to_string()));
100        let root = normalize_existing_dir(&root)?;
101        let database =
102            root.join(raw.index.database.unwrap_or_else(|| ".rag-rat/index.sqlite".to_string()));
103        let targets = resolve_targets(&root, raw.target_bindings, raw.target)?;
104        let local_ai = raw.local_ai.into();
105
106        Ok(Self { root, database, targets, local_ai })
107    }
108}
109
110fn resolve_targets(
111    root: &Path,
112    simple: BTreeMap<String, Vec<String>>,
113    expanded: Vec<RawTarget>,
114) -> Result<Vec<ResolvedTarget>, ConfigError> {
115    let mut names = BTreeSet::new();
116    let mut targets = Vec::new();
117
118    for (language_name, directories) in simple {
119        let language = Language::from_str(&language_name)?;
120        let kind =
121            if language == Language::Markdown { TargetKind::Docs } else { TargetKind::Source };
122        let name = language.as_str().to_string();
123        push_target(
124            root,
125            &mut names,
126            &mut targets,
127            ResolvedTarget {
128                include: language
129                    .simple_extensions()
130                    .iter()
131                    .map(|ext| format!("**/*.{ext}"))
132                    .collect(),
133                exclude: Vec::new(),
134                name,
135                language,
136                directories: directories.into_iter().map(PathBuf::from).collect(),
137                kind,
138            },
139        )?;
140    }
141
142    for target in expanded {
143        let language = Language::from_str(&target.language)?;
144        let kind = target
145            .kind
146            .as_deref()
147            .map(TargetKind::from_str)
148            .transpose()?
149            .unwrap_or(TargetKind::Source);
150        push_target(
151            root,
152            &mut names,
153            &mut targets,
154            ResolvedTarget {
155                name: target.name,
156                language,
157                directories: target.directories.into_iter().map(PathBuf::from).collect(),
158                include: target.include.unwrap_or_else(|| {
159                    language.simple_extensions().iter().map(|ext| format!("**/*.{ext}")).collect()
160                }),
161                exclude: target.exclude.unwrap_or_default(),
162                kind,
163            },
164        )?;
165    }
166
167    Ok(targets)
168}
169
170fn push_target(
171    root: &Path,
172    names: &mut BTreeSet<String>,
173    targets: &mut Vec<ResolvedTarget>,
174    target: ResolvedTarget,
175) -> Result<(), ConfigError> {
176    if !names.insert(target.name.clone()) {
177        return Err(ConfigError::DuplicateTarget(target.name));
178    }
179    for directory in &target.directories {
180        let full_path = root.join(directory);
181        if !full_path.is_dir() {
182            return Err(ConfigError::MissingDirectory(directory.clone()));
183        }
184    }
185    targets.push(target);
186    Ok(())
187}
188
189fn normalize_existing_dir(path: &Path) -> Result<PathBuf, ConfigError> {
190    let absolute =
191        if path.is_absolute() { path.to_path_buf() } else { std::env::current_dir()?.join(path) };
192    let canonical = absolute.canonicalize()?;
193    if !canonical.is_dir() {
194        return Err(ConfigError::MissingDirectory(canonical));
195    }
196    Ok(canonical)
197}
198
199#[derive(Debug, Deserialize)]
200struct RawConfig {
201    #[serde(default)]
202    index: RawIndex,
203    #[serde(default)]
204    local_ai: RawLocalAi,
205    #[serde(default)]
206    target_bindings: BTreeMap<String, Vec<String>>,
207    #[serde(default, rename = "target")]
208    target: Vec<RawTarget>,
209}
210
211#[derive(Debug, Default, Deserialize)]
212struct RawIndex {
213    root: Option<String>,
214    database: Option<String>,
215}
216
217#[derive(Debug, Default, Deserialize)]
218struct RawLocalAi {
219    #[serde(default)]
220    embedding: RawEmbedding,
221}
222
223impl From<RawLocalAi> for LocalAiConfig {
224    fn from(raw: RawLocalAi) -> Self {
225        Self { embedding: raw.embedding.into() }
226    }
227}
228
229#[derive(Debug, Default, Deserialize)]
230struct RawEmbedding {
231    #[serde(default)]
232    runtime: RawEmbeddingRuntime,
233}
234
235impl From<RawEmbedding> for EmbeddingConfig {
236    fn from(raw: RawEmbedding) -> Self {
237        Self { runtime: raw.runtime.into() }
238    }
239}
240
241#[derive(Debug, Default, Deserialize)]
242struct RawEmbeddingRuntime {
243    batch_size: Option<u32>,
244    ort_threads: Option<u32>,
245    omp_threads: Option<u32>,
246    max_embedding_chars: Option<usize>,
247}
248
249impl From<RawEmbeddingRuntime> for EmbeddingRuntimeConfig {
250    fn from(raw: RawEmbeddingRuntime) -> Self {
251        let default = EmbeddingRuntimeConfig::default();
252        Self {
253            batch_size: raw.batch_size.unwrap_or(default.batch_size),
254            ort_threads: raw.ort_threads.or(default.ort_threads),
255            omp_threads: raw.omp_threads.or(default.omp_threads),
256            max_embedding_chars: raw.max_embedding_chars.unwrap_or(default.max_embedding_chars),
257        }
258    }
259}
260
261#[derive(Debug, Deserialize)]
262struct RawTarget {
263    name: String,
264    language: String,
265    directories: Vec<String>,
266    kind: Option<String>,
267    include: Option<Vec<String>>,
268    exclude: Option<Vec<String>>,
269}
270
271#[derive(Debug, Error)]
272pub enum ConfigError {
273    #[error("failed to read config: {0}")]
274    Io(#[from] std::io::Error),
275    #[error("failed to parse config TOML: {0}")]
276    Toml(#[from] toml::de::Error),
277    #[error("{0}")]
278    Language(#[from] LanguageError),
279    #[error("unknown target kind `{0}`")]
280    UnknownTargetKind(String),
281    #[error("duplicate target name `{0}`")]
282    DuplicateTarget(String),
283    #[error("configured directory does not exist: {0}")]
284    MissingDirectory(PathBuf),
285}
286
287#[cfg(test)]
288mod tests {
289    use super::*;
290
291    #[test]
292    fn parses_simple_and_expanded_targets() {
293        let root = std::env::current_dir().unwrap();
294        let simple = BTreeMap::from([("rust".to_string(), vec![".".to_string()])]);
295        let expanded = vec![RawTarget {
296            name: "generated-ts".to_string(),
297            language: "typescript".to_string(),
298            directories: vec![".".to_string()],
299            kind: Some("generated".to_string()),
300            include: Some(vec!["**/*.ts".to_string()]),
301            exclude: Some(vec!["**/*.map".to_string()]),
302        }];
303
304        let targets = resolve_targets(&root, simple, expanded).unwrap();
305
306        assert_eq!(targets.len(), 2);
307        assert_eq!(targets[0].language, Language::Rust);
308        assert_eq!(targets[1].kind, TargetKind::Generated);
309    }
310
311    #[test]
312    fn embedding_runtime_defaults_match_local_profile() {
313        let runtime = EmbeddingRuntimeConfig::default();
314
315        assert_eq!(runtime.batch_size, 64);
316        assert_eq!(runtime.ort_threads, Some(4));
317        assert_eq!(runtime.omp_threads, Some(1));
318        assert_eq!(runtime.max_embedding_chars, 4000);
319    }
320
321    #[test]
322    fn parses_embedding_runtime_overrides() {
323        let raw: RawConfig = toml::from_str(
324            r#"
325            [index]
326            root = "."
327            database = ".rag-rat/index.sqlite"
328
329            [local_ai.embedding.runtime]
330            batch_size = 128
331            ort_threads = 2
332            omp_threads = 1
333            max_embedding_chars = 5000
334            "#,
335        )
336        .unwrap();
337
338        let local_ai: LocalAiConfig = raw.local_ai.into();
339
340        assert_eq!(
341            local_ai.embedding.runtime,
342            EmbeddingRuntimeConfig {
343                batch_size: 128,
344                ort_threads: Some(2),
345                omp_threads: Some(1),
346                max_embedding_chars: 5000,
347            }
348        );
349    }
350
351    #[test]
352    fn rejects_unknown_language() {
353        let root = std::env::current_dir().unwrap();
354        let simple = BTreeMap::from([("python".to_string(), vec![".".to_string()])]);
355
356        let err = resolve_targets(&root, simple, Vec::new()).unwrap_err();
357
358        assert!(err.to_string().contains("unknown language"));
359    }
360}