use bytes::Bytes;
use std::collections::HashMap;
use crate::buffer::{ReadBuffer, WriteBuffer};
use crate::capabilities::Capabilities;
use crate::constants::{auth_mode, verifier_type, FunctionCode, MessageType, PacketType, PACKET_HEADER_SIZE};
use crate::crypto::{
decrypt_cbc_192, decrypt_cbc_256, encrypt_cbc_192, encrypt_cbc_256_pkcs7,
generate_11g_combo_key, generate_11g_password_hash, generate_12c_combo_key,
generate_12c_password_hash, generate_salt, generate_session_key_part, pbkdf2_derive,
};
use crate::error::{Error, Result};
use crate::packet::PacketHeader;
#[derive(Debug, Default)]
pub struct SessionData {
pub auth_sesskey: Option<String>,
pub auth_vfr_data: Option<String>,
pub auth_pbkdf2_csk_salt: Option<String>,
pub auth_pbkdf2_vgen_count: Option<u32>,
pub auth_pbkdf2_sder_count: Option<u32>,
pub auth_version_no: Option<u32>,
pub auth_globally_unique_dbid: Option<String>,
pub auth_svr_response: Option<String>,
}
impl SessionData {
pub fn from_pairs(pairs: &HashMap<String, String>) -> Self {
let mut data = SessionData::default();
for (key, value) in pairs {
match key.as_str() {
"AUTH_SESSKEY" => data.auth_sesskey = Some(value.clone()),
"AUTH_VFR_DATA" => data.auth_vfr_data = Some(value.clone()),
"AUTH_PBKDF2_CSK_SALT" => data.auth_pbkdf2_csk_salt = Some(value.clone()),
"AUTH_PBKDF2_VGEN_COUNT" => {
data.auth_pbkdf2_vgen_count = value.parse().ok();
}
"AUTH_PBKDF2_SDER_COUNT" => {
data.auth_pbkdf2_sder_count = value.parse().ok();
}
"AUTH_VERSION_NO" => {
data.auth_version_no = value.parse().ok();
}
"AUTH_GLOBALLY_UNIQUE_DBID" => {
data.auth_globally_unique_dbid = Some(value.clone());
}
"AUTH_SVR_RESPONSE" => data.auth_svr_response = Some(value.clone()),
_ => {} }
}
data
}
}
#[derive(Debug)]
pub struct AuthMessage {
username: String,
password: Vec<u8>,
phase: AuthPhase,
auth_mode: u32,
session_data: SessionData,
verifier_type: u32,
combo_key: Option<Vec<u8>>,
client_session_key: Option<Vec<u8>>,
terminal: String,
program: String,
machine: String,
osuser: String,
pid: String,
driver_name: String,
_service_name: String,
sequence_number: u8,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum AuthPhase {
One,
Two,
Complete,
}
impl AuthMessage {
pub fn new(
username: &str,
password: &[u8],
service_name: &str,
) -> Self {
Self {
username: username.to_uppercase(),
password: password.to_vec(),
phase: AuthPhase::One,
auth_mode: auth_mode::LOGON,
session_data: SessionData::default(),
verifier_type: 0,
combo_key: None,
client_session_key: None,
terminal: std::env::var("TERM").unwrap_or_else(|_| "unknown".to_string()),
program: std::env::current_exe()
.map(|p| p.file_name().unwrap_or_default().to_string_lossy().to_string())
.unwrap_or_else(|_| "oracle-rs".to_string()),
machine: hostname::get()
.map(|h| h.to_string_lossy().to_string())
.unwrap_or_else(|_| "localhost".to_string()),
osuser: std::env::var("USER")
.or_else(|_| std::env::var("USERNAME"))
.unwrap_or_else(|_| "unknown".to_string()),
pid: std::process::id().to_string(),
driver_name: format!("oracle-rs : {}", env!("CARGO_PKG_VERSION")),
_service_name: service_name.to_string(),
sequence_number: 1,
}
}
pub fn set_sequence_number(&mut self, seq: u8) {
self.sequence_number = seq;
}
pub fn with_sysdba(mut self) -> Self {
self.auth_mode |= auth_mode::SYSDBA;
self
}
pub fn with_sysoper(mut self) -> Self {
self.auth_mode |= auth_mode::SYSOPER;
self
}
pub fn phase(&self) -> AuthPhase {
self.phase
}
pub fn is_complete(&self) -> bool {
self.phase == AuthPhase::Complete
}
pub fn combo_key(&self) -> Option<&[u8]> {
self.combo_key.as_deref()
}
pub fn build_request(&self, caps: &Capabilities, large_sdu: bool) -> Result<Bytes> {
match self.phase {
AuthPhase::One => self.build_phase_one(caps, large_sdu),
AuthPhase::Two => self.build_phase_two(caps, large_sdu),
AuthPhase::Complete => Err(Error::Protocol("Authentication already complete".to_string())),
}
}
fn build_phase_one(&self, caps: &Capabilities, large_sdu: bool) -> Result<Bytes> {
let mut buf = WriteBuffer::with_capacity(512);
buf.write_zeros(PACKET_HEADER_SIZE)?;
buf.write_u16_be(0)?;
buf.write_u8(MessageType::Function as u8)?;
buf.write_u8(FunctionCode::AuthPhaseOne as u8)?;
buf.write_u8(self.sequence_number)?;
if caps.ttc_field_version >= 18 {
buf.write_ub8(0)?;
}
let has_user = !self.username.is_empty();
buf.write_u8(if has_user { 1 } else { 0 })?;
let user_bytes = self.username.as_bytes();
buf.write_ub4(user_bytes.len() as u32)?;
buf.write_ub4(self.auth_mode)?;
buf.write_u8(1)?;
let num_pairs = 5u32;
buf.write_ub4(num_pairs)?;
buf.write_u8(1)?;
buf.write_u8(1)?;
if has_user {
buf.write_bytes_with_length(Some(user_bytes))?;
}
self.write_key_value(&mut buf, "AUTH_TERMINAL", &self.terminal, 0)?;
self.write_key_value(&mut buf, "AUTH_PROGRAM_NM", &self.program, 0)?;
self.write_key_value(&mut buf, "AUTH_MACHINE", &self.machine, 0)?;
self.write_key_value(&mut buf, "AUTH_PID", &self.pid, 0)?;
self.write_key_value(&mut buf, "AUTH_SID", &self.osuser, 0)?;
let total_len = buf.len() as u32;
let header = PacketHeader::new(PacketType::Data, total_len);
let mut header_buf = WriteBuffer::with_capacity(PACKET_HEADER_SIZE);
header.write(&mut header_buf, large_sdu)?;
let mut result = buf.into_inner();
result[..PACKET_HEADER_SIZE].copy_from_slice(header_buf.as_slice());
Ok(result.freeze())
}
fn build_phase_two(&self, caps: &Capabilities, large_sdu: bool) -> Result<Bytes> {
let encoded_password = self.encode_password()?;
let session_key = self.client_session_key.as_ref()
.ok_or_else(|| Error::Protocol("Client session key not generated".to_string()))?;
let mut buf = WriteBuffer::with_capacity(1024);
buf.write_zeros(PACKET_HEADER_SIZE)?;
buf.write_u16_be(0)?;
buf.write_u8(MessageType::Function as u8)?;
buf.write_u8(FunctionCode::AuthPhaseTwo as u8)?;
buf.write_u8(2)?;
if caps.ttc_field_version >= 18 {
buf.write_ub8(0)?;
}
let has_user = !self.username.is_empty();
buf.write_u8(if has_user { 1 } else { 0 })?;
let user_bytes = self.username.as_bytes();
buf.write_ub4(user_bytes.len() as u32)?;
let mode = self.auth_mode | auth_mode::WITH_PASSWORD;
buf.write_ub4(mode)?;
buf.write_u8(1)?;
let num_pairs = if self.verifier_type == verifier_type::V12C {
7u32 } else {
6u32 };
buf.write_ub4(num_pairs)?;
buf.write_u8(1)?;
buf.write_u8(1)?;
if has_user {
buf.write_bytes_with_length(Some(user_bytes))?;
}
let session_key_hex = hex::encode_upper(session_key);
let key_len = if self.verifier_type == verifier_type::V12C { 64 } else { 96 };
let key_str = &session_key_hex[..key_len.min(session_key_hex.len())];
self.write_key_value(&mut buf, "AUTH_SESSKEY", key_str, 1)?;
if self.verifier_type == verifier_type::V12C {
if let Some(speedy) = self.generate_speedy_key()? {
self.write_key_value(&mut buf, "AUTH_PBKDF2_SPEEDY_KEY", &speedy, 0)?;
}
}
self.write_key_value(&mut buf, "AUTH_PASSWORD", &encoded_password, 0)?;
self.write_key_value(&mut buf, "SESSION_CLIENT_CHARSET", "873", 0)?;
self.write_key_value(&mut buf, "SESSION_CLIENT_DRIVER_NAME", &self.driver_name, 0)?;
self.write_key_value(&mut buf, "SESSION_CLIENT_VERSION", "54530048", 0)?;
let tz_stmt = self.get_alter_timezone_statement();
self.write_key_value(&mut buf, "AUTH_ALTER_SESSION", &tz_stmt, 1)?;
let total_len = buf.len() as u32;
let header = PacketHeader::new(PacketType::Data, total_len);
let mut header_buf = WriteBuffer::with_capacity(PACKET_HEADER_SIZE);
header.write(&mut header_buf, large_sdu)?;
let mut result = buf.into_inner();
result[..PACKET_HEADER_SIZE].copy_from_slice(header_buf.as_slice());
Ok(result.freeze())
}
fn write_key_value(
&self,
buf: &mut WriteBuffer,
key: &str,
value: &str,
flags: u32,
) -> Result<()> {
let key_bytes = key.as_bytes();
let value_bytes = value.as_bytes();
buf.write_ub4(key_bytes.len() as u32)?;
buf.write_bytes_with_length(Some(key_bytes))?;
buf.write_ub4(value_bytes.len() as u32)?;
if !value_bytes.is_empty() {
buf.write_bytes_with_length(Some(value_bytes))?;
}
buf.write_ub4(flags)?;
Ok(())
}
pub fn parse_response(&mut self, payload: &[u8]) -> Result<()> {
let mut buf = ReadBuffer::from_slice(payload);
buf.skip(2)?;
let msg_type = buf.read_u8()?;
if msg_type == MessageType::Error as u8 {
return Err(Error::AuthenticationFailed("Server returned error".to_string()));
}
let num_params = buf.read_ub2()?;
let mut pairs = HashMap::new();
let mut vtype = 0u32;
for _ in 0..num_params {
let key = Self::read_auth_string(&mut buf)?;
let value = Self::read_auth_string(&mut buf)?;
if key == "AUTH_VFR_DATA" {
vtype = buf.read_ub4()?;
} else {
buf.skip_ub4()?; }
pairs.insert(key, value);
}
self.session_data = SessionData::from_pairs(&pairs);
if vtype != 0 {
self.verifier_type = vtype;
}
match self.phase {
AuthPhase::One => {
self.phase = AuthPhase::Two;
self.generate_verifier()?;
}
AuthPhase::Two => {
self.phase = AuthPhase::Complete;
self.verify_server_response()?;
}
AuthPhase::Complete => {}
}
Ok(())
}
fn read_auth_string(buf: &mut ReadBuffer) -> Result<String> {
let declared_len = buf.read_ub4()?;
if declared_len == 0 {
return Ok(String::new());
}
match buf.read_bytes_with_length()? {
Some(bytes) => Ok(String::from_utf8_lossy(&bytes).to_string()),
None => Ok(String::new()),
}
}
fn generate_verifier(&mut self) -> Result<()> {
let vfr_data = self.session_data.auth_vfr_data.as_ref()
.ok_or_else(|| Error::AuthenticationFailed("Missing AUTH_VFR_DATA".to_string()))?;
let vfr_bytes = hex::decode(vfr_data)
.map_err(|e| Error::Protocol(format!("Invalid AUTH_VFR_DATA hex: {}", e)))?;
let server_key = self.session_data.auth_sesskey.as_ref()
.ok_or_else(|| Error::AuthenticationFailed("Missing AUTH_SESSKEY".to_string()))?;
let server_key_bytes = hex::decode(server_key)
.map_err(|e| Error::Protocol(format!("Invalid AUTH_SESSKEY hex: {}", e)))?;
match self.verifier_type {
verifier_type::V12C => self.generate_12c_verifier(&vfr_bytes, &server_key_bytes),
verifier_type::V11G_1 | verifier_type::V11G_2 => {
self.generate_11g_verifier(&vfr_bytes, &server_key_bytes)
}
_ => Err(Error::UnsupportedVerifierType(self.verifier_type)),
}
}
fn generate_12c_verifier(&mut self, vfr_data: &[u8], server_key: &[u8]) -> Result<()> {
let iterations = self.session_data.auth_pbkdf2_vgen_count
.ok_or_else(|| Error::AuthenticationFailed("Missing AUTH_PBKDF2_VGEN_COUNT".to_string()))?;
let password_hash = generate_12c_password_hash(&self.password, vfr_data, iterations);
let session_key_part_a = decrypt_cbc_256(&password_hash, server_key)?;
let session_key_part_b = generate_session_key_part(session_key_part_a.len());
let encrypted_client_key = encrypt_cbc_256_pkcs7(&password_hash, &session_key_part_b)?;
self.client_session_key = Some(encrypted_client_key);
let csk_salt = self.session_data.auth_pbkdf2_csk_salt.as_ref()
.ok_or_else(|| Error::AuthenticationFailed("Missing AUTH_PBKDF2_CSK_SALT".to_string()))?;
let csk_salt_bytes = hex::decode(csk_salt)
.map_err(|e| Error::Protocol(format!("Invalid CSK_SALT hex: {}", e)))?;
let sder_count = self.session_data.auth_pbkdf2_sder_count
.ok_or_else(|| Error::AuthenticationFailed("Missing AUTH_PBKDF2_SDER_COUNT".to_string()))?;
self.combo_key = Some(generate_12c_combo_key(
&session_key_part_a,
&session_key_part_b,
&csk_salt_bytes,
sder_count,
));
Ok(())
}
fn generate_11g_verifier(&mut self, vfr_data: &[u8], server_key: &[u8]) -> Result<()> {
let password_hash = generate_11g_password_hash(&self.password, vfr_data);
let session_key_part_a = decrypt_cbc_192(&password_hash, server_key)?;
let session_key_part_b = generate_session_key_part(session_key_part_a.len());
let encrypted_client_key = encrypt_cbc_192(&password_hash, &session_key_part_b)?;
self.client_session_key = Some(encrypted_client_key);
self.combo_key = Some(generate_11g_combo_key(
&session_key_part_a,
&session_key_part_b,
));
Ok(())
}
fn encode_password(&self) -> Result<String> {
let combo_key = self.combo_key.as_ref()
.ok_or_else(|| Error::Protocol("Combo key not generated".to_string()))?;
let salt = generate_salt();
let mut password_with_salt = salt.to_vec();
password_with_salt.extend_from_slice(&self.password);
let encrypted = if self.verifier_type == verifier_type::V12C {
encrypt_cbc_256_pkcs7(combo_key, &password_with_salt)?
} else {
encrypt_cbc_192(combo_key, &password_with_salt)?
};
Ok(hex::encode_upper(&encrypted))
}
fn generate_speedy_key(&self) -> Result<Option<String>> {
if self.verifier_type != verifier_type::V12C {
return Ok(None);
}
let combo_key = self.combo_key.as_ref()
.ok_or_else(|| Error::Protocol("Combo key not generated".to_string()))?;
let vfr_data = self.session_data.auth_vfr_data.as_ref()
.ok_or_else(|| Error::AuthenticationFailed("Missing AUTH_VFR_DATA".to_string()))?;
let vfr_bytes = hex::decode(vfr_data)
.map_err(|e| Error::Protocol(format!("Invalid AUTH_VFR_DATA hex: {}", e)))?;
let iterations = self.session_data.auth_pbkdf2_vgen_count
.ok_or_else(|| Error::AuthenticationFailed("Missing iterations".to_string()))?;
let mut salt = vfr_bytes.clone();
salt.extend_from_slice(b"AUTH_PBKDF2_SPEEDY_KEY");
let password_key = pbkdf2_derive(&self.password, &salt, iterations, 64);
let random_salt = generate_salt();
let mut speedy_data = random_salt.to_vec();
speedy_data.extend_from_slice(&password_key);
let encrypted = encrypt_cbc_256_pkcs7(combo_key, &speedy_data)?;
Ok(Some(hex::encode_upper(&encrypted[..80])))
}
fn verify_server_response(&self) -> Result<()> {
if let Some(response) = &self.session_data.auth_svr_response {
let combo_key = self.combo_key.as_ref()
.ok_or_else(|| Error::Protocol("Combo key not available".to_string()))?;
let encrypted = hex::decode(response)
.map_err(|e| Error::Protocol(format!("Invalid server response hex: {}", e)))?;
let decrypted = if self.verifier_type == verifier_type::V12C {
decrypt_cbc_256(combo_key, &encrypted)?
} else {
decrypt_cbc_192(combo_key, &encrypted)?
};
if decrypted.len() >= 32 && &decrypted[16..32] == b"SERVER_TO_CLIENT" {
Ok(())
} else {
Err(Error::AuthenticationFailed("Invalid server response".to_string()))
}
} else {
Ok(())
}
}
fn get_alter_timezone_statement(&self) -> String {
if let Ok(tz) = std::env::var("ORA_SDTZ") {
return format!("ALTER SESSION SET TIME_ZONE='{}'\x00", tz);
}
let now = chrono::Local::now();
let offset = now.offset().local_minus_utc();
let hours = offset / 3600;
let minutes = (offset.abs() % 3600) / 60;
let sign = if hours >= 0 { '+' } else { '-' };
format!(
"ALTER SESSION SET TIME_ZONE='{}{:02}:{:02}'\x00",
sign,
hours.abs(),
minutes
)
}
pub fn clear_password(&mut self) {
self.password.fill(0);
self.password.clear();
}
}
impl Drop for AuthMessage {
fn drop(&mut self) {
self.clear_password();
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_auth_message_creation() {
let msg = AuthMessage::new("SCOTT", b"tiger", "FREEPDB1");
assert_eq!(msg.username, "SCOTT");
assert_eq!(msg.phase(), AuthPhase::One);
assert!(!msg.is_complete());
}
#[test]
fn test_auth_mode_sysdba() {
let msg = AuthMessage::new("SYS", b"password", "ORCL").with_sysdba();
assert!(msg.auth_mode & auth_mode::SYSDBA != 0);
assert!(msg.auth_mode & auth_mode::LOGON != 0);
}
#[test]
fn test_session_data_parsing() {
let mut pairs = HashMap::new();
pairs.insert("AUTH_SESSKEY".to_string(), "AABBCCDD".to_string());
pairs.insert("AUTH_VFR_DATA".to_string(), "11223344".to_string());
pairs.insert("AUTH_PBKDF2_VGEN_COUNT".to_string(), "4096".to_string());
let data = SessionData::from_pairs(&pairs);
assert_eq!(data.auth_sesskey, Some("AABBCCDD".to_string()));
assert_eq!(data.auth_vfr_data, Some("11223344".to_string()));
assert_eq!(data.auth_pbkdf2_vgen_count, Some(4096));
}
#[test]
fn test_phase_one_build() {
let msg = AuthMessage::new("TESTUSER", b"password", "TESTDB");
let caps = Capabilities::new();
let packet = msg.build_request(&caps, false).unwrap();
assert!(packet.len() > PACKET_HEADER_SIZE);
assert_eq!(packet[4], PacketType::Data as u8);
assert_eq!(packet[PACKET_HEADER_SIZE + 3], FunctionCode::AuthPhaseOne as u8);
}
#[test]
fn test_clear_password() {
let mut msg = AuthMessage::new("USER", b"secret", "DB");
assert!(!msg.password.is_empty());
msg.clear_password();
assert!(msg.password.is_empty());
}
#[test]
fn test_read_auth_string_zero_length() {
let data = [0x00];
let mut buf = ReadBuffer::from_slice(&data);
let result = AuthMessage::read_auth_string(&mut buf).unwrap();
assert_eq!(result, "");
}
#[test]
fn test_read_auth_string_with_data() {
let data = [0x01, 0x05, 0x05, b'H', b'E', b'L', b'L', b'O'];
let mut buf = ReadBuffer::from_slice(&data);
let result = AuthMessage::read_auth_string(&mut buf).unwrap();
assert_eq!(result, "HELLO");
}
#[test]
fn test_read_auth_string_null_bytes() {
let data = [0x01, 0x05, 0xFF];
let mut buf = ReadBuffer::from_slice(&data);
let result = AuthMessage::read_auth_string(&mut buf).unwrap();
assert_eq!(result, "");
}
}