Skip to main content

databend_driver/
conn.rs

1// Copyright 2021 Datafuse Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::BTreeMap;
16use std::collections::BTreeSet;
17use std::path::Path;
18use std::path::PathBuf;
19use std::sync::Arc;
20
21use async_trait::async_trait;
22use log::info;
23use tokio::fs::File;
24use tokio::io::AsyncRead;
25use tokio::io::BufReader;
26use tokio_stream::StreamExt;
27
28use crate::client::LoadMethod;
29use databend_client::schema::{DataType, Field, NumberDataType, Schema};
30use databend_client::StageLocation;
31use databend_client::{presign_download_from_stage, PresignedResponse};
32use databend_driver_core::error::{Error, Result};
33use databend_driver_core::raw_rows::{RawRow, RawRowIterator};
34use databend_driver_core::rows::{Row, RowIterator, RowStatsIterator, RowWithStats, ServerStats};
35use databend_driver_core::value::{NumberValue, Value};
36
37pub struct ConnectionInfo {
38    pub handler: String,
39    pub host: String,
40    pub port: u16,
41    pub user: String,
42    pub catalog: Option<String>,
43    pub database: Option<String>,
44    pub warehouse: Option<String>,
45}
46
47pub type Reader = Box<dyn AsyncRead + Send + Sync + Unpin + 'static>;
48
49#[async_trait]
50pub trait IConnection: Send + Sync {
51    async fn info(&self) -> ConnectionInfo;
52    async fn close(&self) -> Result<()> {
53        Ok(())
54    }
55
56    fn close_with_spawn(&self) -> Result<()> {
57        Ok(())
58    }
59
60    fn last_query_id(&self) -> Option<String>;
61
62    async fn version(&self) -> Result<String> {
63        let row = self.query_row("SELECT version()").await?;
64        let version = match row {
65            Some(row) => {
66                let (version,): (String,) = row.try_into().map_err(Error::Parsing)?;
67                version
68            }
69            None => "".to_string(),
70        };
71        Ok(version)
72    }
73
74    async fn exec(&self, sql: &str) -> Result<i64>;
75    async fn kill_query(&self, query_id: &str) -> Result<()>;
76    async fn query_iter(&self, sql: &str) -> Result<RowIterator>;
77    async fn query_iter_ext(&self, sql: &str) -> Result<RowStatsIterator>;
78
79    fn supports_server_side_params(&self) -> bool {
80        false
81    }
82
83    async fn exec_with_params(&self, sql: &str, _params: Option<serde_json::Value>) -> Result<i64> {
84        self.exec(sql).await
85    }
86
87    async fn query_iter_with_params(
88        &self,
89        sql: &str,
90        _params: Option<serde_json::Value>,
91    ) -> Result<RowIterator> {
92        self.query_iter(sql).await
93    }
94
95    async fn query_iter_ext_with_params(
96        &self,
97        sql: &str,
98        _params: Option<serde_json::Value>,
99    ) -> Result<RowStatsIterator> {
100        self.query_iter_ext(sql).await
101    }
102
103    async fn query_row(&self, sql: &str) -> Result<Option<Row>> {
104        let rows = self.query_all(sql).await?;
105        let row = rows.into_iter().next();
106        Ok(row)
107    }
108
109    async fn query_all(&self, sql: &str) -> Result<Vec<Row>> {
110        let rows = self.query_iter(sql).await?;
111        rows.collect().await
112    }
113
114    // raw data response query, only for test
115    async fn query_raw_iter(&self, _sql: &str) -> Result<RawRowIterator> {
116        Err(Error::BadArgument(
117            "Unsupported implement query_raw_iter".to_string(),
118        ))
119    }
120
121    // raw data response query, only for test
122    async fn query_raw_all(&self, sql: &str) -> Result<Vec<RawRow>> {
123        let rows = self.query_raw_iter(sql).await?;
124        rows.collect().await
125    }
126
127    /// Get presigned url for a given operation and stage location.
128    /// The operation can be "UPLOAD" or "DOWNLOAD".
129    async fn get_presigned_url(&self, operation: &str, stage: &str) -> Result<PresignedResponse> {
130        info!("get presigned url: {} {}", operation, stage);
131        let sql = format!("PRESIGN {operation} {stage}");
132        let row = self.query_row(&sql).await?.ok_or_else(|| {
133            Error::InvalidResponse("Empty response from server for presigned request".to_string())
134        })?;
135        let (method, headers, url): (String, String, String) =
136            row.try_into().map_err(Error::Parsing)?;
137        let headers: BTreeMap<String, String> = serde_json::from_str(&headers)?;
138        Ok(PresignedResponse {
139            method,
140            headers,
141            url,
142        })
143    }
144
145    async fn upload_to_stage(&self, stage: &str, data: Reader, size: u64) -> Result<()>;
146
147    async fn load_data(
148        &self,
149        sql: &str,
150        data: Reader,
151        size: u64,
152        method: LoadMethod,
153    ) -> Result<ServerStats>;
154
155    async fn load_file(&self, sql: &str, fp: &Path, method: LoadMethod) -> Result<ServerStats>;
156
157    async fn load_file_with_options(
158        &self,
159        sql: &str,
160        fp: &Path,
161        file_format_options: Option<BTreeMap<&str, &str>>,
162        copy_options: Option<BTreeMap<&str, &str>>,
163    ) -> Result<ServerStats>;
164
165    async fn stream_load(
166        &self,
167        sql: &str,
168        data: Vec<Vec<&str>>,
169        _method: LoadMethod,
170    ) -> Result<ServerStats>;
171
172    fn set_warehouse(&self, warehouse: &str) -> Result<()>;
173
174    fn set_database(&self, database: &str) -> Result<()>;
175
176    fn set_role(&self, role: &str) -> Result<()>;
177
178    fn set_session(&self, key: &str, value: &str) -> Result<()>;
179
180    // PUT file://<path_to_file>/<filename> internalStage|externalStage
181    async fn put_files(&self, local_file: &str, stage: &str) -> Result<RowStatsIterator> {
182        let mut total_count: usize = 0;
183        let mut total_size: usize = 0;
184        let local_dsn = url::Url::parse(local_file)?;
185        validate_local_scheme(local_dsn.scheme())?;
186        let entries = expand_local_glob(local_dsn.path())?;
187        let mut results = Vec::new();
188        let stage_location = StageLocation::try_from(stage)?;
189        let schema = Arc::new(put_get_schema());
190        for entry in entries {
191            let filename = entry
192                .file_name()
193                .ok_or_else(|| Error::BadArgument(format!("Invalid local file path: {entry:?}")))?
194                .to_str()
195                .ok_or_else(|| Error::BadArgument(format!("Invalid local file path: {entry:?}")))?;
196            let stage_file = stage_location.file_path(filename);
197            let (fname, status, size) = if entry.is_dir() {
198                (
199                    entry.to_string_lossy().to_string(),
200                    format!(
201                        "BadArgument: Local path is a directory: {}",
202                        entry.display()
203                    ),
204                    0,
205                )
206            } else {
207                let file = File::open(&entry).await?;
208                let size = file.metadata().await?.len();
209                let data = BufReader::new(file);
210                match self
211                    .upload_to_stage(&stage_file, Box::new(data), size)
212                    .await
213                {
214                    Ok(_) => {
215                        total_count += 1;
216                        total_size += size as usize;
217                        (
218                            entry.to_string_lossy().to_string(),
219                            "SUCCESS".to_owned(),
220                            size,
221                        )
222                    }
223                    Err(e) => (entry.to_string_lossy().to_string(), e.to_string(), size),
224                }
225            };
226            let ss = ServerStats {
227                write_rows: total_count,
228                write_bytes: total_size,
229
230                ..Default::default()
231            };
232            results.push(Ok(RowWithStats::Stats(ss)));
233            results.push(Ok(RowWithStats::Row(Row::from_vec(
234                schema.clone(),
235                vec![
236                    Value::String(fname),
237                    Value::String(status),
238                    Value::Number(NumberValue::UInt64(size)),
239                ],
240            ))));
241        }
242        Ok(RowStatsIterator::new(
243            schema,
244            Box::pin(tokio_stream::iter(results)),
245        ))
246    }
247
248    async fn get_files(&self, stage: &str, local_file: &str) -> Result<RowStatsIterator> {
249        let mut total_count: usize = 0;
250        let mut total_size: usize = 0;
251        let local_dsn = url::Url::parse(local_file)?;
252        validate_local_scheme(local_dsn.scheme())?;
253        ensure_local_destination_dir(local_dsn.path())?;
254        let location = StageLocation::try_from(stage)?;
255        let list_sql = format!("LIST {location}");
256        let mut response = self.query_iter(&list_sql).await?;
257        let mut files = Vec::new();
258        let mut results = Vec::new();
259        let schema = Arc::new(put_get_schema());
260        while let Some(row) = response.next().await {
261            let (name, _, _, _, _): (String, u64, Option<String>, String, Option<String>) =
262                row?.try_into().map_err(Error::Parsing)?;
263            if let Some(file) = normalize_stage_list_entry(&location, &name)? {
264                files.push(file);
265            }
266        }
267        if files.is_empty() {
268            return Err(Error::BadArgument(format!(
269                "No stage files matched: {stage}"
270            )));
271        }
272        ensure_unique_basenames(&files)?;
273        for file in files {
274            let local_file = Path::new(local_dsn.path()).join(&file.basename);
275            let (status, size) = match self.get_presigned_url("DOWNLOAD", &file.stage_file).await {
276                Ok(presign) => match presign_download_from_stage(presign, &local_file).await {
277                    Ok(size) => {
278                        total_count += 1;
279                        total_size += size as usize;
280                        ("SUCCESS".to_owned(), size)
281                    }
282                    Err(e) => (e.to_string(), 0),
283                },
284                Err(e) => (e.to_string(), 0),
285            };
286            let ss = ServerStats {
287                read_rows: total_count,
288                read_bytes: total_size,
289                ..Default::default()
290            };
291            results.push(Ok(RowWithStats::Stats(ss)));
292            results.push(Ok(RowWithStats::Row(Row::from_vec(
293                schema.clone(),
294                vec![
295                    Value::String(local_file.to_string_lossy().to_string()),
296                    Value::String(status),
297                    Value::Number(NumberValue::UInt64(size)),
298                ],
299            ))));
300        }
301        Ok(RowStatsIterator::new(
302            schema,
303            Box::pin(tokio_stream::iter(results)),
304        ))
305    }
306}
307
308#[derive(Debug, Clone)]
309struct StageFile {
310    stage_file: String,
311    basename: String,
312}
313
314fn expand_local_glob(pattern: &str) -> Result<Vec<PathBuf>> {
315    let entries = glob::glob(pattern)?.collect::<std::result::Result<Vec<_>, _>>()?;
316    if entries.is_empty() {
317        return Err(Error::BadArgument(format!(
318            "No local files matched: {pattern}"
319        )));
320    }
321    Ok(entries)
322}
323
324fn ensure_local_destination_dir(path: &str) -> Result<()> {
325    let local_path = Path::new(path);
326    if local_path.exists() && !local_path.is_dir() {
327        return Err(Error::BadArgument(format!(
328            "Local destination is not a directory: {path}"
329        )));
330    }
331    Ok(())
332}
333
334fn normalize_stage_list_entry(
335    location: &StageLocation,
336    listed_name: &str,
337) -> Result<Option<StageFile>> {
338    let stage_prefix = format!("{}/", location.name);
339    let relative = if let Some(path) = listed_name.strip_prefix(&stage_prefix) {
340        if !location.path.is_empty() && !path.starts_with(&location.path) {
341            return Ok(None);
342        }
343        path
344    } else if location.path.is_empty() || listed_name.starts_with(&location.path) {
345        listed_name
346    } else {
347        return Ok(None);
348    };
349
350    let basename = Path::new(relative)
351        .file_name()
352        .and_then(|name| name.to_str())
353        .ok_or_else(|| Error::BadArgument(format!("Invalid stage file path: {listed_name}")))?
354        .to_string();
355    let stage_file = if listed_name.starts_with(&stage_prefix) {
356        format!("@{listed_name}")
357    } else {
358        format!("@{}/{}", location.name, relative)
359    };
360
361    Ok(Some(StageFile {
362        stage_file,
363        basename,
364    }))
365}
366
367fn ensure_unique_basenames(files: &[StageFile]) -> Result<()> {
368    let mut seen = BTreeSet::new();
369    for file in files {
370        if !seen.insert(file.basename.clone()) {
371            return Err(Error::BadArgument(format!(
372                "Duplicate local file basename in GET results: {}",
373                file.basename
374            )));
375        }
376    }
377    Ok(())
378}
379
380fn put_get_schema() -> Schema {
381    Schema::from_vec(vec![
382        Field {
383            name: "file".to_string(),
384            data_type: DataType::String,
385        },
386        Field {
387            name: "status".to_string(),
388            data_type: DataType::String,
389        },
390        Field {
391            name: "size".to_string(),
392            data_type: DataType::Number(NumberDataType::UInt64),
393        },
394    ])
395}
396
397fn validate_local_scheme(scheme: &str) -> Result<()> {
398    match scheme {
399        "file" | "fs" => Ok(()),
400        _ => Err(Error::BadArgument(
401            "Supported schemes: file:// or fs://".to_string(),
402        )),
403    }
404}