1use std::collections::BTreeMap;
16use std::path::Path;
17use std::sync::Arc;
18
19use async_trait::async_trait;
20use log::info;
21use tokio::fs::File;
22use tokio::io::AsyncRead;
23use tokio::io::BufReader;
24use tokio_stream::StreamExt;
25
26use crate::client::LoadMethod;
27use databend_client::StageLocation;
28use databend_client::{presign_download_from_stage, PresignedResponse};
29use databend_driver_core::error::{Error, Result};
30use databend_driver_core::raw_rows::{RawRow, RawRowIterator};
31use databend_driver_core::rows::{Row, RowIterator, RowStatsIterator, RowWithStats, ServerStats};
32use databend_driver_core::schema::{DataType, Field, NumberDataType, Schema};
33use databend_driver_core::value::{NumberValue, Value};
34
35pub struct ConnectionInfo {
36 pub handler: String,
37 pub host: String,
38 pub port: u16,
39 pub user: String,
40 pub catalog: Option<String>,
41 pub database: Option<String>,
42 pub warehouse: Option<String>,
43}
44
45pub type Reader = Box<dyn AsyncRead + Send + Sync + Unpin + 'static>;
46
47#[async_trait]
48pub trait IConnection: Send + Sync {
49 async fn info(&self) -> ConnectionInfo;
50 async fn close(&self) -> Result<()> {
51 Ok(())
52 }
53
54 fn close_with_spawn(&self) -> Result<()> {
55 Ok(())
56 }
57
58 fn last_query_id(&self) -> Option<String>;
59
60 async fn version(&self) -> Result<String> {
61 let row = self.query_row("SELECT version()").await?;
62 let version = match row {
63 Some(row) => {
64 let (version,): (String,) = row.try_into().map_err(Error::Parsing)?;
65 version
66 }
67 None => "".to_string(),
68 };
69 Ok(version)
70 }
71
72 async fn exec(&self, sql: &str) -> Result<i64>;
73 async fn kill_query(&self, query_id: &str) -> Result<()>;
74 async fn query_iter(&self, sql: &str) -> Result<RowIterator>;
75 async fn query_iter_ext(&self, sql: &str) -> Result<RowStatsIterator>;
76
77 async fn query_row(&self, sql: &str) -> Result<Option<Row>> {
78 let rows = self.query_all(sql).await?;
79 let row = rows.into_iter().next();
80 Ok(row)
81 }
82
83 async fn query_all(&self, sql: &str) -> Result<Vec<Row>> {
84 let rows = self.query_iter(sql).await?;
85 rows.collect().await
86 }
87
88 async fn query_raw_iter(&self, _sql: &str) -> Result<RawRowIterator> {
90 Err(Error::BadArgument(
91 "Unsupported implement query_raw_iter".to_string(),
92 ))
93 }
94
95 async fn query_raw_all(&self, sql: &str) -> Result<Vec<RawRow>> {
97 let rows = self.query_raw_iter(sql).await?;
98 rows.collect().await
99 }
100
101 async fn get_presigned_url(&self, operation: &str, stage: &str) -> Result<PresignedResponse> {
104 info!("get presigned url: {} {}", operation, stage);
105 let sql = format!("PRESIGN {operation} {stage}");
106 let row = self.query_row(&sql).await?.ok_or_else(|| {
107 Error::InvalidResponse("Empty response from server for presigned request".to_string())
108 })?;
109 let (method, headers, url): (String, String, String) =
110 row.try_into().map_err(Error::Parsing)?;
111 let headers: BTreeMap<String, String> = serde_json::from_str(&headers)?;
112 Ok(PresignedResponse {
113 method,
114 headers,
115 url,
116 })
117 }
118
119 async fn upload_to_stage(&self, stage: &str, data: Reader, size: u64) -> Result<()>;
120
121 async fn load_data(
122 &self,
123 sql: &str,
124 data: Reader,
125 size: u64,
126 method: LoadMethod,
127 ) -> Result<ServerStats>;
128
129 async fn load_file(&self, sql: &str, fp: &Path, method: LoadMethod) -> Result<ServerStats>;
130
131 async fn load_file_with_options(
132 &self,
133 sql: &str,
134 fp: &Path,
135 file_format_options: Option<BTreeMap<&str, &str>>,
136 copy_options: Option<BTreeMap<&str, &str>>,
137 ) -> Result<ServerStats>;
138
139 async fn stream_load(
140 &self,
141 sql: &str,
142 data: Vec<Vec<&str>>,
143 _method: LoadMethod,
144 ) -> Result<ServerStats>;
145
146 async fn put_files(&self, local_file: &str, stage: &str) -> Result<RowStatsIterator> {
148 let mut total_count: usize = 0;
149 let mut total_size: usize = 0;
150 let local_dsn = url::Url::parse(local_file)?;
151 validate_local_scheme(local_dsn.scheme())?;
152 let mut results = Vec::new();
153 let stage_location = StageLocation::try_from(stage)?;
154 let schema = Arc::new(put_get_schema());
155 for entry in glob::glob(local_dsn.path())? {
156 let entry = entry?;
157 let filename = entry
158 .file_name()
159 .ok_or_else(|| Error::BadArgument(format!("Invalid local file path: {entry:?}")))?
160 .to_str()
161 .ok_or_else(|| Error::BadArgument(format!("Invalid local file path: {entry:?}")))?;
162 let stage_file = stage_location.file_path(filename);
163 let file = File::open(&entry).await?;
164 let size = file.metadata().await?.len();
165 let data = BufReader::new(file);
166 let (fname, status) = match self
167 .upload_to_stage(&stage_file, Box::new(data), size)
168 .await
169 {
170 Ok(_) => {
171 total_count += 1;
172 total_size += size as usize;
173 (entry.to_string_lossy().to_string(), "SUCCESS".to_owned())
174 }
175 Err(e) => (entry.to_string_lossy().to_string(), e.to_string()),
176 };
177 let ss = ServerStats {
178 write_rows: total_count,
179 write_bytes: total_size,
180
181 ..Default::default()
182 };
183 results.push(Ok(RowWithStats::Stats(ss)));
184 results.push(Ok(RowWithStats::Row(Row::from_vec(
185 schema.clone(),
186 vec![
187 Value::String(fname),
188 Value::String(status),
189 Value::Number(NumberValue::UInt64(size)),
190 ],
191 ))));
192 }
193 Ok(RowStatsIterator::new(
194 schema,
195 Box::pin(tokio_stream::iter(results)),
196 ))
197 }
198
199 async fn get_files(&self, stage: &str, local_file: &str) -> Result<RowStatsIterator> {
200 let mut total_count: usize = 0;
201 let mut total_size: usize = 0;
202 let local_dsn = url::Url::parse(local_file)?;
203 validate_local_scheme(local_dsn.scheme())?;
204 let mut location = StageLocation::try_from(stage)?;
205 if !location.path.ends_with('/') {
206 location.path.push('/');
207 }
208 let list_sql = format!("LIST {location}");
209 let mut response = self.query_iter(&list_sql).await?;
210 let mut results = Vec::new();
211 let schema = Arc::new(put_get_schema());
212 while let Some(row) = response.next().await {
213 let (mut name, _, _, _, _): (String, u64, Option<String>, String, Option<String>) =
214 row?.try_into().map_err(Error::Parsing)?;
215 if !location.path.is_empty() && name.starts_with(&location.path) {
216 name = name[location.path.len()..].to_string();
217 }
218 let stage_file = format!("{location}/{name}");
219 let presign = self.get_presigned_url("DOWNLOAD", &stage_file).await?;
220 let local_file = Path::new(local_dsn.path()).join(&name);
221 let status = presign_download_from_stage(presign, &local_file).await;
222 let (status, size) = match status {
223 Ok(size) => {
224 total_count += 1;
225 total_size += size as usize;
226 ("SUCCESS".to_owned(), size)
227 }
228 Err(e) => (e.to_string(), 0),
229 };
230 let ss = ServerStats {
231 read_rows: total_count,
232 read_bytes: total_size,
233 ..Default::default()
234 };
235 results.push(Ok(RowWithStats::Stats(ss)));
236 results.push(Ok(RowWithStats::Row(Row::from_vec(
237 schema.clone(),
238 vec![
239 Value::String(local_file.to_string_lossy().to_string()),
240 Value::String(status),
241 Value::Number(NumberValue::UInt64(size)),
242 ],
243 ))));
244 }
245 Ok(RowStatsIterator::new(
246 schema,
247 Box::pin(tokio_stream::iter(results)),
248 ))
249 }
250}
251
252fn put_get_schema() -> Schema {
253 Schema::from_vec(vec![
254 Field {
255 name: "file".to_string(),
256 data_type: DataType::String,
257 },
258 Field {
259 name: "status".to_string(),
260 data_type: DataType::String,
261 },
262 Field {
263 name: "size".to_string(),
264 data_type: DataType::Number(NumberDataType::UInt64),
265 },
266 ])
267}
268
269fn validate_local_scheme(scheme: &str) -> Result<()> {
270 match scheme {
271 "file" | "fs" => Ok(()),
272 _ => Err(Error::BadArgument(
273 "Supported schemes: file:// or fs://".to_string(),
274 )),
275 }
276}