use std::path::{Path, PathBuf};
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SourceConfig {
#[serde(default = "default_true")]
pub enabled: bool,
#[serde(default = "default_timeout")]
pub timeout: f64,
#[serde(default = "default_max_retries")]
pub max_retries: u32,
#[serde(default)]
pub api_key: String,
}
impl Default for SourceConfig {
fn default() -> Self {
Self {
enabled: true,
timeout: 30.0,
max_retries: 3,
api_key: String::new(),
}
}
}
fn default_true() -> bool {
true
}
fn default_timeout() -> f64 {
30.0
}
fn default_max_retries() -> u32 {
3
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EpoConfig {
#[serde(default = "default_true")]
pub enabled: bool,
#[serde(default = "default_timeout")]
pub timeout: f64,
#[serde(default = "default_max_retries")]
pub max_retries: u32,
#[serde(default)]
pub consumer_key: String,
#[serde(default)]
pub consumer_secret: String,
}
impl Default for EpoConfig {
fn default() -> Self {
Self {
enabled: true,
timeout: 30.0,
max_retries: 3,
consumer_key: String::new(),
consumer_secret: String::new(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChatConfig {
#[serde(default = "default_model")]
pub model: String,
#[serde(default = "default_max_tokens")]
pub max_tokens: u32,
#[serde(default = "default_scoring_concurrency")]
pub scoring_concurrency: u32,
}
impl Default for ChatConfig {
fn default() -> Self {
Self {
model: "claude-sonnet-4-6".to_string(),
max_tokens: 4096,
scoring_concurrency: 5,
}
}
}
fn default_model() -> String {
"claude-sonnet-4-6".to_string()
}
fn default_max_tokens() -> u32 {
4096
}
fn default_scoring_concurrency() -> u32 {
5
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct UiConfig {
#[serde(default = "default_true")]
pub show_institutional_hint: bool,
}
impl Default for UiConfig {
fn default() -> Self {
Self {
show_institutional_hint: true,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Config {
pub db_path: PathBuf,
#[serde(default = "default_sources")]
pub default_sources: Vec<String>,
#[serde(default)]
pub pubmed: SourceConfig,
#[serde(default)]
pub arxiv: SourceConfig,
#[serde(default)]
pub openalex: SourceConfig,
#[serde(default)]
pub inspire: SourceConfig,
#[serde(default)]
pub patentsview: SourceConfig,
#[serde(default)]
pub lens: SourceConfig,
#[serde(default)]
pub epo: EpoConfig,
#[serde(default)]
pub chat: ChatConfig,
#[serde(default)]
pub ui: UiConfig,
}
fn default_sources() -> Vec<String> {
vec![
"pubmed".into(),
"arxiv".into(),
"openalex".into(),
"inspire".into(),
]
}
impl Config {
pub fn papers_dir(&self) -> PathBuf {
self.db_path
.parent()
.unwrap_or_else(|| Path::new("."))
.join("papers")
}
}
impl Default for Config {
fn default() -> Self {
let workspace = find_workspace_root().unwrap_or_else(|| std::env::current_dir().unwrap());
Self {
db_path: default_db_path(&workspace),
default_sources: default_sources(),
pubmed: SourceConfig::default(),
arxiv: SourceConfig::default(),
openalex: SourceConfig::default(),
inspire: SourceConfig::default(),
patentsview: SourceConfig::default(),
lens: SourceConfig::default(),
epo: EpoConfig::default(),
chat: ChatConfig::default(),
ui: UiConfig::default(),
}
}
}
fn find_workspace_root() -> Option<PathBuf> {
let output = std::process::Command::new("git")
.args(["rev-parse", "--show-toplevel"])
.output()
.ok()?;
if output.status.success() {
let path = String::from_utf8(output.stdout).ok()?;
Some(PathBuf::from(path.trim()))
} else {
None
}
}
fn default_db_path(workspace: &Path) -> PathBuf {
if let Ok(db) = std::env::var("SCITADEL_DB") {
let expanded = if db.starts_with('~') {
if let Ok(home) = std::env::var("HOME") {
db.replacen('~', &home, 1)
} else {
db
}
} else {
db
};
return PathBuf::from(expanded);
}
workspace.join(".scitadel").join("scitadel.db")
}
pub fn load_config() -> Config {
use crate::credentials::resolve;
let workspace = find_workspace_root().unwrap_or_else(|| std::env::current_dir().unwrap());
let db_path = default_db_path(&workspace);
let config_path = workspace.join(".scitadel").join("config.toml");
let mut config: Config = std::fs::read_to_string(&config_path)
.ok()
.and_then(|contents| toml::from_str(&contents).ok())
.unwrap_or_default();
config.db_path = db_path;
config.pubmed.api_key = resolve(
"pubmed.api_key",
"SCITADEL_PUBMED_API_KEY",
&config.pubmed.api_key,
)
.unwrap_or_default();
config.openalex.api_key = resolve(
"openalex.email",
"SCITADEL_OPENALEX_EMAIL",
&config.openalex.api_key,
)
.unwrap_or_default();
config.patentsview.api_key = resolve(
"patentsview.api_key",
"SCITADEL_PATENTSVIEW_KEY",
&config.patentsview.api_key,
)
.unwrap_or_default();
config.lens.api_key = resolve(
"lens.api_token",
"SCITADEL_LENS_TOKEN",
&config.lens.api_key,
)
.unwrap_or_default();
config.epo.consumer_key = resolve(
"epo.consumer_key",
"SCITADEL_EPO_KEY",
&config.epo.consumer_key,
)
.unwrap_or_default();
config.epo.consumer_secret = resolve(
"epo.consumer_secret",
"SCITADEL_EPO_SECRET",
&config.epo.consumer_secret,
)
.unwrap_or_default();
if let Ok(model) = std::env::var("SCITADEL_CHAT_MODEL") {
config.chat.model = model;
}
if let Ok(tokens) = std::env::var("SCITADEL_CHAT_MAX_TOKENS")
&& let Ok(v) = tokens.parse()
{
config.chat.max_tokens = v;
}
if let Ok(conc) = std::env::var("SCITADEL_SCORING_CONCURRENCY")
&& let Ok(v) = conc.parse()
{
config.chat.scoring_concurrency = v;
}
config
}
pub fn load_config_from(path: &Path) -> Result<Config, crate::error::CoreError> {
let contents = std::fs::read_to_string(path)?;
toml::from_str(&contents).map_err(|e| crate::error::CoreError::Config(e.to_string()))
}