1use async_trait::async_trait;
16use chrono_tz::Tz;
17use log::info;
18use std::collections::{BTreeMap, VecDeque};
19use std::marker::PhantomData;
20use std::path::Path;
21use std::pin::Pin;
22use std::sync::Arc;
23use std::task::{Context, Poll};
24use std::time::Instant;
25use tokio::fs::File;
26use tokio::io::BufReader;
27use tokio_stream::Stream;
28
29use crate::client::LoadMethod;
30use crate::conn::{ConnectionInfo, IConnection, Reader};
31use databend_client::APIClient;
32use databend_client::Pages;
33use databend_driver_core::error::{Error, Result};
34use databend_driver_core::raw_rows::{RawRow, RawRowIterator, RawRowWithStats};
35use databend_driver_core::rows::{
36 Row, RowIterator, RowStatsIterator, RowWithStats, Rows, ServerStats,
37};
38use databend_driver_core::schema::{Schema, SchemaRef};
39
40const LOAD_PLACEHOLDER: &str = "@_databend_load";
41
42#[derive(Clone)]
43pub struct RestAPIConnection {
44 client: Arc<APIClient>,
45}
46
47impl RestAPIConnection {
48 fn gen_temp_stage_location(&self) -> Result<String> {
49 let now = chrono::Utc::now()
50 .timestamp_nanos_opt()
51 .ok_or_else(|| Error::IO("Failed to get current timestamp".to_string()))?;
52 Ok(format!("@~/client/load/{now}"))
53 }
54
55 async fn load_data_with_stage(
56 &self,
57 sql: &str,
58 data: Reader,
59 size: u64,
60 ) -> Result<ServerStats> {
61 let location = self.gen_temp_stage_location()?;
62 self.upload_to_stage(&location, data, size).await?;
63 if self.client.capability().streaming_load {
64 let sql = sql.replace(LOAD_PLACEHOLDER, &location);
65 let page = self.client.query_all(&sql).await?;
66 Ok(ServerStats::from(page.stats))
67 } else {
68 let file_format_options = Self::default_file_format_options();
69 let copy_options = Self::default_copy_options();
70 let stats = self
71 .client
72 .insert_with_stage(sql, &location, file_format_options, copy_options)
73 .await?;
74 Ok(ServerStats::from(stats))
75 }
76 }
77
78 async fn load_data_with_streaming(
79 &self,
80 sql: &str,
81 data: Reader,
82 size: u64,
83 ) -> Result<ServerStats> {
84 let start = Instant::now();
85 let response = self
86 .client
87 .streaming_load(sql, data, "<no_filename>")
88 .await?;
89 Ok(ServerStats {
90 total_rows: 0,
91 total_bytes: 0,
92 read_rows: response.stats.rows,
93 read_bytes: size as usize,
94 write_rows: response.stats.rows,
95 write_bytes: response.stats.bytes,
96 running_time_ms: start.elapsed().as_millis() as f64,
97 spill_file_nums: 0,
98 spill_bytes: 0,
99 })
100 }
101 async fn load_data_with_options(
102 &self,
103 sql: &str,
104 data: Reader,
105 size: u64,
106 file_format_options: Option<BTreeMap<&str, &str>>,
107 copy_options: Option<BTreeMap<&str, &str>>,
108 ) -> Result<ServerStats> {
109 let location = self.gen_temp_stage_location()?;
110 let file_format_options =
111 file_format_options.unwrap_or_else(Self::default_file_format_options);
112 let copy_options = copy_options.unwrap_or_else(Self::default_copy_options);
113 self.upload_to_stage(&location, Box::new(data), size)
114 .await?;
115 let stats = self
116 .client
117 .insert_with_stage(sql, &location, file_format_options, copy_options)
118 .await?;
119 Ok(ServerStats::from(stats))
120 }
121}
122
123#[async_trait]
124impl IConnection for RestAPIConnection {
125 async fn info(&self) -> ConnectionInfo {
126 ConnectionInfo {
127 handler: "RestAPI".to_string(),
128 host: self.client.host().to_string(),
129 port: self.client.port(),
130 user: self.client.username(),
131 catalog: self.client.current_catalog(),
132 database: self.client.current_database(),
133 warehouse: self.client.current_warehouse(),
134 }
135 }
136
137 fn last_query_id(&self) -> Option<String> {
138 self.client.last_query_id()
139 }
140
141 async fn close(&self) -> Result<()> {
142 self.client.close().await;
143 Ok(())
144 }
145
146 fn close_with_spawn(&self) -> Result<()> {
147 self.client.close_with_spawn();
148 Ok(())
149 }
150
151 async fn exec(&self, sql: &str) -> Result<i64> {
152 info!("exec: {}", sql);
153 let page = self.client.query_all(sql).await?;
154 Ok(page.stats.progresses.write_progress.rows as i64)
155 }
156
157 async fn kill_query(&self, query_id: &str) -> Result<()> {
158 Ok(self.client.kill_query(query_id).await?)
159 }
160
161 async fn query_iter(&self, sql: &str) -> Result<RowIterator> {
162 info!("query iter: {}", sql);
163 let rows_with_progress = self.query_iter_ext(sql).await?;
164 let rows = rows_with_progress.filter_rows().await?;
165 Ok(rows)
166 }
167
168 async fn query_iter_ext(&self, sql: &str) -> Result<RowStatsIterator> {
169 info!("query iter ext: {}", sql);
170 let pages = self.client.start_query(sql, true).await?;
171 let (schema, rows) = RestAPIRows::<RowWithStats>::from_pages(pages).await?;
172 Ok(RowStatsIterator::new(Arc::new(schema), Box::pin(rows)))
173 }
174
175 async fn query_raw_iter(&self, sql: &str) -> Result<RawRowIterator> {
177 info!("query raw iter: {}", sql);
178 let pages = self.client.start_query(sql, true).await?;
179 let (schema, rows) = RestAPIRows::<RawRowWithStats>::from_pages(pages).await?;
180 Ok(RawRowIterator::new(Arc::new(schema), Box::pin(rows)))
181 }
182
183 async fn upload_to_stage(&self, stage: &str, data: Reader, size: u64) -> Result<()> {
184 self.client.upload_to_stage(stage, data, size).await?;
185 Ok(())
186 }
187
188 async fn load_data(
189 &self,
190 sql: &str,
191 data: Reader,
192 size: u64,
193 method: LoadMethod,
194 ) -> Result<ServerStats> {
195 let sql = sql.trim_end();
196 let sql = sql.trim_end_matches(';');
197 info!("load data: {}, size: {}, method: {method:?}", sql, size);
198 let sql_low = sql.to_lowercase();
199 let has_place_holder = sql_low.contains(LOAD_PLACEHOLDER);
200 let sql = match (self.client.capability().streaming_load, has_place_holder) {
201 (false, false) => {
202 return self
204 .load_data_with_options(sql, data, size, None, None)
205 .await;
206 }
207 (false, true) => return Err(Error::BadArgument(
208 "Please upgrade your server to >= 1.2.781 to support insert from @_databend_load"
209 .to_string(),
210 )),
211 (true, false) => {
212 format!("{sql} from @_databend_load file_format=(type=csv)")
213 }
214 (true, true) => sql.to_string(),
215 };
216
217 match method {
218 LoadMethod::Streaming => self.load_data_with_streaming(&sql, data, size).await,
219 LoadMethod::Stage => self.load_data_with_stage(&sql, data, size).await,
220 }
221 }
222
223 async fn load_file(&self, sql: &str, fp: &Path, method: LoadMethod) -> Result<ServerStats> {
224 info!("load file: {}, file: {:?}", sql, fp,);
225 let file = File::open(fp).await?;
226 let metadata = file.metadata().await?;
227 let size = metadata.len();
228 let data = BufReader::new(file);
229 self.load_data(sql, Box::new(data), size, method).await
230 }
231
232 async fn load_file_with_options(
233 &self,
234 sql: &str,
235 fp: &Path,
236 file_format_options: Option<BTreeMap<&str, &str>>,
237 copy_options: Option<BTreeMap<&str, &str>>,
238 ) -> Result<ServerStats> {
239 let file = File::open(fp).await?;
240 let metadata = file.metadata().await?;
241 let size = metadata.len();
242 let data = BufReader::new(file);
243 self.load_data_with_options(sql, Box::new(data), size, file_format_options, copy_options)
244 .await
245 }
246
247 async fn stream_load(
248 &self,
249 sql: &str,
250 data: Vec<Vec<&str>>,
251 method: LoadMethod,
252 ) -> Result<ServerStats> {
253 info!("stream load: {}; rows: {:?}", sql, data.len());
254 let mut wtr = csv::WriterBuilder::new().from_writer(vec![]);
255 for row in data {
256 wtr.write_record(row)
257 .map_err(|e| Error::BadArgument(e.to_string()))?;
258 }
259 let bytes = wtr.into_inner().map_err(|e| Error::IO(e.to_string()))?;
260 let size = bytes.len() as u64;
261 let reader = Box::new(std::io::Cursor::new(bytes));
262 let stats = if self.client.capability().streaming_load {
263 let sql = format!("{sql} from @_databend_load file_format = (type = csv)");
264 self.load_data(&sql, reader, size, method).await?
265 } else {
266 self.load_data_with_options(sql, reader, size, None, None)
267 .await?
268 };
269 Ok(stats)
270 }
271}
272
273impl<'o> RestAPIConnection {
274 pub async fn try_create(dsn: &str, name: String) -> Result<Self> {
275 let client = APIClient::new(dsn, Some(name)).await?;
276 Ok(Self { client })
277 }
278
279 fn default_file_format_options() -> BTreeMap<&'o str, &'o str> {
280 vec![
281 ("type", "CSV"),
282 ("field_delimiter", ","),
283 ("record_delimiter", "\n"),
284 ("skip_header", "0"),
285 ]
286 .into_iter()
287 .collect()
288 }
289
290 fn default_copy_options() -> BTreeMap<&'o str, &'o str> {
291 vec![("purge", "true")].into_iter().collect()
292 }
293}
294
295pub struct RestAPIRows<T> {
296 pages: Pages,
297
298 schema: SchemaRef,
299 timezone: Tz,
300
301 data: VecDeque<Vec<Option<String>>>,
302 rows: VecDeque<Row>,
303
304 stats: Option<ServerStats>,
305
306 _phantom: std::marker::PhantomData<T>,
307}
308
309impl<T> RestAPIRows<T> {
310 async fn from_pages(pages: Pages) -> Result<(Schema, Self)> {
311 let (pages, schema, timezone) = pages.wait_for_schema(true).await?;
312 let schema: Schema = schema.try_into()?;
313 let rows = Self {
314 pages,
315 schema: Arc::new(schema.clone()),
316 timezone,
317 data: Default::default(),
318 rows: Default::default(),
319 stats: None,
320 _phantom: PhantomData,
321 };
322 Ok((schema, rows))
323 }
324}
325
326impl<T: FromRowStats + std::marker::Unpin> Stream for RestAPIRows<T> {
327 type Item = Result<T>;
328
329 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
330 if let Some(ss) = self.stats.take() {
331 return Poll::Ready(Some(Ok(T::from_stats(ss))));
332 }
333 if self.data.len() > 1 {
336 if let Some(row) = self.data.pop_front() {
337 let row = T::try_from_raw_row(row, self.schema.clone(), self.timezone)?;
338 return Poll::Ready(Some(Ok(row)));
339 }
340 } else if self.rows.len() > 1 {
341 if let Some(row) = self.rows.pop_front() {
342 let row = T::from_row(row);
343 return Poll::Ready(Some(Ok(row)));
344 }
345 }
346
347 match Pin::new(&mut self.pages).poll_next(cx) {
348 Poll::Ready(Some(Ok(page))) => {
349 if self.schema.fields().is_empty() {
350 self.schema = Arc::new(page.raw_schema.try_into()?);
351 }
352 if page.batches.is_empty() {
353 let mut new_data = page.data.into();
354 self.data.append(&mut new_data);
355 } else {
356 for batch in page.batches.into_iter() {
357 let rows = Rows::try_from((batch, self.timezone))?;
358 self.rows.extend(rows);
359 }
360 }
361 Poll::Ready(Some(Ok(T::from_stats(page.stats.into()))))
362 }
363 Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e.into()))),
364 Poll::Ready(None) => {
365 if let Some(row) = self.rows.pop_front() {
366 let row = T::from_row(row);
367 Poll::Ready(Some(Ok(row)))
368 } else if let Some(row) = self.data.pop_front() {
369 let row = T::try_from_raw_row(row, self.schema.clone(), self.timezone)?;
370 Poll::Ready(Some(Ok(row)))
371 } else {
372 Poll::Ready(None)
373 }
374 }
375 Poll::Pending => Poll::Pending,
376 }
377 }
378}
379
380trait FromRowStats: Send + Sync + Clone {
381 fn from_stats(stats: ServerStats) -> Self;
382 fn try_from_raw_row(row: Vec<Option<String>>, schema: SchemaRef, tz: Tz) -> Result<Self>;
383 fn from_row(row: Row) -> Self;
384}
385
386impl FromRowStats for RowWithStats {
387 fn from_stats(stats: ServerStats) -> Self {
388 RowWithStats::Stats(stats)
389 }
390
391 fn try_from_raw_row(row: Vec<Option<String>>, schema: SchemaRef, tz: Tz) -> Result<Self> {
392 Ok(RowWithStats::Row(Row::try_from((schema, row, tz))?))
393 }
394 fn from_row(row: Row) -> Self {
395 RowWithStats::Row(row)
396 }
397}
398
399impl FromRowStats for RawRowWithStats {
400 fn from_stats(stats: ServerStats) -> Self {
401 RawRowWithStats::Stats(stats)
402 }
403
404 fn try_from_raw_row(row: Vec<Option<String>>, schema: SchemaRef, tz: Tz) -> Result<Self> {
405 let rows = Row::try_from((schema, row.clone(), tz))?;
406 Ok(RawRowWithStats::Row(RawRow::new(rows, row)))
407 }
408
409 fn from_row(row: Row) -> Self {
410 RawRowWithStats::Row(RawRow::from(row))
411 }
412}