Skip to main content

courier/config/
loader.rs

1use std::path::{Path, PathBuf};
2
3use anyhow::{Context, Result};
4
5use super::parse::parse_by_extension;
6use super::redact::redact_secret;
7use super::types::Config;
8
9impl Config {
10    pub fn load(path: impl AsRef<Path>) -> Result<Self> {
11        let path = path.as_ref();
12        if path.is_dir() {
13            Self::load_dir(path)
14        } else {
15            Self::load_file(path)
16        }
17    }
18
19    fn load_file(path: &Path) -> Result<Self> {
20        let content = std::fs::read_to_string(path)
21            .with_context(|| format!("failed to read config file {}", path.display()))?;
22        parse_by_extension(path, &content, path.parent())
23            .with_context(|| format!("failed to load config file {}", path.display()))
24    }
25
26    fn load_dir(dir: &Path) -> Result<Self> {
27        let mut files: Vec<PathBuf> = std::fs::read_dir(dir)
28            .with_context(|| format!("failed to read config directory {}", dir.display()))?
29            .filter_map(|entry| entry.ok())
30            .map(|entry| entry.path())
31            .filter(|p| p.is_file() && is_supported_config_extension(p))
32            .collect();
33        files.sort();
34
35        let mut merged = Config {
36            pipelines: Vec::new(),
37            observability: None,
38        };
39        let mut seen_names: std::collections::HashMap<String, PathBuf> =
40            std::collections::HashMap::new();
41        let mut observability_file: Option<PathBuf> = None;
42
43        for file in files {
44            let Config {
45                pipelines,
46                observability,
47            } = Self::load_file(&file)?;
48
49            if let Some(observability) = observability {
50                if let Some(prev) = observability_file.replace(file.clone()) {
51                    anyhow::bail!(
52                        "duplicate observability config in {} (also defined in {})",
53                        file.display(),
54                        prev.display(),
55                    );
56                }
57                merged.observability = Some(observability);
58            }
59
60            for pipeline in pipelines {
61                if let Some(prev) = seen_names.insert(pipeline.name.clone(), file.clone()) {
62                    anyhow::bail!(
63                        "duplicate pipeline name '{}' in {} (also defined in {})",
64                        redact_secret(&pipeline.name),
65                        file.display(),
66                        prev.display(),
67                    );
68                }
69                merged.pipelines.push(pipeline);
70            }
71        }
72
73        Ok(merged)
74    }
75}
76
77fn is_supported_config_extension(path: &Path) -> bool {
78    matches!(
79        path.extension().and_then(|s| s.to_str()),
80        Some("toml" | "json")
81    )
82}
83
84#[cfg(test)]
85mod tests {
86    use crate::config::{Config, ENV_LOCK, ErrorPolicyConfig, LogFormat};
87
88    fn set_env_var(key: &str, value: &str) {
89        unsafe {
90            std::env::set_var(key, value);
91        }
92    }
93
94    fn remove_env_var(key: &str) {
95        unsafe {
96            std::env::remove_var(key);
97        }
98    }
99
100    #[test]
101    fn load_reads_file_from_disk() {
102        let dir = tempfile::tempdir().unwrap();
103        let path = dir.path().join("courier.toml");
104        std::fs::write(
105            &path,
106            r#"
107            [[pipelines]]
108            name = "from-disk"
109
110            [pipelines.source]
111            type = "noop"
112
113            [[pipelines.sinks]]
114            type = "noop"
115            "#,
116        )
117        .unwrap();
118
119        let config = Config::load(&path).unwrap();
120        assert_eq!(config.pipelines.len(), 1);
121        assert_eq!(config.pipelines[0].name, "from-disk");
122    }
123
124    #[test]
125    fn load_reports_missing_file_with_path_context() {
126        let err = Config::load("/nonexistent/courier.toml").unwrap_err();
127        let msg = format!("{err:#}");
128        assert!(msg.contains("/nonexistent/courier.toml"), "{msg}");
129    }
130
131    #[test]
132    fn load_dispatches_on_extension() {
133        let dir = tempfile::tempdir().unwrap();
134        let path = dir.path().join("courier.json");
135        std::fs::write(
136            &path,
137            r#"{
138              "pipelines": [
139                {
140                  "name": "from-json",
141                  "source": { "type": "noop" },
142                  "sinks": [{ "type": "noop" }]
143                }
144              ]
145            }"#,
146        )
147        .unwrap();
148
149        let config = Config::load(&path).unwrap();
150        assert_eq!(config.pipelines.len(), 1);
151        assert_eq!(config.pipelines[0].name, "from-json");
152    }
153
154    #[test]
155    fn load_directory_concatenates_pipelines_in_sorted_order() {
156        let dir = tempfile::tempdir().unwrap();
157        std::fs::write(
158            dir.path().join("b.toml"),
159            r#"
160            [[pipelines]]
161            name = "second"
162            [pipelines.source]
163            type = "noop"
164            [[pipelines.sinks]]
165            type = "noop"
166            "#,
167        )
168        .unwrap();
169        std::fs::write(
170            dir.path().join("a.json"),
171            r#"{
172              "pipelines": [
173                {
174                  "name": "first",
175                  "source": { "type": "noop" },
176                  "sinks": [{ "type": "noop" }]
177                }
178              ]
179            }"#,
180        )
181        .unwrap();
182        std::fs::write(dir.path().join("notes.txt"), "ignored").unwrap();
183
184        let config = Config::load(dir.path()).unwrap();
185        let names: Vec<_> = config.pipelines.iter().map(|p| p.name.as_str()).collect();
186        assert_eq!(names, vec!["first", "second"]);
187    }
188
189    #[test]
190    fn load_directory_interpolates_each_toml_and_json_file() {
191        let _guard = ENV_LOCK.lock().unwrap();
192        set_env_var("COURIER_TEST_DIR_SUFFIX", "env");
193        set_env_var("COURIER_TEST_DIR_URL", "https://example.test/data");
194
195        let dir = tempfile::tempdir().unwrap();
196        std::fs::write(
197            dir.path().join("a.json"),
198            r#"{
199              "pipelines": [
200                {
201                  "name": "json-${env:COURIER_TEST_DIR_SUFFIX}",
202                  "source": {
203                    "type": "noop",
204                    "url": "${env:COURIER_TEST_DIR_URL}"
205                  },
206                  "sinks": [{ "type": "noop" }]
207                }
208              ]
209            }"#,
210        )
211        .unwrap();
212        std::fs::write(
213            dir.path().join("b.toml"),
214            r#"
215            [[pipelines]]
216            name = "toml-${env:COURIER_TEST_DIR_SUFFIX}"
217            [pipelines.source]
218            type = "noop"
219            url = "${env:COURIER_TEST_DIR_URL}"
220            [[pipelines.sinks]]
221            type = "noop"
222            "#,
223        )
224        .unwrap();
225
226        let config = Config::load(dir.path()).unwrap();
227        let names: Vec<_> = config.pipelines.iter().map(|p| p.name.as_str()).collect();
228        assert_eq!(names, vec!["json-env", "toml-env"]);
229        assert_eq!(
230            config.pipelines[0].source.config["url"],
231            "https://example.test/data"
232        );
233        assert_eq!(
234            config.pipelines[1].source.config["url"],
235            "https://example.test/data"
236        );
237
238        remove_env_var("COURIER_TEST_DIR_SUFFIX");
239        remove_env_var("COURIER_TEST_DIR_URL");
240    }
241
242    #[test]
243    fn load_resolves_script_file_relative_to_config_file() {
244        let dir = tempfile::tempdir().unwrap();
245        let script_dir = dir.path().join("transforms");
246        std::fs::create_dir(&script_dir).unwrap();
247        std::fs::write(script_dir.join("enrich.rhai"), "fn transform(env) { env }").unwrap();
248
249        let config_path = dir.path().join("courier.toml");
250        std::fs::write(
251            &config_path,
252            r#"
253            [[pipelines]]
254            name = "script-path"
255
256            [pipelines.source]
257            type = "noop"
258
259            [[pipelines.transforms]]
260            type = "script"
261            runtime = "rhai"
262            script_file = "./transforms/enrich.rhai"
263
264            [[pipelines.sinks]]
265            type = "noop"
266            "#,
267        )
268        .unwrap();
269
270        let config = Config::load(&config_path).unwrap();
271        let script_file = config.pipelines[0].transforms[0].config["script_file"]
272            .as_str()
273            .unwrap();
274        assert_eq!(
275            script_file,
276            dir.path()
277                .join("./transforms/enrich.rhai")
278                .to_string_lossy()
279                .as_ref()
280        );
281        assert!(std::path::Path::new(script_file).is_absolute());
282    }
283
284    #[test]
285    fn load_directory_rejects_duplicate_pipeline_names() {
286        let dir = tempfile::tempdir().unwrap();
287        let body = r#"
288            [[pipelines]]
289            name = "dup"
290            [pipelines.source]
291            type = "noop"
292            [[pipelines.sinks]]
293            type = "noop"
294        "#;
295        std::fs::write(dir.path().join("a.toml"), body).unwrap();
296        std::fs::write(dir.path().join("b.toml"), body).unwrap();
297
298        let err = Config::load(dir.path()).unwrap_err();
299        let msg = format!("{err:#}");
300        assert!(msg.contains("duplicate pipeline name 'dup'"), "{msg}");
301    }
302
303    #[test]
304    fn load_directory_propagates_parse_error_with_file_context() {
305        let dir = tempfile::tempdir().unwrap();
306        std::fs::write(dir.path().join("broken.toml"), "not valid toml ===").unwrap();
307
308        let err = Config::load(dir.path()).unwrap_err();
309        let msg = format!("{err:#}");
310        assert!(msg.contains("broken.toml"), "{msg}");
311    }
312
313    #[test]
314    fn load_empty_directory_yields_no_pipelines() {
315        let dir = tempfile::tempdir().unwrap();
316        let config = Config::load(dir.path()).unwrap();
317        assert!(config.pipelines.is_empty());
318    }
319
320    #[test]
321    fn load_rejects_unsupported_extension() {
322        let dir = tempfile::tempdir().unwrap();
323        let path = dir.path().join("courier.yaml");
324        std::fs::write(&path, "pipelines: []").unwrap();
325
326        let err = Config::load(&path).unwrap_err();
327        let msg = format!("{err:#}");
328        assert!(msg.contains("unsupported config file extension"), "{msg}");
329    }
330
331    #[test]
332    fn directory_mode_keeps_defaults_per_file() {
333        let dir = tempfile::tempdir().unwrap();
334        std::fs::write(
335            dir.path().join("a.toml"),
336            r#"
337            [defaults.sink]
338            on_error = "fail_pipeline"
339
340            [[pipelines]]
341            name = "with-default"
342            [pipelines.source]
343            type = "noop"
344            [[pipelines.sinks]]
345            type = "noop"
346            "#,
347        )
348        .unwrap();
349        std::fs::write(
350            dir.path().join("b.toml"),
351            r#"
352            [[pipelines]]
353            name = "no-default"
354            [pipelines.source]
355            type = "noop"
356            [[pipelines.sinks]]
357            type = "noop"
358            "#,
359        )
360        .unwrap();
361
362        let config = Config::load(dir.path()).unwrap();
363        let by_name: std::collections::HashMap<_, _> = config
364            .pipelines
365            .iter()
366            .map(|p| (p.name.as_str(), p))
367            .collect();
368        assert_eq!(
369            by_name["with-default"].sinks[0].on_error,
370            Some(ErrorPolicyConfig::FailPipeline),
371        );
372        assert_eq!(by_name["no-default"].sinks[0].on_error, None);
373    }
374
375    #[test]
376    fn directory_mode_preserves_observability() {
377        let dir = tempfile::tempdir().unwrap();
378        std::fs::write(
379            dir.path().join("a.toml"),
380            r#"
381            [observability]
382            service_name = "courier-prod"
383            log_format = "json"
384            log_level = "courier=debug"
385            log_keys = true
386
387            [observability.metrics]
388            otlp_endpoint = "http://metrics:4317"
389            export_interval_ms = 5000
390
391            [observability.tracing]
392            otlp_endpoint = "http://traces:4317"
393            sample_ratio = 0.25
394
395            [[pipelines]]
396            name = "p1"
397            [pipelines.source]
398            type = "noop"
399            [[pipelines.sinks]]
400            type = "noop"
401            "#,
402        )
403        .unwrap();
404        std::fs::write(
405            dir.path().join("b.toml"),
406            r#"
407            [[pipelines]]
408            name = "p2"
409            [pipelines.source]
410            type = "noop"
411            [[pipelines.sinks]]
412            type = "noop"
413            "#,
414        )
415        .unwrap();
416
417        let config = Config::load(dir.path()).unwrap();
418        assert_eq!(config.pipelines.len(), 2);
419
420        let obs = config.observability.expect("observability should load");
421        assert_eq!(obs.log_format, LogFormat::Json);
422        assert_eq!(obs.log_level.as_deref(), Some("courier=debug"));
423        assert!(obs.log_keys);
424        assert_eq!(
425            obs.metrics.otlp_endpoint.as_deref(),
426            Some("http://metrics:4317")
427        );
428        assert_eq!(obs.metrics.export_interval_ms, 5000);
429        assert_eq!(
430            obs.tracing.otlp_endpoint.as_deref(),
431            Some("http://traces:4317")
432        );
433        assert_eq!(obs.tracing.sample_ratio, 0.25);
434        assert_eq!(obs.service_name, "courier-prod");
435    }
436
437    #[test]
438    fn directory_mode_rejects_duplicate_observability_blocks() {
439        let dir = tempfile::tempdir().unwrap();
440        std::fs::write(
441            dir.path().join("a.toml"),
442            r#"
443            [observability]
444            log_level = "info"
445
446            [[pipelines]]
447            name = "p1"
448            [pipelines.source]
449            type = "noop"
450            [[pipelines.sinks]]
451            type = "noop"
452            "#,
453        )
454        .unwrap();
455        std::fs::write(
456            dir.path().join("b.toml"),
457            r#"
458            [observability]
459            log_level = "debug"
460
461            [[pipelines]]
462            name = "p2"
463            [pipelines.source]
464            type = "noop"
465            [[pipelines.sinks]]
466            type = "noop"
467            "#,
468        )
469        .unwrap();
470
471        let msg = format!("{:#}", Config::load(dir.path()).unwrap_err());
472        assert!(msg.contains("duplicate observability config"), "{msg}");
473        assert!(msg.contains("a.toml"), "{msg}");
474        assert!(msg.contains("b.toml"), "{msg}");
475    }
476}