1use std::fmt;
11use std::str::FromStr;
12use std::time::Duration;
13
14use crate::constants::charset;
15use crate::error::{Error, Result};
16use crate::transport::TlsConfig;
17
18pub const DEFAULT_PORT: u16 = 1521;
20
21pub const DEFAULT_SDU: u32 = 8192;
23
24pub const DEFAULT_STMTCACHESIZE: usize = 20;
26
27#[derive(Debug, Clone, PartialEq, Eq)]
29pub enum ServiceMethod {
30 ServiceName(String),
32 Sid(String),
34}
35
36impl ServiceMethod {
37 pub fn service_name(&self) -> Option<&str> {
39 match self {
40 ServiceMethod::ServiceName(s) => Some(s),
41 ServiceMethod::Sid(_) => None,
42 }
43 }
44
45 pub fn sid(&self) -> Option<&str> {
47 match self {
48 ServiceMethod::ServiceName(_) => None,
49 ServiceMethod::Sid(s) => Some(s),
50 }
51 }
52}
53
54#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
56pub enum TlsMode {
57 #[default]
59 Disable,
60 Require,
62}
63
64#[derive(Debug, Clone)]
110pub struct Config {
111 pub host: String,
113 pub port: u16,
115 pub service: ServiceMethod,
117 pub username: String,
119 password: String,
121 pub tls_mode: TlsMode,
123 pub tls_config: Option<TlsConfig>,
125 pub connect_timeout: Duration,
127 pub sdu: u32,
129 pub charset_id: u16,
131 pub ncharset_id: u16,
133 pub stmtcachesize: usize,
135}
136
137impl Config {
138 pub fn new(
140 host: impl Into<String>,
141 port: u16,
142 service_name: impl Into<String>,
143 username: impl Into<String>,
144 password: impl Into<String>,
145 ) -> Self {
146 Self {
147 host: host.into(),
148 port,
149 service: ServiceMethod::ServiceName(service_name.into()),
150 username: username.into(),
151 password: password.into(),
152 tls_mode: TlsMode::Disable,
153 tls_config: None,
154 connect_timeout: Duration::from_secs(10),
155 sdu: DEFAULT_SDU,
156 charset_id: charset::UTF8,
157 ncharset_id: charset::UTF16,
158 stmtcachesize: DEFAULT_STMTCACHESIZE,
159 }
160 }
161
162 pub fn with_sid(
164 host: impl Into<String>,
165 port: u16,
166 sid: impl Into<String>,
167 username: impl Into<String>,
168 password: impl Into<String>,
169 ) -> Self {
170 Self {
171 host: host.into(),
172 port,
173 service: ServiceMethod::Sid(sid.into()),
174 username: username.into(),
175 password: password.into(),
176 tls_mode: TlsMode::Disable,
177 tls_config: None,
178 connect_timeout: Duration::from_secs(10),
179 sdu: DEFAULT_SDU,
180 charset_id: charset::UTF8,
181 ncharset_id: charset::UTF16,
182 stmtcachesize: DEFAULT_STMTCACHESIZE,
183 }
184 }
185
186 pub fn tls(mut self, mode: TlsMode) -> Self {
188 self.tls_mode = mode;
189 self
190 }
191
192 pub fn tls_config(mut self, config: TlsConfig) -> Self {
194 self.tls_config = Some(config);
195 self.tls_mode = TlsMode::Require;
196 self
197 }
198
199 pub fn is_tls_enabled(&self) -> bool {
201 self.tls_mode == TlsMode::Require
202 }
203
204 pub fn with_tls(mut self) -> Result<Self> {
219 let tls_config = TlsConfig::new();
220 tls_config.build_client_config()?;
222 self.tls_config = Some(tls_config);
223 self.tls_mode = TlsMode::Require;
224 Ok(self)
225 }
226
227 pub fn with_wallet(
248 mut self,
249 wallet_path: impl Into<String>,
250 wallet_password: Option<&str>,
251 ) -> Result<Self> {
252 let tls_config = TlsConfig::new()
253 .with_wallet(wallet_path, wallet_password.map(|s| s.to_string()));
254 tls_config.build_client_config()?;
256 self.tls_config = Some(tls_config);
257 self.tls_mode = TlsMode::Require;
258 Ok(self)
259 }
260
261 pub fn with_drcp(self, _connection_class: &str, _purity: &str) -> Self {
281 self
284 }
285
286 pub fn with_statement_cache_size(mut self, size: usize) -> Self {
302 self.stmtcachesize = size;
303 self
304 }
305
306 pub fn connect_timeout(mut self, timeout: Duration) -> Self {
308 self.connect_timeout = timeout;
309 self
310 }
311
312 pub fn sdu(mut self, sdu: u32) -> Self {
314 self.sdu = sdu;
315 self
316 }
317
318 pub fn stmtcachesize(mut self, size: usize) -> Self {
323 self.stmtcachesize = size;
324 self
325 }
326
327 pub(crate) fn password(&self) -> &str {
329 &self.password
330 }
331
332 pub fn set_password(&mut self, password: impl Into<String>) {
334 self.password = password.into();
335 }
336
337 pub fn set_username(&mut self, username: impl Into<String>) {
339 self.username = username.into();
340 }
341
342 pub fn build_connect_string(&self) -> String {
344 let mut parts = Vec::new();
345
346 let protocol = match self.tls_mode {
348 TlsMode::Disable => "TCP",
349 TlsMode::Require => "TCPS",
350 };
351 parts.push(format!(
352 "(ADDRESS=(PROTOCOL={})(HOST={})(PORT={}))",
353 protocol, self.host, self.port
354 ));
355
356 let service_part = match &self.service {
358 ServiceMethod::ServiceName(name) => format!("(SERVICE_NAME={})", name),
359 ServiceMethod::Sid(sid) => format!("(SID={})", sid),
360 };
361 parts.push(format!("(CONNECT_DATA={})", service_part));
362
363 format!("(DESCRIPTION={})", parts.join(""))
364 }
365
366 pub fn socket_addr(&self) -> String {
368 format!("{}:{}", self.host, self.port)
369 }
370}
371
372impl Default for Config {
373 fn default() -> Self {
374 Self {
375 host: "localhost".to_string(),
376 port: DEFAULT_PORT,
377 service: ServiceMethod::ServiceName("FREEPDB1".to_string()),
378 username: String::new(),
379 password: String::new(),
380 tls_mode: TlsMode::Disable,
381 tls_config: None,
382 connect_timeout: Duration::from_secs(10),
383 sdu: DEFAULT_SDU,
384 charset_id: charset::UTF8,
385 ncharset_id: charset::UTF16,
386 stmtcachesize: DEFAULT_STMTCACHESIZE,
387 }
388 }
389}
390
391impl FromStr for Config {
399 type Err = Error;
400
401 fn from_str(s: &str) -> Result<Self> {
402 let s = s.trim();
403
404 let s = s.trim_start_matches('/');
406
407 if s.is_empty() {
408 return Err(Error::InvalidConnectionString(
409 "empty connection string".to_string(),
410 ));
411 }
412
413 if s.starts_with('(') {
415 return Err(Error::InvalidConnectionString(
416 "TNS descriptor format not yet supported, use EZConnect format".to_string(),
417 ));
418 }
419
420 let mut config = Config::default();
424
425 if let Some(slash_pos) = s.find('/') {
427 let host_port = &s[..slash_pos];
428 let service_name = &s[slash_pos + 1..];
429
430 if service_name.is_empty() {
431 return Err(Error::InvalidConnectionString(
432 "missing service name after /".to_string(),
433 ));
434 }
435
436 config.service = ServiceMethod::ServiceName(service_name.to_string());
437
438 if let Some(colon_pos) = host_port.find(':') {
440 config.host = host_port[..colon_pos].to_string();
441 config.port = host_port[colon_pos + 1..]
442 .parse()
443 .map_err(|_| Error::InvalidConnectionString("invalid port number".to_string()))?;
444 } else {
445 config.host = host_port.to_string();
446 config.port = DEFAULT_PORT;
447 }
448 } else {
449 let parts: Vec<&str> = s.split(':').collect();
451
452 match parts.len() {
453 1 => {
454 config.host = parts[0].to_string();
456 }
457 2 => {
458 config.host = parts[0].to_string();
460 config.port = parts[1]
461 .parse()
462 .map_err(|_| Error::InvalidConnectionString("invalid port number".to_string()))?;
463 }
464 3 => {
465 config.host = parts[0].to_string();
467 config.port = parts[1]
468 .parse()
469 .map_err(|_| Error::InvalidConnectionString("invalid port number".to_string()))?;
470 config.service = ServiceMethod::Sid(parts[2].to_string());
471 }
472 _ => {
473 return Err(Error::InvalidConnectionString(
474 "too many colons in connection string".to_string(),
475 ));
476 }
477 }
478 }
479
480 if config.host.is_empty() {
481 return Err(Error::InvalidConnectionString(
482 "missing host".to_string(),
483 ));
484 }
485
486 Ok(config)
487 }
488}
489
490impl fmt::Display for Config {
491 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
492 match &self.service {
493 ServiceMethod::ServiceName(name) => {
494 write!(f, "{}:{}/{}", self.host, self.port, name)
495 }
496 ServiceMethod::Sid(sid) => {
497 write!(f, "{}:{}:{}", self.host, self.port, sid)
498 }
499 }
500 }
501}
502
503#[cfg(test)]
504mod tests {
505 use super::*;
506
507 #[test]
508 fn test_parse_ezconnect_full() {
509 let config: Config = "myhost:1522/myservice".parse().unwrap();
510 assert_eq!(config.host, "myhost");
511 assert_eq!(config.port, 1522);
512 assert_eq!(
513 config.service,
514 ServiceMethod::ServiceName("myservice".to_string())
515 );
516 }
517
518 #[test]
519 fn test_parse_ezconnect_default_port() {
520 let config: Config = "myhost/myservice".parse().unwrap();
521 assert_eq!(config.host, "myhost");
522 assert_eq!(config.port, DEFAULT_PORT);
523 assert_eq!(
524 config.service,
525 ServiceMethod::ServiceName("myservice".to_string())
526 );
527 }
528
529 #[test]
530 fn test_parse_ezconnect_with_slashes() {
531 let config: Config = "//myhost:1522/myservice".parse().unwrap();
532 assert_eq!(config.host, "myhost");
533 assert_eq!(config.port, 1522);
534 }
535
536 #[test]
537 fn test_parse_ezconnect_sid_format() {
538 let config: Config = "myhost:1522:ORCL".parse().unwrap();
539 assert_eq!(config.host, "myhost");
540 assert_eq!(config.port, 1522);
541 assert_eq!(config.service, ServiceMethod::Sid("ORCL".to_string()));
542 }
543
544 #[test]
545 fn test_parse_host_only() {
546 let config: Config = "myhost".parse().unwrap();
547 assert_eq!(config.host, "myhost");
548 assert_eq!(config.port, DEFAULT_PORT);
549 }
550
551 #[test]
552 fn test_parse_host_port() {
553 let config: Config = "myhost:1522".parse().unwrap();
554 assert_eq!(config.host, "myhost");
555 assert_eq!(config.port, 1522);
556 }
557
558 #[test]
559 fn test_parse_empty() {
560 let result: Result<Config> = "".parse();
561 assert!(result.is_err());
562 }
563
564 #[test]
565 fn test_parse_invalid_port() {
566 let result: Result<Config> = "myhost:notaport/service".parse();
567 assert!(result.is_err());
568 }
569
570 #[test]
571 fn test_build_connect_string() {
572 let config = Config::new("myhost", 1522, "myservice", "user", "pass");
573 let connect_str = config.build_connect_string();
574 assert!(connect_str.contains("(HOST=myhost)"));
575 assert!(connect_str.contains("(PORT=1522)"));
576 assert!(connect_str.contains("(SERVICE_NAME=myservice)"));
577 assert!(connect_str.contains("(PROTOCOL=TCP)"));
578 }
579
580 #[test]
581 fn test_build_connect_string_sid() {
582 let config = Config::with_sid("myhost", 1522, "ORCL", "user", "pass");
583 let connect_str = config.build_connect_string();
584 assert!(connect_str.contains("(SID=ORCL)"));
585 }
586
587 #[test]
588 fn test_config_display() {
589 let config = Config::new("myhost", 1522, "myservice", "user", "pass");
590 assert_eq!(config.to_string(), "myhost:1522/myservice");
591
592 let config_sid = Config::with_sid("myhost", 1522, "ORCL", "user", "pass");
593 assert_eq!(config_sid.to_string(), "myhost:1522:ORCL");
594 }
595
596 #[test]
597 fn test_config_builder_pattern() {
598 let config = Config::new("host", 1521, "svc", "user", "pass")
599 .tls(TlsMode::Require)
600 .connect_timeout(Duration::from_secs(30))
601 .sdu(16384);
602
603 assert_eq!(config.tls_mode, TlsMode::Require);
604 assert_eq!(config.connect_timeout, Duration::from_secs(30));
605 assert_eq!(config.sdu, 16384);
606 }
607}