use std::collections::BTreeMap;
use std::collections::BTreeSet;
use std::path::Path;
use std::path::PathBuf;
use std::sync::Arc;
use async_trait::async_trait;
use log::info;
use tokio::fs::File;
use tokio::io::AsyncRead;
use tokio::io::BufReader;
use tokio_stream::StreamExt;
use crate::client::LoadMethod;
use databend_client::schema::{DataType, Field, NumberDataType, Schema};
use databend_client::StageLocation;
use databend_client::{presign_download_from_stage, PresignedResponse};
use databend_driver_core::error::{Error, Result};
use databend_driver_core::raw_rows::{RawRow, RawRowIterator};
use databend_driver_core::rows::{Row, RowIterator, RowStatsIterator, RowWithStats, ServerStats};
use databend_driver_core::value::{NumberValue, Value};
pub struct ConnectionInfo {
pub handler: String,
pub host: String,
pub port: u16,
pub user: String,
pub catalog: Option<String>,
pub database: Option<String>,
pub warehouse: Option<String>,
}
pub type Reader = Box<dyn AsyncRead + Send + Sync + Unpin + 'static>;
#[async_trait]
pub trait IConnection: Send + Sync {
async fn info(&self) -> ConnectionInfo;
async fn close(&self) -> Result<()> {
Ok(())
}
fn close_with_spawn(&self) -> Result<()> {
Ok(())
}
fn last_query_id(&self) -> Option<String>;
async fn version(&self) -> Result<String> {
let row = self.query_row("SELECT version()").await?;
let version = match row {
Some(row) => {
let (version,): (String,) = row.try_into().map_err(Error::Parsing)?;
version
}
None => "".to_string(),
};
Ok(version)
}
async fn exec(&self, sql: &str) -> Result<i64>;
async fn kill_query(&self, query_id: &str) -> Result<()>;
async fn query_iter(&self, sql: &str) -> Result<RowIterator>;
async fn query_iter_ext(&self, sql: &str) -> Result<RowStatsIterator>;
fn supports_server_side_params(&self) -> bool {
false
}
async fn exec_with_params(&self, sql: &str, _params: Option<serde_json::Value>) -> Result<i64> {
self.exec(sql).await
}
async fn query_iter_with_params(
&self,
sql: &str,
_params: Option<serde_json::Value>,
) -> Result<RowIterator> {
self.query_iter(sql).await
}
async fn query_iter_ext_with_params(
&self,
sql: &str,
_params: Option<serde_json::Value>,
) -> Result<RowStatsIterator> {
self.query_iter_ext(sql).await
}
async fn query_row(&self, sql: &str) -> Result<Option<Row>> {
let rows = self.query_all(sql).await?;
let row = rows.into_iter().next();
Ok(row)
}
async fn query_all(&self, sql: &str) -> Result<Vec<Row>> {
let rows = self.query_iter(sql).await?;
rows.collect().await
}
async fn query_raw_iter(&self, _sql: &str) -> Result<RawRowIterator> {
Err(Error::BadArgument(
"Unsupported implement query_raw_iter".to_string(),
))
}
async fn query_raw_all(&self, sql: &str) -> Result<Vec<RawRow>> {
let rows = self.query_raw_iter(sql).await?;
rows.collect().await
}
async fn get_presigned_url(&self, operation: &str, stage: &str) -> Result<PresignedResponse> {
info!("get presigned url: {} {}", operation, stage);
let sql = format!("PRESIGN {operation} {stage}");
let row = self.query_row(&sql).await?.ok_or_else(|| {
Error::InvalidResponse("Empty response from server for presigned request".to_string())
})?;
let (method, headers, url): (String, String, String) =
row.try_into().map_err(Error::Parsing)?;
let headers: BTreeMap<String, String> = serde_json::from_str(&headers)?;
Ok(PresignedResponse {
method,
headers,
url,
})
}
async fn upload_to_stage(&self, stage: &str, data: Reader, size: u64) -> Result<()>;
async fn load_data(
&self,
sql: &str,
data: Reader,
size: u64,
method: LoadMethod,
) -> Result<ServerStats>;
async fn load_file(&self, sql: &str, fp: &Path, method: LoadMethod) -> Result<ServerStats>;
async fn load_file_with_options(
&self,
sql: &str,
fp: &Path,
file_format_options: Option<BTreeMap<&str, &str>>,
copy_options: Option<BTreeMap<&str, &str>>,
) -> Result<ServerStats>;
async fn stream_load(
&self,
sql: &str,
data: Vec<Vec<&str>>,
_method: LoadMethod,
) -> Result<ServerStats>;
fn set_warehouse(&self, warehouse: &str) -> Result<()>;
fn set_database(&self, database: &str) -> Result<()>;
fn set_role(&self, role: &str) -> Result<()>;
fn set_session(&self, key: &str, value: &str) -> Result<()>;
async fn put_files(&self, local_file: &str, stage: &str) -> Result<RowStatsIterator> {
let mut total_count: usize = 0;
let mut total_size: usize = 0;
let local_dsn = url::Url::parse(local_file)?;
validate_local_scheme(local_dsn.scheme())?;
let entries = expand_local_glob(local_dsn.path())?;
let mut results = Vec::new();
let stage_location = StageLocation::try_from(stage)?;
let schema = Arc::new(put_get_schema());
for entry in entries {
let filename = entry
.file_name()
.ok_or_else(|| Error::BadArgument(format!("Invalid local file path: {entry:?}")))?
.to_str()
.ok_or_else(|| Error::BadArgument(format!("Invalid local file path: {entry:?}")))?;
let stage_file = stage_location.file_path(filename);
let (fname, status, size) = if entry.is_dir() {
(
entry.to_string_lossy().to_string(),
format!(
"BadArgument: Local path is a directory: {}",
entry.display()
),
0,
)
} else {
let file = File::open(&entry).await?;
let size = file.metadata().await?.len();
let data = BufReader::new(file);
match self
.upload_to_stage(&stage_file, Box::new(data), size)
.await
{
Ok(_) => {
total_count += 1;
total_size += size as usize;
(
entry.to_string_lossy().to_string(),
"SUCCESS".to_owned(),
size,
)
}
Err(e) => (entry.to_string_lossy().to_string(), e.to_string(), size),
}
};
let ss = ServerStats {
write_rows: total_count,
write_bytes: total_size,
..Default::default()
};
results.push(Ok(RowWithStats::Stats(ss)));
results.push(Ok(RowWithStats::Row(Row::from_vec(
schema.clone(),
vec![
Value::String(fname),
Value::String(status),
Value::Number(NumberValue::UInt64(size)),
],
))));
}
Ok(RowStatsIterator::new(
schema,
Box::pin(tokio_stream::iter(results)),
))
}
async fn get_files(&self, stage: &str, local_file: &str) -> Result<RowStatsIterator> {
let mut total_count: usize = 0;
let mut total_size: usize = 0;
let local_dsn = url::Url::parse(local_file)?;
validate_local_scheme(local_dsn.scheme())?;
ensure_local_destination_dir(local_dsn.path())?;
let location = StageLocation::try_from(stage)?;
let list_sql = format!("LIST {location}");
let mut response = self.query_iter(&list_sql).await?;
let mut files = Vec::new();
let mut results = Vec::new();
let schema = Arc::new(put_get_schema());
while let Some(row) = response.next().await {
let (name, _, _, _, _): (String, u64, Option<String>, String, Option<String>) =
row?.try_into().map_err(Error::Parsing)?;
if let Some(file) = normalize_stage_list_entry(&location, &name)? {
files.push(file);
}
}
if files.is_empty() {
return Err(Error::BadArgument(format!(
"No stage files matched: {stage}"
)));
}
ensure_unique_basenames(&files)?;
for file in files {
let local_file = Path::new(local_dsn.path()).join(&file.basename);
let (status, size) = match self.get_presigned_url("DOWNLOAD", &file.stage_file).await {
Ok(presign) => match presign_download_from_stage(presign, &local_file).await {
Ok(size) => {
total_count += 1;
total_size += size as usize;
("SUCCESS".to_owned(), size)
}
Err(e) => (e.to_string(), 0),
},
Err(e) => (e.to_string(), 0),
};
let ss = ServerStats {
read_rows: total_count,
read_bytes: total_size,
..Default::default()
};
results.push(Ok(RowWithStats::Stats(ss)));
results.push(Ok(RowWithStats::Row(Row::from_vec(
schema.clone(),
vec![
Value::String(local_file.to_string_lossy().to_string()),
Value::String(status),
Value::Number(NumberValue::UInt64(size)),
],
))));
}
Ok(RowStatsIterator::new(
schema,
Box::pin(tokio_stream::iter(results)),
))
}
}
#[derive(Debug, Clone)]
struct StageFile {
stage_file: String,
basename: String,
}
fn expand_local_glob(pattern: &str) -> Result<Vec<PathBuf>> {
let entries = glob::glob(pattern)?.collect::<std::result::Result<Vec<_>, _>>()?;
if entries.is_empty() {
return Err(Error::BadArgument(format!(
"No local files matched: {pattern}"
)));
}
Ok(entries)
}
fn ensure_local_destination_dir(path: &str) -> Result<()> {
let local_path = Path::new(path);
if local_path.exists() && !local_path.is_dir() {
return Err(Error::BadArgument(format!(
"Local destination is not a directory: {path}"
)));
}
Ok(())
}
fn normalize_stage_list_entry(
location: &StageLocation,
listed_name: &str,
) -> Result<Option<StageFile>> {
let stage_prefix = format!("{}/", location.name);
let relative = if let Some(path) = listed_name.strip_prefix(&stage_prefix) {
if !location.path.is_empty() && !path.starts_with(&location.path) {
return Ok(None);
}
path
} else if location.path.is_empty() || listed_name.starts_with(&location.path) {
listed_name
} else {
return Ok(None);
};
let basename = Path::new(relative)
.file_name()
.and_then(|name| name.to_str())
.ok_or_else(|| Error::BadArgument(format!("Invalid stage file path: {listed_name}")))?
.to_string();
let stage_file = if listed_name.starts_with(&stage_prefix) {
format!("@{listed_name}")
} else {
format!("@{}/{}", location.name, relative)
};
Ok(Some(StageFile {
stage_file,
basename,
}))
}
fn ensure_unique_basenames(files: &[StageFile]) -> Result<()> {
let mut seen = BTreeSet::new();
for file in files {
if !seen.insert(file.basename.clone()) {
return Err(Error::BadArgument(format!(
"Duplicate local file basename in GET results: {}",
file.basename
)));
}
}
Ok(())
}
fn put_get_schema() -> Schema {
Schema::from_vec(vec![
Field {
name: "file".to_string(),
data_type: DataType::String,
},
Field {
name: "status".to_string(),
data_type: DataType::String,
},
Field {
name: "size".to_string(),
data_type: DataType::Number(NumberDataType::UInt64),
},
])
}
fn validate_local_scheme(scheme: &str) -> Result<()> {
match scheme {
"file" | "fs" => Ok(()),
_ => Err(Error::BadArgument(
"Supported schemes: file:// or fs://".to_string(),
)),
}
}