1use once_cell::sync::Lazy;
16use std::collections::BTreeMap;
17use std::path::Path;
18use std::str::FromStr;
19use url::Url;
20
21use crate::conn::IConnection;
22#[cfg(feature = "flight-sql")]
23use crate::flight_sql::FlightSQLConnection;
24use crate::ConnectionInfo;
25use crate::Params;
26
27use databend_client::PresignedResponse;
28use databend_driver_core::error::{Error, Result};
29use databend_driver_core::raw_rows::{RawRow, RawRowIterator};
30use databend_driver_core::rows::{Row, RowIterator, RowStatsIterator, ServerStats};
31use databend_driver_core::value::Value;
32
33use crate::rest_api::RestAPIConnection;
34
35static VERSION: Lazy<String> = Lazy::new(|| {
36 let version = option_env!("CARGO_PKG_VERSION").unwrap_or("unknown");
37 version.to_string()
38});
39
40#[derive(Clone, Copy, Debug, PartialEq)]
41pub enum LoadMethod {
42 Stage,
43 Streaming,
44}
45
46impl FromStr for LoadMethod {
47 type Err = Error;
48
49 fn from_str(s: &str) -> Result<Self, Self::Err> {
50 match s.to_lowercase().as_str() {
51 "stage" => Ok(LoadMethod::Stage),
52 "streaming" => Ok(LoadMethod::Streaming),
53 _ => Err(Error::BadArgument(format!("invalid load method: {s}"))),
54 }
55 }
56}
57
58#[derive(Clone)]
59pub struct Client {
60 dsn: String,
61 name: String,
62}
63
64use crate::conn::Reader;
65
66pub struct Connection {
67 inner: Box<dyn IConnection>,
68}
69
70impl Client {
71 pub fn new(dsn: String) -> Self {
72 let name = format!("databend-driver-rust/{}", VERSION.as_str());
73 Self { dsn, name }
74 }
75
76 pub fn with_name(mut self, name: String) -> Self {
77 self.name = name;
78 self
79 }
80
81 pub async fn get_conn(&self) -> Result<Connection> {
82 let u = Url::parse(&self.dsn)?;
83 match u.scheme() {
84 "databend" | "databend+http" | "databend+https" => {
85 let conn = RestAPIConnection::try_create(&self.dsn, self.name.clone()).await?;
86 Ok(Connection {
87 inner: Box::new(conn),
88 })
89 }
90 #[cfg(feature = "flight-sql")]
91 "databend+flight" | "databend+grpc" => {
92 let conn = FlightSQLConnection::try_create(&self.dsn, self.name.clone()).await?;
93 Ok(Connection {
94 inner: Box::new(conn),
95 })
96 }
97 _ => Err(Error::Parsing(format!(
98 "Unsupported scheme: {}",
99 u.scheme()
100 ))),
101 }
102 }
103}
104
105impl Connection {
106 pub fn inner(&self) -> &dyn IConnection {
107 self.inner.as_ref()
108 }
109
110 pub async fn info(&self) -> ConnectionInfo {
111 self.inner.info().await
112 }
113 pub async fn close(&self) -> Result<()> {
114 self.inner.close().await
115 }
116
117 pub fn last_query_id(&self) -> Option<String> {
118 self.inner.last_query_id()
119 }
120
121 pub async fn version(&self) -> Result<String> {
122 self.inner.version().await
123 }
124
125 pub fn format_sql<P: Into<Params> + Send>(&self, sql: &str, params: P) -> String {
126 let params = params.into();
127 params.replace(sql)
128 }
129
130 pub async fn kill_query(&self, query_id: &str) -> Result<()> {
131 self.inner.kill_query(query_id).await
132 }
133
134 pub fn query(&self, sql: &str) -> QueryBuilder<'_> {
135 QueryBuilder::new(self, sql)
136 }
137
138 pub fn exec(&self, sql: &str) -> ExecBuilder<'_> {
139 ExecBuilder::new(self, sql)
140 }
141
142 pub async fn query_iter(&self, sql: &str) -> Result<RowIterator> {
143 QueryBuilder::new(self, sql).iter().await
144 }
145
146 pub async fn query_iter_ext(&self, sql: &str) -> Result<RowStatsIterator> {
147 QueryBuilder::new(self, sql).iter_ext().await
148 }
149
150 pub async fn query_row(&self, sql: &str) -> Result<Option<Row>> {
151 QueryBuilder::new(self, sql).one().await
152 }
153
154 pub async fn query_all(&self, sql: &str) -> Result<Vec<Row>> {
155 QueryBuilder::new(self, sql).all().await
156 }
157
158 pub async fn query_raw_iter(&self, sql: &str) -> Result<RawRowIterator> {
160 self.inner.query_raw_iter(sql).await
161 }
162
163 pub async fn query_raw_all(&self, sql: &str) -> Result<Vec<RawRow>> {
165 self.inner.query_raw_all(sql).await
166 }
167
168 pub async fn get_presigned_url(
171 &self,
172 operation: &str,
173 stage: &str,
174 ) -> Result<PresignedResponse> {
175 self.inner.get_presigned_url(operation, stage).await
176 }
177
178 pub async fn upload_to_stage(&self, stage: &str, data: Reader, size: u64) -> Result<()> {
179 self.inner.upload_to_stage(stage, data, size).await
180 }
181
182 pub async fn load_data(
183 &self,
184 sql: &str,
185 data: Reader,
186 size: u64,
187 method: LoadMethod,
188 ) -> Result<ServerStats> {
189 self.inner.load_data(sql, data, size, method).await
190 }
191
192 pub async fn load_file(&self, sql: &str, fp: &Path, method: LoadMethod) -> Result<ServerStats> {
193 self.inner.load_file(sql, fp, method).await
194 }
195
196 pub async fn load_file_with_options(
197 &self,
198 sql: &str,
199 fp: &Path,
200 file_format_options: Option<BTreeMap<&str, &str>>,
201 copy_options: Option<BTreeMap<&str, &str>>,
202 ) -> Result<ServerStats> {
203 self.inner
204 .load_file_with_options(sql, fp, file_format_options, copy_options)
205 .await
206 }
207
208 pub async fn stream_load(
209 &self,
210 sql: &str,
211 data: Vec<Vec<&str>>,
212 method: LoadMethod,
213 ) -> Result<ServerStats> {
214 self.inner.stream_load(sql, data, method).await
215 }
216
217 pub async fn put_files(&self, local_file: &str, stage: &str) -> Result<RowStatsIterator> {
219 self.inner.put_files(local_file, stage).await
220 }
221
222 pub async fn get_files(&self, stage: &str, local_file: &str) -> Result<RowStatsIterator> {
223 self.inner.get_files(stage, local_file).await
224 }
225
226 pub fn query_as<T>(&self, sql: &str) -> ORMQueryBuilder<'_, T>
228 where
229 T: TryFrom<Row> + RowORM,
230 T::Error: std::fmt::Display,
231 {
232 ORMQueryBuilder::new(self, sql)
233 }
234
235 pub async fn insert<T>(&self, table_name: &str) -> Result<InsertCursor<'_, T>>
236 where
237 T: Clone + RowORM,
238 {
239 Ok(InsertCursor::new(self, table_name.to_string()))
240 }
241}
242
243pub struct QueryCursor<T> {
244 iter: RowIterator,
245 _phantom: std::marker::PhantomData<T>,
246}
247
248impl<T> QueryCursor<T>
249where
250 T: TryFrom<Row>,
251 T::Error: std::fmt::Display,
252{
253 fn new(iter: RowIterator) -> Self {
254 Self {
255 iter,
256 _phantom: std::marker::PhantomData,
257 }
258 }
259
260 pub async fn fetch(&mut self) -> Result<Option<T>> {
261 use tokio_stream::StreamExt;
262 match self.iter.next().await {
263 Some(row) => {
264 let row = row?;
265 let typed_row = T::try_from(row).map_err(|e| Error::Parsing(e.to_string()))?;
266 Ok(Some(typed_row))
267 }
268 None => Ok(None),
269 }
270 }
271
272 pub async fn next(&mut self) -> Result<Option<T>> {
273 self.fetch().await
274 }
275
276 pub async fn fetch_all(self) -> Result<Vec<T>> {
277 self.iter.try_collect().await
278 }
279}
280
281pub struct InsertCursor<'a, T> {
282 connection: &'a Connection,
283 table_name: String,
284 rows: Vec<T>,
285 _phantom: std::marker::PhantomData<T>,
286}
287
288impl<'a, T> InsertCursor<'a, T>
289where
290 T: Clone + RowORM,
291{
292 fn new(connection: &'a Connection, table_name: String) -> Self {
293 Self {
294 connection,
295 table_name,
296 rows: Vec::new(),
297 _phantom: std::marker::PhantomData,
298 }
299 }
300
301 pub async fn write(&mut self, row: &T) -> Result<()> {
302 self.rows.push(row.clone());
303 Ok(())
304 }
305
306 pub async fn end(self) -> Result<i64> {
307 if self.rows.is_empty() {
308 return Ok(0);
309 }
310 let connection = self.connection;
311 let field_names = T::insert_field_names();
313 let field_list = field_names.join(", ");
314 let placeholder_list = (0..field_names.len())
315 .map(|_| "?")
316 .collect::<Vec<_>>()
317 .join(", ");
318
319 let sql = format!(
320 "INSERT INTO {} ({}) VALUES ({})",
321 self.table_name, field_list, placeholder_list
322 );
323
324 let mut total_inserted = 0;
325 for row in &self.rows {
326 let values = row.to_values();
327 let param_strings: Vec<String> =
328 values.into_iter().map(|v| v.to_sql_string()).collect();
329 let params = Params::QuestionParams(param_strings);
330 let inserted = connection.exec(&sql).bind(params).await?;
331 total_inserted += inserted;
332 }
333
334 Ok(total_inserted)
335 }
336}
337
338fn replace_query_fields_placeholder(sql: &str, field_names: &[&str]) -> String {
340 let fields = field_names.join(", ");
341 sql.replace("?fields", &fields)
342}
343
344#[allow(dead_code)]
346fn replace_insert_fields_placeholder(sql: &str, field_names: &[&str]) -> String {
347 let fields = field_names.join(", ");
348 sql.replace("?fields", &fields)
349}
350
351pub struct ORMQueryBuilder<'a, T> {
353 connection: &'a Connection,
354 sql: String,
355 params: Option<Params>,
356 _phantom: std::marker::PhantomData<T>,
357}
358
359impl<'a, T> ORMQueryBuilder<'a, T>
360where
361 T: TryFrom<Row> + RowORM,
362 T::Error: std::fmt::Display,
363{
364 fn new(connection: &'a Connection, sql: &str) -> Self {
365 Self {
366 connection,
367 sql: sql.to_string(),
368 params: None,
369 _phantom: std::marker::PhantomData,
370 }
371 }
372
373 pub fn bind<P: Into<Params> + Send>(mut self, params: P) -> Self {
374 self.params = Some(params.into());
375 self
376 }
377
378 pub async fn execute(self) -> Result<QueryCursor<T>> {
379 let sql_with_fields = replace_query_fields_placeholder(&self.sql, &T::query_field_names());
380 let final_sql = if let Some(params) = self.params {
381 params.replace(&sql_with_fields)
382 } else {
383 sql_with_fields
384 };
385 let row_iter = self.connection.inner.query_iter(&final_sql).await?;
386 Ok(QueryCursor::new(row_iter))
387 }
388}
389
390impl<'a, T> std::future::IntoFuture for ORMQueryBuilder<'a, T>
391where
392 T: TryFrom<Row> + RowORM + Send + 'a,
393 T::Error: std::fmt::Display,
394{
395 type Output = Result<QueryCursor<T>>;
396 type IntoFuture =
397 std::pin::Pin<Box<dyn std::future::Future<Output = Self::Output> + Send + 'a>>;
398
399 fn into_future(self) -> Self::IntoFuture {
400 Box::pin(self.execute())
401 }
402}
403
404pub struct QueryBuilder<'a> {
406 connection: &'a Connection,
407 sql: String,
408 params: Option<Params>,
409}
410
411impl<'a> QueryBuilder<'a> {
412 fn new(connection: &'a Connection, sql: &str) -> Self {
413 Self {
414 connection,
415 sql: sql.to_string(),
416 params: None,
417 }
418 }
419
420 pub fn bind<P: Into<Params> + Send>(mut self, params: P) -> Self {
421 self.params = Some(params.into());
422 self
423 }
424
425 pub async fn iter(self) -> Result<RowIterator> {
426 let sql = self.get_final_sql();
427 self.connection.inner.query_iter(&sql).await
428 }
429
430 pub async fn iter_ext(self) -> Result<RowStatsIterator> {
431 let sql = self.get_final_sql();
432 self.connection.inner.query_iter_ext(&sql).await
433 }
434
435 pub async fn one(self) -> Result<Option<Row>> {
436 let sql = self.get_final_sql();
437 self.connection.inner.query_row(&sql).await
438 }
439
440 pub async fn all(self) -> Result<Vec<Row>> {
441 let sql = self.get_final_sql();
442 self.connection.inner.query_all(&sql).await
443 }
444
445 pub async fn cursor_as<T>(self) -> Result<QueryCursor<T>>
446 where
447 T: TryFrom<Row> + RowORM,
448 T::Error: std::fmt::Display,
449 {
450 let sql_with_fields = replace_query_fields_placeholder(&self.sql, &T::query_field_names());
451 let final_sql = if let Some(params) = self.params {
452 params.replace(&sql_with_fields)
453 } else {
454 sql_with_fields
455 };
456 let row_iter = self.connection.inner.query_iter(&final_sql).await?;
457 Ok(QueryCursor::new(row_iter))
458 }
459
460 fn get_final_sql(&self) -> String {
461 match &self.params {
462 Some(params) => params.replace(&self.sql),
463 None => self.sql.clone(),
464 }
465 }
466}
467
468pub struct ExecBuilder<'a> {
470 connection: &'a Connection,
471 sql: String,
472 params: Option<Params>,
473}
474
475impl<'a> ExecBuilder<'a> {
476 fn new(connection: &'a Connection, sql: &str) -> Self {
477 Self {
478 connection,
479 sql: sql.to_string(),
480 params: None,
481 }
482 }
483
484 pub fn bind<P: Into<Params> + Send>(mut self, params: P) -> Self {
485 self.params = Some(params.into());
486 self
487 }
488
489 pub async fn execute(self) -> Result<i64> {
490 let sql = match self.params {
491 Some(params) => params.replace(&self.sql),
492 None => self.sql,
493 };
494 self.connection.inner.exec(&sql).await
495 }
496}
497
498impl<'a> std::future::IntoFuture for ExecBuilder<'a> {
499 type Output = Result<i64>;
500 type IntoFuture =
501 std::pin::Pin<Box<dyn std::future::Future<Output = Self::Output> + Send + 'a>>;
502
503 fn into_future(self) -> Self::IntoFuture {
504 Box::pin(self.execute())
505 }
506}
507
508pub trait RowORM: TryFrom<Row> + Clone {
510 fn field_names() -> Vec<&'static str>; fn query_field_names() -> Vec<&'static str>; fn insert_field_names() -> Vec<&'static str>; fn to_values(&self) -> Vec<Value>;
514}