1use log::error;
16use once_cell::sync::Lazy;
17use std::collections::BTreeMap;
18use std::path::Path;
19use std::str::FromStr;
20use url::Url;
21
22use crate::conn::IConnection;
23#[cfg(feature = "flight-sql")]
24use crate::flight_sql::FlightSQLConnection;
25use crate::placeholder::PlaceholderVisitor;
26use crate::ConnectionInfo;
27use crate::Params;
28
29use databend_client::PresignedResponse;
30use databend_common_ast::parser::Dialect;
31use databend_driver_core::error::{Error, Result};
32use databend_driver_core::raw_rows::{RawRow, RawRowIterator};
33use databend_driver_core::rows::{Row, RowIterator, RowStatsIterator, ServerStats};
34use databend_driver_core::value::Value;
35use tokio_stream::StreamExt;
36
37use crate::rest_api::RestAPIConnection;
38
39static VERSION: Lazy<String> = Lazy::new(|| {
40 let version = option_env!("CARGO_PKG_VERSION").unwrap_or("unknown");
41 version.to_string()
42});
43
44#[derive(Clone, Copy, Debug, PartialEq)]
45pub enum LoadMethod {
46 Stage,
47 Streaming,
48}
49
50impl FromStr for LoadMethod {
51 type Err = Error;
52
53 fn from_str(s: &str) -> Result<Self, Self::Err> {
54 match s.to_lowercase().as_str() {
55 "stage" => Ok(LoadMethod::Stage),
56 "streaming" => Ok(LoadMethod::Streaming),
57 _ => Err(Error::BadArgument(format!("invalid load method: {s}"))),
58 }
59 }
60}
61
62#[derive(Clone)]
63pub struct Client {
64 dsn: String,
65 name: String,
66}
67
68use crate::conn::Reader;
69
70pub struct Connection {
71 inner: Box<dyn IConnection>,
72}
73
74impl Client {
75 pub fn new(dsn: String) -> Self {
76 let name = format!("databend-driver-rust/{}", VERSION.as_str());
77 Self { dsn, name }
78 }
79
80 pub fn with_name(mut self, name: String) -> Self {
81 self.name = name;
82 self
83 }
84
85 pub async fn get_conn(&self) -> Result<Connection> {
86 let u = Url::parse(&self.dsn)?;
87 match u.scheme() {
88 "databend" | "databend+http" | "databend+https" => {
89 let conn = RestAPIConnection::try_create(&self.dsn, self.name.clone()).await?;
90 Ok(Connection {
91 inner: Box::new(conn),
92 })
93 }
94 #[cfg(feature = "flight-sql")]
95 "databend+flight" | "databend+grpc" => {
96 let conn = FlightSQLConnection::try_create(&self.dsn, self.name.clone()).await?;
97 Ok(Connection {
98 inner: Box::new(conn),
99 })
100 }
101 _ => Err(Error::Parsing(format!(
102 "Unsupported scheme: {}",
103 u.scheme()
104 ))),
105 }
106 }
107}
108
109impl Drop for Connection {
110 fn drop(&mut self) {
111 if let Err(e) = self.inner.close_with_spawn() {
112 error!("fail to close connection: {}", e);
113 }
114 }
115}
116
117impl Connection {
118 pub fn inner(&self) -> &dyn IConnection {
119 self.inner.as_ref()
120 }
121
122 pub async fn info(&self) -> ConnectionInfo {
123 self.inner.info().await
124 }
125 pub async fn close(&self) -> Result<()> {
126 self.inner.close().await
127 }
128
129 pub fn last_query_id(&self) -> Option<String> {
130 self.inner.last_query_id()
131 }
132
133 pub async fn version(&self) -> Result<String> {
134 self.inner.version().await
135 }
136
137 pub fn format_sql<P: Into<Params> + Send>(&self, sql: &str, params: P) -> String {
138 let params = params.into();
139 params.replace(sql)
140 }
141
142 pub async fn kill_query(&self, query_id: &str) -> Result<()> {
143 self.inner.kill_query(query_id).await
144 }
145
146 pub fn query(&self, sql: &str) -> QueryBuilder<'_> {
147 QueryBuilder::new(self, sql)
148 }
149
150 pub fn exec(&self, sql: &str) -> ExecBuilder<'_> {
151 ExecBuilder::new(self, sql)
152 }
153
154 pub async fn query_iter(&self, sql: &str) -> Result<RowIterator> {
155 QueryBuilder::new(self, sql).iter().await
156 }
157
158 pub async fn query_iter_ext(&self, sql: &str) -> Result<RowStatsIterator> {
159 QueryBuilder::new(self, sql).iter_ext().await
160 }
161
162 pub async fn query_row(&self, sql: &str) -> Result<Option<Row>> {
163 QueryBuilder::new(self, sql).one().await
164 }
165
166 pub async fn query_all(&self, sql: &str) -> Result<Vec<Row>> {
167 QueryBuilder::new(self, sql).all().await
168 }
169
170 pub async fn query_raw_iter(&self, sql: &str) -> Result<RawRowIterator> {
172 self.inner.query_raw_iter(sql).await
173 }
174
175 pub async fn query_raw_all(&self, sql: &str) -> Result<Vec<RawRow>> {
177 self.inner.query_raw_all(sql).await
178 }
179
180 pub async fn get_presigned_url(
183 &self,
184 operation: &str,
185 stage: &str,
186 ) -> Result<PresignedResponse> {
187 self.inner.get_presigned_url(operation, stage).await
188 }
189
190 pub async fn upload_to_stage(&self, stage: &str, data: Reader, size: u64) -> Result<()> {
191 self.inner.upload_to_stage(stage, data, size).await
192 }
193
194 pub async fn load_data(
195 &self,
196 sql: &str,
197 data: Reader,
198 size: u64,
199 method: LoadMethod,
200 ) -> Result<ServerStats> {
201 self.inner.load_data(sql, data, size, method).await
202 }
203
204 pub async fn load_file(&self, sql: &str, fp: &Path, method: LoadMethod) -> Result<ServerStats> {
205 self.inner.load_file(sql, fp, method).await
206 }
207
208 pub async fn load_file_with_options(
209 &self,
210 sql: &str,
211 fp: &Path,
212 file_format_options: Option<BTreeMap<&str, &str>>,
213 copy_options: Option<BTreeMap<&str, &str>>,
214 ) -> Result<ServerStats> {
215 self.inner
216 .load_file_with_options(sql, fp, file_format_options, copy_options)
217 .await
218 }
219
220 pub async fn stream_load(
221 &self,
222 sql: &str,
223 data: Vec<Vec<&str>>,
224 method: LoadMethod,
225 ) -> Result<ServerStats> {
226 self.inner.stream_load(sql, data, method).await
227 }
228
229 pub fn set_warehouse(&self, warehouse: &str) -> Result<()> {
230 self.inner.set_warehouse(warehouse)
231 }
232
233 pub fn set_database(&self, database: &str) -> Result<()> {
234 self.inner.set_database(database)
235 }
236
237 pub fn set_role(&self, role: &str) -> Result<()> {
238 self.inner.set_role(role)
239 }
240
241 pub fn set_session(&self, key: &str, value: &str) -> Result<()> {
242 self.inner.set_session(key, value)
243 }
244
245 pub async fn put_files(&self, local_file: &str, stage: &str) -> Result<RowStatsIterator> {
247 self.inner.put_files(local_file, stage).await
248 }
249
250 pub async fn get_files(&self, stage: &str, local_file: &str) -> Result<RowStatsIterator> {
251 self.inner.get_files(stage, local_file).await
252 }
253
254 pub fn query_as<T>(&self, sql: &str) -> ORMQueryBuilder<'_, T>
256 where
257 T: TryFrom<Row> + RowORM,
258 T::Error: std::fmt::Display,
259 {
260 ORMQueryBuilder::new(self, sql)
261 }
262
263 pub async fn insert<T>(&self, table_name: &str) -> Result<InsertCursor<'_, T>>
264 where
265 T: Clone + RowORM,
266 {
267 Ok(InsertCursor::new(self, table_name.to_string()))
268 }
269}
270
271pub struct QueryCursor<T> {
272 iter: RowIterator,
273 _phantom: std::marker::PhantomData<T>,
274}
275
276impl<T> QueryCursor<T>
277where
278 T: TryFrom<Row>,
279 T::Error: std::fmt::Display,
280{
281 fn new(iter: RowIterator) -> Self {
282 Self {
283 iter,
284 _phantom: std::marker::PhantomData,
285 }
286 }
287
288 pub async fn fetch(&mut self) -> Result<Option<T>> {
289 use tokio_stream::StreamExt;
290 match self.iter.next().await {
291 Some(row) => {
292 let row = row?;
293 let typed_row = T::try_from(row).map_err(|e| Error::Parsing(e.to_string()))?;
294 Ok(Some(typed_row))
295 }
296 None => Ok(None),
297 }
298 }
299
300 pub async fn next(&mut self) -> Result<Option<T>> {
301 self.fetch().await
302 }
303
304 pub async fn fetch_all(self) -> Result<Vec<T>> {
305 self.iter.try_collect().await
306 }
307}
308
309pub struct InsertCursor<'a, T> {
310 connection: &'a Connection,
311 table_name: String,
312 rows: Vec<T>,
313 _phantom: std::marker::PhantomData<T>,
314}
315
316impl<'a, T> InsertCursor<'a, T>
317where
318 T: Clone + RowORM,
319{
320 fn new(connection: &'a Connection, table_name: String) -> Self {
321 Self {
322 connection,
323 table_name,
324 rows: Vec::new(),
325 _phantom: std::marker::PhantomData,
326 }
327 }
328
329 pub async fn write(&mut self, row: &T) -> Result<()> {
330 self.rows.push(row.clone());
331 Ok(())
332 }
333
334 pub async fn end(self) -> Result<i64> {
335 if self.rows.is_empty() {
336 return Ok(0);
337 }
338 let connection = self.connection;
339 let field_names = T::insert_field_names();
341 let field_list = field_names.join(", ");
342 let placeholder_list = (0..field_names.len())
343 .map(|_| "?")
344 .collect::<Vec<_>>()
345 .join(", ");
346
347 let sql = format!(
348 "INSERT INTO {} ({}) VALUES ({})",
349 self.table_name, field_list, placeholder_list
350 );
351
352 let mut total_inserted = 0;
353 for row in &self.rows {
354 let values = row.to_values();
355 let json_values: Vec<serde_json::Value> =
356 values.into_iter().map(|v| v.to_json_value()).collect();
357 let params = Params::QuestionParams(json_values);
358 let inserted = connection.exec(&sql).bind(params).await?;
359 total_inserted += inserted;
360 }
361
362 Ok(total_inserted)
363 }
364}
365
366fn replace_query_fields_placeholder(sql: &str, field_names: &[&str]) -> String {
368 let fields = field_names.join(", ");
369 sql.replace("?fields", &fields)
370}
371
372#[allow(dead_code)]
374fn replace_insert_fields_placeholder(sql: &str, field_names: &[&str]) -> String {
375 let fields = field_names.join(", ");
376 sql.replace("?fields", &fields)
377}
378
379pub struct ORMQueryBuilder<'a, T> {
381 connection: &'a Connection,
382 sql: String,
383 params: Option<Params>,
384 _phantom: std::marker::PhantomData<T>,
385}
386
387impl<'a, T> ORMQueryBuilder<'a, T>
388where
389 T: TryFrom<Row> + RowORM,
390 T::Error: std::fmt::Display,
391{
392 fn new(connection: &'a Connection, sql: &str) -> Self {
393 Self {
394 connection,
395 sql: sql.to_string(),
396 params: None,
397 _phantom: std::marker::PhantomData,
398 }
399 }
400
401 pub fn bind<P: Into<Params> + Send>(mut self, params: P) -> Self {
402 self.params = Some(params.into());
403 self
404 }
405
406 pub async fn execute(self) -> Result<QueryCursor<T>> {
407 let sql_with_fields = replace_query_fields_placeholder(&self.sql, &T::query_field_names());
408 let final_sql = if let Some(params) = self.params {
409 params.replace(&sql_with_fields)
410 } else {
411 sql_with_fields
412 };
413 let row_iter = self.connection.inner.query_iter(&final_sql).await?;
414 Ok(QueryCursor::new(row_iter))
415 }
416}
417
418impl<'a, T> std::future::IntoFuture for ORMQueryBuilder<'a, T>
419where
420 T: TryFrom<Row> + RowORM + Send + 'a,
421 T::Error: std::fmt::Display,
422{
423 type Output = Result<QueryCursor<T>>;
424 type IntoFuture =
425 std::pin::Pin<Box<dyn std::future::Future<Output = Self::Output> + Send + 'a>>;
426
427 fn into_future(self) -> Self::IntoFuture {
428 Box::pin(self.execute())
429 }
430}
431
432pub struct QueryBuilder<'a> {
434 connection: &'a Connection,
435 sql: String,
436 params: Option<Params>,
437}
438
439impl<'a> QueryBuilder<'a> {
440 fn new(connection: &'a Connection, sql: &str) -> Self {
441 Self {
442 connection,
443 sql: sql.to_string(),
444 params: None,
445 }
446 }
447
448 pub fn bind<P: Into<Params> + Send>(mut self, params: P) -> Self {
449 self.params = Some(params.into());
450 self
451 }
452
453 pub async fn iter(self) -> Result<RowIterator> {
454 if let Some(params) = &self.params {
455 if self.should_use_server_side_params() {
456 let json_params = params.to_json_value();
457 return self
458 .connection
459 .inner
460 .query_iter_with_params(&self.sql, Some(json_params))
461 .await;
462 }
463 }
464 let sql = self.get_final_sql();
465 self.connection.inner.query_iter(&sql).await
466 }
467
468 pub async fn iter_ext(self) -> Result<RowStatsIterator> {
469 if let Some(params) = &self.params {
470 if self.should_use_server_side_params() {
471 let json_params = params.to_json_value();
472 return self
473 .connection
474 .inner
475 .query_iter_ext_with_params(&self.sql, Some(json_params))
476 .await;
477 }
478 }
479 let sql = self.get_final_sql();
480 self.connection.inner.query_iter_ext(&sql).await
481 }
482
483 pub async fn one(self) -> Result<Option<Row>> {
484 if let Some(params) = &self.params {
485 if self.should_use_server_side_params() {
486 let json_params = params.to_json_value();
487 let mut rows = self
488 .connection
489 .inner
490 .query_iter_with_params(&self.sql, Some(json_params))
491 .await?;
492 return match rows.next().await {
493 Some(r) => Ok(Some(r?)),
494 None => Ok(None),
495 };
496 }
497 }
498 let sql = self.get_final_sql();
499 self.connection.inner.query_row(&sql).await
500 }
501
502 pub async fn all(self) -> Result<Vec<Row>> {
503 if let Some(params) = &self.params {
504 if self.should_use_server_side_params() {
505 let json_params = params.to_json_value();
506 let rows = self
507 .connection
508 .inner
509 .query_iter_with_params(&self.sql, Some(json_params))
510 .await?;
511 return rows.collect().await;
512 }
513 }
514 let sql = self.get_final_sql();
515 self.connection.inner.query_all(&sql).await
516 }
517
518 pub async fn cursor_as<T>(self) -> Result<QueryCursor<T>>
519 where
520 T: TryFrom<Row> + RowORM,
521 T::Error: std::fmt::Display,
522 {
523 let sql_with_fields = replace_query_fields_placeholder(&self.sql, &T::query_field_names());
524 let final_sql = if let Some(params) = self.params {
525 params.replace(&sql_with_fields)
526 } else {
527 sql_with_fields
528 };
529 let row_iter = self.connection.inner.query_iter(&final_sql).await?;
530 Ok(QueryCursor::new(row_iter))
531 }
532
533 fn should_use_server_side_params(&self) -> bool {
534 self.connection.inner.supports_server_side_params()
535 && !sql_has_dollar_placeholders(&self.sql)
536 }
537
538 fn get_final_sql(&self) -> String {
539 match &self.params {
540 Some(params) => params.replace(&self.sql),
541 None => self.sql.clone(),
542 }
543 }
544}
545
546pub struct ExecBuilder<'a> {
548 connection: &'a Connection,
549 sql: String,
550 params: Option<Params>,
551}
552
553impl<'a> ExecBuilder<'a> {
554 fn new(connection: &'a Connection, sql: &str) -> Self {
555 Self {
556 connection,
557 sql: sql.to_string(),
558 params: None,
559 }
560 }
561
562 pub fn bind<P: Into<Params> + Send>(mut self, params: P) -> Self {
563 self.params = Some(params.into());
564 self
565 }
566
567 pub async fn execute(self) -> Result<i64> {
568 if let Some(ref params) = self.params {
569 if self.should_use_server_side_params() {
570 let json_params = params.to_json_value();
571 return self
572 .connection
573 .inner
574 .exec_with_params(&self.sql, Some(json_params))
575 .await;
576 }
577 }
578 let sql = match self.params {
579 Some(params) => params.replace(&self.sql),
580 None => self.sql,
581 };
582 self.connection.inner.exec(&sql).await
583 }
584
585 fn should_use_server_side_params(&self) -> bool {
586 self.connection.inner.supports_server_side_params()
587 && !sql_has_dollar_placeholders(&self.sql)
588 }
589}
590
591impl<'a> std::future::IntoFuture for ExecBuilder<'a> {
592 type Output = Result<i64>;
593 type IntoFuture =
594 std::pin::Pin<Box<dyn std::future::Future<Output = Self::Output> + Send + 'a>>;
595
596 fn into_future(self) -> Self::IntoFuture {
597 Box::pin(self.execute())
598 }
599}
600
601fn sql_has_dollar_placeholders(sql: &str) -> bool {
602 let tokens = match databend_common_ast::parser::tokenize_sql(sql) {
603 Ok(t) => t,
604 Err(_) => return false,
605 };
606 if let Ok((stmt, _)) = databend_common_ast::parser::parse_sql(&tokens, Dialect::PostgreSQL) {
607 let mut visitor = PlaceholderVisitor::new();
608 return visitor.has_dollar_positions(&stmt);
609 }
610 false
611}
612
613pub trait RowORM: TryFrom<Row> + Clone {
615 fn field_names() -> Vec<&'static str>; fn query_field_names() -> Vec<&'static str>; fn insert_field_names() -> Vec<&'static str>; fn to_values(&self) -> Vec<Value>;
619}