1use std::collections::BTreeMap;
16use std::path::Path;
17use std::sync::Arc;
18
19use async_trait::async_trait;
20use log::info;
21use tokio::fs::File;
22use tokio::io::AsyncRead;
23use tokio::io::BufReader;
24use tokio_stream::StreamExt;
25
26use crate::client::LoadMethod;
27use databend_client::StageLocation;
28use databend_client::{presign_download_from_stage, PresignedResponse};
29use databend_driver_core::error::{Error, Result};
30use databend_driver_core::raw_rows::{RawRow, RawRowIterator};
31use databend_driver_core::rows::{Row, RowIterator, RowStatsIterator, RowWithStats, ServerStats};
32use databend_driver_core::schema::{DataType, Field, NumberDataType, Schema};
33use databend_driver_core::value::{NumberValue, Value};
34
35pub struct ConnectionInfo {
36 pub handler: String,
37 pub host: String,
38 pub port: u16,
39 pub user: String,
40 pub catalog: Option<String>,
41 pub database: Option<String>,
42 pub warehouse: Option<String>,
43}
44
45pub type Reader = Box<dyn AsyncRead + Send + Sync + Unpin + 'static>;
46
47#[async_trait]
48pub trait IConnection: Send + Sync {
49 async fn info(&self) -> ConnectionInfo;
50 async fn close(&self) -> Result<()> {
51 Ok(())
52 }
53
54 fn last_query_id(&self) -> Option<String>;
55
56 async fn version(&self) -> Result<String> {
57 let row = self.query_row("SELECT version()").await?;
58 let version = match row {
59 Some(row) => {
60 let (version,): (String,) = row.try_into().map_err(Error::Parsing)?;
61 version
62 }
63 None => "".to_string(),
64 };
65 Ok(version)
66 }
67
68 async fn exec(&self, sql: &str) -> Result<i64>;
69 async fn kill_query(&self, query_id: &str) -> Result<()>;
70 async fn query_iter(&self, sql: &str) -> Result<RowIterator>;
71 async fn query_iter_ext(&self, sql: &str) -> Result<RowStatsIterator>;
72
73 async fn query_row(&self, sql: &str) -> Result<Option<Row>> {
74 let rows = self.query_all(sql).await?;
75 let row = rows.into_iter().next();
76 Ok(row)
77 }
78
79 async fn query_all(&self, sql: &str) -> Result<Vec<Row>> {
80 let rows = self.query_iter(sql).await?;
81 rows.collect().await
82 }
83
84 async fn query_raw_iter(&self, _sql: &str) -> Result<RawRowIterator> {
86 Err(Error::BadArgument(
87 "Unsupported implement query_raw_iter".to_string(),
88 ))
89 }
90
91 async fn query_raw_all(&self, sql: &str) -> Result<Vec<RawRow>> {
93 let rows = self.query_raw_iter(sql).await?;
94 rows.collect().await
95 }
96
97 async fn get_presigned_url(&self, operation: &str, stage: &str) -> Result<PresignedResponse> {
100 info!("get presigned url: {} {}", operation, stage);
101 let sql = format!("PRESIGN {operation} {stage}");
102 let row = self.query_row(&sql).await?.ok_or_else(|| {
103 Error::InvalidResponse("Empty response from server for presigned request".to_string())
104 })?;
105 let (method, headers, url): (String, String, String) =
106 row.try_into().map_err(Error::Parsing)?;
107 let headers: BTreeMap<String, String> = serde_json::from_str(&headers)?;
108 Ok(PresignedResponse {
109 method,
110 headers,
111 url,
112 })
113 }
114
115 async fn upload_to_stage(&self, stage: &str, data: Reader, size: u64) -> Result<()>;
116
117 async fn load_data(
118 &self,
119 sql: &str,
120 data: Reader,
121 size: u64,
122 method: LoadMethod,
123 ) -> Result<ServerStats>;
124
125 async fn load_file(&self, sql: &str, fp: &Path, method: LoadMethod) -> Result<ServerStats>;
126
127 async fn load_file_with_options(
128 &self,
129 sql: &str,
130 fp: &Path,
131 file_format_options: Option<BTreeMap<&str, &str>>,
132 copy_options: Option<BTreeMap<&str, &str>>,
133 ) -> Result<ServerStats>;
134
135 async fn stream_load(
136 &self,
137 sql: &str,
138 data: Vec<Vec<&str>>,
139 _method: LoadMethod,
140 ) -> Result<ServerStats>;
141
142 async fn put_files(&self, local_file: &str, stage: &str) -> Result<RowStatsIterator> {
144 let mut total_count: usize = 0;
145 let mut total_size: usize = 0;
146 let local_dsn = url::Url::parse(local_file)?;
147 validate_local_scheme(local_dsn.scheme())?;
148 let mut results = Vec::new();
149 let stage_location = StageLocation::try_from(stage)?;
150 let schema = Arc::new(put_get_schema());
151 for entry in glob::glob(local_dsn.path())? {
152 let entry = entry?;
153 let filename = entry
154 .file_name()
155 .ok_or_else(|| Error::BadArgument(format!("Invalid local file path: {entry:?}")))?
156 .to_str()
157 .ok_or_else(|| Error::BadArgument(format!("Invalid local file path: {entry:?}")))?;
158 let stage_file = stage_location.file_path(filename);
159 let file = File::open(&entry).await?;
160 let size = file.metadata().await?.len();
161 let data = BufReader::new(file);
162 let (fname, status) = match self
163 .upload_to_stage(&stage_file, Box::new(data), size)
164 .await
165 {
166 Ok(_) => {
167 total_count += 1;
168 total_size += size as usize;
169 (entry.to_string_lossy().to_string(), "SUCCESS".to_owned())
170 }
171 Err(e) => (entry.to_string_lossy().to_string(), e.to_string()),
172 };
173 let ss = ServerStats {
174 write_rows: total_count,
175 write_bytes: total_size,
176
177 ..Default::default()
178 };
179 results.push(Ok(RowWithStats::Stats(ss)));
180 results.push(Ok(RowWithStats::Row(Row::from_vec(
181 schema.clone(),
182 vec![
183 Value::String(fname),
184 Value::String(status),
185 Value::Number(NumberValue::UInt64(size)),
186 ],
187 ))));
188 }
189 Ok(RowStatsIterator::new(
190 schema,
191 Box::pin(tokio_stream::iter(results)),
192 ))
193 }
194
195 async fn get_files(&self, stage: &str, local_file: &str) -> Result<RowStatsIterator> {
196 let mut total_count: usize = 0;
197 let mut total_size: usize = 0;
198 let local_dsn = url::Url::parse(local_file)?;
199 validate_local_scheme(local_dsn.scheme())?;
200 let mut location = StageLocation::try_from(stage)?;
201 if !location.path.ends_with('/') {
202 location.path.push('/');
203 }
204 let list_sql = format!("LIST {location}");
205 let mut response = self.query_iter(&list_sql).await?;
206 let mut results = Vec::new();
207 let schema = Arc::new(put_get_schema());
208 while let Some(row) = response.next().await {
209 let (mut name, _, _, _, _): (String, u64, Option<String>, String, Option<String>) =
210 row?.try_into().map_err(Error::Parsing)?;
211 if !location.path.is_empty() && name.starts_with(&location.path) {
212 name = name[location.path.len()..].to_string();
213 }
214 let stage_file = format!("{location}/{name}");
215 let presign = self.get_presigned_url("DOWNLOAD", &stage_file).await?;
216 let local_file = Path::new(local_dsn.path()).join(&name);
217 let status = presign_download_from_stage(presign, &local_file).await;
218 let (status, size) = match status {
219 Ok(size) => {
220 total_count += 1;
221 total_size += size as usize;
222 ("SUCCESS".to_owned(), size)
223 }
224 Err(e) => (e.to_string(), 0),
225 };
226 let ss = ServerStats {
227 read_rows: total_count,
228 read_bytes: total_size,
229 ..Default::default()
230 };
231 results.push(Ok(RowWithStats::Stats(ss)));
232 results.push(Ok(RowWithStats::Row(Row::from_vec(
233 schema.clone(),
234 vec![
235 Value::String(local_file.to_string_lossy().to_string()),
236 Value::String(status),
237 Value::Number(NumberValue::UInt64(size)),
238 ],
239 ))));
240 }
241 Ok(RowStatsIterator::new(
242 schema,
243 Box::pin(tokio_stream::iter(results)),
244 ))
245 }
246}
247
248fn put_get_schema() -> Schema {
249 Schema::from_vec(vec![
250 Field {
251 name: "file".to_string(),
252 data_type: DataType::String,
253 },
254 Field {
255 name: "status".to_string(),
256 data_type: DataType::String,
257 },
258 Field {
259 name: "size".to_string(),
260 data_type: DataType::Number(NumberDataType::UInt64),
261 },
262 ])
263}
264
265fn validate_local_scheme(scheme: &str) -> Result<()> {
266 match scheme {
267 "file" | "fs" => Ok(()),
268 _ => Err(Error::BadArgument(
269 "Supported schemes: file:// or fs://".to_string(),
270 )),
271 }
272}