use clap::Parser;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::path::PathBuf;
use crate::error::{Result, RossbyError};
#[derive(Parser, Debug)]
#[command(name = "rossby")]
#[command(author, version, about, long_about = None)]
pub struct Args {
pub netcdf_file: PathBuf,
#[arg(short = 'H', long, env = "ROSSBY_HOST", default_value = "127.0.0.1")]
pub host: String,
#[arg(short, long, env = "ROSSBY_PORT", default_value = "8000")]
pub port: u16,
#[arg(short, long, env = "ROSSBY_WORKERS")]
pub workers: Option<usize>,
#[arg(short, long, env = "ROSSBY_CONFIG")]
pub config: Option<PathBuf>,
#[arg(long, env = "ROSSBY_LOG_LEVEL", default_value = "info")]
pub log_level: String,
#[arg(long, env = "ROSSBY_DISCOVERY_URL")]
pub discovery_url: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ServerConfig {
#[serde(default = "default_host")]
pub host: String,
#[serde(default = "default_port")]
pub port: u16,
#[serde(default)]
pub workers: Option<usize>,
#[serde(default)]
pub discovery_url: Option<String>,
#[serde(default = "default_max_data_points")]
pub max_data_points: usize,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DataConfig {
#[serde(default = "default_interpolation")]
pub interpolation_method: String,
#[serde(default)]
pub file_path: Option<PathBuf>,
#[serde(default)]
pub dimension_aliases: HashMap<String, String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Config {
#[serde(default)]
pub server: ServerConfig,
#[serde(default)]
pub data: DataConfig,
#[serde(default = "default_log_level")]
pub log_level: String,
}
impl Config {
pub fn load() -> Result<(Self, PathBuf)> {
let args = Args::parse();
let mut config = Config::default();
if let Some(config_path) = &args.config {
let json_config = Self::load_from_file(config_path)?;
config.merge(json_config);
}
config.server.host = args.host;
config.server.port = args.port;
if args.workers.is_some() {
config.server.workers = args.workers;
}
if args.discovery_url.is_some() {
config.server.discovery_url = args.discovery_url;
}
config.log_level = args.log_level;
let netcdf_path = args.netcdf_file;
Ok((config, netcdf_path))
}
fn load_from_file(path: &PathBuf) -> Result<Self> {
let content = std::fs::read_to_string(path)?;
let config: Config = serde_json::from_str(&content)?;
Ok(config)
}
fn merge(&mut self, other: Config) {
self.server.host = other.server.host;
self.server.port = other.server.port;
if other.server.workers.is_some() {
self.server.workers = other.server.workers;
}
self.data = other.data;
self.log_level = other.log_level;
}
pub fn validate(&self) -> Result<()> {
if self.server.host.is_empty() {
return Err(RossbyError::Config {
message: "Server host cannot be empty".to_string(),
});
}
if self.server.port == 0 {
return Err(RossbyError::Config {
message: "Server port cannot be 0".to_string(),
});
}
match self.log_level.as_str() {
"trace" | "debug" | "info" | "warn" | "error" => {}
_ => {
return Err(RossbyError::Config {
message: format!(
"Invalid log level: {}. Must be one of: trace, debug, info, warn, error",
self.log_level
),
});
}
}
match self.data.interpolation_method.as_str() {
"nearest" | "bilinear" | "bicubic" => {}
_ => {
return Err(RossbyError::Config {
message: format!(
"Invalid interpolation method: {}. Must be one of: nearest, bilinear, bicubic",
self.data.interpolation_method
)
});
}
}
Ok(())
}
}
impl Default for Config {
fn default() -> Self {
Self {
server: ServerConfig::default(),
data: DataConfig::default(),
log_level: default_log_level(),
}
}
}
impl Default for ServerConfig {
fn default() -> Self {
Self {
host: default_host(),
port: default_port(),
workers: None,
discovery_url: None,
max_data_points: default_max_data_points(),
}
}
}
impl Default for DataConfig {
fn default() -> Self {
Self {
interpolation_method: default_interpolation(),
file_path: None,
dimension_aliases: HashMap::new(),
}
}
}
fn default_host() -> String {
"127.0.0.1".to_string()
}
fn default_port() -> u16 {
8000
}
fn default_interpolation() -> String {
"bilinear".to_string()
}
fn default_log_level() -> String {
"info".to_string()
}
fn default_max_data_points() -> usize {
100_000_000 }
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_default_config() {
let config = Config::default();
assert_eq!(config.server.host, "127.0.0.1");
assert_eq!(config.server.port, 8000);
assert_eq!(config.data.interpolation_method, "bilinear");
assert_eq!(config.log_level, "info");
}
#[test]
fn test_config_merge() {
let mut config1 = Config::default();
let mut config2 = Config::default();
config2.server.port = 9000;
config2.server.workers = Some(4);
config1.merge(config2);
assert_eq!(config1.server.port, 9000);
assert_eq!(config1.server.workers, Some(4));
}
#[test]
fn test_config_validation() {
let config = Config::default();
assert!(config.validate().is_ok());
let mut config = Config::default();
config.server.host = "".to_string();
assert!(config.validate().is_err());
let mut config = Config::default();
config.server.port = 0;
assert!(config.validate().is_err());
let config = Config {
log_level: "invalid".to_string(),
..Config::default()
};
assert!(config.validate().is_err());
let mut config = Config::default();
config.data.interpolation_method = "invalid".to_string();
assert!(config.validate().is_err());
}
}