use std::path::PathBuf;
use std::sync::Arc;
use std::time::Duration;
use clap::Parser;
use config::{ConfigError, File};
use serde::Deserialize;
const DEFAULT_UPSTREAM_URL: &str = "https://mastodon.social";
const DEFAULT_HOST: &str = "0.0.0.0";
const DEFAULT_PORT: u16 = 8080;
const DEFAULT_DATABASE_PATH: &str = "ivoryvalley.db";
const DEFAULT_MAX_BODY_SIZE: usize = 50 * 1024 * 1024;
const DEFAULT_CONNECT_TIMEOUT_SECS: u64 = 10;
const DEFAULT_REQUEST_TIMEOUT_SECS: u64 = 30;
const DEFAULT_RECORD_TRAFFIC_PATH: Option<&str> = None;
const DEFAULT_CLEANUP_INTERVAL_SECS: u64 = 3600;
const DEFAULT_CLEANUP_MAX_AGE_SECS: u64 = 7 * 24 * 3600;
#[derive(Parser, Debug)]
#[command(name = "ivoryvalley")]
#[command(about = "A Mastodon proxy server for filtering content")]
pub struct CliArgs {
#[arg(long, env = "IV_UPSTREAM_URL")]
pub upstream_url: Option<String>,
#[arg(long, env = "IV_HOST")]
pub host: Option<String>,
#[arg(short, long, env = "IV_PORT")]
pub port: Option<u16>,
#[arg(long, env = "IV_DATABASE_PATH")]
pub database_path: Option<PathBuf>,
#[arg(long, env = "IV_MAX_BODY_SIZE")]
pub max_body_size: Option<usize>,
#[arg(long, env = "IV_CONNECT_TIMEOUT_SECS")]
pub connect_timeout_secs: Option<u64>,
#[arg(long, env = "IV_REQUEST_TIMEOUT_SECS")]
pub request_timeout_secs: Option<u64>,
#[arg(long, env = "IV_RECORD_TRAFFIC_PATH")]
pub record_traffic_path: Option<PathBuf>,
#[arg(long, env = "IV_CLEANUP_INTERVAL_SECS")]
pub cleanup_interval_secs: Option<u64>,
#[arg(long, env = "IV_CLEANUP_MAX_AGE_SECS")]
pub cleanup_max_age_secs: Option<u64>,
#[arg(short, long, env = "IV_CONFIG")]
pub config: Option<PathBuf>,
}
#[derive(Debug, Deserialize, Default)]
#[serde(default)]
struct FileConfig {
upstream_url: Option<String>,
host: Option<String>,
port: Option<u16>,
database_path: Option<PathBuf>,
max_body_size: Option<usize>,
connect_timeout_secs: Option<u64>,
request_timeout_secs: Option<u64>,
record_traffic_path: Option<PathBuf>,
cleanup_interval_secs: Option<u64>,
cleanup_max_age_secs: Option<u64>,
}
#[derive(Debug, Clone)]
pub struct Config {
pub upstream_url: String,
pub host: String,
pub port: u16,
pub database_path: PathBuf,
pub max_body_size: usize,
pub connect_timeout_secs: u64,
pub request_timeout_secs: u64,
pub record_traffic_path: Option<PathBuf>,
pub cleanup_interval_secs: u64,
pub cleanup_max_age_secs: u64,
}
impl Default for Config {
fn default() -> Self {
Self {
upstream_url: DEFAULT_UPSTREAM_URL.to_string(),
host: DEFAULT_HOST.to_string(),
port: DEFAULT_PORT,
database_path: PathBuf::from(DEFAULT_DATABASE_PATH),
max_body_size: DEFAULT_MAX_BODY_SIZE,
connect_timeout_secs: DEFAULT_CONNECT_TIMEOUT_SECS,
request_timeout_secs: DEFAULT_REQUEST_TIMEOUT_SECS,
record_traffic_path: DEFAULT_RECORD_TRAFFIC_PATH.map(PathBuf::from),
cleanup_interval_secs: DEFAULT_CLEANUP_INTERVAL_SECS,
cleanup_max_age_secs: DEFAULT_CLEANUP_MAX_AGE_SECS,
}
}
}
impl Config {
#[allow(dead_code)] pub fn new(upstream_url: &str, host: &str, port: u16, database_path: PathBuf) -> Self {
Self {
upstream_url: upstream_url.to_string(),
host: host.to_string(),
port,
database_path,
max_body_size: DEFAULT_MAX_BODY_SIZE,
connect_timeout_secs: DEFAULT_CONNECT_TIMEOUT_SECS,
request_timeout_secs: DEFAULT_REQUEST_TIMEOUT_SECS,
record_traffic_path: None,
cleanup_interval_secs: DEFAULT_CLEANUP_INTERVAL_SECS,
cleanup_max_age_secs: DEFAULT_CLEANUP_MAX_AGE_SECS,
}
}
#[allow(dead_code)] pub fn with_max_body_size(
upstream_url: &str,
host: &str,
port: u16,
database_path: PathBuf,
max_body_size: usize,
) -> Self {
Self {
upstream_url: upstream_url.to_string(),
host: host.to_string(),
port,
database_path,
max_body_size,
connect_timeout_secs: DEFAULT_CONNECT_TIMEOUT_SECS,
request_timeout_secs: DEFAULT_REQUEST_TIMEOUT_SECS,
record_traffic_path: None,
cleanup_interval_secs: DEFAULT_CLEANUP_INTERVAL_SECS,
cleanup_max_age_secs: DEFAULT_CLEANUP_MAX_AGE_SECS,
}
}
pub fn load() -> Result<Self, ConfigError> {
Self::load_from_args(CliArgs::parse())
}
pub fn load_from_args(args: CliArgs) -> Result<Self, ConfigError> {
let mut config = Config::default();
let file_config = Self::load_file_config(&args.config)?;
if let Some(url) = file_config.upstream_url {
config.upstream_url = url;
}
if let Some(h) = file_config.host {
config.host = h;
}
if let Some(p) = file_config.port {
config.port = p;
}
if let Some(db) = file_config.database_path {
config.database_path = db;
}
if let Some(size) = file_config.max_body_size {
config.max_body_size = size;
}
if let Some(ct) = file_config.connect_timeout_secs {
config.connect_timeout_secs = ct;
}
if let Some(rt) = file_config.request_timeout_secs {
config.request_timeout_secs = rt;
}
if let Some(path) = file_config.record_traffic_path {
config.record_traffic_path = Some(path);
}
if let Some(interval) = file_config.cleanup_interval_secs {
config.cleanup_interval_secs = interval;
}
if let Some(max_age) = file_config.cleanup_max_age_secs {
config.cleanup_max_age_secs = max_age;
}
if let Some(url) = args.upstream_url {
config.upstream_url = url;
}
if let Some(h) = args.host {
config.host = h;
}
if let Some(p) = args.port {
config.port = p;
}
if let Some(db) = args.database_path {
config.database_path = db;
}
if let Some(size) = args.max_body_size {
config.max_body_size = size;
}
if let Some(ct) = args.connect_timeout_secs {
config.connect_timeout_secs = ct;
}
if let Some(rt) = args.request_timeout_secs {
config.request_timeout_secs = rt;
}
if let Some(path) = args.record_traffic_path {
config.record_traffic_path = Some(path);
}
if let Some(interval) = args.cleanup_interval_secs {
config.cleanup_interval_secs = interval;
}
if let Some(max_age) = args.cleanup_max_age_secs {
config.cleanup_max_age_secs = max_age;
}
Ok(config)
}
fn load_file_config(config_path: &Option<PathBuf>) -> Result<FileConfig, ConfigError> {
let mut builder = config::Config::builder();
if let Some(path) = config_path {
builder = builder.add_source(File::from(path.as_path()));
} else {
builder = builder
.add_source(File::with_name("config").required(false))
.add_source(File::with_name("ivoryvalley").required(false));
}
let settings = builder.build()?;
settings.try_deserialize()
}
pub fn bind_addr(&self) -> String {
format!("{}:{}", self.host, self.port)
}
}
#[derive(Clone)]
pub struct AppState {
pub config: Arc<Config>,
pub http_client: reqwest::Client,
pub seen_uri_store: Arc<crate::db::SeenUriStore>,
pub traffic_recorder: Option<Arc<crate::recording::TrafficRecorder>>,
}
impl AppState {
pub fn new(config: Config, seen_store: Arc<crate::db::SeenUriStore>) -> Self {
let http_client = reqwest::Client::builder()
.redirect(reqwest::redirect::Policy::none())
.connect_timeout(Duration::from_secs(config.connect_timeout_secs))
.timeout(Duration::from_secs(config.request_timeout_secs))
.build()
.expect("Failed to create HTTP client");
let traffic_recorder = config.record_traffic_path.as_ref().and_then(|path| {
match crate::recording::TrafficRecorder::new(path.clone()) {
Ok(recorder) => {
tracing::info!("Traffic recording enabled: {}", path.display());
Some(Arc::new(recorder))
}
Err(e) => {
tracing::error!("Failed to initialize traffic recorder: {}", e);
None
}
}
});
Self {
config: Arc::new(config),
http_client,
seen_uri_store: seen_store,
traffic_recorder,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::io::Write;
use tempfile::NamedTempFile;
#[test]
fn test_config_default() {
let config = Config::default();
assert_eq!(config.upstream_url, "https://mastodon.social");
assert_eq!(config.host, "0.0.0.0");
assert_eq!(config.port, 8080);
assert_eq!(config.database_path, PathBuf::from("ivoryvalley.db"));
assert_eq!(config.max_body_size, 50 * 1024 * 1024);
assert_eq!(config.connect_timeout_secs, 10);
assert_eq!(config.request_timeout_secs, 30);
assert_eq!(config.cleanup_interval_secs, 3600);
assert_eq!(config.cleanup_max_age_secs, 7 * 24 * 3600);
}
#[test]
fn test_config_new() {
let config = Config::new(
"https://example.com",
"127.0.0.1",
3000,
PathBuf::from("/data/test.db"),
);
assert_eq!(config.upstream_url, "https://example.com");
assert_eq!(config.host, "127.0.0.1");
assert_eq!(config.port, 3000);
assert_eq!(config.database_path, PathBuf::from("/data/test.db"));
}
#[test]
fn test_bind_addr() {
let config = Config::new(
"https://mastodon.social",
"127.0.0.1",
3000,
PathBuf::from("test.db"),
);
assert_eq!(config.bind_addr(), "127.0.0.1:3000");
}
#[test]
fn test_load_defaults_when_no_config() {
let mut file = NamedTempFile::with_suffix(".toml").unwrap();
writeln!(file, "# empty config").unwrap();
let args = CliArgs {
upstream_url: None,
host: None,
port: None,
database_path: None,
max_body_size: None,
connect_timeout_secs: None,
request_timeout_secs: None,
record_traffic_path: None,
cleanup_interval_secs: None,
cleanup_max_age_secs: None,
config: Some(file.path().to_path_buf()),
};
let config = Config::load_from_args(args).unwrap();
assert_eq!(config.upstream_url, "https://mastodon.social");
assert_eq!(config.host, "0.0.0.0");
assert_eq!(config.port, 8080);
assert_eq!(config.database_path, PathBuf::from("ivoryvalley.db"));
assert_eq!(config.max_body_size, 50 * 1024 * 1024);
assert_eq!(config.connect_timeout_secs, 10);
assert_eq!(config.request_timeout_secs, 30);
assert_eq!(config.record_traffic_path, None);
assert_eq!(config.cleanup_interval_secs, 3600);
assert_eq!(config.cleanup_max_age_secs, 7 * 24 * 3600);
}
#[test]
fn test_load_from_cli_args() {
let args = CliArgs {
upstream_url: Some("https://cli.example.com".to_string()),
host: Some("192.168.1.1".to_string()),
port: Some(9000),
database_path: Some(PathBuf::from("/cli/path.db")),
max_body_size: Some(100 * 1024 * 1024),
connect_timeout_secs: Some(5),
request_timeout_secs: Some(60),
record_traffic_path: Some(PathBuf::from("/tmp/traffic.jsonl")),
cleanup_interval_secs: Some(1800),
cleanup_max_age_secs: Some(86400),
config: None,
};
let config = Config::load_from_args(args).unwrap();
assert_eq!(config.upstream_url, "https://cli.example.com");
assert_eq!(config.host, "192.168.1.1");
assert_eq!(config.port, 9000);
assert_eq!(config.database_path, PathBuf::from("/cli/path.db"));
assert_eq!(config.max_body_size, 100 * 1024 * 1024);
assert_eq!(config.connect_timeout_secs, 5);
assert_eq!(config.request_timeout_secs, 60);
assert_eq!(
config.record_traffic_path,
Some(PathBuf::from("/tmp/traffic.jsonl"))
);
assert_eq!(config.cleanup_interval_secs, 1800);
assert_eq!(config.cleanup_max_age_secs, 86400);
}
#[test]
fn test_load_from_toml_file() {
let mut file = NamedTempFile::with_suffix(".toml").unwrap();
writeln!(
file,
r#"
upstream_url = "https://toml.example.com"
host = "10.0.0.1"
port = 7000
database_path = "/toml/db.sqlite"
connect_timeout_secs = 15
request_timeout_secs = 45
"#
)
.unwrap();
let args = CliArgs {
upstream_url: None,
host: None,
port: None,
database_path: None,
max_body_size: None,
connect_timeout_secs: None,
request_timeout_secs: None,
record_traffic_path: None,
cleanup_interval_secs: None,
cleanup_max_age_secs: None,
config: Some(file.path().to_path_buf()),
};
let config = Config::load_from_args(args).unwrap();
assert_eq!(config.upstream_url, "https://toml.example.com");
assert_eq!(config.host, "10.0.0.1");
assert_eq!(config.port, 7000);
assert_eq!(config.database_path, PathBuf::from("/toml/db.sqlite"));
assert_eq!(config.connect_timeout_secs, 15);
assert_eq!(config.request_timeout_secs, 45);
}
#[test]
fn test_load_from_yaml_file() {
let mut file = NamedTempFile::with_suffix(".yaml").unwrap();
writeln!(
file,
r#"
upstream_url: "https://yaml.example.com"
host: "10.0.0.2"
port: 6000
database_path: "/yaml/db.sqlite"
connect_timeout_secs: 20
request_timeout_secs: 120
"#
)
.unwrap();
let args = CliArgs {
upstream_url: None,
host: None,
port: None,
database_path: None,
max_body_size: None,
connect_timeout_secs: None,
request_timeout_secs: None,
record_traffic_path: None,
cleanup_interval_secs: None,
cleanup_max_age_secs: None,
config: Some(file.path().to_path_buf()),
};
let config = Config::load_from_args(args).unwrap();
assert_eq!(config.upstream_url, "https://yaml.example.com");
assert_eq!(config.host, "10.0.0.2");
assert_eq!(config.port, 6000);
assert_eq!(config.database_path, PathBuf::from("/yaml/db.sqlite"));
assert_eq!(config.connect_timeout_secs, 20);
assert_eq!(config.request_timeout_secs, 120);
}
#[test]
fn test_cli_overrides_file() {
let mut file = NamedTempFile::with_suffix(".toml").unwrap();
writeln!(
file,
r#"
upstream_url = "https://file.example.com"
host = "10.0.0.1"
port = 7000
database_path = "/file/db.sqlite"
connect_timeout_secs = 15
request_timeout_secs = 45
"#
)
.unwrap();
let args = CliArgs {
upstream_url: Some("https://cli.example.com".to_string()),
host: None, port: Some(9999),
database_path: None, max_body_size: None,
connect_timeout_secs: Some(5), request_timeout_secs: None, record_traffic_path: None,
cleanup_interval_secs: None,
cleanup_max_age_secs: None,
config: Some(file.path().to_path_buf()),
};
let config = Config::load_from_args(args).unwrap();
assert_eq!(config.upstream_url, "https://cli.example.com"); assert_eq!(config.host, "10.0.0.1"); assert_eq!(config.port, 9999); assert_eq!(config.database_path, PathBuf::from("/file/db.sqlite")); assert_eq!(config.connect_timeout_secs, 5); assert_eq!(config.request_timeout_secs, 45); }
#[test]
fn test_partial_file_config_uses_defaults() {
let mut file = NamedTempFile::with_suffix(".toml").unwrap();
writeln!(
file,
r#"
upstream_url = "https://partial.example.com"
"#
)
.unwrap();
let args = CliArgs {
upstream_url: None,
host: None,
port: None,
database_path: None,
max_body_size: None,
connect_timeout_secs: None,
request_timeout_secs: None,
record_traffic_path: None,
cleanup_interval_secs: None,
cleanup_max_age_secs: None,
config: Some(file.path().to_path_buf()),
};
let config = Config::load_from_args(args).unwrap();
assert_eq!(config.upstream_url, "https://partial.example.com"); assert_eq!(config.host, "0.0.0.0"); assert_eq!(config.port, 8080); assert_eq!(config.database_path, PathBuf::from("ivoryvalley.db")); assert_eq!(config.max_body_size, 50 * 1024 * 1024); assert_eq!(config.connect_timeout_secs, 10); assert_eq!(config.request_timeout_secs, 30); assert_eq!(config.record_traffic_path, None); assert_eq!(config.cleanup_interval_secs, 3600); assert_eq!(config.cleanup_max_age_secs, 7 * 24 * 3600); }
use std::sync::Mutex;
static ENV_VAR_TEST_MUTEX: Mutex<()> = Mutex::new(());
struct EnvVarGuard<'a> {
vars: Vec<&'static str>,
_lock: std::sync::MutexGuard<'a, ()>,
}
impl<'a> EnvVarGuard<'a> {
fn new(vars: &[(&'static str, &str)]) -> Self {
let lock = ENV_VAR_TEST_MUTEX.lock().unwrap();
for (key, value) in vars {
std::env::set_var(key, value);
}
Self {
vars: vars.iter().map(|(k, _)| *k).collect(),
_lock: lock,
}
}
}
impl Drop for EnvVarGuard<'_> {
fn drop(&mut self) {
for var in &self.vars {
std::env::remove_var(var);
}
}
}
#[test]
fn test_env_var_prefix_uses_iv() {
let _guard = EnvVarGuard::new(&[
("IV_HOST", "192.168.99.1"),
("IV_PORT", "9999"),
("IV_UPSTREAM_URL", "https://env.example.com"),
]);
let args = CliArgs::try_parse_from(["ivoryvalley"]).unwrap();
assert_eq!(args.host, Some("192.168.99.1".to_string()));
assert_eq!(args.port, Some(9999));
assert_eq!(
args.upstream_url,
Some("https://env.example.com".to_string())
);
}
#[test]
fn test_kubernetes_style_env_vars_ignored() {
let _guard = EnvVarGuard::new(&[
("IVORYVALLEY_PORT", "tcp://10.43.62.146:80"),
("IVORYVALLEY_PORT_80_TCP", "tcp://10.43.62.146:80"),
]);
let args = CliArgs::try_parse_from(["ivoryvalley"]).unwrap();
assert_eq!(args.port, None);
}
}