databend_driver/
client.rs

1// Copyright 2021 Datafuse Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use log::error;
16use once_cell::sync::Lazy;
17use std::collections::BTreeMap;
18use std::path::Path;
19use std::str::FromStr;
20use url::Url;
21
22use crate::conn::IConnection;
23#[cfg(feature = "flight-sql")]
24use crate::flight_sql::FlightSQLConnection;
25use crate::ConnectionInfo;
26use crate::Params;
27
28use databend_client::PresignedResponse;
29use databend_driver_core::error::{Error, Result};
30use databend_driver_core::raw_rows::{RawRow, RawRowIterator};
31use databend_driver_core::rows::{Row, RowIterator, RowStatsIterator, ServerStats};
32use databend_driver_core::value::Value;
33
34use crate::rest_api::RestAPIConnection;
35
36static VERSION: Lazy<String> = Lazy::new(|| {
37    let version = option_env!("CARGO_PKG_VERSION").unwrap_or("unknown");
38    version.to_string()
39});
40
41#[derive(Clone, Copy, Debug, PartialEq)]
42pub enum LoadMethod {
43    Stage,
44    Streaming,
45}
46
47impl FromStr for LoadMethod {
48    type Err = Error;
49
50    fn from_str(s: &str) -> Result<Self, Self::Err> {
51        match s.to_lowercase().as_str() {
52            "stage" => Ok(LoadMethod::Stage),
53            "streaming" => Ok(LoadMethod::Streaming),
54            _ => Err(Error::BadArgument(format!("invalid load method: {s}"))),
55        }
56    }
57}
58
59#[derive(Clone)]
60pub struct Client {
61    dsn: String,
62    name: String,
63}
64
65use crate::conn::Reader;
66
67pub struct Connection {
68    inner: Box<dyn IConnection>,
69}
70
71impl Client {
72    pub fn new(dsn: String) -> Self {
73        let name = format!("databend-driver-rust/{}", VERSION.as_str());
74        Self { dsn, name }
75    }
76
77    pub fn with_name(mut self, name: String) -> Self {
78        self.name = name;
79        self
80    }
81
82    pub async fn get_conn(&self) -> Result<Connection> {
83        let u = Url::parse(&self.dsn)?;
84        match u.scheme() {
85            "databend" | "databend+http" | "databend+https" => {
86                let conn = RestAPIConnection::try_create(&self.dsn, self.name.clone()).await?;
87                Ok(Connection {
88                    inner: Box::new(conn),
89                })
90            }
91            #[cfg(feature = "flight-sql")]
92            "databend+flight" | "databend+grpc" => {
93                let conn = FlightSQLConnection::try_create(&self.dsn, self.name.clone()).await?;
94                Ok(Connection {
95                    inner: Box::new(conn),
96                })
97            }
98            _ => Err(Error::Parsing(format!(
99                "Unsupported scheme: {}",
100                u.scheme()
101            ))),
102        }
103    }
104}
105
106impl Drop for Connection {
107    fn drop(&mut self) {
108        if let Err(e) = self.inner.close_with_spawn() {
109            error!("fail to close connection: {}", e);
110        }
111    }
112}
113
114impl Connection {
115    pub fn inner(&self) -> &dyn IConnection {
116        self.inner.as_ref()
117    }
118
119    pub async fn info(&self) -> ConnectionInfo {
120        self.inner.info().await
121    }
122    pub async fn close(&self) -> Result<()> {
123        self.inner.close().await
124    }
125
126    pub fn last_query_id(&self) -> Option<String> {
127        self.inner.last_query_id()
128    }
129
130    pub async fn version(&self) -> Result<String> {
131        self.inner.version().await
132    }
133
134    pub fn format_sql<P: Into<Params> + Send>(&self, sql: &str, params: P) -> String {
135        let params = params.into();
136        params.replace(sql)
137    }
138
139    pub async fn kill_query(&self, query_id: &str) -> Result<()> {
140        self.inner.kill_query(query_id).await
141    }
142
143    pub fn query(&self, sql: &str) -> QueryBuilder<'_> {
144        QueryBuilder::new(self, sql)
145    }
146
147    pub fn exec(&self, sql: &str) -> ExecBuilder<'_> {
148        ExecBuilder::new(self, sql)
149    }
150
151    pub async fn query_iter(&self, sql: &str) -> Result<RowIterator> {
152        QueryBuilder::new(self, sql).iter().await
153    }
154
155    pub async fn query_iter_ext(&self, sql: &str) -> Result<RowStatsIterator> {
156        QueryBuilder::new(self, sql).iter_ext().await
157    }
158
159    pub async fn query_row(&self, sql: &str) -> Result<Option<Row>> {
160        QueryBuilder::new(self, sql).one().await
161    }
162
163    pub async fn query_all(&self, sql: &str) -> Result<Vec<Row>> {
164        QueryBuilder::new(self, sql).all().await
165    }
166
167    // raw data response query, only for test
168    pub async fn query_raw_iter(&self, sql: &str) -> Result<RawRowIterator> {
169        self.inner.query_raw_iter(sql).await
170    }
171
172    // raw data response query, only for test
173    pub async fn query_raw_all(&self, sql: &str) -> Result<Vec<RawRow>> {
174        self.inner.query_raw_all(sql).await
175    }
176
177    /// Get presigned url for a given operation and stage location.
178    /// The operation can be "UPLOAD" or "DOWNLOAD".
179    pub async fn get_presigned_url(
180        &self,
181        operation: &str,
182        stage: &str,
183    ) -> Result<PresignedResponse> {
184        self.inner.get_presigned_url(operation, stage).await
185    }
186
187    pub async fn upload_to_stage(&self, stage: &str, data: Reader, size: u64) -> Result<()> {
188        self.inner.upload_to_stage(stage, data, size).await
189    }
190
191    pub async fn load_data(
192        &self,
193        sql: &str,
194        data: Reader,
195        size: u64,
196        method: LoadMethod,
197    ) -> Result<ServerStats> {
198        self.inner.load_data(sql, data, size, method).await
199    }
200
201    pub async fn load_file(&self, sql: &str, fp: &Path, method: LoadMethod) -> Result<ServerStats> {
202        self.inner.load_file(sql, fp, method).await
203    }
204
205    pub async fn load_file_with_options(
206        &self,
207        sql: &str,
208        fp: &Path,
209        file_format_options: Option<BTreeMap<&str, &str>>,
210        copy_options: Option<BTreeMap<&str, &str>>,
211    ) -> Result<ServerStats> {
212        self.inner
213            .load_file_with_options(sql, fp, file_format_options, copy_options)
214            .await
215    }
216
217    pub async fn stream_load(
218        &self,
219        sql: &str,
220        data: Vec<Vec<&str>>,
221        method: LoadMethod,
222    ) -> Result<ServerStats> {
223        self.inner.stream_load(sql, data, method).await
224    }
225
226    pub fn set_warehouse(&self, warehouse: &str) -> Result<()> {
227        self.inner.set_warehouse(warehouse)
228    }
229
230    pub fn set_database(&self, database: &str) -> Result<()> {
231        self.inner.set_database(database)
232    }
233
234    pub fn set_role(&self, role: &str) -> Result<()> {
235        self.inner.set_role(role)
236    }
237
238    pub fn set_session(&self, key: &str, value: &str) -> Result<()> {
239        self.inner.set_session(key, value)
240    }
241
242    // PUT file://<path_to_file>/<filename> internalStage|externalStage
243    pub async fn put_files(&self, local_file: &str, stage: &str) -> Result<RowStatsIterator> {
244        self.inner.put_files(local_file, stage).await
245    }
246
247    pub async fn get_files(&self, stage: &str, local_file: &str) -> Result<RowStatsIterator> {
248        self.inner.get_files(stage, local_file).await
249    }
250
251    // ORM Methods
252    pub fn query_as<T>(&self, sql: &str) -> ORMQueryBuilder<'_, T>
253    where
254        T: TryFrom<Row> + RowORM,
255        T::Error: std::fmt::Display,
256    {
257        ORMQueryBuilder::new(self, sql)
258    }
259
260    pub async fn insert<T>(&self, table_name: &str) -> Result<InsertCursor<'_, T>>
261    where
262        T: Clone + RowORM,
263    {
264        Ok(InsertCursor::new(self, table_name.to_string()))
265    }
266}
267
268pub struct QueryCursor<T> {
269    iter: RowIterator,
270    _phantom: std::marker::PhantomData<T>,
271}
272
273impl<T> QueryCursor<T>
274where
275    T: TryFrom<Row>,
276    T::Error: std::fmt::Display,
277{
278    fn new(iter: RowIterator) -> Self {
279        Self {
280            iter,
281            _phantom: std::marker::PhantomData,
282        }
283    }
284
285    pub async fn fetch(&mut self) -> Result<Option<T>> {
286        use tokio_stream::StreamExt;
287        match self.iter.next().await {
288            Some(row) => {
289                let row = row?;
290                let typed_row = T::try_from(row).map_err(|e| Error::Parsing(e.to_string()))?;
291                Ok(Some(typed_row))
292            }
293            None => Ok(None),
294        }
295    }
296
297    pub async fn next(&mut self) -> Result<Option<T>> {
298        self.fetch().await
299    }
300
301    pub async fn fetch_all(self) -> Result<Vec<T>> {
302        self.iter.try_collect().await
303    }
304}
305
306pub struct InsertCursor<'a, T> {
307    connection: &'a Connection,
308    table_name: String,
309    rows: Vec<T>,
310    _phantom: std::marker::PhantomData<T>,
311}
312
313impl<'a, T> InsertCursor<'a, T>
314where
315    T: Clone + RowORM,
316{
317    fn new(connection: &'a Connection, table_name: String) -> Self {
318        Self {
319            connection,
320            table_name,
321            rows: Vec::new(),
322            _phantom: std::marker::PhantomData,
323        }
324    }
325
326    pub async fn write(&mut self, row: &T) -> Result<()> {
327        self.rows.push(row.clone());
328        Ok(())
329    }
330
331    pub async fn end(self) -> Result<i64> {
332        if self.rows.is_empty() {
333            return Ok(0);
334        }
335        let connection = self.connection;
336        // Generate field names and values for INSERT (exclude skip_serializing)
337        let field_names = T::insert_field_names();
338        let field_list = field_names.join(", ");
339        let placeholder_list = (0..field_names.len())
340            .map(|_| "?")
341            .collect::<Vec<_>>()
342            .join(", ");
343
344        let sql = format!(
345            "INSERT INTO {} ({}) VALUES ({})",
346            self.table_name, field_list, placeholder_list
347        );
348
349        let mut total_inserted = 0;
350        for row in &self.rows {
351            let values = row.to_values();
352            let param_strings: Vec<String> =
353                values.into_iter().map(|v| v.to_sql_string()).collect();
354            let params = Params::QuestionParams(param_strings);
355            let inserted = connection.exec(&sql).bind(params).await?;
356            total_inserted += inserted;
357        }
358
359        Ok(total_inserted)
360    }
361}
362
363// Helper function to replace ?fields placeholder for queries
364fn replace_query_fields_placeholder(sql: &str, field_names: &[&str]) -> String {
365    let fields = field_names.join(", ");
366    sql.replace("?fields", &fields)
367}
368
369// Helper function to replace ?fields placeholder for inserts
370#[allow(dead_code)]
371fn replace_insert_fields_placeholder(sql: &str, field_names: &[&str]) -> String {
372    let fields = field_names.join(", ");
373    sql.replace("?fields", &fields)
374}
375
376// ORM Query Builder
377pub struct ORMQueryBuilder<'a, T> {
378    connection: &'a Connection,
379    sql: String,
380    params: Option<Params>,
381    _phantom: std::marker::PhantomData<T>,
382}
383
384impl<'a, T> ORMQueryBuilder<'a, T>
385where
386    T: TryFrom<Row> + RowORM,
387    T::Error: std::fmt::Display,
388{
389    fn new(connection: &'a Connection, sql: &str) -> Self {
390        Self {
391            connection,
392            sql: sql.to_string(),
393            params: None,
394            _phantom: std::marker::PhantomData,
395        }
396    }
397
398    pub fn bind<P: Into<Params> + Send>(mut self, params: P) -> Self {
399        self.params = Some(params.into());
400        self
401    }
402
403    pub async fn execute(self) -> Result<QueryCursor<T>> {
404        let sql_with_fields = replace_query_fields_placeholder(&self.sql, &T::query_field_names());
405        let final_sql = if let Some(params) = self.params {
406            params.replace(&sql_with_fields)
407        } else {
408            sql_with_fields
409        };
410        let row_iter = self.connection.inner.query_iter(&final_sql).await?;
411        Ok(QueryCursor::new(row_iter))
412    }
413}
414
415impl<'a, T> std::future::IntoFuture for ORMQueryBuilder<'a, T>
416where
417    T: TryFrom<Row> + RowORM + Send + 'a,
418    T::Error: std::fmt::Display,
419{
420    type Output = Result<QueryCursor<T>>;
421    type IntoFuture =
422        std::pin::Pin<Box<dyn std::future::Future<Output = Self::Output> + Send + 'a>>;
423
424    fn into_future(self) -> Self::IntoFuture {
425        Box::pin(self.execute())
426    }
427}
428
429// Builder pattern for query operations
430pub struct QueryBuilder<'a> {
431    connection: &'a Connection,
432    sql: String,
433    params: Option<Params>,
434}
435
436impl<'a> QueryBuilder<'a> {
437    fn new(connection: &'a Connection, sql: &str) -> Self {
438        Self {
439            connection,
440            sql: sql.to_string(),
441            params: None,
442        }
443    }
444
445    pub fn bind<P: Into<Params> + Send>(mut self, params: P) -> Self {
446        self.params = Some(params.into());
447        self
448    }
449
450    pub async fn iter(self) -> Result<RowIterator> {
451        let sql = self.get_final_sql();
452        self.connection.inner.query_iter(&sql).await
453    }
454
455    pub async fn iter_ext(self) -> Result<RowStatsIterator> {
456        let sql = self.get_final_sql();
457        self.connection.inner.query_iter_ext(&sql).await
458    }
459
460    pub async fn one(self) -> Result<Option<Row>> {
461        let sql = self.get_final_sql();
462        self.connection.inner.query_row(&sql).await
463    }
464
465    pub async fn all(self) -> Result<Vec<Row>> {
466        let sql = self.get_final_sql();
467        self.connection.inner.query_all(&sql).await
468    }
469
470    pub async fn cursor_as<T>(self) -> Result<QueryCursor<T>>
471    where
472        T: TryFrom<Row> + RowORM,
473        T::Error: std::fmt::Display,
474    {
475        let sql_with_fields = replace_query_fields_placeholder(&self.sql, &T::query_field_names());
476        let final_sql = if let Some(params) = self.params {
477            params.replace(&sql_with_fields)
478        } else {
479            sql_with_fields
480        };
481        let row_iter = self.connection.inner.query_iter(&final_sql).await?;
482        Ok(QueryCursor::new(row_iter))
483    }
484
485    fn get_final_sql(&self) -> String {
486        match &self.params {
487            Some(params) => params.replace(&self.sql),
488            None => self.sql.clone(),
489        }
490    }
491}
492
493// Builder pattern for execution operations
494pub struct ExecBuilder<'a> {
495    connection: &'a Connection,
496    sql: String,
497    params: Option<Params>,
498}
499
500impl<'a> ExecBuilder<'a> {
501    fn new(connection: &'a Connection, sql: &str) -> Self {
502        Self {
503            connection,
504            sql: sql.to_string(),
505            params: None,
506        }
507    }
508
509    pub fn bind<P: Into<Params> + Send>(mut self, params: P) -> Self {
510        self.params = Some(params.into());
511        self
512    }
513
514    pub async fn execute(self) -> Result<i64> {
515        let sql = match self.params {
516            Some(params) => params.replace(&self.sql),
517            None => self.sql,
518        };
519        self.connection.inner.exec(&sql).await
520    }
521}
522
523impl<'a> std::future::IntoFuture for ExecBuilder<'a> {
524    type Output = Result<i64>;
525    type IntoFuture =
526        std::pin::Pin<Box<dyn std::future::Future<Output = Self::Output> + Send + 'a>>;
527
528    fn into_future(self) -> Self::IntoFuture {
529        Box::pin(self.execute())
530    }
531}
532
533// Add trait bounds for ORM functionality
534pub trait RowORM: TryFrom<Row> + Clone {
535    fn field_names() -> Vec<&'static str>; // For backward compatibility
536    fn query_field_names() -> Vec<&'static str>; // For SELECT queries (exclude skip_deserializing)
537    fn insert_field_names() -> Vec<&'static str>; // For INSERT statements (exclude skip_serializing)
538    fn to_values(&self) -> Vec<Value>;
539}