use std::path::Path;
use axum::http;
use serde::Deserialize;
use url::Url;
#[derive(Debug, Clone, Deserialize)]
pub struct Config {
pub app: AppConfig,
#[serde(default)]
pub auth: AuthConfig,
#[serde(default)]
pub database: DatabaseConfig,
pub jetstream: Option<JetstreamConfig>,
pub firehose: Option<FirehoseConfig>,
pub feed_generator: Option<FeedGeneratorConfig>,
pub labeler: Option<LabelerConfig>,
pub rate_limit: Option<RateLimitTomlConfig>,
}
#[derive(Debug, Clone, Deserialize)]
pub struct AppConfig {
pub name: String,
#[serde(default = "default_host")]
pub host: String,
#[serde(default = "default_port")]
pub port: u16,
pub secret_key: String,
#[serde(default)]
pub cors_origins: Vec<String>,
#[serde(default = "default_environment")]
pub environment: String,
}
impl Default for AppConfig {
fn default() -> Self {
Self {
name: String::new(),
host: default_host(),
port: default_port(),
secret_key: String::new(),
cors_origins: Vec::new(),
environment: default_environment(),
}
}
}
fn default_host() -> String {
"127.0.0.1".to_string()
}
fn default_port() -> u16 {
3000
}
fn default_environment() -> String {
"development".to_string()
}
#[derive(Debug, Clone, Deserialize)]
pub struct AuthConfig {
#[serde(default = "default_client_id")]
pub client_id: String,
#[serde(default = "default_redirect_uri")]
pub redirect_uri: String,
#[serde(default = "default_scope")]
pub scope: String,
#[serde(default = "default_post_login_redirect")]
pub post_login_redirect: String,
}
impl Default for AuthConfig {
fn default() -> Self {
Self {
client_id: default_client_id(),
redirect_uri: default_redirect_uri(),
scope: default_scope(),
post_login_redirect: default_post_login_redirect(),
}
}
}
fn default_client_id() -> String {
"http://localhost:3000/client-metadata.json".to_string()
}
fn default_redirect_uri() -> String {
"http://localhost:3000/auth/callback".to_string()
}
fn default_scope() -> String {
"atproto transition:generic".to_string()
}
fn default_post_login_redirect() -> String {
"/".to_string()
}
#[derive(Debug, Clone, Deserialize)]
pub struct DatabaseConfig {
#[serde(default = "default_database_url")]
pub url: String,
}
impl Default for DatabaseConfig {
fn default() -> Self {
Self {
url: default_database_url(),
}
}
}
fn default_database_url() -> String {
"sqlite://atrg.db".to_string()
}
#[derive(Debug, Clone, Deserialize)]
pub struct JetstreamConfig {
pub host: String,
pub collections: Vec<String>,
pub zstd_dict: Option<String>,
#[serde(default = "default_channel_capacity")]
pub channel_capacity: usize,
#[serde(default = "default_max_lag_events")]
pub max_lag_events: usize,
}
fn default_channel_capacity() -> usize {
1024
}
fn default_max_lag_events() -> usize {
10_000
}
#[derive(Debug, Clone, Deserialize)]
pub struct FirehoseConfig {
pub relay: String,
pub cursor: Option<i64>,
#[serde(default = "default_firehose_channel_capacity")]
pub channel_capacity: usize,
}
fn default_firehose_channel_capacity() -> usize {
1024
}
#[derive(Debug, Clone, Deserialize)]
pub struct FeedGeneratorConfig {
pub did: String,
}
#[derive(Debug, Clone, Deserialize)]
pub struct LabelerConfig {
pub did: String,
pub signing_key_path: Option<String>,
pub signing_key_base64: Option<String>,
}
#[derive(Debug, Clone, Deserialize)]
pub struct RateLimitTomlConfig {
#[serde(default = "default_rps")]
pub requests_per_second: f64,
#[serde(default = "default_burst")]
pub burst: u32,
#[serde(default = "default_rate_limit_enabled")]
pub enabled: bool,
}
fn default_rps() -> f64 {
10.0
}
fn default_burst() -> u32 {
50
}
fn default_rate_limit_enabled() -> bool {
true
}
impl Config {
pub fn load(path: impl AsRef<Path>) -> anyhow::Result<Self> {
let path = path.as_ref();
let contents = std::fs::read_to_string(path).map_err(|e| {
anyhow::anyhow!(
"Failed to read config file '{}': {}. \
Make sure you're running from a directory that contains atrg.toml.",
path.display(),
e
)
})?;
Self::parse_toml(&contents)
}
pub fn parse_toml(toml_str: &str) -> anyhow::Result<Self> {
let config: Config = toml::from_str(toml_str).map_err(|e| {
let msg = e.to_string();
if msg.contains("missing field `app`") {
anyhow::anyhow!(
"Config error: the [app] section is required in atrg.toml. \
At minimum you need:\n\n\
[app]\n\
name = \"my-app\"\n\
secret_key = \"some-secret-key\"\n\n\
Full error: {e}"
)
} else {
anyhow::anyhow!("Failed to parse atrg.toml: {e}")
}
})?;
config.validate()?;
Ok(config)
}
fn validate(&self) -> anyhow::Result<()> {
if self.app.name.trim().is_empty() {
anyhow::bail!("Config error: app.name must not be empty");
}
if self.app.secret_key.trim().is_empty() {
anyhow::bail!("Config error: app.secret_key must not be empty");
}
if Url::parse(&self.auth.redirect_uri).is_err() {
anyhow::bail!(
"Config error: auth.redirect_uri '{}' is not a valid URL",
self.auth.redirect_uri
);
}
if Url::parse(&self.auth.client_id).is_err() {
anyhow::bail!(
"Config error: auth.client_id '{}' is not a valid URL",
self.auth.client_id
);
}
for origin in &self.app.cors_origins {
if origin == "*" {
continue; }
if origin.parse::<http::HeaderValue>().is_err() {
anyhow::bail!(
"Config error: cors_origins entry '{}' is not a valid origin",
origin
);
}
}
if self.app.secret_key.len() < 32 {
tracing::warn!(
"app.secret_key is only {} characters — use at least 32 for production",
self.app.secret_key.len()
);
}
let is_local = self.app.host == "localhost" || self.app.host == "127.0.0.1";
if self.app.secret_key == "CHANGE_ME_IN_PRODUCTION" && !is_local {
tracing::warn!(
"app.secret_key is the scaffold default and host is '{}' — \
change it before deploying!",
self.app.host
);
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
const FULL_CONFIG: &str = r#"
[app]
name = "my-app"
host = "0.0.0.0"
port = 8080
secret_key = "super-secret-key-that-is-long-enough"
cors_origins = ["http://localhost:5173", "https://example.com"]
environment = "production"
[auth]
client_id = "https://myapp.example.com/client-metadata.json"
redirect_uri = "https://myapp.example.com/auth/callback"
scope = "atproto transition:generic"
[database]
url = "sqlite://prod.db"
[jetstream]
host = "jetstream1.us-east.bsky.network"
collections = ["app.bsky.feed.post", "app.bsky.feed.like"]
zstd_dict = "/tmp/dict.bin"
channel_capacity = 2048
max_lag_events = 20000
"#;
const MINIMAL_CONFIG: &str = r#"
[app]
name = "tiny"
secret_key = "abcdefghijklmnopqrstuvwxyz123456"
"#;
#[test]
fn parse_full_config() {
let cfg = Config::parse_toml(FULL_CONFIG).expect("should parse full config");
assert_eq!(cfg.app.name, "my-app");
assert_eq!(cfg.app.host, "0.0.0.0");
assert_eq!(cfg.app.port, 8080);
assert_eq!(cfg.app.environment, "production");
assert_eq!(cfg.app.cors_origins.len(), 2);
assert_eq!(
cfg.auth.client_id,
"https://myapp.example.com/client-metadata.json"
);
assert_eq!(
cfg.auth.redirect_uri,
"https://myapp.example.com/auth/callback"
);
assert_eq!(cfg.auth.scope, "atproto transition:generic");
assert_eq!(cfg.database.url, "sqlite://prod.db");
let js = cfg.jetstream.expect("jetstream should be present");
assert_eq!(js.host, "jetstream1.us-east.bsky.network");
assert_eq!(js.collections.len(), 2);
assert_eq!(js.zstd_dict.as_deref(), Some("/tmp/dict.bin"));
assert_eq!(js.channel_capacity, 2048);
assert_eq!(js.max_lag_events, 20000);
}
#[test]
fn parse_minimal_config_defaults_applied() {
let cfg = Config::parse_toml(MINIMAL_CONFIG).expect("should parse minimal config");
assert_eq!(cfg.app.name, "tiny");
assert_eq!(cfg.app.host, "127.0.0.1");
assert_eq!(cfg.app.port, 3000);
assert_eq!(cfg.app.environment, "development");
assert!(cfg.app.cors_origins.is_empty());
assert_eq!(
cfg.auth.client_id,
"http://localhost:3000/client-metadata.json"
);
assert_eq!(cfg.auth.redirect_uri, "http://localhost:3000/auth/callback");
assert_eq!(cfg.auth.scope, "atproto transition:generic");
assert_eq!(cfg.database.url, "sqlite://atrg.db");
assert!(cfg.jetstream.is_none());
}
#[test]
fn missing_app_section_gives_friendly_error() {
let toml = r#"
[database]
url = "sqlite://test.db"
"#;
let err = Config::parse_toml(toml).unwrap_err();
let msg = err.to_string();
assert!(
msg.contains("[app] section is required"),
"expected friendly error, got: {msg}"
);
}
#[test]
fn empty_name_is_rejected() {
let toml = r#"
[app]
name = ""
secret_key = "abcdefghijklmnopqrstuvwxyz123456"
"#;
let err = Config::parse_toml(toml).unwrap_err();
assert!(
err.to_string().contains("app.name must not be empty"),
"got: {}",
err
);
}
#[test]
fn empty_secret_key_is_rejected() {
let toml = r#"
[app]
name = "test"
secret_key = ""
"#;
let err = Config::parse_toml(toml).unwrap_err();
assert!(
err.to_string().contains("app.secret_key must not be empty"),
"got: {}",
err
);
}
#[test]
fn invalid_redirect_uri_is_rejected() {
let toml = r#"
[app]
name = "test"
secret_key = "abcdefghijklmnopqrstuvwxyz123456"
[auth]
redirect_uri = "not a url at all"
"#;
let err = Config::parse_toml(toml).unwrap_err();
let msg = err.to_string();
assert!(
msg.contains("auth.redirect_uri") && msg.contains("not a valid URL"),
"expected redirect_uri error, got: {msg}"
);
}
#[test]
fn invalid_client_id_is_rejected() {
let toml = r#"
[app]
name = "test"
secret_key = "abcdefghijklmnopqrstuvwxyz123456"
[auth]
client_id = "not a url"
"#;
let err = Config::parse_toml(toml).unwrap_err();
let msg = err.to_string();
assert!(
msg.contains("auth.client_id") && msg.contains("not a valid URL"),
"expected client_id error, got: {msg}"
);
}
#[test]
fn invalid_cors_origin_is_rejected() {
let toml = r#"
[app]
name = "test"
secret_key = "abcdefghijklmnopqrstuvwxyz123456"
cors_origins = ["http://ok.example.com", "\x00bad"]
"#;
let err = Config::parse_toml(toml).unwrap_err();
let msg = err.to_string();
assert!(
msg.contains("cors_origins"),
"expected cors origin error, got: {msg}"
);
}
#[test]
fn wildcard_cors_origin_is_accepted() {
let toml = r#"
[app]
name = "test"
secret_key = "abcdefghijklmnopqrstuvwxyz123456"
cors_origins = ["*"]
"#;
Config::parse_toml(toml).expect("wildcard should be accepted");
}
#[test]
fn parse_config_with_firehose_and_feeds() {
let toml = r#"
[app]
name = "test"
secret_key = "abcdefghijklmnopqrstuvwxyz123456"
[firehose]
relay = "wss://bsky.network"
[feed_generator]
did = "did:web:feeds.example.com"
[labeler]
did = "did:web:labels.example.com"
signing_key_path = "/etc/keys/labeler.pem"
[rate_limit]
requests_per_second = 20.0
burst = 100
enabled = true
"#;
let cfg = Config::parse_toml(toml).unwrap();
let fh = cfg.firehose.unwrap();
assert_eq!(fh.relay, "wss://bsky.network");
assert!(fh.cursor.is_none());
assert_eq!(fh.channel_capacity, 1024);
let fg = cfg.feed_generator.unwrap();
assert_eq!(fg.did, "did:web:feeds.example.com");
let lb = cfg.labeler.unwrap();
assert_eq!(lb.did, "did:web:labels.example.com");
assert_eq!(lb.signing_key_path.unwrap(), "/etc/keys/labeler.pem");
let rl = cfg.rate_limit.unwrap();
assert!((rl.requests_per_second - 20.0).abs() < f64::EPSILON);
assert_eq!(rl.burst, 100);
}
#[test]
fn new_sections_are_all_optional() {
let toml = r#"
[app]
name = "test"
secret_key = "abcdefghijklmnopqrstuvwxyz123456"
"#;
let cfg = Config::parse_toml(toml).unwrap();
assert!(cfg.firehose.is_none());
assert!(cfg.feed_generator.is_none());
assert!(cfg.labeler.is_none());
assert!(cfg.rate_limit.is_none());
}
#[test]
fn jetstream_defaults_applied() {
let toml = r#"
[app]
name = "test"
secret_key = "abcdefghijklmnopqrstuvwxyz123456"
[jetstream]
host = "jetstream1.us-east.bsky.network"
collections = ["app.bsky.feed.post"]
"#;
let cfg = Config::parse_toml(toml).unwrap();
let js = cfg.jetstream.unwrap();
assert_eq!(js.channel_capacity, 1024);
assert_eq!(js.max_lag_events, 10_000);
assert!(js.zstd_dict.is_none());
}
}