1use log::error;
16use once_cell::sync::Lazy;
17use std::collections::BTreeMap;
18use std::path::Path;
19use std::str::FromStr;
20use url::Url;
21
22use crate::conn::IConnection;
23#[cfg(feature = "flight-sql")]
24use crate::flight_sql::FlightSQLConnection;
25use crate::placeholder::PlaceholderVisitor;
26use crate::ConnectionInfo;
27use crate::Params;
28
29use databend_common_ast::parser::Dialect;
30use lake_client::PresignedResponse;
31use lake_driver_core::error::{Error, Result};
32use lake_driver_core::raw_rows::{RawRow, RawRowIterator};
33use lake_driver_core::rows::{Row, RowIterator, RowStatsIterator, ServerStats};
34use lake_driver_core::value::Value;
35use tokio_stream::StreamExt;
36
37use crate::rest_api::RestAPIConnection;
38
39static VERSION: Lazy<String> = Lazy::new(|| {
40 let version = option_env!("CARGO_PKG_VERSION").unwrap_or("unknown");
41 version.to_string()
42});
43
44#[derive(Clone, Copy, Debug, PartialEq)]
45pub enum LoadMethod {
46 Stage,
47 Streaming,
48}
49
50impl FromStr for LoadMethod {
51 type Err = Error;
52
53 fn from_str(s: &str) -> Result<Self, Self::Err> {
54 match s.to_lowercase().as_str() {
55 "stage" => Ok(LoadMethod::Stage),
56 "streaming" => Ok(LoadMethod::Streaming),
57 _ => Err(Error::BadArgument(format!("invalid load method: {s}"))),
58 }
59 }
60}
61
62#[derive(Clone)]
63pub struct Client {
64 dsn: String,
65 name: String,
66}
67
68use crate::conn::Reader;
69
70pub struct Connection {
71 inner: Box<dyn IConnection>,
72}
73
74impl Client {
75 pub fn new(dsn: String) -> Self {
76 let name = format!("lake-driver-rust/{}", VERSION.as_str());
77 Self { dsn, name }
78 }
79
80 pub fn with_name(mut self, name: String) -> Self {
81 self.name = name;
82 self
83 }
84
85 pub async fn get_conn(&self) -> Result<Connection> {
86 let u = Url::parse(&self.dsn)?;
87 match u.scheme() {
88 "databend" | "databend+http" | "databend+https" | "lake" | "lake+http"
89 | "lake+https" => {
90 let conn = RestAPIConnection::try_create(&self.dsn, self.name.clone()).await?;
91 Ok(Connection {
92 inner: Box::new(conn),
93 })
94 }
95 #[cfg(feature = "flight-sql")]
96 "databend+flight" | "databend+grpc" | "lake+flight" | "lake+grpc" => {
97 let conn = FlightSQLConnection::try_create(&self.dsn, self.name.clone()).await?;
98 Ok(Connection {
99 inner: Box::new(conn),
100 })
101 }
102 _ => Err(Error::Parsing(format!(
103 "Unsupported scheme: {}",
104 u.scheme()
105 ))),
106 }
107 }
108}
109
110impl Drop for Connection {
111 fn drop(&mut self) {
112 if let Err(e) = self.inner.close_with_spawn() {
113 error!("fail to close connection: {}", e);
114 }
115 }
116}
117
118impl Connection {
119 pub fn inner(&self) -> &dyn IConnection {
120 self.inner.as_ref()
121 }
122
123 pub async fn info(&self) -> ConnectionInfo {
124 self.inner.info().await
125 }
126 pub async fn close(&self) -> Result<()> {
127 self.inner.close().await
128 }
129
130 pub fn last_query_id(&self) -> Option<String> {
131 self.inner.last_query_id()
132 }
133
134 pub async fn version(&self) -> Result<String> {
135 self.inner.version().await
136 }
137
138 pub fn format_sql<P: Into<Params> + Send>(&self, sql: &str, params: P) -> String {
139 let params = params.into();
140 params.replace(sql)
141 }
142
143 pub async fn kill_query(&self, query_id: &str) -> Result<()> {
144 self.inner.kill_query(query_id).await
145 }
146
147 pub fn query(&self, sql: &str) -> QueryBuilder<'_> {
148 QueryBuilder::new(self, sql)
149 }
150
151 pub fn exec(&self, sql: &str) -> ExecBuilder<'_> {
152 ExecBuilder::new(self, sql)
153 }
154
155 pub async fn query_iter(&self, sql: &str) -> Result<RowIterator> {
156 QueryBuilder::new(self, sql).iter().await
157 }
158
159 pub async fn query_iter_ext(&self, sql: &str) -> Result<RowStatsIterator> {
160 QueryBuilder::new(self, sql).iter_ext().await
161 }
162
163 pub async fn query_row(&self, sql: &str) -> Result<Option<Row>> {
164 QueryBuilder::new(self, sql).one().await
165 }
166
167 pub async fn query_all(&self, sql: &str) -> Result<Vec<Row>> {
168 QueryBuilder::new(self, sql).all().await
169 }
170
171 pub async fn query_raw_iter(&self, sql: &str) -> Result<RawRowIterator> {
173 self.inner.query_raw_iter(sql).await
174 }
175
176 pub async fn query_raw_all(&self, sql: &str) -> Result<Vec<RawRow>> {
178 self.inner.query_raw_all(sql).await
179 }
180
181 pub async fn get_presigned_url(
184 &self,
185 operation: &str,
186 stage: &str,
187 ) -> Result<PresignedResponse> {
188 self.inner.get_presigned_url(operation, stage).await
189 }
190
191 pub async fn upload_to_stage(&self, stage: &str, data: Reader, size: u64) -> Result<()> {
192 self.inner.upload_to_stage(stage, data, size).await
193 }
194
195 pub async fn load_data(
196 &self,
197 sql: &str,
198 data: Reader,
199 size: u64,
200 method: LoadMethod,
201 ) -> Result<ServerStats> {
202 self.inner.load_data(sql, data, size, method).await
203 }
204
205 pub async fn load_file(&self, sql: &str, fp: &Path, method: LoadMethod) -> Result<ServerStats> {
206 self.inner.load_file(sql, fp, method).await
207 }
208
209 pub async fn load_file_with_options(
210 &self,
211 sql: &str,
212 fp: &Path,
213 file_format_options: Option<BTreeMap<&str, &str>>,
214 copy_options: Option<BTreeMap<&str, &str>>,
215 ) -> Result<ServerStats> {
216 self.inner
217 .load_file_with_options(sql, fp, file_format_options, copy_options)
218 .await
219 }
220
221 pub async fn stream_load(
222 &self,
223 sql: &str,
224 data: Vec<Vec<&str>>,
225 method: LoadMethod,
226 ) -> Result<ServerStats> {
227 self.inner.stream_load(sql, data, method).await
228 }
229
230 pub fn set_warehouse(&self, warehouse: &str) -> Result<()> {
231 self.inner.set_warehouse(warehouse)
232 }
233
234 pub fn set_database(&self, database: &str) -> Result<()> {
235 self.inner.set_database(database)
236 }
237
238 pub fn set_role(&self, role: &str) -> Result<()> {
239 self.inner.set_role(role)
240 }
241
242 pub fn set_session(&self, key: &str, value: &str) -> Result<()> {
243 self.inner.set_session(key, value)
244 }
245
246 pub async fn put_files(&self, local_file: &str, stage: &str) -> Result<RowStatsIterator> {
248 self.inner.put_files(local_file, stage).await
249 }
250
251 pub async fn get_files(&self, stage: &str, local_file: &str) -> Result<RowStatsIterator> {
252 self.inner.get_files(stage, local_file).await
253 }
254
255 pub fn query_as<T>(&self, sql: &str) -> ORMQueryBuilder<'_, T>
257 where
258 T: TryFrom<Row> + RowORM,
259 T::Error: std::fmt::Display,
260 {
261 ORMQueryBuilder::new(self, sql)
262 }
263
264 pub async fn insert<T>(&self, table_name: &str) -> Result<InsertCursor<'_, T>>
265 where
266 T: Clone + RowORM,
267 {
268 Ok(InsertCursor::new(self, table_name.to_string()))
269 }
270}
271
272pub struct QueryCursor<T> {
273 iter: RowIterator,
274 _phantom: std::marker::PhantomData<T>,
275}
276
277impl<T> QueryCursor<T>
278where
279 T: TryFrom<Row>,
280 T::Error: std::fmt::Display,
281{
282 fn new(iter: RowIterator) -> Self {
283 Self {
284 iter,
285 _phantom: std::marker::PhantomData,
286 }
287 }
288
289 pub async fn fetch(&mut self) -> Result<Option<T>> {
290 use tokio_stream::StreamExt;
291 match self.iter.next().await {
292 Some(row) => {
293 let row = row?;
294 let typed_row = T::try_from(row).map_err(|e| Error::Parsing(e.to_string()))?;
295 Ok(Some(typed_row))
296 }
297 None => Ok(None),
298 }
299 }
300
301 pub async fn next(&mut self) -> Result<Option<T>> {
302 self.fetch().await
303 }
304
305 pub async fn fetch_all(self) -> Result<Vec<T>> {
306 self.iter.try_collect().await
307 }
308}
309
310pub struct InsertCursor<'a, T> {
311 connection: &'a Connection,
312 table_name: String,
313 rows: Vec<T>,
314 _phantom: std::marker::PhantomData<T>,
315}
316
317impl<'a, T> InsertCursor<'a, T>
318where
319 T: Clone + RowORM,
320{
321 fn new(connection: &'a Connection, table_name: String) -> Self {
322 Self {
323 connection,
324 table_name,
325 rows: Vec::new(),
326 _phantom: std::marker::PhantomData,
327 }
328 }
329
330 pub async fn write(&mut self, row: &T) -> Result<()> {
331 self.rows.push(row.clone());
332 Ok(())
333 }
334
335 pub async fn end(self) -> Result<i64> {
336 if self.rows.is_empty() {
337 return Ok(0);
338 }
339 let connection = self.connection;
340 let field_names = T::insert_field_names();
342 let field_list = field_names.join(", ");
343 let placeholder_list = (0..field_names.len())
344 .map(|_| "?")
345 .collect::<Vec<_>>()
346 .join(", ");
347
348 let sql = format!(
349 "INSERT INTO {} ({}) VALUES ({})",
350 self.table_name, field_list, placeholder_list
351 );
352
353 let mut total_inserted = 0;
354 for row in &self.rows {
355 let values = row.to_values();
356 let json_values: Vec<serde_json::Value> =
357 values.into_iter().map(|v| v.to_json_value()).collect();
358 let params = Params::QuestionParams(json_values);
359 let inserted = connection.exec(&sql).bind(params).await?;
360 total_inserted += inserted;
361 }
362
363 Ok(total_inserted)
364 }
365}
366
367fn replace_query_fields_placeholder(sql: &str, field_names: &[&str]) -> String {
369 let fields = field_names.join(", ");
370 sql.replace("?fields", &fields)
371}
372
373#[allow(dead_code)]
375fn replace_insert_fields_placeholder(sql: &str, field_names: &[&str]) -> String {
376 let fields = field_names.join(", ");
377 sql.replace("?fields", &fields)
378}
379
380pub struct ORMQueryBuilder<'a, T> {
382 connection: &'a Connection,
383 sql: String,
384 params: Option<Params>,
385 _phantom: std::marker::PhantomData<T>,
386}
387
388impl<'a, T> ORMQueryBuilder<'a, T>
389where
390 T: TryFrom<Row> + RowORM,
391 T::Error: std::fmt::Display,
392{
393 fn new(connection: &'a Connection, sql: &str) -> Self {
394 Self {
395 connection,
396 sql: sql.to_string(),
397 params: None,
398 _phantom: std::marker::PhantomData,
399 }
400 }
401
402 pub fn bind<P: Into<Params> + Send>(mut self, params: P) -> Self {
403 self.params = Some(params.into());
404 self
405 }
406
407 pub async fn execute(self) -> Result<QueryCursor<T>> {
408 let sql_with_fields = replace_query_fields_placeholder(&self.sql, &T::query_field_names());
409 let final_sql = if let Some(params) = self.params {
410 params.replace(&sql_with_fields)
411 } else {
412 sql_with_fields
413 };
414 let row_iter = self.connection.inner.query_iter(&final_sql).await?;
415 Ok(QueryCursor::new(row_iter))
416 }
417}
418
419impl<'a, T> std::future::IntoFuture for ORMQueryBuilder<'a, T>
420where
421 T: TryFrom<Row> + RowORM + Send + 'a,
422 T::Error: std::fmt::Display,
423{
424 type Output = Result<QueryCursor<T>>;
425 type IntoFuture =
426 std::pin::Pin<Box<dyn std::future::Future<Output = Self::Output> + Send + 'a>>;
427
428 fn into_future(self) -> Self::IntoFuture {
429 Box::pin(self.execute())
430 }
431}
432
433pub struct QueryBuilder<'a> {
435 connection: &'a Connection,
436 sql: String,
437 params: Option<Params>,
438}
439
440impl<'a> QueryBuilder<'a> {
441 fn new(connection: &'a Connection, sql: &str) -> Self {
442 Self {
443 connection,
444 sql: sql.to_string(),
445 params: None,
446 }
447 }
448
449 pub fn bind<P: Into<Params> + Send>(mut self, params: P) -> Self {
450 self.params = Some(params.into());
451 self
452 }
453
454 pub async fn iter(self) -> Result<RowIterator> {
455 if let Some(params) = &self.params {
456 if self.should_use_server_side_params() {
457 let json_params = params.to_json_value();
458 return self
459 .connection
460 .inner
461 .query_iter_with_params(&self.sql, Some(json_params))
462 .await;
463 }
464 }
465 let sql = self.get_final_sql();
466 self.connection.inner.query_iter(&sql).await
467 }
468
469 pub async fn iter_ext(self) -> Result<RowStatsIterator> {
470 if let Some(params) = &self.params {
471 if self.should_use_server_side_params() {
472 let json_params = params.to_json_value();
473 return self
474 .connection
475 .inner
476 .query_iter_ext_with_params(&self.sql, Some(json_params))
477 .await;
478 }
479 }
480 let sql = self.get_final_sql();
481 self.connection.inner.query_iter_ext(&sql).await
482 }
483
484 pub async fn one(self) -> Result<Option<Row>> {
485 if let Some(params) = &self.params {
486 if self.should_use_server_side_params() {
487 let json_params = params.to_json_value();
488 let mut rows = self
489 .connection
490 .inner
491 .query_iter_with_params(&self.sql, Some(json_params))
492 .await?;
493 return match rows.next().await {
494 Some(r) => Ok(Some(r?)),
495 None => Ok(None),
496 };
497 }
498 }
499 let sql = self.get_final_sql();
500 self.connection.inner.query_row(&sql).await
501 }
502
503 pub async fn all(self) -> Result<Vec<Row>> {
504 if let Some(params) = &self.params {
505 if self.should_use_server_side_params() {
506 let json_params = params.to_json_value();
507 let rows = self
508 .connection
509 .inner
510 .query_iter_with_params(&self.sql, Some(json_params))
511 .await?;
512 return rows.collect().await;
513 }
514 }
515 let sql = self.get_final_sql();
516 self.connection.inner.query_all(&sql).await
517 }
518
519 pub async fn cursor_as<T>(self) -> Result<QueryCursor<T>>
520 where
521 T: TryFrom<Row> + RowORM,
522 T::Error: std::fmt::Display,
523 {
524 let sql_with_fields = replace_query_fields_placeholder(&self.sql, &T::query_field_names());
525 let final_sql = if let Some(params) = self.params {
526 params.replace(&sql_with_fields)
527 } else {
528 sql_with_fields
529 };
530 let row_iter = self.connection.inner.query_iter(&final_sql).await?;
531 Ok(QueryCursor::new(row_iter))
532 }
533
534 fn should_use_server_side_params(&self) -> bool {
535 self.connection.inner.supports_server_side_params()
536 && !sql_has_dollar_placeholders(&self.sql)
537 }
538
539 fn get_final_sql(&self) -> String {
540 match &self.params {
541 Some(params) => params.replace(&self.sql),
542 None => self.sql.clone(),
543 }
544 }
545}
546
547pub struct ExecBuilder<'a> {
549 connection: &'a Connection,
550 sql: String,
551 params: Option<Params>,
552}
553
554impl<'a> ExecBuilder<'a> {
555 fn new(connection: &'a Connection, sql: &str) -> Self {
556 Self {
557 connection,
558 sql: sql.to_string(),
559 params: None,
560 }
561 }
562
563 pub fn bind<P: Into<Params> + Send>(mut self, params: P) -> Self {
564 self.params = Some(params.into());
565 self
566 }
567
568 pub async fn execute(self) -> Result<i64> {
569 if let Some(ref params) = self.params {
570 if self.should_use_server_side_params() {
571 let json_params = params.to_json_value();
572 return self
573 .connection
574 .inner
575 .exec_with_params(&self.sql, Some(json_params))
576 .await;
577 }
578 }
579 let sql = match self.params {
580 Some(params) => params.replace(&self.sql),
581 None => self.sql,
582 };
583 self.connection.inner.exec(&sql).await
584 }
585
586 fn should_use_server_side_params(&self) -> bool {
587 self.connection.inner.supports_server_side_params()
588 && !sql_has_dollar_placeholders(&self.sql)
589 }
590}
591
592impl<'a> std::future::IntoFuture for ExecBuilder<'a> {
593 type Output = Result<i64>;
594 type IntoFuture =
595 std::pin::Pin<Box<dyn std::future::Future<Output = Self::Output> + Send + 'a>>;
596
597 fn into_future(self) -> Self::IntoFuture {
598 Box::pin(self.execute())
599 }
600}
601
602fn sql_has_dollar_placeholders(sql: &str) -> bool {
603 let tokens = match databend_common_ast::parser::tokenize_sql(sql) {
604 Ok(t) => t,
605 Err(_) => return false,
606 };
607 if let Ok((stmt, _)) = databend_common_ast::parser::parse_sql(&tokens, Dialect::PostgreSQL) {
608 let mut visitor = PlaceholderVisitor::new();
609 return visitor.has_dollar_positions(&stmt);
610 }
611 false
612}
613
614pub trait RowORM: TryFrom<Row> + Clone {
616 fn field_names() -> Vec<&'static str>; fn query_field_names() -> Vec<&'static str>; fn insert_field_names() -> Vec<&'static str>; fn to_values(&self) -> Vec<Value>;
620}