1use std::collections::BTreeMap;
16use std::path::Path;
17use std::sync::Arc;
18
19use async_trait::async_trait;
20use once_cell::sync::Lazy;
21use tokio::fs::File;
22use tokio::io::AsyncRead;
23use tokio::io::BufReader;
24use tokio_stream::StreamExt;
25use url::Url;
26
27#[cfg(feature = "flight-sql")]
28use crate::flight_sql::FlightSQLConnection;
29
30use databend_client::StageLocation;
31use databend_client::{presign_download_from_stage, PresignedResponse};
32use bigbytes_driver_core::error::{Error, Result};
33use bigbytes_driver_core::raw_rows::{RawRow, RawRowIterator};
34use bigbytes_driver_core::rows::{Row, RowIterator, RowStatsIterator, RowWithStats, ServerStats};
35use bigbytes_driver_core::schema::{DataType, Field, NumberDataType, Schema};
36use bigbytes_driver_core::value::{NumberValue, Value};
37
38use crate::rest_api::RestAPIConnection;
39
40static VERSION: Lazy<String> = Lazy::new(|| {
41 let version = option_env!("CARGO_PKG_VERSION").unwrap_or("unknown");
42 version.to_string()
43});
44
45#[derive(Clone)]
46pub struct Client {
47 dsn: String,
48 name: String,
49}
50
51impl Client {
52 pub fn new(dsn: String) -> Self {
53 let name = format!("bigbytes-driver-rust/{}", VERSION.as_str());
54 Self { dsn, name }
55 }
56
57 pub fn with_name(mut self, name: String) -> Self {
58 self.name = name;
59 self
60 }
61
62 pub async fn get_conn(&self) -> Result<Box<dyn Connection>> {
63 let u = Url::parse(&self.dsn)?;
64 match u.scheme() {
65 "databend" | "databend+http" | "databend+https" => {
66 let conn = RestAPIConnection::try_create(&self.dsn, self.name.clone()).await?;
67 Ok(Box::new(conn))
68 }
69 #[cfg(feature = "flight-sql")]
70 "databend+flight" | "databend+grpc" => {
71 let conn = FlightSQLConnection::try_create(&self.dsn, self.name.clone()).await?;
72 Ok(Box::new(conn))
73 }
74 _ => Err(Error::Parsing(format!(
75 "Unsupported scheme: {}",
76 u.scheme()
77 ))),
78 }
79 }
80}
81
82pub struct ConnectionInfo {
83 pub handler: String,
84 pub host: String,
85 pub port: u16,
86 pub user: String,
87 pub database: Option<String>,
88 pub warehouse: Option<String>,
89}
90
91pub type Reader = Box<dyn AsyncRead + Send + Sync + Unpin + 'static>;
92
93#[async_trait]
94pub trait Connection: Send + Sync {
95 async fn info(&self) -> ConnectionInfo;
96 async fn close(&self) -> Result<()> {
97 Ok(())
98 }
99
100 fn last_query_id(&self) -> Option<String>;
101
102 async fn version(&self) -> Result<String> {
103 let row = self.query_row("SELECT version()").await?;
104 let version = match row {
105 Some(row) => {
106 let (version,): (String,) = row.try_into().map_err(Error::Parsing)?;
107 version
108 }
109 None => "".to_string(),
110 };
111 Ok(version)
112 }
113
114 async fn exec(&self, sql: &str) -> Result<i64>;
115 async fn kill_query(&self, query_id: &str) -> Result<()>;
116 async fn query_iter(&self, sql: &str) -> Result<RowIterator>;
117 async fn query_iter_ext(&self, sql: &str) -> Result<RowStatsIterator>;
118
119 async fn query_row(&self, sql: &str) -> Result<Option<Row>> {
120 let rows = self.query_all(sql).await?;
121 let row = rows.into_iter().next();
122 Ok(row)
123 }
124
125 async fn query_all(&self, sql: &str) -> Result<Vec<Row>> {
126 let rows = self.query_iter(sql).await?;
127 rows.collect().await
128 }
129
130 async fn query_raw_iter(&self, _sql: &str) -> Result<RawRowIterator> {
132 Err(Error::BadArgument(
133 "Unsupported implement query_raw_iter".to_string(),
134 ))
135 }
136
137 async fn query_raw_all(&self, sql: &str) -> Result<Vec<RawRow>> {
139 let rows = self.query_raw_iter(sql).await?;
140 rows.collect().await
141 }
142
143 async fn get_presigned_url(&self, operation: &str, stage: &str) -> Result<PresignedResponse>;
146
147 async fn upload_to_stage(&self, stage: &str, data: Reader, size: u64) -> Result<()>;
148
149 async fn load_data(
150 &self,
151 sql: &str,
152 data: Reader,
153 size: u64,
154 file_format_options: Option<BTreeMap<&str, &str>>,
155 copy_options: Option<BTreeMap<&str, &str>>,
156 ) -> Result<ServerStats>;
157
158 async fn load_file(
159 &self,
160 sql: &str,
161 fp: &Path,
162 format_options: Option<BTreeMap<&str, &str>>,
163 copy_options: Option<BTreeMap<&str, &str>>,
164 ) -> Result<ServerStats>;
165
166 async fn stream_load(&self, sql: &str, data: Vec<Vec<&str>>) -> Result<ServerStats>;
167
168 async fn put_files(&self, local_file: &str, stage: &str) -> Result<RowStatsIterator> {
170 let mut total_count: usize = 0;
171 let mut total_size: usize = 0;
172 let local_dsn = url::Url::parse(local_file)?;
173 validate_local_scheme(local_dsn.scheme())?;
174 let mut results = Vec::new();
175 let stage_location = StageLocation::try_from(stage)?;
176 let schema = Arc::new(put_get_schema());
177 for entry in glob::glob(local_dsn.path())? {
178 let entry = entry?;
179 let filename = entry
180 .file_name()
181 .ok_or_else(|| Error::BadArgument(format!("Invalid local file path: {:?}", entry)))?
182 .to_str()
183 .ok_or_else(|| {
184 Error::BadArgument(format!("Invalid local file path: {:?}", entry))
185 })?;
186 let stage_file = stage_location.file_path(filename);
187 let file = File::open(&entry).await?;
188 let size = file.metadata().await?.len();
189 let data = BufReader::new(file);
190 let (fname, status) = match self
191 .upload_to_stage(&stage_file, Box::new(data), size)
192 .await
193 {
194 Ok(_) => {
195 total_count += 1;
196 total_size += size as usize;
197 (entry.to_string_lossy().to_string(), "SUCCESS".to_owned())
198 }
199 Err(e) => (entry.to_string_lossy().to_string(), e.to_string()),
200 };
201 let ss = ServerStats {
202 write_rows: total_count,
203 write_bytes: total_size,
204
205 ..Default::default()
206 };
207 results.push(Ok(RowWithStats::Stats(ss)));
208 results.push(Ok(RowWithStats::Row(Row::from_vec(
209 schema.clone(),
210 vec![
211 Value::String(fname),
212 Value::String(status),
213 Value::Number(NumberValue::UInt64(size)),
214 ],
215 ))));
216 }
217 Ok(RowStatsIterator::new(
218 schema,
219 Box::pin(tokio_stream::iter(results)),
220 ))
221 }
222
223 async fn get_files(&self, stage: &str, local_file: &str) -> Result<RowStatsIterator> {
224 let mut total_count: usize = 0;
225 let mut total_size: usize = 0;
226 let local_dsn = url::Url::parse(local_file)?;
227 validate_local_scheme(local_dsn.scheme())?;
228 let mut location = StageLocation::try_from(stage)?;
229 if !location.path.ends_with('/') {
230 location.path.push('/');
231 }
232 let list_sql = format!("LIST {}", location);
233 let mut response = self.query_iter(&list_sql).await?;
234 let mut results = Vec::new();
235 let schema = Arc::new(put_get_schema());
236 while let Some(row) = response.next().await {
237 let (mut name, _, _, _, _): (String, u64, Option<String>, String, Option<String>) =
238 row?.try_into().map_err(Error::Parsing)?;
239 if !location.path.is_empty() && name.starts_with(&location.path) {
240 name = name[location.path.len()..].to_string();
241 }
242 let stage_file = format!("{}/{}", location, name);
243 let presign = self.get_presigned_url("DOWNLOAD", &stage_file).await?;
244 let local_file = Path::new(local_dsn.path()).join(&name);
245 let status = presign_download_from_stage(presign, &local_file).await;
246 let (status, size) = match status {
247 Ok(size) => {
248 total_count += 1;
249 total_size += size as usize;
250 ("SUCCESS".to_owned(), size)
251 }
252 Err(e) => (e.to_string(), 0),
253 };
254 let ss = ServerStats {
255 read_rows: total_count,
256 read_bytes: total_size,
257 ..Default::default()
258 };
259 results.push(Ok(RowWithStats::Stats(ss)));
260 results.push(Ok(RowWithStats::Row(Row::from_vec(
261 schema.clone(),
262 vec![
263 Value::String(local_file.to_string_lossy().to_string()),
264 Value::String(status),
265 Value::Number(NumberValue::UInt64(size)),
266 ],
267 ))));
268 }
269 Ok(RowStatsIterator::new(
270 schema,
271 Box::pin(tokio_stream::iter(results)),
272 ))
273 }
274}
275
276fn put_get_schema() -> Schema {
277 Schema::from_vec(vec![
278 Field {
279 name: "file".to_string(),
280 data_type: DataType::String,
281 },
282 Field {
283 name: "status".to_string(),
284 data_type: DataType::String,
285 },
286 Field {
287 name: "size".to_string(),
288 data_type: DataType::Number(NumberDataType::UInt64),
289 },
290 ])
291}
292
293fn validate_local_scheme(scheme: &str) -> Result<()> {
294 match scheme {
295 "file" | "fs" => Ok(()),
296 _ => Err(Error::BadArgument(
297 "Supported schemes: file:// or fs://".to_string(),
298 )),
299 }
300}