databend_driver/
client.rs

1// Copyright 2021 Datafuse Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use log::error;
16use once_cell::sync::Lazy;
17use std::collections::BTreeMap;
18use std::path::Path;
19use std::str::FromStr;
20use url::Url;
21
22use crate::conn::IConnection;
23#[cfg(feature = "flight-sql")]
24use crate::flight_sql::FlightSQLConnection;
25use crate::ConnectionInfo;
26use crate::Params;
27
28use databend_client::PresignedResponse;
29use databend_driver_core::error::{Error, Result};
30use databend_driver_core::raw_rows::{RawRow, RawRowIterator};
31use databend_driver_core::rows::{Row, RowIterator, RowStatsIterator, ServerStats};
32use databend_driver_core::value::Value;
33
34use crate::rest_api::RestAPIConnection;
35
36static VERSION: Lazy<String> = Lazy::new(|| {
37    let version = option_env!("CARGO_PKG_VERSION").unwrap_or("unknown");
38    version.to_string()
39});
40
41#[derive(Clone, Copy, Debug, PartialEq)]
42pub enum LoadMethod {
43    Stage,
44    Streaming,
45}
46
47impl FromStr for LoadMethod {
48    type Err = Error;
49
50    fn from_str(s: &str) -> Result<Self, Self::Err> {
51        match s.to_lowercase().as_str() {
52            "stage" => Ok(LoadMethod::Stage),
53            "streaming" => Ok(LoadMethod::Streaming),
54            _ => Err(Error::BadArgument(format!("invalid load method: {s}"))),
55        }
56    }
57}
58
59#[derive(Clone)]
60pub struct Client {
61    dsn: String,
62    name: String,
63}
64
65use crate::conn::Reader;
66
67pub struct Connection {
68    inner: Box<dyn IConnection>,
69}
70
71impl Client {
72    pub fn new(dsn: String) -> Self {
73        let name = format!("databend-driver-rust/{}", VERSION.as_str());
74        Self { dsn, name }
75    }
76
77    pub fn with_name(mut self, name: String) -> Self {
78        self.name = name;
79        self
80    }
81
82    pub async fn get_conn(&self) -> Result<Connection> {
83        let u = Url::parse(&self.dsn)?;
84        match u.scheme() {
85            "databend" | "databend+http" | "databend+https" => {
86                let conn = RestAPIConnection::try_create(&self.dsn, self.name.clone()).await?;
87                Ok(Connection {
88                    inner: Box::new(conn),
89                })
90            }
91            #[cfg(feature = "flight-sql")]
92            "databend+flight" | "databend+grpc" => {
93                let conn = FlightSQLConnection::try_create(&self.dsn, self.name.clone()).await?;
94                Ok(Connection {
95                    inner: Box::new(conn),
96                })
97            }
98            _ => Err(Error::Parsing(format!(
99                "Unsupported scheme: {}",
100                u.scheme()
101            ))),
102        }
103    }
104}
105
106impl Drop for Connection {
107    fn drop(&mut self) {
108        if let Err(e) = self.inner.close_with_spawn() {
109            error!("fail to close connection: {}", e);
110        }
111    }
112}
113
114impl Connection {
115    pub fn inner(&self) -> &dyn IConnection {
116        self.inner.as_ref()
117    }
118
119    pub async fn info(&self) -> ConnectionInfo {
120        self.inner.info().await
121    }
122    pub async fn close(&self) -> Result<()> {
123        self.inner.close().await
124    }
125
126    pub fn last_query_id(&self) -> Option<String> {
127        self.inner.last_query_id()
128    }
129
130    pub async fn version(&self) -> Result<String> {
131        self.inner.version().await
132    }
133
134    pub fn format_sql<P: Into<Params> + Send>(&self, sql: &str, params: P) -> String {
135        let params = params.into();
136        params.replace(sql)
137    }
138
139    pub async fn kill_query(&self, query_id: &str) -> Result<()> {
140        self.inner.kill_query(query_id).await
141    }
142
143    pub fn query(&self, sql: &str) -> QueryBuilder<'_> {
144        QueryBuilder::new(self, sql)
145    }
146
147    pub fn exec(&self, sql: &str) -> ExecBuilder<'_> {
148        ExecBuilder::new(self, sql)
149    }
150
151    pub async fn query_iter(&self, sql: &str) -> Result<RowIterator> {
152        QueryBuilder::new(self, sql).iter().await
153    }
154
155    pub async fn query_iter_ext(&self, sql: &str) -> Result<RowStatsIterator> {
156        QueryBuilder::new(self, sql).iter_ext().await
157    }
158
159    pub async fn query_row(&self, sql: &str) -> Result<Option<Row>> {
160        QueryBuilder::new(self, sql).one().await
161    }
162
163    pub async fn query_all(&self, sql: &str) -> Result<Vec<Row>> {
164        QueryBuilder::new(self, sql).all().await
165    }
166
167    // raw data response query, only for test
168    pub async fn query_raw_iter(&self, sql: &str) -> Result<RawRowIterator> {
169        self.inner.query_raw_iter(sql).await
170    }
171
172    // raw data response query, only for test
173    pub async fn query_raw_all(&self, sql: &str) -> Result<Vec<RawRow>> {
174        self.inner.query_raw_all(sql).await
175    }
176
177    /// Get presigned url for a given operation and stage location.
178    /// The operation can be "UPLOAD" or "DOWNLOAD".
179    pub async fn get_presigned_url(
180        &self,
181        operation: &str,
182        stage: &str,
183    ) -> Result<PresignedResponse> {
184        self.inner.get_presigned_url(operation, stage).await
185    }
186
187    pub async fn upload_to_stage(&self, stage: &str, data: Reader, size: u64) -> Result<()> {
188        self.inner.upload_to_stage(stage, data, size).await
189    }
190
191    pub async fn load_data(
192        &self,
193        sql: &str,
194        data: Reader,
195        size: u64,
196        method: LoadMethod,
197    ) -> Result<ServerStats> {
198        self.inner.load_data(sql, data, size, method).await
199    }
200
201    pub async fn load_file(&self, sql: &str, fp: &Path, method: LoadMethod) -> Result<ServerStats> {
202        self.inner.load_file(sql, fp, method).await
203    }
204
205    pub async fn load_file_with_options(
206        &self,
207        sql: &str,
208        fp: &Path,
209        file_format_options: Option<BTreeMap<&str, &str>>,
210        copy_options: Option<BTreeMap<&str, &str>>,
211    ) -> Result<ServerStats> {
212        self.inner
213            .load_file_with_options(sql, fp, file_format_options, copy_options)
214            .await
215    }
216
217    pub async fn stream_load(
218        &self,
219        sql: &str,
220        data: Vec<Vec<&str>>,
221        method: LoadMethod,
222    ) -> Result<ServerStats> {
223        self.inner.stream_load(sql, data, method).await
224    }
225
226    // PUT file://<path_to_file>/<filename> internalStage|externalStage
227    pub async fn put_files(&self, local_file: &str, stage: &str) -> Result<RowStatsIterator> {
228        self.inner.put_files(local_file, stage).await
229    }
230
231    pub async fn get_files(&self, stage: &str, local_file: &str) -> Result<RowStatsIterator> {
232        self.inner.get_files(stage, local_file).await
233    }
234
235    // ORM Methods
236    pub fn query_as<T>(&self, sql: &str) -> ORMQueryBuilder<'_, T>
237    where
238        T: TryFrom<Row> + RowORM,
239        T::Error: std::fmt::Display,
240    {
241        ORMQueryBuilder::new(self, sql)
242    }
243
244    pub async fn insert<T>(&self, table_name: &str) -> Result<InsertCursor<'_, T>>
245    where
246        T: Clone + RowORM,
247    {
248        Ok(InsertCursor::new(self, table_name.to_string()))
249    }
250}
251
252pub struct QueryCursor<T> {
253    iter: RowIterator,
254    _phantom: std::marker::PhantomData<T>,
255}
256
257impl<T> QueryCursor<T>
258where
259    T: TryFrom<Row>,
260    T::Error: std::fmt::Display,
261{
262    fn new(iter: RowIterator) -> Self {
263        Self {
264            iter,
265            _phantom: std::marker::PhantomData,
266        }
267    }
268
269    pub async fn fetch(&mut self) -> Result<Option<T>> {
270        use tokio_stream::StreamExt;
271        match self.iter.next().await {
272            Some(row) => {
273                let row = row?;
274                let typed_row = T::try_from(row).map_err(|e| Error::Parsing(e.to_string()))?;
275                Ok(Some(typed_row))
276            }
277            None => Ok(None),
278        }
279    }
280
281    pub async fn next(&mut self) -> Result<Option<T>> {
282        self.fetch().await
283    }
284
285    pub async fn fetch_all(self) -> Result<Vec<T>> {
286        self.iter.try_collect().await
287    }
288}
289
290pub struct InsertCursor<'a, T> {
291    connection: &'a Connection,
292    table_name: String,
293    rows: Vec<T>,
294    _phantom: std::marker::PhantomData<T>,
295}
296
297impl<'a, T> InsertCursor<'a, T>
298where
299    T: Clone + RowORM,
300{
301    fn new(connection: &'a Connection, table_name: String) -> Self {
302        Self {
303            connection,
304            table_name,
305            rows: Vec::new(),
306            _phantom: std::marker::PhantomData,
307        }
308    }
309
310    pub async fn write(&mut self, row: &T) -> Result<()> {
311        self.rows.push(row.clone());
312        Ok(())
313    }
314
315    pub async fn end(self) -> Result<i64> {
316        if self.rows.is_empty() {
317            return Ok(0);
318        }
319        let connection = self.connection;
320        // Generate field names and values for INSERT (exclude skip_serializing)
321        let field_names = T::insert_field_names();
322        let field_list = field_names.join(", ");
323        let placeholder_list = (0..field_names.len())
324            .map(|_| "?")
325            .collect::<Vec<_>>()
326            .join(", ");
327
328        let sql = format!(
329            "INSERT INTO {} ({}) VALUES ({})",
330            self.table_name, field_list, placeholder_list
331        );
332
333        let mut total_inserted = 0;
334        for row in &self.rows {
335            let values = row.to_values();
336            let param_strings: Vec<String> =
337                values.into_iter().map(|v| v.to_sql_string()).collect();
338            let params = Params::QuestionParams(param_strings);
339            let inserted = connection.exec(&sql).bind(params).await?;
340            total_inserted += inserted;
341        }
342
343        Ok(total_inserted)
344    }
345}
346
347// Helper function to replace ?fields placeholder for queries
348fn replace_query_fields_placeholder(sql: &str, field_names: &[&str]) -> String {
349    let fields = field_names.join(", ");
350    sql.replace("?fields", &fields)
351}
352
353// Helper function to replace ?fields placeholder for inserts
354#[allow(dead_code)]
355fn replace_insert_fields_placeholder(sql: &str, field_names: &[&str]) -> String {
356    let fields = field_names.join(", ");
357    sql.replace("?fields", &fields)
358}
359
360// ORM Query Builder
361pub struct ORMQueryBuilder<'a, T> {
362    connection: &'a Connection,
363    sql: String,
364    params: Option<Params>,
365    _phantom: std::marker::PhantomData<T>,
366}
367
368impl<'a, T> ORMQueryBuilder<'a, T>
369where
370    T: TryFrom<Row> + RowORM,
371    T::Error: std::fmt::Display,
372{
373    fn new(connection: &'a Connection, sql: &str) -> Self {
374        Self {
375            connection,
376            sql: sql.to_string(),
377            params: None,
378            _phantom: std::marker::PhantomData,
379        }
380    }
381
382    pub fn bind<P: Into<Params> + Send>(mut self, params: P) -> Self {
383        self.params = Some(params.into());
384        self
385    }
386
387    pub async fn execute(self) -> Result<QueryCursor<T>> {
388        let sql_with_fields = replace_query_fields_placeholder(&self.sql, &T::query_field_names());
389        let final_sql = if let Some(params) = self.params {
390            params.replace(&sql_with_fields)
391        } else {
392            sql_with_fields
393        };
394        let row_iter = self.connection.inner.query_iter(&final_sql).await?;
395        Ok(QueryCursor::new(row_iter))
396    }
397}
398
399impl<'a, T> std::future::IntoFuture for ORMQueryBuilder<'a, T>
400where
401    T: TryFrom<Row> + RowORM + Send + 'a,
402    T::Error: std::fmt::Display,
403{
404    type Output = Result<QueryCursor<T>>;
405    type IntoFuture =
406        std::pin::Pin<Box<dyn std::future::Future<Output = Self::Output> + Send + 'a>>;
407
408    fn into_future(self) -> Self::IntoFuture {
409        Box::pin(self.execute())
410    }
411}
412
413// Builder pattern for query operations
414pub struct QueryBuilder<'a> {
415    connection: &'a Connection,
416    sql: String,
417    params: Option<Params>,
418}
419
420impl<'a> QueryBuilder<'a> {
421    fn new(connection: &'a Connection, sql: &str) -> Self {
422        Self {
423            connection,
424            sql: sql.to_string(),
425            params: None,
426        }
427    }
428
429    pub fn bind<P: Into<Params> + Send>(mut self, params: P) -> Self {
430        self.params = Some(params.into());
431        self
432    }
433
434    pub async fn iter(self) -> Result<RowIterator> {
435        let sql = self.get_final_sql();
436        self.connection.inner.query_iter(&sql).await
437    }
438
439    pub async fn iter_ext(self) -> Result<RowStatsIterator> {
440        let sql = self.get_final_sql();
441        self.connection.inner.query_iter_ext(&sql).await
442    }
443
444    pub async fn one(self) -> Result<Option<Row>> {
445        let sql = self.get_final_sql();
446        self.connection.inner.query_row(&sql).await
447    }
448
449    pub async fn all(self) -> Result<Vec<Row>> {
450        let sql = self.get_final_sql();
451        self.connection.inner.query_all(&sql).await
452    }
453
454    pub async fn cursor_as<T>(self) -> Result<QueryCursor<T>>
455    where
456        T: TryFrom<Row> + RowORM,
457        T::Error: std::fmt::Display,
458    {
459        let sql_with_fields = replace_query_fields_placeholder(&self.sql, &T::query_field_names());
460        let final_sql = if let Some(params) = self.params {
461            params.replace(&sql_with_fields)
462        } else {
463            sql_with_fields
464        };
465        let row_iter = self.connection.inner.query_iter(&final_sql).await?;
466        Ok(QueryCursor::new(row_iter))
467    }
468
469    fn get_final_sql(&self) -> String {
470        match &self.params {
471            Some(params) => params.replace(&self.sql),
472            None => self.sql.clone(),
473        }
474    }
475}
476
477// Builder pattern for execution operations
478pub struct ExecBuilder<'a> {
479    connection: &'a Connection,
480    sql: String,
481    params: Option<Params>,
482}
483
484impl<'a> ExecBuilder<'a> {
485    fn new(connection: &'a Connection, sql: &str) -> Self {
486        Self {
487            connection,
488            sql: sql.to_string(),
489            params: None,
490        }
491    }
492
493    pub fn bind<P: Into<Params> + Send>(mut self, params: P) -> Self {
494        self.params = Some(params.into());
495        self
496    }
497
498    pub async fn execute(self) -> Result<i64> {
499        let sql = match self.params {
500            Some(params) => params.replace(&self.sql),
501            None => self.sql,
502        };
503        self.connection.inner.exec(&sql).await
504    }
505}
506
507impl<'a> std::future::IntoFuture for ExecBuilder<'a> {
508    type Output = Result<i64>;
509    type IntoFuture =
510        std::pin::Pin<Box<dyn std::future::Future<Output = Self::Output> + Send + 'a>>;
511
512    fn into_future(self) -> Self::IntoFuture {
513        Box::pin(self.execute())
514    }
515}
516
517// Add trait bounds for ORM functionality
518pub trait RowORM: TryFrom<Row> + Clone {
519    fn field_names() -> Vec<&'static str>; // For backward compatibility
520    fn query_field_names() -> Vec<&'static str>; // For SELECT queries (exclude skip_deserializing)
521    fn insert_field_names() -> Vec<&'static str>; // For INSERT statements (exclude skip_serializing)
522    fn to_values(&self) -> Vec<Value>;
523}