use std::fmt;
use std::str::FromStr;
use std::time::Duration;
use crate::constants::charset;
use crate::error::{Error, Result};
use crate::transport::TlsConfig;
pub const DEFAULT_PORT: u16 = 1521;
pub const DEFAULT_SDU: u32 = 8192;
pub const DEFAULT_STMTCACHESIZE: usize = 20;
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum ServiceMethod {
ServiceName(String),
Sid(String),
}
impl ServiceMethod {
pub fn service_name(&self) -> Option<&str> {
match self {
ServiceMethod::ServiceName(s) => Some(s),
ServiceMethod::Sid(_) => None,
}
}
pub fn sid(&self) -> Option<&str> {
match self {
ServiceMethod::ServiceName(_) => None,
ServiceMethod::Sid(s) => Some(s),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum TlsMode {
#[default]
Disable,
Require,
}
#[derive(Debug, Clone)]
pub struct Config {
pub host: String,
pub port: u16,
pub service: ServiceMethod,
pub username: String,
password: String,
pub tls_mode: TlsMode,
pub tls_config: Option<TlsConfig>,
pub connect_timeout: Duration,
pub sdu: u32,
pub charset_id: u16,
pub ncharset_id: u16,
pub stmtcachesize: usize,
}
impl Config {
pub fn new(
host: impl Into<String>,
port: u16,
service_name: impl Into<String>,
username: impl Into<String>,
password: impl Into<String>,
) -> Self {
Self {
host: host.into(),
port,
service: ServiceMethod::ServiceName(service_name.into()),
username: username.into(),
password: password.into(),
tls_mode: TlsMode::Disable,
tls_config: None,
connect_timeout: Duration::from_secs(10),
sdu: DEFAULT_SDU,
charset_id: charset::UTF8,
ncharset_id: charset::UTF16,
stmtcachesize: DEFAULT_STMTCACHESIZE,
}
}
pub fn with_sid(
host: impl Into<String>,
port: u16,
sid: impl Into<String>,
username: impl Into<String>,
password: impl Into<String>,
) -> Self {
Self {
host: host.into(),
port,
service: ServiceMethod::Sid(sid.into()),
username: username.into(),
password: password.into(),
tls_mode: TlsMode::Disable,
tls_config: None,
connect_timeout: Duration::from_secs(10),
sdu: DEFAULT_SDU,
charset_id: charset::UTF8,
ncharset_id: charset::UTF16,
stmtcachesize: DEFAULT_STMTCACHESIZE,
}
}
pub fn tls(mut self, mode: TlsMode) -> Self {
self.tls_mode = mode;
self
}
pub fn tls_config(mut self, config: TlsConfig) -> Self {
self.tls_config = Some(config);
self.tls_mode = TlsMode::Require;
self
}
pub fn is_tls_enabled(&self) -> bool {
self.tls_mode == TlsMode::Require
}
pub fn with_tls(mut self) -> Result<Self> {
let tls_config = TlsConfig::new();
tls_config.build_client_config()?;
self.tls_config = Some(tls_config);
self.tls_mode = TlsMode::Require;
Ok(self)
}
pub fn with_wallet(
mut self,
wallet_path: impl Into<String>,
wallet_password: Option<&str>,
) -> Result<Self> {
let tls_config = TlsConfig::new()
.with_wallet(wallet_path, wallet_password.map(|s| s.to_string()));
tls_config.build_client_config()?;
self.tls_config = Some(tls_config);
self.tls_mode = TlsMode::Require;
Ok(self)
}
pub fn with_drcp(self, _connection_class: &str, _purity: &str) -> Self {
self
}
pub fn with_statement_cache_size(mut self, size: usize) -> Self {
self.stmtcachesize = size;
self
}
pub fn connect_timeout(mut self, timeout: Duration) -> Self {
self.connect_timeout = timeout;
self
}
pub fn sdu(mut self, sdu: u32) -> Self {
self.sdu = sdu;
self
}
pub fn stmtcachesize(mut self, size: usize) -> Self {
self.stmtcachesize = size;
self
}
pub(crate) fn password(&self) -> &str {
&self.password
}
pub fn set_password(&mut self, password: impl Into<String>) {
self.password = password.into();
}
pub fn set_username(&mut self, username: impl Into<String>) {
self.username = username.into();
}
pub fn build_connect_string(&self) -> String {
let mut parts = Vec::new();
let protocol = match self.tls_mode {
TlsMode::Disable => "TCP",
TlsMode::Require => "TCPS",
};
parts.push(format!(
"(ADDRESS=(PROTOCOL={})(HOST={})(PORT={}))",
protocol, self.host, self.port
));
let service_part = match &self.service {
ServiceMethod::ServiceName(name) => format!("(SERVICE_NAME={})", name),
ServiceMethod::Sid(sid) => format!("(SID={})", sid),
};
parts.push(format!("(CONNECT_DATA={})", service_part));
format!("(DESCRIPTION={})", parts.join(""))
}
pub fn socket_addr(&self) -> String {
format!("{}:{}", self.host, self.port)
}
}
impl Default for Config {
fn default() -> Self {
Self {
host: "localhost".to_string(),
port: DEFAULT_PORT,
service: ServiceMethod::ServiceName("FREEPDB1".to_string()),
username: String::new(),
password: String::new(),
tls_mode: TlsMode::Disable,
tls_config: None,
connect_timeout: Duration::from_secs(10),
sdu: DEFAULT_SDU,
charset_id: charset::UTF8,
ncharset_id: charset::UTF16,
stmtcachesize: DEFAULT_STMTCACHESIZE,
}
}
}
impl FromStr for Config {
type Err = Error;
fn from_str(s: &str) -> Result<Self> {
let s = s.trim();
let s = s.trim_start_matches('/');
if s.is_empty() {
return Err(Error::InvalidConnectionString(
"empty connection string".to_string(),
));
}
if s.starts_with('(') {
return Err(Error::InvalidConnectionString(
"TNS descriptor format not yet supported, use EZConnect format".to_string(),
));
}
let mut config = Config::default();
if let Some(slash_pos) = s.find('/') {
let host_port = &s[..slash_pos];
let service_name = &s[slash_pos + 1..];
if service_name.is_empty() {
return Err(Error::InvalidConnectionString(
"missing service name after /".to_string(),
));
}
config.service = ServiceMethod::ServiceName(service_name.to_string());
if let Some(colon_pos) = host_port.find(':') {
config.host = host_port[..colon_pos].to_string();
config.port = host_port[colon_pos + 1..]
.parse()
.map_err(|_| Error::InvalidConnectionString("invalid port number".to_string()))?;
} else {
config.host = host_port.to_string();
config.port = DEFAULT_PORT;
}
} else {
let parts: Vec<&str> = s.split(':').collect();
match parts.len() {
1 => {
config.host = parts[0].to_string();
}
2 => {
config.host = parts[0].to_string();
config.port = parts[1]
.parse()
.map_err(|_| Error::InvalidConnectionString("invalid port number".to_string()))?;
}
3 => {
config.host = parts[0].to_string();
config.port = parts[1]
.parse()
.map_err(|_| Error::InvalidConnectionString("invalid port number".to_string()))?;
config.service = ServiceMethod::Sid(parts[2].to_string());
}
_ => {
return Err(Error::InvalidConnectionString(
"too many colons in connection string".to_string(),
));
}
}
}
if config.host.is_empty() {
return Err(Error::InvalidConnectionString(
"missing host".to_string(),
));
}
Ok(config)
}
}
impl fmt::Display for Config {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match &self.service {
ServiceMethod::ServiceName(name) => {
write!(f, "{}:{}/{}", self.host, self.port, name)
}
ServiceMethod::Sid(sid) => {
write!(f, "{}:{}:{}", self.host, self.port, sid)
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parse_ezconnect_full() {
let config: Config = "myhost:1522/myservice".parse().unwrap();
assert_eq!(config.host, "myhost");
assert_eq!(config.port, 1522);
assert_eq!(
config.service,
ServiceMethod::ServiceName("myservice".to_string())
);
}
#[test]
fn test_parse_ezconnect_default_port() {
let config: Config = "myhost/myservice".parse().unwrap();
assert_eq!(config.host, "myhost");
assert_eq!(config.port, DEFAULT_PORT);
assert_eq!(
config.service,
ServiceMethod::ServiceName("myservice".to_string())
);
}
#[test]
fn test_parse_ezconnect_with_slashes() {
let config: Config = "//myhost:1522/myservice".parse().unwrap();
assert_eq!(config.host, "myhost");
assert_eq!(config.port, 1522);
}
#[test]
fn test_parse_ezconnect_sid_format() {
let config: Config = "myhost:1522:ORCL".parse().unwrap();
assert_eq!(config.host, "myhost");
assert_eq!(config.port, 1522);
assert_eq!(config.service, ServiceMethod::Sid("ORCL".to_string()));
}
#[test]
fn test_parse_host_only() {
let config: Config = "myhost".parse().unwrap();
assert_eq!(config.host, "myhost");
assert_eq!(config.port, DEFAULT_PORT);
}
#[test]
fn test_parse_host_port() {
let config: Config = "myhost:1522".parse().unwrap();
assert_eq!(config.host, "myhost");
assert_eq!(config.port, 1522);
}
#[test]
fn test_parse_empty() {
let result: Result<Config> = "".parse();
assert!(result.is_err());
}
#[test]
fn test_parse_invalid_port() {
let result: Result<Config> = "myhost:notaport/service".parse();
assert!(result.is_err());
}
#[test]
fn test_build_connect_string() {
let config = Config::new("myhost", 1522, "myservice", "user", "pass");
let connect_str = config.build_connect_string();
assert!(connect_str.contains("(HOST=myhost)"));
assert!(connect_str.contains("(PORT=1522)"));
assert!(connect_str.contains("(SERVICE_NAME=myservice)"));
assert!(connect_str.contains("(PROTOCOL=TCP)"));
}
#[test]
fn test_build_connect_string_sid() {
let config = Config::with_sid("myhost", 1522, "ORCL", "user", "pass");
let connect_str = config.build_connect_string();
assert!(connect_str.contains("(SID=ORCL)"));
}
#[test]
fn test_config_display() {
let config = Config::new("myhost", 1522, "myservice", "user", "pass");
assert_eq!(config.to_string(), "myhost:1522/myservice");
let config_sid = Config::with_sid("myhost", 1522, "ORCL", "user", "pass");
assert_eq!(config_sid.to_string(), "myhost:1522:ORCL");
}
#[test]
fn test_config_builder_pattern() {
let config = Config::new("host", 1521, "svc", "user", "pass")
.tls(TlsMode::Require)
.connect_timeout(Duration::from_secs(30))
.sdu(16384);
assert_eq!(config.tls_mode, TlsMode::Require);
assert_eq!(config.connect_timeout, Duration::from_secs(30));
assert_eq!(config.sdu, 16384);
}
}