1use async_trait::async_trait;
16use log::info;
17use std::collections::{BTreeMap, VecDeque};
18use std::marker::PhantomData;
19use std::path::Path;
20use std::pin::Pin;
21use std::sync::Arc;
22use std::task::{Context, Poll};
23use std::time::Instant;
24use tokio::fs::File;
25use tokio::io::BufReader;
26use tokio_stream::Stream;
27
28use crate::client::LoadMethod;
29use crate::conn::{ConnectionInfo, IConnection, Reader};
30use databend_client::schema::{Schema, SchemaRef};
31use databend_client::Pages;
32use databend_client::{APIClient, ResultFormatSettings};
33use databend_driver_core::error::{Error, Result};
34use databend_driver_core::raw_rows::{RawRow, RawRowIterator, RawRowWithStats};
35use databend_driver_core::rows::{
36 Row, RowIterator, RowStatsIterator, RowWithStats, Rows, ServerStats,
37};
38
39const LOAD_PLACEHOLDER: &str = "@_databend_load";
40
41#[derive(Clone)]
42pub struct RestAPIConnection {
43 client: Arc<APIClient>,
44}
45
46impl RestAPIConnection {
47 fn gen_temp_stage_location(&self) -> Result<String> {
48 let now = chrono::Utc::now()
49 .timestamp_nanos_opt()
50 .ok_or_else(|| Error::IO("Failed to get current timestamp".to_string()))?;
51 Ok(format!("@~/client/load/{now}"))
52 }
53
54 async fn load_data_with_stage(
55 &self,
56 sql: &str,
57 data: Reader,
58 size: u64,
59 ) -> Result<ServerStats> {
60 let location = self.gen_temp_stage_location()?;
61 self.upload_to_stage(&location, data, size).await?;
62 if self.client.capability().streaming_load {
63 let sql = sql.replace(LOAD_PLACEHOLDER, &location);
64 let page = self.client.query_all(&sql, None).await?;
65 Ok(ServerStats::from(page.stats))
66 } else {
67 let file_format_options = Self::default_file_format_options();
68 let copy_options = Self::default_copy_options();
69 let stats = self
70 .client
71 .insert_with_stage(sql, &location, file_format_options, copy_options)
72 .await?;
73 Ok(ServerStats::from(stats))
74 }
75 }
76
77 async fn load_data_with_streaming(
78 &self,
79 sql: &str,
80 data: Reader,
81 size: u64,
82 ) -> Result<ServerStats> {
83 let start = Instant::now();
84 let response = self
85 .client
86 .streaming_load(sql, data, "<no_filename>")
87 .await?;
88 Ok(ServerStats {
89 total_rows: 0,
90 total_bytes: 0,
91 read_rows: response.stats.rows,
92 read_bytes: size as usize,
93 write_rows: response.stats.rows,
94 write_bytes: response.stats.bytes,
95 running_time_ms: start.elapsed().as_millis() as f64,
96 spill_file_nums: 0,
97 spill_bytes: 0,
98 })
99 }
100 async fn load_data_with_options(
101 &self,
102 sql: &str,
103 data: Reader,
104 size: u64,
105 file_format_options: Option<BTreeMap<&str, &str>>,
106 copy_options: Option<BTreeMap<&str, &str>>,
107 ) -> Result<ServerStats> {
108 let location = self.gen_temp_stage_location()?;
109 let file_format_options =
110 file_format_options.unwrap_or_else(Self::default_file_format_options);
111 let copy_options = copy_options.unwrap_or_else(Self::default_copy_options);
112 self.upload_to_stage(&location, Box::new(data), size)
113 .await?;
114 let stats = self
115 .client
116 .insert_with_stage(sql, &location, file_format_options, copy_options)
117 .await?;
118 Ok(ServerStats::from(stats))
119 }
120}
121
122#[async_trait]
123impl IConnection for RestAPIConnection {
124 async fn info(&self) -> ConnectionInfo {
125 ConnectionInfo {
126 handler: "RestAPI".to_string(),
127 host: self.client.host().to_string(),
128 port: self.client.port(),
129 user: self.client.username(),
130 catalog: self.client.current_catalog(),
131 database: self.client.current_database(),
132 warehouse: self.client.current_warehouse(),
133 }
134 }
135
136 fn last_query_id(&self) -> Option<String> {
137 self.client.last_query_id()
138 }
139
140 fn set_warehouse(&self, warehouse: &str) -> Result<()> {
141 self.client.set_warehouse(warehouse.to_string());
142 Ok(())
143 }
144
145 fn set_database(&self, database: &str) -> Result<()> {
146 self.client.set_database(database.to_string());
147 Ok(())
148 }
149
150 fn set_role(&self, role: &str) -> Result<()> {
151 self.client.set_role(role.to_string());
152 Ok(())
153 }
154
155 fn set_session(&self, key: &str, value: &str) -> Result<()> {
156 self.client.set_session(key.to_string(), value.to_string());
157 Ok(())
158 }
159
160 async fn close(&self) -> Result<()> {
161 self.client.close().await;
162 Ok(())
163 }
164
165 fn close_with_spawn(&self) -> Result<()> {
166 self.client.close_with_spawn();
167 Ok(())
168 }
169
170 fn supports_server_side_params(&self) -> bool {
171 self.client.capability().server_side_params
172 }
173
174 async fn exec(&self, sql: &str) -> Result<i64> {
175 info!("exec: {}", sql);
176 let page = self.client.query_all(sql, None).await?;
177 Ok(page.stats.progresses.write_progress.rows as i64)
178 }
179
180 async fn exec_with_params(&self, sql: &str, params: Option<serde_json::Value>) -> Result<i64> {
181 info!("exec with params: {}", sql);
182 let page = self.client.query_all(sql, params).await?;
183 Ok(page.stats.progresses.write_progress.rows as i64)
184 }
185
186 async fn kill_query(&self, query_id: &str) -> Result<()> {
187 Ok(self.client.kill_query(query_id).await?)
188 }
189
190 async fn query_iter(&self, sql: &str) -> Result<RowIterator> {
191 info!("query iter: {}", sql);
192 let rows_with_progress = self.query_iter_ext(sql).await?;
193 let rows = rows_with_progress.filter_rows().await?;
194 Ok(rows)
195 }
196
197 async fn query_iter_ext(&self, sql: &str) -> Result<RowStatsIterator> {
198 info!("query iter ext: {}", sql);
199 let pages = self.client.start_query(sql, true, None).await?;
200 let (schema, rows) = RestAPIRows::<RowWithStats>::from_pages(pages).await?;
201 Ok(RowStatsIterator::new(Arc::new(schema), Box::pin(rows)))
202 }
203
204 async fn query_iter_with_params(
205 &self,
206 sql: &str,
207 params: Option<serde_json::Value>,
208 ) -> Result<RowIterator> {
209 info!("query iter with params: {}", sql);
210 let rows_with_progress = self.query_iter_ext_with_params(sql, params).await?;
211 let rows = rows_with_progress.filter_rows().await?;
212 Ok(rows)
213 }
214
215 async fn query_iter_ext_with_params(
216 &self,
217 sql: &str,
218 params: Option<serde_json::Value>,
219 ) -> Result<RowStatsIterator> {
220 info!("query iter ext with params: {}", sql);
221 let pages = self.client.start_query(sql, true, params).await?;
222 let (schema, rows) = RestAPIRows::<RowWithStats>::from_pages(pages).await?;
223 Ok(RowStatsIterator::new(Arc::new(schema), Box::pin(rows)))
224 }
225
226 async fn query_raw_iter(&self, sql: &str) -> Result<RawRowIterator> {
228 info!("query raw iter: {}", sql);
229 let pages = self.client.start_query(sql, true, None).await?;
230 let (schema, rows) = RestAPIRows::<RawRowWithStats>::from_pages(pages).await?;
231 Ok(RawRowIterator::new(Arc::new(schema), Box::pin(rows)))
232 }
233
234 async fn upload_to_stage(&self, stage: &str, data: Reader, size: u64) -> Result<()> {
235 self.client.upload_to_stage(stage, data, size).await?;
236 Ok(())
237 }
238
239 async fn load_data(
240 &self,
241 sql: &str,
242 data: Reader,
243 size: u64,
244 method: LoadMethod,
245 ) -> Result<ServerStats> {
246 let sql = sql.trim_end();
247 let sql = sql.trim_end_matches(';');
248 info!("load data: {}, size: {}, method: {method:?}", sql, size);
249 let sql_low = sql.to_lowercase();
250 let has_place_holder = sql_low.contains(LOAD_PLACEHOLDER);
251 let sql = match (self.client.capability().streaming_load, has_place_holder) {
252 (false, false) => {
253 return self
255 .load_data_with_options(sql, data, size, None, None)
256 .await;
257 }
258 (false, true) => return Err(Error::BadArgument(
259 "Please upgrade your server to >= 1.2.781 to support insert from @_databend_load"
260 .to_string(),
261 )),
262 (true, false) => {
263 format!("{sql} from @_databend_load file_format=(type=csv)")
264 }
265 (true, true) => sql.to_string(),
266 };
267
268 match method {
269 LoadMethod::Streaming => self.load_data_with_streaming(&sql, data, size).await,
270 LoadMethod::Stage => self.load_data_with_stage(&sql, data, size).await,
271 }
272 }
273
274 async fn load_file(&self, sql: &str, fp: &Path, method: LoadMethod) -> Result<ServerStats> {
275 info!("load file: {}, file: {:?}", sql, fp,);
276 let file = File::open(fp).await?;
277 let metadata = file.metadata().await?;
278 let size = metadata.len();
279 let data = BufReader::new(file);
280 self.load_data(sql, Box::new(data), size, method).await
281 }
282
283 async fn load_file_with_options(
284 &self,
285 sql: &str,
286 fp: &Path,
287 file_format_options: Option<BTreeMap<&str, &str>>,
288 copy_options: Option<BTreeMap<&str, &str>>,
289 ) -> Result<ServerStats> {
290 let file = File::open(fp).await?;
291 let metadata = file.metadata().await?;
292 let size = metadata.len();
293 let data = BufReader::new(file);
294 self.load_data_with_options(sql, Box::new(data), size, file_format_options, copy_options)
295 .await
296 }
297
298 async fn stream_load(
299 &self,
300 sql: &str,
301 data: Vec<Vec<&str>>,
302 method: LoadMethod,
303 ) -> Result<ServerStats> {
304 info!("stream load: {}; rows: {:?}", sql, data.len());
305 let mut wtr = csv::WriterBuilder::new().from_writer(vec![]);
306 for row in data {
307 wtr.write_record(row)
308 .map_err(|e| Error::BadArgument(e.to_string()))?;
309 }
310 let bytes = wtr.into_inner().map_err(|e| Error::IO(e.to_string()))?;
311 let size = bytes.len() as u64;
312 let reader = Box::new(std::io::Cursor::new(bytes));
313 let stats = if self.client.capability().streaming_load {
314 let sql = format!("{sql} from @_databend_load file_format = (type = csv)");
315 self.load_data(&sql, reader, size, method).await?
316 } else {
317 self.load_data_with_options(sql, reader, size, None, None)
318 .await?
319 };
320 Ok(stats)
321 }
322}
323
324impl<'o> RestAPIConnection {
325 pub async fn try_create(dsn: &str, name: String) -> Result<Self> {
326 let client = APIClient::new(dsn, Some(name)).await?;
327 Ok(Self { client })
328 }
329
330 fn default_file_format_options() -> BTreeMap<&'o str, &'o str> {
331 vec![
332 ("type", "CSV"),
333 ("field_delimiter", ","),
334 ("record_delimiter", "\n"),
335 ("skip_header", "0"),
336 ]
337 .into_iter()
338 .collect()
339 }
340
341 fn default_copy_options() -> BTreeMap<&'o str, &'o str> {
342 vec![("purge", "true")].into_iter().collect()
343 }
344}
345
346pub struct RestAPIRows<T> {
347 pages: Pages,
348
349 schema: SchemaRef,
350 settings: ResultFormatSettings,
351
352 data: VecDeque<Vec<Option<String>>>,
353 rows: VecDeque<Row>,
354
355 stats: Option<ServerStats>,
356
357 _phantom: std::marker::PhantomData<T>,
358}
359
360impl<T> RestAPIRows<T> {
361 async fn from_pages(pages: Pages) -> Result<(Schema, Self)> {
362 let (pages, schema, settings) = pages.wait_for_schema(true).await?;
363 let rows = Self {
364 pages,
365 schema: Arc::new(schema.clone()),
366 settings,
367 data: Default::default(),
368 rows: Default::default(),
369 stats: None,
370 _phantom: PhantomData,
371 };
372 Ok((schema, rows))
373 }
374}
375
376impl<T: FromRowStats + std::marker::Unpin> Stream for RestAPIRows<T> {
377 type Item = Result<T>;
378
379 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
380 if let Some(ss) = self.stats.take() {
381 return Poll::Ready(Some(Ok(T::from_stats(ss))));
382 }
383 if self.data.len() > 1 {
386 if let Some(row) = self.data.pop_front() {
387 let row = T::try_from_raw_row(row, self.schema.clone(), &self.settings)?;
388 return Poll::Ready(Some(Ok(row)));
389 }
390 } else if self.rows.len() > 1 {
391 if let Some(row) = self.rows.pop_front() {
392 let row = T::from_row(row);
393 return Poll::Ready(Some(Ok(row)));
394 }
395 }
396
397 match Pin::new(&mut self.pages).poll_next(cx) {
398 Poll::Ready(Some(Ok(page))) => {
399 if self.schema.fields().is_empty() {
400 if !page.raw_schema.is_empty() {
401 self.schema = Arc::new(page.raw_schema.try_into()?);
402 } else if !page.batches.is_empty() {
403 self.schema = Arc::new(page.batches[0].schema().clone().try_into()?);
404 }
405 }
406 if page.batches.is_empty() {
407 let mut new_data = page.data.into();
408 self.data.append(&mut new_data);
409 } else {
410 for batch in page.batches.into_iter() {
411 let rows = Rows::try_from((batch, self.settings.clone()))?;
412 self.rows.extend(rows);
413 }
414 }
415 Poll::Ready(Some(Ok(T::from_stats(page.stats.into()))))
416 }
417 Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e.into()))),
418 Poll::Ready(None) => {
419 if let Some(row) = self.rows.pop_front() {
420 let row = T::from_row(row);
421 Poll::Ready(Some(Ok(row)))
422 } else if let Some(row) = self.data.pop_front() {
423 let row = T::try_from_raw_row(row, self.schema.clone(), &self.settings)?;
424 Poll::Ready(Some(Ok(row)))
425 } else {
426 Poll::Ready(None)
427 }
428 }
429 Poll::Pending => Poll::Pending,
430 }
431 }
432}
433
434trait FromRowStats: Send + Sync + Clone {
435 fn from_stats(stats: ServerStats) -> Self;
436 fn try_from_raw_row(
437 row: Vec<Option<String>>,
438 schema: SchemaRef,
439 tz: &ResultFormatSettings,
440 ) -> Result<Self>;
441 fn from_row(row: Row) -> Self;
442}
443
444impl FromRowStats for RowWithStats {
445 fn from_stats(stats: ServerStats) -> Self {
446 RowWithStats::Stats(stats)
447 }
448
449 fn try_from_raw_row(
450 row: Vec<Option<String>>,
451 schema: SchemaRef,
452 settings: &ResultFormatSettings,
453 ) -> Result<Self> {
454 Ok(RowWithStats::Row(Row::try_from((schema, row, settings))?))
455 }
456 fn from_row(row: Row) -> Self {
457 RowWithStats::Row(row)
458 }
459}
460
461impl FromRowStats for RawRowWithStats {
462 fn from_stats(stats: ServerStats) -> Self {
463 RawRowWithStats::Stats(stats)
464 }
465
466 fn try_from_raw_row(
467 row: Vec<Option<String>>,
468 schema: SchemaRef,
469 tz: &ResultFormatSettings,
470 ) -> Result<Self> {
471 let rows = Row::try_from((schema, row.clone(), tz))?;
472 Ok(RawRowWithStats::Row(RawRow::new(rows, row)))
473 }
474
475 fn from_row(row: Row) -> Self {
476 RawRowWithStats::Row(RawRow::from(row))
477 }
478}