use std::path::PathBuf;
use tokio::task::JoinHandle;
use crate::query::import::Compression;
use crate::transport::HttpTransportClient;
use super::ImportError;
const CHUNK_SIZE: usize = 64 * 1024;
#[derive(Debug, Clone)]
pub struct ImportFileEntry {
pub address: String,
pub file_name: String,
pub public_key: Option<String>,
}
impl ImportFileEntry {
pub fn new(address: String, file_name: String, public_key: Option<String>) -> Self {
Self {
address,
file_name,
public_key,
}
}
}
pub struct ParallelTransportPool {
connections: Vec<HttpTransportClient>,
entries: Vec<ImportFileEntry>,
}
impl std::fmt::Debug for ParallelTransportPool {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ParallelTransportPool")
.field("connection_count", &self.connections.len())
.field("entries", &self.entries)
.finish()
}
}
impl ParallelTransportPool {
pub async fn connect(
host: &str,
port: u16,
use_tls: bool,
file_count: usize,
) -> Result<Self, ImportError> {
if file_count == 0 {
return Err(ImportError::InvalidConfig(
"file_count must be at least 1".to_string(),
));
}
let mut connect_handles: Vec<JoinHandle<Result<HttpTransportClient, ImportError>>> =
Vec::with_capacity(file_count);
for _ in 0..file_count {
let host = host.to_string();
let handle = tokio::spawn(async move {
HttpTransportClient::connect(&host, port, use_tls)
.await
.map_err(|e| {
ImportError::HttpTransportError(format!("Failed to connect to Exasol: {e}"))
})
});
connect_handles.push(handle);
}
let mut connections = Vec::with_capacity(file_count);
let mut entries = Vec::with_capacity(file_count);
for (idx, handle) in connect_handles.into_iter().enumerate() {
let client = handle
.await
.map_err(|e| {
ImportError::ParallelImportError(format!(
"Connection task {} panicked: {e}",
idx
))
})?
.map_err(|e| {
ImportError::ParallelImportError(format!("Connection {} failed: {e}", idx))
})?;
let file_name = format!("{:03}.csv", idx + 1);
let entry = ImportFileEntry::new(
client.internal_address().to_string(),
file_name,
client.public_key_fingerprint().map(String::from),
);
connections.push(client);
entries.push(entry);
}
Ok(Self {
connections,
entries,
})
}
#[must_use]
pub fn file_entries(&self) -> &[ImportFileEntry] {
&self.entries
}
#[must_use]
pub fn len(&self) -> usize {
self.connections.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.connections.is_empty()
}
#[must_use]
pub fn into_connections(self) -> Vec<HttpTransportClient> {
self.connections
}
}
pub async fn stream_files_parallel(
connections: Vec<HttpTransportClient>,
file_data: Vec<Vec<u8>>,
_compression: Compression,
) -> Result<(), ImportError> {
if connections.len() != file_data.len() {
return Err(ImportError::InvalidConfig(format!(
"Connection count ({}) != file data count ({})",
connections.len(),
file_data.len()
)));
}
let mut stream_handles: Vec<JoinHandle<Result<(), ImportError>>> =
Vec::with_capacity(connections.len());
for (idx, (mut client, data)) in connections.into_iter().zip(file_data).enumerate() {
let handle = tokio::spawn(async move {
client.handle_import_request().await.map_err(|e| {
ImportError::ParallelImportError(format!(
"File {} failed to handle import request: {e}",
idx
))
})?;
for chunk in data.chunks(CHUNK_SIZE) {
client.write_chunked_body(chunk).await.map_err(|e| {
ImportError::ParallelImportError(format!("File {} streaming failed: {e}", idx))
})?;
}
client.write_final_chunk().await.map_err(|e| {
ImportError::ParallelImportError(format!(
"File {} failed to send final chunk: {e}",
idx
))
})?;
Ok(())
});
stream_handles.push(handle);
}
for (idx, handle) in stream_handles.into_iter().enumerate() {
handle
.await
.map_err(|e| {
ImportError::ParallelImportError(format!("Stream task {} panicked: {e}", idx))
})?
.map_err(|e| ImportError::ParallelImportError(format!("Stream {} failed: {e}", idx)))?;
}
Ok(())
}
pub async fn convert_parquet_files_to_csv(
paths: Vec<PathBuf>,
batch_size: usize,
null_value: String,
column_separator: char,
column_delimiter: char,
) -> Result<Vec<Vec<u8>>, ImportError> {
use crate::import::parquet::{record_batch_to_csv, ParquetImportOptions};
use parquet::arrow::arrow_reader::ParquetRecordBatchReaderBuilder;
let mut conversion_handles: Vec<JoinHandle<Result<Vec<u8>, ImportError>>> =
Vec::with_capacity(paths.len());
for (idx, path) in paths.into_iter().enumerate() {
let null_value = null_value.clone();
let handle = tokio::task::spawn_blocking(move || {
let file = std::fs::File::open(&path).map_err(|e| {
ImportError::ParallelImportError(format!(
"Failed to open Parquet file {}: {e}",
path.display()
))
})?;
let builder = ParquetRecordBatchReaderBuilder::try_new(file).map_err(|e| {
ImportError::ParallelImportError(format!(
"Failed to read Parquet file {}: {e}",
path.display()
))
})?;
let reader = builder.with_batch_size(batch_size).build().map_err(|e| {
ImportError::ParallelImportError(format!(
"Failed to build Parquet reader for {}: {e}",
path.display()
))
})?;
let options = ParquetImportOptions::default()
.with_null_value(&null_value)
.with_column_separator(column_separator)
.with_column_delimiter(column_delimiter);
let mut csv_data = Vec::new();
for batch_result in reader {
let batch = batch_result.map_err(|e| {
ImportError::ParallelImportError(format!(
"Failed to read batch from {}: {e}",
path.display()
))
})?;
let csv_rows = record_batch_to_csv(&batch, &options).map_err(|e| {
ImportError::ParallelImportError(format!(
"Failed to convert batch to CSV from {}: {e}",
path.display()
))
})?;
for row in csv_rows {
csv_data.extend_from_slice(row.as_bytes());
csv_data.push(b'\n');
}
}
Ok(csv_data)
});
let handle = tokio::spawn(async move {
handle.await.map_err(|e| {
ImportError::ParallelImportError(format!(
"Parquet conversion task {} panicked: {e}",
idx
))
})?
});
conversion_handles.push(handle);
}
let mut results = Vec::with_capacity(conversion_handles.len());
for (idx, handle) in conversion_handles.into_iter().enumerate() {
let csv_data = handle
.await
.map_err(|e| {
ImportError::ParallelImportError(format!("Conversion task {} panicked: {e}", idx))
})?
.map_err(|e| {
ImportError::ParallelImportError(format!("Conversion {} failed: {e}", idx))
})?;
results.push(csv_data);
}
Ok(results)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_import_file_entry_new() {
let entry = ImportFileEntry::new(
"10.0.0.5:8563".to_string(),
"001.csv".to_string(),
Some("sha256//abc123".to_string()),
);
assert_eq!(entry.address, "10.0.0.5:8563");
assert_eq!(entry.file_name, "001.csv");
assert_eq!(entry.public_key, Some("sha256//abc123".to_string()));
}
#[test]
fn test_import_file_entry_no_tls() {
let entry = ImportFileEntry::new("10.0.0.5:8563".to_string(), "002.csv".to_string(), None);
assert_eq!(entry.address, "10.0.0.5:8563");
assert_eq!(entry.file_name, "002.csv");
assert!(entry.public_key.is_none());
}
#[tokio::test]
async fn test_parallel_transport_pool_zero_count_error() {
let result = ParallelTransportPool::connect("localhost", 8563, false, 0).await;
assert!(result.is_err());
let err = result.unwrap_err();
assert!(matches!(err, ImportError::InvalidConfig(_)));
}
#[tokio::test]
async fn test_stream_files_parallel_mismatched_counts() {
let connections = vec![];
let file_data = vec![vec![1, 2, 3]];
let result = stream_files_parallel(connections, file_data, Compression::None).await;
assert!(result.is_err());
let err = result.unwrap_err();
assert!(matches!(err, ImportError::InvalidConfig(_)));
}
#[tokio::test]
async fn test_stream_files_parallel_empty() {
let connections = vec![];
let file_data = vec![];
let result = stream_files_parallel(connections, file_data, Compression::None).await;
assert!(result.is_ok());
}
}