#![cfg(test)]
use super::{
Handshake, MySqlConnectOptions, MySqlConnection, MySqlConnectionInner, MySqlError, capability,
};
use crate::cx::Cx;
use crate::types::Outcome;
use std::io::{Read, Write};
use std::net::TcpListener;
use std::sync::atomic::AtomicBool;
use std::time::Duration;
fn run<F: std::future::Future>(future: F) -> F::Output {
futures_lite::future::block_on(future)
}
struct PacketBuffer {
buf: Vec<u8>,
sequence: u8,
}
impl PacketBuffer {
fn new() -> Self {
Self {
buf: Vec::new(),
sequence: 0,
}
}
fn set_sequence(&mut self, seq: u8) {
self.sequence = seq;
}
fn write_byte(&mut self, byte: u8) {
self.buf.push(byte);
}
fn write_bytes(&mut self, bytes: &[u8]) {
self.buf.extend_from_slice(bytes);
}
fn build_packet(self) -> MySqlPacket {
let length = self.buf.len() as u32;
let mut packet = Vec::new();
packet.extend_from_slice(&length.to_le_bytes()[0..3]);
packet.push(self.sequence);
packet.extend_from_slice(&self.buf);
MySqlPacket { bytes: packet }
}
}
struct MySqlPacket {
bytes: Vec<u8>,
}
fn make_test_connection(stream: crate::net::TcpStream, sequence: u8) -> MySqlConnection {
MySqlConnection {
inner: MySqlConnectionInner {
stream,
connection_id: 0,
capabilities: 0,
charset: 0,
status_flags: 0,
sequence,
closed: false,
server_version: String::new(),
needs_rollback: false,
max_result_rows: super::DEFAULT_MAX_RESULT_ROWS,
prepared_statement_epoch: 0,
query_in_flight: AtomicBool::new(false),
},
options: None,
}
}
#[test]
fn audit_handshake_does_not_advertise_local_infile_capability() {
let listener = TcpListener::bind("127.0.0.1:0").expect("bind test listener");
let addr = listener.local_addr().expect("listener addr");
let server = std::thread::spawn(move || {
let (mut stream, _) = listener.accept().expect("accept client");
stream
.set_read_timeout(Some(Duration::from_secs(2)))
.expect("set read timeout");
let handshake = create_initial_handshake_packet();
stream.write_all(&handshake.bytes).expect("write handshake");
stream.flush().expect("flush handshake");
let mut header = [0u8; 4];
stream
.read_exact(&mut header)
.expect("read response header");
let length = u32::from_le_bytes([header[0], header[1], header[2], 0]);
let mut payload = vec![0u8; length as usize];
stream
.read_exact(&mut payload)
.expect("read response payload");
let client_caps = u32::from_le_bytes(
payload
.get(0..4)
.and_then(|s| s.try_into().ok())
.expect("client capability bytes missing"),
);
assert_eq!(
client_caps & capability::CLIENT_LOCAL_FILES,
0,
"SECURITY: Client must not advertise CLIENT_LOCAL_FILES capability by default"
);
assert_ne!(
client_caps & capability::CLIENT_PROTOCOL_41,
0,
"Sanity check: expected normal handshake capabilities"
);
});
let stream = run(async {
crate::net::TcpStream::connect_socket_addr(addr)
.await
.expect("connect to test server")
});
let mut conn = make_test_connection(stream, 1);
let options = MySqlConnectOptions::parse("mysql://user:pass@localhost/testdb")
.expect("parse mysql options");
let handshake = Handshake {
server_version: "8.0.0-test".to_string(),
connection_id: 99,
auth_plugin_data: b"01234567890123456789".to_vec(),
capabilities: capability::CLIENT_PROTOCOL_41
| capability::CLIENT_SECURE_CONNECTION
| capability::CLIENT_PLUGIN_AUTH
| capability::CLIENT_LOCAL_FILES,
charset: 45,
status_flags: 0,
auth_plugin_name: "caching_sha2_password".to_string(),
};
run(conn.send_handshake_response(&options, &handshake)).expect("send handshake response");
server.join().expect("server thread join");
}
#[test]
fn audit_server_local_infile_request_rejection() {
let listener = TcpListener::bind("127.0.0.1:0").expect("bind test listener");
let addr = listener.local_addr().expect("listener addr");
let server = std::thread::spawn(move || {
let (mut stream, _) = listener.accept().expect("accept client");
stream
.set_read_timeout(Some(Duration::from_secs(2)))
.expect("set read timeout");
let mut header = [0u8; 4];
stream.read_exact(&mut header).expect("read query header");
let length = u32::from_le_bytes([header[0], header[1], header[2], 0]);
let mut _payload = vec![0u8; length as usize];
stream
.read_exact(&mut _payload)
.expect("read query payload");
let mut response = PacketBuffer::new();
response.write_byte(0xFB); response.write_bytes(b"/etc/passwd");
let mut packet = PacketBuffer::new();
packet.set_sequence(1);
packet.buf = response.buf;
let packet = packet.build_packet();
stream
.write_all(&packet.bytes)
.expect("write malicious LOCAL INFILE request");
stream.flush().expect("flush LOCAL INFILE request");
});
let stream = run(async {
crate::net::TcpStream::connect_socket_addr(addr)
.await
.expect("connect client")
});
let mut conn = make_test_connection(stream, 0);
let cx = Cx::for_testing();
let outcome = run(conn.query_static_sql(&cx, "SELECT 1"));
match outcome {
Outcome::Err(MySqlError::Protocol(msg)) => {
assert!(
msg.contains("LOAD DATA LOCAL INFILE request rejected"),
"SECURITY: Error message must indicate LOCAL INFILE rejection, got: {msg}"
);
assert!(
msg.contains("disabled by default"),
"SECURITY: Error message must indicate LOCAL INFILE is disabled, got: {msg}"
);
}
other => {
panic!(
"CRITICAL: Expected LOCAL INFILE rejection, got: {other:?}. \
Client may be vulnerable to file exfiltration attacks!"
);
}
}
server.join().expect("server thread join");
assert!(
conn.inner.closed,
"SECURITY: Connection must be closed after LOCAL INFILE rejection (fail-closed behavior)"
);
}
#[test]
fn audit_local_infile_rejection_comprehensive_paths() {
let malicious_paths: &[&[u8]] = &[
b"/etc/passwd", b"/etc/shadow", b"C:\\windows\\system32\\config\\SAM", b"../../../etc/passwd", b"/proc/self/environ", b"/home/user/.ssh/id_rsa", b"/var/log/mysql/mysql.log", b"", b"/tmp/does-not-exist.txt", ];
for (i, &path) in malicious_paths.iter().enumerate() {
let listener = TcpListener::bind("127.0.0.1:0").expect("bind test listener");
let addr = listener.local_addr().expect("listener addr");
let path_clone = path.to_vec();
let server = std::thread::spawn(move || {
let (mut stream, _) = listener.accept().expect("accept client");
stream
.set_read_timeout(Some(Duration::from_secs(2)))
.expect("set read timeout");
let mut header = [0u8; 4];
stream.read_exact(&mut header).expect("read query header");
let length = u32::from_le_bytes([header[0], header[1], header[2], 0]);
let mut _payload = vec![0u8; length as usize];
stream
.read_exact(&mut _payload)
.expect("read query payload");
let mut response = PacketBuffer::new();
response.write_byte(0xFB);
response.write_bytes(&path_clone);
let mut packet = PacketBuffer::new();
packet.set_sequence(1);
packet.buf = response.buf;
let packet = packet.build_packet();
stream
.write_all(&packet.bytes)
.expect("write LOCAL INFILE request");
stream.flush().expect("flush LOCAL INFILE request");
});
let stream = run(async {
crate::net::TcpStream::connect_socket_addr(addr)
.await
.expect("connect client")
});
let mut conn = make_test_connection(stream, 0);
let cx = Cx::for_testing();
let outcome = run(conn.query_static_sql(&cx, "SELECT 1"));
assert!(
matches!(outcome, Outcome::Err(MySqlError::Protocol(ref msg))
if msg.contains("LOAD DATA LOCAL INFILE request rejected")),
"SECURITY: Path {} (test {}) must be rejected, got: {outcome:?}",
String::from_utf8_lossy(path),
i + 1
);
server.join().expect("server thread join");
}
}
fn create_initial_handshake_packet() -> MySqlPacket {
let mut handshake = PacketBuffer::new();
handshake.set_sequence(0);
handshake.write_byte(10);
handshake.write_bytes(b"8.0.0-test\0");
handshake.write_bytes(&99u32.to_le_bytes());
handshake.write_bytes(b"12345678");
handshake.write_byte(0);
let caps_low = (capability::CLIENT_PROTOCOL_41
| capability::CLIENT_SECURE_CONNECTION
| capability::CLIENT_PLUGIN_AUTH
| capability::CLIENT_LOCAL_FILES) as u16; handshake.write_bytes(&caps_low.to_le_bytes());
handshake.write_byte(45);
handshake.write_bytes(&0u16.to_le_bytes());
handshake.write_bytes(&0u16.to_le_bytes());
handshake.write_byte(21);
handshake.write_bytes(&[0u8; 10]);
handshake.write_bytes(b"123456789012");
handshake.write_bytes(b"caching_sha2_password\0");
handshake.build_packet()
}