1use crate::error::ClientError;
9use serde::{Deserialize, Serialize};
10use std::time::Duration;
11
12#[derive(Debug, Clone, Serialize, Deserialize, Default)]
18pub struct ClientConfig {
19 pub connection: ConnectionConfig,
20 pub pool: PoolConfig,
21 pub retry: RetryConfig,
22 pub timeout: TimeoutConfig,
23}
24
25impl ClientConfig {
26 pub fn new(host: impl Into<String>, port: u16, database: impl Into<String>) -> Self {
28 Self {
29 connection: ConnectionConfig {
30 host: host.into(),
31 port,
32 database: database.into(),
33 ..Default::default()
34 },
35 ..Default::default()
36 }
37 }
38
39 pub fn from_url(url: &str) -> Result<Self, ClientError> {
41 let url = url
42 .strip_prefix("aegis://")
43 .ok_or_else(|| ClientError::InvalidUrl("URL must start with aegis://".to_string()))?;
44
45 let (auth_host, path) = url.split_once('/').unwrap_or((url, ""));
46
47 let (auth, host_port) = if auth_host.contains('@') {
48 let parts: Vec<&str> = auth_host.splitn(2, '@').collect();
49 (Some(parts[0]), parts[1])
50 } else {
51 (None, auth_host)
52 };
53
54 let (host, port) = if host_port.contains(':') {
55 let parts: Vec<&str> = host_port.splitn(2, ':').collect();
56 let port: u16 = parts[1]
57 .parse()
58 .map_err(|_| ClientError::InvalidUrl("Invalid port".to_string()))?;
59 (parts[0].to_string(), port)
60 } else {
61 (host_port.to_string(), 9090) };
63
64 let database = if path.is_empty() {
65 "default".to_string()
66 } else {
67 path.split('?').next().unwrap_or("default").to_string()
68 };
69
70 let (username, password) = if let Some(auth) = auth {
71 if auth.contains(':') {
72 let parts: Vec<&str> = auth.splitn(2, ':').collect();
73 (Some(parts[0].to_string()), Some(parts[1].to_string()))
74 } else {
75 (Some(auth.to_string()), None)
76 }
77 } else {
78 (None, None)
79 };
80
81 Ok(Self {
82 connection: ConnectionConfig {
83 host,
84 port,
85 database,
86 username,
87 password,
88 ..Default::default()
89 },
90 ..Default::default()
91 })
92 }
93
94 pub fn with_pool_size(mut self, min: usize, max: usize) -> Self {
96 self.pool.min_connections = min;
97 self.pool.max_connections = max;
98 self
99 }
100
101 pub fn with_connect_timeout(mut self, timeout: Duration) -> Self {
103 self.timeout.connect = timeout;
104 self
105 }
106
107 pub fn with_query_timeout(mut self, timeout: Duration) -> Self {
109 self.timeout.query = timeout;
110 self
111 }
112}
113
114#[derive(Debug, Clone, Serialize, Deserialize)]
120pub struct ConnectionConfig {
121 pub host: String,
122 pub port: u16,
123 pub database: String,
124 pub username: Option<String>,
125 pub password: Option<String>,
126 pub ssl_mode: SslMode,
127 pub application_name: Option<String>,
128}
129
130impl Default for ConnectionConfig {
131 fn default() -> Self {
132 Self {
133 host: "localhost".to_string(),
134 port: 9090, database: "default".to_string(),
136 username: None,
137 password: None,
138 ssl_mode: SslMode::Prefer,
139 application_name: None,
140 }
141 }
142}
143
144impl ConnectionConfig {
145 pub fn connection_string(&self) -> String {
147 let mut parts = vec![format!("host={}", self.host), format!("port={}", self.port)];
148
149 parts.push(format!("dbname={}", self.database));
150
151 if let Some(ref user) = self.username {
152 parts.push(format!("user={}", user));
153 }
154
155 if let Some(ref app) = self.application_name {
156 parts.push(format!("application_name={}", app));
157 }
158
159 parts.push(format!("sslmode={}", self.ssl_mode.as_str()));
160
161 parts.join(" ")
162 }
163
164 pub fn address(&self) -> String {
166 format!("{}:{}", self.host, self.port)
167 }
168}
169
170#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
176pub enum SslMode {
177 Disable,
178 #[default]
179 Prefer,
180 Require,
181 VerifyCa,
182 VerifyFull,
183}
184
185impl SslMode {
186 pub fn as_str(&self) -> &'static str {
187 match self {
188 Self::Disable => "disable",
189 Self::Prefer => "prefer",
190 Self::Require => "require",
191 Self::VerifyCa => "verify-ca",
192 Self::VerifyFull => "verify-full",
193 }
194 }
195}
196
197#[derive(Debug, Clone, Serialize, Deserialize)]
203pub struct PoolConfig {
204 pub min_connections: usize,
205 pub max_connections: usize,
206 pub acquire_timeout: Duration,
207 pub idle_timeout: Duration,
208 pub max_lifetime: Duration,
209 pub test_on_acquire: bool,
210}
211
212impl Default for PoolConfig {
213 fn default() -> Self {
214 Self {
215 min_connections: 1,
216 max_connections: 10,
217 acquire_timeout: Duration::from_secs(30),
218 idle_timeout: Duration::from_secs(600),
219 max_lifetime: Duration::from_secs(1800),
220 test_on_acquire: true,
221 }
222 }
223}
224
225#[derive(Debug, Clone, Serialize, Deserialize)]
231pub struct RetryConfig {
232 pub max_retries: u32,
233 pub initial_delay: Duration,
234 pub max_delay: Duration,
235 pub multiplier: f64,
236}
237
238impl Default for RetryConfig {
239 fn default() -> Self {
240 Self {
241 max_retries: 3,
242 initial_delay: Duration::from_millis(100),
243 max_delay: Duration::from_secs(10),
244 multiplier: 2.0,
245 }
246 }
247}
248
249impl RetryConfig {
250 pub fn delay_for_attempt(&self, attempt: u32) -> Duration {
252 let delay_ms = self.initial_delay.as_millis() as f64 * self.multiplier.powi(attempt as i32);
253 let delay = Duration::from_millis(delay_ms as u64);
254 delay.min(self.max_delay)
255 }
256}
257
258#[derive(Debug, Clone, Serialize, Deserialize)]
264pub struct TimeoutConfig {
265 pub connect: Duration,
266 pub query: Duration,
267 pub statement: Duration,
268}
269
270impl Default for TimeoutConfig {
271 fn default() -> Self {
272 Self {
273 connect: Duration::from_secs(10),
274 query: Duration::from_secs(30),
275 statement: Duration::from_secs(300),
276 }
277 }
278}
279
280#[cfg(test)]
285mod tests {
286 use super::*;
287
288 #[test]
289 fn test_client_config_default() {
290 let config = ClientConfig::default();
291 assert_eq!(config.connection.host, "localhost");
292 assert_eq!(config.connection.port, 9090);
293 assert_eq!(config.pool.max_connections, 10);
294 }
295
296 #[test]
297 fn test_from_url_simple() {
298 let config = ClientConfig::from_url("aegis://localhost:9090/testdb")
299 .expect("Should parse simple URL");
300 assert_eq!(config.connection.host, "localhost");
301 assert_eq!(config.connection.port, 9090);
302 assert_eq!(config.connection.database, "testdb");
303 }
304
305 #[test]
306 fn test_from_url_with_auth() {
307 let config = ClientConfig::from_url("aegis://user:pass@localhost:9090/testdb")
308 .expect("Should parse URL with auth");
309 assert_eq!(config.connection.host, "localhost");
310 assert_eq!(config.connection.username, Some("user".to_string()));
311 assert_eq!(config.connection.password, Some("pass".to_string()));
312 }
313
314 #[test]
315 fn test_from_url_default_port() {
316 let config = ClientConfig::from_url("aegis://localhost/testdb")
317 .expect("Should parse URL with default port");
318 assert_eq!(config.connection.port, 9090);
319 }
320
321 #[test]
322 fn test_connection_string() {
323 let config = ConnectionConfig {
324 host: "db.example.com".to_string(),
325 port: 5433,
326 database: "mydb".to_string(),
327 username: Some("admin".to_string()),
328 password: None,
329 ssl_mode: SslMode::Require,
330 application_name: Some("myapp".to_string()),
331 };
332
333 let conn_str = config.connection_string();
334 assert!(conn_str.contains("host=db.example.com"));
335 assert!(conn_str.contains("port=5433"));
336 assert!(conn_str.contains("dbname=mydb"));
337 assert!(conn_str.contains("user=admin"));
338 assert!(conn_str.contains("sslmode=require"));
339 }
340
341 #[test]
342 fn test_retry_delay() {
343 let config = RetryConfig::default();
344 let delay0 = config.delay_for_attempt(0);
345 let delay1 = config.delay_for_attempt(1);
346 let delay2 = config.delay_for_attempt(2);
347
348 assert_eq!(delay0, Duration::from_millis(100));
349 assert_eq!(delay1, Duration::from_millis(200));
350 assert_eq!(delay2, Duration::from_millis(400));
351 }
352
353 #[test]
354 fn test_retry_max_delay() {
355 let config = RetryConfig {
356 max_delay: Duration::from_millis(500),
357 ..Default::default()
358 };
359
360 let delay10 = config.delay_for_attempt(10);
361 assert_eq!(delay10, Duration::from_millis(500));
362 }
363}