Skip to main content

atrg_core/
config.rs

1//! Configuration types and loader for `atrg.toml`.
2//!
3//! The [`Config`] struct is the single source of truth for all framework
4//! configuration. It is loaded once at startup by [`Config::load`] and then
5//! wrapped in an `Arc` inside [`AppState`](crate::state::AppState).
6
7use std::path::Path;
8
9use axum::http;
10use serde::Deserialize;
11use url::Url;
12
13// ---------------------------------------------------------------------------
14// Top-level config
15// ---------------------------------------------------------------------------
16
17/// Root configuration, deserialized from `atrg.toml`.
18#[derive(Debug, Clone, Deserialize)]
19pub struct Config {
20    /// Application-level settings.
21    pub app: AppConfig,
22
23    /// OAuth / authentication settings.
24    #[serde(default)]
25    pub auth: AuthConfig,
26
27    /// Database connection settings.
28    #[serde(default)]
29    pub database: DatabaseConfig,
30
31    /// Optional Jetstream real-time event consumer settings.
32    pub jetstream: Option<JetstreamConfig>,
33
34    /// Optional relay firehose consumer settings.
35    pub firehose: Option<FirehoseConfig>,
36
37    /// Optional feed generator settings.
38    pub feed_generator: Option<FeedGeneratorConfig>,
39
40    /// Optional labeler settings.
41    pub labeler: Option<LabelerConfig>,
42
43    /// Optional rate limiting settings.
44    pub rate_limit: Option<RateLimitTomlConfig>,
45}
46
47// ---------------------------------------------------------------------------
48// AppConfig
49// ---------------------------------------------------------------------------
50
51/// `[app]` section of `atrg.toml`.
52#[derive(Debug, Clone, Deserialize)]
53pub struct AppConfig {
54    /// Human-readable application name. Must be non-empty.
55    pub name: String,
56
57    /// Bind address for the HTTP server.
58    #[serde(default = "default_host")]
59    pub host: String,
60
61    /// Bind port for the HTTP server.
62    #[serde(default = "default_port")]
63    pub port: u16,
64
65    /// Secret key used for session signing. Should be ≥ 32 characters in
66    /// production.
67    pub secret_key: String,
68
69    /// Allowed CORS origins. An empty list means same-origin only. A single
70    /// `"*"` entry enables the permissive wildcard.
71    #[serde(default)]
72    pub cors_origins: Vec<String>,
73
74    /// `"development"` or `"production"`. Affects cookie flags and security
75    /// headers.
76    #[serde(default = "default_environment")]
77    pub environment: String,
78
79    /// DIDs to auto-provision as admin on startup. Populated from `atrg.toml`
80    /// or the `ATRG_APP__ADMIN_DIDS` env var (comma-separated).
81    #[serde(default)]
82    pub admin_dids: Vec<String>,
83}
84
85impl Default for AppConfig {
86    fn default() -> Self {
87        Self {
88            name: String::new(),
89            host: default_host(),
90            port: default_port(),
91            secret_key: String::new(),
92            cors_origins: Vec::new(),
93            environment: default_environment(),
94            admin_dids: Vec::new(),
95        }
96    }
97}
98
99fn default_host() -> String {
100    "127.0.0.1".to_string()
101}
102
103fn default_port() -> u16 {
104    3000
105}
106
107fn default_environment() -> String {
108    "development".to_string()
109}
110
111// ---------------------------------------------------------------------------
112// AuthConfig
113// ---------------------------------------------------------------------------
114
115/// `[auth]` section of `atrg.toml`.
116#[derive(Debug, Clone, Deserialize)]
117pub struct AuthConfig {
118    /// AT Protocol OAuth client ID (must be a valid URL).
119    #[serde(default = "default_client_id")]
120    pub client_id: String,
121
122    /// OAuth redirect URI (must be a valid URL).
123    #[serde(default = "default_redirect_uri")]
124    pub redirect_uri: String,
125
126    /// OAuth scope string.
127    #[serde(default = "default_scope")]
128    pub scope: String,
129
130    /// URL to redirect the browser to after successful OAuth login.
131    /// This is the **frontend** URL, not the OAuth callback.
132    /// Defaults to `"/"`.
133    #[serde(default = "default_post_login_redirect")]
134    pub post_login_redirect: String,
135}
136
137impl Default for AuthConfig {
138    fn default() -> Self {
139        Self {
140            client_id: default_client_id(),
141            redirect_uri: default_redirect_uri(),
142            scope: default_scope(),
143            post_login_redirect: default_post_login_redirect(),
144        }
145    }
146}
147
148fn default_client_id() -> String {
149    "http://localhost:3000/client-metadata.json".to_string()
150}
151
152fn default_redirect_uri() -> String {
153    "http://localhost:3000/auth/callback".to_string()
154}
155
156fn default_scope() -> String {
157    "atproto transition:generic".to_string()
158}
159
160fn default_post_login_redirect() -> String {
161    "/".to_string()
162}
163
164// ---------------------------------------------------------------------------
165// DatabaseConfig
166// ---------------------------------------------------------------------------
167
168/// `[database]` section of `atrg.toml`.
169#[derive(Debug, Clone, Deserialize)]
170pub struct DatabaseConfig {
171    /// SQLite connection URL.
172    #[serde(default = "default_database_url")]
173    pub url: String,
174}
175
176impl Default for DatabaseConfig {
177    fn default() -> Self {
178        Self {
179            url: default_database_url(),
180        }
181    }
182}
183
184fn default_database_url() -> String {
185    "sqlite://atrg.db".to_string()
186}
187
188// ---------------------------------------------------------------------------
189// JetstreamConfig
190// ---------------------------------------------------------------------------
191
192/// `[jetstream]` section of `atrg.toml`. Only present when Jetstream
193/// consumption is enabled.
194#[derive(Debug, Clone, Deserialize)]
195pub struct JetstreamConfig {
196    /// Jetstream relay host, e.g. `"jetstream1.us-east.bsky.network"`.
197    pub host: String,
198
199    /// NSID collections to subscribe to, e.g. `["app.bsky.feed.post"]`.
200    pub collections: Vec<String>,
201
202    /// Optional path or URL to a ZSTD dictionary for decompression.
203    pub zstd_dict: Option<String>,
204
205    /// Bounded back-pressure channel size.
206    #[serde(default = "default_channel_capacity")]
207    pub channel_capacity: usize,
208
209    /// Event lag threshold before shedding/warning.
210    #[serde(default = "default_max_lag_events")]
211    pub max_lag_events: usize,
212}
213
214fn default_channel_capacity() -> usize {
215    1024
216}
217
218fn default_max_lag_events() -> usize {
219    10_000
220}
221
222// ---------------------------------------------------------------------------
223// FirehoseConfig
224// ---------------------------------------------------------------------------
225
226/// `[firehose]` section of `atrg.toml`. Present when relay firehose
227/// consumption is enabled (full `com.atproto.sync.subscribeRepos`).
228#[derive(Debug, Clone, Deserialize)]
229pub struct FirehoseConfig {
230    /// Relay WebSocket URL, e.g. `"wss://bsky.network"`.
231    pub relay: String,
232
233    /// Sequence number to resume from. `None` means start from head.
234    pub cursor: Option<i64>,
235
236    /// Bounded back-pressure channel capacity.
237    #[serde(default = "default_firehose_channel_capacity")]
238    pub channel_capacity: usize,
239}
240
241fn default_firehose_channel_capacity() -> usize {
242    1024
243}
244
245// ---------------------------------------------------------------------------
246// FeedGeneratorConfig
247// ---------------------------------------------------------------------------
248
249/// `[feed_generator]` section of `atrg.toml`. Present when the server
250/// acts as an AT Protocol feed generator.
251#[derive(Debug, Clone, Deserialize)]
252pub struct FeedGeneratorConfig {
253    /// DID of the feed generator service (typically `did:web:<hostname>`).
254    pub did: String,
255}
256
257// ---------------------------------------------------------------------------
258// LabelerConfig
259// ---------------------------------------------------------------------------
260
261/// `[labeler]` section of `atrg.toml`. Present when the server acts as
262/// an AT Protocol labeler.
263#[derive(Debug, Clone, Deserialize)]
264pub struct LabelerConfig {
265    /// DID of the labeler service.
266    pub did: String,
267
268    /// Path to the signing key file (PEM format).
269    pub signing_key_path: Option<String>,
270
271    /// Inline signing key (base64-encoded, for env var injection).
272    pub signing_key_base64: Option<String>,
273}
274
275// ---------------------------------------------------------------------------
276// RateLimitConfig (TOML)
277// ---------------------------------------------------------------------------
278
279/// `[rate_limit]` section of `atrg.toml`.
280#[derive(Debug, Clone, Deserialize)]
281pub struct RateLimitTomlConfig {
282    /// Maximum sustained requests per second.
283    #[serde(default = "default_rps")]
284    pub requests_per_second: f64,
285
286    /// Maximum burst size.
287    #[serde(default = "default_burst")]
288    pub burst: u32,
289
290    /// Whether rate limiting is enabled (default: true in production).
291    #[serde(default = "default_rate_limit_enabled")]
292    pub enabled: bool,
293}
294
295fn default_rps() -> f64 {
296    10.0
297}
298
299fn default_burst() -> u32 {
300    50
301}
302
303fn default_rate_limit_enabled() -> bool {
304    true
305}
306
307// ---------------------------------------------------------------------------
308// Loading & validation
309// ---------------------------------------------------------------------------
310
311impl Config {
312    /// Load and validate a [`Config`] from the TOML file at `path`.
313    ///
314    /// # Errors
315    ///
316    /// Returns an error if the file cannot be read, the TOML is malformed, or
317    /// mandatory validation checks fail (e.g. empty `app.name`).
318    pub fn load(path: impl AsRef<Path>) -> anyhow::Result<Self> {
319        let path = path.as_ref();
320        let contents = std::fs::read_to_string(path).map_err(|e| {
321            anyhow::anyhow!(
322                "Failed to read config file '{}': {}. \
323                 Make sure you're running from a directory that contains atrg.toml.",
324                path.display(),
325                e
326            )
327        })?;
328        Self::parse_toml(&contents)
329    }
330
331    /// Parse and validate a [`Config`] from a TOML string.
332    ///
333    /// This is the inner implementation shared by [`Config::load`] and tests.
334    pub fn parse_toml(toml_str: &str) -> anyhow::Result<Self> {
335        let config: Config = toml::from_str(toml_str).map_err(|e| {
336            // Provide a friendlier message when a required section is missing.
337            let msg = e.to_string();
338            if msg.contains("missing field `app`") {
339                anyhow::anyhow!(
340                    "Config error: the [app] section is required in atrg.toml. \
341                     At minimum you need:\n\n\
342                     [app]\n\
343                     name = \"my-app\"\n\
344                     secret_key = \"some-secret-key\"\n\n\
345                     Full error: {e}"
346                )
347            } else {
348                anyhow::anyhow!("Failed to parse atrg.toml: {e}")
349            }
350        })?;
351
352        config.validate()?;
353        Ok(config)
354    }
355
356    /// Run all validation checks and emit warnings.
357    fn validate(&self) -> anyhow::Result<()> {
358        // -- hard errors ------------------------------------------------
359
360        if self.app.name.trim().is_empty() {
361            anyhow::bail!("Config error: app.name must not be empty");
362        }
363
364        if self.app.secret_key.trim().is_empty() {
365            anyhow::bail!("Config error: app.secret_key must not be empty");
366        }
367
368        // Validate redirect_uri is a proper URL.
369        if Url::parse(&self.auth.redirect_uri).is_err() {
370            anyhow::bail!(
371                "Config error: auth.redirect_uri '{}' is not a valid URL",
372                self.auth.redirect_uri
373            );
374        }
375
376        // Validate client_id is a proper URL.
377        if Url::parse(&self.auth.client_id).is_err() {
378            anyhow::bail!(
379                "Config error: auth.client_id '{}' is not a valid URL",
380                self.auth.client_id
381            );
382        }
383
384        // Validate each CORS origin entry.
385        for origin in &self.app.cors_origins {
386            if origin == "*" {
387                continue; // wildcard is fine
388            }
389            if origin.parse::<http::HeaderValue>().is_err() {
390                anyhow::bail!(
391                    "Config error: cors_origins entry '{}' is not a valid origin",
392                    origin
393                );
394            }
395        }
396
397        // -- soft warnings ---------------------------------------------
398
399        if self.app.secret_key.len() < 32 {
400            tracing::warn!(
401                "app.secret_key is only {} characters — use at least 32 for production",
402                self.app.secret_key.len()
403            );
404        }
405
406        let is_local = self.app.host == "localhost" || self.app.host == "127.0.0.1";
407        if self.app.secret_key == "CHANGE_ME_IN_PRODUCTION" && !is_local {
408            tracing::warn!(
409                "app.secret_key is the scaffold default and host is '{}' — \
410                 change it before deploying!",
411                self.app.host
412            );
413        }
414
415        Ok(())
416    }
417}
418
419// ---------------------------------------------------------------------------
420// App-specific config loading
421// ---------------------------------------------------------------------------
422
423/// Load an app-specific configuration section from `atrg.toml`.
424///
425/// This allows apps to define custom `[section_name]` blocks in `atrg.toml`
426/// and deserialize them into typed structs, with automatic environment
427/// variable overrides using the `{PREFIX}_FIELD` convention.
428///
429/// # Examples
430///
431/// ```rust,ignore
432/// #[derive(serde::Deserialize)]
433/// struct MyAppConfig {
434///     database_url: String,
435///     admin_dids: Vec<String>,
436/// }
437///
438/// let config: MyAppConfig = atrg_core::config::load_app_config("myapp")?;
439/// ```
440///
441/// Fields can be overridden by env vars: set `MYAPP_DATABASE_URL` to override
442/// `[myapp] database_url`. The prefix is derived by uppercasing the section name.
443pub fn load_app_config<T: serde::de::DeserializeOwned>(section_name: &str) -> anyhow::Result<T> {
444    load_app_config_from_path::<T>(section_name, "atrg.toml")
445}
446
447/// Load an app-specific configuration section from a specific TOML file path.
448pub fn load_app_config_from_path<T: serde::de::DeserializeOwned>(
449    section_name: &str,
450    path: &str,
451) -> anyhow::Result<T> {
452    let toml_str = std::fs::read_to_string(path)
453        .map_err(|e| anyhow::anyhow!("Failed to read {}: {}", path, e))?;
454    let toml_val: toml::Value = toml::from_str(&toml_str)
455        .map_err(|e| anyhow::anyhow!("Failed to parse {}: {}", path, e))?;
456    let section = toml_val
457        .get(section_name)
458        .ok_or_else(|| anyhow::anyhow!("Missing [{}] section in {}", section_name, path))?;
459    let config: T = section.clone().try_into().map_err(|e| {
460        anyhow::anyhow!(
461            "Invalid [{}] configuration in {}: {}",
462            section_name,
463            path,
464            e
465        )
466    })?;
467    Ok(config)
468}
469
470// ---------------------------------------------------------------------------
471// Tests
472// ---------------------------------------------------------------------------
473
474#[cfg(test)]
475mod tests {
476    use super::*;
477
478    /// A full config fixture exercising every field.
479    const FULL_CONFIG: &str = r#"
480[app]
481name = "my-app"
482host = "0.0.0.0"
483port = 8080
484secret_key = "super-secret-key-that-is-long-enough"
485cors_origins = ["http://localhost:5173", "https://example.com"]
486environment = "production"
487
488[auth]
489client_id = "https://myapp.example.com/client-metadata.json"
490redirect_uri = "https://myapp.example.com/auth/callback"
491scope = "atproto transition:generic"
492
493[database]
494url = "sqlite://prod.db"
495
496[jetstream]
497host = "jetstream1.us-east.bsky.network"
498collections = ["app.bsky.feed.post", "app.bsky.feed.like"]
499zstd_dict = "/tmp/dict.bin"
500channel_capacity = 2048
501max_lag_events = 20000
502"#;
503
504    /// Minimal config — only the required fields.
505    const MINIMAL_CONFIG: &str = r#"
506[app]
507name = "tiny"
508secret_key = "abcdefghijklmnopqrstuvwxyz123456"
509"#;
510
511    #[test]
512    fn parse_full_config() {
513        let cfg = Config::parse_toml(FULL_CONFIG).expect("should parse full config");
514
515        assert_eq!(cfg.app.name, "my-app");
516        assert_eq!(cfg.app.host, "0.0.0.0");
517        assert_eq!(cfg.app.port, 8080);
518        assert_eq!(cfg.app.environment, "production");
519        assert_eq!(cfg.app.cors_origins.len(), 2);
520
521        assert_eq!(
522            cfg.auth.client_id,
523            "https://myapp.example.com/client-metadata.json"
524        );
525        assert_eq!(
526            cfg.auth.redirect_uri,
527            "https://myapp.example.com/auth/callback"
528        );
529        assert_eq!(cfg.auth.scope, "atproto transition:generic");
530
531        assert_eq!(cfg.database.url, "sqlite://prod.db");
532
533        let js = cfg.jetstream.expect("jetstream should be present");
534        assert_eq!(js.host, "jetstream1.us-east.bsky.network");
535        assert_eq!(js.collections.len(), 2);
536        assert_eq!(js.zstd_dict.as_deref(), Some("/tmp/dict.bin"));
537        assert_eq!(js.channel_capacity, 2048);
538        assert_eq!(js.max_lag_events, 20000);
539    }
540
541    #[test]
542    fn parse_minimal_config_defaults_applied() {
543        let cfg = Config::parse_toml(MINIMAL_CONFIG).expect("should parse minimal config");
544
545        // Explicit values
546        assert_eq!(cfg.app.name, "tiny");
547
548        // Defaults
549        assert_eq!(cfg.app.host, "127.0.0.1");
550        assert_eq!(cfg.app.port, 3000);
551        assert_eq!(cfg.app.environment, "development");
552        assert!(cfg.app.cors_origins.is_empty());
553
554        assert_eq!(
555            cfg.auth.client_id,
556            "http://localhost:3000/client-metadata.json"
557        );
558        assert_eq!(cfg.auth.redirect_uri, "http://localhost:3000/auth/callback");
559        assert_eq!(cfg.auth.scope, "atproto transition:generic");
560
561        assert_eq!(cfg.database.url, "sqlite://atrg.db");
562        assert!(cfg.jetstream.is_none());
563    }
564
565    #[test]
566    fn missing_app_section_gives_friendly_error() {
567        let toml = r#"
568[database]
569url = "sqlite://test.db"
570"#;
571        let err = Config::parse_toml(toml).unwrap_err();
572        let msg = err.to_string();
573        assert!(
574            msg.contains("[app] section is required"),
575            "expected friendly error, got: {msg}"
576        );
577    }
578
579    #[test]
580    fn empty_name_is_rejected() {
581        let toml = r#"
582[app]
583name = ""
584secret_key = "abcdefghijklmnopqrstuvwxyz123456"
585"#;
586        let err = Config::parse_toml(toml).unwrap_err();
587        assert!(
588            err.to_string().contains("app.name must not be empty"),
589            "got: {}",
590            err
591        );
592    }
593
594    #[test]
595    fn empty_secret_key_is_rejected() {
596        let toml = r#"
597[app]
598name = "test"
599secret_key = ""
600"#;
601        let err = Config::parse_toml(toml).unwrap_err();
602        assert!(
603            err.to_string().contains("app.secret_key must not be empty"),
604            "got: {}",
605            err
606        );
607    }
608
609    #[test]
610    fn invalid_redirect_uri_is_rejected() {
611        let toml = r#"
612[app]
613name = "test"
614secret_key = "abcdefghijklmnopqrstuvwxyz123456"
615
616[auth]
617redirect_uri = "not a url at all"
618"#;
619        let err = Config::parse_toml(toml).unwrap_err();
620        let msg = err.to_string();
621        assert!(
622            msg.contains("auth.redirect_uri") && msg.contains("not a valid URL"),
623            "expected redirect_uri error, got: {msg}"
624        );
625    }
626
627    #[test]
628    fn invalid_client_id_is_rejected() {
629        let toml = r#"
630[app]
631name = "test"
632secret_key = "abcdefghijklmnopqrstuvwxyz123456"
633
634[auth]
635client_id = "not a url"
636"#;
637        let err = Config::parse_toml(toml).unwrap_err();
638        let msg = err.to_string();
639        assert!(
640            msg.contains("auth.client_id") && msg.contains("not a valid URL"),
641            "expected client_id error, got: {msg}"
642        );
643    }
644
645    #[test]
646    fn invalid_cors_origin_is_rejected() {
647        let toml = r#"
648[app]
649name = "test"
650secret_key = "abcdefghijklmnopqrstuvwxyz123456"
651cors_origins = ["http://ok.example.com", "\x00bad"]
652"#;
653        let err = Config::parse_toml(toml).unwrap_err();
654        let msg = err.to_string();
655        assert!(
656            msg.contains("cors_origins"),
657            "expected cors origin error, got: {msg}"
658        );
659    }
660
661    #[test]
662    fn wildcard_cors_origin_is_accepted() {
663        let toml = r#"
664[app]
665name = "test"
666secret_key = "abcdefghijklmnopqrstuvwxyz123456"
667cors_origins = ["*"]
668"#;
669        Config::parse_toml(toml).expect("wildcard should be accepted");
670    }
671
672    #[test]
673    fn parse_config_with_firehose_and_feeds() {
674        let toml = r#"
675[app]
676name = "test"
677secret_key = "abcdefghijklmnopqrstuvwxyz123456"
678
679[firehose]
680relay = "wss://bsky.network"
681
682[feed_generator]
683did = "did:web:feeds.example.com"
684
685[labeler]
686did = "did:web:labels.example.com"
687signing_key_path = "/etc/keys/labeler.pem"
688
689[rate_limit]
690requests_per_second = 20.0
691burst = 100
692enabled = true
693"#;
694        let cfg = Config::parse_toml(toml).unwrap();
695        let fh = cfg.firehose.unwrap();
696        assert_eq!(fh.relay, "wss://bsky.network");
697        assert!(fh.cursor.is_none());
698        assert_eq!(fh.channel_capacity, 1024);
699
700        let fg = cfg.feed_generator.unwrap();
701        assert_eq!(fg.did, "did:web:feeds.example.com");
702
703        let lb = cfg.labeler.unwrap();
704        assert_eq!(lb.did, "did:web:labels.example.com");
705        assert_eq!(lb.signing_key_path.unwrap(), "/etc/keys/labeler.pem");
706
707        let rl = cfg.rate_limit.unwrap();
708        assert!((rl.requests_per_second - 20.0).abs() < f64::EPSILON);
709        assert_eq!(rl.burst, 100);
710    }
711
712    #[test]
713    fn new_sections_are_all_optional() {
714        let toml = r#"
715[app]
716name = "test"
717secret_key = "abcdefghijklmnopqrstuvwxyz123456"
718"#;
719        let cfg = Config::parse_toml(toml).unwrap();
720        assert!(cfg.firehose.is_none());
721        assert!(cfg.feed_generator.is_none());
722        assert!(cfg.labeler.is_none());
723        assert!(cfg.rate_limit.is_none());
724    }
725
726    #[test]
727    fn jetstream_defaults_applied() {
728        let toml = r#"
729[app]
730name = "test"
731secret_key = "abcdefghijklmnopqrstuvwxyz123456"
732
733[jetstream]
734host = "jetstream1.us-east.bsky.network"
735collections = ["app.bsky.feed.post"]
736"#;
737        let cfg = Config::parse_toml(toml).unwrap();
738        let js = cfg.jetstream.unwrap();
739        assert_eq!(js.channel_capacity, 1024);
740        assert_eq!(js.max_lag_events, 10_000);
741        assert!(js.zstd_dict.is_none());
742    }
743}