Skip to main content

databend_driver/
client.rs

1// Copyright 2021 Datafuse Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use log::error;
16use once_cell::sync::Lazy;
17use std::collections::BTreeMap;
18use std::path::Path;
19use std::str::FromStr;
20use url::Url;
21
22use crate::conn::IConnection;
23#[cfg(feature = "flight-sql")]
24use crate::flight_sql::FlightSQLConnection;
25use crate::placeholder::PlaceholderVisitor;
26use crate::ConnectionInfo;
27use crate::Params;
28
29use databend_client::PresignedResponse;
30use databend_common_ast::parser::Dialect;
31use databend_driver_core::error::{Error, Result};
32use databend_driver_core::raw_rows::{RawRow, RawRowIterator};
33use databend_driver_core::rows::{Row, RowIterator, RowStatsIterator, ServerStats};
34use databend_driver_core::value::Value;
35use tokio_stream::StreamExt;
36
37use crate::rest_api::RestAPIConnection;
38
39static VERSION: Lazy<String> = Lazy::new(|| {
40    let version = option_env!("CARGO_PKG_VERSION").unwrap_or("unknown");
41    version.to_string()
42});
43
44#[derive(Clone, Copy, Debug, PartialEq)]
45pub enum LoadMethod {
46    Stage,
47    Streaming,
48}
49
50impl FromStr for LoadMethod {
51    type Err = Error;
52
53    fn from_str(s: &str) -> Result<Self, Self::Err> {
54        match s.to_lowercase().as_str() {
55            "stage" => Ok(LoadMethod::Stage),
56            "streaming" => Ok(LoadMethod::Streaming),
57            _ => Err(Error::BadArgument(format!("invalid load method: {s}"))),
58        }
59    }
60}
61
62#[derive(Clone)]
63pub struct Client {
64    dsn: String,
65    name: String,
66}
67
68use crate::conn::Reader;
69
70pub struct Connection {
71    inner: Box<dyn IConnection>,
72}
73
74impl Client {
75    pub fn new(dsn: String) -> Self {
76        let name = format!("databend-driver-rust/{}", VERSION.as_str());
77        Self { dsn, name }
78    }
79
80    pub fn with_name(mut self, name: String) -> Self {
81        self.name = name;
82        self
83    }
84
85    pub async fn get_conn(&self) -> Result<Connection> {
86        let u = Url::parse(&self.dsn)?;
87        match u.scheme() {
88            "databend" | "databend+http" | "databend+https" => {
89                let conn = RestAPIConnection::try_create(&self.dsn, self.name.clone()).await?;
90                Ok(Connection {
91                    inner: Box::new(conn),
92                })
93            }
94            #[cfg(feature = "flight-sql")]
95            "databend+flight" | "databend+grpc" => {
96                let conn = FlightSQLConnection::try_create(&self.dsn, self.name.clone()).await?;
97                Ok(Connection {
98                    inner: Box::new(conn),
99                })
100            }
101            _ => Err(Error::Parsing(format!(
102                "Unsupported scheme: {}",
103                u.scheme()
104            ))),
105        }
106    }
107}
108
109impl Drop for Connection {
110    fn drop(&mut self) {
111        if let Err(e) = self.inner.close_with_spawn() {
112            error!("fail to close connection: {}", e);
113        }
114    }
115}
116
117impl Connection {
118    pub fn inner(&self) -> &dyn IConnection {
119        self.inner.as_ref()
120    }
121
122    pub async fn info(&self) -> ConnectionInfo {
123        self.inner.info().await
124    }
125    pub async fn close(&self) -> Result<()> {
126        self.inner.close().await
127    }
128
129    pub fn last_query_id(&self) -> Option<String> {
130        self.inner.last_query_id()
131    }
132
133    pub async fn version(&self) -> Result<String> {
134        self.inner.version().await
135    }
136
137    pub fn format_sql<P: Into<Params> + Send>(&self, sql: &str, params: P) -> String {
138        let params = params.into();
139        params.replace(sql)
140    }
141
142    pub async fn kill_query(&self, query_id: &str) -> Result<()> {
143        self.inner.kill_query(query_id).await
144    }
145
146    pub fn query(&self, sql: &str) -> QueryBuilder<'_> {
147        QueryBuilder::new(self, sql)
148    }
149
150    pub fn exec(&self, sql: &str) -> ExecBuilder<'_> {
151        ExecBuilder::new(self, sql)
152    }
153
154    pub async fn query_iter(&self, sql: &str) -> Result<RowIterator> {
155        QueryBuilder::new(self, sql).iter().await
156    }
157
158    pub async fn query_iter_ext(&self, sql: &str) -> Result<RowStatsIterator> {
159        QueryBuilder::new(self, sql).iter_ext().await
160    }
161
162    pub async fn query_row(&self, sql: &str) -> Result<Option<Row>> {
163        QueryBuilder::new(self, sql).one().await
164    }
165
166    pub async fn query_all(&self, sql: &str) -> Result<Vec<Row>> {
167        QueryBuilder::new(self, sql).all().await
168    }
169
170    // raw data response query, only for test
171    pub async fn query_raw_iter(&self, sql: &str) -> Result<RawRowIterator> {
172        self.inner.query_raw_iter(sql).await
173    }
174
175    // raw data response query, only for test
176    pub async fn query_raw_all(&self, sql: &str) -> Result<Vec<RawRow>> {
177        self.inner.query_raw_all(sql).await
178    }
179
180    /// Get presigned url for a given operation and stage location.
181    /// The operation can be "UPLOAD" or "DOWNLOAD".
182    pub async fn get_presigned_url(
183        &self,
184        operation: &str,
185        stage: &str,
186    ) -> Result<PresignedResponse> {
187        self.inner.get_presigned_url(operation, stage).await
188    }
189
190    pub async fn upload_to_stage(&self, stage: &str, data: Reader, size: u64) -> Result<()> {
191        self.inner.upload_to_stage(stage, data, size).await
192    }
193
194    pub async fn load_data(
195        &self,
196        sql: &str,
197        data: Reader,
198        size: u64,
199        method: LoadMethod,
200    ) -> Result<ServerStats> {
201        self.inner.load_data(sql, data, size, method).await
202    }
203
204    pub async fn load_file(&self, sql: &str, fp: &Path, method: LoadMethod) -> Result<ServerStats> {
205        self.inner.load_file(sql, fp, method).await
206    }
207
208    pub async fn load_file_with_options(
209        &self,
210        sql: &str,
211        fp: &Path,
212        file_format_options: Option<BTreeMap<&str, &str>>,
213        copy_options: Option<BTreeMap<&str, &str>>,
214    ) -> Result<ServerStats> {
215        self.inner
216            .load_file_with_options(sql, fp, file_format_options, copy_options)
217            .await
218    }
219
220    pub async fn stream_load(
221        &self,
222        sql: &str,
223        data: Vec<Vec<&str>>,
224        method: LoadMethod,
225    ) -> Result<ServerStats> {
226        self.inner.stream_load(sql, data, method).await
227    }
228
229    pub fn set_warehouse(&self, warehouse: &str) -> Result<()> {
230        self.inner.set_warehouse(warehouse)
231    }
232
233    pub fn set_database(&self, database: &str) -> Result<()> {
234        self.inner.set_database(database)
235    }
236
237    pub fn set_role(&self, role: &str) -> Result<()> {
238        self.inner.set_role(role)
239    }
240
241    pub fn set_session(&self, key: &str, value: &str) -> Result<()> {
242        self.inner.set_session(key, value)
243    }
244
245    // PUT file://<path_to_file>/<filename> internalStage|externalStage
246    pub async fn put_files(&self, local_file: &str, stage: &str) -> Result<RowStatsIterator> {
247        self.inner.put_files(local_file, stage).await
248    }
249
250    pub async fn get_files(&self, stage: &str, local_file: &str) -> Result<RowStatsIterator> {
251        self.inner.get_files(stage, local_file).await
252    }
253
254    // ORM Methods
255    pub fn query_as<T>(&self, sql: &str) -> ORMQueryBuilder<'_, T>
256    where
257        T: TryFrom<Row> + RowORM,
258        T::Error: std::fmt::Display,
259    {
260        ORMQueryBuilder::new(self, sql)
261    }
262
263    pub async fn insert<T>(&self, table_name: &str) -> Result<InsertCursor<'_, T>>
264    where
265        T: Clone + RowORM,
266    {
267        Ok(InsertCursor::new(self, table_name.to_string()))
268    }
269}
270
271pub struct QueryCursor<T> {
272    iter: RowIterator,
273    _phantom: std::marker::PhantomData<T>,
274}
275
276impl<T> QueryCursor<T>
277where
278    T: TryFrom<Row>,
279    T::Error: std::fmt::Display,
280{
281    fn new(iter: RowIterator) -> Self {
282        Self {
283            iter,
284            _phantom: std::marker::PhantomData,
285        }
286    }
287
288    pub async fn fetch(&mut self) -> Result<Option<T>> {
289        use tokio_stream::StreamExt;
290        match self.iter.next().await {
291            Some(row) => {
292                let row = row?;
293                let typed_row = T::try_from(row).map_err(|e| Error::Parsing(e.to_string()))?;
294                Ok(Some(typed_row))
295            }
296            None => Ok(None),
297        }
298    }
299
300    pub async fn next(&mut self) -> Result<Option<T>> {
301        self.fetch().await
302    }
303
304    pub async fn fetch_all(self) -> Result<Vec<T>> {
305        self.iter.try_collect().await
306    }
307}
308
309pub struct InsertCursor<'a, T> {
310    connection: &'a Connection,
311    table_name: String,
312    rows: Vec<T>,
313    _phantom: std::marker::PhantomData<T>,
314}
315
316impl<'a, T> InsertCursor<'a, T>
317where
318    T: Clone + RowORM,
319{
320    fn new(connection: &'a Connection, table_name: String) -> Self {
321        Self {
322            connection,
323            table_name,
324            rows: Vec::new(),
325            _phantom: std::marker::PhantomData,
326        }
327    }
328
329    pub async fn write(&mut self, row: &T) -> Result<()> {
330        self.rows.push(row.clone());
331        Ok(())
332    }
333
334    pub async fn end(self) -> Result<i64> {
335        if self.rows.is_empty() {
336            return Ok(0);
337        }
338        let connection = self.connection;
339        // Generate field names and values for INSERT (exclude skip_serializing)
340        let field_names = T::insert_field_names();
341        let field_list = field_names.join(", ");
342        let placeholder_list = (0..field_names.len())
343            .map(|_| "?")
344            .collect::<Vec<_>>()
345            .join(", ");
346
347        let sql = format!(
348            "INSERT INTO {} ({}) VALUES ({})",
349            self.table_name, field_list, placeholder_list
350        );
351
352        let mut total_inserted = 0;
353        for row in &self.rows {
354            let values = row.to_values();
355            let json_values: Vec<serde_json::Value> =
356                values.into_iter().map(|v| v.to_json_value()).collect();
357            let params = Params::QuestionParams(json_values);
358            let inserted = connection.exec(&sql).bind(params).await?;
359            total_inserted += inserted;
360        }
361
362        Ok(total_inserted)
363    }
364}
365
366// Helper function to replace ?fields placeholder for queries
367fn replace_query_fields_placeholder(sql: &str, field_names: &[&str]) -> String {
368    let fields = field_names.join(", ");
369    sql.replace("?fields", &fields)
370}
371
372// Helper function to replace ?fields placeholder for inserts
373#[allow(dead_code)]
374fn replace_insert_fields_placeholder(sql: &str, field_names: &[&str]) -> String {
375    let fields = field_names.join(", ");
376    sql.replace("?fields", &fields)
377}
378
379// ORM Query Builder
380pub struct ORMQueryBuilder<'a, T> {
381    connection: &'a Connection,
382    sql: String,
383    params: Option<Params>,
384    _phantom: std::marker::PhantomData<T>,
385}
386
387impl<'a, T> ORMQueryBuilder<'a, T>
388where
389    T: TryFrom<Row> + RowORM,
390    T::Error: std::fmt::Display,
391{
392    fn new(connection: &'a Connection, sql: &str) -> Self {
393        Self {
394            connection,
395            sql: sql.to_string(),
396            params: None,
397            _phantom: std::marker::PhantomData,
398        }
399    }
400
401    pub fn bind<P: Into<Params> + Send>(mut self, params: P) -> Self {
402        self.params = Some(params.into());
403        self
404    }
405
406    pub async fn execute(self) -> Result<QueryCursor<T>> {
407        let sql_with_fields = replace_query_fields_placeholder(&self.sql, &T::query_field_names());
408        let final_sql = if let Some(params) = self.params {
409            params.replace(&sql_with_fields)
410        } else {
411            sql_with_fields
412        };
413        let row_iter = self.connection.inner.query_iter(&final_sql).await?;
414        Ok(QueryCursor::new(row_iter))
415    }
416}
417
418impl<'a, T> std::future::IntoFuture for ORMQueryBuilder<'a, T>
419where
420    T: TryFrom<Row> + RowORM + Send + 'a,
421    T::Error: std::fmt::Display,
422{
423    type Output = Result<QueryCursor<T>>;
424    type IntoFuture =
425        std::pin::Pin<Box<dyn std::future::Future<Output = Self::Output> + Send + 'a>>;
426
427    fn into_future(self) -> Self::IntoFuture {
428        Box::pin(self.execute())
429    }
430}
431
432// Builder pattern for query operations
433pub struct QueryBuilder<'a> {
434    connection: &'a Connection,
435    sql: String,
436    params: Option<Params>,
437}
438
439impl<'a> QueryBuilder<'a> {
440    fn new(connection: &'a Connection, sql: &str) -> Self {
441        Self {
442            connection,
443            sql: sql.to_string(),
444            params: None,
445        }
446    }
447
448    pub fn bind<P: Into<Params> + Send>(mut self, params: P) -> Self {
449        self.params = Some(params.into());
450        self
451    }
452
453    pub async fn iter(self) -> Result<RowIterator> {
454        if let Some(params) = &self.params {
455            if self.should_use_server_side_params() {
456                let json_params = params.to_json_value();
457                return self
458                    .connection
459                    .inner
460                    .query_iter_with_params(&self.sql, Some(json_params))
461                    .await;
462            }
463        }
464        let sql = self.get_final_sql();
465        self.connection.inner.query_iter(&sql).await
466    }
467
468    pub async fn iter_ext(self) -> Result<RowStatsIterator> {
469        if let Some(params) = &self.params {
470            if self.should_use_server_side_params() {
471                let json_params = params.to_json_value();
472                return self
473                    .connection
474                    .inner
475                    .query_iter_ext_with_params(&self.sql, Some(json_params))
476                    .await;
477            }
478        }
479        let sql = self.get_final_sql();
480        self.connection.inner.query_iter_ext(&sql).await
481    }
482
483    pub async fn one(self) -> Result<Option<Row>> {
484        if let Some(params) = &self.params {
485            if self.should_use_server_side_params() {
486                let json_params = params.to_json_value();
487                let mut rows = self
488                    .connection
489                    .inner
490                    .query_iter_with_params(&self.sql, Some(json_params))
491                    .await?;
492                return match rows.next().await {
493                    Some(r) => Ok(Some(r?)),
494                    None => Ok(None),
495                };
496            }
497        }
498        let sql = self.get_final_sql();
499        self.connection.inner.query_row(&sql).await
500    }
501
502    pub async fn all(self) -> Result<Vec<Row>> {
503        if let Some(params) = &self.params {
504            if self.should_use_server_side_params() {
505                let json_params = params.to_json_value();
506                let rows = self
507                    .connection
508                    .inner
509                    .query_iter_with_params(&self.sql, Some(json_params))
510                    .await?;
511                return rows.collect().await;
512            }
513        }
514        let sql = self.get_final_sql();
515        self.connection.inner.query_all(&sql).await
516    }
517
518    pub async fn cursor_as<T>(self) -> Result<QueryCursor<T>>
519    where
520        T: TryFrom<Row> + RowORM,
521        T::Error: std::fmt::Display,
522    {
523        let sql_with_fields = replace_query_fields_placeholder(&self.sql, &T::query_field_names());
524        let final_sql = if let Some(params) = self.params {
525            params.replace(&sql_with_fields)
526        } else {
527            sql_with_fields
528        };
529        let row_iter = self.connection.inner.query_iter(&final_sql).await?;
530        Ok(QueryCursor::new(row_iter))
531    }
532
533    fn should_use_server_side_params(&self) -> bool {
534        self.connection.inner.supports_server_side_params()
535            && !sql_has_dollar_placeholders(&self.sql)
536    }
537
538    fn get_final_sql(&self) -> String {
539        match &self.params {
540            Some(params) => params.replace(&self.sql),
541            None => self.sql.clone(),
542        }
543    }
544}
545
546// Builder pattern for execution operations
547pub struct ExecBuilder<'a> {
548    connection: &'a Connection,
549    sql: String,
550    params: Option<Params>,
551}
552
553impl<'a> ExecBuilder<'a> {
554    fn new(connection: &'a Connection, sql: &str) -> Self {
555        Self {
556            connection,
557            sql: sql.to_string(),
558            params: None,
559        }
560    }
561
562    pub fn bind<P: Into<Params> + Send>(mut self, params: P) -> Self {
563        self.params = Some(params.into());
564        self
565    }
566
567    pub async fn execute(self) -> Result<i64> {
568        if let Some(ref params) = self.params {
569            if self.should_use_server_side_params() {
570                let json_params = params.to_json_value();
571                return self
572                    .connection
573                    .inner
574                    .exec_with_params(&self.sql, Some(json_params))
575                    .await;
576            }
577        }
578        let sql = match self.params {
579            Some(params) => params.replace(&self.sql),
580            None => self.sql,
581        };
582        self.connection.inner.exec(&sql).await
583    }
584
585    fn should_use_server_side_params(&self) -> bool {
586        self.connection.inner.supports_server_side_params()
587            && !sql_has_dollar_placeholders(&self.sql)
588    }
589}
590
591impl<'a> std::future::IntoFuture for ExecBuilder<'a> {
592    type Output = Result<i64>;
593    type IntoFuture =
594        std::pin::Pin<Box<dyn std::future::Future<Output = Self::Output> + Send + 'a>>;
595
596    fn into_future(self) -> Self::IntoFuture {
597        Box::pin(self.execute())
598    }
599}
600
601fn sql_has_dollar_placeholders(sql: &str) -> bool {
602    let tokens = match databend_common_ast::parser::tokenize_sql(sql) {
603        Ok(t) => t,
604        Err(_) => return false,
605    };
606    if let Ok((stmt, _)) = databend_common_ast::parser::parse_sql(&tokens, Dialect::PostgreSQL) {
607        let mut visitor = PlaceholderVisitor::new();
608        return visitor.has_dollar_positions(&stmt);
609    }
610    false
611}
612
613// Add trait bounds for ORM functionality
614pub trait RowORM: TryFrom<Row> + Clone {
615    fn field_names() -> Vec<&'static str>; // For backward compatibility
616    fn query_field_names() -> Vec<&'static str>; // For SELECT queries (exclude skip_deserializing)
617    fn insert_field_names() -> Vec<&'static str>; // For INSERT statements (exclude skip_serializing)
618    fn to_values(&self) -> Vec<Value>;
619}