databend_driver/
conn.rs

1// Copyright 2021 Datafuse Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::BTreeMap;
16use std::path::Path;
17use std::sync::Arc;
18
19use async_trait::async_trait;
20use log::info;
21use tokio::fs::File;
22use tokio::io::AsyncRead;
23use tokio::io::BufReader;
24use tokio_stream::StreamExt;
25
26use crate::client::LoadMethod;
27use databend_client::StageLocation;
28use databend_client::{presign_download_from_stage, PresignedResponse};
29use databend_driver_core::error::{Error, Result};
30use databend_driver_core::raw_rows::{RawRow, RawRowIterator};
31use databend_driver_core::rows::{Row, RowIterator, RowStatsIterator, RowWithStats, ServerStats};
32use databend_driver_core::schema::{DataType, Field, NumberDataType, Schema};
33use databend_driver_core::value::{NumberValue, Value};
34
35pub struct ConnectionInfo {
36    pub handler: String,
37    pub host: String,
38    pub port: u16,
39    pub user: String,
40    pub catalog: Option<String>,
41    pub database: Option<String>,
42    pub warehouse: Option<String>,
43}
44
45pub type Reader = Box<dyn AsyncRead + Send + Sync + Unpin + 'static>;
46
47#[async_trait]
48pub trait IConnection: Send + Sync {
49    async fn info(&self) -> ConnectionInfo;
50    async fn close(&self) -> Result<()> {
51        Ok(())
52    }
53
54    fn last_query_id(&self) -> Option<String>;
55
56    async fn version(&self) -> Result<String> {
57        let row = self.query_row("SELECT version()").await?;
58        let version = match row {
59            Some(row) => {
60                let (version,): (String,) = row.try_into().map_err(Error::Parsing)?;
61                version
62            }
63            None => "".to_string(),
64        };
65        Ok(version)
66    }
67
68    async fn exec(&self, sql: &str) -> Result<i64>;
69    async fn kill_query(&self, query_id: &str) -> Result<()>;
70    async fn query_iter(&self, sql: &str) -> Result<RowIterator>;
71    async fn query_iter_ext(&self, sql: &str) -> Result<RowStatsIterator>;
72
73    async fn query_row(&self, sql: &str) -> Result<Option<Row>> {
74        let rows = self.query_all(sql).await?;
75        let row = rows.into_iter().next();
76        Ok(row)
77    }
78
79    async fn query_all(&self, sql: &str) -> Result<Vec<Row>> {
80        let rows = self.query_iter(sql).await?;
81        rows.collect().await
82    }
83
84    // raw data response query, only for test
85    async fn query_raw_iter(&self, _sql: &str) -> Result<RawRowIterator> {
86        Err(Error::BadArgument(
87            "Unsupported implement query_raw_iter".to_string(),
88        ))
89    }
90
91    // raw data response query, only for test
92    async fn query_raw_all(&self, sql: &str) -> Result<Vec<RawRow>> {
93        let rows = self.query_raw_iter(sql).await?;
94        rows.collect().await
95    }
96
97    /// Get presigned url for a given operation and stage location.
98    /// The operation can be "UPLOAD" or "DOWNLOAD".
99    async fn get_presigned_url(&self, operation: &str, stage: &str) -> Result<PresignedResponse> {
100        info!("get presigned url: {} {}", operation, stage);
101        let sql = format!("PRESIGN {operation} {stage}");
102        let row = self.query_row(&sql).await?.ok_or_else(|| {
103            Error::InvalidResponse("Empty response from server for presigned request".to_string())
104        })?;
105        let (method, headers, url): (String, String, String) =
106            row.try_into().map_err(Error::Parsing)?;
107        let headers: BTreeMap<String, String> = serde_json::from_str(&headers)?;
108        Ok(PresignedResponse {
109            method,
110            headers,
111            url,
112        })
113    }
114
115    async fn upload_to_stage(&self, stage: &str, data: Reader, size: u64) -> Result<()>;
116
117    async fn load_data(
118        &self,
119        sql: &str,
120        data: Reader,
121        size: u64,
122        method: LoadMethod,
123    ) -> Result<ServerStats>;
124
125    async fn load_file(&self, sql: &str, fp: &Path, method: LoadMethod) -> Result<ServerStats>;
126
127    async fn load_file_with_options(
128        &self,
129        sql: &str,
130        fp: &Path,
131        file_format_options: Option<BTreeMap<&str, &str>>,
132        copy_options: Option<BTreeMap<&str, &str>>,
133    ) -> Result<ServerStats>;
134
135    async fn stream_load(
136        &self,
137        sql: &str,
138        data: Vec<Vec<&str>>,
139        _method: LoadMethod,
140    ) -> Result<ServerStats>;
141
142    // PUT file://<path_to_file>/<filename> internalStage|externalStage
143    async fn put_files(&self, local_file: &str, stage: &str) -> Result<RowStatsIterator> {
144        let mut total_count: usize = 0;
145        let mut total_size: usize = 0;
146        let local_dsn = url::Url::parse(local_file)?;
147        validate_local_scheme(local_dsn.scheme())?;
148        let mut results = Vec::new();
149        let stage_location = StageLocation::try_from(stage)?;
150        let schema = Arc::new(put_get_schema());
151        for entry in glob::glob(local_dsn.path())? {
152            let entry = entry?;
153            let filename = entry
154                .file_name()
155                .ok_or_else(|| Error::BadArgument(format!("Invalid local file path: {entry:?}")))?
156                .to_str()
157                .ok_or_else(|| Error::BadArgument(format!("Invalid local file path: {entry:?}")))?;
158            let stage_file = stage_location.file_path(filename);
159            let file = File::open(&entry).await?;
160            let size = file.metadata().await?.len();
161            let data = BufReader::new(file);
162            let (fname, status) = match self
163                .upload_to_stage(&stage_file, Box::new(data), size)
164                .await
165            {
166                Ok(_) => {
167                    total_count += 1;
168                    total_size += size as usize;
169                    (entry.to_string_lossy().to_string(), "SUCCESS".to_owned())
170                }
171                Err(e) => (entry.to_string_lossy().to_string(), e.to_string()),
172            };
173            let ss = ServerStats {
174                write_rows: total_count,
175                write_bytes: total_size,
176
177                ..Default::default()
178            };
179            results.push(Ok(RowWithStats::Stats(ss)));
180            results.push(Ok(RowWithStats::Row(Row::from_vec(
181                schema.clone(),
182                vec![
183                    Value::String(fname),
184                    Value::String(status),
185                    Value::Number(NumberValue::UInt64(size)),
186                ],
187            ))));
188        }
189        Ok(RowStatsIterator::new(
190            schema,
191            Box::pin(tokio_stream::iter(results)),
192        ))
193    }
194
195    async fn get_files(&self, stage: &str, local_file: &str) -> Result<RowStatsIterator> {
196        let mut total_count: usize = 0;
197        let mut total_size: usize = 0;
198        let local_dsn = url::Url::parse(local_file)?;
199        validate_local_scheme(local_dsn.scheme())?;
200        let mut location = StageLocation::try_from(stage)?;
201        if !location.path.ends_with('/') {
202            location.path.push('/');
203        }
204        let list_sql = format!("LIST {location}");
205        let mut response = self.query_iter(&list_sql).await?;
206        let mut results = Vec::new();
207        let schema = Arc::new(put_get_schema());
208        while let Some(row) = response.next().await {
209            let (mut name, _, _, _, _): (String, u64, Option<String>, String, Option<String>) =
210                row?.try_into().map_err(Error::Parsing)?;
211            if !location.path.is_empty() && name.starts_with(&location.path) {
212                name = name[location.path.len()..].to_string();
213            }
214            let stage_file = format!("{location}/{name}");
215            let presign = self.get_presigned_url("DOWNLOAD", &stage_file).await?;
216            let local_file = Path::new(local_dsn.path()).join(&name);
217            let status = presign_download_from_stage(presign, &local_file).await;
218            let (status, size) = match status {
219                Ok(size) => {
220                    total_count += 1;
221                    total_size += size as usize;
222                    ("SUCCESS".to_owned(), size)
223                }
224                Err(e) => (e.to_string(), 0),
225            };
226            let ss = ServerStats {
227                read_rows: total_count,
228                read_bytes: total_size,
229                ..Default::default()
230            };
231            results.push(Ok(RowWithStats::Stats(ss)));
232            results.push(Ok(RowWithStats::Row(Row::from_vec(
233                schema.clone(),
234                vec![
235                    Value::String(local_file.to_string_lossy().to_string()),
236                    Value::String(status),
237                    Value::Number(NumberValue::UInt64(size)),
238                ],
239            ))));
240        }
241        Ok(RowStatsIterator::new(
242            schema,
243            Box::pin(tokio_stream::iter(results)),
244        ))
245    }
246}
247
248fn put_get_schema() -> Schema {
249    Schema::from_vec(vec![
250        Field {
251            name: "file".to_string(),
252            data_type: DataType::String,
253        },
254        Field {
255            name: "status".to_string(),
256            data_type: DataType::String,
257        },
258        Field {
259            name: "size".to_string(),
260            data_type: DataType::Number(NumberDataType::UInt64),
261        },
262    ])
263}
264
265fn validate_local_scheme(scheme: &str) -> Result<()> {
266    match scheme {
267        "file" | "fs" => Ok(()),
268        _ => Err(Error::BadArgument(
269            "Supported schemes: file:// or fs://".to_string(),
270        )),
271    }
272}