1use std::collections::BTreeMap;
16use std::collections::BTreeSet;
17use std::path::Path;
18use std::path::PathBuf;
19use std::sync::Arc;
20
21use async_trait::async_trait;
22use log::info;
23use tokio::fs::File;
24use tokio::io::AsyncRead;
25use tokio::io::BufReader;
26use tokio_stream::StreamExt;
27
28use crate::client::LoadMethod;
29use databend_client::schema::{DataType, Field, NumberDataType, Schema};
30use databend_client::StageLocation;
31use databend_client::{presign_download_from_stage, PresignedResponse};
32use databend_driver_core::error::{Error, Result};
33use databend_driver_core::raw_rows::{RawRow, RawRowIterator};
34use databend_driver_core::rows::{Row, RowIterator, RowStatsIterator, RowWithStats, ServerStats};
35use databend_driver_core::value::{NumberValue, Value};
36
37pub struct ConnectionInfo {
38 pub handler: String,
39 pub host: String,
40 pub port: u16,
41 pub user: String,
42 pub catalog: Option<String>,
43 pub database: Option<String>,
44 pub warehouse: Option<String>,
45}
46
47pub type Reader = Box<dyn AsyncRead + Send + Sync + Unpin + 'static>;
48
49#[async_trait]
50pub trait IConnection: Send + Sync {
51 async fn info(&self) -> ConnectionInfo;
52 async fn close(&self) -> Result<()> {
53 Ok(())
54 }
55
56 fn close_with_spawn(&self) -> Result<()> {
57 Ok(())
58 }
59
60 fn last_query_id(&self) -> Option<String>;
61
62 async fn version(&self) -> Result<String> {
63 let row = self.query_row("SELECT version()").await?;
64 let version = match row {
65 Some(row) => {
66 let (version,): (String,) = row.try_into().map_err(Error::Parsing)?;
67 version
68 }
69 None => "".to_string(),
70 };
71 Ok(version)
72 }
73
74 async fn exec(&self, sql: &str) -> Result<i64>;
75 async fn kill_query(&self, query_id: &str) -> Result<()>;
76 async fn query_iter(&self, sql: &str) -> Result<RowIterator>;
77 async fn query_iter_ext(&self, sql: &str) -> Result<RowStatsIterator>;
78
79 fn supports_server_side_params(&self) -> bool {
80 false
81 }
82
83 async fn exec_with_params(&self, sql: &str, _params: Option<serde_json::Value>) -> Result<i64> {
84 self.exec(sql).await
85 }
86
87 async fn query_iter_with_params(
88 &self,
89 sql: &str,
90 _params: Option<serde_json::Value>,
91 ) -> Result<RowIterator> {
92 self.query_iter(sql).await
93 }
94
95 async fn query_iter_ext_with_params(
96 &self,
97 sql: &str,
98 _params: Option<serde_json::Value>,
99 ) -> Result<RowStatsIterator> {
100 self.query_iter_ext(sql).await
101 }
102
103 async fn query_row(&self, sql: &str) -> Result<Option<Row>> {
104 let rows = self.query_all(sql).await?;
105 let row = rows.into_iter().next();
106 Ok(row)
107 }
108
109 async fn query_all(&self, sql: &str) -> Result<Vec<Row>> {
110 let rows = self.query_iter(sql).await?;
111 rows.collect().await
112 }
113
114 async fn query_raw_iter(&self, _sql: &str) -> Result<RawRowIterator> {
116 Err(Error::BadArgument(
117 "Unsupported implement query_raw_iter".to_string(),
118 ))
119 }
120
121 async fn query_raw_all(&self, sql: &str) -> Result<Vec<RawRow>> {
123 let rows = self.query_raw_iter(sql).await?;
124 rows.collect().await
125 }
126
127 async fn get_presigned_url(&self, operation: &str, stage: &str) -> Result<PresignedResponse> {
130 info!("get presigned url: {} {}", operation, stage);
131 let sql = format!("PRESIGN {operation} {stage}");
132 let row = self.query_row(&sql).await?.ok_or_else(|| {
133 Error::InvalidResponse("Empty response from server for presigned request".to_string())
134 })?;
135 let (method, headers, url): (String, String, String) =
136 row.try_into().map_err(Error::Parsing)?;
137 let headers: BTreeMap<String, String> = serde_json::from_str(&headers)?;
138 Ok(PresignedResponse {
139 method,
140 headers,
141 url,
142 })
143 }
144
145 async fn upload_to_stage(&self, stage: &str, data: Reader, size: u64) -> Result<()>;
146
147 async fn load_data(
148 &self,
149 sql: &str,
150 data: Reader,
151 size: u64,
152 method: LoadMethod,
153 ) -> Result<ServerStats>;
154
155 async fn load_file(&self, sql: &str, fp: &Path, method: LoadMethod) -> Result<ServerStats>;
156
157 async fn load_file_with_options(
158 &self,
159 sql: &str,
160 fp: &Path,
161 file_format_options: Option<BTreeMap<&str, &str>>,
162 copy_options: Option<BTreeMap<&str, &str>>,
163 ) -> Result<ServerStats>;
164
165 async fn stream_load(
166 &self,
167 sql: &str,
168 data: Vec<Vec<&str>>,
169 _method: LoadMethod,
170 ) -> Result<ServerStats>;
171
172 fn set_warehouse(&self, warehouse: &str) -> Result<()>;
173
174 fn set_database(&self, database: &str) -> Result<()>;
175
176 fn set_role(&self, role: &str) -> Result<()>;
177
178 fn set_session(&self, key: &str, value: &str) -> Result<()>;
179
180 async fn put_files(&self, local_file: &str, stage: &str) -> Result<RowStatsIterator> {
182 let mut total_count: usize = 0;
183 let mut total_size: usize = 0;
184 let local_dsn = url::Url::parse(local_file)?;
185 validate_local_scheme(local_dsn.scheme())?;
186 let entries = expand_local_glob(local_dsn.path())?;
187 let mut results = Vec::new();
188 let stage_location = StageLocation::try_from(stage)?;
189 let schema = Arc::new(put_get_schema());
190 for entry in entries {
191 let filename = entry
192 .file_name()
193 .ok_or_else(|| Error::BadArgument(format!("Invalid local file path: {entry:?}")))?
194 .to_str()
195 .ok_or_else(|| Error::BadArgument(format!("Invalid local file path: {entry:?}")))?;
196 let stage_file = stage_location.file_path(filename);
197 let (fname, status, size) = if entry.is_dir() {
198 (
199 entry.to_string_lossy().to_string(),
200 format!(
201 "BadArgument: Local path is a directory: {}",
202 entry.display()
203 ),
204 0,
205 )
206 } else {
207 let file = File::open(&entry).await?;
208 let size = file.metadata().await?.len();
209 let data = BufReader::new(file);
210 match self
211 .upload_to_stage(&stage_file, Box::new(data), size)
212 .await
213 {
214 Ok(_) => {
215 total_count += 1;
216 total_size += size as usize;
217 (
218 entry.to_string_lossy().to_string(),
219 "SUCCESS".to_owned(),
220 size,
221 )
222 }
223 Err(e) => (entry.to_string_lossy().to_string(), e.to_string(), size),
224 }
225 };
226 let ss = ServerStats {
227 write_rows: total_count,
228 write_bytes: total_size,
229
230 ..Default::default()
231 };
232 results.push(Ok(RowWithStats::Stats(ss)));
233 results.push(Ok(RowWithStats::Row(Row::from_vec(
234 schema.clone(),
235 vec![
236 Value::String(fname),
237 Value::String(status),
238 Value::Number(NumberValue::UInt64(size)),
239 ],
240 ))));
241 }
242 Ok(RowStatsIterator::new(
243 schema,
244 Box::pin(tokio_stream::iter(results)),
245 ))
246 }
247
248 async fn get_files(&self, stage: &str, local_file: &str) -> Result<RowStatsIterator> {
249 let mut total_count: usize = 0;
250 let mut total_size: usize = 0;
251 let local_dsn = url::Url::parse(local_file)?;
252 validate_local_scheme(local_dsn.scheme())?;
253 ensure_local_destination_dir(local_dsn.path())?;
254 let location = StageLocation::try_from(stage)?;
255 let list_sql = format!("LIST {location}");
256 let mut response = self.query_iter(&list_sql).await?;
257 let mut files = Vec::new();
258 let mut results = Vec::new();
259 let schema = Arc::new(put_get_schema());
260 while let Some(row) = response.next().await {
261 let (name, _, _, _, _): (String, u64, Option<String>, String, Option<String>) =
262 row?.try_into().map_err(Error::Parsing)?;
263 if let Some(file) = normalize_stage_list_entry(&location, &name)? {
264 files.push(file);
265 }
266 }
267 if files.is_empty() {
268 return Err(Error::BadArgument(format!(
269 "No stage files matched: {stage}"
270 )));
271 }
272 ensure_unique_basenames(&files)?;
273 for file in files {
274 let local_file = Path::new(local_dsn.path()).join(&file.basename);
275 let (status, size) = match self.get_presigned_url("DOWNLOAD", &file.stage_file).await {
276 Ok(presign) => match presign_download_from_stage(presign, &local_file).await {
277 Ok(size) => {
278 total_count += 1;
279 total_size += size as usize;
280 ("SUCCESS".to_owned(), size)
281 }
282 Err(e) => (e.to_string(), 0),
283 },
284 Err(e) => (e.to_string(), 0),
285 };
286 let ss = ServerStats {
287 read_rows: total_count,
288 read_bytes: total_size,
289 ..Default::default()
290 };
291 results.push(Ok(RowWithStats::Stats(ss)));
292 results.push(Ok(RowWithStats::Row(Row::from_vec(
293 schema.clone(),
294 vec![
295 Value::String(local_file.to_string_lossy().to_string()),
296 Value::String(status),
297 Value::Number(NumberValue::UInt64(size)),
298 ],
299 ))));
300 }
301 Ok(RowStatsIterator::new(
302 schema,
303 Box::pin(tokio_stream::iter(results)),
304 ))
305 }
306}
307
308#[derive(Debug, Clone)]
309struct StageFile {
310 stage_file: String,
311 basename: String,
312}
313
314fn expand_local_glob(pattern: &str) -> Result<Vec<PathBuf>> {
315 let entries = glob::glob(pattern)?.collect::<std::result::Result<Vec<_>, _>>()?;
316 if entries.is_empty() {
317 return Err(Error::BadArgument(format!(
318 "No local files matched: {pattern}"
319 )));
320 }
321 Ok(entries)
322}
323
324fn ensure_local_destination_dir(path: &str) -> Result<()> {
325 let local_path = Path::new(path);
326 if local_path.exists() && !local_path.is_dir() {
327 return Err(Error::BadArgument(format!(
328 "Local destination is not a directory: {path}"
329 )));
330 }
331 Ok(())
332}
333
334fn normalize_stage_list_entry(
335 location: &StageLocation,
336 listed_name: &str,
337) -> Result<Option<StageFile>> {
338 let stage_prefix = format!("{}/", location.name);
339 let relative = if let Some(path) = listed_name.strip_prefix(&stage_prefix) {
340 if !location.path.is_empty() && !path.starts_with(&location.path) {
341 return Ok(None);
342 }
343 path
344 } else if location.path.is_empty() || listed_name.starts_with(&location.path) {
345 listed_name
346 } else {
347 return Ok(None);
348 };
349
350 let basename = Path::new(relative)
351 .file_name()
352 .and_then(|name| name.to_str())
353 .ok_or_else(|| Error::BadArgument(format!("Invalid stage file path: {listed_name}")))?
354 .to_string();
355 let stage_file = if listed_name.starts_with(&stage_prefix) {
356 format!("@{listed_name}")
357 } else {
358 format!("@{}/{}", location.name, relative)
359 };
360
361 Ok(Some(StageFile {
362 stage_file,
363 basename,
364 }))
365}
366
367fn ensure_unique_basenames(files: &[StageFile]) -> Result<()> {
368 let mut seen = BTreeSet::new();
369 for file in files {
370 if !seen.insert(file.basename.clone()) {
371 return Err(Error::BadArgument(format!(
372 "Duplicate local file basename in GET results: {}",
373 file.basename
374 )));
375 }
376 }
377 Ok(())
378}
379
380fn put_get_schema() -> Schema {
381 Schema::from_vec(vec![
382 Field {
383 name: "file".to_string(),
384 data_type: DataType::String,
385 },
386 Field {
387 name: "status".to_string(),
388 data_type: DataType::String,
389 },
390 Field {
391 name: "size".to_string(),
392 data_type: DataType::Number(NumberDataType::UInt64),
393 },
394 ])
395}
396
397fn validate_local_scheme(scheme: &str) -> Result<()> {
398 match scheme {
399 "file" | "fs" => Ok(()),
400 _ => Err(Error::BadArgument(
401 "Supported schemes: file:// or fs://".to_string(),
402 )),
403 }
404}