use std::collections::{BTreeMap, VecDeque};
use std::future::Future;
use std::marker::PhantomData;
use std::path::Path;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use async_trait::async_trait;
use log::info;
use tokio::fs::File;
use tokio::io::BufReader;
use tokio_stream::Stream;
use databend_client::PresignedResponse;
use databend_client::QueryResponse;
use databend_client::{APIClient, SchemaField};
use bigbytes_driver_core::error::{Error, Result};
use bigbytes_driver_core::raw_rows::{RawRow, RawRowIterator, RawRowWithStats};
use bigbytes_driver_core::rows::{Row, RowIterator, RowStatsIterator, RowWithStats, ServerStats};
use bigbytes_driver_core::schema::{Schema, SchemaRef};
use crate::conn::{Connection, ConnectionInfo, Reader};
#[derive(Clone)]
pub struct RestAPIConnection {
client: Arc<APIClient>,
}
#[async_trait]
impl Connection for RestAPIConnection {
async fn info(&self) -> ConnectionInfo {
ConnectionInfo {
handler: "RestAPI".to_string(),
host: self.client.host().to_string(),
port: self.client.port(),
user: self.client.username(),
database: self.client.current_database(),
warehouse: self.client.current_warehouse(),
}
}
fn last_query_id(&self) -> Option<String> {
self.client.last_query_id()
}
async fn close(&self) -> Result<()> {
self.client.close().await;
Ok(())
}
async fn exec(&self, sql: &str) -> Result<i64> {
info!("exec: {}", sql);
let mut resp = self.client.start_query(sql).await?;
let node_id = resp.node_id.clone();
while let Some(next_uri) = resp.next_uri {
resp = self
.client
.query_page(&resp.id, &next_uri, &node_id)
.await?;
}
Ok(resp.stats.progresses.write_progress.rows as i64)
}
async fn kill_query(&self, query_id: &str) -> Result<()> {
Ok(self.client.kill_query(query_id).await?)
}
async fn query_iter(&self, sql: &str) -> Result<RowIterator> {
info!("query iter: {}", sql);
let rows_with_progress = self.query_iter_ext(sql).await?;
let rows = rows_with_progress.filter_rows().await;
Ok(rows)
}
async fn query_iter_ext(&self, sql: &str) -> Result<RowStatsIterator> {
info!("query iter ext: {}", sql);
let resp = self.client.start_query(sql).await?;
let resp = self.wait_for_schema(resp, true).await?;
let (schema, rows) = RestAPIRows::<RowWithStats>::from_response(self.client.clone(), resp)?;
Ok(RowStatsIterator::new(Arc::new(schema), Box::pin(rows)))
}
async fn query_raw_iter(&self, sql: &str) -> Result<RawRowIterator> {
info!("query raw iter: {}", sql);
let resp = self.client.start_query(sql).await?;
let resp = self.wait_for_schema(resp, true).await?;
let (schema, rows) =
RestAPIRows::<RawRowWithStats>::from_response(self.client.clone(), resp)?;
Ok(RawRowIterator::new(Arc::new(schema), Box::pin(rows)))
}
async fn get_presigned_url(&self, operation: &str, stage: &str) -> Result<PresignedResponse> {
info!("get presigned url: {} {}", operation, stage);
let sql = format!("PRESIGN {} {}", operation, stage);
let row = self.query_row(&sql).await?.ok_or_else(|| {
Error::InvalidResponse("Empty response from server for presigned request".to_string())
})?;
let (method, headers, url): (String, String, String) =
row.try_into().map_err(Error::Parsing)?;
let headers: BTreeMap<String, String> = serde_json::from_str(&headers)?;
Ok(PresignedResponse {
method,
headers,
url,
})
}
async fn upload_to_stage(&self, stage: &str, data: Reader, size: u64) -> Result<()> {
self.client.upload_to_stage(stage, data, size).await?;
Ok(())
}
async fn load_data(
&self,
sql: &str,
data: Reader,
size: u64,
file_format_options: Option<BTreeMap<&str, &str>>,
copy_options: Option<BTreeMap<&str, &str>>,
) -> Result<ServerStats> {
info!(
"load data: {}, size: {}, format: {:?}, copy: {:?}",
sql, size, file_format_options, copy_options
);
let now = chrono::Utc::now()
.timestamp_nanos_opt()
.ok_or_else(|| Error::IO("Failed to get current timestamp".to_string()))?;
let stage = format!("@~/client/load/{}", now);
let file_format_options =
file_format_options.unwrap_or_else(Self::default_file_format_options);
let copy_options = copy_options.unwrap_or_else(Self::default_copy_options);
self.upload_to_stage(&stage, data, size).await?;
let resp = self
.client
.insert_with_stage(sql, &stage, file_format_options, copy_options)
.await?;
Ok(ServerStats::from(resp.stats))
}
async fn load_file(
&self,
sql: &str,
fp: &Path,
format_options: Option<BTreeMap<&str, &str>>,
copy_options: Option<BTreeMap<&str, &str>>,
) -> Result<ServerStats> {
info!(
"load file: {}, file: {:?}, format: {:?}, copy: {:?}",
sql, fp, format_options, copy_options
);
let file = File::open(fp).await?;
let metadata = file.metadata().await?;
let size = metadata.len();
let data = BufReader::new(file);
let mut format_options = format_options.unwrap_or_else(Self::default_file_format_options);
if !format_options.contains_key("type") {
let file_type = fp
.extension()
.ok_or_else(|| Error::BadArgument("file type not specified".to_string()))?
.to_str()
.ok_or_else(|| Error::BadArgument("file type empty".to_string()))?;
format_options.insert("type", file_type);
}
self.load_data(
sql,
Box::new(data),
size,
Some(format_options),
copy_options,
)
.await
}
async fn stream_load(&self, sql: &str, data: Vec<Vec<&str>>) -> Result<ServerStats> {
info!("stream load: {}, length: {:?}", sql, data.len());
let mut wtr = csv::WriterBuilder::new().from_writer(vec![]);
for row in data {
wtr.write_record(row)
.map_err(|e| Error::BadArgument(e.to_string()))?;
}
let bytes = wtr.into_inner().map_err(|e| Error::IO(e.to_string()))?;
let size = bytes.len() as u64;
let reader = Box::new(std::io::Cursor::new(bytes));
let stats = self.load_data(sql, reader, size, None, None).await?;
Ok(stats)
}
}
impl<'o> RestAPIConnection {
pub async fn try_create(dsn: &str, name: String) -> Result<Self> {
let client = APIClient::new(dsn, Some(name)).await?;
Ok(Self {
client: Arc::new(client),
})
}
async fn wait_for_schema(
&self,
resp: QueryResponse,
return_on_progress: bool,
) -> Result<QueryResponse> {
if !resp.data.is_empty()
|| !resp.schema.is_empty()
|| (return_on_progress && resp.stats.progresses.has_progress())
{
return Ok(resp);
}
let node_id = resp.node_id.clone();
if let Some(node_id) = &node_id {
self.client.set_last_node_id(node_id.clone());
}
let mut result = resp;
while let Some(next_uri) = result.next_uri {
result = self
.client
.query_page(&result.id, &next_uri, &node_id)
.await?;
if !result.data.is_empty()
|| !result.schema.is_empty()
|| (return_on_progress && result.stats.progresses.has_progress())
{
break;
}
}
Ok(result)
}
fn default_file_format_options() -> BTreeMap<&'o str, &'o str> {
vec![
("type", "CSV"),
("field_delimiter", ","),
("record_delimiter", "\n"),
("skip_header", "0"),
]
.into_iter()
.collect()
}
fn default_copy_options() -> BTreeMap<&'o str, &'o str> {
vec![("purge", "true")].into_iter().collect()
}
pub async fn query_row_batch(&self, sql: &str) -> Result<RowBatch> {
let resp = self.client.start_query(sql).await?;
let resp = self.wait_for_schema(resp, false).await?;
RowBatch::from_response(self.client.clone(), resp)
}
}
type PageFut = Pin<Box<dyn Future<Output = Result<QueryResponse>> + Send>>;
pub struct RestAPIRows<T> {
client: Arc<APIClient>,
schema: SchemaRef,
data: VecDeque<Vec<Option<String>>>,
stats: Option<ServerStats>,
query_id: String,
node_id: Option<String>,
next_uri: Option<String>,
next_page: Option<PageFut>,
_phantom: std::marker::PhantomData<T>,
}
impl<T> RestAPIRows<T> {
fn from_response(client: Arc<APIClient>, resp: QueryResponse) -> Result<(Schema, Self)> {
let schema: Schema = resp.schema.try_into()?;
let rows = Self {
client,
query_id: resp.id,
node_id: resp.node_id,
next_uri: resp.next_uri,
schema: Arc::new(schema.clone()),
data: resp.data.into(),
stats: Some(ServerStats::from(resp.stats)),
next_page: None,
_phantom: PhantomData,
};
Ok((schema, rows))
}
}
impl<T: FromRowStats + std::marker::Unpin> Stream for RestAPIRows<T> {
type Item = Result<T>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
if let Some(ss) = self.stats.take() {
return Poll::Ready(Some(Ok(T::from_stats(ss))));
}
if self.data.len() > 1 {
if let Some(row) = self.data.pop_front() {
let row = T::try_from_row(row, self.schema.clone())?;
return Poll::Ready(Some(Ok(row)));
}
}
match self.next_page {
Some(ref mut next_page) => match Pin::new(next_page).poll(cx) {
Poll::Ready(Ok(resp)) => {
if self.schema.fields().is_empty() {
self.schema = Arc::new(resp.schema.try_into()?);
}
self.next_uri = resp.next_uri;
self.next_page = None;
let mut new_data = resp.data.into();
self.data.append(&mut new_data);
Poll::Ready(Some(Ok(T::from_stats(resp.stats.into()))))
}
Poll::Ready(Err(e)) => {
self.next_page = None;
Poll::Ready(Some(Err(e)))
}
Poll::Pending => Poll::Pending,
},
None => match self.next_uri {
Some(ref next_uri) => {
let client = self.client.clone();
let next_uri = next_uri.clone();
let query_id = self.query_id.clone();
let node_id = self.node_id.clone();
self.next_page = Some(Box::pin(async move {
client
.query_page(&query_id, &next_uri, &node_id)
.await
.map_err(|e| e.into())
}));
self.poll_next(cx)
}
None => match self.data.pop_front() {
Some(row) => {
let row = T::try_from_row(row, self.schema.clone())?;
Poll::Ready(Some(Ok(row)))
}
None => Poll::Ready(None),
},
},
}
}
}
trait FromRowStats: Send + Sync + Clone {
fn from_stats(stats: ServerStats) -> Self;
fn try_from_row(row: Vec<Option<String>>, schema: SchemaRef) -> Result<Self>;
}
impl FromRowStats for RowWithStats {
fn from_stats(stats: ServerStats) -> Self {
RowWithStats::Stats(stats)
}
fn try_from_row(row: Vec<Option<String>>, schema: SchemaRef) -> Result<Self> {
Ok(RowWithStats::Row(Row::try_from((schema, row))?))
}
}
impl FromRowStats for RawRowWithStats {
fn from_stats(stats: ServerStats) -> Self {
RawRowWithStats::Stats(stats)
}
fn try_from_row(row: Vec<Option<String>>, schema: SchemaRef) -> Result<Self> {
let rows = Row::try_from((schema, row.clone()))?;
Ok(RawRowWithStats::Row(RawRow::new(rows, row)))
}
}
pub struct RowBatch {
schema: Vec<SchemaField>,
client: Arc<APIClient>,
query_id: String,
node_id: Option<String>,
next_uri: Option<String>,
data: Vec<Vec<Option<String>>>,
}
impl RowBatch {
pub fn schema(&self) -> Vec<SchemaField> {
self.schema.clone()
}
fn from_response(client: Arc<APIClient>, mut resp: QueryResponse) -> Result<Self> {
Ok(Self {
schema: std::mem::take(&mut resp.schema),
client,
query_id: resp.id,
node_id: resp.node_id,
next_uri: resp.next_uri,
data: resp.data,
})
}
pub async fn fetch_next_page(&mut self) -> Result<Vec<Vec<Option<String>>>> {
if !self.data.is_empty() {
return Ok(std::mem::take(&mut self.data));
}
while let Some(next_uri) = &self.next_uri {
let resp = self
.client
.query_page(&self.query_id, next_uri, &self.node_id)
.await?;
self.next_uri = resp.next_uri;
if !resp.data.is_empty() {
return Ok(resp.data);
}
}
Ok(vec![])
}
}