Skip to main content

lake_driver/
rest_api.rs

1// Copyright 2025 TiDB Cloud
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use async_trait::async_trait;
16use log::info;
17use std::collections::{BTreeMap, VecDeque};
18use std::marker::PhantomData;
19use std::path::Path;
20use std::pin::Pin;
21use std::sync::Arc;
22use std::task::{Context, Poll};
23use std::time::Instant;
24use tokio::fs::File;
25use tokio::io::BufReader;
26use tokio_stream::Stream;
27
28use crate::client::LoadMethod;
29use crate::conn::{ConnectionInfo, IConnection, Reader};
30use lake_client::schema::{Schema, SchemaRef};
31use lake_client::Pages;
32use lake_client::{APIClient, ResultFormatSettings};
33use lake_driver_core::error::{Error, Result};
34use lake_driver_core::raw_rows::{RawRow, RawRowIterator, RawRowWithStats};
35use lake_driver_core::rows::{Row, RowIterator, RowStatsIterator, RowWithStats, Rows, ServerStats};
36
37const LOAD_PLACEHOLDER: &str = "@_databend_load";
38
39#[derive(Clone)]
40pub struct RestAPIConnection {
41    client: Arc<APIClient>,
42}
43
44impl RestAPIConnection {
45    fn gen_temp_stage_location(&self) -> Result<String> {
46        let now = chrono::Utc::now()
47            .timestamp_nanos_opt()
48            .ok_or_else(|| Error::IO("Failed to get current timestamp".to_string()))?;
49        Ok(format!("@~/client/load/{now}"))
50    }
51
52    async fn load_data_with_stage(
53        &self,
54        sql: &str,
55        data: Reader,
56        size: u64,
57    ) -> Result<ServerStats> {
58        let location = self.gen_temp_stage_location()?;
59        self.upload_to_stage(&location, data, size).await?;
60        if self.client.capability().streaming_load {
61            let sql = sql.replace(LOAD_PLACEHOLDER, &location);
62            let page = self.client.query_all(&sql, None).await?;
63            Ok(ServerStats::from(page.stats))
64        } else {
65            let file_format_options = Self::default_file_format_options();
66            let copy_options = Self::default_copy_options();
67            let stats = self
68                .client
69                .insert_with_stage(sql, &location, file_format_options, copy_options)
70                .await?;
71            Ok(ServerStats::from(stats))
72        }
73    }
74
75    async fn load_data_with_streaming(
76        &self,
77        sql: &str,
78        data: Reader,
79        size: u64,
80    ) -> Result<ServerStats> {
81        let start = Instant::now();
82        let response = self
83            .client
84            .streaming_load(sql, data, "<no_filename>")
85            .await?;
86        Ok(ServerStats {
87            total_rows: 0,
88            total_bytes: 0,
89            read_rows: response.stats.rows,
90            read_bytes: size as usize,
91            write_rows: response.stats.rows,
92            write_bytes: response.stats.bytes,
93            running_time_ms: start.elapsed().as_millis() as f64,
94            spill_file_nums: 0,
95            spill_bytes: 0,
96        })
97    }
98    async fn load_data_with_options(
99        &self,
100        sql: &str,
101        data: Reader,
102        size: u64,
103        file_format_options: Option<BTreeMap<&str, &str>>,
104        copy_options: Option<BTreeMap<&str, &str>>,
105    ) -> Result<ServerStats> {
106        let location = self.gen_temp_stage_location()?;
107        let file_format_options =
108            file_format_options.unwrap_or_else(Self::default_file_format_options);
109        let copy_options = copy_options.unwrap_or_else(Self::default_copy_options);
110        self.upload_to_stage(&location, Box::new(data), size)
111            .await?;
112        let stats = self
113            .client
114            .insert_with_stage(sql, &location, file_format_options, copy_options)
115            .await?;
116        Ok(ServerStats::from(stats))
117    }
118}
119
120#[async_trait]
121impl IConnection for RestAPIConnection {
122    async fn info(&self) -> ConnectionInfo {
123        ConnectionInfo {
124            handler: "RestAPI".to_string(),
125            host: self.client.host().to_string(),
126            port: self.client.port(),
127            user: self.client.username(),
128            catalog: self.client.current_catalog(),
129            database: self.client.current_database(),
130            warehouse: self.client.current_warehouse(),
131        }
132    }
133
134    fn last_query_id(&self) -> Option<String> {
135        self.client.last_query_id()
136    }
137
138    fn set_warehouse(&self, warehouse: &str) -> Result<()> {
139        self.client.set_warehouse(warehouse.to_string());
140        Ok(())
141    }
142
143    fn set_database(&self, database: &str) -> Result<()> {
144        self.client.set_database(database.to_string());
145        Ok(())
146    }
147
148    fn set_role(&self, role: &str) -> Result<()> {
149        self.client.set_role(role.to_string());
150        Ok(())
151    }
152
153    fn set_session(&self, key: &str, value: &str) -> Result<()> {
154        self.client.set_session(key.to_string(), value.to_string());
155        Ok(())
156    }
157
158    async fn close(&self) -> Result<()> {
159        self.client.close().await;
160        Ok(())
161    }
162
163    fn close_with_spawn(&self) -> Result<()> {
164        self.client.close_with_spawn();
165        Ok(())
166    }
167
168    fn supports_server_side_params(&self) -> bool {
169        self.client.capability().server_side_params
170    }
171
172    async fn exec(&self, sql: &str) -> Result<i64> {
173        info!("exec: {}", sql);
174        let page = self.client.query_all(sql, None).await?;
175        Ok(page.stats.progresses.write_progress.rows as i64)
176    }
177
178    async fn exec_with_params(&self, sql: &str, params: Option<serde_json::Value>) -> Result<i64> {
179        info!("exec with params: {}", sql);
180        let page = self.client.query_all(sql, params).await?;
181        Ok(page.stats.progresses.write_progress.rows as i64)
182    }
183
184    async fn kill_query(&self, query_id: &str) -> Result<()> {
185        Ok(self.client.kill_query(query_id).await?)
186    }
187
188    async fn query_iter(&self, sql: &str) -> Result<RowIterator> {
189        info!("query iter: {}", sql);
190        let rows_with_progress = self.query_iter_ext(sql).await?;
191        let rows = rows_with_progress.filter_rows().await?;
192        Ok(rows)
193    }
194
195    async fn query_iter_ext(&self, sql: &str) -> Result<RowStatsIterator> {
196        info!("query iter ext: {}", sql);
197        let pages = self.client.start_query(sql, true, None).await?;
198        let (schema, rows) = RestAPIRows::<RowWithStats>::from_pages(pages).await?;
199        Ok(RowStatsIterator::new(Arc::new(schema), Box::pin(rows)))
200    }
201
202    async fn query_iter_with_params(
203        &self,
204        sql: &str,
205        params: Option<serde_json::Value>,
206    ) -> Result<RowIterator> {
207        info!("query iter with params: {}", sql);
208        let rows_with_progress = self.query_iter_ext_with_params(sql, params).await?;
209        let rows = rows_with_progress.filter_rows().await?;
210        Ok(rows)
211    }
212
213    async fn query_iter_ext_with_params(
214        &self,
215        sql: &str,
216        params: Option<serde_json::Value>,
217    ) -> Result<RowStatsIterator> {
218        info!("query iter ext with params: {}", sql);
219        let pages = self.client.start_query(sql, true, params).await?;
220        let (schema, rows) = RestAPIRows::<RowWithStats>::from_pages(pages).await?;
221        Ok(RowStatsIterator::new(Arc::new(schema), Box::pin(rows)))
222    }
223
224    // raw data response query, only for test
225    async fn query_raw_iter(&self, sql: &str) -> Result<RawRowIterator> {
226        info!("query raw iter: {}", sql);
227        let pages = self.client.start_query(sql, true, None).await?;
228        let (schema, rows) = RestAPIRows::<RawRowWithStats>::from_pages(pages).await?;
229        Ok(RawRowIterator::new(Arc::new(schema), Box::pin(rows)))
230    }
231
232    async fn upload_to_stage(&self, stage: &str, data: Reader, size: u64) -> Result<()> {
233        self.client.upload_to_stage(stage, data, size).await?;
234        Ok(())
235    }
236
237    async fn load_data(
238        &self,
239        sql: &str,
240        data: Reader,
241        size: u64,
242        method: LoadMethod,
243    ) -> Result<ServerStats> {
244        let sql = sql.trim_end();
245        let sql = sql.trim_end_matches(';');
246        info!("load data: {}, size: {}, method: {method:?}", sql, size);
247        let sql_low = sql.to_lowercase();
248        let has_place_holder = sql_low.contains(LOAD_PLACEHOLDER);
249        let sql = match (self.client.capability().streaming_load, has_place_holder) {
250            (false, false) => {
251                // todo: deprecate this later
252                return self
253                    .load_data_with_options(sql, data, size, None, None)
254                    .await;
255            }
256            (false, true) => return Err(Error::BadArgument(
257                "Please upgrade your server to >= 1.2.781 to support insert from @_databend_load"
258                    .to_string(),
259            )),
260            (true, false) => {
261                format!("{sql} from @_databend_load file_format=(type=csv)")
262            }
263            (true, true) => sql.to_string(),
264        };
265
266        match method {
267            LoadMethod::Streaming => self.load_data_with_streaming(&sql, data, size).await,
268            LoadMethod::Stage => self.load_data_with_stage(&sql, data, size).await,
269        }
270    }
271
272    async fn load_file(&self, sql: &str, fp: &Path, method: LoadMethod) -> Result<ServerStats> {
273        info!("load file: {}, file: {:?}", sql, fp,);
274        let file = File::open(fp).await?;
275        let metadata = file.metadata().await?;
276        let size = metadata.len();
277        let data = BufReader::new(file);
278        self.load_data(sql, Box::new(data), size, method).await
279    }
280
281    async fn load_file_with_options(
282        &self,
283        sql: &str,
284        fp: &Path,
285        file_format_options: Option<BTreeMap<&str, &str>>,
286        copy_options: Option<BTreeMap<&str, &str>>,
287    ) -> Result<ServerStats> {
288        let file = File::open(fp).await?;
289        let metadata = file.metadata().await?;
290        let size = metadata.len();
291        let data = BufReader::new(file);
292        self.load_data_with_options(sql, Box::new(data), size, file_format_options, copy_options)
293            .await
294    }
295
296    async fn stream_load(
297        &self,
298        sql: &str,
299        data: Vec<Vec<&str>>,
300        method: LoadMethod,
301    ) -> Result<ServerStats> {
302        info!("stream load: {}; rows: {:?}", sql, data.len());
303        let mut wtr = csv::WriterBuilder::new().from_writer(vec![]);
304        for row in data {
305            wtr.write_record(row)
306                .map_err(|e| Error::BadArgument(e.to_string()))?;
307        }
308        let bytes = wtr.into_inner().map_err(|e| Error::IO(e.to_string()))?;
309        let size = bytes.len() as u64;
310        let reader = Box::new(std::io::Cursor::new(bytes));
311        let stats = if self.client.capability().streaming_load {
312            let sql = format!("{sql} from @_databend_load file_format = (type = csv)");
313            self.load_data(&sql, reader, size, method).await?
314        } else {
315            self.load_data_with_options(sql, reader, size, None, None)
316                .await?
317        };
318        Ok(stats)
319    }
320}
321
322impl<'o> RestAPIConnection {
323    pub async fn try_create(dsn: &str, name: String) -> Result<Self> {
324        let client = APIClient::new(dsn, Some(name)).await?;
325        Ok(Self { client })
326    }
327
328    fn default_file_format_options() -> BTreeMap<&'o str, &'o str> {
329        vec![
330            ("type", "CSV"),
331            ("field_delimiter", ","),
332            ("record_delimiter", "\n"),
333            ("skip_header", "0"),
334        ]
335        .into_iter()
336        .collect()
337    }
338
339    fn default_copy_options() -> BTreeMap<&'o str, &'o str> {
340        vec![("purge", "true")].into_iter().collect()
341    }
342}
343
344pub struct RestAPIRows<T> {
345    pages: Pages,
346
347    schema: SchemaRef,
348    settings: ResultFormatSettings,
349
350    data: VecDeque<Vec<Option<String>>>,
351    rows: VecDeque<Row>,
352
353    stats: Option<ServerStats>,
354
355    _phantom: std::marker::PhantomData<T>,
356}
357
358impl<T> RestAPIRows<T> {
359    async fn from_pages(pages: Pages) -> Result<(Schema, Self)> {
360        let (pages, schema, settings) = pages.wait_for_schema(true).await?;
361        let rows = Self {
362            pages,
363            schema: Arc::new(schema.clone()),
364            settings,
365            data: Default::default(),
366            rows: Default::default(),
367            stats: None,
368            _phantom: PhantomData,
369        };
370        Ok((schema, rows))
371    }
372}
373
374impl<T: FromRowStats + std::marker::Unpin> Stream for RestAPIRows<T> {
375    type Item = Result<T>;
376
377    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
378        if let Some(ss) = self.stats.take() {
379            return Poll::Ready(Some(Ok(T::from_stats(ss))));
380        }
381        // Skip to fetch next page if there is only one row left in buffer.
382        // Therefore, we could guarantee the `/final` called before the last row.
383        if self.data.len() > 1 {
384            if let Some(row) = self.data.pop_front() {
385                let row = T::try_from_raw_row(row, self.schema.clone(), &self.settings)?;
386                return Poll::Ready(Some(Ok(row)));
387            }
388        } else if self.rows.len() > 1 {
389            if let Some(row) = self.rows.pop_front() {
390                let row = T::from_row(row);
391                return Poll::Ready(Some(Ok(row)));
392            }
393        }
394
395        match Pin::new(&mut self.pages).poll_next(cx) {
396            Poll::Ready(Some(Ok(page))) => {
397                if self.schema.fields().is_empty() {
398                    if !page.raw_schema.is_empty() {
399                        self.schema = Arc::new(page.raw_schema.try_into()?);
400                    } else if !page.batches.is_empty() {
401                        self.schema = Arc::new(page.batches[0].schema().clone().try_into()?);
402                    }
403                }
404                if page.batches.is_empty() {
405                    let mut new_data = page.data.into();
406                    self.data.append(&mut new_data);
407                } else {
408                    for batch in page.batches.into_iter() {
409                        let rows = Rows::try_from((batch, self.settings.clone()))?;
410                        self.rows.extend(rows);
411                    }
412                }
413                Poll::Ready(Some(Ok(T::from_stats(page.stats.into()))))
414            }
415            Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e.into()))),
416            Poll::Ready(None) => {
417                if let Some(row) = self.rows.pop_front() {
418                    let row = T::from_row(row);
419                    Poll::Ready(Some(Ok(row)))
420                } else if let Some(row) = self.data.pop_front() {
421                    let row = T::try_from_raw_row(row, self.schema.clone(), &self.settings)?;
422                    Poll::Ready(Some(Ok(row)))
423                } else {
424                    Poll::Ready(None)
425                }
426            }
427            Poll::Pending => Poll::Pending,
428        }
429    }
430}
431
432trait FromRowStats: Send + Sync + Clone {
433    fn from_stats(stats: ServerStats) -> Self;
434    fn try_from_raw_row(
435        row: Vec<Option<String>>,
436        schema: SchemaRef,
437        tz: &ResultFormatSettings,
438    ) -> Result<Self>;
439    fn from_row(row: Row) -> Self;
440}
441
442impl FromRowStats for RowWithStats {
443    fn from_stats(stats: ServerStats) -> Self {
444        RowWithStats::Stats(stats)
445    }
446
447    fn try_from_raw_row(
448        row: Vec<Option<String>>,
449        schema: SchemaRef,
450        settings: &ResultFormatSettings,
451    ) -> Result<Self> {
452        Ok(RowWithStats::Row(Row::try_from((schema, row, settings))?))
453    }
454    fn from_row(row: Row) -> Self {
455        RowWithStats::Row(row)
456    }
457}
458
459impl FromRowStats for RawRowWithStats {
460    fn from_stats(stats: ServerStats) -> Self {
461        RawRowWithStats::Stats(stats)
462    }
463
464    fn try_from_raw_row(
465        row: Vec<Option<String>>,
466        schema: SchemaRef,
467        tz: &ResultFormatSettings,
468    ) -> Result<Self> {
469        let rows = Row::try_from((schema, row.clone(), tz))?;
470        Ok(RawRowWithStats::Row(RawRow::new(rows, row)))
471    }
472
473    fn from_row(row: Row) -> Self {
474        RawRowWithStats::Row(RawRow::from(row))
475    }
476}