databend_driver/
rest_api.rs

1// Copyright 2021 Datafuse Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use async_trait::async_trait;
16use jiff::tz::TimeZone;
17use log::info;
18use std::collections::{BTreeMap, VecDeque};
19use std::marker::PhantomData;
20use std::path::Path;
21use std::pin::Pin;
22use std::sync::Arc;
23use std::task::{Context, Poll};
24use std::time::Instant;
25use tokio::fs::File;
26use tokio::io::BufReader;
27use tokio_stream::Stream;
28
29use crate::client::LoadMethod;
30use crate::conn::{ConnectionInfo, IConnection, Reader};
31use databend_client::schema::{Schema, SchemaRef};
32use databend_client::Pages;
33use databend_client::{APIClient, ResultFormatSettings};
34use databend_driver_core::error::{Error, Result};
35use databend_driver_core::raw_rows::{RawRow, RawRowIterator, RawRowWithStats};
36use databend_driver_core::rows::{
37    Row, RowIterator, RowStatsIterator, RowWithStats, Rows, ServerStats,
38};
39
40const LOAD_PLACEHOLDER: &str = "@_databend_load";
41
42#[derive(Clone)]
43pub struct RestAPIConnection {
44    client: Arc<APIClient>,
45}
46
47impl RestAPIConnection {
48    fn gen_temp_stage_location(&self) -> Result<String> {
49        let now = chrono::Utc::now()
50            .timestamp_nanos_opt()
51            .ok_or_else(|| Error::IO("Failed to get current timestamp".to_string()))?;
52        Ok(format!("@~/client/load/{now}"))
53    }
54
55    async fn load_data_with_stage(
56        &self,
57        sql: &str,
58        data: Reader,
59        size: u64,
60    ) -> Result<ServerStats> {
61        let location = self.gen_temp_stage_location()?;
62        self.upload_to_stage(&location, data, size).await?;
63        if self.client.capability().streaming_load {
64            let sql = sql.replace(LOAD_PLACEHOLDER, &location);
65            let page = self.client.query_all(&sql).await?;
66            Ok(ServerStats::from(page.stats))
67        } else {
68            let file_format_options = Self::default_file_format_options();
69            let copy_options = Self::default_copy_options();
70            let stats = self
71                .client
72                .insert_with_stage(sql, &location, file_format_options, copy_options)
73                .await?;
74            Ok(ServerStats::from(stats))
75        }
76    }
77
78    async fn load_data_with_streaming(
79        &self,
80        sql: &str,
81        data: Reader,
82        size: u64,
83    ) -> Result<ServerStats> {
84        let start = Instant::now();
85        let response = self
86            .client
87            .streaming_load(sql, data, "<no_filename>")
88            .await?;
89        Ok(ServerStats {
90            total_rows: 0,
91            total_bytes: 0,
92            read_rows: response.stats.rows,
93            read_bytes: size as usize,
94            write_rows: response.stats.rows,
95            write_bytes: response.stats.bytes,
96            running_time_ms: start.elapsed().as_millis() as f64,
97            spill_file_nums: 0,
98            spill_bytes: 0,
99        })
100    }
101    async fn load_data_with_options(
102        &self,
103        sql: &str,
104        data: Reader,
105        size: u64,
106        file_format_options: Option<BTreeMap<&str, &str>>,
107        copy_options: Option<BTreeMap<&str, &str>>,
108    ) -> Result<ServerStats> {
109        let location = self.gen_temp_stage_location()?;
110        let file_format_options =
111            file_format_options.unwrap_or_else(Self::default_file_format_options);
112        let copy_options = copy_options.unwrap_or_else(Self::default_copy_options);
113        self.upload_to_stage(&location, Box::new(data), size)
114            .await?;
115        let stats = self
116            .client
117            .insert_with_stage(sql, &location, file_format_options, copy_options)
118            .await?;
119        Ok(ServerStats::from(stats))
120    }
121}
122
123#[async_trait]
124impl IConnection for RestAPIConnection {
125    async fn info(&self) -> ConnectionInfo {
126        ConnectionInfo {
127            handler: "RestAPI".to_string(),
128            host: self.client.host().to_string(),
129            port: self.client.port(),
130            user: self.client.username(),
131            catalog: self.client.current_catalog(),
132            database: self.client.current_database(),
133            warehouse: self.client.current_warehouse(),
134        }
135    }
136
137    fn last_query_id(&self) -> Option<String> {
138        self.client.last_query_id()
139    }
140
141    fn set_warehouse(&self, warehouse: &str) -> Result<()> {
142        self.client.set_warehouse(warehouse.to_string());
143        Ok(())
144    }
145
146    fn set_database(&self, database: &str) -> Result<()> {
147        self.client.set_database(database.to_string());
148        Ok(())
149    }
150
151    fn set_role(&self, role: &str) -> Result<()> {
152        self.client.set_role(role.to_string());
153        Ok(())
154    }
155
156    fn set_session(&self, key: &str, value: &str) -> Result<()> {
157        self.client.set_session(key.to_string(), value.to_string());
158        Ok(())
159    }
160
161    async fn close(&self) -> Result<()> {
162        self.client.close().await;
163        Ok(())
164    }
165
166    fn close_with_spawn(&self) -> Result<()> {
167        self.client.close_with_spawn();
168        Ok(())
169    }
170
171    async fn exec(&self, sql: &str) -> Result<i64> {
172        info!("exec: {}", sql);
173        let page = self.client.query_all(sql).await?;
174        Ok(page.stats.progresses.write_progress.rows as i64)
175    }
176
177    async fn kill_query(&self, query_id: &str) -> Result<()> {
178        Ok(self.client.kill_query(query_id).await?)
179    }
180
181    async fn query_iter(&self, sql: &str) -> Result<RowIterator> {
182        info!("query iter: {}", sql);
183        let rows_with_progress = self.query_iter_ext(sql).await?;
184        let rows = rows_with_progress.filter_rows().await?;
185        Ok(rows)
186    }
187
188    async fn query_iter_ext(&self, sql: &str) -> Result<RowStatsIterator> {
189        info!("query iter ext: {}", sql);
190        let pages = self.client.start_query(sql, true).await?;
191        let (schema, rows) = RestAPIRows::<RowWithStats>::from_pages(pages).await?;
192        Ok(RowStatsIterator::new(Arc::new(schema), Box::pin(rows)))
193    }
194
195    // raw data response query, only for test
196    async fn query_raw_iter(&self, sql: &str) -> Result<RawRowIterator> {
197        info!("query raw iter: {}", sql);
198        let pages = self.client.start_query(sql, true).await?;
199        let (schema, rows) = RestAPIRows::<RawRowWithStats>::from_pages(pages).await?;
200        Ok(RawRowIterator::new(Arc::new(schema), Box::pin(rows)))
201    }
202
203    async fn upload_to_stage(&self, stage: &str, data: Reader, size: u64) -> Result<()> {
204        self.client.upload_to_stage(stage, data, size).await?;
205        Ok(())
206    }
207
208    async fn load_data(
209        &self,
210        sql: &str,
211        data: Reader,
212        size: u64,
213        method: LoadMethod,
214    ) -> Result<ServerStats> {
215        let sql = sql.trim_end();
216        let sql = sql.trim_end_matches(';');
217        info!("load data: {}, size: {}, method: {method:?}", sql, size);
218        let sql_low = sql.to_lowercase();
219        let has_place_holder = sql_low.contains(LOAD_PLACEHOLDER);
220        let sql = match (self.client.capability().streaming_load, has_place_holder) {
221            (false, false) => {
222                // todo: deprecate this later
223                return self
224                    .load_data_with_options(sql, data, size, None, None)
225                    .await;
226            }
227            (false, true) => return Err(Error::BadArgument(
228                "Please upgrade your server to >= 1.2.781 to support insert from @_databend_load"
229                    .to_string(),
230            )),
231            (true, false) => {
232                format!("{sql} from @_databend_load file_format=(type=csv)")
233            }
234            (true, true) => sql.to_string(),
235        };
236
237        match method {
238            LoadMethod::Streaming => self.load_data_with_streaming(&sql, data, size).await,
239            LoadMethod::Stage => self.load_data_with_stage(&sql, data, size).await,
240        }
241    }
242
243    async fn load_file(&self, sql: &str, fp: &Path, method: LoadMethod) -> Result<ServerStats> {
244        info!("load file: {}, file: {:?}", sql, fp,);
245        let file = File::open(fp).await?;
246        let metadata = file.metadata().await?;
247        let size = metadata.len();
248        let data = BufReader::new(file);
249        self.load_data(sql, Box::new(data), size, method).await
250    }
251
252    async fn load_file_with_options(
253        &self,
254        sql: &str,
255        fp: &Path,
256        file_format_options: Option<BTreeMap<&str, &str>>,
257        copy_options: Option<BTreeMap<&str, &str>>,
258    ) -> Result<ServerStats> {
259        let file = File::open(fp).await?;
260        let metadata = file.metadata().await?;
261        let size = metadata.len();
262        let data = BufReader::new(file);
263        self.load_data_with_options(sql, Box::new(data), size, file_format_options, copy_options)
264            .await
265    }
266
267    async fn stream_load(
268        &self,
269        sql: &str,
270        data: Vec<Vec<&str>>,
271        method: LoadMethod,
272    ) -> Result<ServerStats> {
273        info!("stream load: {}; rows: {:?}", sql, data.len());
274        let mut wtr = csv::WriterBuilder::new().from_writer(vec![]);
275        for row in data {
276            wtr.write_record(row)
277                .map_err(|e| Error::BadArgument(e.to_string()))?;
278        }
279        let bytes = wtr.into_inner().map_err(|e| Error::IO(e.to_string()))?;
280        let size = bytes.len() as u64;
281        let reader = Box::new(std::io::Cursor::new(bytes));
282        let stats = if self.client.capability().streaming_load {
283            let sql = format!("{sql} from @_databend_load file_format = (type = csv)");
284            self.load_data(&sql, reader, size, method).await?
285        } else {
286            self.load_data_with_options(sql, reader, size, None, None)
287                .await?
288        };
289        Ok(stats)
290    }
291}
292
293impl<'o> RestAPIConnection {
294    pub async fn try_create(dsn: &str, name: String) -> Result<Self> {
295        let client = APIClient::new(dsn, Some(name)).await?;
296        Ok(Self { client })
297    }
298
299    fn default_file_format_options() -> BTreeMap<&'o str, &'o str> {
300        vec![
301            ("type", "CSV"),
302            ("field_delimiter", ","),
303            ("record_delimiter", "\n"),
304            ("skip_header", "0"),
305        ]
306        .into_iter()
307        .collect()
308    }
309
310    fn default_copy_options() -> BTreeMap<&'o str, &'o str> {
311        vec![("purge", "true")].into_iter().collect()
312    }
313}
314
315pub struct RestAPIRows<T> {
316    pages: Pages,
317
318    schema: SchemaRef,
319    settings: ResultFormatSettings,
320
321    data: VecDeque<Vec<Option<String>>>,
322    rows: VecDeque<Row>,
323
324    stats: Option<ServerStats>,
325
326    _phantom: std::marker::PhantomData<T>,
327}
328
329impl<T> RestAPIRows<T> {
330    async fn from_pages(pages: Pages) -> Result<(Schema, Self)> {
331        let (pages, schema, settings) = pages.wait_for_schema(true).await?;
332        let rows = Self {
333            pages,
334            schema: Arc::new(schema.clone()),
335            settings,
336            data: Default::default(),
337            rows: Default::default(),
338            stats: None,
339            _phantom: PhantomData,
340        };
341        Ok((schema, rows))
342    }
343}
344
345impl<T: FromRowStats + std::marker::Unpin> Stream for RestAPIRows<T> {
346    type Item = Result<T>;
347
348    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
349        if let Some(ss) = self.stats.take() {
350            return Poll::Ready(Some(Ok(T::from_stats(ss))));
351        }
352        // Skip to fetch next page if there is only one row left in buffer.
353        // Therefore, we could guarantee the `/final` called before the last row.
354        if self.data.len() > 1 {
355            if let Some(row) = self.data.pop_front() {
356                let row = T::try_from_raw_row(row, self.schema.clone(), &self.settings.timezone)?;
357                return Poll::Ready(Some(Ok(row)));
358            }
359        } else if self.rows.len() > 1 {
360            if let Some(row) = self.rows.pop_front() {
361                let row = T::from_row(row);
362                return Poll::Ready(Some(Ok(row)));
363            }
364        }
365
366        match Pin::new(&mut self.pages).poll_next(cx) {
367            Poll::Ready(Some(Ok(page))) => {
368                if self.schema.fields().is_empty() {
369                    if !page.raw_schema.is_empty() {
370                        self.schema = Arc::new(page.raw_schema.try_into()?);
371                    } else if !page.batches.is_empty() {
372                        self.schema = Arc::new(page.batches[0].schema().clone().try_into()?);
373                    }
374                }
375                if page.batches.is_empty() {
376                    let mut new_data = page.data.into();
377                    self.data.append(&mut new_data);
378                } else {
379                    for batch in page.batches.into_iter() {
380                        let rows = Rows::try_from((batch, self.settings.clone()))?;
381                        self.rows.extend(rows);
382                    }
383                }
384                Poll::Ready(Some(Ok(T::from_stats(page.stats.into()))))
385            }
386            Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e.into()))),
387            Poll::Ready(None) => {
388                if let Some(row) = self.rows.pop_front() {
389                    let row = T::from_row(row);
390                    Poll::Ready(Some(Ok(row)))
391                } else if let Some(row) = self.data.pop_front() {
392                    let row =
393                        T::try_from_raw_row(row, self.schema.clone(), &self.settings.timezone)?;
394                    Poll::Ready(Some(Ok(row)))
395                } else {
396                    Poll::Ready(None)
397                }
398            }
399            Poll::Pending => Poll::Pending,
400        }
401    }
402}
403
404trait FromRowStats: Send + Sync + Clone {
405    fn from_stats(stats: ServerStats) -> Self;
406    fn try_from_raw_row(row: Vec<Option<String>>, schema: SchemaRef, tz: &TimeZone)
407        -> Result<Self>;
408    fn from_row(row: Row) -> Self;
409}
410
411impl FromRowStats for RowWithStats {
412    fn from_stats(stats: ServerStats) -> Self {
413        RowWithStats::Stats(stats)
414    }
415
416    fn try_from_raw_row(
417        row: Vec<Option<String>>,
418        schema: SchemaRef,
419        tz: &TimeZone,
420    ) -> Result<Self> {
421        Ok(RowWithStats::Row(Row::try_from((schema, row, tz))?))
422    }
423    fn from_row(row: Row) -> Self {
424        RowWithStats::Row(row)
425    }
426}
427
428impl FromRowStats for RawRowWithStats {
429    fn from_stats(stats: ServerStats) -> Self {
430        RawRowWithStats::Stats(stats)
431    }
432
433    fn try_from_raw_row(
434        row: Vec<Option<String>>,
435        schema: SchemaRef,
436        tz: &TimeZone,
437    ) -> Result<Self> {
438        let rows = Row::try_from((schema, row.clone(), tz))?;
439        Ok(RawRowWithStats::Row(RawRow::new(rows, row)))
440    }
441
442    fn from_row(row: Row) -> Self {
443        RawRowWithStats::Row(RawRow::from(row))
444    }
445}