databend_driver/
conn.rs

1// Copyright 2021 Datafuse Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::BTreeMap;
16use std::path::Path;
17use std::sync::Arc;
18
19use async_trait::async_trait;
20use log::info;
21use tokio::fs::File;
22use tokio::io::AsyncRead;
23use tokio::io::BufReader;
24use tokio_stream::StreamExt;
25
26use crate::client::LoadMethod;
27use databend_client::schema::{DataType, Field, NumberDataType, Schema};
28use databend_client::StageLocation;
29use databend_client::{presign_download_from_stage, PresignedResponse};
30use databend_driver_core::error::{Error, Result};
31use databend_driver_core::raw_rows::{RawRow, RawRowIterator};
32use databend_driver_core::rows::{Row, RowIterator, RowStatsIterator, RowWithStats, ServerStats};
33use databend_driver_core::value::{NumberValue, Value};
34
35pub struct ConnectionInfo {
36    pub handler: String,
37    pub host: String,
38    pub port: u16,
39    pub user: String,
40    pub catalog: Option<String>,
41    pub database: Option<String>,
42    pub warehouse: Option<String>,
43}
44
45pub type Reader = Box<dyn AsyncRead + Send + Sync + Unpin + 'static>;
46
47#[async_trait]
48pub trait IConnection: Send + Sync {
49    async fn info(&self) -> ConnectionInfo;
50    async fn close(&self) -> Result<()> {
51        Ok(())
52    }
53
54    fn close_with_spawn(&self) -> Result<()> {
55        Ok(())
56    }
57
58    fn last_query_id(&self) -> Option<String>;
59
60    async fn version(&self) -> Result<String> {
61        let row = self.query_row("SELECT version()").await?;
62        let version = match row {
63            Some(row) => {
64                let (version,): (String,) = row.try_into().map_err(Error::Parsing)?;
65                version
66            }
67            None => "".to_string(),
68        };
69        Ok(version)
70    }
71
72    async fn exec(&self, sql: &str) -> Result<i64>;
73    async fn kill_query(&self, query_id: &str) -> Result<()>;
74    async fn query_iter(&self, sql: &str) -> Result<RowIterator>;
75    async fn query_iter_ext(&self, sql: &str) -> Result<RowStatsIterator>;
76
77    async fn query_row(&self, sql: &str) -> Result<Option<Row>> {
78        let rows = self.query_all(sql).await?;
79        let row = rows.into_iter().next();
80        Ok(row)
81    }
82
83    async fn query_all(&self, sql: &str) -> Result<Vec<Row>> {
84        let rows = self.query_iter(sql).await?;
85        rows.collect().await
86    }
87
88    // raw data response query, only for test
89    async fn query_raw_iter(&self, _sql: &str) -> Result<RawRowIterator> {
90        Err(Error::BadArgument(
91            "Unsupported implement query_raw_iter".to_string(),
92        ))
93    }
94
95    // raw data response query, only for test
96    async fn query_raw_all(&self, sql: &str) -> Result<Vec<RawRow>> {
97        let rows = self.query_raw_iter(sql).await?;
98        rows.collect().await
99    }
100
101    /// Get presigned url for a given operation and stage location.
102    /// The operation can be "UPLOAD" or "DOWNLOAD".
103    async fn get_presigned_url(&self, operation: &str, stage: &str) -> Result<PresignedResponse> {
104        info!("get presigned url: {} {}", operation, stage);
105        let sql = format!("PRESIGN {operation} {stage}");
106        let row = self.query_row(&sql).await?.ok_or_else(|| {
107            Error::InvalidResponse("Empty response from server for presigned request".to_string())
108        })?;
109        let (method, headers, url): (String, String, String) =
110            row.try_into().map_err(Error::Parsing)?;
111        let headers: BTreeMap<String, String> = serde_json::from_str(&headers)?;
112        Ok(PresignedResponse {
113            method,
114            headers,
115            url,
116        })
117    }
118
119    async fn upload_to_stage(&self, stage: &str, data: Reader, size: u64) -> Result<()>;
120
121    async fn load_data(
122        &self,
123        sql: &str,
124        data: Reader,
125        size: u64,
126        method: LoadMethod,
127    ) -> Result<ServerStats>;
128
129    async fn load_file(&self, sql: &str, fp: &Path, method: LoadMethod) -> Result<ServerStats>;
130
131    async fn load_file_with_options(
132        &self,
133        sql: &str,
134        fp: &Path,
135        file_format_options: Option<BTreeMap<&str, &str>>,
136        copy_options: Option<BTreeMap<&str, &str>>,
137    ) -> Result<ServerStats>;
138
139    async fn stream_load(
140        &self,
141        sql: &str,
142        data: Vec<Vec<&str>>,
143        _method: LoadMethod,
144    ) -> Result<ServerStats>;
145
146    fn set_warehouse(&self, warehouse: &str) -> Result<()>;
147
148    fn set_database(&self, database: &str) -> Result<()>;
149
150    fn set_role(&self, role: &str) -> Result<()>;
151
152    fn set_session(&self, key: &str, value: &str) -> Result<()>;
153
154    // PUT file://<path_to_file>/<filename> internalStage|externalStage
155    async fn put_files(&self, local_file: &str, stage: &str) -> Result<RowStatsIterator> {
156        let mut total_count: usize = 0;
157        let mut total_size: usize = 0;
158        let local_dsn = url::Url::parse(local_file)?;
159        validate_local_scheme(local_dsn.scheme())?;
160        let mut results = Vec::new();
161        let stage_location = StageLocation::try_from(stage)?;
162        let schema = Arc::new(put_get_schema());
163        for entry in glob::glob(local_dsn.path())? {
164            let entry = entry?;
165            let filename = entry
166                .file_name()
167                .ok_or_else(|| Error::BadArgument(format!("Invalid local file path: {entry:?}")))?
168                .to_str()
169                .ok_or_else(|| Error::BadArgument(format!("Invalid local file path: {entry:?}")))?;
170            let stage_file = stage_location.file_path(filename);
171            let file = File::open(&entry).await?;
172            let size = file.metadata().await?.len();
173            let data = BufReader::new(file);
174            let (fname, status) = match self
175                .upload_to_stage(&stage_file, Box::new(data), size)
176                .await
177            {
178                Ok(_) => {
179                    total_count += 1;
180                    total_size += size as usize;
181                    (entry.to_string_lossy().to_string(), "SUCCESS".to_owned())
182                }
183                Err(e) => (entry.to_string_lossy().to_string(), e.to_string()),
184            };
185            let ss = ServerStats {
186                write_rows: total_count,
187                write_bytes: total_size,
188
189                ..Default::default()
190            };
191            results.push(Ok(RowWithStats::Stats(ss)));
192            results.push(Ok(RowWithStats::Row(Row::from_vec(
193                schema.clone(),
194                vec![
195                    Value::String(fname),
196                    Value::String(status),
197                    Value::Number(NumberValue::UInt64(size)),
198                ],
199            ))));
200        }
201        Ok(RowStatsIterator::new(
202            schema,
203            Box::pin(tokio_stream::iter(results)),
204        ))
205    }
206
207    async fn get_files(&self, stage: &str, local_file: &str) -> Result<RowStatsIterator> {
208        let mut total_count: usize = 0;
209        let mut total_size: usize = 0;
210        let local_dsn = url::Url::parse(local_file)?;
211        validate_local_scheme(local_dsn.scheme())?;
212        let mut location = StageLocation::try_from(stage)?;
213        if !location.path.ends_with('/') {
214            location.path.push('/');
215        }
216        let list_sql = format!("LIST {location}");
217        let mut response = self.query_iter(&list_sql).await?;
218        let mut results = Vec::new();
219        let schema = Arc::new(put_get_schema());
220        while let Some(row) = response.next().await {
221            let (mut name, _, _, _, _): (String, u64, Option<String>, String, Option<String>) =
222                row?.try_into().map_err(Error::Parsing)?;
223            if !location.path.is_empty() && name.starts_with(&location.path) {
224                name = name[location.path.len()..].to_string();
225            }
226            let stage_file = format!("{location}/{name}");
227            let presign = self.get_presigned_url("DOWNLOAD", &stage_file).await?;
228            let local_file = Path::new(local_dsn.path()).join(&name);
229            let status = presign_download_from_stage(presign, &local_file).await;
230            let (status, size) = match status {
231                Ok(size) => {
232                    total_count += 1;
233                    total_size += size as usize;
234                    ("SUCCESS".to_owned(), size)
235                }
236                Err(e) => (e.to_string(), 0),
237            };
238            let ss = ServerStats {
239                read_rows: total_count,
240                read_bytes: total_size,
241                ..Default::default()
242            };
243            results.push(Ok(RowWithStats::Stats(ss)));
244            results.push(Ok(RowWithStats::Row(Row::from_vec(
245                schema.clone(),
246                vec![
247                    Value::String(local_file.to_string_lossy().to_string()),
248                    Value::String(status),
249                    Value::Number(NumberValue::UInt64(size)),
250                ],
251            ))));
252        }
253        Ok(RowStatsIterator::new(
254            schema,
255            Box::pin(tokio_stream::iter(results)),
256        ))
257    }
258}
259
260fn put_get_schema() -> Schema {
261    Schema::from_vec(vec![
262        Field {
263            name: "file".to_string(),
264            data_type: DataType::String,
265        },
266        Field {
267            name: "status".to_string(),
268            data_type: DataType::String,
269        },
270        Field {
271            name: "size".to_string(),
272            data_type: DataType::Number(NumberDataType::UInt64),
273        },
274    ])
275}
276
277fn validate_local_scheme(scheme: &str) -> Result<()> {
278    match scheme {
279        "file" | "fs" => Ok(()),
280        _ => Err(Error::BadArgument(
281            "Supported schemes: file:// or fs://".to_string(),
282        )),
283    }
284}