use serde::{Deserialize, Serialize, Serializer};
use std::{
net::IpAddr,
path::{Path, PathBuf},
};
use figment::{
Figment,
providers::{Env, Format, Serialized, Toml},
};
use crate::errors::ConfigError;
const DEFAULT_PORT: u16 = 5200;
const DEFAULT_OEMBED_TIMEOUT_MS: u64 = 500;
const DEFAULT_OEMBED_CACHE_SIZE: usize = 2 * 1024 * 1024;
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct SortField {
pub field: String,
#[serde(default = "default_sort_order")]
pub order: String,
#[serde(default = "default_sort_compare")]
pub compare: String,
}
fn default_sort_order() -> String {
"asc".to_string()
}
fn default_sort_compare() -> String {
"string".to_string()
}
fn default_link_tracking() -> bool {
true
}
pub fn default_incomplete_markers() -> Vec<String> {
vec![
"TK".to_string(),
"TODO".to_string(),
"FIXME".to_string(),
"XXX".to_string(),
]
}
fn default_build_tag_pages() -> bool {
true
}
fn default_sidebar_style() -> String {
"panel".to_string()
}
const DEFAULT_SIDEBAR_MAX_ITEMS: usize = 100;
fn default_sidebar_max_items() -> usize {
DEFAULT_SIDEBAR_MAX_ITEMS
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct TagSource {
pub field: String,
#[serde(default)]
pub label: Option<String>,
#[serde(default)]
pub label_plural: Option<String>,
}
impl TagSource {
pub fn singular_label(&self) -> String {
if let Some(ref label) = self.label {
return label.clone();
}
let field_name = self.field.rsplit('.').next().unwrap_or(&self.field);
title_case(field_name)
}
pub fn plural_label(&self) -> String {
if let Some(ref label) = self.label_plural {
return label.clone();
}
format!("{}s", self.singular_label())
}
pub fn url_source(&self) -> String {
crate::wikilink::sanitize_path_component(&self.field.to_lowercase())
}
}
fn title_case(s: &str) -> String {
if s.is_empty() {
return String::new();
}
let base = s.strip_suffix('s').unwrap_or(s);
if base.is_empty() {
return "S".to_string();
}
let mut chars = base.chars();
match chars.next() {
Some(first) => first.to_uppercase().chain(chars).collect(),
None => String::new(),
}
}
pub fn default_tag_sources() -> Vec<TagSource> {
vec![TagSource {
field: "tags".to_string(),
label: None,
label_plural: None,
}]
}
pub fn tag_sources_to_set(sources: &[TagSource]) -> std::collections::HashSet<String> {
sources.iter().map(|s| s.field.clone()).collect()
}
pub fn tag_sources_to_url_sources(sources: &[TagSource]) -> Vec<String> {
sources.iter().map(|s| s.url_source()).collect()
}
impl Default for SortField {
fn default() -> Self {
Self {
field: "title".to_string(),
order: default_sort_order(),
compare: default_sort_compare(),
}
}
}
pub fn default_sort_config() -> Vec<SortField> {
vec![SortField::default()]
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct IpArray(pub [u8; 4]);
#[derive(Debug, Deserialize, Serialize, Clone)]
pub struct Config {
pub root_dir: PathBuf,
pub host: IpArray,
pub port: u16,
pub static_folder: String,
pub markdown_extensions: Vec<String>,
pub theme: String,
pub index_file: String,
pub ignore_dirs: Vec<String>,
pub ignore_globs: Vec<String>,
pub watcher_ignore_dirs: Vec<String>,
pub oembed_timeout_ms: u64,
pub oembed_cache_size: usize,
#[serde(default)]
pub template_folder: Option<PathBuf>,
#[serde(default = "default_sort_config")]
pub sort: Vec<SortField>,
#[serde(default)]
pub build_concurrency: Option<usize>,
#[serde(default)]
pub transcode: bool,
#[serde(default)]
pub skip_link_checks: bool,
#[serde(default = "default_link_tracking")]
pub link_tracking: bool,
#[serde(default = "default_tag_sources")]
pub tag_sources: Vec<TagSource>,
#[serde(default = "default_build_tag_pages")]
pub build_tag_pages: bool,
#[serde(default = "default_sidebar_style")]
pub sidebar_style: String,
#[serde(default = "default_sidebar_max_items")]
pub sidebar_max_items: usize,
#[serde(default)]
pub title_prefix: String,
#[serde(default)]
pub title_suffix: String,
#[serde(default = "default_incomplete_markers")]
pub incomplete_markers: Vec<String>,
#[serde(default)]
pub mark_incomplete: Option<bool>,
}
impl std::fmt::Display for IpArray {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let [a, b, c, d] = self.0;
write!(f, "{a}.{b}.{c}.{d}")
}
}
impl<'de> Deserialize<'de> for IpArray {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
let ip_str = String::deserialize(deserializer)?;
let ip: IpAddr = ip_str.parse().map_err(serde::de::Error::custom)?;
match ip {
IpAddr::V4(v4) => Ok(IpArray(v4.octets())),
IpAddr::V6(_) => Err(serde::de::Error::custom("IPv6 addresses are not supported")),
}
}
}
impl Serialize for IpArray {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
let ip = std::net::Ipv4Addr::from(self.0);
serializer.serialize_str(&ip.to_string())
}
}
impl Default for Config {
fn default() -> Self {
Config {
root_dir: std::env::current_dir().unwrap_or_else(|_| PathBuf::from(".")),
host: IpArray([127, 0, 0, 1]),
port: DEFAULT_PORT,
static_folder: "static".to_string(),
markdown_extensions: vec!["md".to_string()],
theme: "default".to_string(),
index_file: "index.md".to_string(),
ignore_dirs: [
"target",
"result",
"build",
"node_modules",
"ci",
"templates",
".git",
".github",
"dist",
"out",
"coverage",
]
.into_iter()
.map(|x| x.to_string())
.collect(),
ignore_globs: [
"*.log", "*.bak", "*.lock", "*.sh", "*.css", "*.scss", "*.js", "*.ts",
]
.into_iter()
.map(|x| x.to_string())
.collect(),
watcher_ignore_dirs: [".direnv", ".git", "result", "target", "build"]
.into_iter()
.map(|x| x.to_string())
.collect(),
oembed_timeout_ms: DEFAULT_OEMBED_TIMEOUT_MS,
oembed_cache_size: DEFAULT_OEMBED_CACHE_SIZE,
template_folder: None,
sort: default_sort_config(),
build_concurrency: None, transcode: false, skip_link_checks: false, link_tracking: true, tag_sources: default_tag_sources(),
build_tag_pages: true, sidebar_style: default_sidebar_style(),
sidebar_max_items: default_sidebar_max_items(),
title_prefix: String::new(),
title_suffix: String::new(),
incomplete_markers: default_incomplete_markers(),
mark_incomplete: None,
}
}
}
fn is_home_dir(path: &Path) -> bool {
std::env::var_os("HOME")
.map(PathBuf::from)
.is_some_and(|home| path == home)
}
pub fn find_root_dir(start_path: &Path) -> PathBuf {
const DIR_MARKERS: &[&str] = &[".mbr", ".git", ".zk", ".obsidian"];
const FILE_MARKERS: &[&str] = &["book.toml", "mkdocs.yml", "docusaurus.config.js"];
let dir = if start_path.is_dir() {
start_path
} else {
start_path.parent().unwrap_or(start_path)
};
for marker in DIR_MARKERS {
if let Some(root) = dir
.ancestors()
.find(|a| a.join(marker).is_dir())
.map(|p| p.to_path_buf())
{
if is_home_dir(&root) {
break;
}
return root;
}
}
for marker in FILE_MARKERS {
if let Some(root) = dir
.ancestors()
.find(|a| a.join(marker).is_file())
.map(|p| p.to_path_buf())
{
if is_home_dir(&root) {
break;
}
return root;
}
}
dir.to_path_buf()
}
impl Config {
pub fn read(search_config_from: &Path) -> Result<Self, crate::MbrError> {
let default_config = Config::default();
let root_dir = find_root_dir(search_config_from);
let mut config: Config = Figment::new()
.merge(Serialized::defaults(default_config))
.merge(Env::prefixed("MBR_"))
.merge(Toml::file(root_dir.join(".mbr/config.toml")))
.extract()
.map_err(|e| ConfigError::ParseFailed(Box::new(e)))?;
tracing::debug!("Loaded config: {:?}", &config);
config.root_dir = root_dir;
config.validate()?;
Ok(config)
}
pub fn validate(&self) -> Result<(), ConfigError> {
if self.port == 0 {
return Err(ConfigError::InvalidPort { port: self.port });
}
if self.sidebar_max_items == 0 {
return Err(ConfigError::InvalidSidebarMaxItems {
value: self.sidebar_max_items,
});
}
if matches!(self.build_concurrency, Some(0)) {
return Err(ConfigError::InvalidBuildConcurrency { value: 0 });
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_title_case() {
assert_eq!(title_case("tags"), "Tag");
assert_eq!(title_case("performers"), "Performer");
assert_eq!(title_case("category"), "Category");
assert_eq!(title_case("Tag"), "Tag");
assert_eq!(title_case("s"), "S");
assert_eq!(title_case(""), "");
}
#[test]
fn test_tag_source_singular_label_explicit() {
let source = TagSource {
field: "taxonomy.performers".to_string(),
label: Some("Performer".to_string()),
label_plural: None,
};
assert_eq!(source.singular_label(), "Performer");
}
#[test]
fn test_tag_source_singular_label_derived() {
let source = TagSource {
field: "tags".to_string(),
label: None,
label_plural: None,
};
assert_eq!(source.singular_label(), "Tag");
}
#[test]
fn test_tag_source_singular_label_derived_nested() {
let source = TagSource {
field: "taxonomy.performers".to_string(),
label: None,
label_plural: None,
};
assert_eq!(source.singular_label(), "Performer");
}
#[test]
fn test_tag_source_plural_label_explicit() {
let source = TagSource {
field: "taxonomy.performers".to_string(),
label: None,
label_plural: Some("Performers".to_string()),
};
assert_eq!(source.plural_label(), "Performers");
}
#[test]
fn test_tag_source_plural_label_derived() {
let source = TagSource {
field: "tags".to_string(),
label: None,
label_plural: None,
};
assert_eq!(source.plural_label(), "Tags");
}
#[test]
fn test_tag_source_url_source() {
let source = TagSource {
field: "Tags".to_string(),
label: None,
label_plural: None,
};
assert_eq!(source.url_source(), "tags");
let source = TagSource {
field: "taxonomy.Performers".to_string(),
label: None,
label_plural: None,
};
assert_eq!(source.url_source(), "taxonomy.performers");
}
#[test]
fn test_default_tag_sources() {
let sources = default_tag_sources();
assert_eq!(sources.len(), 1);
assert_eq!(sources[0].field, "tags");
assert_eq!(sources[0].singular_label(), "Tag");
assert_eq!(sources[0].plural_label(), "Tags");
assert_eq!(sources[0].url_source(), "tags");
}
#[test]
fn test_config_default_has_tag_sources() {
let config = Config::default();
assert_eq!(config.tag_sources.len(), 1);
assert_eq!(config.tag_sources[0].field, "tags");
assert!(config.build_tag_pages);
}
#[test]
fn test_tag_source_serialization() {
let source = TagSource {
field: "taxonomy.tags".to_string(),
label: Some("Tag".to_string()),
label_plural: Some("Tags".to_string()),
};
let json = serde_json::to_string(&source).unwrap();
let parsed: TagSource = serde_json::from_str(&json).unwrap();
assert_eq!(source, parsed);
}
#[test]
fn test_tag_source_deserialization_minimal() {
let json = r#"{"field": "tags"}"#;
let source: TagSource = serde_json::from_str(json).unwrap();
assert_eq!(source.field, "tags");
assert!(source.label.is_none());
assert!(source.label_plural.is_none());
}
#[test]
fn test_validate_default_config_passes() {
let config = Config::default();
assert!(config.validate().is_ok());
}
#[test]
fn test_validate_port_zero_fails() {
let config = Config {
port: 0,
..Default::default()
};
let result = config.validate();
assert!(result.is_err());
let err = result.unwrap_err();
assert!(matches!(err, ConfigError::InvalidPort { port: 0 }));
}
#[test]
fn test_validate_valid_ports_pass() {
let config = Config {
port: 1,
..Default::default()
};
assert!(config.validate().is_ok());
let config = Config {
port: 80,
..Default::default()
};
assert!(config.validate().is_ok());
let config = Config {
port: 443,
..Default::default()
};
assert!(config.validate().is_ok());
let config = Config {
port: 65535,
..Default::default()
};
assert!(config.validate().is_ok());
}
#[test]
fn test_validate_sidebar_max_items_zero_fails() {
let config = Config {
sidebar_max_items: 0,
..Default::default()
};
let result = config.validate();
assert!(result.is_err());
let err = result.unwrap_err();
assert!(matches!(
err,
ConfigError::InvalidSidebarMaxItems { value: 0 }
));
}
#[test]
fn test_validate_valid_sidebar_max_items_pass() {
let config = Config {
sidebar_max_items: 1,
..Default::default()
};
assert!(config.validate().is_ok());
let config = Config {
sidebar_max_items: 100,
..Default::default()
};
assert!(config.validate().is_ok());
let config = Config {
sidebar_max_items: 10000,
..Default::default()
};
assert!(config.validate().is_ok());
}
#[test]
fn test_validate_build_concurrency_zero_fails() {
let config = Config {
build_concurrency: Some(0),
..Default::default()
};
let result = config.validate();
assert!(result.is_err());
let err = result.unwrap_err();
assert!(matches!(
err,
ConfigError::InvalidBuildConcurrency { value: 0 }
));
}
#[test]
fn test_validate_build_concurrency_none_passes() {
let config = Config {
build_concurrency: None, ..Default::default()
};
assert!(config.validate().is_ok());
}
#[test]
fn test_validate_valid_build_concurrency_pass() {
let config = Config {
build_concurrency: Some(1),
..Default::default()
};
assert!(config.validate().is_ok());
let config = Config {
build_concurrency: Some(8),
..Default::default()
};
assert!(config.validate().is_ok());
let config = Config {
build_concurrency: Some(32),
..Default::default()
};
assert!(config.validate().is_ok());
}
#[test]
fn test_default_title_prefix_empty() {
let config = Config::default();
assert_eq!(config.title_prefix, "");
}
#[test]
fn test_default_title_suffix_empty() {
let config = Config::default();
assert_eq!(config.title_suffix, "");
}
#[test]
fn test_validate_oembed_cache_size_zero_is_valid() {
let config = Config {
oembed_cache_size: 0,
..Default::default()
};
assert!(config.validate().is_ok());
}
#[test]
fn test_find_root_dir_file_without_markers_returns_parent() {
let tmp = tempfile::tempdir().unwrap();
let file_path = tmp.path().join("test.md");
std::fs::write(&file_path, "# Hello").unwrap();
let root = find_root_dir(&file_path);
assert!(
root.is_dir(),
"root_dir should be a directory, got: {root:?}"
);
assert_eq!(root, tmp.path());
}
#[test]
fn test_find_root_dir_directory_without_markers_returns_itself() {
let tmp = tempfile::tempdir().unwrap();
let root = find_root_dir(tmp.path());
assert!(root.is_dir());
assert!(
root.is_dir(),
"root_dir should be a directory, got: {root:?}"
);
}
#[test]
fn test_find_root_dir_with_git_marker_returns_marker_parent() {
let tmp = tempfile::tempdir().unwrap();
let nested = tmp.path().join("sub").join("dir");
std::fs::create_dir_all(&nested).unwrap();
std::fs::create_dir(tmp.path().join(".git")).unwrap();
let file_path = nested.join("test.md");
std::fs::write(&file_path, "# Hello").unwrap();
let root = find_root_dir(&file_path);
assert_eq!(root, tmp.path().to_path_buf());
assert!(root.is_dir());
}
#[test]
fn test_is_home_dir() {
if let Some(home) = std::env::var_os("HOME").map(PathBuf::from) {
assert!(is_home_dir(&home));
assert!(!is_home_dir(Path::new("/tmp")));
}
}
}