databend_driver/
client.rs

1// Copyright 2021 Datafuse Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use once_cell::sync::Lazy;
16use std::collections::BTreeMap;
17use std::path::Path;
18use std::str::FromStr;
19use url::Url;
20
21use crate::conn::IConnection;
22#[cfg(feature = "flight-sql")]
23use crate::flight_sql::FlightSQLConnection;
24use crate::ConnectionInfo;
25use crate::Params;
26
27use databend_client::PresignedResponse;
28use databend_driver_core::error::{Error, Result};
29use databend_driver_core::raw_rows::{RawRow, RawRowIterator};
30use databend_driver_core::rows::{Row, RowIterator, RowStatsIterator, ServerStats};
31use databend_driver_core::value::Value;
32
33use crate::rest_api::RestAPIConnection;
34
35static VERSION: Lazy<String> = Lazy::new(|| {
36    let version = option_env!("CARGO_PKG_VERSION").unwrap_or("unknown");
37    version.to_string()
38});
39
40#[derive(Clone, Copy, Debug, PartialEq)]
41pub enum LoadMethod {
42    Stage,
43    Streaming,
44}
45
46impl FromStr for LoadMethod {
47    type Err = Error;
48
49    fn from_str(s: &str) -> Result<Self, Self::Err> {
50        match s.to_lowercase().as_str() {
51            "stage" => Ok(LoadMethod::Stage),
52            "streaming" => Ok(LoadMethod::Streaming),
53            _ => Err(Error::BadArgument(format!("invalid load method: {s}"))),
54        }
55    }
56}
57
58#[derive(Clone)]
59pub struct Client {
60    dsn: String,
61    name: String,
62}
63
64use crate::conn::Reader;
65
66pub struct Connection {
67    inner: Box<dyn IConnection>,
68}
69
70impl Client {
71    pub fn new(dsn: String) -> Self {
72        let name = format!("databend-driver-rust/{}", VERSION.as_str());
73        Self { dsn, name }
74    }
75
76    pub fn with_name(mut self, name: String) -> Self {
77        self.name = name;
78        self
79    }
80
81    pub async fn get_conn(&self) -> Result<Connection> {
82        let u = Url::parse(&self.dsn)?;
83        match u.scheme() {
84            "databend" | "databend+http" | "databend+https" => {
85                let conn = RestAPIConnection::try_create(&self.dsn, self.name.clone()).await?;
86                Ok(Connection {
87                    inner: Box::new(conn),
88                })
89            }
90            #[cfg(feature = "flight-sql")]
91            "databend+flight" | "databend+grpc" => {
92                let conn = FlightSQLConnection::try_create(&self.dsn, self.name.clone()).await?;
93                Ok(Connection {
94                    inner: Box::new(conn),
95                })
96            }
97            _ => Err(Error::Parsing(format!(
98                "Unsupported scheme: {}",
99                u.scheme()
100            ))),
101        }
102    }
103}
104
105impl Connection {
106    pub fn inner(&self) -> &dyn IConnection {
107        self.inner.as_ref()
108    }
109
110    pub async fn info(&self) -> ConnectionInfo {
111        self.inner.info().await
112    }
113    pub async fn close(&self) -> Result<()> {
114        self.inner.close().await
115    }
116
117    pub fn last_query_id(&self) -> Option<String> {
118        self.inner.last_query_id()
119    }
120
121    pub async fn version(&self) -> Result<String> {
122        self.inner.version().await
123    }
124
125    pub fn format_sql<P: Into<Params> + Send>(&self, sql: &str, params: P) -> String {
126        let params = params.into();
127        params.replace(sql)
128    }
129
130    pub async fn kill_query(&self, query_id: &str) -> Result<()> {
131        self.inner.kill_query(query_id).await
132    }
133
134    pub fn query(&self, sql: &str) -> QueryBuilder<'_> {
135        QueryBuilder::new(self, sql)
136    }
137
138    pub fn exec(&self, sql: &str) -> ExecBuilder<'_> {
139        ExecBuilder::new(self, sql)
140    }
141
142    pub async fn query_iter(&self, sql: &str) -> Result<RowIterator> {
143        QueryBuilder::new(self, sql).iter().await
144    }
145
146    pub async fn query_iter_ext(&self, sql: &str) -> Result<RowStatsIterator> {
147        QueryBuilder::new(self, sql).iter_ext().await
148    }
149
150    pub async fn query_row(&self, sql: &str) -> Result<Option<Row>> {
151        QueryBuilder::new(self, sql).one().await
152    }
153
154    pub async fn query_all(&self, sql: &str) -> Result<Vec<Row>> {
155        QueryBuilder::new(self, sql).all().await
156    }
157
158    // raw data response query, only for test
159    pub async fn query_raw_iter(&self, sql: &str) -> Result<RawRowIterator> {
160        self.inner.query_raw_iter(sql).await
161    }
162
163    // raw data response query, only for test
164    pub async fn query_raw_all(&self, sql: &str) -> Result<Vec<RawRow>> {
165        self.inner.query_raw_all(sql).await
166    }
167
168    /// Get presigned url for a given operation and stage location.
169    /// The operation can be "UPLOAD" or "DOWNLOAD".
170    pub async fn get_presigned_url(
171        &self,
172        operation: &str,
173        stage: &str,
174    ) -> Result<PresignedResponse> {
175        self.inner.get_presigned_url(operation, stage).await
176    }
177
178    pub async fn upload_to_stage(&self, stage: &str, data: Reader, size: u64) -> Result<()> {
179        self.inner.upload_to_stage(stage, data, size).await
180    }
181
182    pub async fn load_data(
183        &self,
184        sql: &str,
185        data: Reader,
186        size: u64,
187        method: LoadMethod,
188    ) -> Result<ServerStats> {
189        self.inner.load_data(sql, data, size, method).await
190    }
191
192    pub async fn load_file(&self, sql: &str, fp: &Path, method: LoadMethod) -> Result<ServerStats> {
193        self.inner.load_file(sql, fp, method).await
194    }
195
196    pub async fn load_file_with_options(
197        &self,
198        sql: &str,
199        fp: &Path,
200        file_format_options: Option<BTreeMap<&str, &str>>,
201        copy_options: Option<BTreeMap<&str, &str>>,
202    ) -> Result<ServerStats> {
203        self.inner
204            .load_file_with_options(sql, fp, file_format_options, copy_options)
205            .await
206    }
207
208    pub async fn stream_load(
209        &self,
210        sql: &str,
211        data: Vec<Vec<&str>>,
212        method: LoadMethod,
213    ) -> Result<ServerStats> {
214        self.inner.stream_load(sql, data, method).await
215    }
216
217    // PUT file://<path_to_file>/<filename> internalStage|externalStage
218    pub async fn put_files(&self, local_file: &str, stage: &str) -> Result<RowStatsIterator> {
219        self.inner.put_files(local_file, stage).await
220    }
221
222    pub async fn get_files(&self, stage: &str, local_file: &str) -> Result<RowStatsIterator> {
223        self.inner.get_files(stage, local_file).await
224    }
225
226    // ORM Methods
227    pub fn query_as<T>(&self, sql: &str) -> ORMQueryBuilder<'_, T>
228    where
229        T: TryFrom<Row> + RowORM,
230        T::Error: std::fmt::Display,
231    {
232        ORMQueryBuilder::new(self, sql)
233    }
234
235    pub async fn insert<T>(&self, table_name: &str) -> Result<InsertCursor<'_, T>>
236    where
237        T: Clone + RowORM,
238    {
239        Ok(InsertCursor::new(self, table_name.to_string()))
240    }
241}
242
243pub struct QueryCursor<T> {
244    iter: RowIterator,
245    _phantom: std::marker::PhantomData<T>,
246}
247
248impl<T> QueryCursor<T>
249where
250    T: TryFrom<Row>,
251    T::Error: std::fmt::Display,
252{
253    fn new(iter: RowIterator) -> Self {
254        Self {
255            iter,
256            _phantom: std::marker::PhantomData,
257        }
258    }
259
260    pub async fn fetch(&mut self) -> Result<Option<T>> {
261        use tokio_stream::StreamExt;
262        match self.iter.next().await {
263            Some(row) => {
264                let row = row?;
265                let typed_row = T::try_from(row).map_err(|e| Error::Parsing(e.to_string()))?;
266                Ok(Some(typed_row))
267            }
268            None => Ok(None),
269        }
270    }
271
272    pub async fn next(&mut self) -> Result<Option<T>> {
273        self.fetch().await
274    }
275
276    pub async fn fetch_all(self) -> Result<Vec<T>> {
277        self.iter.try_collect().await
278    }
279}
280
281pub struct InsertCursor<'a, T> {
282    connection: &'a Connection,
283    table_name: String,
284    rows: Vec<T>,
285    _phantom: std::marker::PhantomData<T>,
286}
287
288impl<'a, T> InsertCursor<'a, T>
289where
290    T: Clone + RowORM,
291{
292    fn new(connection: &'a Connection, table_name: String) -> Self {
293        Self {
294            connection,
295            table_name,
296            rows: Vec::new(),
297            _phantom: std::marker::PhantomData,
298        }
299    }
300
301    pub async fn write(&mut self, row: &T) -> Result<()> {
302        self.rows.push(row.clone());
303        Ok(())
304    }
305
306    pub async fn end(self) -> Result<i64> {
307        if self.rows.is_empty() {
308            return Ok(0);
309        }
310        let connection = self.connection;
311        // Generate field names and values for INSERT (exclude skip_serializing)
312        let field_names = T::insert_field_names();
313        let field_list = field_names.join(", ");
314        let placeholder_list = (0..field_names.len())
315            .map(|_| "?")
316            .collect::<Vec<_>>()
317            .join(", ");
318
319        let sql = format!(
320            "INSERT INTO {} ({}) VALUES ({})",
321            self.table_name, field_list, placeholder_list
322        );
323
324        let mut total_inserted = 0;
325        for row in &self.rows {
326            let values = row.to_values();
327            let param_strings: Vec<String> =
328                values.into_iter().map(|v| v.to_sql_string()).collect();
329            let params = Params::QuestionParams(param_strings);
330            let inserted = connection.exec(&sql).bind(params).await?;
331            total_inserted += inserted;
332        }
333
334        Ok(total_inserted)
335    }
336}
337
338// Helper function to replace ?fields placeholder for queries
339fn replace_query_fields_placeholder(sql: &str, field_names: &[&str]) -> String {
340    let fields = field_names.join(", ");
341    sql.replace("?fields", &fields)
342}
343
344// Helper function to replace ?fields placeholder for inserts
345#[allow(dead_code)]
346fn replace_insert_fields_placeholder(sql: &str, field_names: &[&str]) -> String {
347    let fields = field_names.join(", ");
348    sql.replace("?fields", &fields)
349}
350
351// ORM Query Builder
352pub struct ORMQueryBuilder<'a, T> {
353    connection: &'a Connection,
354    sql: String,
355    params: Option<Params>,
356    _phantom: std::marker::PhantomData<T>,
357}
358
359impl<'a, T> ORMQueryBuilder<'a, T>
360where
361    T: TryFrom<Row> + RowORM,
362    T::Error: std::fmt::Display,
363{
364    fn new(connection: &'a Connection, sql: &str) -> Self {
365        Self {
366            connection,
367            sql: sql.to_string(),
368            params: None,
369            _phantom: std::marker::PhantomData,
370        }
371    }
372
373    pub fn bind<P: Into<Params> + Send>(mut self, params: P) -> Self {
374        self.params = Some(params.into());
375        self
376    }
377
378    pub async fn execute(self) -> Result<QueryCursor<T>> {
379        let sql_with_fields = replace_query_fields_placeholder(&self.sql, &T::query_field_names());
380        let final_sql = if let Some(params) = self.params {
381            params.replace(&sql_with_fields)
382        } else {
383            sql_with_fields
384        };
385        let row_iter = self.connection.inner.query_iter(&final_sql).await?;
386        Ok(QueryCursor::new(row_iter))
387    }
388}
389
390impl<'a, T> std::future::IntoFuture for ORMQueryBuilder<'a, T>
391where
392    T: TryFrom<Row> + RowORM + Send + 'a,
393    T::Error: std::fmt::Display,
394{
395    type Output = Result<QueryCursor<T>>;
396    type IntoFuture =
397        std::pin::Pin<Box<dyn std::future::Future<Output = Self::Output> + Send + 'a>>;
398
399    fn into_future(self) -> Self::IntoFuture {
400        Box::pin(self.execute())
401    }
402}
403
404// Builder pattern for query operations
405pub struct QueryBuilder<'a> {
406    connection: &'a Connection,
407    sql: String,
408    params: Option<Params>,
409}
410
411impl<'a> QueryBuilder<'a> {
412    fn new(connection: &'a Connection, sql: &str) -> Self {
413        Self {
414            connection,
415            sql: sql.to_string(),
416            params: None,
417        }
418    }
419
420    pub fn bind<P: Into<Params> + Send>(mut self, params: P) -> Self {
421        self.params = Some(params.into());
422        self
423    }
424
425    pub async fn iter(self) -> Result<RowIterator> {
426        let sql = self.get_final_sql();
427        self.connection.inner.query_iter(&sql).await
428    }
429
430    pub async fn iter_ext(self) -> Result<RowStatsIterator> {
431        let sql = self.get_final_sql();
432        self.connection.inner.query_iter_ext(&sql).await
433    }
434
435    pub async fn one(self) -> Result<Option<Row>> {
436        let sql = self.get_final_sql();
437        self.connection.inner.query_row(&sql).await
438    }
439
440    pub async fn all(self) -> Result<Vec<Row>> {
441        let sql = self.get_final_sql();
442        self.connection.inner.query_all(&sql).await
443    }
444
445    pub async fn cursor_as<T>(self) -> Result<QueryCursor<T>>
446    where
447        T: TryFrom<Row> + RowORM,
448        T::Error: std::fmt::Display,
449    {
450        let sql_with_fields = replace_query_fields_placeholder(&self.sql, &T::query_field_names());
451        let final_sql = if let Some(params) = self.params {
452            params.replace(&sql_with_fields)
453        } else {
454            sql_with_fields
455        };
456        let row_iter = self.connection.inner.query_iter(&final_sql).await?;
457        Ok(QueryCursor::new(row_iter))
458    }
459
460    fn get_final_sql(&self) -> String {
461        match &self.params {
462            Some(params) => params.replace(&self.sql),
463            None => self.sql.clone(),
464        }
465    }
466}
467
468// Builder pattern for execution operations
469pub struct ExecBuilder<'a> {
470    connection: &'a Connection,
471    sql: String,
472    params: Option<Params>,
473}
474
475impl<'a> ExecBuilder<'a> {
476    fn new(connection: &'a Connection, sql: &str) -> Self {
477        Self {
478            connection,
479            sql: sql.to_string(),
480            params: None,
481        }
482    }
483
484    pub fn bind<P: Into<Params> + Send>(mut self, params: P) -> Self {
485        self.params = Some(params.into());
486        self
487    }
488
489    pub async fn execute(self) -> Result<i64> {
490        let sql = match self.params {
491            Some(params) => params.replace(&self.sql),
492            None => self.sql,
493        };
494        self.connection.inner.exec(&sql).await
495    }
496}
497
498impl<'a> std::future::IntoFuture for ExecBuilder<'a> {
499    type Output = Result<i64>;
500    type IntoFuture =
501        std::pin::Pin<Box<dyn std::future::Future<Output = Self::Output> + Send + 'a>>;
502
503    fn into_future(self) -> Self::IntoFuture {
504        Box::pin(self.execute())
505    }
506}
507
508// Add trait bounds for ORM functionality
509pub trait RowORM: TryFrom<Row> + Clone {
510    fn field_names() -> Vec<&'static str>; // For backward compatibility
511    fn query_field_names() -> Vec<&'static str>; // For SELECT queries (exclude skip_deserializing)
512    fn insert_field_names() -> Vec<&'static str>; // For INSERT statements (exclude skip_serializing)
513    fn to_values(&self) -> Vec<Value>;
514}