use anyhow::Result;
use arrow::array::RecordBatch;
use client_api::{ColumnId, StorageClient, StorageRequest, TableId};
use parquet::arrow::arrow_reader::ParquetRecordBatchReaderBuilder;
use rocket::serde::ser;
use std::collections::HashMap;
use std::fs::{self, File};
use std::io::Write;
use std::path::Path;
use tokio::sync::mpsc::{channel, Receiver, Sender};
use tokio::task;
use crate::client_api::{self, DataRequest};
pub struct StorageClientImpl {
id: usize,
table_file_map: HashMap<TableId, String>,
server_url: String,
local_cache: String,
use_local_cache: bool,
file_server_map: HashMap<String, String>,
}
impl StorageClientImpl {
pub fn new(id: usize, server_url: &str) -> Self {
let cache = StorageClientImpl::local_cache_path();
if !Path::new(&cache).exists() {
fs::create_dir_all(&cache).unwrap();
}
Self {
id,
table_file_map: HashMap::new(),
server_url: server_url.to_string(),
local_cache: cache,
use_local_cache: false,
file_server_map: HashMap::new(),
}
}
pub fn getid(&self) -> usize {
self.id
}
pub fn new_for_test(
id: usize,
map: HashMap<TableId, String>,
server_url: &str,
use_local_cache: bool,
) -> Self {
let cache = StorageClientImpl::local_cache_path();
println!("Save cache to {}", cache);
if !Path::new(&cache).exists() {
fs::create_dir_all(&cache).unwrap();
}
Self {
id,
table_file_map: map,
server_url: server_url.to_string(),
local_cache: cache,
use_local_cache,
file_server_map: HashMap::new(),
}
}
pub fn local_cache_path() -> String {
String::from("./istziio_client_cache/")
}
pub async fn read_entire_table(
&mut self,
table: TableId,
request_id: usize,
) -> Result<Receiver<RecordBatch>> {
let mut file_path = self.get_path(table)?;
if !self.use_local_cache {
let start = std::time::Instant::now();
file_path = self.fetch_file(&file_path, request_id).await.unwrap();
let duration = start.elapsed();
}
let (sender, receiver) = channel::<RecordBatch>(1000);
task::spawn(async move {
if let Err(e) = Self::read_pqt_all(&file_path, sender).await {
println!("Error reading parquet file: {:?}", e);
}
});
Ok(receiver)
}
pub async fn read_entire_table_sync(
&mut self,
table: TableId,
request_id: usize,
) -> Result<Vec<RecordBatch>> {
let mut file_path = self.get_path(table)?;
if !self.use_local_cache {
let start = std::time::Instant::now();
file_path = self.fetch_file(&file_path, request_id).await.unwrap();
let duration = start.elapsed();
}
Self::read_pqt_all_sync(&file_path).await
}
#[allow(unused_variables)]
pub async fn entire_columns(
&self,
table: TableId,
columns: Vec<ColumnId>,
) -> Result<Receiver<RecordBatch>> {
let file_path = self.get_path(table)?;
let (sender, receiver) = channel::<RecordBatch>(1000);
task::spawn(async move {
if let Err(e) = Self::read_pqt_all(&file_path, sender).await {
println!("Error reading parquet file: {:?}", e);
}
});
Ok(receiver)
}
fn get_path(&self, table: TableId) -> Result<String> {
if let Some(file_path) = self.table_file_map.get(&table) {
Ok(file_path.clone())
} else {
panic!(
"Path not found in local table, catalog service is assume not available yet,
please check if local table_file_map is correctly initialized."
);
}
}
#[allow(dead_code)]
#[allow(unused_variables)]
fn consult_catalog(&self, table: TableId) -> Result<()> {
todo!()
}
async fn fetch_file(&mut self, file_path: &str, request_id: usize) -> Result<String> {
let trimmed_path: Vec<&str> = file_path.split('/').collect();
let file_name = trimmed_path.last().ok_or_else(|| {
anyhow::Error::msg("File path is empty")
})?;
let mut server_url = self.server_url.clone();
if self.file_server_map.contains_key(file_name.to_owned()) {
println!(
"File {} is in file_server_map, rid:{}",
file_name, request_id
);
server_url = self
.file_server_map
.get(file_name.to_owned())
.unwrap()
.clone();
}
let url = format!("{}/s3/{}?rid={}", server_url, file_name, request_id);
println!("Sending request: {}, rid:{}", url, request_id);
let start = std::time::Instant::now();
let response = reqwest::get(url).await?;
println!(
"Request id: {:?}, Response remote_addr: {:?}, Response status: {:?}",
request_id,
response.remote_addr().unwrap().to_string(),
response.status()
);
self.file_server_map.insert(
file_name.to_owned().to_string(),
response.remote_addr().unwrap().to_string(),
);
let file_contents = response.bytes().await?;
let duration = start.elapsed();
println!(
"Time used to wait for server response: {:?}, rid:{}",
duration, request_id
);
let mut file_path = self.local_cache.clone();
file_path.push_str(file_name);
let mut dup_id = 0;
while Path::new(&file_path).exists() {
dup_id += 1;
file_path = self.local_cache.clone();
file_path.push_str(file_name);
file_path.push_str(&format!("_{}", dup_id));
}
let mut file = File::create(&file_path)?;
file.write_all(&file_contents)?;
println!("parquet written to {}", file_path);
if dup_id > 0 {
Ok(file_name.to_string() + "_" + dup_id.to_string().as_str())
} else {
Ok(file_name.to_string())
}
}
async fn read_pqt_all(file_path: &str, sender: Sender<RecordBatch>) -> Result<()> {
let mut local_path = StorageClientImpl::local_cache_path();
local_path.push_str(file_path);
println!("read_pqt_all Reading from local_path: {:?}", local_path);
let start = std::time::Instant::now();
let file = File::open(local_path)?;
let builder = ParquetRecordBatchReaderBuilder::try_new(file)?;
let mut reader = builder.build()?;
while let Some(Ok(rb)) = reader.next() {
sender.send(rb).await?;
}
let duration = start.elapsed();
println!(
"read_pqt_all: Time used to read from parquet: {:?}",
duration
);
Ok(())
}
async fn read_pqt_all_sync(file_path: &str) -> Result<Vec<RecordBatch>> {
let mut local_path = StorageClientImpl::local_cache_path();
let start = std::time::Instant::now();
local_path.push_str(file_path);
println!(
"read_pqt_all_sync Reading from local_path: {:?}",
local_path
);
let file = File::open(local_path)?;
let builder = ParquetRecordBatchReaderBuilder::try_new(file)?;
let mut reader = builder.build()?;
let mut result: Vec<RecordBatch> = Vec::new();
while let Some(Ok(rb)) = reader.next() {
result.push(rb);
}
let duration = start.elapsed();
println!(
"read_pqt_all_sync: Time used to read from parquet: {:?}",
duration
);
Ok(result)
}
}
#[async_trait::async_trait]
impl StorageClient for StorageClientImpl {
async fn request_data(&mut self, request: StorageRequest) -> Result<Receiver<RecordBatch>> {
match request.data_request() {
DataRequest::Table(table_id) => {
self.read_entire_table(*table_id, request.request_id())
.await
}
DataRequest::Columns(_table_id, _column_ids) => {
unimplemented!("Column request is not supported yet")
}
DataRequest::Tuple(_record_ids) => {
unimplemented!("Tuple request is not supported yet")
}
}
}
async fn request_data_sync(&mut self, request: StorageRequest) -> Result<Vec<RecordBatch>> {
match request.data_request() {
DataRequest::Table(table_id) => {
self.read_entire_table_sync(*table_id, request.request_id())
.await
}
DataRequest::Columns(_table_id, _column_ids) => {
unimplemented!("Column request is not supported yet")
}
DataRequest::Tuple(_record_ids) => {
unimplemented!("Tuple request is not supported yet")
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use arrow::array::{Int32Array, StringArray};
use arrow::datatypes::{DataType, Field, Schema};
use parquet::column::writer::ColumnWriter;
use parquet::data_type::ByteArray;
use parquet::file::properties::WriterProperties;
use parquet::file::writer::SerializedFileWriter;
use parquet::schema::parser::parse_message_type;
use std::sync::Arc;
use std::time::Duration;
use tokio::runtime::Runtime;
use tokio::time::sleep;
fn create_sample_rb() -> RecordBatch {
let schema = Schema::new(vec![
Field::new("id", DataType::Int32, false),
Field::new("name", DataType::Utf8, false),
]);
let row_num = 10;
let ids_vec = (1..=row_num).collect::<Vec<i32>>();
let names_vec = (1..=row_num)
.map(|id| format!("testrow_{}", id))
.collect::<Vec<String>>();
let ids = Int32Array::from(ids_vec);
let names = StringArray::from(names_vec);
RecordBatch::try_new(Arc::new(schema), vec![Arc::new(ids), Arc::new(names)]).unwrap()
}
fn create_sample_parquet_file(file_name: &str, row_num: usize) -> anyhow::Result<()> {
let mut cache_path = StorageClientImpl::local_cache_path();
if !Path::new(&cache_path).exists() {
fs::create_dir_all(&cache_path).unwrap();
}
cache_path.push_str(file_name);
let path = Path::new(&cache_path);
let message_type = "
message schema {
REQUIRED INT32 id;
REQUIRED BYTE_ARRAY name (UTF8);
}
";
let schema = Arc::new(parse_message_type(message_type).unwrap());
let file = fs::File::create(path).unwrap();
let props: WriterProperties = WriterProperties::builder().build();
let mut writer = SerializedFileWriter::new(file, schema, Arc::new(props))?;
let ids = (1..=row_num as i32).collect::<Vec<i32>>();
let names = (1..=row_num)
.map(|id| format!("testrow_{}", id))
.collect::<Vec<String>>();
let names_str: Vec<&str> = names.iter().map(|name| name.as_str()).collect();
let mut row_group_writer = writer.next_row_group().unwrap();
while let Some(mut col_writer) = row_group_writer.next_column().unwrap() {
match col_writer.untyped() {
ColumnWriter::Int32ColumnWriter(ref mut typed_writer) => {
typed_writer.write_batch(&ids, None, None)?;
}
ColumnWriter::ByteArrayColumnWriter(ref mut typed_writer) => {
let byte_array_names: Vec<ByteArray> = names_str
.iter()
.map(|&name| ByteArray::from(name))
.collect();
typed_writer.write_batch(&byte_array_names, None, None)?;
}
_ => {}
}
col_writer.close().unwrap()
}
row_group_writer.close().unwrap();
writer.close().unwrap();
Ok(())
}
fn setup_local() -> (StorageClientImpl, String) {
let mut table_file_map: HashMap<TableId, String> = HashMap::new();
let file_name: &str = "sample.parquet";
let mut file_path = StorageClientImpl::local_cache_path();
file_path.push_str(file_name);
table_file_map.insert(0, file_name.to_string());
create_sample_parquet_file(file_name, 10).unwrap();
(
StorageClientImpl::new_for_test(1, table_file_map, "http://localhost:26380", true),
file_name.to_string(),
)
}
fn setup_local_large() -> (StorageClientImpl, String) {
let mut table_file_map: HashMap<TableId, String> = HashMap::new();
let file_name: &str = "sample.parquet";
let mut file_path = StorageClientImpl::local_cache_path();
file_path.push_str(file_name);
table_file_map.insert(0, file_name.to_string());
create_sample_parquet_file(file_name, 1_000_000).unwrap();
(
StorageClientImpl::new_for_test(1, table_file_map, "http://localhost:26380", true),
file_name.to_string(),
)
}
fn setup_remote() -> StorageClientImpl {
let mut table_file_map: HashMap<TableId, String> = HashMap::new();
let file_name: &str = "sample.parquet";
let mut file_path = StorageClientImpl::local_cache_path();
file_path.push_str(file_name);
table_file_map.insert(0, file_name.to_string());
StorageClientImpl::new_for_test(1, table_file_map, "http://localhost:26380", false)
}
#[test]
fn test_entire_table_local() {
let (mut client, _file_name) = setup_local();
let rt = Runtime::new().unwrap();
rt.block_on(async {
let mut receiver = client.read_entire_table(0, 0).await.unwrap();
sleep(Duration::from_secs(1)).await;
let res = receiver.try_recv();
assert!(res.is_ok());
let record_batch = res.unwrap();
let sample_rb = create_sample_rb();
assert_eq!(record_batch, sample_rb);
});
}
#[test]
fn test_entire_table_local_sync() {
let (mut client, _file_name) = setup_local();
let rt = Runtime::new().unwrap();
rt.block_on(async {
let res = client.read_entire_table_sync(0, 0).await;
assert!(res.is_ok());
let record_batch = res.unwrap();
let rb = record_batch.first();
assert!(rb.is_some());
let sample_rb = create_sample_rb();
assert_eq!(rb.unwrap().clone(), sample_rb);
});
}
#[test]
#[ignore]
fn test_entire_table_remote() {
let mut client = setup_remote();
let rt = Runtime::new().unwrap();
rt.block_on(async {
let mut receiver = client.read_entire_table(0, 0).await.unwrap();
sleep(Duration::from_secs(1)).await;
let res: std::prelude::v1::Result<RecordBatch, tokio::sync::mpsc::error::TryRecvError> =
receiver.try_recv();
assert!(res.is_ok());
let record_batch = res.unwrap();
let sample_rb = create_sample_rb();
assert_eq!(record_batch, sample_rb);
});
}
#[test]
fn test_request_data_table_local_simple() {
let (mut client, _file_name) = setup_local();
let rt: Runtime = Runtime::new().unwrap();
rt.block_on(async {
let mut receiver = client
.request_data(StorageRequest::new(0, DataRequest::Table(0)))
.await
.unwrap();
sleep(Duration::from_secs(1)).await;
let res = receiver.try_recv();
assert!(res.is_ok());
let record_batch = res.unwrap();
let sample_rb = create_sample_rb();
assert_eq!(record_batch, sample_rb);
});
}
#[test]
fn test_request_data_table_local_exhaustive() {
let (mut client, _file_name) = setup_local_large();
let rt: Runtime = Runtime::new().unwrap();
rt.block_on(async {
let mut receiver = client
.request_data(StorageRequest::new(0, DataRequest::Table(0)))
.await
.unwrap();
sleep(Duration::from_secs(1)).await;
let mut total_num_rows = 0;
while let Some(rb) = receiver.recv().await {
total_num_rows += rb.num_rows();
}
assert_eq!(total_num_rows, 1_000_000);
});
}
#[test]
#[ignore]
fn test_request_data_table_remote() {
let mut client = setup_remote();
let rt = Runtime::new().unwrap();
rt.block_on(async {
let mut receiver = client
.request_data(StorageRequest::new(0, DataRequest::Table(0)))
.await
.unwrap();
sleep(Duration::from_secs(1)).await;
let res = receiver.try_recv();
assert!(res.is_ok());
let record_batch = res.unwrap();
let sample_rb: RecordBatch = create_sample_rb();
assert_eq!(record_batch, sample_rb);
});
}
#[test]
fn test_request_data_table_local_sync() {
let (mut client, _file_name) = setup_local();
let rt: Runtime = Runtime::new().unwrap();
rt.block_on(async {
let res = client
.request_data_sync(StorageRequest::new(0, DataRequest::Table(0)))
.await;
assert!(res.is_ok());
let record_batch = res.unwrap();
let rb = record_batch.first();
assert!(rb.is_some());
let sample_rb = create_sample_rb();
assert_eq!(rb.unwrap().clone(), sample_rb);
});
}
}