1use std::collections::{BTreeMap, VecDeque};
16use std::future::Future;
17use std::marker::PhantomData;
18use std::path::Path;
19use std::pin::Pin;
20use std::sync::Arc;
21use std::task::{Context, Poll};
22
23use async_trait::async_trait;
24use log::info;
25use tokio::fs::File;
26use tokio::io::BufReader;
27use tokio_stream::Stream;
28
29use databend_client::PresignedResponse;
30use databend_client::QueryResponse;
31use databend_client::{APIClient, SchemaField};
32use bigbytes_driver_core::error::{Error, Result};
33use bigbytes_driver_core::raw_rows::{RawRow, RawRowIterator, RawRowWithStats};
34use bigbytes_driver_core::rows::{Row, RowIterator, RowStatsIterator, RowWithStats, ServerStats};
35use bigbytes_driver_core::schema::{Schema, SchemaRef};
36
37use crate::conn::{Connection, ConnectionInfo, Reader};
38
39#[derive(Clone)]
40pub struct RestAPIConnection {
41 client: Arc<APIClient>,
42}
43
44#[async_trait]
45impl Connection for RestAPIConnection {
46 async fn info(&self) -> ConnectionInfo {
47 ConnectionInfo {
48 handler: "RestAPI".to_string(),
49 host: self.client.host().to_string(),
50 port: self.client.port(),
51 user: self.client.username(),
52 database: self.client.current_database(),
53 warehouse: self.client.current_warehouse(),
54 }
55 }
56
57 fn last_query_id(&self) -> Option<String> {
58 self.client.last_query_id()
59 }
60
61 async fn close(&self) -> Result<()> {
62 self.client.close().await;
63 Ok(())
64 }
65
66 async fn exec(&self, sql: &str) -> Result<i64> {
67 info!("exec: {}", sql);
68 let mut resp = self.client.start_query(sql).await?;
69 let node_id = resp.node_id.clone();
70 while let Some(next_uri) = resp.next_uri {
71 resp = self
72 .client
73 .query_page(&resp.id, &next_uri, &node_id)
74 .await?;
75 }
76 Ok(resp.stats.progresses.write_progress.rows as i64)
77 }
78
79 async fn kill_query(&self, query_id: &str) -> Result<()> {
80 Ok(self.client.kill_query(query_id).await?)
81 }
82
83 async fn query_iter(&self, sql: &str) -> Result<RowIterator> {
84 info!("query iter: {}", sql);
85 let rows_with_progress = self.query_iter_ext(sql).await?;
86 let rows = rows_with_progress.filter_rows().await;
87 Ok(rows)
88 }
89
90 async fn query_iter_ext(&self, sql: &str) -> Result<RowStatsIterator> {
91 info!("query iter ext: {}", sql);
92 let resp = self.client.start_query(sql).await?;
93 let resp = self.wait_for_schema(resp, true).await?;
94 let (schema, rows) = RestAPIRows::<RowWithStats>::from_response(self.client.clone(), resp)?;
95 Ok(RowStatsIterator::new(Arc::new(schema), Box::pin(rows)))
96 }
97
98 async fn query_raw_iter(&self, sql: &str) -> Result<RawRowIterator> {
100 info!("query raw iter: {}", sql);
101 let resp = self.client.start_query(sql).await?;
102 let resp = self.wait_for_schema(resp, true).await?;
103 let (schema, rows) =
104 RestAPIRows::<RawRowWithStats>::from_response(self.client.clone(), resp)?;
105 Ok(RawRowIterator::new(Arc::new(schema), Box::pin(rows)))
106 }
107
108 async fn get_presigned_url(&self, operation: &str, stage: &str) -> Result<PresignedResponse> {
109 info!("get presigned url: {} {}", operation, stage);
110 let sql = format!("PRESIGN {} {}", operation, stage);
111 let row = self.query_row(&sql).await?.ok_or_else(|| {
112 Error::InvalidResponse("Empty response from server for presigned request".to_string())
113 })?;
114 let (method, headers, url): (String, String, String) =
115 row.try_into().map_err(Error::Parsing)?;
116 let headers: BTreeMap<String, String> = serde_json::from_str(&headers)?;
117 Ok(PresignedResponse {
118 method,
119 headers,
120 url,
121 })
122 }
123
124 async fn upload_to_stage(&self, stage: &str, data: Reader, size: u64) -> Result<()> {
125 self.client.upload_to_stage(stage, data, size).await?;
126 Ok(())
127 }
128
129 async fn load_data(
130 &self,
131 sql: &str,
132 data: Reader,
133 size: u64,
134 file_format_options: Option<BTreeMap<&str, &str>>,
135 copy_options: Option<BTreeMap<&str, &str>>,
136 ) -> Result<ServerStats> {
137 info!(
138 "load data: {}, size: {}, format: {:?}, copy: {:?}",
139 sql, size, file_format_options, copy_options
140 );
141 let now = chrono::Utc::now()
142 .timestamp_nanos_opt()
143 .ok_or_else(|| Error::IO("Failed to get current timestamp".to_string()))?;
144 let stage = format!("@~/client/load/{}", now);
145
146 let file_format_options =
147 file_format_options.unwrap_or_else(Self::default_file_format_options);
148 let copy_options = copy_options.unwrap_or_else(Self::default_copy_options);
149
150 self.upload_to_stage(&stage, data, size).await?;
151 let resp = self
152 .client
153 .insert_with_stage(sql, &stage, file_format_options, copy_options)
154 .await?;
155 Ok(ServerStats::from(resp.stats))
156 }
157
158 async fn load_file(
159 &self,
160 sql: &str,
161 fp: &Path,
162 format_options: Option<BTreeMap<&str, &str>>,
163 copy_options: Option<BTreeMap<&str, &str>>,
164 ) -> Result<ServerStats> {
165 info!(
166 "load file: {}, file: {:?}, format: {:?}, copy: {:?}",
167 sql, fp, format_options, copy_options
168 );
169 let file = File::open(fp).await?;
170 let metadata = file.metadata().await?;
171 let size = metadata.len();
172 let data = BufReader::new(file);
173 let mut format_options = format_options.unwrap_or_else(Self::default_file_format_options);
174 if !format_options.contains_key("type") {
175 let file_type = fp
176 .extension()
177 .ok_or_else(|| Error::BadArgument("file type not specified".to_string()))?
178 .to_str()
179 .ok_or_else(|| Error::BadArgument("file type empty".to_string()))?;
180 format_options.insert("type", file_type);
181 }
182 self.load_data(
183 sql,
184 Box::new(data),
185 size,
186 Some(format_options),
187 copy_options,
188 )
189 .await
190 }
191
192 async fn stream_load(&self, sql: &str, data: Vec<Vec<&str>>) -> Result<ServerStats> {
193 info!("stream load: {}, length: {:?}", sql, data.len());
194 let mut wtr = csv::WriterBuilder::new().from_writer(vec![]);
195 for row in data {
196 wtr.write_record(row)
197 .map_err(|e| Error::BadArgument(e.to_string()))?;
198 }
199 let bytes = wtr.into_inner().map_err(|e| Error::IO(e.to_string()))?;
200 let size = bytes.len() as u64;
201 let reader = Box::new(std::io::Cursor::new(bytes));
202 let stats = self.load_data(sql, reader, size, None, None).await?;
203 Ok(stats)
204 }
205}
206
207impl<'o> RestAPIConnection {
208 pub async fn try_create(dsn: &str, name: String) -> Result<Self> {
209 let client = APIClient::new(dsn, Some(name)).await?;
210 Ok(Self {
211 client: Arc::new(client),
212 })
213 }
214
215 async fn wait_for_schema(
216 &self,
217 resp: QueryResponse,
218 return_on_progress: bool,
219 ) -> Result<QueryResponse> {
220 if !resp.data.is_empty()
221 || !resp.schema.is_empty()
222 || (return_on_progress && resp.stats.progresses.has_progress())
223 {
224 return Ok(resp);
225 }
226 let node_id = resp.node_id.clone();
227 if let Some(node_id) = &node_id {
228 self.client.set_last_node_id(node_id.clone());
229 }
230 let mut result = resp;
231 while let Some(next_uri) = result.next_uri {
233 result = self
234 .client
235 .query_page(&result.id, &next_uri, &node_id)
236 .await?;
237
238 if !result.data.is_empty()
239 || !result.schema.is_empty()
240 || (return_on_progress && result.stats.progresses.has_progress())
241 {
242 break;
243 }
244 }
245 Ok(result)
246 }
247
248 fn default_file_format_options() -> BTreeMap<&'o str, &'o str> {
249 vec![
250 ("type", "CSV"),
251 ("field_delimiter", ","),
252 ("record_delimiter", "\n"),
253 ("skip_header", "0"),
254 ]
255 .into_iter()
256 .collect()
257 }
258
259 fn default_copy_options() -> BTreeMap<&'o str, &'o str> {
260 vec![("purge", "true")].into_iter().collect()
261 }
262
263 pub async fn query_row_batch(&self, sql: &str) -> Result<RowBatch> {
264 let resp = self.client.start_query(sql).await?;
265 let resp = self.wait_for_schema(resp, false).await?;
266 RowBatch::from_response(self.client.clone(), resp)
267 }
268}
269
270type PageFut = Pin<Box<dyn Future<Output = Result<QueryResponse>> + Send>>;
271
272pub struct RestAPIRows<T> {
273 client: Arc<APIClient>,
274 schema: SchemaRef,
275 data: VecDeque<Vec<Option<String>>>,
276 stats: Option<ServerStats>,
277 query_id: String,
278 node_id: Option<String>,
279 next_uri: Option<String>,
280 next_page: Option<PageFut>,
281 _phantom: std::marker::PhantomData<T>,
282}
283
284impl<T> RestAPIRows<T> {
285 fn from_response(client: Arc<APIClient>, resp: QueryResponse) -> Result<(Schema, Self)> {
286 let schema: Schema = resp.schema.try_into()?;
287 let rows = Self {
288 client,
289 query_id: resp.id,
290 node_id: resp.node_id,
291 next_uri: resp.next_uri,
292 schema: Arc::new(schema.clone()),
293 data: resp.data.into(),
294 stats: Some(ServerStats::from(resp.stats)),
295 next_page: None,
296 _phantom: PhantomData,
297 };
298 Ok((schema, rows))
299 }
300}
301
302impl<T: FromRowStats + std::marker::Unpin> Stream for RestAPIRows<T> {
303 type Item = Result<T>;
304
305 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
306 if let Some(ss) = self.stats.take() {
307 return Poll::Ready(Some(Ok(T::from_stats(ss))));
308 }
309 if self.data.len() > 1 {
312 if let Some(row) = self.data.pop_front() {
313 let row = T::try_from_row(row, self.schema.clone())?;
314 return Poll::Ready(Some(Ok(row)));
315 }
316 }
317 match self.next_page {
318 Some(ref mut next_page) => match Pin::new(next_page).poll(cx) {
319 Poll::Ready(Ok(resp)) => {
320 if self.schema.fields().is_empty() {
321 self.schema = Arc::new(resp.schema.try_into()?);
322 }
323 self.next_uri = resp.next_uri;
324 self.next_page = None;
325 let mut new_data = resp.data.into();
326 self.data.append(&mut new_data);
327 Poll::Ready(Some(Ok(T::from_stats(resp.stats.into()))))
328 }
329 Poll::Ready(Err(e)) => {
330 self.next_page = None;
331 Poll::Ready(Some(Err(e)))
332 }
333 Poll::Pending => Poll::Pending,
334 },
335 None => match self.next_uri {
336 Some(ref next_uri) => {
337 let client = self.client.clone();
338 let next_uri = next_uri.clone();
339 let query_id = self.query_id.clone();
340 let node_id = self.node_id.clone();
341 self.next_page = Some(Box::pin(async move {
342 client
343 .query_page(&query_id, &next_uri, &node_id)
344 .await
345 .map_err(|e| e.into())
346 }));
347 self.poll_next(cx)
348 }
349 None => match self.data.pop_front() {
350 Some(row) => {
351 let row = T::try_from_row(row, self.schema.clone())?;
352 Poll::Ready(Some(Ok(row)))
353 }
354 None => Poll::Ready(None),
355 },
356 },
357 }
358 }
359}
360
361trait FromRowStats: Send + Sync + Clone {
362 fn from_stats(stats: ServerStats) -> Self;
363 fn try_from_row(row: Vec<Option<String>>, schema: SchemaRef) -> Result<Self>;
364}
365
366impl FromRowStats for RowWithStats {
367 fn from_stats(stats: ServerStats) -> Self {
368 RowWithStats::Stats(stats)
369 }
370
371 fn try_from_row(row: Vec<Option<String>>, schema: SchemaRef) -> Result<Self> {
372 Ok(RowWithStats::Row(Row::try_from((schema, row))?))
373 }
374}
375
376impl FromRowStats for RawRowWithStats {
377 fn from_stats(stats: ServerStats) -> Self {
378 RawRowWithStats::Stats(stats)
379 }
380
381 fn try_from_row(row: Vec<Option<String>>, schema: SchemaRef) -> Result<Self> {
382 let rows = Row::try_from((schema, row.clone()))?;
383 Ok(RawRowWithStats::Row(RawRow::new(rows, row)))
384 }
385}
386
387pub struct RowBatch {
388 schema: Vec<SchemaField>,
389 client: Arc<APIClient>,
390 query_id: String,
391 node_id: Option<String>,
392
393 next_uri: Option<String>,
394 data: Vec<Vec<Option<String>>>,
395}
396
397impl RowBatch {
398 pub fn schema(&self) -> Vec<SchemaField> {
399 self.schema.clone()
400 }
401
402 fn from_response(client: Arc<APIClient>, mut resp: QueryResponse) -> Result<Self> {
403 Ok(Self {
404 schema: std::mem::take(&mut resp.schema),
405 client,
406 query_id: resp.id,
407 node_id: resp.node_id,
408 next_uri: resp.next_uri,
409 data: resp.data,
410 })
411 }
412
413 pub async fn fetch_next_page(&mut self) -> Result<Vec<Vec<Option<String>>>> {
414 if !self.data.is_empty() {
415 return Ok(std::mem::take(&mut self.data));
416 }
417 while let Some(next_uri) = &self.next_uri {
418 let resp = self
419 .client
420 .query_page(&self.query_id, next_uri, &self.node_id)
421 .await?;
422
423 self.next_uri = resp.next_uri;
424 if !resp.data.is_empty() {
425 return Ok(resp.data);
426 }
427 }
428 Ok(vec![])
429 }
430}