Skip to main content

mcp_postgres/
config.rs

1use anyhow::Result;
2use serde::{Deserialize, Serialize};
3use std::fmt;
4use std::str::FromStr;
5use std::time::Duration;
6
7#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
8pub enum AccessMode {
9    #[serde(rename = "unrestricted")]
10    Unrestricted,
11    #[serde(rename = "restricted")]
12    Restricted,
13}
14
15impl fmt::Display for AccessMode {
16    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
17        match self {
18            AccessMode::Unrestricted => write!(f, "unrestricted"),
19            AccessMode::Restricted => write!(f, "restricted"),
20        }
21    }
22}
23
24impl FromStr for AccessMode {
25    type Err = String;
26    fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
27        match s.to_lowercase().as_str() {
28            "unrestricted" => Ok(AccessMode::Unrestricted),
29            "restricted" => Ok(AccessMode::Restricted),
30            _ => Err(format!("Invalid access mode: {s}. Use 'unrestricted' or 'restricted'")),
31        }
32    }
33}
34
35#[derive(Debug, Clone, Serialize, Deserialize)]
36pub struct Config {
37    pub database: DatabaseConfig,
38    pub server: ServerConfig,
39    pub pool: PoolConfig,
40    pub metrics: MetricsConfig,
41}
42
43#[derive(Debug, Clone, Serialize, Deserialize)]
44pub struct DatabaseConfig {
45    pub url: String,
46}
47
48#[derive(Debug, Clone, Serialize, Deserialize)]
49pub struct ServerConfig {
50    pub host: String,
51    pub port: u16,
52    pub request_timeout: Duration,
53    pub access_mode: AccessMode,
54}
55
56#[derive(Debug, Clone, Serialize, Deserialize)]
57pub struct PoolConfig {
58    pub min_size: u32,
59    pub max_size: u32,
60    pub queue_timeout: Duration,
61}
62
63#[derive(Debug, Clone, Serialize, Deserialize)]
64pub struct MetricsConfig {
65    pub enabled: bool,
66    pub port: u16,
67}
68
69impl Config {
70    pub fn from_args(args: &super::Args) -> Result<Self> {
71        let database_url = args.database_url.clone()
72            .or_else(|| std::env::var("DATABASE_URL").ok())
73            .unwrap_or_else(|| "postgres://postgres:postgres@localhost:5432/postgres".to_string());
74
75        let min_size = args.min_connections.unwrap_or(5);
76        let max_size = args.max_connections.unwrap_or(20);
77
78        Ok(Config {
79            database: DatabaseConfig {
80                url: database_url,
81            },
82            server: ServerConfig {
83                host: args.host.clone(),
84                port: args.port,
85                request_timeout: Duration::from_secs(30),
86                access_mode: args.access_mode,
87            },
88            pool: PoolConfig {
89                min_size,
90                max_size,
91                queue_timeout: Duration::from_secs(10),
92            },
93            metrics: MetricsConfig {
94                enabled: args.enable_metrics,
95                port: args.metrics_port,
96            },
97        })
98    }
99}
100
101impl Default for Config {
102    fn default() -> Self {
103        Self {
104            database: DatabaseConfig {
105                url: "postgres://postgres:postgres@localhost:5432/postgres".to_string(),
106            },
107            server: ServerConfig {
108                host: "127.0.0.1".to_string(),
109                port: 3000,
110                request_timeout: Duration::from_secs(30),
111                access_mode: AccessMode::Unrestricted,
112            },
113            pool: PoolConfig {
114                min_size: 5,
115                max_size: 20,
116                queue_timeout: Duration::from_secs(10),
117            },
118            metrics: MetricsConfig {
119                enabled: false,
120                port: 9090,
121            },
122        }
123    }
124}
125
126#[cfg(test)]
127mod tests {
128    use super::*;
129    use clap::Parser;
130
131    #[test]
132    fn test_config_defaults() {
133        let cfg = Config::default();
134        assert_eq!(cfg.server.host, "127.0.0.1");
135        assert_eq!(cfg.server.port, 3000);
136        assert_eq!(cfg.server.request_timeout, Duration::from_secs(30));
137    }
138
139    #[test]
140    fn test_database_config_defaults() {
141        let cfg = Config::default();
142        assert_eq!(cfg.database.url, "postgres://postgres:postgres@localhost:5432/postgres");
143    }
144
145    #[test]
146    fn test_pool_config_defaults() {
147        let cfg = Config::default();
148        assert_eq!(cfg.pool.min_size, 5);
149        assert_eq!(cfg.pool.max_size, 20);
150        assert_eq!(cfg.pool.queue_timeout, Duration::from_secs(10));
151    }
152
153    #[test]
154    fn test_metrics_config_defaults() {
155        let cfg = Config::default();
156        assert!(!cfg.metrics.enabled);
157        assert_eq!(cfg.metrics.port, 9090);
158    }
159
160    #[test]
161    fn test_config_serde() {
162        let cfg = Config::default();
163        let json = serde_json::to_string(&cfg).unwrap();
164        let deserialized: Config = serde_json::from_str(&json).unwrap();
165        assert_eq!(deserialized.server.port, cfg.server.port);
166        assert_eq!(deserialized.pool.min_size, cfg.pool.min_size);
167        assert_eq!(deserialized.database.url, cfg.database.url);
168    }
169
170    #[test]
171    fn test_config_from_args_cpu_aware() {
172        let num_cpus = num_cpus::get() as u32;
173
174        // Simulating what from_args does with defaults
175        let min_size = 1;
176        let max_size = num_cpus * 8;
177
178        assert_eq!(min_size, 1);
179        assert!(max_size > 0);
180        assert_eq!(max_size, num_cpus * 8);
181    }
182
183    #[test]
184    fn test_pool_config_values() {
185        let cfg = Config::default();
186        assert!(cfg.pool.min_size > 0);
187        assert!(cfg.pool.max_size >= cfg.pool.min_size);
188    }
189
190    #[test]
191    fn test_server_config_debug() {
192        let cfg = Config::default();
193        let debug = format!("{:?}", cfg);
194        assert!(debug.contains("127.0.0.1"));
195        assert!(debug.contains("3000"));
196    }
197}