1use log::error;
16use once_cell::sync::Lazy;
17use std::collections::BTreeMap;
18use std::path::Path;
19use std::str::FromStr;
20use url::Url;
21
22use crate::conn::IConnection;
23#[cfg(feature = "flight-sql")]
24use crate::flight_sql::FlightSQLConnection;
25use crate::ConnectionInfo;
26use crate::Params;
27
28use databend_client::PresignedResponse;
29use databend_driver_core::error::{Error, Result};
30use databend_driver_core::raw_rows::{RawRow, RawRowIterator};
31use databend_driver_core::rows::{Row, RowIterator, RowStatsIterator, ServerStats};
32use databend_driver_core::value::Value;
33
34use crate::rest_api::RestAPIConnection;
35
36static VERSION: Lazy<String> = Lazy::new(|| {
37 let version = option_env!("CARGO_PKG_VERSION").unwrap_or("unknown");
38 version.to_string()
39});
40
41#[derive(Clone, Copy, Debug, PartialEq)]
42pub enum LoadMethod {
43 Stage,
44 Streaming,
45}
46
47impl FromStr for LoadMethod {
48 type Err = Error;
49
50 fn from_str(s: &str) -> Result<Self, Self::Err> {
51 match s.to_lowercase().as_str() {
52 "stage" => Ok(LoadMethod::Stage),
53 "streaming" => Ok(LoadMethod::Streaming),
54 _ => Err(Error::BadArgument(format!("invalid load method: {s}"))),
55 }
56 }
57}
58
59#[derive(Clone)]
60pub struct Client {
61 dsn: String,
62 name: String,
63}
64
65use crate::conn::Reader;
66
67pub struct Connection {
68 inner: Box<dyn IConnection>,
69}
70
71impl Client {
72 pub fn new(dsn: String) -> Self {
73 let name = format!("databend-driver-rust/{}", VERSION.as_str());
74 Self { dsn, name }
75 }
76
77 pub fn with_name(mut self, name: String) -> Self {
78 self.name = name;
79 self
80 }
81
82 pub async fn get_conn(&self) -> Result<Connection> {
83 let u = Url::parse(&self.dsn)?;
84 match u.scheme() {
85 "databend" | "databend+http" | "databend+https" => {
86 let conn = RestAPIConnection::try_create(&self.dsn, self.name.clone()).await?;
87 Ok(Connection {
88 inner: Box::new(conn),
89 })
90 }
91 #[cfg(feature = "flight-sql")]
92 "databend+flight" | "databend+grpc" => {
93 let conn = FlightSQLConnection::try_create(&self.dsn, self.name.clone()).await?;
94 Ok(Connection {
95 inner: Box::new(conn),
96 })
97 }
98 _ => Err(Error::Parsing(format!(
99 "Unsupported scheme: {}",
100 u.scheme()
101 ))),
102 }
103 }
104}
105
106impl Drop for Connection {
107 fn drop(&mut self) {
108 if let Err(e) = self.inner.close_with_spawn() {
109 error!("fail to close connection: {}", e);
110 }
111 }
112}
113
114impl Connection {
115 pub fn inner(&self) -> &dyn IConnection {
116 self.inner.as_ref()
117 }
118
119 pub async fn info(&self) -> ConnectionInfo {
120 self.inner.info().await
121 }
122 pub async fn close(&self) -> Result<()> {
123 self.inner.close().await
124 }
125
126 pub fn last_query_id(&self) -> Option<String> {
127 self.inner.last_query_id()
128 }
129
130 pub async fn version(&self) -> Result<String> {
131 self.inner.version().await
132 }
133
134 pub fn format_sql<P: Into<Params> + Send>(&self, sql: &str, params: P) -> String {
135 let params = params.into();
136 params.replace(sql)
137 }
138
139 pub async fn kill_query(&self, query_id: &str) -> Result<()> {
140 self.inner.kill_query(query_id).await
141 }
142
143 pub fn query(&self, sql: &str) -> QueryBuilder<'_> {
144 QueryBuilder::new(self, sql)
145 }
146
147 pub fn exec(&self, sql: &str) -> ExecBuilder<'_> {
148 ExecBuilder::new(self, sql)
149 }
150
151 pub async fn query_iter(&self, sql: &str) -> Result<RowIterator> {
152 QueryBuilder::new(self, sql).iter().await
153 }
154
155 pub async fn query_iter_ext(&self, sql: &str) -> Result<RowStatsIterator> {
156 QueryBuilder::new(self, sql).iter_ext().await
157 }
158
159 pub async fn query_row(&self, sql: &str) -> Result<Option<Row>> {
160 QueryBuilder::new(self, sql).one().await
161 }
162
163 pub async fn query_all(&self, sql: &str) -> Result<Vec<Row>> {
164 QueryBuilder::new(self, sql).all().await
165 }
166
167 pub async fn query_raw_iter(&self, sql: &str) -> Result<RawRowIterator> {
169 self.inner.query_raw_iter(sql).await
170 }
171
172 pub async fn query_raw_all(&self, sql: &str) -> Result<Vec<RawRow>> {
174 self.inner.query_raw_all(sql).await
175 }
176
177 pub async fn get_presigned_url(
180 &self,
181 operation: &str,
182 stage: &str,
183 ) -> Result<PresignedResponse> {
184 self.inner.get_presigned_url(operation, stage).await
185 }
186
187 pub async fn upload_to_stage(&self, stage: &str, data: Reader, size: u64) -> Result<()> {
188 self.inner.upload_to_stage(stage, data, size).await
189 }
190
191 pub async fn load_data(
192 &self,
193 sql: &str,
194 data: Reader,
195 size: u64,
196 method: LoadMethod,
197 ) -> Result<ServerStats> {
198 self.inner.load_data(sql, data, size, method).await
199 }
200
201 pub async fn load_file(&self, sql: &str, fp: &Path, method: LoadMethod) -> Result<ServerStats> {
202 self.inner.load_file(sql, fp, method).await
203 }
204
205 pub async fn load_file_with_options(
206 &self,
207 sql: &str,
208 fp: &Path,
209 file_format_options: Option<BTreeMap<&str, &str>>,
210 copy_options: Option<BTreeMap<&str, &str>>,
211 ) -> Result<ServerStats> {
212 self.inner
213 .load_file_with_options(sql, fp, file_format_options, copy_options)
214 .await
215 }
216
217 pub async fn stream_load(
218 &self,
219 sql: &str,
220 data: Vec<Vec<&str>>,
221 method: LoadMethod,
222 ) -> Result<ServerStats> {
223 self.inner.stream_load(sql, data, method).await
224 }
225
226 pub fn set_warehouse(&self, warehouse: &str) -> Result<()> {
227 self.inner.set_warehouse(warehouse)
228 }
229
230 pub fn set_database(&self, database: &str) -> Result<()> {
231 self.inner.set_database(database)
232 }
233
234 pub fn set_role(&self, role: &str) -> Result<()> {
235 self.inner.set_role(role)
236 }
237
238 pub fn set_session(&self, key: &str, value: &str) -> Result<()> {
239 self.inner.set_session(key, value)
240 }
241
242 pub async fn put_files(&self, local_file: &str, stage: &str) -> Result<RowStatsIterator> {
244 self.inner.put_files(local_file, stage).await
245 }
246
247 pub async fn get_files(&self, stage: &str, local_file: &str) -> Result<RowStatsIterator> {
248 self.inner.get_files(stage, local_file).await
249 }
250
251 pub fn query_as<T>(&self, sql: &str) -> ORMQueryBuilder<'_, T>
253 where
254 T: TryFrom<Row> + RowORM,
255 T::Error: std::fmt::Display,
256 {
257 ORMQueryBuilder::new(self, sql)
258 }
259
260 pub async fn insert<T>(&self, table_name: &str) -> Result<InsertCursor<'_, T>>
261 where
262 T: Clone + RowORM,
263 {
264 Ok(InsertCursor::new(self, table_name.to_string()))
265 }
266}
267
268pub struct QueryCursor<T> {
269 iter: RowIterator,
270 _phantom: std::marker::PhantomData<T>,
271}
272
273impl<T> QueryCursor<T>
274where
275 T: TryFrom<Row>,
276 T::Error: std::fmt::Display,
277{
278 fn new(iter: RowIterator) -> Self {
279 Self {
280 iter,
281 _phantom: std::marker::PhantomData,
282 }
283 }
284
285 pub async fn fetch(&mut self) -> Result<Option<T>> {
286 use tokio_stream::StreamExt;
287 match self.iter.next().await {
288 Some(row) => {
289 let row = row?;
290 let typed_row = T::try_from(row).map_err(|e| Error::Parsing(e.to_string()))?;
291 Ok(Some(typed_row))
292 }
293 None => Ok(None),
294 }
295 }
296
297 pub async fn next(&mut self) -> Result<Option<T>> {
298 self.fetch().await
299 }
300
301 pub async fn fetch_all(self) -> Result<Vec<T>> {
302 self.iter.try_collect().await
303 }
304}
305
306pub struct InsertCursor<'a, T> {
307 connection: &'a Connection,
308 table_name: String,
309 rows: Vec<T>,
310 _phantom: std::marker::PhantomData<T>,
311}
312
313impl<'a, T> InsertCursor<'a, T>
314where
315 T: Clone + RowORM,
316{
317 fn new(connection: &'a Connection, table_name: String) -> Self {
318 Self {
319 connection,
320 table_name,
321 rows: Vec::new(),
322 _phantom: std::marker::PhantomData,
323 }
324 }
325
326 pub async fn write(&mut self, row: &T) -> Result<()> {
327 self.rows.push(row.clone());
328 Ok(())
329 }
330
331 pub async fn end(self) -> Result<i64> {
332 if self.rows.is_empty() {
333 return Ok(0);
334 }
335 let connection = self.connection;
336 let field_names = T::insert_field_names();
338 let field_list = field_names.join(", ");
339 let placeholder_list = (0..field_names.len())
340 .map(|_| "?")
341 .collect::<Vec<_>>()
342 .join(", ");
343
344 let sql = format!(
345 "INSERT INTO {} ({}) VALUES ({})",
346 self.table_name, field_list, placeholder_list
347 );
348
349 let mut total_inserted = 0;
350 for row in &self.rows {
351 let values = row.to_values();
352 let param_strings: Vec<String> =
353 values.into_iter().map(|v| v.to_sql_string()).collect();
354 let params = Params::QuestionParams(param_strings);
355 let inserted = connection.exec(&sql).bind(params).await?;
356 total_inserted += inserted;
357 }
358
359 Ok(total_inserted)
360 }
361}
362
363fn replace_query_fields_placeholder(sql: &str, field_names: &[&str]) -> String {
365 let fields = field_names.join(", ");
366 sql.replace("?fields", &fields)
367}
368
369#[allow(dead_code)]
371fn replace_insert_fields_placeholder(sql: &str, field_names: &[&str]) -> String {
372 let fields = field_names.join(", ");
373 sql.replace("?fields", &fields)
374}
375
376pub struct ORMQueryBuilder<'a, T> {
378 connection: &'a Connection,
379 sql: String,
380 params: Option<Params>,
381 _phantom: std::marker::PhantomData<T>,
382}
383
384impl<'a, T> ORMQueryBuilder<'a, T>
385where
386 T: TryFrom<Row> + RowORM,
387 T::Error: std::fmt::Display,
388{
389 fn new(connection: &'a Connection, sql: &str) -> Self {
390 Self {
391 connection,
392 sql: sql.to_string(),
393 params: None,
394 _phantom: std::marker::PhantomData,
395 }
396 }
397
398 pub fn bind<P: Into<Params> + Send>(mut self, params: P) -> Self {
399 self.params = Some(params.into());
400 self
401 }
402
403 pub async fn execute(self) -> Result<QueryCursor<T>> {
404 let sql_with_fields = replace_query_fields_placeholder(&self.sql, &T::query_field_names());
405 let final_sql = if let Some(params) = self.params {
406 params.replace(&sql_with_fields)
407 } else {
408 sql_with_fields
409 };
410 let row_iter = self.connection.inner.query_iter(&final_sql).await?;
411 Ok(QueryCursor::new(row_iter))
412 }
413}
414
415impl<'a, T> std::future::IntoFuture for ORMQueryBuilder<'a, T>
416where
417 T: TryFrom<Row> + RowORM + Send + 'a,
418 T::Error: std::fmt::Display,
419{
420 type Output = Result<QueryCursor<T>>;
421 type IntoFuture =
422 std::pin::Pin<Box<dyn std::future::Future<Output = Self::Output> + Send + 'a>>;
423
424 fn into_future(self) -> Self::IntoFuture {
425 Box::pin(self.execute())
426 }
427}
428
429pub struct QueryBuilder<'a> {
431 connection: &'a Connection,
432 sql: String,
433 params: Option<Params>,
434}
435
436impl<'a> QueryBuilder<'a> {
437 fn new(connection: &'a Connection, sql: &str) -> Self {
438 Self {
439 connection,
440 sql: sql.to_string(),
441 params: None,
442 }
443 }
444
445 pub fn bind<P: Into<Params> + Send>(mut self, params: P) -> Self {
446 self.params = Some(params.into());
447 self
448 }
449
450 pub async fn iter(self) -> Result<RowIterator> {
451 let sql = self.get_final_sql();
452 self.connection.inner.query_iter(&sql).await
453 }
454
455 pub async fn iter_ext(self) -> Result<RowStatsIterator> {
456 let sql = self.get_final_sql();
457 self.connection.inner.query_iter_ext(&sql).await
458 }
459
460 pub async fn one(self) -> Result<Option<Row>> {
461 let sql = self.get_final_sql();
462 self.connection.inner.query_row(&sql).await
463 }
464
465 pub async fn all(self) -> Result<Vec<Row>> {
466 let sql = self.get_final_sql();
467 self.connection.inner.query_all(&sql).await
468 }
469
470 pub async fn cursor_as<T>(self) -> Result<QueryCursor<T>>
471 where
472 T: TryFrom<Row> + RowORM,
473 T::Error: std::fmt::Display,
474 {
475 let sql_with_fields = replace_query_fields_placeholder(&self.sql, &T::query_field_names());
476 let final_sql = if let Some(params) = self.params {
477 params.replace(&sql_with_fields)
478 } else {
479 sql_with_fields
480 };
481 let row_iter = self.connection.inner.query_iter(&final_sql).await?;
482 Ok(QueryCursor::new(row_iter))
483 }
484
485 fn get_final_sql(&self) -> String {
486 match &self.params {
487 Some(params) => params.replace(&self.sql),
488 None => self.sql.clone(),
489 }
490 }
491}
492
493pub struct ExecBuilder<'a> {
495 connection: &'a Connection,
496 sql: String,
497 params: Option<Params>,
498}
499
500impl<'a> ExecBuilder<'a> {
501 fn new(connection: &'a Connection, sql: &str) -> Self {
502 Self {
503 connection,
504 sql: sql.to_string(),
505 params: None,
506 }
507 }
508
509 pub fn bind<P: Into<Params> + Send>(mut self, params: P) -> Self {
510 self.params = Some(params.into());
511 self
512 }
513
514 pub async fn execute(self) -> Result<i64> {
515 let sql = match self.params {
516 Some(params) => params.replace(&self.sql),
517 None => self.sql,
518 };
519 self.connection.inner.exec(&sql).await
520 }
521}
522
523impl<'a> std::future::IntoFuture for ExecBuilder<'a> {
524 type Output = Result<i64>;
525 type IntoFuture =
526 std::pin::Pin<Box<dyn std::future::Future<Output = Self::Output> + Send + 'a>>;
527
528 fn into_future(self) -> Self::IntoFuture {
529 Box::pin(self.execute())
530 }
531}
532
533pub trait RowORM: TryFrom<Row> + Clone {
535 fn field_names() -> Vec<&'static str>; fn query_field_names() -> Vec<&'static str>; fn insert_field_names() -> Vec<&'static str>; fn to_values(&self) -> Vec<Value>;
539}