use std::collections::HashMap;
use anyhow::{Context, Result, bail};
use serde::Deserialize;
#[derive(Debug, Clone, Deserialize)]
pub struct SiteRuleConfig {
pub site: SiteConfig,
pub rewrite: RewriteConfig,
#[serde(default)]
pub request: RequestConfig,
pub json: JsonConfig,
#[serde(default, rename = "fetch_additional")]
pub additional_fetches: Vec<AdditionalFetchConfig>,
#[serde(default, rename = "fetch_concurrent")]
pub concurrent_fetches: Vec<ConcurrentFetchConfig>,
#[serde(default)]
pub fallback: Vec<FallbackConfig>,
pub template: TemplateConfig,
#[serde(default)]
pub metadata: MetadataConfig,
#[serde(default)]
pub engagement: EngagementConfig,
}
#[derive(Debug, Clone, Deserialize)]
pub struct AdditionalFetchConfig {
pub prefix: String,
pub rewrite_from: String,
pub rewrite_to: String,
pub accept: Option<String>,
#[serde(default)]
pub json: JsonConfig,
}
#[derive(Debug, Clone, Deserialize)]
pub struct ConcurrentFetchConfig {
pub prefix: String,
pub rewrite_from: String,
pub rewrite_to: String,
pub items_path: String,
#[serde(default)]
pub json: JsonConfig,
pub accept: Option<String>,
pub max_items: Option<usize>,
}
impl ConcurrentFetchConfig {
#[must_use]
pub fn item_limit(&self) -> usize {
self.max_items.unwrap_or(10)
}
}
#[derive(Debug, Clone, Deserialize, PartialEq, Eq, Default)]
#[serde(rename_all = "lowercase")]
pub enum FallbackType {
#[default]
Json,
Html,
}
impl FallbackType {
#[must_use]
pub fn as_str(&self) -> &'static str {
match self {
Self::Json => "json",
Self::Html => "html",
}
}
}
#[derive(Debug, Clone, Deserialize)]
pub struct FallbackConfig {
pub rewrite_from: String,
pub rewrite_to: String,
#[serde(default, rename = "type")]
pub fallback_type: FallbackType,
pub accept: Option<String>,
#[serde(default)]
pub json: JsonConfig,
#[serde(default)]
pub css: HashMap<String, String>,
}
#[derive(Debug, Clone, Deserialize)]
pub struct SiteConfig {
pub name: String,
pub patterns: Vec<String>,
}
#[derive(Debug, Clone, Deserialize)]
pub struct RewriteConfig {
pub from: String,
pub to: String,
}
#[derive(Debug, Clone, Deserialize, Default, PartialEq, Eq)]
#[serde(rename_all = "lowercase")]
pub enum ClientKind {
#[default]
Default,
Standard,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct AuthConfig {
pub env_var: String,
pub header_name: String,
pub bearer: bool,
}
impl AuthConfig {
pub fn parse(s: &str) -> anyhow::Result<Self> {
let rest = s
.strip_prefix("env:")
.ok_or_else(|| anyhow::anyhow!("auth value must start with 'env:' (got '{s}')"))?;
match rest.split_once(':') {
None => Ok(Self {
env_var: rest.to_string(),
header_name: "Authorization".to_string(),
bearer: true,
}),
Some((var, suffix)) => {
let header_name = suffix
.strip_prefix("header=")
.ok_or_else(|| {
anyhow::anyhow!("auth suffix must be 'header=NAME' (got '{suffix}')")
})?
.to_string();
anyhow::ensure!(
!header_name.is_empty(),
"auth header name must not be empty"
);
Ok(Self {
env_var: var.to_string(),
header_name,
bearer: false,
})
}
}
}
#[must_use]
pub fn resolve(&self) -> Option<(String, String)> {
let value = std::env::var(&self.env_var).ok()?;
let header_value = if self.bearer {
format!("Bearer {value}")
} else {
value
};
Some((self.header_name.clone(), header_value))
}
}
#[derive(Debug, Clone, Deserialize, Default)]
pub struct RequestConfig {
#[serde(default)]
pub client: ClientKind,
#[serde(default)]
pub headers: HashMap<String, String>,
pub accept: Option<String>,
pub auth: Option<String>,
pub success_path: Option<String>,
}
#[derive(Debug, Clone, Deserialize, Default)]
pub struct JsonConfig(pub HashMap<String, String>);
#[derive(Debug, Clone, Deserialize)]
pub struct TemplateConfig {
pub format: String,
}
#[derive(Debug, Clone, Deserialize, Default)]
pub struct MetadataConfig {
#[serde(default)]
pub platform: String,
pub author: Option<String>,
pub title_field: Option<String>,
pub published_field: Option<String>,
pub canonical_url_field: Option<String>,
pub media_urls_field: Option<String>,
#[serde(flatten)]
pub extra: HashMap<String, String>,
}
#[derive(Debug, Clone, Deserialize, Default)]
pub struct EngagementConfig {
pub likes: Option<String>,
pub reposts: Option<String>,
pub replies: Option<String>,
pub views: Option<String>,
}
impl SiteRuleConfig {
pub fn from_toml(toml_str: &str) -> Result<Self> {
let config: Self = toml::from_str(toml_str).context("failed to parse site rule TOML")?;
config.validate()?;
Ok(config)
}
fn validate(&self) -> Result<()> {
if self.site.name.is_empty() {
bail!("site.name must not be empty");
}
if self.site.patterns.is_empty() {
bail!(
"site.patterns must not be empty for rule '{}'",
self.site.name
);
}
for pattern in &self.site.patterns {
regex::Regex::new(pattern).with_context(|| {
format!(
"invalid pattern regex '{}' in rule '{}'",
pattern, self.site.name
)
})?;
}
regex::Regex::new(&self.rewrite.from).with_context(|| {
format!(
"invalid rewrite.from regex '{}' in rule '{}'",
self.rewrite.from, self.site.name
)
})?;
if self.template.format.is_empty() {
bail!(
"template.format must not be empty in rule '{}'",
self.site.name
);
}
for (i, af) in self.additional_fetches.iter().enumerate() {
if af.prefix.is_empty() {
bail!(
"fetch_additional[{i}].prefix must not be empty in rule '{}'",
self.site.name
);
}
regex::Regex::new(&af.rewrite_from).with_context(|| {
format!(
"invalid fetch_additional[{i}].rewrite_from regex '{}' in rule '{}'",
af.rewrite_from, self.site.name
)
})?;
}
for (i, cf) in self.concurrent_fetches.iter().enumerate() {
if cf.prefix.is_empty() {
bail!(
"fetch_concurrent[{i}].prefix must not be empty in rule '{}'",
self.site.name
);
}
regex::Regex::new(&cf.rewrite_from).with_context(|| {
format!(
"invalid fetch_concurrent[{i}].rewrite_from regex '{}' in rule '{}'",
cf.rewrite_from, self.site.name
)
})?;
}
for (i, fb) in self.fallback.iter().enumerate() {
regex::Regex::new(&fb.rewrite_from).with_context(|| {
format!(
"invalid fallback[{i}].rewrite_from regex '{}' in rule '{}'",
fb.rewrite_from, self.site.name
)
})?;
}
if let Some(auth) = &self.request.auth {
AuthConfig::parse(auth).with_context(|| {
format!(
"invalid request.auth '{}' in rule '{}'",
auth, self.site.name
)
})?;
}
Ok(())
}
}
#[cfg(test)]
#[path = "config_tests.rs"]
mod tests;