1use std::collections::BTreeMap;
16use std::path::Path;
17use std::sync::Arc;
18
19use async_trait::async_trait;
20use log::info;
21use tokio::fs::File;
22use tokio::io::AsyncRead;
23use tokio::io::BufReader;
24use tokio_stream::StreamExt;
25
26use crate::client::LoadMethod;
27use databend_client::schema::{DataType, Field, NumberDataType, Schema};
28use databend_client::StageLocation;
29use databend_client::{presign_download_from_stage, PresignedResponse};
30use databend_driver_core::error::{Error, Result};
31use databend_driver_core::raw_rows::{RawRow, RawRowIterator};
32use databend_driver_core::rows::{Row, RowIterator, RowStatsIterator, RowWithStats, ServerStats};
33use databend_driver_core::value::{NumberValue, Value};
34
35pub struct ConnectionInfo {
36 pub handler: String,
37 pub host: String,
38 pub port: u16,
39 pub user: String,
40 pub catalog: Option<String>,
41 pub database: Option<String>,
42 pub warehouse: Option<String>,
43}
44
45pub type Reader = Box<dyn AsyncRead + Send + Sync + Unpin + 'static>;
46
47#[async_trait]
48pub trait IConnection: Send + Sync {
49 async fn info(&self) -> ConnectionInfo;
50 async fn close(&self) -> Result<()> {
51 Ok(())
52 }
53
54 fn close_with_spawn(&self) -> Result<()> {
55 Ok(())
56 }
57
58 fn last_query_id(&self) -> Option<String>;
59
60 async fn version(&self) -> Result<String> {
61 let row = self.query_row("SELECT version()").await?;
62 let version = match row {
63 Some(row) => {
64 let (version,): (String,) = row.try_into().map_err(Error::Parsing)?;
65 version
66 }
67 None => "".to_string(),
68 };
69 Ok(version)
70 }
71
72 async fn exec(&self, sql: &str) -> Result<i64>;
73 async fn kill_query(&self, query_id: &str) -> Result<()>;
74 async fn query_iter(&self, sql: &str) -> Result<RowIterator>;
75 async fn query_iter_ext(&self, sql: &str) -> Result<RowStatsIterator>;
76
77 async fn query_row(&self, sql: &str) -> Result<Option<Row>> {
78 let rows = self.query_all(sql).await?;
79 let row = rows.into_iter().next();
80 Ok(row)
81 }
82
83 async fn query_all(&self, sql: &str) -> Result<Vec<Row>> {
84 let rows = self.query_iter(sql).await?;
85 rows.collect().await
86 }
87
88 async fn query_raw_iter(&self, _sql: &str) -> Result<RawRowIterator> {
90 Err(Error::BadArgument(
91 "Unsupported implement query_raw_iter".to_string(),
92 ))
93 }
94
95 async fn query_raw_all(&self, sql: &str) -> Result<Vec<RawRow>> {
97 let rows = self.query_raw_iter(sql).await?;
98 rows.collect().await
99 }
100
101 async fn get_presigned_url(&self, operation: &str, stage: &str) -> Result<PresignedResponse> {
104 info!("get presigned url: {} {}", operation, stage);
105 let sql = format!("PRESIGN {operation} {stage}");
106 let row = self.query_row(&sql).await?.ok_or_else(|| {
107 Error::InvalidResponse("Empty response from server for presigned request".to_string())
108 })?;
109 let (method, headers, url): (String, String, String) =
110 row.try_into().map_err(Error::Parsing)?;
111 let headers: BTreeMap<String, String> = serde_json::from_str(&headers)?;
112 Ok(PresignedResponse {
113 method,
114 headers,
115 url,
116 })
117 }
118
119 async fn upload_to_stage(&self, stage: &str, data: Reader, size: u64) -> Result<()>;
120
121 async fn load_data(
122 &self,
123 sql: &str,
124 data: Reader,
125 size: u64,
126 method: LoadMethod,
127 ) -> Result<ServerStats>;
128
129 async fn load_file(&self, sql: &str, fp: &Path, method: LoadMethod) -> Result<ServerStats>;
130
131 async fn load_file_with_options(
132 &self,
133 sql: &str,
134 fp: &Path,
135 file_format_options: Option<BTreeMap<&str, &str>>,
136 copy_options: Option<BTreeMap<&str, &str>>,
137 ) -> Result<ServerStats>;
138
139 async fn stream_load(
140 &self,
141 sql: &str,
142 data: Vec<Vec<&str>>,
143 _method: LoadMethod,
144 ) -> Result<ServerStats>;
145
146 fn set_warehouse(&self, warehouse: &str) -> Result<()>;
147
148 fn set_database(&self, database: &str) -> Result<()>;
149
150 fn set_role(&self, role: &str) -> Result<()>;
151
152 fn set_session(&self, key: &str, value: &str) -> Result<()>;
153
154 async fn put_files(&self, local_file: &str, stage: &str) -> Result<RowStatsIterator> {
156 let mut total_count: usize = 0;
157 let mut total_size: usize = 0;
158 let local_dsn = url::Url::parse(local_file)?;
159 validate_local_scheme(local_dsn.scheme())?;
160 let mut results = Vec::new();
161 let stage_location = StageLocation::try_from(stage)?;
162 let schema = Arc::new(put_get_schema());
163 for entry in glob::glob(local_dsn.path())? {
164 let entry = entry?;
165 let filename = entry
166 .file_name()
167 .ok_or_else(|| Error::BadArgument(format!("Invalid local file path: {entry:?}")))?
168 .to_str()
169 .ok_or_else(|| Error::BadArgument(format!("Invalid local file path: {entry:?}")))?;
170 let stage_file = stage_location.file_path(filename);
171 let file = File::open(&entry).await?;
172 let size = file.metadata().await?.len();
173 let data = BufReader::new(file);
174 let (fname, status) = match self
175 .upload_to_stage(&stage_file, Box::new(data), size)
176 .await
177 {
178 Ok(_) => {
179 total_count += 1;
180 total_size += size as usize;
181 (entry.to_string_lossy().to_string(), "SUCCESS".to_owned())
182 }
183 Err(e) => (entry.to_string_lossy().to_string(), e.to_string()),
184 };
185 let ss = ServerStats {
186 write_rows: total_count,
187 write_bytes: total_size,
188
189 ..Default::default()
190 };
191 results.push(Ok(RowWithStats::Stats(ss)));
192 results.push(Ok(RowWithStats::Row(Row::from_vec(
193 schema.clone(),
194 vec![
195 Value::String(fname),
196 Value::String(status),
197 Value::Number(NumberValue::UInt64(size)),
198 ],
199 ))));
200 }
201 Ok(RowStatsIterator::new(
202 schema,
203 Box::pin(tokio_stream::iter(results)),
204 ))
205 }
206
207 async fn get_files(&self, stage: &str, local_file: &str) -> Result<RowStatsIterator> {
208 let mut total_count: usize = 0;
209 let mut total_size: usize = 0;
210 let local_dsn = url::Url::parse(local_file)?;
211 validate_local_scheme(local_dsn.scheme())?;
212 let mut location = StageLocation::try_from(stage)?;
213 if !location.path.ends_with('/') {
214 location.path.push('/');
215 }
216 let list_sql = format!("LIST {location}");
217 let mut response = self.query_iter(&list_sql).await?;
218 let mut results = Vec::new();
219 let schema = Arc::new(put_get_schema());
220 while let Some(row) = response.next().await {
221 let (mut name, _, _, _, _): (String, u64, Option<String>, String, Option<String>) =
222 row?.try_into().map_err(Error::Parsing)?;
223 if !location.path.is_empty() && name.starts_with(&location.path) {
224 name = name[location.path.len()..].to_string();
225 }
226 let stage_file = format!("{location}/{name}");
227 let presign = self.get_presigned_url("DOWNLOAD", &stage_file).await?;
228 let local_file = Path::new(local_dsn.path()).join(&name);
229 let status = presign_download_from_stage(presign, &local_file).await;
230 let (status, size) = match status {
231 Ok(size) => {
232 total_count += 1;
233 total_size += size as usize;
234 ("SUCCESS".to_owned(), size)
235 }
236 Err(e) => (e.to_string(), 0),
237 };
238 let ss = ServerStats {
239 read_rows: total_count,
240 read_bytes: total_size,
241 ..Default::default()
242 };
243 results.push(Ok(RowWithStats::Stats(ss)));
244 results.push(Ok(RowWithStats::Row(Row::from_vec(
245 schema.clone(),
246 vec![
247 Value::String(local_file.to_string_lossy().to_string()),
248 Value::String(status),
249 Value::Number(NumberValue::UInt64(size)),
250 ],
251 ))));
252 }
253 Ok(RowStatsIterator::new(
254 schema,
255 Box::pin(tokio_stream::iter(results)),
256 ))
257 }
258}
259
260fn put_get_schema() -> Schema {
261 Schema::from_vec(vec![
262 Field {
263 name: "file".to_string(),
264 data_type: DataType::String,
265 },
266 Field {
267 name: "status".to_string(),
268 data_type: DataType::String,
269 },
270 Field {
271 name: "size".to_string(),
272 data_type: DataType::Number(NumberDataType::UInt64),
273 },
274 ])
275}
276
277fn validate_local_scheme(scheme: &str) -> Result<()> {
278 match scheme {
279 "file" | "fs" => Ok(()),
280 _ => Err(Error::BadArgument(
281 "Supported schemes: file:// or fs://".to_string(),
282 )),
283 }
284}