oracle_rs/
config.rs

1//! Connection configuration and connection string parsing
2//!
3//! Supports Oracle EZConnect format:
4//! - `host:port/service_name`
5//! - `host/service_name`
6//! - `host:port:sid`
7//!
8//! And TNS-style connect descriptors.
9
10use std::fmt;
11use std::str::FromStr;
12use std::time::Duration;
13
14use crate::constants::charset;
15use crate::error::{Error, Result};
16use crate::transport::TlsConfig;
17
18/// Default Oracle port
19pub const DEFAULT_PORT: u16 = 1521;
20
21/// Default SDU size
22pub const DEFAULT_SDU: u32 = 8192;
23
24/// Default statement cache size (matches python-oracledb default)
25pub const DEFAULT_STMTCACHESIZE: usize = 20;
26
27/// Service identification method
28#[derive(Debug, Clone, PartialEq, Eq)]
29pub enum ServiceMethod {
30    /// Connect using service name
31    ServiceName(String),
32    /// Connect using SID (legacy)
33    Sid(String),
34}
35
36impl ServiceMethod {
37    /// Get the service name if this is a ServiceName variant
38    pub fn service_name(&self) -> Option<&str> {
39        match self {
40            ServiceMethod::ServiceName(s) => Some(s),
41            ServiceMethod::Sid(_) => None,
42        }
43    }
44
45    /// Get the SID if this is a Sid variant
46    pub fn sid(&self) -> Option<&str> {
47        match self {
48            ServiceMethod::ServiceName(_) => None,
49            ServiceMethod::Sid(s) => Some(s),
50        }
51    }
52}
53
54/// TLS mode for connections
55#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
56pub enum TlsMode {
57    /// No TLS (plain TCP)
58    #[default]
59    Disable,
60    /// Require TLS (TCPS)
61    Require,
62}
63
64/// Connection configuration for Oracle databases.
65///
66/// This struct holds all the parameters needed to establish a connection to an
67/// Oracle database, including host, port, credentials, TLS settings, and more.
68///
69/// # Examples
70///
71/// ## Basic connection
72///
73/// ```rust
74/// use oracle_rs::Config;
75///
76/// let config = Config::new("localhost", 1521, "FREEPDB1", "user", "password");
77/// ```
78///
79/// ## TLS connection with system certificates
80///
81/// ```rust,no_run
82/// use oracle_rs::Config;
83///
84/// let config = Config::new("hostname", 2484, "service", "user", "password")
85///     .with_tls()
86///     .expect("TLS configuration failed");
87/// ```
88///
89/// ## TLS connection with Oracle wallet
90///
91/// ```rust,ignore
92/// use oracle_rs::Config;
93///
94/// let config = Config::new("hostname", 2484, "service", "user", "password")
95///     .with_wallet("/path/to/wallet", Some("wallet_password"))
96///     .expect("Wallet configuration failed");
97/// ```
98///
99/// ## With custom options
100///
101/// ```rust
102/// use oracle_rs::Config;
103/// use std::time::Duration;
104///
105/// let config = Config::new("localhost", 1521, "FREEPDB1", "user", "password")
106///     .connect_timeout(Duration::from_secs(30))
107///     .stmtcachesize(50);
108/// ```
109#[derive(Debug, Clone)]
110pub struct Config {
111    /// Host to connect to
112    pub host: String,
113    /// Port to connect to
114    pub port: u16,
115    /// Service name or SID
116    pub service: ServiceMethod,
117    /// Username for authentication
118    pub username: String,
119    /// Password for authentication (stored temporarily)
120    password: String,
121    /// TLS mode
122    pub tls_mode: TlsMode,
123    /// TLS configuration (certificates, wallet, etc.)
124    pub tls_config: Option<TlsConfig>,
125    /// Connection timeout
126    pub connect_timeout: Duration,
127    /// SDU (Session Data Unit) size
128    pub sdu: u32,
129    /// Client charset ID
130    pub charset_id: u16,
131    /// National charset ID
132    pub ncharset_id: u16,
133    /// Statement cache size (0 = disabled)
134    pub stmtcachesize: usize,
135}
136
137impl Config {
138    /// Create a new configuration with service name
139    pub fn new(
140        host: impl Into<String>,
141        port: u16,
142        service_name: impl Into<String>,
143        username: impl Into<String>,
144        password: impl Into<String>,
145    ) -> Self {
146        Self {
147            host: host.into(),
148            port,
149            service: ServiceMethod::ServiceName(service_name.into()),
150            username: username.into(),
151            password: password.into(),
152            tls_mode: TlsMode::Disable,
153            tls_config: None,
154            connect_timeout: Duration::from_secs(10),
155            sdu: DEFAULT_SDU,
156            charset_id: charset::UTF8,
157            ncharset_id: charset::UTF16,
158            stmtcachesize: DEFAULT_STMTCACHESIZE,
159        }
160    }
161
162    /// Create a new configuration with SID
163    pub fn with_sid(
164        host: impl Into<String>,
165        port: u16,
166        sid: impl Into<String>,
167        username: impl Into<String>,
168        password: impl Into<String>,
169    ) -> Self {
170        Self {
171            host: host.into(),
172            port,
173            service: ServiceMethod::Sid(sid.into()),
174            username: username.into(),
175            password: password.into(),
176            tls_mode: TlsMode::Disable,
177            tls_config: None,
178            connect_timeout: Duration::from_secs(10),
179            sdu: DEFAULT_SDU,
180            charset_id: charset::UTF8,
181            ncharset_id: charset::UTF16,
182            stmtcachesize: DEFAULT_STMTCACHESIZE,
183        }
184    }
185
186    /// Set TLS mode
187    pub fn tls(mut self, mode: TlsMode) -> Self {
188        self.tls_mode = mode;
189        self
190    }
191
192    /// Set TLS configuration
193    pub fn tls_config(mut self, config: TlsConfig) -> Self {
194        self.tls_config = Some(config);
195        self.tls_mode = TlsMode::Require;
196        self
197    }
198
199    /// Check if TLS is enabled
200    pub fn is_tls_enabled(&self) -> bool {
201        self.tls_mode == TlsMode::Require
202    }
203
204    /// Enable TLS with system root certificates.
205    ///
206    /// This configures the connection to use TLS (TCPS protocol) with the
207    /// system's trusted root certificate store for server verification.
208    ///
209    /// # Example
210    ///
211    /// ```rust
212    /// use oracle_rs::Config;
213    ///
214    /// let config = Config::new("hostname", 2484, "service", "user", "password")
215    ///     .with_tls()
216    ///     .expect("TLS configuration failed");
217    /// ```
218    pub fn with_tls(mut self) -> Result<Self> {
219        let tls_config = TlsConfig::new();
220        // Validate that TLS config can be built
221        tls_config.build_client_config()?;
222        self.tls_config = Some(tls_config);
223        self.tls_mode = TlsMode::Require;
224        Ok(self)
225    }
226
227    /// Enable TLS with an Oracle wallet.
228    ///
229    /// Oracle wallets (ewallet.p12 or ewallet.pem files) contain certificates
230    /// and keys for secure connections. This is the standard way to configure
231    /// TLS for Oracle Cloud and enterprise deployments.
232    ///
233    /// # Arguments
234    ///
235    /// * `wallet_path` - Path to the wallet directory containing ewallet.pem
236    /// * `wallet_password` - Optional password for encrypted wallets
237    ///
238    /// # Example
239    ///
240    /// ```rust,ignore
241    /// use oracle_rs::Config;
242    ///
243    /// let config = Config::new("hostname", 2484, "service", "user", "password")
244    ///     .with_wallet("/path/to/wallet", Some("wallet_password"))
245    ///     .expect("Wallet configuration failed");
246    /// ```
247    pub fn with_wallet(
248        mut self,
249        wallet_path: impl Into<String>,
250        wallet_password: Option<&str>,
251    ) -> Result<Self> {
252        let tls_config = TlsConfig::new()
253            .with_wallet(wallet_path, wallet_password.map(|s| s.to_string()));
254        // Validate that TLS config can be built
255        tls_config.build_client_config()?;
256        self.tls_config = Some(tls_config);
257        self.tls_mode = TlsMode::Require;
258        Ok(self)
259    }
260
261    /// Configure DRCP (Database Resident Connection Pooling).
262    ///
263    /// DRCP allows the database server to maintain a pool of connections that
264    /// can be shared across multiple client processes, reducing server resource
265    /// usage.
266    ///
267    /// # Arguments
268    ///
269    /// * `connection_class` - Name identifying this class of connections
270    /// * `purity` - Either "self" (dedicated) or "new" (can share with others)
271    ///
272    /// # Example
273    ///
274    /// ```rust
275    /// use oracle_rs::Config;
276    ///
277    /// let config = Config::new("hostname", 1521, "service", "user", "password")
278    ///     .with_drcp("my_app", "self");
279    /// ```
280    pub fn with_drcp(self, _connection_class: &str, _purity: &str) -> Self {
281        // TODO: Implement DRCP configuration storage
282        // For now, DRCP is handled at connection time via the connect descriptor
283        self
284    }
285
286    /// Set the statement cache size.
287    ///
288    /// Statement caching improves performance by reusing parsed SQL statements.
289    /// Set to 0 to disable caching.
290    ///
291    /// Default is 20 (matches python-oracledb default).
292    ///
293    /// # Example
294    ///
295    /// ```rust
296    /// use oracle_rs::Config;
297    ///
298    /// let config = Config::new("localhost", 1521, "FREEPDB1", "user", "password")
299    ///     .with_statement_cache_size(100);
300    /// ```
301    pub fn with_statement_cache_size(mut self, size: usize) -> Self {
302        self.stmtcachesize = size;
303        self
304    }
305
306    /// Set connection timeout
307    pub fn connect_timeout(mut self, timeout: Duration) -> Self {
308        self.connect_timeout = timeout;
309        self
310    }
311
312    /// Set SDU size
313    pub fn sdu(mut self, sdu: u32) -> Self {
314        self.sdu = sdu;
315        self
316    }
317
318    /// Set statement cache size
319    ///
320    /// Set to 0 to disable statement caching.
321    /// Default is 20 (matches python-oracledb default).
322    pub fn stmtcachesize(mut self, size: usize) -> Self {
323        self.stmtcachesize = size;
324        self
325    }
326
327    /// Get the password (for authentication)
328    pub(crate) fn password(&self) -> &str {
329        &self.password
330    }
331
332    /// Set the password
333    pub fn set_password(&mut self, password: impl Into<String>) {
334        self.password = password.into();
335    }
336
337    /// Set the username
338    pub fn set_username(&mut self, username: impl Into<String>) {
339        self.username = username.into();
340    }
341
342    /// Build a TNS connect descriptor string
343    pub fn build_connect_string(&self) -> String {
344        let mut parts = Vec::new();
345
346        // Address
347        let protocol = match self.tls_mode {
348            TlsMode::Disable => "TCP",
349            TlsMode::Require => "TCPS",
350        };
351        parts.push(format!(
352            "(ADDRESS=(PROTOCOL={})(HOST={})(PORT={}))",
353            protocol, self.host, self.port
354        ));
355
356        // Connect data
357        let service_part = match &self.service {
358            ServiceMethod::ServiceName(name) => format!("(SERVICE_NAME={})", name),
359            ServiceMethod::Sid(sid) => format!("(SID={})", sid),
360        };
361        parts.push(format!("(CONNECT_DATA={})", service_part));
362
363        format!("(DESCRIPTION={})", parts.join(""))
364    }
365
366    /// Get the socket address string
367    pub fn socket_addr(&self) -> String {
368        format!("{}:{}", self.host, self.port)
369    }
370}
371
372impl Default for Config {
373    fn default() -> Self {
374        Self {
375            host: "localhost".to_string(),
376            port: DEFAULT_PORT,
377            service: ServiceMethod::ServiceName("FREEPDB1".to_string()),
378            username: String::new(),
379            password: String::new(),
380            tls_mode: TlsMode::Disable,
381            tls_config: None,
382            connect_timeout: Duration::from_secs(10),
383            sdu: DEFAULT_SDU,
384            charset_id: charset::UTF8,
385            ncharset_id: charset::UTF16,
386            stmtcachesize: DEFAULT_STMTCACHESIZE,
387        }
388    }
389}
390
391/// Parse an EZConnect-style connection string
392///
393/// Formats supported:
394/// - `host:port/service_name`
395/// - `host/service_name`
396/// - `host:port:sid`
397/// - `//host:port/service_name` (with optional leading slashes)
398impl FromStr for Config {
399    type Err = Error;
400
401    fn from_str(s: &str) -> Result<Self> {
402        let s = s.trim();
403
404        // Strip leading slashes if present
405        let s = s.trim_start_matches('/');
406
407        if s.is_empty() {
408            return Err(Error::InvalidConnectionString(
409                "empty connection string".to_string(),
410            ));
411        }
412
413        // Check for TNS descriptor format
414        if s.starts_with('(') {
415            return Err(Error::InvalidConnectionString(
416                "TNS descriptor format not yet supported, use EZConnect format".to_string(),
417            ));
418        }
419
420        // Parse EZConnect format
421        // Format: [//]host[:port][/service_name] or host:port:sid
422
423        let mut config = Config::default();
424
425        // Check for service_name format (contains /)
426        if let Some(slash_pos) = s.find('/') {
427            let host_port = &s[..slash_pos];
428            let service_name = &s[slash_pos + 1..];
429
430            if service_name.is_empty() {
431                return Err(Error::InvalidConnectionString(
432                    "missing service name after /".to_string(),
433                ));
434            }
435
436            config.service = ServiceMethod::ServiceName(service_name.to_string());
437
438            // Parse host:port
439            if let Some(colon_pos) = host_port.find(':') {
440                config.host = host_port[..colon_pos].to_string();
441                config.port = host_port[colon_pos + 1..]
442                    .parse()
443                    .map_err(|_| Error::InvalidConnectionString("invalid port number".to_string()))?;
444            } else {
445                config.host = host_port.to_string();
446                config.port = DEFAULT_PORT;
447            }
448        } else {
449            // Check for SID format (host:port:sid)
450            let parts: Vec<&str> = s.split(':').collect();
451
452            match parts.len() {
453                1 => {
454                    // Just host, use defaults
455                    config.host = parts[0].to_string();
456                }
457                2 => {
458                    // host:port
459                    config.host = parts[0].to_string();
460                    config.port = parts[1]
461                        .parse()
462                        .map_err(|_| Error::InvalidConnectionString("invalid port number".to_string()))?;
463                }
464                3 => {
465                    // host:port:sid
466                    config.host = parts[0].to_string();
467                    config.port = parts[1]
468                        .parse()
469                        .map_err(|_| Error::InvalidConnectionString("invalid port number".to_string()))?;
470                    config.service = ServiceMethod::Sid(parts[2].to_string());
471                }
472                _ => {
473                    return Err(Error::InvalidConnectionString(
474                        "too many colons in connection string".to_string(),
475                    ));
476                }
477            }
478        }
479
480        if config.host.is_empty() {
481            return Err(Error::InvalidConnectionString(
482                "missing host".to_string(),
483            ));
484        }
485
486        Ok(config)
487    }
488}
489
490impl fmt::Display for Config {
491    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
492        match &self.service {
493            ServiceMethod::ServiceName(name) => {
494                write!(f, "{}:{}/{}", self.host, self.port, name)
495            }
496            ServiceMethod::Sid(sid) => {
497                write!(f, "{}:{}:{}", self.host, self.port, sid)
498            }
499        }
500    }
501}
502
503#[cfg(test)]
504mod tests {
505    use super::*;
506
507    #[test]
508    fn test_parse_ezconnect_full() {
509        let config: Config = "myhost:1522/myservice".parse().unwrap();
510        assert_eq!(config.host, "myhost");
511        assert_eq!(config.port, 1522);
512        assert_eq!(
513            config.service,
514            ServiceMethod::ServiceName("myservice".to_string())
515        );
516    }
517
518    #[test]
519    fn test_parse_ezconnect_default_port() {
520        let config: Config = "myhost/myservice".parse().unwrap();
521        assert_eq!(config.host, "myhost");
522        assert_eq!(config.port, DEFAULT_PORT);
523        assert_eq!(
524            config.service,
525            ServiceMethod::ServiceName("myservice".to_string())
526        );
527    }
528
529    #[test]
530    fn test_parse_ezconnect_with_slashes() {
531        let config: Config = "//myhost:1522/myservice".parse().unwrap();
532        assert_eq!(config.host, "myhost");
533        assert_eq!(config.port, 1522);
534    }
535
536    #[test]
537    fn test_parse_ezconnect_sid_format() {
538        let config: Config = "myhost:1522:ORCL".parse().unwrap();
539        assert_eq!(config.host, "myhost");
540        assert_eq!(config.port, 1522);
541        assert_eq!(config.service, ServiceMethod::Sid("ORCL".to_string()));
542    }
543
544    #[test]
545    fn test_parse_host_only() {
546        let config: Config = "myhost".parse().unwrap();
547        assert_eq!(config.host, "myhost");
548        assert_eq!(config.port, DEFAULT_PORT);
549    }
550
551    #[test]
552    fn test_parse_host_port() {
553        let config: Config = "myhost:1522".parse().unwrap();
554        assert_eq!(config.host, "myhost");
555        assert_eq!(config.port, 1522);
556    }
557
558    #[test]
559    fn test_parse_empty() {
560        let result: Result<Config> = "".parse();
561        assert!(result.is_err());
562    }
563
564    #[test]
565    fn test_parse_invalid_port() {
566        let result: Result<Config> = "myhost:notaport/service".parse();
567        assert!(result.is_err());
568    }
569
570    #[test]
571    fn test_build_connect_string() {
572        let config = Config::new("myhost", 1522, "myservice", "user", "pass");
573        let connect_str = config.build_connect_string();
574        assert!(connect_str.contains("(HOST=myhost)"));
575        assert!(connect_str.contains("(PORT=1522)"));
576        assert!(connect_str.contains("(SERVICE_NAME=myservice)"));
577        assert!(connect_str.contains("(PROTOCOL=TCP)"));
578    }
579
580    #[test]
581    fn test_build_connect_string_sid() {
582        let config = Config::with_sid("myhost", 1522, "ORCL", "user", "pass");
583        let connect_str = config.build_connect_string();
584        assert!(connect_str.contains("(SID=ORCL)"));
585    }
586
587    #[test]
588    fn test_config_display() {
589        let config = Config::new("myhost", 1522, "myservice", "user", "pass");
590        assert_eq!(config.to_string(), "myhost:1522/myservice");
591
592        let config_sid = Config::with_sid("myhost", 1522, "ORCL", "user", "pass");
593        assert_eq!(config_sid.to_string(), "myhost:1522:ORCL");
594    }
595
596    #[test]
597    fn test_config_builder_pattern() {
598        let config = Config::new("host", 1521, "svc", "user", "pass")
599            .tls(TlsMode::Require)
600            .connect_timeout(Duration::from_secs(30))
601            .sdu(16384);
602
603        assert_eq!(config.tls_mode, TlsMode::Require);
604        assert_eq!(config.connect_timeout, Duration::from_secs(30));
605        assert_eq!(config.sdu, 16384);
606    }
607}