Skip to main content

bigbytes_driver/
conn.rs

1// Copyright 2024 Digitrans Inc
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::BTreeMap;
16use std::path::Path;
17use std::sync::Arc;
18
19use async_trait::async_trait;
20use once_cell::sync::Lazy;
21use tokio::fs::File;
22use tokio::io::AsyncRead;
23use tokio::io::BufReader;
24use tokio_stream::StreamExt;
25use url::Url;
26
27#[cfg(feature = "flight-sql")]
28use crate::flight_sql::FlightSQLConnection;
29
30use databend_client::StageLocation;
31use databend_client::{presign_download_from_stage, PresignedResponse};
32use bigbytes_driver_core::error::{Error, Result};
33use bigbytes_driver_core::raw_rows::{RawRow, RawRowIterator};
34use bigbytes_driver_core::rows::{Row, RowIterator, RowStatsIterator, RowWithStats, ServerStats};
35use bigbytes_driver_core::schema::{DataType, Field, NumberDataType, Schema};
36use bigbytes_driver_core::value::{NumberValue, Value};
37
38use crate::rest_api::RestAPIConnection;
39
40static VERSION: Lazy<String> = Lazy::new(|| {
41    let version = option_env!("CARGO_PKG_VERSION").unwrap_or("unknown");
42    version.to_string()
43});
44
45#[derive(Clone)]
46pub struct Client {
47    dsn: String,
48    name: String,
49}
50
51impl Client {
52    pub fn new(dsn: String) -> Self {
53        let name = format!("bigbytes-driver-rust/{}", VERSION.as_str());
54        Self { dsn, name }
55    }
56
57    pub fn with_name(mut self, name: String) -> Self {
58        self.name = name;
59        self
60    }
61
62    pub async fn get_conn(&self) -> Result<Box<dyn Connection>> {
63        let u = Url::parse(&self.dsn)?;
64        match u.scheme() {
65            "databend" | "databend+http" | "databend+https" => {
66                let conn = RestAPIConnection::try_create(&self.dsn, self.name.clone()).await?;
67                Ok(Box::new(conn))
68            }
69            #[cfg(feature = "flight-sql")]
70            "databend+flight" | "databend+grpc" => {
71                let conn = FlightSQLConnection::try_create(&self.dsn, self.name.clone()).await?;
72                Ok(Box::new(conn))
73            }
74            _ => Err(Error::Parsing(format!(
75                "Unsupported scheme: {}",
76                u.scheme()
77            ))),
78        }
79    }
80}
81
82pub struct ConnectionInfo {
83    pub handler: String,
84    pub host: String,
85    pub port: u16,
86    pub user: String,
87    pub database: Option<String>,
88    pub warehouse: Option<String>,
89}
90
91pub type Reader = Box<dyn AsyncRead + Send + Sync + Unpin + 'static>;
92
93#[async_trait]
94pub trait Connection: Send + Sync {
95    async fn info(&self) -> ConnectionInfo;
96    async fn close(&self) -> Result<()> {
97        Ok(())
98    }
99
100    fn last_query_id(&self) -> Option<String>;
101
102    async fn version(&self) -> Result<String> {
103        let row = self.query_row("SELECT version()").await?;
104        let version = match row {
105            Some(row) => {
106                let (version,): (String,) = row.try_into().map_err(Error::Parsing)?;
107                version
108            }
109            None => "".to_string(),
110        };
111        Ok(version)
112    }
113
114    async fn exec(&self, sql: &str) -> Result<i64>;
115    async fn kill_query(&self, query_id: &str) -> Result<()>;
116    async fn query_iter(&self, sql: &str) -> Result<RowIterator>;
117    async fn query_iter_ext(&self, sql: &str) -> Result<RowStatsIterator>;
118
119    async fn query_row(&self, sql: &str) -> Result<Option<Row>> {
120        let rows = self.query_all(sql).await?;
121        let row = rows.into_iter().next();
122        Ok(row)
123    }
124
125    async fn query_all(&self, sql: &str) -> Result<Vec<Row>> {
126        let rows = self.query_iter(sql).await?;
127        rows.collect().await
128    }
129
130    // raw data response query, only for test
131    async fn query_raw_iter(&self, _sql: &str) -> Result<RawRowIterator> {
132        Err(Error::BadArgument(
133            "Unsupported implement query_raw_iter".to_string(),
134        ))
135    }
136
137    // raw data response query, only for test
138    async fn query_raw_all(&self, sql: &str) -> Result<Vec<RawRow>> {
139        let rows = self.query_raw_iter(sql).await?;
140        rows.collect().await
141    }
142
143    /// Get presigned url for a given operation and stage location.
144    /// The operation can be "UPLOAD" or "DOWNLOAD".
145    async fn get_presigned_url(&self, operation: &str, stage: &str) -> Result<PresignedResponse>;
146
147    async fn upload_to_stage(&self, stage: &str, data: Reader, size: u64) -> Result<()>;
148
149    async fn load_data(
150        &self,
151        sql: &str,
152        data: Reader,
153        size: u64,
154        file_format_options: Option<BTreeMap<&str, &str>>,
155        copy_options: Option<BTreeMap<&str, &str>>,
156    ) -> Result<ServerStats>;
157
158    async fn load_file(
159        &self,
160        sql: &str,
161        fp: &Path,
162        format_options: Option<BTreeMap<&str, &str>>,
163        copy_options: Option<BTreeMap<&str, &str>>,
164    ) -> Result<ServerStats>;
165
166    async fn stream_load(&self, sql: &str, data: Vec<Vec<&str>>) -> Result<ServerStats>;
167
168    // PUT file://<path_to_file>/<filename> internalStage|externalStage
169    async fn put_files(&self, local_file: &str, stage: &str) -> Result<RowStatsIterator> {
170        let mut total_count: usize = 0;
171        let mut total_size: usize = 0;
172        let local_dsn = url::Url::parse(local_file)?;
173        validate_local_scheme(local_dsn.scheme())?;
174        let mut results = Vec::new();
175        let stage_location = StageLocation::try_from(stage)?;
176        let schema = Arc::new(put_get_schema());
177        for entry in glob::glob(local_dsn.path())? {
178            let entry = entry?;
179            let filename = entry
180                .file_name()
181                .ok_or_else(|| Error::BadArgument(format!("Invalid local file path: {:?}", entry)))?
182                .to_str()
183                .ok_or_else(|| {
184                    Error::BadArgument(format!("Invalid local file path: {:?}", entry))
185                })?;
186            let stage_file = stage_location.file_path(filename);
187            let file = File::open(&entry).await?;
188            let size = file.metadata().await?.len();
189            let data = BufReader::new(file);
190            let (fname, status) = match self
191                .upload_to_stage(&stage_file, Box::new(data), size)
192                .await
193            {
194                Ok(_) => {
195                    total_count += 1;
196                    total_size += size as usize;
197                    (entry.to_string_lossy().to_string(), "SUCCESS".to_owned())
198                }
199                Err(e) => (entry.to_string_lossy().to_string(), e.to_string()),
200            };
201            let ss = ServerStats {
202                write_rows: total_count,
203                write_bytes: total_size,
204
205                ..Default::default()
206            };
207            results.push(Ok(RowWithStats::Stats(ss)));
208            results.push(Ok(RowWithStats::Row(Row::from_vec(
209                schema.clone(),
210                vec![
211                    Value::String(fname),
212                    Value::String(status),
213                    Value::Number(NumberValue::UInt64(size)),
214                ],
215            ))));
216        }
217        Ok(RowStatsIterator::new(
218            schema,
219            Box::pin(tokio_stream::iter(results)),
220        ))
221    }
222
223    async fn get_files(&self, stage: &str, local_file: &str) -> Result<RowStatsIterator> {
224        let mut total_count: usize = 0;
225        let mut total_size: usize = 0;
226        let local_dsn = url::Url::parse(local_file)?;
227        validate_local_scheme(local_dsn.scheme())?;
228        let mut location = StageLocation::try_from(stage)?;
229        if !location.path.ends_with('/') {
230            location.path.push('/');
231        }
232        let list_sql = format!("LIST {}", location);
233        let mut response = self.query_iter(&list_sql).await?;
234        let mut results = Vec::new();
235        let schema = Arc::new(put_get_schema());
236        while let Some(row) = response.next().await {
237            let (mut name, _, _, _, _): (String, u64, Option<String>, String, Option<String>) =
238                row?.try_into().map_err(Error::Parsing)?;
239            if !location.path.is_empty() && name.starts_with(&location.path) {
240                name = name[location.path.len()..].to_string();
241            }
242            let stage_file = format!("{}/{}", location, name);
243            let presign = self.get_presigned_url("DOWNLOAD", &stage_file).await?;
244            let local_file = Path::new(local_dsn.path()).join(&name);
245            let status = presign_download_from_stage(presign, &local_file).await;
246            let (status, size) = match status {
247                Ok(size) => {
248                    total_count += 1;
249                    total_size += size as usize;
250                    ("SUCCESS".to_owned(), size)
251                }
252                Err(e) => (e.to_string(), 0),
253            };
254            let ss = ServerStats {
255                read_rows: total_count,
256                read_bytes: total_size,
257                ..Default::default()
258            };
259            results.push(Ok(RowWithStats::Stats(ss)));
260            results.push(Ok(RowWithStats::Row(Row::from_vec(
261                schema.clone(),
262                vec![
263                    Value::String(local_file.to_string_lossy().to_string()),
264                    Value::String(status),
265                    Value::Number(NumberValue::UInt64(size)),
266                ],
267            ))));
268        }
269        Ok(RowStatsIterator::new(
270            schema,
271            Box::pin(tokio_stream::iter(results)),
272        ))
273    }
274}
275
276fn put_get_schema() -> Schema {
277    Schema::from_vec(vec![
278        Field {
279            name: "file".to_string(),
280            data_type: DataType::String,
281        },
282        Field {
283            name: "status".to_string(),
284            data_type: DataType::String,
285        },
286        Field {
287            name: "size".to_string(),
288            data_type: DataType::Number(NumberDataType::UInt64),
289        },
290    ])
291}
292
293fn validate_local_scheme(scheme: &str) -> Result<()> {
294    match scheme {
295        "file" | "fs" => Ok(()),
296        _ => Err(Error::BadArgument(
297            "Supported schemes: file:// or fs://".to_string(),
298        )),
299    }
300}