1use log::error;
16use once_cell::sync::Lazy;
17use std::collections::BTreeMap;
18use std::path::Path;
19use std::str::FromStr;
20use url::Url;
21
22use crate::conn::IConnection;
23#[cfg(feature = "flight-sql")]
24use crate::flight_sql::FlightSQLConnection;
25use crate::ConnectionInfo;
26use crate::Params;
27
28use databend_client::PresignedResponse;
29use databend_driver_core::error::{Error, Result};
30use databend_driver_core::raw_rows::{RawRow, RawRowIterator};
31use databend_driver_core::rows::{Row, RowIterator, RowStatsIterator, ServerStats};
32use databend_driver_core::value::Value;
33
34use crate::rest_api::RestAPIConnection;
35
36static VERSION: Lazy<String> = Lazy::new(|| {
37 let version = option_env!("CARGO_PKG_VERSION").unwrap_or("unknown");
38 version.to_string()
39});
40
41#[derive(Clone, Copy, Debug, PartialEq)]
42pub enum LoadMethod {
43 Stage,
44 Streaming,
45}
46
47impl FromStr for LoadMethod {
48 type Err = Error;
49
50 fn from_str(s: &str) -> Result<Self, Self::Err> {
51 match s.to_lowercase().as_str() {
52 "stage" => Ok(LoadMethod::Stage),
53 "streaming" => Ok(LoadMethod::Streaming),
54 _ => Err(Error::BadArgument(format!("invalid load method: {s}"))),
55 }
56 }
57}
58
59#[derive(Clone)]
60pub struct Client {
61 dsn: String,
62 name: String,
63}
64
65use crate::conn::Reader;
66
67pub struct Connection {
68 inner: Box<dyn IConnection>,
69}
70
71impl Client {
72 pub fn new(dsn: String) -> Self {
73 let name = format!("databend-driver-rust/{}", VERSION.as_str());
74 Self { dsn, name }
75 }
76
77 pub fn with_name(mut self, name: String) -> Self {
78 self.name = name;
79 self
80 }
81
82 pub async fn get_conn(&self) -> Result<Connection> {
83 let u = Url::parse(&self.dsn)?;
84 match u.scheme() {
85 "databend" | "databend+http" | "databend+https" => {
86 let conn = RestAPIConnection::try_create(&self.dsn, self.name.clone()).await?;
87 Ok(Connection {
88 inner: Box::new(conn),
89 })
90 }
91 #[cfg(feature = "flight-sql")]
92 "databend+flight" | "databend+grpc" => {
93 let conn = FlightSQLConnection::try_create(&self.dsn, self.name.clone()).await?;
94 Ok(Connection {
95 inner: Box::new(conn),
96 })
97 }
98 _ => Err(Error::Parsing(format!(
99 "Unsupported scheme: {}",
100 u.scheme()
101 ))),
102 }
103 }
104}
105
106impl Drop for Connection {
107 fn drop(&mut self) {
108 if let Err(e) = self.inner.close_with_spawn() {
109 error!("fail to close connection: {}", e);
110 }
111 }
112}
113
114impl Connection {
115 pub fn inner(&self) -> &dyn IConnection {
116 self.inner.as_ref()
117 }
118
119 pub async fn info(&self) -> ConnectionInfo {
120 self.inner.info().await
121 }
122 pub async fn close(&self) -> Result<()> {
123 self.inner.close().await
124 }
125
126 pub fn last_query_id(&self) -> Option<String> {
127 self.inner.last_query_id()
128 }
129
130 pub async fn version(&self) -> Result<String> {
131 self.inner.version().await
132 }
133
134 pub fn format_sql<P: Into<Params> + Send>(&self, sql: &str, params: P) -> String {
135 let params = params.into();
136 params.replace(sql)
137 }
138
139 pub async fn kill_query(&self, query_id: &str) -> Result<()> {
140 self.inner.kill_query(query_id).await
141 }
142
143 pub fn query(&self, sql: &str) -> QueryBuilder<'_> {
144 QueryBuilder::new(self, sql)
145 }
146
147 pub fn exec(&self, sql: &str) -> ExecBuilder<'_> {
148 ExecBuilder::new(self, sql)
149 }
150
151 pub async fn query_iter(&self, sql: &str) -> Result<RowIterator> {
152 QueryBuilder::new(self, sql).iter().await
153 }
154
155 pub async fn query_iter_ext(&self, sql: &str) -> Result<RowStatsIterator> {
156 QueryBuilder::new(self, sql).iter_ext().await
157 }
158
159 pub async fn query_row(&self, sql: &str) -> Result<Option<Row>> {
160 QueryBuilder::new(self, sql).one().await
161 }
162
163 pub async fn query_all(&self, sql: &str) -> Result<Vec<Row>> {
164 QueryBuilder::new(self, sql).all().await
165 }
166
167 pub async fn query_raw_iter(&self, sql: &str) -> Result<RawRowIterator> {
169 self.inner.query_raw_iter(sql).await
170 }
171
172 pub async fn query_raw_all(&self, sql: &str) -> Result<Vec<RawRow>> {
174 self.inner.query_raw_all(sql).await
175 }
176
177 pub async fn get_presigned_url(
180 &self,
181 operation: &str,
182 stage: &str,
183 ) -> Result<PresignedResponse> {
184 self.inner.get_presigned_url(operation, stage).await
185 }
186
187 pub async fn upload_to_stage(&self, stage: &str, data: Reader, size: u64) -> Result<()> {
188 self.inner.upload_to_stage(stage, data, size).await
189 }
190
191 pub async fn load_data(
192 &self,
193 sql: &str,
194 data: Reader,
195 size: u64,
196 method: LoadMethod,
197 ) -> Result<ServerStats> {
198 self.inner.load_data(sql, data, size, method).await
199 }
200
201 pub async fn load_file(&self, sql: &str, fp: &Path, method: LoadMethod) -> Result<ServerStats> {
202 self.inner.load_file(sql, fp, method).await
203 }
204
205 pub async fn load_file_with_options(
206 &self,
207 sql: &str,
208 fp: &Path,
209 file_format_options: Option<BTreeMap<&str, &str>>,
210 copy_options: Option<BTreeMap<&str, &str>>,
211 ) -> Result<ServerStats> {
212 self.inner
213 .load_file_with_options(sql, fp, file_format_options, copy_options)
214 .await
215 }
216
217 pub async fn stream_load(
218 &self,
219 sql: &str,
220 data: Vec<Vec<&str>>,
221 method: LoadMethod,
222 ) -> Result<ServerStats> {
223 self.inner.stream_load(sql, data, method).await
224 }
225
226 pub async fn put_files(&self, local_file: &str, stage: &str) -> Result<RowStatsIterator> {
228 self.inner.put_files(local_file, stage).await
229 }
230
231 pub async fn get_files(&self, stage: &str, local_file: &str) -> Result<RowStatsIterator> {
232 self.inner.get_files(stage, local_file).await
233 }
234
235 pub fn query_as<T>(&self, sql: &str) -> ORMQueryBuilder<'_, T>
237 where
238 T: TryFrom<Row> + RowORM,
239 T::Error: std::fmt::Display,
240 {
241 ORMQueryBuilder::new(self, sql)
242 }
243
244 pub async fn insert<T>(&self, table_name: &str) -> Result<InsertCursor<'_, T>>
245 where
246 T: Clone + RowORM,
247 {
248 Ok(InsertCursor::new(self, table_name.to_string()))
249 }
250}
251
252pub struct QueryCursor<T> {
253 iter: RowIterator,
254 _phantom: std::marker::PhantomData<T>,
255}
256
257impl<T> QueryCursor<T>
258where
259 T: TryFrom<Row>,
260 T::Error: std::fmt::Display,
261{
262 fn new(iter: RowIterator) -> Self {
263 Self {
264 iter,
265 _phantom: std::marker::PhantomData,
266 }
267 }
268
269 pub async fn fetch(&mut self) -> Result<Option<T>> {
270 use tokio_stream::StreamExt;
271 match self.iter.next().await {
272 Some(row) => {
273 let row = row?;
274 let typed_row = T::try_from(row).map_err(|e| Error::Parsing(e.to_string()))?;
275 Ok(Some(typed_row))
276 }
277 None => Ok(None),
278 }
279 }
280
281 pub async fn next(&mut self) -> Result<Option<T>> {
282 self.fetch().await
283 }
284
285 pub async fn fetch_all(self) -> Result<Vec<T>> {
286 self.iter.try_collect().await
287 }
288}
289
290pub struct InsertCursor<'a, T> {
291 connection: &'a Connection,
292 table_name: String,
293 rows: Vec<T>,
294 _phantom: std::marker::PhantomData<T>,
295}
296
297impl<'a, T> InsertCursor<'a, T>
298where
299 T: Clone + RowORM,
300{
301 fn new(connection: &'a Connection, table_name: String) -> Self {
302 Self {
303 connection,
304 table_name,
305 rows: Vec::new(),
306 _phantom: std::marker::PhantomData,
307 }
308 }
309
310 pub async fn write(&mut self, row: &T) -> Result<()> {
311 self.rows.push(row.clone());
312 Ok(())
313 }
314
315 pub async fn end(self) -> Result<i64> {
316 if self.rows.is_empty() {
317 return Ok(0);
318 }
319 let connection = self.connection;
320 let field_names = T::insert_field_names();
322 let field_list = field_names.join(", ");
323 let placeholder_list = (0..field_names.len())
324 .map(|_| "?")
325 .collect::<Vec<_>>()
326 .join(", ");
327
328 let sql = format!(
329 "INSERT INTO {} ({}) VALUES ({})",
330 self.table_name, field_list, placeholder_list
331 );
332
333 let mut total_inserted = 0;
334 for row in &self.rows {
335 let values = row.to_values();
336 let param_strings: Vec<String> =
337 values.into_iter().map(|v| v.to_sql_string()).collect();
338 let params = Params::QuestionParams(param_strings);
339 let inserted = connection.exec(&sql).bind(params).await?;
340 total_inserted += inserted;
341 }
342
343 Ok(total_inserted)
344 }
345}
346
347fn replace_query_fields_placeholder(sql: &str, field_names: &[&str]) -> String {
349 let fields = field_names.join(", ");
350 sql.replace("?fields", &fields)
351}
352
353#[allow(dead_code)]
355fn replace_insert_fields_placeholder(sql: &str, field_names: &[&str]) -> String {
356 let fields = field_names.join(", ");
357 sql.replace("?fields", &fields)
358}
359
360pub struct ORMQueryBuilder<'a, T> {
362 connection: &'a Connection,
363 sql: String,
364 params: Option<Params>,
365 _phantom: std::marker::PhantomData<T>,
366}
367
368impl<'a, T> ORMQueryBuilder<'a, T>
369where
370 T: TryFrom<Row> + RowORM,
371 T::Error: std::fmt::Display,
372{
373 fn new(connection: &'a Connection, sql: &str) -> Self {
374 Self {
375 connection,
376 sql: sql.to_string(),
377 params: None,
378 _phantom: std::marker::PhantomData,
379 }
380 }
381
382 pub fn bind<P: Into<Params> + Send>(mut self, params: P) -> Self {
383 self.params = Some(params.into());
384 self
385 }
386
387 pub async fn execute(self) -> Result<QueryCursor<T>> {
388 let sql_with_fields = replace_query_fields_placeholder(&self.sql, &T::query_field_names());
389 let final_sql = if let Some(params) = self.params {
390 params.replace(&sql_with_fields)
391 } else {
392 sql_with_fields
393 };
394 let row_iter = self.connection.inner.query_iter(&final_sql).await?;
395 Ok(QueryCursor::new(row_iter))
396 }
397}
398
399impl<'a, T> std::future::IntoFuture for ORMQueryBuilder<'a, T>
400where
401 T: TryFrom<Row> + RowORM + Send + 'a,
402 T::Error: std::fmt::Display,
403{
404 type Output = Result<QueryCursor<T>>;
405 type IntoFuture =
406 std::pin::Pin<Box<dyn std::future::Future<Output = Self::Output> + Send + 'a>>;
407
408 fn into_future(self) -> Self::IntoFuture {
409 Box::pin(self.execute())
410 }
411}
412
413pub struct QueryBuilder<'a> {
415 connection: &'a Connection,
416 sql: String,
417 params: Option<Params>,
418}
419
420impl<'a> QueryBuilder<'a> {
421 fn new(connection: &'a Connection, sql: &str) -> Self {
422 Self {
423 connection,
424 sql: sql.to_string(),
425 params: None,
426 }
427 }
428
429 pub fn bind<P: Into<Params> + Send>(mut self, params: P) -> Self {
430 self.params = Some(params.into());
431 self
432 }
433
434 pub async fn iter(self) -> Result<RowIterator> {
435 let sql = self.get_final_sql();
436 self.connection.inner.query_iter(&sql).await
437 }
438
439 pub async fn iter_ext(self) -> Result<RowStatsIterator> {
440 let sql = self.get_final_sql();
441 self.connection.inner.query_iter_ext(&sql).await
442 }
443
444 pub async fn one(self) -> Result<Option<Row>> {
445 let sql = self.get_final_sql();
446 self.connection.inner.query_row(&sql).await
447 }
448
449 pub async fn all(self) -> Result<Vec<Row>> {
450 let sql = self.get_final_sql();
451 self.connection.inner.query_all(&sql).await
452 }
453
454 pub async fn cursor_as<T>(self) -> Result<QueryCursor<T>>
455 where
456 T: TryFrom<Row> + RowORM,
457 T::Error: std::fmt::Display,
458 {
459 let sql_with_fields = replace_query_fields_placeholder(&self.sql, &T::query_field_names());
460 let final_sql = if let Some(params) = self.params {
461 params.replace(&sql_with_fields)
462 } else {
463 sql_with_fields
464 };
465 let row_iter = self.connection.inner.query_iter(&final_sql).await?;
466 Ok(QueryCursor::new(row_iter))
467 }
468
469 fn get_final_sql(&self) -> String {
470 match &self.params {
471 Some(params) => params.replace(&self.sql),
472 None => self.sql.clone(),
473 }
474 }
475}
476
477pub struct ExecBuilder<'a> {
479 connection: &'a Connection,
480 sql: String,
481 params: Option<Params>,
482}
483
484impl<'a> ExecBuilder<'a> {
485 fn new(connection: &'a Connection, sql: &str) -> Self {
486 Self {
487 connection,
488 sql: sql.to_string(),
489 params: None,
490 }
491 }
492
493 pub fn bind<P: Into<Params> + Send>(mut self, params: P) -> Self {
494 self.params = Some(params.into());
495 self
496 }
497
498 pub async fn execute(self) -> Result<i64> {
499 let sql = match self.params {
500 Some(params) => params.replace(&self.sql),
501 None => self.sql,
502 };
503 self.connection.inner.exec(&sql).await
504 }
505}
506
507impl<'a> std::future::IntoFuture for ExecBuilder<'a> {
508 type Output = Result<i64>;
509 type IntoFuture =
510 std::pin::Pin<Box<dyn std::future::Future<Output = Self::Output> + Send + 'a>>;
511
512 fn into_future(self) -> Self::IntoFuture {
513 Box::pin(self.execute())
514 }
515}
516
517pub trait RowORM: TryFrom<Row> + Clone {
519 fn field_names() -> Vec<&'static str>; fn query_field_names() -> Vec<&'static str>; fn insert_field_names() -> Vec<&'static str>; fn to_values(&self) -> Vec<Value>;
523}