databend_driver/
rest_api.rs

1// Copyright 2021 Datafuse Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use async_trait::async_trait;
16use chrono_tz::Tz;
17use log::info;
18use std::collections::{BTreeMap, VecDeque};
19use std::marker::PhantomData;
20use std::path::Path;
21use std::pin::Pin;
22use std::sync::Arc;
23use std::task::{Context, Poll};
24use std::time::Instant;
25use tokio::fs::File;
26use tokio::io::BufReader;
27use tokio_stream::Stream;
28
29use crate::client::LoadMethod;
30use crate::conn::{ConnectionInfo, IConnection, Reader};
31use databend_client::APIClient;
32use databend_client::Pages;
33use databend_driver_core::error::{Error, Result};
34use databend_driver_core::raw_rows::{RawRow, RawRowIterator, RawRowWithStats};
35use databend_driver_core::rows::{
36    Row, RowIterator, RowStatsIterator, RowWithStats, Rows, ServerStats,
37};
38use databend_driver_core::schema::{Schema, SchemaRef};
39
40const LOAD_PLACEHOLDER: &str = "@_databend_load";
41
42#[derive(Clone)]
43pub struct RestAPIConnection {
44    client: Arc<APIClient>,
45}
46
47impl RestAPIConnection {
48    fn gen_temp_stage_location(&self) -> Result<String> {
49        let now = chrono::Utc::now()
50            .timestamp_nanos_opt()
51            .ok_or_else(|| Error::IO("Failed to get current timestamp".to_string()))?;
52        Ok(format!("@~/client/load/{now}"))
53    }
54
55    async fn load_data_with_stage(
56        &self,
57        sql: &str,
58        data: Reader,
59        size: u64,
60    ) -> Result<ServerStats> {
61        let location = self.gen_temp_stage_location()?;
62        self.upload_to_stage(&location, data, size).await?;
63        if self.client.capability().streaming_load {
64            let sql = sql.replace(LOAD_PLACEHOLDER, &location);
65            let page = self.client.query_all(&sql).await?;
66            Ok(ServerStats::from(page.stats))
67        } else {
68            let file_format_options = Self::default_file_format_options();
69            let copy_options = Self::default_copy_options();
70            let stats = self
71                .client
72                .insert_with_stage(sql, &location, file_format_options, copy_options)
73                .await?;
74            Ok(ServerStats::from(stats))
75        }
76    }
77
78    async fn load_data_with_streaming(
79        &self,
80        sql: &str,
81        data: Reader,
82        size: u64,
83    ) -> Result<ServerStats> {
84        let start = Instant::now();
85        let response = self
86            .client
87            .streaming_load(sql, data, "<no_filename>")
88            .await?;
89        Ok(ServerStats {
90            total_rows: 0,
91            total_bytes: 0,
92            read_rows: response.stats.rows,
93            read_bytes: size as usize,
94            write_rows: response.stats.rows,
95            write_bytes: response.stats.bytes,
96            running_time_ms: start.elapsed().as_millis() as f64,
97            spill_file_nums: 0,
98            spill_bytes: 0,
99        })
100    }
101    async fn load_data_with_options(
102        &self,
103        sql: &str,
104        data: Reader,
105        size: u64,
106        file_format_options: Option<BTreeMap<&str, &str>>,
107        copy_options: Option<BTreeMap<&str, &str>>,
108    ) -> Result<ServerStats> {
109        let location = self.gen_temp_stage_location()?;
110        let file_format_options =
111            file_format_options.unwrap_or_else(Self::default_file_format_options);
112        let copy_options = copy_options.unwrap_or_else(Self::default_copy_options);
113        self.upload_to_stage(&location, Box::new(data), size)
114            .await?;
115        let stats = self
116            .client
117            .insert_with_stage(sql, &location, file_format_options, copy_options)
118            .await?;
119        Ok(ServerStats::from(stats))
120    }
121}
122
123#[async_trait]
124impl IConnection for RestAPIConnection {
125    async fn info(&self) -> ConnectionInfo {
126        ConnectionInfo {
127            handler: "RestAPI".to_string(),
128            host: self.client.host().to_string(),
129            port: self.client.port(),
130            user: self.client.username(),
131            catalog: self.client.current_catalog(),
132            database: self.client.current_database(),
133            warehouse: self.client.current_warehouse(),
134        }
135    }
136
137    fn last_query_id(&self) -> Option<String> {
138        self.client.last_query_id()
139    }
140
141    async fn close(&self) -> Result<()> {
142        self.client.close().await;
143        Ok(())
144    }
145
146    fn close_with_spawn(&self) -> Result<()> {
147        self.client.close_with_spawn();
148        Ok(())
149    }
150
151    async fn exec(&self, sql: &str) -> Result<i64> {
152        info!("exec: {}", sql);
153        let page = self.client.query_all(sql).await?;
154        Ok(page.stats.progresses.write_progress.rows as i64)
155    }
156
157    async fn kill_query(&self, query_id: &str) -> Result<()> {
158        Ok(self.client.kill_query(query_id).await?)
159    }
160
161    async fn query_iter(&self, sql: &str) -> Result<RowIterator> {
162        info!("query iter: {}", sql);
163        let rows_with_progress = self.query_iter_ext(sql).await?;
164        let rows = rows_with_progress.filter_rows().await?;
165        Ok(rows)
166    }
167
168    async fn query_iter_ext(&self, sql: &str) -> Result<RowStatsIterator> {
169        info!("query iter ext: {}", sql);
170        let pages = self.client.start_query(sql, true).await?;
171        let (schema, rows) = RestAPIRows::<RowWithStats>::from_pages(pages).await?;
172        Ok(RowStatsIterator::new(Arc::new(schema), Box::pin(rows)))
173    }
174
175    // raw data response query, only for test
176    async fn query_raw_iter(&self, sql: &str) -> Result<RawRowIterator> {
177        info!("query raw iter: {}", sql);
178        let pages = self.client.start_query(sql, true).await?;
179        let (schema, rows) = RestAPIRows::<RawRowWithStats>::from_pages(pages).await?;
180        Ok(RawRowIterator::new(Arc::new(schema), Box::pin(rows)))
181    }
182
183    async fn upload_to_stage(&self, stage: &str, data: Reader, size: u64) -> Result<()> {
184        self.client.upload_to_stage(stage, data, size).await?;
185        Ok(())
186    }
187
188    async fn load_data(
189        &self,
190        sql: &str,
191        data: Reader,
192        size: u64,
193        method: LoadMethod,
194    ) -> Result<ServerStats> {
195        let sql = sql.trim_end();
196        let sql = sql.trim_end_matches(';');
197        info!("load data: {}, size: {}, method: {method:?}", sql, size);
198        let sql_low = sql.to_lowercase();
199        let has_place_holder = sql_low.contains(LOAD_PLACEHOLDER);
200        let sql = match (self.client.capability().streaming_load, has_place_holder) {
201            (false, false) => {
202                // todo: deprecate this later
203                return self
204                    .load_data_with_options(sql, data, size, None, None)
205                    .await;
206            }
207            (false, true) => return Err(Error::BadArgument(
208                "Please upgrade your server to >= 1.2.781 to support insert from @_databend_load"
209                    .to_string(),
210            )),
211            (true, false) => {
212                format!("{sql} from @_databend_load file_format=(type=csv)")
213            }
214            (true, true) => sql.to_string(),
215        };
216
217        match method {
218            LoadMethod::Streaming => self.load_data_with_streaming(&sql, data, size).await,
219            LoadMethod::Stage => self.load_data_with_stage(&sql, data, size).await,
220        }
221    }
222
223    async fn load_file(&self, sql: &str, fp: &Path, method: LoadMethod) -> Result<ServerStats> {
224        info!("load file: {}, file: {:?}", sql, fp,);
225        let file = File::open(fp).await?;
226        let metadata = file.metadata().await?;
227        let size = metadata.len();
228        let data = BufReader::new(file);
229        self.load_data(sql, Box::new(data), size, method).await
230    }
231
232    async fn load_file_with_options(
233        &self,
234        sql: &str,
235        fp: &Path,
236        file_format_options: Option<BTreeMap<&str, &str>>,
237        copy_options: Option<BTreeMap<&str, &str>>,
238    ) -> Result<ServerStats> {
239        let file = File::open(fp).await?;
240        let metadata = file.metadata().await?;
241        let size = metadata.len();
242        let data = BufReader::new(file);
243        self.load_data_with_options(sql, Box::new(data), size, file_format_options, copy_options)
244            .await
245    }
246
247    async fn stream_load(
248        &self,
249        sql: &str,
250        data: Vec<Vec<&str>>,
251        method: LoadMethod,
252    ) -> Result<ServerStats> {
253        info!("stream load: {}; rows: {:?}", sql, data.len());
254        let mut wtr = csv::WriterBuilder::new().from_writer(vec![]);
255        for row in data {
256            wtr.write_record(row)
257                .map_err(|e| Error::BadArgument(e.to_string()))?;
258        }
259        let bytes = wtr.into_inner().map_err(|e| Error::IO(e.to_string()))?;
260        let size = bytes.len() as u64;
261        let reader = Box::new(std::io::Cursor::new(bytes));
262        let stats = if self.client.capability().streaming_load {
263            let sql = format!("{sql} from @_databend_load file_format = (type = csv)");
264            self.load_data(&sql, reader, size, method).await?
265        } else {
266            self.load_data_with_options(sql, reader, size, None, None)
267                .await?
268        };
269        Ok(stats)
270    }
271}
272
273impl<'o> RestAPIConnection {
274    pub async fn try_create(dsn: &str, name: String) -> Result<Self> {
275        let client = APIClient::new(dsn, Some(name)).await?;
276        Ok(Self { client })
277    }
278
279    fn default_file_format_options() -> BTreeMap<&'o str, &'o str> {
280        vec![
281            ("type", "CSV"),
282            ("field_delimiter", ","),
283            ("record_delimiter", "\n"),
284            ("skip_header", "0"),
285        ]
286        .into_iter()
287        .collect()
288    }
289
290    fn default_copy_options() -> BTreeMap<&'o str, &'o str> {
291        vec![("purge", "true")].into_iter().collect()
292    }
293}
294
295pub struct RestAPIRows<T> {
296    pages: Pages,
297
298    schema: SchemaRef,
299    timezone: Tz,
300
301    data: VecDeque<Vec<Option<String>>>,
302    rows: VecDeque<Row>,
303
304    stats: Option<ServerStats>,
305
306    _phantom: std::marker::PhantomData<T>,
307}
308
309impl<T> RestAPIRows<T> {
310    async fn from_pages(pages: Pages) -> Result<(Schema, Self)> {
311        let (pages, schema, timezone) = pages.wait_for_schema(true).await?;
312        let schema: Schema = schema.try_into()?;
313        let rows = Self {
314            pages,
315            schema: Arc::new(schema.clone()),
316            timezone,
317            data: Default::default(),
318            rows: Default::default(),
319            stats: None,
320            _phantom: PhantomData,
321        };
322        Ok((schema, rows))
323    }
324}
325
326impl<T: FromRowStats + std::marker::Unpin> Stream for RestAPIRows<T> {
327    type Item = Result<T>;
328
329    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
330        if let Some(ss) = self.stats.take() {
331            return Poll::Ready(Some(Ok(T::from_stats(ss))));
332        }
333        // Skip to fetch next page if there is only one row left in buffer.
334        // Therefore, we could guarantee the `/final` called before the last row.
335        if self.data.len() > 1 {
336            if let Some(row) = self.data.pop_front() {
337                let row = T::try_from_raw_row(row, self.schema.clone(), self.timezone)?;
338                return Poll::Ready(Some(Ok(row)));
339            }
340        } else if self.rows.len() > 1 {
341            if let Some(row) = self.rows.pop_front() {
342                let row = T::from_row(row);
343                return Poll::Ready(Some(Ok(row)));
344            }
345        }
346
347        match Pin::new(&mut self.pages).poll_next(cx) {
348            Poll::Ready(Some(Ok(page))) => {
349                if self.schema.fields().is_empty() {
350                    self.schema = Arc::new(page.raw_schema.try_into()?);
351                }
352                if page.batches.is_empty() {
353                    let mut new_data = page.data.into();
354                    self.data.append(&mut new_data);
355                } else {
356                    for batch in page.batches.into_iter() {
357                        let rows = Rows::try_from((batch, self.timezone))?;
358                        self.rows.extend(rows);
359                    }
360                }
361                Poll::Ready(Some(Ok(T::from_stats(page.stats.into()))))
362            }
363            Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e.into()))),
364            Poll::Ready(None) => {
365                if let Some(row) = self.rows.pop_front() {
366                    let row = T::from_row(row);
367                    Poll::Ready(Some(Ok(row)))
368                } else if let Some(row) = self.data.pop_front() {
369                    let row = T::try_from_raw_row(row, self.schema.clone(), self.timezone)?;
370                    Poll::Ready(Some(Ok(row)))
371                } else {
372                    Poll::Ready(None)
373                }
374            }
375            Poll::Pending => Poll::Pending,
376        }
377    }
378}
379
380trait FromRowStats: Send + Sync + Clone {
381    fn from_stats(stats: ServerStats) -> Self;
382    fn try_from_raw_row(row: Vec<Option<String>>, schema: SchemaRef, tz: Tz) -> Result<Self>;
383    fn from_row(row: Row) -> Self;
384}
385
386impl FromRowStats for RowWithStats {
387    fn from_stats(stats: ServerStats) -> Self {
388        RowWithStats::Stats(stats)
389    }
390
391    fn try_from_raw_row(row: Vec<Option<String>>, schema: SchemaRef, tz: Tz) -> Result<Self> {
392        Ok(RowWithStats::Row(Row::try_from((schema, row, tz))?))
393    }
394    fn from_row(row: Row) -> Self {
395        RowWithStats::Row(row)
396    }
397}
398
399impl FromRowStats for RawRowWithStats {
400    fn from_stats(stats: ServerStats) -> Self {
401        RawRowWithStats::Stats(stats)
402    }
403
404    fn try_from_raw_row(row: Vec<Option<String>>, schema: SchemaRef, tz: Tz) -> Result<Self> {
405        let rows = Row::try_from((schema, row.clone(), tz))?;
406        Ok(RawRowWithStats::Row(RawRow::new(rows, row)))
407    }
408
409    fn from_row(row: Row) -> Self {
410        RawRowWithStats::Row(RawRow::from(row))
411    }
412}