databend_driver/
conn.rs

1// Copyright 2021 Datafuse Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::BTreeMap;
16use std::path::Path;
17use std::sync::Arc;
18
19use async_trait::async_trait;
20use log::info;
21use tokio::fs::File;
22use tokio::io::AsyncRead;
23use tokio::io::BufReader;
24use tokio_stream::StreamExt;
25
26use crate::client::LoadMethod;
27use databend_client::StageLocation;
28use databend_client::{presign_download_from_stage, PresignedResponse};
29use databend_driver_core::error::{Error, Result};
30use databend_driver_core::raw_rows::{RawRow, RawRowIterator};
31use databend_driver_core::rows::{Row, RowIterator, RowStatsIterator, RowWithStats, ServerStats};
32use databend_driver_core::schema::{DataType, Field, NumberDataType, Schema};
33use databend_driver_core::value::{NumberValue, Value};
34
35pub struct ConnectionInfo {
36    pub handler: String,
37    pub host: String,
38    pub port: u16,
39    pub user: String,
40    pub catalog: Option<String>,
41    pub database: Option<String>,
42    pub warehouse: Option<String>,
43}
44
45pub type Reader = Box<dyn AsyncRead + Send + Sync + Unpin + 'static>;
46
47#[async_trait]
48pub trait IConnection: Send + Sync {
49    async fn info(&self) -> ConnectionInfo;
50    async fn close(&self) -> Result<()> {
51        Ok(())
52    }
53
54    fn close_with_spawn(&self) -> Result<()> {
55        Ok(())
56    }
57
58    fn last_query_id(&self) -> Option<String>;
59
60    async fn version(&self) -> Result<String> {
61        let row = self.query_row("SELECT version()").await?;
62        let version = match row {
63            Some(row) => {
64                let (version,): (String,) = row.try_into().map_err(Error::Parsing)?;
65                version
66            }
67            None => "".to_string(),
68        };
69        Ok(version)
70    }
71
72    async fn exec(&self, sql: &str) -> Result<i64>;
73    async fn kill_query(&self, query_id: &str) -> Result<()>;
74    async fn query_iter(&self, sql: &str) -> Result<RowIterator>;
75    async fn query_iter_ext(&self, sql: &str) -> Result<RowStatsIterator>;
76
77    async fn query_row(&self, sql: &str) -> Result<Option<Row>> {
78        let rows = self.query_all(sql).await?;
79        let row = rows.into_iter().next();
80        Ok(row)
81    }
82
83    async fn query_all(&self, sql: &str) -> Result<Vec<Row>> {
84        let rows = self.query_iter(sql).await?;
85        rows.collect().await
86    }
87
88    // raw data response query, only for test
89    async fn query_raw_iter(&self, _sql: &str) -> Result<RawRowIterator> {
90        Err(Error::BadArgument(
91            "Unsupported implement query_raw_iter".to_string(),
92        ))
93    }
94
95    // raw data response query, only for test
96    async fn query_raw_all(&self, sql: &str) -> Result<Vec<RawRow>> {
97        let rows = self.query_raw_iter(sql).await?;
98        rows.collect().await
99    }
100
101    /// Get presigned url for a given operation and stage location.
102    /// The operation can be "UPLOAD" or "DOWNLOAD".
103    async fn get_presigned_url(&self, operation: &str, stage: &str) -> Result<PresignedResponse> {
104        info!("get presigned url: {} {}", operation, stage);
105        let sql = format!("PRESIGN {operation} {stage}");
106        let row = self.query_row(&sql).await?.ok_or_else(|| {
107            Error::InvalidResponse("Empty response from server for presigned request".to_string())
108        })?;
109        let (method, headers, url): (String, String, String) =
110            row.try_into().map_err(Error::Parsing)?;
111        let headers: BTreeMap<String, String> = serde_json::from_str(&headers)?;
112        Ok(PresignedResponse {
113            method,
114            headers,
115            url,
116        })
117    }
118
119    async fn upload_to_stage(&self, stage: &str, data: Reader, size: u64) -> Result<()>;
120
121    async fn load_data(
122        &self,
123        sql: &str,
124        data: Reader,
125        size: u64,
126        method: LoadMethod,
127    ) -> Result<ServerStats>;
128
129    async fn load_file(&self, sql: &str, fp: &Path, method: LoadMethod) -> Result<ServerStats>;
130
131    async fn load_file_with_options(
132        &self,
133        sql: &str,
134        fp: &Path,
135        file_format_options: Option<BTreeMap<&str, &str>>,
136        copy_options: Option<BTreeMap<&str, &str>>,
137    ) -> Result<ServerStats>;
138
139    async fn stream_load(
140        &self,
141        sql: &str,
142        data: Vec<Vec<&str>>,
143        _method: LoadMethod,
144    ) -> Result<ServerStats>;
145
146    // PUT file://<path_to_file>/<filename> internalStage|externalStage
147    async fn put_files(&self, local_file: &str, stage: &str) -> Result<RowStatsIterator> {
148        let mut total_count: usize = 0;
149        let mut total_size: usize = 0;
150        let local_dsn = url::Url::parse(local_file)?;
151        validate_local_scheme(local_dsn.scheme())?;
152        let mut results = Vec::new();
153        let stage_location = StageLocation::try_from(stage)?;
154        let schema = Arc::new(put_get_schema());
155        for entry in glob::glob(local_dsn.path())? {
156            let entry = entry?;
157            let filename = entry
158                .file_name()
159                .ok_or_else(|| Error::BadArgument(format!("Invalid local file path: {entry:?}")))?
160                .to_str()
161                .ok_or_else(|| Error::BadArgument(format!("Invalid local file path: {entry:?}")))?;
162            let stage_file = stage_location.file_path(filename);
163            let file = File::open(&entry).await?;
164            let size = file.metadata().await?.len();
165            let data = BufReader::new(file);
166            let (fname, status) = match self
167                .upload_to_stage(&stage_file, Box::new(data), size)
168                .await
169            {
170                Ok(_) => {
171                    total_count += 1;
172                    total_size += size as usize;
173                    (entry.to_string_lossy().to_string(), "SUCCESS".to_owned())
174                }
175                Err(e) => (entry.to_string_lossy().to_string(), e.to_string()),
176            };
177            let ss = ServerStats {
178                write_rows: total_count,
179                write_bytes: total_size,
180
181                ..Default::default()
182            };
183            results.push(Ok(RowWithStats::Stats(ss)));
184            results.push(Ok(RowWithStats::Row(Row::from_vec(
185                schema.clone(),
186                vec![
187                    Value::String(fname),
188                    Value::String(status),
189                    Value::Number(NumberValue::UInt64(size)),
190                ],
191            ))));
192        }
193        Ok(RowStatsIterator::new(
194            schema,
195            Box::pin(tokio_stream::iter(results)),
196        ))
197    }
198
199    async fn get_files(&self, stage: &str, local_file: &str) -> Result<RowStatsIterator> {
200        let mut total_count: usize = 0;
201        let mut total_size: usize = 0;
202        let local_dsn = url::Url::parse(local_file)?;
203        validate_local_scheme(local_dsn.scheme())?;
204        let mut location = StageLocation::try_from(stage)?;
205        if !location.path.ends_with('/') {
206            location.path.push('/');
207        }
208        let list_sql = format!("LIST {location}");
209        let mut response = self.query_iter(&list_sql).await?;
210        let mut results = Vec::new();
211        let schema = Arc::new(put_get_schema());
212        while let Some(row) = response.next().await {
213            let (mut name, _, _, _, _): (String, u64, Option<String>, String, Option<String>) =
214                row?.try_into().map_err(Error::Parsing)?;
215            if !location.path.is_empty() && name.starts_with(&location.path) {
216                name = name[location.path.len()..].to_string();
217            }
218            let stage_file = format!("{location}/{name}");
219            let presign = self.get_presigned_url("DOWNLOAD", &stage_file).await?;
220            let local_file = Path::new(local_dsn.path()).join(&name);
221            let status = presign_download_from_stage(presign, &local_file).await;
222            let (status, size) = match status {
223                Ok(size) => {
224                    total_count += 1;
225                    total_size += size as usize;
226                    ("SUCCESS".to_owned(), size)
227                }
228                Err(e) => (e.to_string(), 0),
229            };
230            let ss = ServerStats {
231                read_rows: total_count,
232                read_bytes: total_size,
233                ..Default::default()
234            };
235            results.push(Ok(RowWithStats::Stats(ss)));
236            results.push(Ok(RowWithStats::Row(Row::from_vec(
237                schema.clone(),
238                vec![
239                    Value::String(local_file.to_string_lossy().to_string()),
240                    Value::String(status),
241                    Value::Number(NumberValue::UInt64(size)),
242                ],
243            ))));
244        }
245        Ok(RowStatsIterator::new(
246            schema,
247            Box::pin(tokio_stream::iter(results)),
248        ))
249    }
250}
251
252fn put_get_schema() -> Schema {
253    Schema::from_vec(vec![
254        Field {
255            name: "file".to_string(),
256            data_type: DataType::String,
257        },
258        Field {
259            name: "status".to_string(),
260            data_type: DataType::String,
261        },
262        Field {
263            name: "size".to_string(),
264            data_type: DataType::Number(NumberDataType::UInt64),
265        },
266    ])
267}
268
269fn validate_local_scheme(scheme: &str) -> Result<()> {
270    match scheme {
271        "file" | "fs" => Ok(()),
272        _ => Err(Error::BadArgument(
273            "Supported schemes: file:// or fs://".to_string(),
274        )),
275    }
276}