use std::io;
use std::sync::Arc;
use base64::{engine::general_purpose::STANDARD as BASE64_STANDARD, Engine};
use rcgen::{CertifiedKey, KeyPair};
use rustls::pki_types::{CertificateDer, PrivateKeyDer, PrivatePkcs8KeyDer};
use rustls::{ClientConfig, RootCertStore, ServerConfig};
use sha2::{Digest, Sha256};
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
use tokio::net::TcpStream;
use tokio::sync::mpsc;
use tokio_rustls::TlsConnector;
use crate::error::TransportError;
pub const EXA_MAGIC_NUMBER: u32 = 0x02212102;
pub const EXA_PROTOCOL_VERSION_MAJOR: u32 = 1;
pub const EXA_PROTOCOL_VERSION_MINOR: u32 = 1;
pub const EXA_MAGIC_PACKET_SIZE: usize = 12;
pub const EXA_RESPONSE_PACKET_SIZE: usize = 24;
pub const HTTP_CHUNK_SIZE: usize = 64 * 1024;
#[must_use]
pub fn generate_magic_packet() -> [u8; EXA_MAGIC_PACKET_SIZE] {
let mut packet = [0u8; EXA_MAGIC_PACKET_SIZE];
packet[0..4].copy_from_slice(&EXA_MAGIC_NUMBER.to_le_bytes());
packet[4..8].copy_from_slice(&EXA_PROTOCOL_VERSION_MAJOR.to_le_bytes());
packet[8..12].copy_from_slice(&EXA_PROTOCOL_VERSION_MINOR.to_le_bytes());
packet
}
pub fn parse_magic_packet(packet: &[u8]) -> Result<(u32, u32, u32), TransportError> {
if packet.len() < EXA_MAGIC_PACKET_SIZE {
return Err(TransportError::ProtocolError(format!(
"Magic packet too short: expected {} bytes, got {}",
EXA_MAGIC_PACKET_SIZE,
packet.len()
)));
}
let magic = u32::from_le_bytes([packet[0], packet[1], packet[2], packet[3]]);
let major = u32::from_le_bytes([packet[4], packet[5], packet[6], packet[7]]);
let minor = u32::from_le_bytes([packet[8], packet[9], packet[10], packet[11]]);
if magic != EXA_MAGIC_NUMBER {
return Err(TransportError::ProtocolError(format!(
"Invalid magic number: expected 0x{:08X}, got 0x{:08X}",
EXA_MAGIC_NUMBER, magic
)));
}
Ok((magic, major, minor))
}
pub fn parse_response_packet(packet: &[u8]) -> Result<(String, u16), TransportError> {
if packet.len() < EXA_RESPONSE_PACKET_SIZE {
return Err(TransportError::ProtocolError(format!(
"Response packet too short: expected {} bytes, got {}",
EXA_RESPONSE_PACKET_SIZE,
packet.len()
)));
}
let port = i32::from_le_bytes([packet[4], packet[5], packet[6], packet[7]]);
if port < 0 || port > u16::MAX as i32 {
return Err(TransportError::ProtocolError(format!(
"Invalid port in response packet: {}",
port
)));
}
let ip_bytes = &packet[8..24];
let ip_string = String::from_utf8_lossy(ip_bytes)
.trim_end_matches('\0')
.to_string();
if ip_string.is_empty() {
return Err(TransportError::ProtocolError(
"Empty IP address in response packet".to_string(),
));
}
Ok((ip_string, port as u16))
}
pub async fn perform_handshake<S>(stream: &mut S) -> Result<(String, u16), TransportError>
where
S: AsyncRead + AsyncWrite + Unpin,
{
let magic_packet = generate_magic_packet();
stream
.write_all(&magic_packet)
.await
.map_err(|e| TransportError::IoError(format!("Failed to send magic packet: {e}")))?;
stream
.flush()
.await
.map_err(|e| TransportError::IoError(format!("Failed to flush magic packet: {e}")))?;
let mut response = [0u8; EXA_RESPONSE_PACKET_SIZE];
stream
.read_exact(&mut response)
.await
.map_err(|e| TransportError::IoError(format!("Failed to read response packet: {e}")))?;
parse_response_packet(&response)
}
#[derive(Clone)]
pub struct TlsCertificate {
pub certificate_der: Vec<u8>,
pub private_key_der: Vec<u8>,
pub public_key_der: Vec<u8>,
pub fingerprint: String,
}
impl TlsCertificate {
pub fn generate() -> Result<Self, TransportError> {
let key_pair = KeyPair::generate()
.map_err(|e| TransportError::TlsError(format!("Failed to generate key pair: {e}")))?;
let params = rcgen::CertificateParams::new(vec!["localhost".to_string()]).map_err(|e| {
TransportError::TlsError(format!("Failed to create certificate params: {e}"))
})?;
let cert = params.self_signed(&key_pair).map_err(|e| {
TransportError::TlsError(format!("Failed to generate certificate: {e}"))
})?;
let certified_key = CertifiedKey { cert, key_pair };
let certificate_der = certified_key.cert.der().to_vec();
let private_key_der = certified_key.key_pair.serialize_der();
let public_key_der = certified_key.key_pair.public_key_der().to_vec();
let fingerprint = compute_public_key_fingerprint(&public_key_der);
Ok(Self {
certificate_der,
private_key_der,
public_key_der,
fingerprint,
})
}
pub fn to_server_config(&self) -> Result<ServerConfig, TransportError> {
let cert = CertificateDer::from(self.certificate_der.clone());
let key = PrivateKeyDer::try_from(self.private_key_der.clone())
.map_err(|e| TransportError::TlsError(format!("Failed to parse private key: {e}")))?;
ServerConfig::builder()
.with_no_client_auth()
.with_single_cert(vec![cert], key)
.map_err(|e| TransportError::TlsError(format!("Failed to create TLS config: {e}")))
}
pub fn to_client_config(&self) -> Result<ClientConfig, TransportError> {
let root_store = RootCertStore::empty();
let config = ClientConfig::builder()
.with_root_certificates(root_store)
.with_no_client_auth();
Ok(config)
}
}
enum ClientConnectionStream {
Tcp(TcpStream),
Tls(Box<tokio_rustls::client::TlsStream<TcpStream>>),
}
pub struct HttpTransportClient {
stream: ClientConnectionStream,
internal_addr: String,
tls_certificate: Option<TlsCertificate>,
}
impl HttpTransportClient {
pub async fn connect(host: &str, port: u16, use_tls: bool) -> Result<Self, TransportError> {
let addr = format!("{host}:{port}");
let mut tcp_stream = TcpStream::connect(&addr)
.await
.map_err(|e| TransportError::IoError(format!("Failed to connect to {addr}: {e}")))?;
let (ip, internal_port) = perform_handshake(&mut tcp_stream).await?;
let internal_addr = format!("{ip}:{internal_port}");
if use_tls {
let cert = TlsCertificate::generate()?;
let cert_der = CertificateDer::from(cert.certificate_der.clone());
let key_der = PrivatePkcs8KeyDer::from(cert.private_key_der.clone());
let connector = TlsConnector::from(Arc::new(
ClientConfig::builder()
.dangerous()
.with_custom_certificate_verifier(Arc::new(NoVerifier))
.with_client_auth_cert(vec![cert_der], key_der.into())
.map_err(|e| {
TransportError::TlsError(format!("Failed to set client cert: {e}"))
})?,
));
let server_name = "exasol"
.try_into()
.map_err(|e| TransportError::TlsError(format!("Invalid server name: {e:?}")))?;
let tls_stream = connector
.connect(server_name, tcp_stream)
.await
.map_err(|e| TransportError::TlsError(format!("TLS handshake failed: {e}")))?;
Ok(Self {
stream: ClientConnectionStream::Tls(Box::new(tls_stream)),
internal_addr,
tls_certificate: Some(cert),
})
} else {
Ok(Self {
stream: ClientConnectionStream::Tcp(tcp_stream),
internal_addr,
tls_certificate: None,
})
}
}
#[must_use]
pub fn internal_address(&self) -> &str {
&self.internal_addr
}
#[must_use]
pub fn public_key_fingerprint(&self) -> Option<&str> {
self.tls_certificate
.as_ref()
.map(|c| c.fingerprint.as_str())
}
pub async fn write(&mut self, data: &[u8]) -> Result<(), TransportError> {
match &mut self.stream {
ClientConnectionStream::Tcp(stream) => {
stream
.write_all(data)
.await
.map_err(|e| TransportError::IoError(format!("Failed to write data: {e}")))?;
stream
.flush()
.await
.map_err(|e| TransportError::IoError(format!("Failed to flush data: {e}")))?;
}
ClientConnectionStream::Tls(stream) => {
stream
.write_all(data)
.await
.map_err(|e| TransportError::IoError(format!("Failed to write data: {e}")))?;
stream
.flush()
.await
.map_err(|e| TransportError::IoError(format!("Failed to flush data: {e}")))?;
}
}
Ok(())
}
pub async fn read(&mut self, buf: &mut [u8]) -> Result<usize, TransportError> {
match &mut self.stream {
ClientConnectionStream::Tcp(stream) => stream
.read(buf)
.await
.map_err(|e| TransportError::IoError(format!("Failed to read data: {e}"))),
ClientConnectionStream::Tls(stream) => stream
.read(buf)
.await
.map_err(|e| TransportError::IoError(format!("Failed to read data: {e}"))),
}
}
pub async fn shutdown(&mut self) -> Result<(), TransportError> {
match &mut self.stream {
ClientConnectionStream::Tcp(stream) => stream.shutdown().await.map_err(|e| {
TransportError::IoError(format!("Failed to shutdown connection: {e}"))
}),
ClientConnectionStream::Tls(stream) => stream.shutdown().await.map_err(|e| {
TransportError::IoError(format!("Failed to shutdown connection: {e}"))
}),
}
}
pub async fn read_http_request(&mut self) -> Result<HttpRequest, TransportError> {
match &mut self.stream {
ClientConnectionStream::Tcp(stream) => parse_http_request(stream).await,
ClientConnectionStream::Tls(stream) => parse_http_request(stream.as_mut()).await,
}
}
async fn read_exact_bytes(&mut self, buf: &mut [u8]) -> io::Result<()> {
match &mut self.stream {
ClientConnectionStream::Tcp(stream) => {
stream.read_exact(buf).await?;
Ok(())
}
ClientConnectionStream::Tls(stream) => {
stream.read_exact(buf).await?;
Ok(())
}
}
}
async fn read_line_from_stream(&mut self) -> io::Result<String> {
match &mut self.stream {
ClientConnectionStream::Tcp(stream) => read_line(stream).await,
ClientConnectionStream::Tls(stream) => read_line(stream.as_mut()).await,
}
}
pub async fn read_chunked_body(&mut self) -> Result<Vec<u8>, TransportError> {
let mut body = Vec::new();
loop {
let size_line = self.read_line_from_stream().await.map_err(|e| {
TransportError::ProtocolError(format!("Failed to read chunk size: {e}"))
})?;
let chunk_size = parse_chunk_size(&size_line)?;
if chunk_size == 0 {
let mut trailer = [0u8; 2];
self.read_exact_bytes(&mut trailer).await.map_err(|e| {
TransportError::ProtocolError(format!("Failed to read chunk trailer: {e}"))
})?;
break;
}
let mut chunk = vec![0u8; chunk_size];
self.read_exact_bytes(&mut chunk).await.map_err(|e| {
TransportError::ProtocolError(format!("Failed to read chunk data: {e}"))
})?;
body.extend_from_slice(&chunk);
let mut crlf = [0u8; 2];
self.read_exact_bytes(&mut crlf).await.map_err(|e| {
TransportError::ProtocolError(format!("Failed to read chunk CRLF: {e}"))
})?;
}
Ok(body)
}
pub async fn write_http_response(
&mut self,
status_code: u16,
status_text: &str,
headers: &[(&str, &str)],
body: Option<&[u8]>,
) -> Result<(), TransportError> {
let response = build_http_response(status_code, status_text, headers, body);
self.write(&response).await
}
pub async fn write_chunked_response_headers(&mut self) -> Result<(), TransportError> {
let headers = build_chunked_response_headers();
self.write(&headers).await
}
pub async fn write_chunked_body(&mut self, data: &[u8]) -> Result<(), TransportError> {
if data.is_empty() {
return Ok(());
}
let chunk = encode_chunk(data);
self.write(&chunk).await
}
pub async fn write_final_chunk(&mut self) -> Result<(), TransportError> {
let final_chunk = encode_chunk(&[]);
self.write(&final_chunk).await
}
pub async fn handle_import_request(&mut self) -> Result<HttpRequest, TransportError> {
let request = self.read_http_request().await?;
if request.method != HttpMethod::Get {
return Err(TransportError::ProtocolError(format!(
"Expected GET request for IMPORT, got {}",
request.method
)));
}
self.write_chunked_response_headers().await?;
Ok(request)
}
pub async fn handle_parquet_import_requests(
&mut self,
file_bytes: &[u8],
) -> Result<(), TransportError> {
match &mut self.stream {
ClientConnectionStream::Tcp(stream) => {
serve_parquet_range_requests(stream, file_bytes).await
}
ClientConnectionStream::Tls(stream) => {
serve_parquet_range_requests(stream.as_mut(), file_bytes).await
}
}
}
pub async fn handle_export_request(
&mut self,
) -> Result<(HttpRequest, Vec<u8>), TransportError> {
let request = self.read_http_request().await?;
if request.method != HttpMethod::Put {
return Err(TransportError::ProtocolError(format!(
"Expected PUT request for EXPORT, got {}",
request.method
)));
}
let body = if request.is_chunked() {
self.read_chunked_body().await?
} else if let Some(content_length) = request.content_length() {
let mut body = vec![0u8; content_length];
self.read_exact_bytes(&mut body)
.await
.map_err(|e| TransportError::IoError(format!("Failed to read body: {e}")))?;
body
} else {
return Err(TransportError::ProtocolError(
"PUT request has no Content-Length or Transfer-Encoding".to_string(),
));
};
self.write_http_response(200, "OK", &[], None).await?;
Ok((request, body))
}
}
#[derive(Debug)]
struct NoVerifier;
impl rustls::client::danger::ServerCertVerifier for NoVerifier {
fn verify_server_cert(
&self,
_end_entity: &CertificateDer<'_>,
_intermediates: &[CertificateDer<'_>],
_server_name: &rustls::pki_types::ServerName<'_>,
_ocsp_response: &[u8],
_now: rustls::pki_types::UnixTime,
) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
Ok(rustls::client::danger::ServerCertVerified::assertion())
}
fn verify_tls12_signature(
&self,
_message: &[u8],
_cert: &CertificateDer<'_>,
_dss: &rustls::DigitallySignedStruct,
) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
}
fn verify_tls13_signature(
&self,
_message: &[u8],
_cert: &CertificateDer<'_>,
_dss: &rustls::DigitallySignedStruct,
) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
}
fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
vec![
rustls::SignatureScheme::RSA_PKCS1_SHA256,
rustls::SignatureScheme::RSA_PKCS1_SHA384,
rustls::SignatureScheme::RSA_PKCS1_SHA512,
rustls::SignatureScheme::ECDSA_NISTP256_SHA256,
rustls::SignatureScheme::ECDSA_NISTP384_SHA384,
rustls::SignatureScheme::ECDSA_NISTP521_SHA512,
rustls::SignatureScheme::RSA_PSS_SHA256,
rustls::SignatureScheme::RSA_PSS_SHA384,
rustls::SignatureScheme::RSA_PSS_SHA512,
rustls::SignatureScheme::ED25519,
]
}
}
#[must_use]
pub fn compute_sha256_fingerprint(data: &[u8]) -> String {
let mut hasher = Sha256::new();
hasher.update(data);
let hash = hasher.finalize();
BASE64_STANDARD.encode(hash)
}
#[must_use]
pub fn compute_public_key_fingerprint(public_key_der: &[u8]) -> String {
let base64_hash = compute_sha256_fingerprint(public_key_der);
format!("sha256//{base64_hash}")
}
pub struct DataPipe {
tx: mpsc::Sender<Vec<u8>>,
rx: mpsc::Receiver<Vec<u8>>,
}
impl DataPipe {
#[must_use]
pub fn create_pair(buffer_size: usize) -> (Self, Self) {
let (tx1, rx1) = mpsc::channel(buffer_size);
let (tx2, rx2) = mpsc::channel(buffer_size);
let writer = DataPipe { tx: tx1, rx: rx2 };
let reader = DataPipe { tx: tx2, rx: rx1 };
(writer, reader)
}
pub async fn send(&self, data: Vec<u8>) -> Result<(), TransportError> {
self.tx.send(data).await.map_err(|e| {
TransportError::SendError(format!("Failed to send data through pipe: {e}"))
})
}
pub async fn recv(&mut self) -> Option<Vec<u8>> {
self.rx.recv().await
}
}
#[must_use]
pub fn encode_chunk(data: &[u8]) -> Vec<u8> {
if data.is_empty() {
b"0\r\n\r\n".to_vec()
} else {
let size_hex = format!("{:X}\r\n", data.len());
let mut result = Vec::with_capacity(size_hex.len() + data.len() + 2);
result.extend_from_slice(size_hex.as_bytes());
result.extend_from_slice(data);
result.extend_from_slice(b"\r\n");
result
}
}
pub fn parse_chunk_size(size_line: &str) -> Result<usize, TransportError> {
let size_str = size_line.split(';').next().unwrap_or(size_line).trim();
usize::from_str_radix(size_str, 16)
.map_err(|e| TransportError::ProtocolError(format!("Invalid chunk size '{size_str}': {e}")))
}
pub async fn read_line<S: AsyncRead + Unpin>(stream: &mut S) -> io::Result<String> {
let mut line = Vec::new();
let mut buf = [0u8; 1];
loop {
stream.read_exact(&mut buf).await?;
if buf[0] == b'\r' {
stream.read_exact(&mut buf).await?;
if buf[0] == b'\n' {
break;
}
line.push(b'\r');
line.push(buf[0]);
} else {
line.push(buf[0]);
}
}
String::from_utf8(line).map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))
}
async fn serve_parquet_range_requests<S>(
stream: &mut S,
file_bytes: &[u8],
) -> Result<(), TransportError>
where
S: AsyncRead + AsyncWrite + Unpin,
{
loop {
let request = match parse_http_request(stream).await {
Ok(req) => req,
Err(TransportError::IoError(_)) => return Ok(()),
Err(TransportError::ProtocolError(_)) => return Ok(()),
Err(e) => return Err(e),
};
match request.method {
HttpMethod::Head => {
let response = format!(
"HTTP/1.1 200 OK\r\nContent-Length: {}\r\n\r\n",
file_bytes.len()
);
stream.write_all(response.as_bytes()).await.map_err(|e| {
TransportError::IoError(format!("Failed to write HEAD response: {e}"))
})?;
stream.flush().await.map_err(|e| {
TransportError::IoError(format!("Failed to flush HEAD response: {e}"))
})?;
}
HttpMethod::Get => {
if let Some(range_header) = request.headers.get("range") {
let parsed_range = range_header
.strip_prefix("bytes=")
.and_then(|s| s.split_once('-'))
.and_then(|(start_str, end_str)| {
let start = start_str.trim().parse::<usize>().ok()?;
let end = end_str.trim().parse::<usize>().ok()?;
Some((start, end))
});
if let Some((start, end)) = parsed_range {
if file_bytes.is_empty() || start >= file_bytes.len() {
stream
.write_all(b"HTTP/1.1 400 Bad Request\r\n\r\n")
.await
.map_err(|e| {
TransportError::IoError(format!("Failed to write 400: {e}"))
})?;
stream.flush().await.map_err(|e| {
TransportError::IoError(format!("Failed to flush 400: {e}"))
})?;
continue;
}
let clamped_end = end.min(file_bytes.len().saturating_sub(1));
if clamped_end < start {
stream
.write_all(b"HTTP/1.1 400 Bad Request\r\n\r\n")
.await
.map_err(|e| {
TransportError::IoError(format!("Failed to write 400: {e}"))
})?;
stream.flush().await.map_err(|e| {
TransportError::IoError(format!("Failed to flush 400: {e}"))
})?;
continue;
}
let slice = &file_bytes[start..=clamped_end];
let response = format!(
"HTTP/1.1 206 Partial Content\r\nContent-Length: {}\r\nContent-Range: bytes {}-{}/{}\r\n\r\n",
slice.len(),
start,
clamped_end,
file_bytes.len()
);
stream.write_all(response.as_bytes()).await.map_err(|e| {
TransportError::IoError(format!("Failed to write 206 headers: {e}"))
})?;
stream.write_all(slice).await.map_err(|e| {
TransportError::IoError(format!("Failed to write 206 body: {e}"))
})?;
stream.flush().await.map_err(|e| {
TransportError::IoError(format!("Failed to flush 206 response: {e}"))
})?;
} else {
stream
.write_all(b"HTTP/1.1 400 Bad Request\r\n\r\n")
.await
.map_err(|e| {
TransportError::IoError(format!("Failed to write 400: {e}"))
})?;
stream.flush().await.map_err(|e| {
TransportError::IoError(format!("Failed to flush 400: {e}"))
})?;
}
} else {
let response = format!(
"HTTP/1.1 200 OK\r\nContent-Length: {}\r\n\r\n",
file_bytes.len()
);
stream.write_all(response.as_bytes()).await.map_err(|e| {
TransportError::IoError(format!("Failed to write GET headers: {e}"))
})?;
stream.write_all(file_bytes).await.map_err(|e| {
TransportError::IoError(format!("Failed to write GET body: {e}"))
})?;
stream.flush().await.map_err(|e| {
TransportError::IoError(format!("Failed to flush GET response: {e}"))
})?;
}
}
HttpMethod::Put => {
let _ = stream
.write_all(b"HTTP/1.1 405 Method Not Allowed\r\n\r\n")
.await;
let _ = stream.flush().await;
return Err(TransportError::ProtocolError(
"Unexpected PUT in Parquet import handler".to_string(),
));
}
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum HttpMethod {
Get,
Put,
Head,
}
impl std::fmt::Display for HttpMethod {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
HttpMethod::Get => write!(f, "GET"),
HttpMethod::Put => write!(f, "PUT"),
HttpMethod::Head => write!(f, "HEAD"),
}
}
}
#[derive(Debug, Clone)]
pub struct HttpRequest {
pub method: HttpMethod,
pub path: String,
pub headers: std::collections::HashMap<String, String>,
}
impl HttpRequest {
#[must_use]
pub fn content_length(&self) -> Option<usize> {
self.headers
.get("content-length")
.and_then(|v| v.parse().ok())
}
#[must_use]
pub fn is_chunked(&self) -> bool {
self.headers
.get("transfer-encoding")
.map(|v| v.to_lowercase().contains("chunked"))
.unwrap_or(false)
}
#[must_use]
pub fn host(&self) -> Option<&str> {
self.headers.get("host").map(|s| s.as_str())
}
}
pub async fn parse_http_request<S: AsyncRead + Unpin>(
stream: &mut S,
) -> Result<HttpRequest, TransportError> {
let request_line = read_line(stream)
.await
.map_err(|e| TransportError::ProtocolError(format!("Failed to read request line: {e}")))?;
let parts: Vec<&str> = request_line.split_whitespace().collect();
if parts.len() < 3 {
return Err(TransportError::ProtocolError(format!(
"Invalid HTTP request line: '{request_line}'"
)));
}
let method = match parts[0] {
"GET" => HttpMethod::Get,
"PUT" => HttpMethod::Put,
"HEAD" => HttpMethod::Head,
other => {
return Err(TransportError::ProtocolError(format!(
"Unsupported HTTP method: '{other}'"
)))
}
};
let path = parts[1].to_string();
let mut headers = std::collections::HashMap::new();
loop {
let line = read_line(stream).await.map_err(|e| {
TransportError::ProtocolError(format!("Failed to read header line: {e}"))
})?;
if line.is_empty() {
break;
}
if let Some((name, value)) = line.split_once(':') {
headers.insert(name.trim().to_lowercase(), value.trim().to_string());
}
}
Ok(HttpRequest {
method,
path,
headers,
})
}
#[must_use]
pub fn build_http_response(
status_code: u16,
status_text: &str,
headers: &[(&str, &str)],
body: Option<&[u8]>,
) -> Vec<u8> {
let mut response = format!("HTTP/1.1 {status_code} {status_text}\r\n");
for (name, value) in headers {
response.push_str(&format!("{name}: {value}\r\n"));
}
if let Some(body_data) = body {
if !headers
.iter()
.any(|(n, _)| n.eq_ignore_ascii_case("content-length"))
{
response.push_str(&format!("Content-Length: {}\r\n", body_data.len()));
}
}
response.push_str("\r\n");
let mut result = response.into_bytes();
if let Some(body_data) = body {
result.extend_from_slice(body_data);
}
result
}
#[must_use]
pub fn build_chunked_response_headers() -> Vec<u8> {
build_http_response(
200,
"OK",
&[
("Content-Type", "application/octet-stream"),
("Transfer-Encoding", "chunked"),
],
None,
)
}
#[must_use]
pub fn build_ok_response() -> Vec<u8> {
build_http_response(200, "OK", &[], None)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_generate_magic_packet() {
let packet = generate_magic_packet();
assert_eq!(packet.len(), EXA_MAGIC_PACKET_SIZE);
let magic = u32::from_le_bytes([packet[0], packet[1], packet[2], packet[3]]);
assert_eq!(magic, EXA_MAGIC_NUMBER);
let major = u32::from_le_bytes([packet[4], packet[5], packet[6], packet[7]]);
let minor = u32::from_le_bytes([packet[8], packet[9], packet[10], packet[11]]);
assert_eq!(major, EXA_PROTOCOL_VERSION_MAJOR);
assert_eq!(minor, EXA_PROTOCOL_VERSION_MINOR);
}
#[test]
fn test_parse_magic_packet_valid() {
let packet = generate_magic_packet();
let result = parse_magic_packet(&packet);
assert!(result.is_ok());
let (magic, major, minor) = result.unwrap();
assert_eq!(magic, EXA_MAGIC_NUMBER);
assert_eq!(major, EXA_PROTOCOL_VERSION_MAJOR);
assert_eq!(minor, EXA_PROTOCOL_VERSION_MINOR);
}
#[test]
fn test_parse_magic_packet_too_short() {
let packet = [0u8; 8]; let result = parse_magic_packet(&packet);
assert!(result.is_err());
let err = result.unwrap_err();
assert!(matches!(err, TransportError::ProtocolError(_)));
}
#[test]
fn test_parse_magic_packet_invalid_magic() {
let mut packet = generate_magic_packet();
packet[0] = 0xFF;
let result = parse_magic_packet(&packet);
assert!(result.is_err());
let err = result.unwrap_err();
assert!(matches!(err, TransportError::ProtocolError(_)));
}
#[test]
fn test_parse_response_packet_valid() {
let mut packet = [0u8; EXA_RESPONSE_PACKET_SIZE];
packet[0..4].copy_from_slice(&0i32.to_le_bytes());
packet[4..8].copy_from_slice(&8563i32.to_le_bytes());
let ip_addr = b"192.168.1.100\0\0\0"; packet[8..24].copy_from_slice(ip_addr);
let result = parse_response_packet(&packet);
assert!(result.is_ok());
let (addr, p) = result.unwrap();
assert_eq!(addr, "192.168.1.100");
assert_eq!(p, 8563);
}
#[test]
fn test_parse_response_packet_ip_format() {
let mut packet = [0u8; EXA_RESPONSE_PACKET_SIZE];
packet[0..4].copy_from_slice(&0i32.to_le_bytes());
packet[4..8].copy_from_slice(&8563i32.to_le_bytes());
let ip_addr = b"10.0.0.5\0\0\0\0\0\0\0\0"; packet[8..24].copy_from_slice(ip_addr);
let result = parse_response_packet(&packet);
assert!(result.is_ok());
let (addr, p) = result.unwrap();
assert_eq!(addr, "10.0.0.5");
assert_eq!(p, 8563);
}
#[test]
fn test_parse_response_packet_pyexasol_format() {
let mut packet = [0u8; EXA_RESPONSE_PACKET_SIZE];
packet[0..4].copy_from_slice(&0i32.to_le_bytes());
packet[4..8].copy_from_slice(&8563i32.to_le_bytes());
let ip_addr = b"10.0.0.5\0\0\0\0\0\0\0\0"; packet[8..24].copy_from_slice(ip_addr);
let result = parse_response_packet(&packet);
assert!(result.is_ok());
let (addr, port) = result.unwrap();
assert_eq!(addr, "10.0.0.5");
assert_eq!(port, 8563);
}
#[test]
fn test_parse_response_packet_with_longer_ip() {
let mut packet = [0u8; EXA_RESPONSE_PACKET_SIZE];
packet[0..4].copy_from_slice(&0i32.to_le_bytes());
packet[4..8].copy_from_slice(&12345i32.to_le_bytes());
let ip_addr = b"192.168.100.123\0"; packet[8..24].copy_from_slice(ip_addr);
let result = parse_response_packet(&packet);
assert!(result.is_ok());
let (addr, port) = result.unwrap();
assert_eq!(addr, "192.168.100.123");
assert_eq!(port, 12345);
}
#[test]
fn test_parse_response_packet_too_short() {
let packet = [0u8; 16]; let result = parse_response_packet(&packet);
assert!(result.is_err());
let err = result.unwrap_err();
assert!(matches!(err, TransportError::ProtocolError(_)));
}
#[test]
fn test_encode_chunk_with_data() {
let data = b"Hello, World!";
let chunk = encode_chunk(data);
let expected = b"D\r\nHello, World!\r\n";
assert_eq!(chunk, expected);
}
#[test]
fn test_encode_chunk_empty() {
let data = b"";
let chunk = encode_chunk(data);
let expected = b"0\r\n\r\n";
assert_eq!(chunk, expected);
}
#[test]
fn test_encode_chunk_large() {
let data = vec![0xAB; 256];
let chunk = encode_chunk(&data);
assert!(chunk.starts_with(b"100\r\n"));
assert!(chunk.ends_with(b"\r\n"));
assert_eq!(chunk.len(), 5 + 256 + 2); }
#[test]
fn test_parse_chunk_size_valid() {
assert_eq!(parse_chunk_size("D").unwrap(), 13);
assert_eq!(parse_chunk_size("100").unwrap(), 256);
assert_eq!(parse_chunk_size("0").unwrap(), 0);
assert_eq!(parse_chunk_size("FF").unwrap(), 255);
assert_eq!(parse_chunk_size("ff").unwrap(), 255); }
#[test]
fn test_parse_chunk_size_with_extension() {
assert_eq!(parse_chunk_size("D;extension=value").unwrap(), 13);
assert_eq!(parse_chunk_size("100;ext").unwrap(), 256);
}
#[test]
fn test_parse_chunk_size_with_whitespace() {
assert_eq!(parse_chunk_size(" D ").unwrap(), 13);
assert_eq!(parse_chunk_size("\t100\t").unwrap(), 256);
}
#[test]
fn test_parse_chunk_size_invalid() {
let result = parse_chunk_size("not_hex");
assert!(result.is_err());
assert!(matches!(
result.unwrap_err(),
TransportError::ProtocolError(_)
));
}
#[test]
fn test_tls_certificate_generation() {
let result = TlsCertificate::generate();
assert!(result.is_ok());
let cert = result.unwrap();
assert!(!cert.certificate_der.is_empty());
assert!(!cert.private_key_der.is_empty());
assert!(!cert.public_key_der.is_empty());
assert!(!cert.fingerprint.is_empty());
assert!(
cert.fingerprint.starts_with("sha256//"),
"Fingerprint should start with 'sha256//': {}",
cert.fingerprint
);
assert_eq!(cert.fingerprint.len(), 52);
}
#[test]
fn test_compute_sha256_fingerprint() {
let data = b"test certificate data";
let fingerprint = compute_sha256_fingerprint(data);
assert!(!fingerprint.is_empty());
assert_eq!(fingerprint.len(), 44);
assert!(BASE64_STANDARD.decode(&fingerprint).is_ok());
}
#[test]
fn test_fingerprint_format() {
let cert = TlsCertificate::generate().unwrap();
let fingerprint = cert.fingerprint.clone();
assert!(
fingerprint.starts_with("sha256//"),
"Fingerprint should start with 'sha256//': {}",
fingerprint
);
let base64_part = fingerprint.strip_prefix("sha256//").unwrap();
assert_eq!(base64_part.len(), 44); assert!(BASE64_STANDARD.decode(base64_part).is_ok());
}
#[test]
fn test_compute_sha256_fingerprint_consistency() {
let data = b"consistent test data";
let fp1 = compute_sha256_fingerprint(data);
let fp2 = compute_sha256_fingerprint(data);
assert_eq!(fp1, fp2);
}
#[test]
fn test_compute_public_key_fingerprint() {
let data = b"test public key data";
let fingerprint = compute_public_key_fingerprint(data);
assert!(fingerprint.starts_with("sha256//"));
let base64_part = fingerprint.strip_prefix("sha256//").unwrap();
assert_eq!(base64_part.len(), 44); assert!(BASE64_STANDARD.decode(base64_part).is_ok());
}
#[test]
fn test_tls_certificate_public_key_stored() {
let cert = TlsCertificate::generate().unwrap();
assert!(!cert.public_key_der.is_empty());
let expected_fingerprint = compute_public_key_fingerprint(&cert.public_key_der);
assert_eq!(cert.fingerprint, expected_fingerprint);
}
#[test]
fn test_tls_certificate_to_server_config() {
let cert = TlsCertificate::generate().unwrap();
let config = cert.to_server_config();
assert!(config.is_ok());
}
#[tokio::test]
async fn test_data_pipe_send_recv() {
let (writer, mut reader) = DataPipe::create_pair(10);
let data = vec![1, 2, 3, 4, 5];
writer.send(data.clone()).await.unwrap();
let received = reader.recv().await;
assert!(received.is_some());
assert_eq!(received.unwrap(), data);
}
#[tokio::test]
async fn test_data_pipe_multiple_messages() {
let (writer, mut reader) = DataPipe::create_pair(10);
for i in 0..5 {
writer.send(vec![i]).await.unwrap();
}
for i in 0..5 {
let received = reader.recv().await.unwrap();
assert_eq!(received, vec![i]);
}
}
#[tokio::test]
async fn test_read_line() {
use tokio::io::AsyncWriteExt;
let data = b"Hello\r\nWorld\r\n";
let (mut client, mut server) = tokio::io::duplex(64);
tokio::spawn(async move {
server.write_all(data).await.unwrap();
});
let line1 = read_line(&mut client).await.unwrap();
assert_eq!(line1, "Hello");
let line2 = read_line(&mut client).await.unwrap();
assert_eq!(line2, "World");
}
#[tokio::test]
async fn test_parse_http_request_get() {
use tokio::io::AsyncWriteExt;
let request = b"GET /001.csv HTTP/1.1\r\nHost: 10.0.0.5:8563\r\n\r\n";
let (mut client, mut server) = tokio::io::duplex(256);
tokio::spawn(async move {
server.write_all(request).await.unwrap();
});
let parsed = parse_http_request(&mut client).await.unwrap();
assert_eq!(parsed.method, HttpMethod::Get);
assert_eq!(parsed.path, "/001.csv");
assert_eq!(parsed.host(), Some("10.0.0.5:8563"));
assert!(!parsed.is_chunked());
assert!(parsed.content_length().is_none());
}
#[tokio::test]
async fn test_parse_http_request_put_chunked() {
use tokio::io::AsyncWriteExt;
let request = b"PUT /001.csv HTTP/1.1\r\nContent-Type: application/octet-stream\r\nTransfer-Encoding: chunked\r\n\r\n";
let (mut client, mut server) = tokio::io::duplex(256);
tokio::spawn(async move {
server.write_all(request).await.unwrap();
});
let parsed = parse_http_request(&mut client).await.unwrap();
assert_eq!(parsed.method, HttpMethod::Put);
assert_eq!(parsed.path, "/001.csv");
assert!(parsed.is_chunked());
}
#[tokio::test]
async fn test_parse_http_request_put_content_length() {
use tokio::io::AsyncWriteExt;
let request = b"PUT /data.csv HTTP/1.1\r\nContent-Length: 1024\r\n\r\n";
let (mut client, mut server) = tokio::io::duplex(256);
tokio::spawn(async move {
server.write_all(request).await.unwrap();
});
let parsed = parse_http_request(&mut client).await.unwrap();
assert_eq!(parsed.method, HttpMethod::Put);
assert_eq!(parsed.content_length(), Some(1024));
assert!(!parsed.is_chunked());
}
#[tokio::test]
async fn test_parse_http_request_invalid_method() {
use tokio::io::AsyncWriteExt;
let request = b"POST /data HTTP/1.1\r\n\r\n";
let (mut client, mut server) = tokio::io::duplex(256);
tokio::spawn(async move {
server.write_all(request).await.unwrap();
});
let result = parse_http_request(&mut client).await;
assert!(result.is_err());
assert!(matches!(
result.unwrap_err(),
TransportError::ProtocolError(_)
));
}
#[test]
fn test_build_http_response_simple() {
let response = build_http_response(200, "OK", &[], None);
assert_eq!(response, b"HTTP/1.1 200 OK\r\n\r\n");
}
#[test]
fn test_build_http_response_with_headers() {
let response = build_http_response(200, "OK", &[("Content-Type", "text/plain")], None);
assert_eq!(
response,
b"HTTP/1.1 200 OK\r\nContent-Type: text/plain\r\n\r\n"
);
}
#[test]
fn test_build_http_response_with_body() {
let response = build_http_response(200, "OK", &[], Some(b"Hello"));
assert_eq!(
response,
b"HTTP/1.1 200 OK\r\nContent-Length: 5\r\n\r\nHello"
);
}
#[test]
fn test_build_chunked_response_headers() {
let response = build_chunked_response_headers();
let expected = b"HTTP/1.1 200 OK\r\nContent-Type: application/octet-stream\r\nTransfer-Encoding: chunked\r\n\r\n";
assert_eq!(response, expected);
}
#[test]
fn test_build_ok_response() {
let response = build_ok_response();
assert_eq!(response, b"HTTP/1.1 200 OK\r\n\r\n");
}
#[test]
fn test_http_method_display() {
assert_eq!(HttpMethod::Get.to_string(), "GET");
assert_eq!(HttpMethod::Put.to_string(), "PUT");
}
#[test]
fn test_http_request_content_length() {
let mut headers = std::collections::HashMap::new();
headers.insert("content-length".to_string(), "1024".to_string());
let request = HttpRequest {
method: HttpMethod::Put,
path: "/data.csv".to_string(),
headers,
};
assert_eq!(request.content_length(), Some(1024));
}
#[test]
fn test_http_request_is_chunked() {
let mut headers = std::collections::HashMap::new();
headers.insert("transfer-encoding".to_string(), "chunked".to_string());
let request = HttpRequest {
method: HttpMethod::Put,
path: "/data.csv".to_string(),
headers,
};
assert!(request.is_chunked());
}
#[test]
fn test_http_request_host() {
let mut headers = std::collections::HashMap::new();
headers.insert("host".to_string(), "10.0.0.5:8563".to_string());
let request = HttpRequest {
method: HttpMethod::Get,
path: "/001.csv".to_string(),
headers,
};
assert_eq!(request.host(), Some("10.0.0.5:8563"));
}
#[test]
fn test_encode_chunk_roundtrip() {
let data = b"Hello, World!";
let encoded = encode_chunk(data);
let hex_end = encoded.iter().position(|&b| b == b'\r').unwrap();
let size_str = std::str::from_utf8(&encoded[..hex_end]).unwrap();
let size = usize::from_str_radix(size_str, 16).unwrap();
assert_eq!(size, data.len());
let data_start = hex_end + 2;
let data_end = data_start + size;
assert_eq!(&encoded[data_start..data_end], data);
}
#[test]
fn test_http_transport_client_internal_address_format() {
let ip = "10.0.0.5";
let port: u16 = 8563;
let formatted = format!("{}:{}", ip, port);
assert_eq!(formatted, "10.0.0.5:8563");
}
#[test]
fn test_tls_certificate_to_client_config() {
let cert = TlsCertificate::generate().unwrap();
let config = cert.to_client_config();
assert!(config.is_ok());
}
#[test]
fn test_http_method_display_head() {
assert_eq!(HttpMethod::Head.to_string(), "HEAD");
}
#[tokio::test]
async fn test_parse_http_request_head() {
use tokio::io::AsyncWriteExt;
let request = b"HEAD /001.parquet HTTP/1.1\r\nHost: 10.0.0.5:8563\r\n\r\n";
let (mut client, mut server) = tokio::io::duplex(256);
tokio::spawn(async move {
server.write_all(request).await.unwrap();
});
let parsed = parse_http_request(&mut client).await.unwrap();
assert_eq!(parsed.method, HttpMethod::Head);
assert_eq!(parsed.path, "/001.parquet");
}
#[tokio::test]
async fn test_handle_parquet_import_requests_serves_head_and_range() {
use tokio::io::{AsyncReadExt, AsyncWriteExt};
let file_bytes: Vec<u8> = (0u8..16u8).collect();
let file_len = file_bytes.len();
let (mut server_side, mut client_side) = tokio::io::duplex(4096);
let file_bytes_for_handler = file_bytes.clone();
let handler = tokio::spawn(async move {
serve_parquet_range_requests(&mut server_side, &file_bytes_for_handler).await
});
client_side
.write_all(b"HEAD /001.parquet HTTP/1.1\r\nHost: x\r\n\r\n")
.await
.unwrap();
client_side.flush().await.unwrap();
let head_response = read_exactly(&mut client_side, 39).await;
let head_str = std::str::from_utf8(&head_response).unwrap();
assert!(
head_str.starts_with("HTTP/1.1 200 OK\r\n"),
"HEAD response should start with 200 OK: {:?}",
head_str
);
assert!(
head_str.contains(&format!("Content-Length: {file_len}\r\n")),
"HEAD response should advertise Content-Length: {file_len}"
);
assert!(head_str.ends_with("\r\n\r\n"));
client_side
.write_all(b"GET /001.parquet HTTP/1.1\r\nHost: x\r\nRange: bytes=0-3\r\n\r\n")
.await
.unwrap();
client_side.flush().await.unwrap();
let (status_line_1, body_1) = read_response_with_body(&mut client_side, 4).await;
assert!(
status_line_1.starts_with("HTTP/1.1 206 Partial Content\r\n"),
"expected 206 Partial Content, got {:?}",
status_line_1
);
assert!(
status_line_1.contains("Content-Length: 4\r\n"),
"expected Content-Length: 4: {:?}",
status_line_1
);
assert!(
status_line_1.contains(&format!("Content-Range: bytes 0-3/{file_len}\r\n")),
"expected Content-Range: bytes 0-3/{file_len}: {:?}",
status_line_1
);
assert_eq!(body_1, file_bytes[0..=3], "body slice for 0-3 mismatch");
client_side
.write_all(b"GET /001.parquet HTTP/1.1\r\nHost: x\r\nRange: bytes=4-7\r\n\r\n")
.await
.unwrap();
client_side.flush().await.unwrap();
let (status_line_2, body_2) = read_response_with_body(&mut client_side, 4).await;
assert!(
status_line_2.contains("Content-Range: bytes 4-7/16\r\n"),
"expected Content-Range: bytes 4-7/16: {:?}",
status_line_2
);
assert_eq!(body_2, file_bytes[4..=7], "body slice for 4-7 mismatch");
client_side
.write_all(b"GET /001.parquet HTTP/1.1\r\nHost: x\r\n\r\n")
.await
.unwrap();
client_side.flush().await.unwrap();
let (status_line_3, body_3) = read_response_with_body(&mut client_side, file_len).await;
assert!(
status_line_3.starts_with("HTTP/1.1 200 OK\r\n"),
"expected full 200 OK, got {:?}",
status_line_3
);
assert!(
status_line_3.contains(&format!("Content-Length: {file_len}\r\n")),
"expected Content-Length: {file_len}: {:?}",
status_line_3
);
assert_eq!(body_3, file_bytes, "full-body slice mismatch");
drop(client_side);
let result = handler.await.expect("handler task panicked");
assert!(
result.is_ok(),
"handler returned error on connection close: {:?}",
result
);
async fn read_exactly(stream: &mut tokio::io::DuplexStream, n: usize) -> Vec<u8> {
let mut buf = vec![0u8; n];
stream.read_exact(&mut buf).await.unwrap();
buf
}
async fn read_response_with_body(
stream: &mut tokio::io::DuplexStream,
body_len: usize,
) -> (String, Vec<u8>) {
let mut headers = Vec::new();
let mut byte = [0u8; 1];
loop {
stream.read_exact(&mut byte).await.unwrap();
headers.push(byte[0]);
if headers.ends_with(b"\r\n\r\n") {
break;
}
if headers.len() > 4096 {
panic!("headers exceeded sanity limit");
}
}
let header_str = String::from_utf8(headers).unwrap();
let mut body = vec![0u8; body_len];
if body_len > 0 {
stream.read_exact(&mut body).await.unwrap();
}
(header_str, body)
}
}
}