1use anyhow::Result;
2use serde::{Deserialize, Serialize};
3use std::fmt;
4use std::str::FromStr;
5use std::time::Duration;
6
7#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
8pub enum AccessMode {
9 #[serde(rename = "unrestricted")]
10 Unrestricted,
11 #[serde(rename = "restricted")]
12 Restricted,
13}
14
15impl fmt::Display for AccessMode {
16 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
17 match self {
18 AccessMode::Unrestricted => write!(f, "unrestricted"),
19 AccessMode::Restricted => write!(f, "restricted"),
20 }
21 }
22}
23
24impl FromStr for AccessMode {
25 type Err = String;
26 fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
27 match s.to_lowercase().as_str() {
28 "unrestricted" => Ok(AccessMode::Unrestricted),
29 "restricted" => Ok(AccessMode::Restricted),
30 _ => Err(format!(
31 "Invalid access mode: {s}. Use 'unrestricted' or 'restricted'"
32 )),
33 }
34 }
35}
36
37#[derive(Debug, Clone, Serialize, Deserialize)]
38pub struct Config {
39 pub database: DatabaseConfig,
40 pub server: ServerConfig,
41 pub pool: PoolConfig,
42 pub metrics: MetricsConfig,
43}
44
45#[derive(Debug, Clone, Serialize, Deserialize)]
46pub struct DatabaseConfig {
47 pub url: String,
48}
49
50#[derive(Debug, Clone, Serialize, Deserialize)]
51pub struct ServerConfig {
52 pub host: String,
53 pub port: u16,
54 pub request_timeout: Duration,
55 pub access_mode: AccessMode,
56 #[serde(default, skip_serializing)]
59 pub auth_token: Option<String>,
60 #[serde(default)]
62 pub allow_url_import: bool,
63}
64
65#[derive(Debug, Clone, Serialize, Deserialize)]
66pub struct PoolConfig {
67 pub min_size: u32,
68 pub max_size: u32,
69 pub queue_timeout: Duration,
70}
71
72#[derive(Debug, Clone, Serialize, Deserialize)]
73pub struct MetricsConfig {
74 pub enabled: bool,
75 pub port: u16,
76}
77
78impl Config {
79 pub fn from_args(args: &super::Args) -> Result<Self> {
80 let database_url = args
81 .database_url
82 .clone()
83 .or_else(|| std::env::var("DATABASE_URL").ok())
84 .unwrap_or_else(|| "postgres://postgres:postgres@localhost:5432/postgres".to_string());
85
86 let min_size = args.min_connections.unwrap_or(5);
87 let max_size = args.max_connections.unwrap_or(20);
88
89 let auth_token = args
90 .auth_token
91 .clone()
92 .or_else(|| std::env::var("MCP_AUTH_TOKEN").ok())
93 .filter(|t| !t.is_empty());
94
95 Ok(Config {
96 database: DatabaseConfig { url: database_url },
97 server: ServerConfig {
98 host: args.host.clone(),
99 port: args.port,
100 request_timeout: Duration::from_secs(30),
101 access_mode: args.access_mode,
102 auth_token,
103 allow_url_import: args.allow_url_import,
104 },
105 pool: PoolConfig {
106 min_size,
107 max_size,
108 queue_timeout: Duration::from_secs(10),
109 },
110 metrics: MetricsConfig {
111 enabled: args.enable_metrics,
112 port: args.metrics_port,
113 },
114 })
115 }
116}
117
118impl Default for Config {
119 fn default() -> Self {
120 Self {
121 database: DatabaseConfig {
122 url: "postgres://postgres:postgres@localhost:5432/postgres".to_string(),
123 },
124 server: ServerConfig {
125 host: "127.0.0.1".to_string(),
126 port: 3000,
127 request_timeout: Duration::from_secs(30),
128 access_mode: AccessMode::Unrestricted,
129 auth_token: None,
130 allow_url_import: false,
131 },
132 pool: PoolConfig {
133 min_size: 5,
134 max_size: 20,
135 queue_timeout: Duration::from_secs(10),
136 },
137 metrics: MetricsConfig {
138 enabled: false,
139 port: 9090,
140 },
141 }
142 }
143}
144
145#[cfg(test)]
146mod tests {
147 use super::*;
148
149 #[test]
150 fn test_config_defaults() {
151 let cfg = Config::default();
152 assert_eq!(cfg.server.host, "127.0.0.1");
153 assert_eq!(cfg.server.port, 3000);
154 assert_eq!(cfg.server.request_timeout, Duration::from_secs(30));
155 }
156
157 #[test]
158 fn test_database_config_defaults() {
159 let cfg = Config::default();
160 assert_eq!(
161 cfg.database.url,
162 "postgres://postgres:postgres@localhost:5432/postgres"
163 );
164 }
165
166 #[test]
167 fn test_pool_config_defaults() {
168 let cfg = Config::default();
169 assert_eq!(cfg.pool.min_size, 5);
170 assert_eq!(cfg.pool.max_size, 20);
171 assert_eq!(cfg.pool.queue_timeout, Duration::from_secs(10));
172 }
173
174 #[test]
175 fn test_metrics_config_defaults() {
176 let cfg = Config::default();
177 assert!(!cfg.metrics.enabled);
178 assert_eq!(cfg.metrics.port, 9090);
179 }
180
181 #[test]
182 fn test_config_serde() {
183 let cfg = Config::default();
184 let json = serde_json::to_string(&cfg).unwrap();
185 let deserialized: Config = serde_json::from_str(&json).unwrap();
186 assert_eq!(deserialized.server.port, cfg.server.port);
187 assert_eq!(deserialized.pool.min_size, cfg.pool.min_size);
188 assert_eq!(deserialized.database.url, cfg.database.url);
189 }
190
191 #[test]
192 fn test_config_from_args_cpu_aware() {
193 let num_cpus = num_cpus::get() as u32;
194
195 let min_size = 1;
197 let max_size = num_cpus * 8;
198
199 assert_eq!(min_size, 1);
200 assert!(max_size > 0);
201 assert_eq!(max_size, num_cpus * 8);
202 }
203
204 #[test]
205 fn test_pool_config_values() {
206 let cfg = Config::default();
207 assert!(cfg.pool.min_size > 0);
208 assert!(cfg.pool.max_size >= cfg.pool.min_size);
209 }
210
211 #[test]
212 fn test_server_config_debug() {
213 let cfg = Config::default();
214 let debug = format!("{:?}", cfg);
215 assert!(debug.contains("127.0.0.1"));
216 assert!(debug.contains("3000"));
217 }
218}