mod file_config;
use serde::{Deserialize, Serialize};
use std::path::{Path, PathBuf};
const TEST_MODE_ENV_VAR: &str = "RESEARCH_MASTER_TEST_MODE";
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CacheConfig {
#[serde(default)]
pub enabled: bool,
#[serde(default)]
pub directory: Option<PathBuf>,
#[serde(default = "default_search_ttl")]
pub search_ttl_seconds: u64,
#[serde(default = "default_citation_ttl")]
pub citation_ttl_seconds: u64,
#[serde(default = "default_max_cache_size")]
pub max_size_mb: usize,
}
impl Default for CacheConfig {
fn default() -> Self {
Self {
enabled: std::env::var("RESEARCH_MASTER_CACHE_ENABLED").is_ok(),
directory: None,
search_ttl_seconds: default_search_ttl(),
citation_ttl_seconds: default_citation_ttl(),
max_size_mb: default_max_cache_size(),
}
}
}
fn default_search_ttl() -> u64 {
1800 }
fn default_citation_ttl() -> u64 {
900 }
fn default_max_cache_size() -> usize {
500
}
pub fn default_cache_dir() -> PathBuf {
#[cfg(target_os = "macos")]
{
if let Ok(home) = std::env::var("HOME") {
return PathBuf::from(home)
.join("Library")
.join("Caches")
.join("research-master");
}
}
#[cfg(target_os = "linux")]
{
if let Ok(xdg_cache) = std::env::var("XDG_CACHE_HOME") {
return PathBuf::from(xdg_cache).join("research-master");
}
if let Ok(home) = std::env::var("HOME") {
return PathBuf::from(home).join(".cache").join("research-master");
}
}
#[cfg(target_os = "windows")]
{
if let Ok(appdata) = std::env::var("LOCALAPPDATA") {
return PathBuf::from(appdata).join("research-master").join("cache");
}
}
PathBuf::from(".research-master-cache")
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct Config {
#[serde(default)]
pub api_keys: ApiKeys,
#[serde(default)]
pub downloads: DownloadConfig,
#[serde(default)]
pub rate_limits: RateLimitConfig,
#[serde(default)]
pub sources: SourceConfig,
#[serde(default)]
pub cache: CacheConfig,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SourceConfig {
#[serde(default)]
pub enabled_sources: Option<String>,
#[serde(default)]
pub disabled_sources: Option<String>,
#[serde(default)]
pub proxy_http: Option<String>,
#[serde(default)]
pub proxy_https: Option<String>,
#[serde(default)]
pub rate_limits: Option<String>,
}
impl Default for SourceConfig {
fn default() -> Self {
Self::from_env()
}
}
impl SourceConfig {
fn from_env() -> Self {
Self {
enabled_sources: std::env::var("RESEARCH_MASTER_ENABLED_SOURCES").ok(),
disabled_sources: std::env::var("RESEARCH_MASTER_DISABLED_SOURCES").ok(),
proxy_http: std::env::var("RESEARCH_MASTER_PROXY_HTTP").ok(),
proxy_https: std::env::var("RESEARCH_MASTER_PROXY_HTTPS").ok(),
rate_limits: std::env::var("RESEARCH_MASTER_RATE_LIMITS").ok(),
}
}
fn without_env() -> Self {
Self {
enabled_sources: None,
disabled_sources: None,
proxy_http: None,
proxy_https: None,
rate_limits: None,
}
}
pub fn parse_rate_limits(&self) -> std::collections::HashMap<String, f32> {
let mut limits = std::collections::HashMap::new();
if let Some(ref limits_str) = self.rate_limits {
for part in limits_str.split(',') {
let parts: Vec<&str> = part.split(':').collect();
if parts.len() == 2 {
if let Ok(rate) = parts[1].parse::<f32>() {
limits.insert(parts[0].trim().to_string(), rate);
}
}
}
}
limits
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ApiKeys {
#[serde(default)]
pub semantic_scholar: Option<String>,
#[serde(default)]
pub core: Option<String>,
}
impl Default for ApiKeys {
fn default() -> Self {
Self::from_env()
}
}
impl ApiKeys {
fn from_env() -> Self {
Self {
semantic_scholar: std::env::var("SEMANTIC_SCHOLAR_API_KEY").ok(),
core: std::env::var("CORE_API_KEY").ok(),
}
}
fn without_env() -> Self {
Self {
semantic_scholar: None,
core: None,
}
}
}
impl Config {
fn from_env() -> Self {
Self {
api_keys: ApiKeys::from_env(),
downloads: DownloadConfig::default(),
rate_limits: RateLimitConfig::default(),
sources: SourceConfig::from_env(),
cache: CacheConfig::default(),
}
}
fn without_env() -> Self {
Self {
api_keys: ApiKeys::without_env(),
downloads: DownloadConfig::default(),
rate_limits: RateLimitConfig::default(),
sources: SourceConfig::without_env(),
cache: CacheConfig::default(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DownloadConfig {
#[serde(default = "default_download_dir")]
pub default_path: PathBuf,
#[serde(default = "default_true")]
pub organize_by_source: bool,
#[serde(default = "default_max_file_size")]
pub max_file_size_mb: usize,
}
impl Default for DownloadConfig {
fn default() -> Self {
Self {
default_path: default_download_dir(),
organize_by_source: true,
max_file_size_mb: 100,
}
}
}
fn default_download_dir() -> PathBuf {
PathBuf::from("./downloads")
}
fn default_true() -> bool {
true
}
fn default_max_file_size() -> usize {
100
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RateLimitConfig {
#[serde(default = "default_rps")]
pub default_requests_per_second: f32,
#[serde(default = "default_max_concurrent")]
pub max_concurrent_requests: usize,
}
impl Default for RateLimitConfig {
fn default() -> Self {
Self {
default_requests_per_second: default_rps(),
max_concurrent_requests: default_max_concurrent(),
}
}
}
fn default_rps() -> f32 {
5.0
}
fn default_max_concurrent() -> usize {
10
}
pub fn load_config(path: &Path) -> Result<Config, config::ConfigError> {
let test_mode = std::env::var(TEST_MODE_ENV_VAR)
.map(|value| value.eq_ignore_ascii_case("true"))
.unwrap_or(false);
if test_mode {
return Ok(Config::without_env());
}
let settings = config::Config::builder()
.add_source(config::File::from(path))
.add_source(config::Environment::with_prefix("RESEARCH_MASTER"))
.build()?;
settings.try_deserialize()
}
pub fn get_config() -> Config {
let test_mode = std::env::var(TEST_MODE_ENV_VAR)
.map(|value| value.eq_ignore_ascii_case("true"))
.unwrap_or(false);
if test_mode {
Config::without_env()
} else {
Config::from_env()
}
}
pub fn find_config_file() -> Option<PathBuf> {
let path = PathBuf::from("research-master.toml");
if path.exists() {
return Some(path);
}
let path = PathBuf::from(".research-master.toml");
if path.exists() {
return Some(path);
}
if let Ok(xdg_home) = std::env::var("XDG_CONFIG_HOME") {
let path = PathBuf::from(xdg_home)
.join("research-master")
.join("config.toml");
if path.exists() {
return Some(path);
}
}
if let Ok(home) = std::env::var("HOME") {
let home_path = PathBuf::from(&home);
let path = home_path
.join("Library")
.join("Application Support")
.join("research-master")
.join("config.toml");
if path.exists() {
return Some(path);
}
let path = home_path
.join(".config")
.join("research-master")
.join("config.toml");
if path.exists() {
return Some(path);
}
}
if let Ok(appdata) = std::env::var("APPDATA") {
let path = PathBuf::from(appdata)
.join("research-master")
.join("config.toml");
if path.exists() {
return Some(path);
}
}
None
}
pub use file_config::ConfigFile;
pub use file_config::ConfigFileError;
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_default_config() {
let config = Config::default();
assert!(config.downloads.organize_by_source);
assert_eq!(config.rate_limits.default_requests_per_second, 5.0);
}
#[test]
fn test_config_without_env() {
let config = Config::without_env();
assert!(config.api_keys.semantic_scholar.is_none());
assert!(config.api_keys.core.is_none());
assert!(config.sources.enabled_sources.is_none());
assert!(config.sources.disabled_sources.is_none());
}
#[test]
fn test_cache_config_defaults() {
let cache = CacheConfig::default();
assert!(cache.search_ttl_seconds == 1800);
assert!(cache.citation_ttl_seconds == 900);
assert!(cache.max_size_mb == 500);
}
#[test]
fn test_download_config_defaults() {
let download = DownloadConfig::default();
assert!(download.organize_by_source);
assert_eq!(download.max_file_size_mb, 100);
}
#[test]
fn test_rate_limit_config_defaults() {
let rate = RateLimitConfig::default();
assert_eq!(rate.default_requests_per_second, 5.0);
assert_eq!(rate.max_concurrent_requests, 10);
}
#[test]
fn test_source_config_without_env() {
let source = SourceConfig::without_env();
assert!(source.enabled_sources.is_none());
assert!(source.disabled_sources.is_none());
assert!(source.proxy_http.is_none());
assert!(source.proxy_https.is_none());
assert!(source.rate_limits.is_none());
}
#[test]
fn test_api_keys_without_env() {
let keys = ApiKeys::without_env();
assert!(keys.semantic_scholar.is_none());
assert!(keys.core.is_none());
}
#[test]
fn test_parse_rate_limits() {
let source_config = SourceConfig {
rate_limits: Some("semantic:0.5,arxiv:5,openalex:2.5".to_string()),
..Default::default()
};
let limits = source_config.parse_rate_limits();
assert_eq!(limits.get("semantic").copied(), Some(0.5));
assert_eq!(limits.get("arxiv").copied(), Some(5.0));
assert_eq!(limits.get("openalex").copied(), Some(2.5));
assert_eq!(limits.get("nonexistent"), None);
}
#[test]
fn test_parse_rate_limits_empty() {
let source_config = SourceConfig {
rate_limits: None,
..Default::default()
};
let limits = source_config.parse_rate_limits();
assert!(limits.is_empty());
}
#[test]
fn test_parse_rate_limits_invalid_format() {
let source_config = SourceConfig {
rate_limits: Some("semantic:0.5,invalidformat,arxiv:5".to_string()),
..Default::default()
};
let limits = source_config.parse_rate_limits();
assert_eq!(limits.get("semantic").copied(), Some(0.5));
assert_eq!(limits.get("arxiv").copied(), Some(5.0));
assert_eq!(limits.len(), 2);
}
#[test]
fn test_parse_rate_limits_whitespace() {
let source_config = SourceConfig {
rate_limits: Some("semantic:0.5,arxiv:5".to_string()),
..Default::default()
};
let limits = source_config.parse_rate_limits();
assert_eq!(
limits.get("semantic").copied(),
Some(0.5),
"semantic rate should be 0.5"
);
assert_eq!(
limits.get("arxiv").copied(),
Some(5.0),
"arxiv rate should be 5.0"
);
}
#[test]
fn test_find_config_file_nonexistent() {
let result = find_config_file();
let _ = result;
}
#[test]
fn test_find_config_file_current_dir() {
}
#[test]
fn test_find_config_file_hidden() {
}
#[test]
fn test_get_config_test_mode() {
std::env::set_var(TEST_MODE_ENV_VAR, "true");
let config = get_config();
assert!(config.api_keys.semantic_scholar.is_none());
std::env::remove_var(TEST_MODE_ENV_VAR);
}
#[test]
fn test_load_config_test_mode() {
std::env::set_var(TEST_MODE_ENV_VAR, "true");
let result = load_config(Path::new("/nonexistent/path.toml"));
assert!(result.is_ok());
let config = result.unwrap();
assert!(config.api_keys.semantic_scholar.is_none());
std::env::remove_var(TEST_MODE_ENV_VAR);
}
}