1use async_trait::async_trait;
16use jiff::tz::TimeZone;
17use log::info;
18use std::collections::{BTreeMap, VecDeque};
19use std::marker::PhantomData;
20use std::path::Path;
21use std::pin::Pin;
22use std::sync::Arc;
23use std::task::{Context, Poll};
24use std::time::Instant;
25use tokio::fs::File;
26use tokio::io::BufReader;
27use tokio_stream::Stream;
28
29use crate::client::LoadMethod;
30use crate::conn::{ConnectionInfo, IConnection, Reader};
31use databend_client::schema::{Schema, SchemaRef};
32use databend_client::Pages;
33use databend_client::{APIClient, ResultFormatSettings};
34use databend_driver_core::error::{Error, Result};
35use databend_driver_core::raw_rows::{RawRow, RawRowIterator, RawRowWithStats};
36use databend_driver_core::rows::{
37 Row, RowIterator, RowStatsIterator, RowWithStats, Rows, ServerStats,
38};
39
40const LOAD_PLACEHOLDER: &str = "@_databend_load";
41
42#[derive(Clone)]
43pub struct RestAPIConnection {
44 client: Arc<APIClient>,
45}
46
47impl RestAPIConnection {
48 fn gen_temp_stage_location(&self) -> Result<String> {
49 let now = chrono::Utc::now()
50 .timestamp_nanos_opt()
51 .ok_or_else(|| Error::IO("Failed to get current timestamp".to_string()))?;
52 Ok(format!("@~/client/load/{now}"))
53 }
54
55 async fn load_data_with_stage(
56 &self,
57 sql: &str,
58 data: Reader,
59 size: u64,
60 ) -> Result<ServerStats> {
61 let location = self.gen_temp_stage_location()?;
62 self.upload_to_stage(&location, data, size).await?;
63 if self.client.capability().streaming_load {
64 let sql = sql.replace(LOAD_PLACEHOLDER, &location);
65 let page = self.client.query_all(&sql).await?;
66 Ok(ServerStats::from(page.stats))
67 } else {
68 let file_format_options = Self::default_file_format_options();
69 let copy_options = Self::default_copy_options();
70 let stats = self
71 .client
72 .insert_with_stage(sql, &location, file_format_options, copy_options)
73 .await?;
74 Ok(ServerStats::from(stats))
75 }
76 }
77
78 async fn load_data_with_streaming(
79 &self,
80 sql: &str,
81 data: Reader,
82 size: u64,
83 ) -> Result<ServerStats> {
84 let start = Instant::now();
85 let response = self
86 .client
87 .streaming_load(sql, data, "<no_filename>")
88 .await?;
89 Ok(ServerStats {
90 total_rows: 0,
91 total_bytes: 0,
92 read_rows: response.stats.rows,
93 read_bytes: size as usize,
94 write_rows: response.stats.rows,
95 write_bytes: response.stats.bytes,
96 running_time_ms: start.elapsed().as_millis() as f64,
97 spill_file_nums: 0,
98 spill_bytes: 0,
99 })
100 }
101 async fn load_data_with_options(
102 &self,
103 sql: &str,
104 data: Reader,
105 size: u64,
106 file_format_options: Option<BTreeMap<&str, &str>>,
107 copy_options: Option<BTreeMap<&str, &str>>,
108 ) -> Result<ServerStats> {
109 let location = self.gen_temp_stage_location()?;
110 let file_format_options =
111 file_format_options.unwrap_or_else(Self::default_file_format_options);
112 let copy_options = copy_options.unwrap_or_else(Self::default_copy_options);
113 self.upload_to_stage(&location, Box::new(data), size)
114 .await?;
115 let stats = self
116 .client
117 .insert_with_stage(sql, &location, file_format_options, copy_options)
118 .await?;
119 Ok(ServerStats::from(stats))
120 }
121}
122
123#[async_trait]
124impl IConnection for RestAPIConnection {
125 async fn info(&self) -> ConnectionInfo {
126 ConnectionInfo {
127 handler: "RestAPI".to_string(),
128 host: self.client.host().to_string(),
129 port: self.client.port(),
130 user: self.client.username(),
131 catalog: self.client.current_catalog(),
132 database: self.client.current_database(),
133 warehouse: self.client.current_warehouse(),
134 }
135 }
136
137 fn last_query_id(&self) -> Option<String> {
138 self.client.last_query_id()
139 }
140
141 fn set_warehouse(&self, warehouse: &str) -> Result<()> {
142 self.client.set_warehouse(warehouse.to_string());
143 Ok(())
144 }
145
146 fn set_database(&self, database: &str) -> Result<()> {
147 self.client.set_database(database.to_string());
148 Ok(())
149 }
150
151 fn set_role(&self, role: &str) -> Result<()> {
152 self.client.set_role(role.to_string());
153 Ok(())
154 }
155
156 fn set_session(&self, key: &str, value: &str) -> Result<()> {
157 self.client.set_session(key.to_string(), value.to_string());
158 Ok(())
159 }
160
161 async fn close(&self) -> Result<()> {
162 self.client.close().await;
163 Ok(())
164 }
165
166 fn close_with_spawn(&self) -> Result<()> {
167 self.client.close_with_spawn();
168 Ok(())
169 }
170
171 async fn exec(&self, sql: &str) -> Result<i64> {
172 info!("exec: {}", sql);
173 let page = self.client.query_all(sql).await?;
174 Ok(page.stats.progresses.write_progress.rows as i64)
175 }
176
177 async fn kill_query(&self, query_id: &str) -> Result<()> {
178 Ok(self.client.kill_query(query_id).await?)
179 }
180
181 async fn query_iter(&self, sql: &str) -> Result<RowIterator> {
182 info!("query iter: {}", sql);
183 let rows_with_progress = self.query_iter_ext(sql).await?;
184 let rows = rows_with_progress.filter_rows().await?;
185 Ok(rows)
186 }
187
188 async fn query_iter_ext(&self, sql: &str) -> Result<RowStatsIterator> {
189 info!("query iter ext: {}", sql);
190 let pages = self.client.start_query(sql, true).await?;
191 let (schema, rows) = RestAPIRows::<RowWithStats>::from_pages(pages).await?;
192 Ok(RowStatsIterator::new(Arc::new(schema), Box::pin(rows)))
193 }
194
195 async fn query_raw_iter(&self, sql: &str) -> Result<RawRowIterator> {
197 info!("query raw iter: {}", sql);
198 let pages = self.client.start_query(sql, true).await?;
199 let (schema, rows) = RestAPIRows::<RawRowWithStats>::from_pages(pages).await?;
200 Ok(RawRowIterator::new(Arc::new(schema), Box::pin(rows)))
201 }
202
203 async fn upload_to_stage(&self, stage: &str, data: Reader, size: u64) -> Result<()> {
204 self.client.upload_to_stage(stage, data, size).await?;
205 Ok(())
206 }
207
208 async fn load_data(
209 &self,
210 sql: &str,
211 data: Reader,
212 size: u64,
213 method: LoadMethod,
214 ) -> Result<ServerStats> {
215 let sql = sql.trim_end();
216 let sql = sql.trim_end_matches(';');
217 info!("load data: {}, size: {}, method: {method:?}", sql, size);
218 let sql_low = sql.to_lowercase();
219 let has_place_holder = sql_low.contains(LOAD_PLACEHOLDER);
220 let sql = match (self.client.capability().streaming_load, has_place_holder) {
221 (false, false) => {
222 return self
224 .load_data_with_options(sql, data, size, None, None)
225 .await;
226 }
227 (false, true) => return Err(Error::BadArgument(
228 "Please upgrade your server to >= 1.2.781 to support insert from @_databend_load"
229 .to_string(),
230 )),
231 (true, false) => {
232 format!("{sql} from @_databend_load file_format=(type=csv)")
233 }
234 (true, true) => sql.to_string(),
235 };
236
237 match method {
238 LoadMethod::Streaming => self.load_data_with_streaming(&sql, data, size).await,
239 LoadMethod::Stage => self.load_data_with_stage(&sql, data, size).await,
240 }
241 }
242
243 async fn load_file(&self, sql: &str, fp: &Path, method: LoadMethod) -> Result<ServerStats> {
244 info!("load file: {}, file: {:?}", sql, fp,);
245 let file = File::open(fp).await?;
246 let metadata = file.metadata().await?;
247 let size = metadata.len();
248 let data = BufReader::new(file);
249 self.load_data(sql, Box::new(data), size, method).await
250 }
251
252 async fn load_file_with_options(
253 &self,
254 sql: &str,
255 fp: &Path,
256 file_format_options: Option<BTreeMap<&str, &str>>,
257 copy_options: Option<BTreeMap<&str, &str>>,
258 ) -> Result<ServerStats> {
259 let file = File::open(fp).await?;
260 let metadata = file.metadata().await?;
261 let size = metadata.len();
262 let data = BufReader::new(file);
263 self.load_data_with_options(sql, Box::new(data), size, file_format_options, copy_options)
264 .await
265 }
266
267 async fn stream_load(
268 &self,
269 sql: &str,
270 data: Vec<Vec<&str>>,
271 method: LoadMethod,
272 ) -> Result<ServerStats> {
273 info!("stream load: {}; rows: {:?}", sql, data.len());
274 let mut wtr = csv::WriterBuilder::new().from_writer(vec![]);
275 for row in data {
276 wtr.write_record(row)
277 .map_err(|e| Error::BadArgument(e.to_string()))?;
278 }
279 let bytes = wtr.into_inner().map_err(|e| Error::IO(e.to_string()))?;
280 let size = bytes.len() as u64;
281 let reader = Box::new(std::io::Cursor::new(bytes));
282 let stats = if self.client.capability().streaming_load {
283 let sql = format!("{sql} from @_databend_load file_format = (type = csv)");
284 self.load_data(&sql, reader, size, method).await?
285 } else {
286 self.load_data_with_options(sql, reader, size, None, None)
287 .await?
288 };
289 Ok(stats)
290 }
291}
292
293impl<'o> RestAPIConnection {
294 pub async fn try_create(dsn: &str, name: String) -> Result<Self> {
295 let client = APIClient::new(dsn, Some(name)).await?;
296 Ok(Self { client })
297 }
298
299 fn default_file_format_options() -> BTreeMap<&'o str, &'o str> {
300 vec![
301 ("type", "CSV"),
302 ("field_delimiter", ","),
303 ("record_delimiter", "\n"),
304 ("skip_header", "0"),
305 ]
306 .into_iter()
307 .collect()
308 }
309
310 fn default_copy_options() -> BTreeMap<&'o str, &'o str> {
311 vec![("purge", "true")].into_iter().collect()
312 }
313}
314
315pub struct RestAPIRows<T> {
316 pages: Pages,
317
318 schema: SchemaRef,
319 settings: ResultFormatSettings,
320
321 data: VecDeque<Vec<Option<String>>>,
322 rows: VecDeque<Row>,
323
324 stats: Option<ServerStats>,
325
326 _phantom: std::marker::PhantomData<T>,
327}
328
329impl<T> RestAPIRows<T> {
330 async fn from_pages(pages: Pages) -> Result<(Schema, Self)> {
331 let (pages, schema, settings) = pages.wait_for_schema(true).await?;
332 let rows = Self {
333 pages,
334 schema: Arc::new(schema.clone()),
335 settings,
336 data: Default::default(),
337 rows: Default::default(),
338 stats: None,
339 _phantom: PhantomData,
340 };
341 Ok((schema, rows))
342 }
343}
344
345impl<T: FromRowStats + std::marker::Unpin> Stream for RestAPIRows<T> {
346 type Item = Result<T>;
347
348 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
349 if let Some(ss) = self.stats.take() {
350 return Poll::Ready(Some(Ok(T::from_stats(ss))));
351 }
352 if self.data.len() > 1 {
355 if let Some(row) = self.data.pop_front() {
356 let row = T::try_from_raw_row(row, self.schema.clone(), &self.settings.timezone)?;
357 return Poll::Ready(Some(Ok(row)));
358 }
359 } else if self.rows.len() > 1 {
360 if let Some(row) = self.rows.pop_front() {
361 let row = T::from_row(row);
362 return Poll::Ready(Some(Ok(row)));
363 }
364 }
365
366 match Pin::new(&mut self.pages).poll_next(cx) {
367 Poll::Ready(Some(Ok(page))) => {
368 if self.schema.fields().is_empty() {
369 if !page.raw_schema.is_empty() {
370 self.schema = Arc::new(page.raw_schema.try_into()?);
371 } else if !page.batches.is_empty() {
372 self.schema = Arc::new(page.batches[0].schema().clone().try_into()?);
373 }
374 }
375 if page.batches.is_empty() {
376 let mut new_data = page.data.into();
377 self.data.append(&mut new_data);
378 } else {
379 for batch in page.batches.into_iter() {
380 let rows = Rows::try_from((batch, self.settings.clone()))?;
381 self.rows.extend(rows);
382 }
383 }
384 Poll::Ready(Some(Ok(T::from_stats(page.stats.into()))))
385 }
386 Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e.into()))),
387 Poll::Ready(None) => {
388 if let Some(row) = self.rows.pop_front() {
389 let row = T::from_row(row);
390 Poll::Ready(Some(Ok(row)))
391 } else if let Some(row) = self.data.pop_front() {
392 let row =
393 T::try_from_raw_row(row, self.schema.clone(), &self.settings.timezone)?;
394 Poll::Ready(Some(Ok(row)))
395 } else {
396 Poll::Ready(None)
397 }
398 }
399 Poll::Pending => Poll::Pending,
400 }
401 }
402}
403
404trait FromRowStats: Send + Sync + Clone {
405 fn from_stats(stats: ServerStats) -> Self;
406 fn try_from_raw_row(row: Vec<Option<String>>, schema: SchemaRef, tz: &TimeZone)
407 -> Result<Self>;
408 fn from_row(row: Row) -> Self;
409}
410
411impl FromRowStats for RowWithStats {
412 fn from_stats(stats: ServerStats) -> Self {
413 RowWithStats::Stats(stats)
414 }
415
416 fn try_from_raw_row(
417 row: Vec<Option<String>>,
418 schema: SchemaRef,
419 tz: &TimeZone,
420 ) -> Result<Self> {
421 Ok(RowWithStats::Row(Row::try_from((schema, row, tz))?))
422 }
423 fn from_row(row: Row) -> Self {
424 RowWithStats::Row(row)
425 }
426}
427
428impl FromRowStats for RawRowWithStats {
429 fn from_stats(stats: ServerStats) -> Self {
430 RawRowWithStats::Stats(stats)
431 }
432
433 fn try_from_raw_row(
434 row: Vec<Option<String>>,
435 schema: SchemaRef,
436 tz: &TimeZone,
437 ) -> Result<Self> {
438 let rows = Row::try_from((schema, row.clone(), tz))?;
439 Ok(RawRowWithStats::Row(RawRow::new(rows, row)))
440 }
441
442 fn from_row(row: Row) -> Self {
443 RawRowWithStats::Row(RawRow::from(row))
444 }
445}