1use std::collections::{BTreeMap, VecDeque};
16use std::future::Future;
17use std::marker::PhantomData;
18use std::path::Path;
19use std::pin::Pin;
20use std::sync::Arc;
21use std::task::{Context, Poll};
22
23use async_trait::async_trait;
24use log::info;
25use tokio::fs::File;
26use tokio::io::BufReader;
27use tokio_stream::Stream;
28
29use databend_client::PresignedResponse;
30use databend_client::QueryResponse;
31use databend_client::{APIClient, SchemaField};
32use databend_driver_core::error::{Error, Result};
33use databend_driver_core::raw_rows::{RawRow, RawRowIterator, RawRowWithStats};
34use databend_driver_core::rows::{Row, RowIterator, RowStatsIterator, RowWithStats, ServerStats};
35use databend_driver_core::schema::{Schema, SchemaRef};
36
37use crate::conn::{ConnectionInfo, IConnection, Reader};
38
39#[derive(Clone)]
40pub struct RestAPIConnection {
41 client: Arc<APIClient>,
42}
43
44#[async_trait]
45impl IConnection for RestAPIConnection {
46 async fn info(&self) -> ConnectionInfo {
47 ConnectionInfo {
48 handler: "RestAPI".to_string(),
49 host: self.client.host().to_string(),
50 port: self.client.port(),
51 user: self.client.username(),
52 database: self.client.current_database(),
53 warehouse: self.client.current_warehouse(),
54 }
55 }
56
57 fn last_query_id(&self) -> Option<String> {
58 self.client.last_query_id()
59 }
60
61 async fn close(&self) -> Result<()> {
62 self.client.close().await;
63 Ok(())
64 }
65
66 async fn exec(&self, sql: &str) -> Result<i64> {
67 info!("exec: {}", sql);
68 let mut resp = self.client.start_query(sql).await?;
69 let node_id = resp.node_id.clone();
70 if let Some(node_id) = &node_id {
71 self.client.set_last_node_id(node_id.clone());
72 }
73 while let Some(next_uri) = resp.next_uri {
74 resp = self
75 .client
76 .query_page(&resp.id, &next_uri, &node_id)
77 .await?;
78 }
79 Ok(resp.stats.progresses.write_progress.rows as i64)
80 }
81
82 async fn kill_query(&self, query_id: &str) -> Result<()> {
83 Ok(self.client.kill_query(query_id).await?)
84 }
85
86 async fn query_iter(&self, sql: &str) -> Result<RowIterator> {
87 info!("query iter: {}", sql);
88 let rows_with_progress = self.query_iter_ext(sql).await?;
89 let rows = rows_with_progress.filter_rows().await;
90 Ok(rows)
91 }
92
93 async fn query_iter_ext(&self, sql: &str) -> Result<RowStatsIterator> {
94 info!("query iter ext: {}", sql);
95 let resp = self.client.start_query(sql).await?;
96 let resp = self.wait_for_schema(resp, true).await?;
97 let (schema, rows) = RestAPIRows::<RowWithStats>::from_response(self.client.clone(), resp)?;
98 Ok(RowStatsIterator::new(Arc::new(schema), Box::pin(rows)))
99 }
100
101 async fn query_raw_iter(&self, sql: &str) -> Result<RawRowIterator> {
103 info!("query raw iter: {}", sql);
104 let resp = self.client.start_query(sql).await?;
105 let resp = self.wait_for_schema(resp, true).await?;
106 let (schema, rows) =
107 RestAPIRows::<RawRowWithStats>::from_response(self.client.clone(), resp)?;
108 Ok(RawRowIterator::new(Arc::new(schema), Box::pin(rows)))
109 }
110
111 async fn get_presigned_url(&self, operation: &str, stage: &str) -> Result<PresignedResponse> {
112 info!("get presigned url: {} {}", operation, stage);
113 let sql = format!("PRESIGN {} {}", operation, stage);
114 let row = self.query_row(&sql).await?.ok_or_else(|| {
115 Error::InvalidResponse("Empty response from server for presigned request".to_string())
116 })?;
117 let (method, headers, url): (String, String, String) =
118 row.try_into().map_err(Error::Parsing)?;
119 let headers: BTreeMap<String, String> = serde_json::from_str(&headers)?;
120 Ok(PresignedResponse {
121 method,
122 headers,
123 url,
124 })
125 }
126
127 async fn upload_to_stage(&self, stage: &str, data: Reader, size: u64) -> Result<()> {
128 self.client.upload_to_stage(stage, data, size).await?;
129 Ok(())
130 }
131
132 async fn load_data(
133 &self,
134 sql: &str,
135 data: Reader,
136 size: u64,
137 file_format_options: Option<BTreeMap<&str, &str>>,
138 copy_options: Option<BTreeMap<&str, &str>>,
139 ) -> Result<ServerStats> {
140 info!(
141 "load data: {}, size: {}, format: {:?}, copy: {:?}",
142 sql, size, file_format_options, copy_options
143 );
144 let now = chrono::Utc::now()
145 .timestamp_nanos_opt()
146 .ok_or_else(|| Error::IO("Failed to get current timestamp".to_string()))?;
147 let stage = format!("@~/client/load/{}", now);
148
149 let file_format_options =
150 file_format_options.unwrap_or_else(Self::default_file_format_options);
151 let copy_options = copy_options.unwrap_or_else(Self::default_copy_options);
152
153 self.upload_to_stage(&stage, data, size).await?;
154 let resp = self
155 .client
156 .insert_with_stage(sql, &stage, file_format_options, copy_options)
157 .await?;
158 Ok(ServerStats::from(resp.stats))
159 }
160
161 async fn load_file(
162 &self,
163 sql: &str,
164 fp: &Path,
165 format_options: Option<BTreeMap<&str, &str>>,
166 copy_options: Option<BTreeMap<&str, &str>>,
167 ) -> Result<ServerStats> {
168 info!(
169 "load file: {}, file: {:?}, format: {:?}, copy: {:?}",
170 sql, fp, format_options, copy_options
171 );
172 let file = File::open(fp).await?;
173 let metadata = file.metadata().await?;
174 let size = metadata.len();
175 let data = BufReader::new(file);
176 let mut format_options = format_options.unwrap_or_else(Self::default_file_format_options);
177 if !format_options.contains_key("type") {
178 let file_type = fp
179 .extension()
180 .ok_or_else(|| Error::BadArgument("file type not specified".to_string()))?
181 .to_str()
182 .ok_or_else(|| Error::BadArgument("file type empty".to_string()))?;
183 format_options.insert("type", file_type);
184 }
185 self.load_data(
186 sql,
187 Box::new(data),
188 size,
189 Some(format_options),
190 copy_options,
191 )
192 .await
193 }
194
195 async fn stream_load(&self, sql: &str, data: Vec<Vec<&str>>) -> Result<ServerStats> {
196 info!("stream load: {}, length: {:?}", sql, data.len());
197 let mut wtr = csv::WriterBuilder::new().from_writer(vec![]);
198 for row in data {
199 wtr.write_record(row)
200 .map_err(|e| Error::BadArgument(e.to_string()))?;
201 }
202 let bytes = wtr.into_inner().map_err(|e| Error::IO(e.to_string()))?;
203 let size = bytes.len() as u64;
204 let reader = Box::new(std::io::Cursor::new(bytes));
205 let stats = self.load_data(sql, reader, size, None, None).await?;
206 Ok(stats)
207 }
208}
209
210impl<'o> RestAPIConnection {
211 pub async fn try_create(dsn: &str, name: String) -> Result<Self> {
212 let client = APIClient::new(dsn, Some(name)).await?;
213 Ok(Self {
214 client: Arc::new(client),
215 })
216 }
217
218 async fn wait_for_schema(
219 &self,
220 resp: QueryResponse,
221 return_on_progress: bool,
222 ) -> Result<QueryResponse> {
223 let node_id = resp.node_id.clone();
224 if let Some(node_id) = &node_id {
225 self.client.set_last_node_id(node_id.clone());
226 }
227 if !resp.data.is_empty()
228 || !resp.schema.is_empty()
229 || (return_on_progress && resp.stats.progresses.has_progress())
230 {
231 return Ok(resp);
232 }
233 let mut result = resp;
234 while let Some(next_uri) = result.next_uri {
236 result = self
237 .client
238 .query_page(&result.id, &next_uri, &node_id)
239 .await?;
240
241 if !result.data.is_empty()
242 || !result.schema.is_empty()
243 || (return_on_progress && result.stats.progresses.has_progress())
244 {
245 break;
246 }
247 }
248 Ok(result)
249 }
250
251 fn default_file_format_options() -> BTreeMap<&'o str, &'o str> {
252 vec![
253 ("type", "CSV"),
254 ("field_delimiter", ","),
255 ("record_delimiter", "\n"),
256 ("skip_header", "0"),
257 ]
258 .into_iter()
259 .collect()
260 }
261
262 fn default_copy_options() -> BTreeMap<&'o str, &'o str> {
263 vec![("purge", "true")].into_iter().collect()
264 }
265
266 pub async fn query_row_batch(&self, sql: &str) -> Result<RowBatch> {
267 let resp = self.client.start_query(sql).await?;
268 let resp = self.wait_for_schema(resp, false).await?;
269 RowBatch::from_response(self.client.clone(), resp)
270 }
271}
272
273type PageFut = Pin<Box<dyn Future<Output = Result<QueryResponse>> + Send>>;
274
275pub struct RestAPIRows<T> {
276 client: Arc<APIClient>,
277 schema: SchemaRef,
278 data: VecDeque<Vec<Option<String>>>,
279 stats: Option<ServerStats>,
280 query_id: String,
281 node_id: Option<String>,
282 next_uri: Option<String>,
283 next_page: Option<PageFut>,
284 _phantom: std::marker::PhantomData<T>,
285}
286
287impl<T> RestAPIRows<T> {
288 fn from_response(client: Arc<APIClient>, resp: QueryResponse) -> Result<(Schema, Self)> {
289 let schema: Schema = resp.schema.try_into()?;
290 let rows = Self {
291 client,
292 query_id: resp.id,
293 node_id: resp.node_id,
294 next_uri: resp.next_uri,
295 schema: Arc::new(schema.clone()),
296 data: resp.data.into(),
297 stats: Some(ServerStats::from(resp.stats)),
298 next_page: None,
299 _phantom: PhantomData,
300 };
301 Ok((schema, rows))
302 }
303}
304
305impl<T: FromRowStats + std::marker::Unpin> Stream for RestAPIRows<T> {
306 type Item = Result<T>;
307
308 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
309 if let Some(ss) = self.stats.take() {
310 return Poll::Ready(Some(Ok(T::from_stats(ss))));
311 }
312 if self.data.len() > 1 {
315 if let Some(row) = self.data.pop_front() {
316 let row = T::try_from_row(row, self.schema.clone())?;
317 return Poll::Ready(Some(Ok(row)));
318 }
319 }
320 match self.next_page {
321 Some(ref mut next_page) => match Pin::new(next_page).poll(cx) {
322 Poll::Ready(Ok(resp)) => {
323 if self.schema.fields().is_empty() {
324 self.schema = Arc::new(resp.schema.try_into()?);
325 }
326 self.next_uri = resp.next_uri;
327 self.next_page = None;
328 let mut new_data = resp.data.into();
329 self.data.append(&mut new_data);
330 Poll::Ready(Some(Ok(T::from_stats(resp.stats.into()))))
331 }
332 Poll::Ready(Err(e)) => {
333 self.next_page = None;
334 Poll::Ready(Some(Err(e)))
335 }
336 Poll::Pending => Poll::Pending,
337 },
338 None => match self.next_uri {
339 Some(ref next_uri) => {
340 let client = self.client.clone();
341 let next_uri = next_uri.clone();
342 let query_id = self.query_id.clone();
343 let node_id = self.node_id.clone();
344 self.next_page = Some(Box::pin(async move {
345 client
346 .query_page(&query_id, &next_uri, &node_id)
347 .await
348 .map_err(|e| e.into())
349 }));
350 self.poll_next(cx)
351 }
352 None => match self.data.pop_front() {
353 Some(row) => {
354 let row = T::try_from_row(row, self.schema.clone())?;
355 Poll::Ready(Some(Ok(row)))
356 }
357 None => Poll::Ready(None),
358 },
359 },
360 }
361 }
362}
363
364trait FromRowStats: Send + Sync + Clone {
365 fn from_stats(stats: ServerStats) -> Self;
366 fn try_from_row(row: Vec<Option<String>>, schema: SchemaRef) -> Result<Self>;
367}
368
369impl FromRowStats for RowWithStats {
370 fn from_stats(stats: ServerStats) -> Self {
371 RowWithStats::Stats(stats)
372 }
373
374 fn try_from_row(row: Vec<Option<String>>, schema: SchemaRef) -> Result<Self> {
375 Ok(RowWithStats::Row(Row::try_from((schema, row))?))
376 }
377}
378
379impl FromRowStats for RawRowWithStats {
380 fn from_stats(stats: ServerStats) -> Self {
381 RawRowWithStats::Stats(stats)
382 }
383
384 fn try_from_row(row: Vec<Option<String>>, schema: SchemaRef) -> Result<Self> {
385 let rows = Row::try_from((schema, row.clone()))?;
386 Ok(RawRowWithStats::Row(RawRow::new(rows, row)))
387 }
388}
389
390pub struct RowBatch {
391 schema: Vec<SchemaField>,
392 client: Arc<APIClient>,
393 query_id: String,
394 node_id: Option<String>,
395
396 next_uri: Option<String>,
397 data: Vec<Vec<Option<String>>>,
398}
399
400impl RowBatch {
401 pub fn schema(&self) -> Vec<SchemaField> {
402 self.schema.clone()
403 }
404
405 fn from_response(client: Arc<APIClient>, mut resp: QueryResponse) -> Result<Self> {
406 Ok(Self {
407 schema: std::mem::take(&mut resp.schema),
408 client,
409 query_id: resp.id,
410 node_id: resp.node_id,
411 next_uri: resp.next_uri,
412 data: resp.data,
413 })
414 }
415
416 pub async fn fetch_next_page(&mut self) -> Result<Vec<Vec<Option<String>>>> {
417 if !self.data.is_empty() {
418 return Ok(std::mem::take(&mut self.data));
419 }
420 while let Some(next_uri) = &self.next_uri {
421 let resp = self
422 .client
423 .query_page(&self.query_id, next_uri, &self.node_id)
424 .await?;
425
426 self.next_uri = resp.next_uri;
427 if !resp.data.is_empty() {
428 return Ok(resp.data);
429 }
430 }
431 Ok(vec![])
432 }
433}