Skip to main content

lake_driver/
client.rs

1// Copyright 2025 TiDB Cloud
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use log::error;
16use once_cell::sync::Lazy;
17use std::collections::BTreeMap;
18use std::path::Path;
19use std::str::FromStr;
20use url::Url;
21
22use crate::conn::IConnection;
23#[cfg(feature = "flight-sql")]
24use crate::flight_sql::FlightSQLConnection;
25use crate::placeholder::PlaceholderVisitor;
26use crate::ConnectionInfo;
27use crate::Params;
28
29use databend_common_ast::parser::Dialect;
30use lake_client::PresignedResponse;
31use lake_driver_core::error::{Error, Result};
32use lake_driver_core::raw_rows::{RawRow, RawRowIterator};
33use lake_driver_core::rows::{Row, RowIterator, RowStatsIterator, ServerStats};
34use lake_driver_core::value::Value;
35use tokio_stream::StreamExt;
36
37use crate::rest_api::RestAPIConnection;
38
39static VERSION: Lazy<String> = Lazy::new(|| {
40    let version = option_env!("CARGO_PKG_VERSION").unwrap_or("unknown");
41    version.to_string()
42});
43
44#[derive(Clone, Copy, Debug, PartialEq)]
45pub enum LoadMethod {
46    Stage,
47    Streaming,
48}
49
50impl FromStr for LoadMethod {
51    type Err = Error;
52
53    fn from_str(s: &str) -> Result<Self, Self::Err> {
54        match s.to_lowercase().as_str() {
55            "stage" => Ok(LoadMethod::Stage),
56            "streaming" => Ok(LoadMethod::Streaming),
57            _ => Err(Error::BadArgument(format!("invalid load method: {s}"))),
58        }
59    }
60}
61
62#[derive(Clone)]
63pub struct Client {
64    dsn: String,
65    name: String,
66}
67
68use crate::conn::Reader;
69
70pub struct Connection {
71    inner: Box<dyn IConnection>,
72}
73
74impl Client {
75    pub fn new(dsn: String) -> Self {
76        let name = format!("lake-driver-rust/{}", VERSION.as_str());
77        Self { dsn, name }
78    }
79
80    pub fn with_name(mut self, name: String) -> Self {
81        self.name = name;
82        self
83    }
84
85    pub async fn get_conn(&self) -> Result<Connection> {
86        let u = Url::parse(&self.dsn)?;
87        match u.scheme() {
88            "databend" | "databend+http" | "databend+https" | "lake" | "lake+http"
89            | "lake+https" => {
90                let conn = RestAPIConnection::try_create(&self.dsn, self.name.clone()).await?;
91                Ok(Connection {
92                    inner: Box::new(conn),
93                })
94            }
95            #[cfg(feature = "flight-sql")]
96            "databend+flight" | "databend+grpc" | "lake+flight" | "lake+grpc" => {
97                let conn = FlightSQLConnection::try_create(&self.dsn, self.name.clone()).await?;
98                Ok(Connection {
99                    inner: Box::new(conn),
100                })
101            }
102            _ => Err(Error::Parsing(format!(
103                "Unsupported scheme: {}",
104                u.scheme()
105            ))),
106        }
107    }
108}
109
110impl Drop for Connection {
111    fn drop(&mut self) {
112        if let Err(e) = self.inner.close_with_spawn() {
113            error!("fail to close connection: {}", e);
114        }
115    }
116}
117
118impl Connection {
119    pub fn inner(&self) -> &dyn IConnection {
120        self.inner.as_ref()
121    }
122
123    pub async fn info(&self) -> ConnectionInfo {
124        self.inner.info().await
125    }
126    pub async fn close(&self) -> Result<()> {
127        self.inner.close().await
128    }
129
130    pub fn last_query_id(&self) -> Option<String> {
131        self.inner.last_query_id()
132    }
133
134    pub async fn version(&self) -> Result<String> {
135        self.inner.version().await
136    }
137
138    pub fn format_sql<P: Into<Params> + Send>(&self, sql: &str, params: P) -> String {
139        let params = params.into();
140        params.replace(sql)
141    }
142
143    pub async fn kill_query(&self, query_id: &str) -> Result<()> {
144        self.inner.kill_query(query_id).await
145    }
146
147    pub fn query(&self, sql: &str) -> QueryBuilder<'_> {
148        QueryBuilder::new(self, sql)
149    }
150
151    pub fn exec(&self, sql: &str) -> ExecBuilder<'_> {
152        ExecBuilder::new(self, sql)
153    }
154
155    pub async fn query_iter(&self, sql: &str) -> Result<RowIterator> {
156        QueryBuilder::new(self, sql).iter().await
157    }
158
159    pub async fn query_iter_ext(&self, sql: &str) -> Result<RowStatsIterator> {
160        QueryBuilder::new(self, sql).iter_ext().await
161    }
162
163    pub async fn query_row(&self, sql: &str) -> Result<Option<Row>> {
164        QueryBuilder::new(self, sql).one().await
165    }
166
167    pub async fn query_all(&self, sql: &str) -> Result<Vec<Row>> {
168        QueryBuilder::new(self, sql).all().await
169    }
170
171    // raw data response query, only for test
172    pub async fn query_raw_iter(&self, sql: &str) -> Result<RawRowIterator> {
173        self.inner.query_raw_iter(sql).await
174    }
175
176    // raw data response query, only for test
177    pub async fn query_raw_all(&self, sql: &str) -> Result<Vec<RawRow>> {
178        self.inner.query_raw_all(sql).await
179    }
180
181    /// Get presigned url for a given operation and stage location.
182    /// The operation can be "UPLOAD" or "DOWNLOAD".
183    pub async fn get_presigned_url(
184        &self,
185        operation: &str,
186        stage: &str,
187    ) -> Result<PresignedResponse> {
188        self.inner.get_presigned_url(operation, stage).await
189    }
190
191    pub async fn upload_to_stage(&self, stage: &str, data: Reader, size: u64) -> Result<()> {
192        self.inner.upload_to_stage(stage, data, size).await
193    }
194
195    pub async fn load_data(
196        &self,
197        sql: &str,
198        data: Reader,
199        size: u64,
200        method: LoadMethod,
201    ) -> Result<ServerStats> {
202        self.inner.load_data(sql, data, size, method).await
203    }
204
205    pub async fn load_file(&self, sql: &str, fp: &Path, method: LoadMethod) -> Result<ServerStats> {
206        self.inner.load_file(sql, fp, method).await
207    }
208
209    pub async fn load_file_with_options(
210        &self,
211        sql: &str,
212        fp: &Path,
213        file_format_options: Option<BTreeMap<&str, &str>>,
214        copy_options: Option<BTreeMap<&str, &str>>,
215    ) -> Result<ServerStats> {
216        self.inner
217            .load_file_with_options(sql, fp, file_format_options, copy_options)
218            .await
219    }
220
221    pub async fn stream_load(
222        &self,
223        sql: &str,
224        data: Vec<Vec<&str>>,
225        method: LoadMethod,
226    ) -> Result<ServerStats> {
227        self.inner.stream_load(sql, data, method).await
228    }
229
230    pub fn set_warehouse(&self, warehouse: &str) -> Result<()> {
231        self.inner.set_warehouse(warehouse)
232    }
233
234    pub fn set_database(&self, database: &str) -> Result<()> {
235        self.inner.set_database(database)
236    }
237
238    pub fn set_role(&self, role: &str) -> Result<()> {
239        self.inner.set_role(role)
240    }
241
242    pub fn set_session(&self, key: &str, value: &str) -> Result<()> {
243        self.inner.set_session(key, value)
244    }
245
246    // PUT file://<path_to_file>/<filename> internalStage|externalStage
247    pub async fn put_files(&self, local_file: &str, stage: &str) -> Result<RowStatsIterator> {
248        self.inner.put_files(local_file, stage).await
249    }
250
251    pub async fn get_files(&self, stage: &str, local_file: &str) -> Result<RowStatsIterator> {
252        self.inner.get_files(stage, local_file).await
253    }
254
255    // ORM Methods
256    pub fn query_as<T>(&self, sql: &str) -> ORMQueryBuilder<'_, T>
257    where
258        T: TryFrom<Row> + RowORM,
259        T::Error: std::fmt::Display,
260    {
261        ORMQueryBuilder::new(self, sql)
262    }
263
264    pub async fn insert<T>(&self, table_name: &str) -> Result<InsertCursor<'_, T>>
265    where
266        T: Clone + RowORM,
267    {
268        Ok(InsertCursor::new(self, table_name.to_string()))
269    }
270}
271
272pub struct QueryCursor<T> {
273    iter: RowIterator,
274    _phantom: std::marker::PhantomData<T>,
275}
276
277impl<T> QueryCursor<T>
278where
279    T: TryFrom<Row>,
280    T::Error: std::fmt::Display,
281{
282    fn new(iter: RowIterator) -> Self {
283        Self {
284            iter,
285            _phantom: std::marker::PhantomData,
286        }
287    }
288
289    pub async fn fetch(&mut self) -> Result<Option<T>> {
290        use tokio_stream::StreamExt;
291        match self.iter.next().await {
292            Some(row) => {
293                let row = row?;
294                let typed_row = T::try_from(row).map_err(|e| Error::Parsing(e.to_string()))?;
295                Ok(Some(typed_row))
296            }
297            None => Ok(None),
298        }
299    }
300
301    pub async fn next(&mut self) -> Result<Option<T>> {
302        self.fetch().await
303    }
304
305    pub async fn fetch_all(self) -> Result<Vec<T>> {
306        self.iter.try_collect().await
307    }
308}
309
310pub struct InsertCursor<'a, T> {
311    connection: &'a Connection,
312    table_name: String,
313    rows: Vec<T>,
314    _phantom: std::marker::PhantomData<T>,
315}
316
317impl<'a, T> InsertCursor<'a, T>
318where
319    T: Clone + RowORM,
320{
321    fn new(connection: &'a Connection, table_name: String) -> Self {
322        Self {
323            connection,
324            table_name,
325            rows: Vec::new(),
326            _phantom: std::marker::PhantomData,
327        }
328    }
329
330    pub async fn write(&mut self, row: &T) -> Result<()> {
331        self.rows.push(row.clone());
332        Ok(())
333    }
334
335    pub async fn end(self) -> Result<i64> {
336        if self.rows.is_empty() {
337            return Ok(0);
338        }
339        let connection = self.connection;
340        // Generate field names and values for INSERT (exclude skip_serializing)
341        let field_names = T::insert_field_names();
342        let field_list = field_names.join(", ");
343        let placeholder_list = (0..field_names.len())
344            .map(|_| "?")
345            .collect::<Vec<_>>()
346            .join(", ");
347
348        let sql = format!(
349            "INSERT INTO {} ({}) VALUES ({})",
350            self.table_name, field_list, placeholder_list
351        );
352
353        let mut total_inserted = 0;
354        for row in &self.rows {
355            let values = row.to_values();
356            let json_values: Vec<serde_json::Value> =
357                values.into_iter().map(|v| v.to_json_value()).collect();
358            let params = Params::QuestionParams(json_values);
359            let inserted = connection.exec(&sql).bind(params).await?;
360            total_inserted += inserted;
361        }
362
363        Ok(total_inserted)
364    }
365}
366
367// Helper function to replace ?fields placeholder for queries
368fn replace_query_fields_placeholder(sql: &str, field_names: &[&str]) -> String {
369    let fields = field_names.join(", ");
370    sql.replace("?fields", &fields)
371}
372
373// Helper function to replace ?fields placeholder for inserts
374#[allow(dead_code)]
375fn replace_insert_fields_placeholder(sql: &str, field_names: &[&str]) -> String {
376    let fields = field_names.join(", ");
377    sql.replace("?fields", &fields)
378}
379
380// ORM Query Builder
381pub struct ORMQueryBuilder<'a, T> {
382    connection: &'a Connection,
383    sql: String,
384    params: Option<Params>,
385    _phantom: std::marker::PhantomData<T>,
386}
387
388impl<'a, T> ORMQueryBuilder<'a, T>
389where
390    T: TryFrom<Row> + RowORM,
391    T::Error: std::fmt::Display,
392{
393    fn new(connection: &'a Connection, sql: &str) -> Self {
394        Self {
395            connection,
396            sql: sql.to_string(),
397            params: None,
398            _phantom: std::marker::PhantomData,
399        }
400    }
401
402    pub fn bind<P: Into<Params> + Send>(mut self, params: P) -> Self {
403        self.params = Some(params.into());
404        self
405    }
406
407    pub async fn execute(self) -> Result<QueryCursor<T>> {
408        let sql_with_fields = replace_query_fields_placeholder(&self.sql, &T::query_field_names());
409        let final_sql = if let Some(params) = self.params {
410            params.replace(&sql_with_fields)
411        } else {
412            sql_with_fields
413        };
414        let row_iter = self.connection.inner.query_iter(&final_sql).await?;
415        Ok(QueryCursor::new(row_iter))
416    }
417}
418
419impl<'a, T> std::future::IntoFuture for ORMQueryBuilder<'a, T>
420where
421    T: TryFrom<Row> + RowORM + Send + 'a,
422    T::Error: std::fmt::Display,
423{
424    type Output = Result<QueryCursor<T>>;
425    type IntoFuture =
426        std::pin::Pin<Box<dyn std::future::Future<Output = Self::Output> + Send + 'a>>;
427
428    fn into_future(self) -> Self::IntoFuture {
429        Box::pin(self.execute())
430    }
431}
432
433// Builder pattern for query operations
434pub struct QueryBuilder<'a> {
435    connection: &'a Connection,
436    sql: String,
437    params: Option<Params>,
438}
439
440impl<'a> QueryBuilder<'a> {
441    fn new(connection: &'a Connection, sql: &str) -> Self {
442        Self {
443            connection,
444            sql: sql.to_string(),
445            params: None,
446        }
447    }
448
449    pub fn bind<P: Into<Params> + Send>(mut self, params: P) -> Self {
450        self.params = Some(params.into());
451        self
452    }
453
454    pub async fn iter(self) -> Result<RowIterator> {
455        if let Some(params) = &self.params {
456            if self.should_use_server_side_params() {
457                let json_params = params.to_json_value();
458                return self
459                    .connection
460                    .inner
461                    .query_iter_with_params(&self.sql, Some(json_params))
462                    .await;
463            }
464        }
465        let sql = self.get_final_sql();
466        self.connection.inner.query_iter(&sql).await
467    }
468
469    pub async fn iter_ext(self) -> Result<RowStatsIterator> {
470        if let Some(params) = &self.params {
471            if self.should_use_server_side_params() {
472                let json_params = params.to_json_value();
473                return self
474                    .connection
475                    .inner
476                    .query_iter_ext_with_params(&self.sql, Some(json_params))
477                    .await;
478            }
479        }
480        let sql = self.get_final_sql();
481        self.connection.inner.query_iter_ext(&sql).await
482    }
483
484    pub async fn one(self) -> Result<Option<Row>> {
485        if let Some(params) = &self.params {
486            if self.should_use_server_side_params() {
487                let json_params = params.to_json_value();
488                let mut rows = self
489                    .connection
490                    .inner
491                    .query_iter_with_params(&self.sql, Some(json_params))
492                    .await?;
493                return match rows.next().await {
494                    Some(r) => Ok(Some(r?)),
495                    None => Ok(None),
496                };
497            }
498        }
499        let sql = self.get_final_sql();
500        self.connection.inner.query_row(&sql).await
501    }
502
503    pub async fn all(self) -> Result<Vec<Row>> {
504        if let Some(params) = &self.params {
505            if self.should_use_server_side_params() {
506                let json_params = params.to_json_value();
507                let rows = self
508                    .connection
509                    .inner
510                    .query_iter_with_params(&self.sql, Some(json_params))
511                    .await?;
512                return rows.collect().await;
513            }
514        }
515        let sql = self.get_final_sql();
516        self.connection.inner.query_all(&sql).await
517    }
518
519    pub async fn cursor_as<T>(self) -> Result<QueryCursor<T>>
520    where
521        T: TryFrom<Row> + RowORM,
522        T::Error: std::fmt::Display,
523    {
524        let sql_with_fields = replace_query_fields_placeholder(&self.sql, &T::query_field_names());
525        let final_sql = if let Some(params) = self.params {
526            params.replace(&sql_with_fields)
527        } else {
528            sql_with_fields
529        };
530        let row_iter = self.connection.inner.query_iter(&final_sql).await?;
531        Ok(QueryCursor::new(row_iter))
532    }
533
534    fn should_use_server_side_params(&self) -> bool {
535        self.connection.inner.supports_server_side_params()
536            && !sql_has_dollar_placeholders(&self.sql)
537    }
538
539    fn get_final_sql(&self) -> String {
540        match &self.params {
541            Some(params) => params.replace(&self.sql),
542            None => self.sql.clone(),
543        }
544    }
545}
546
547// Builder pattern for execution operations
548pub struct ExecBuilder<'a> {
549    connection: &'a Connection,
550    sql: String,
551    params: Option<Params>,
552}
553
554impl<'a> ExecBuilder<'a> {
555    fn new(connection: &'a Connection, sql: &str) -> Self {
556        Self {
557            connection,
558            sql: sql.to_string(),
559            params: None,
560        }
561    }
562
563    pub fn bind<P: Into<Params> + Send>(mut self, params: P) -> Self {
564        self.params = Some(params.into());
565        self
566    }
567
568    pub async fn execute(self) -> Result<i64> {
569        if let Some(ref params) = self.params {
570            if self.should_use_server_side_params() {
571                let json_params = params.to_json_value();
572                return self
573                    .connection
574                    .inner
575                    .exec_with_params(&self.sql, Some(json_params))
576                    .await;
577            }
578        }
579        let sql = match self.params {
580            Some(params) => params.replace(&self.sql),
581            None => self.sql,
582        };
583        self.connection.inner.exec(&sql).await
584    }
585
586    fn should_use_server_side_params(&self) -> bool {
587        self.connection.inner.supports_server_side_params()
588            && !sql_has_dollar_placeholders(&self.sql)
589    }
590}
591
592impl<'a> std::future::IntoFuture for ExecBuilder<'a> {
593    type Output = Result<i64>;
594    type IntoFuture =
595        std::pin::Pin<Box<dyn std::future::Future<Output = Self::Output> + Send + 'a>>;
596
597    fn into_future(self) -> Self::IntoFuture {
598        Box::pin(self.execute())
599    }
600}
601
602fn sql_has_dollar_placeholders(sql: &str) -> bool {
603    let tokens = match databend_common_ast::parser::tokenize_sql(sql) {
604        Ok(t) => t,
605        Err(_) => return false,
606    };
607    if let Ok((stmt, _)) = databend_common_ast::parser::parse_sql(&tokens, Dialect::PostgreSQL) {
608        let mut visitor = PlaceholderVisitor::new();
609        return visitor.has_dollar_positions(&stmt);
610    }
611    false
612}
613
614// Add trait bounds for ORM functionality
615pub trait RowORM: TryFrom<Row> + Clone {
616    fn field_names() -> Vec<&'static str>; // For backward compatibility
617    fn query_field_names() -> Vec<&'static str>; // For SELECT queries (exclude skip_deserializing)
618    fn insert_field_names() -> Vec<&'static str>; // For INSERT statements (exclude skip_serializing)
619    fn to_values(&self) -> Vec<Value>;
620}