1use async_trait::async_trait;
16use log::info;
17use std::collections::{BTreeMap, VecDeque};
18use std::marker::PhantomData;
19use std::path::Path;
20use std::pin::Pin;
21use std::sync::Arc;
22use std::task::{Context, Poll};
23use std::time::Instant;
24use tokio::fs::File;
25use tokio::io::BufReader;
26use tokio_stream::Stream;
27
28use crate::client::LoadMethod;
29use crate::conn::{ConnectionInfo, IConnection, Reader};
30use lake_client::schema::{Schema, SchemaRef};
31use lake_client::Pages;
32use lake_client::{APIClient, ResultFormatSettings};
33use lake_driver_core::error::{Error, Result};
34use lake_driver_core::raw_rows::{RawRow, RawRowIterator, RawRowWithStats};
35use lake_driver_core::rows::{Row, RowIterator, RowStatsIterator, RowWithStats, Rows, ServerStats};
36
37const LOAD_PLACEHOLDER: &str = "@_databend_load";
38
39#[derive(Clone)]
40pub struct RestAPIConnection {
41 client: Arc<APIClient>,
42}
43
44impl RestAPIConnection {
45 fn gen_temp_stage_location(&self) -> Result<String> {
46 let now = chrono::Utc::now()
47 .timestamp_nanos_opt()
48 .ok_or_else(|| Error::IO("Failed to get current timestamp".to_string()))?;
49 Ok(format!("@~/client/load/{now}"))
50 }
51
52 async fn load_data_with_stage(
53 &self,
54 sql: &str,
55 data: Reader,
56 size: u64,
57 ) -> Result<ServerStats> {
58 let location = self.gen_temp_stage_location()?;
59 self.upload_to_stage(&location, data, size).await?;
60 if self.client.capability().streaming_load {
61 let sql = sql.replace(LOAD_PLACEHOLDER, &location);
62 let page = self.client.query_all(&sql, None).await?;
63 Ok(ServerStats::from(page.stats))
64 } else {
65 let file_format_options = Self::default_file_format_options();
66 let copy_options = Self::default_copy_options();
67 let stats = self
68 .client
69 .insert_with_stage(sql, &location, file_format_options, copy_options)
70 .await?;
71 Ok(ServerStats::from(stats))
72 }
73 }
74
75 async fn load_data_with_streaming(
76 &self,
77 sql: &str,
78 data: Reader,
79 size: u64,
80 ) -> Result<ServerStats> {
81 let start = Instant::now();
82 let response = self
83 .client
84 .streaming_load(sql, data, "<no_filename>")
85 .await?;
86 Ok(ServerStats {
87 total_rows: 0,
88 total_bytes: 0,
89 read_rows: response.stats.rows,
90 read_bytes: size as usize,
91 write_rows: response.stats.rows,
92 write_bytes: response.stats.bytes,
93 running_time_ms: start.elapsed().as_millis() as f64,
94 spill_file_nums: 0,
95 spill_bytes: 0,
96 })
97 }
98 async fn load_data_with_options(
99 &self,
100 sql: &str,
101 data: Reader,
102 size: u64,
103 file_format_options: Option<BTreeMap<&str, &str>>,
104 copy_options: Option<BTreeMap<&str, &str>>,
105 ) -> Result<ServerStats> {
106 let location = self.gen_temp_stage_location()?;
107 let file_format_options =
108 file_format_options.unwrap_or_else(Self::default_file_format_options);
109 let copy_options = copy_options.unwrap_or_else(Self::default_copy_options);
110 self.upload_to_stage(&location, Box::new(data), size)
111 .await?;
112 let stats = self
113 .client
114 .insert_with_stage(sql, &location, file_format_options, copy_options)
115 .await?;
116 Ok(ServerStats::from(stats))
117 }
118}
119
120#[async_trait]
121impl IConnection for RestAPIConnection {
122 async fn info(&self) -> ConnectionInfo {
123 ConnectionInfo {
124 handler: "RestAPI".to_string(),
125 host: self.client.host().to_string(),
126 port: self.client.port(),
127 user: self.client.username(),
128 catalog: self.client.current_catalog(),
129 database: self.client.current_database(),
130 warehouse: self.client.current_warehouse(),
131 }
132 }
133
134 fn last_query_id(&self) -> Option<String> {
135 self.client.last_query_id()
136 }
137
138 fn set_warehouse(&self, warehouse: &str) -> Result<()> {
139 self.client.set_warehouse(warehouse.to_string());
140 Ok(())
141 }
142
143 fn set_database(&self, database: &str) -> Result<()> {
144 self.client.set_database(database.to_string());
145 Ok(())
146 }
147
148 fn set_role(&self, role: &str) -> Result<()> {
149 self.client.set_role(role.to_string());
150 Ok(())
151 }
152
153 fn set_session(&self, key: &str, value: &str) -> Result<()> {
154 self.client.set_session(key.to_string(), value.to_string());
155 Ok(())
156 }
157
158 async fn close(&self) -> Result<()> {
159 self.client.close().await;
160 Ok(())
161 }
162
163 fn close_with_spawn(&self) -> Result<()> {
164 self.client.close_with_spawn();
165 Ok(())
166 }
167
168 fn supports_server_side_params(&self) -> bool {
169 self.client.capability().server_side_params
170 }
171
172 async fn exec(&self, sql: &str) -> Result<i64> {
173 info!("exec: {}", sql);
174 let page = self.client.query_all(sql, None).await?;
175 Ok(page.stats.progresses.write_progress.rows as i64)
176 }
177
178 async fn exec_with_params(&self, sql: &str, params: Option<serde_json::Value>) -> Result<i64> {
179 info!("exec with params: {}", sql);
180 let page = self.client.query_all(sql, params).await?;
181 Ok(page.stats.progresses.write_progress.rows as i64)
182 }
183
184 async fn kill_query(&self, query_id: &str) -> Result<()> {
185 Ok(self.client.kill_query(query_id).await?)
186 }
187
188 async fn query_iter(&self, sql: &str) -> Result<RowIterator> {
189 info!("query iter: {}", sql);
190 let rows_with_progress = self.query_iter_ext(sql).await?;
191 let rows = rows_with_progress.filter_rows().await?;
192 Ok(rows)
193 }
194
195 async fn query_iter_ext(&self, sql: &str) -> Result<RowStatsIterator> {
196 info!("query iter ext: {}", sql);
197 let pages = self.client.start_query(sql, true, None).await?;
198 let (schema, rows) = RestAPIRows::<RowWithStats>::from_pages(pages).await?;
199 Ok(RowStatsIterator::new(Arc::new(schema), Box::pin(rows)))
200 }
201
202 async fn query_iter_with_params(
203 &self,
204 sql: &str,
205 params: Option<serde_json::Value>,
206 ) -> Result<RowIterator> {
207 info!("query iter with params: {}", sql);
208 let rows_with_progress = self.query_iter_ext_with_params(sql, params).await?;
209 let rows = rows_with_progress.filter_rows().await?;
210 Ok(rows)
211 }
212
213 async fn query_iter_ext_with_params(
214 &self,
215 sql: &str,
216 params: Option<serde_json::Value>,
217 ) -> Result<RowStatsIterator> {
218 info!("query iter ext with params: {}", sql);
219 let pages = self.client.start_query(sql, true, params).await?;
220 let (schema, rows) = RestAPIRows::<RowWithStats>::from_pages(pages).await?;
221 Ok(RowStatsIterator::new(Arc::new(schema), Box::pin(rows)))
222 }
223
224 async fn query_raw_iter(&self, sql: &str) -> Result<RawRowIterator> {
226 info!("query raw iter: {}", sql);
227 let pages = self.client.start_query(sql, true, None).await?;
228 let (schema, rows) = RestAPIRows::<RawRowWithStats>::from_pages(pages).await?;
229 Ok(RawRowIterator::new(Arc::new(schema), Box::pin(rows)))
230 }
231
232 async fn upload_to_stage(&self, stage: &str, data: Reader, size: u64) -> Result<()> {
233 self.client.upload_to_stage(stage, data, size).await?;
234 Ok(())
235 }
236
237 async fn load_data(
238 &self,
239 sql: &str,
240 data: Reader,
241 size: u64,
242 method: LoadMethod,
243 ) -> Result<ServerStats> {
244 let sql = sql.trim_end();
245 let sql = sql.trim_end_matches(';');
246 info!("load data: {}, size: {}, method: {method:?}", sql, size);
247 let sql_low = sql.to_lowercase();
248 let has_place_holder = sql_low.contains(LOAD_PLACEHOLDER);
249 let sql = match (self.client.capability().streaming_load, has_place_holder) {
250 (false, false) => {
251 return self
253 .load_data_with_options(sql, data, size, None, None)
254 .await;
255 }
256 (false, true) => return Err(Error::BadArgument(
257 "Please upgrade your server to >= 1.2.781 to support insert from @_databend_load"
258 .to_string(),
259 )),
260 (true, false) => {
261 format!("{sql} from @_databend_load file_format=(type=csv)")
262 }
263 (true, true) => sql.to_string(),
264 };
265
266 match method {
267 LoadMethod::Streaming => self.load_data_with_streaming(&sql, data, size).await,
268 LoadMethod::Stage => self.load_data_with_stage(&sql, data, size).await,
269 }
270 }
271
272 async fn load_file(&self, sql: &str, fp: &Path, method: LoadMethod) -> Result<ServerStats> {
273 info!("load file: {}, file: {:?}", sql, fp,);
274 let file = File::open(fp).await?;
275 let metadata = file.metadata().await?;
276 let size = metadata.len();
277 let data = BufReader::new(file);
278 self.load_data(sql, Box::new(data), size, method).await
279 }
280
281 async fn load_file_with_options(
282 &self,
283 sql: &str,
284 fp: &Path,
285 file_format_options: Option<BTreeMap<&str, &str>>,
286 copy_options: Option<BTreeMap<&str, &str>>,
287 ) -> Result<ServerStats> {
288 let file = File::open(fp).await?;
289 let metadata = file.metadata().await?;
290 let size = metadata.len();
291 let data = BufReader::new(file);
292 self.load_data_with_options(sql, Box::new(data), size, file_format_options, copy_options)
293 .await
294 }
295
296 async fn stream_load(
297 &self,
298 sql: &str,
299 data: Vec<Vec<&str>>,
300 method: LoadMethod,
301 ) -> Result<ServerStats> {
302 info!("stream load: {}; rows: {:?}", sql, data.len());
303 let mut wtr = csv::WriterBuilder::new().from_writer(vec![]);
304 for row in data {
305 wtr.write_record(row)
306 .map_err(|e| Error::BadArgument(e.to_string()))?;
307 }
308 let bytes = wtr.into_inner().map_err(|e| Error::IO(e.to_string()))?;
309 let size = bytes.len() as u64;
310 let reader = Box::new(std::io::Cursor::new(bytes));
311 let stats = if self.client.capability().streaming_load {
312 let sql = format!("{sql} from @_databend_load file_format = (type = csv)");
313 self.load_data(&sql, reader, size, method).await?
314 } else {
315 self.load_data_with_options(sql, reader, size, None, None)
316 .await?
317 };
318 Ok(stats)
319 }
320}
321
322impl<'o> RestAPIConnection {
323 pub async fn try_create(dsn: &str, name: String) -> Result<Self> {
324 let client = APIClient::new(dsn, Some(name)).await?;
325 Ok(Self { client })
326 }
327
328 fn default_file_format_options() -> BTreeMap<&'o str, &'o str> {
329 vec![
330 ("type", "CSV"),
331 ("field_delimiter", ","),
332 ("record_delimiter", "\n"),
333 ("skip_header", "0"),
334 ]
335 .into_iter()
336 .collect()
337 }
338
339 fn default_copy_options() -> BTreeMap<&'o str, &'o str> {
340 vec![("purge", "true")].into_iter().collect()
341 }
342}
343
344pub struct RestAPIRows<T> {
345 pages: Pages,
346
347 schema: SchemaRef,
348 settings: ResultFormatSettings,
349
350 data: VecDeque<Vec<Option<String>>>,
351 rows: VecDeque<Row>,
352
353 stats: Option<ServerStats>,
354
355 _phantom: std::marker::PhantomData<T>,
356}
357
358impl<T> RestAPIRows<T> {
359 async fn from_pages(pages: Pages) -> Result<(Schema, Self)> {
360 let (pages, schema, settings) = pages.wait_for_schema(true).await?;
361 let rows = Self {
362 pages,
363 schema: Arc::new(schema.clone()),
364 settings,
365 data: Default::default(),
366 rows: Default::default(),
367 stats: None,
368 _phantom: PhantomData,
369 };
370 Ok((schema, rows))
371 }
372}
373
374impl<T: FromRowStats + std::marker::Unpin> Stream for RestAPIRows<T> {
375 type Item = Result<T>;
376
377 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
378 if let Some(ss) = self.stats.take() {
379 return Poll::Ready(Some(Ok(T::from_stats(ss))));
380 }
381 if self.data.len() > 1 {
384 if let Some(row) = self.data.pop_front() {
385 let row = T::try_from_raw_row(row, self.schema.clone(), &self.settings)?;
386 return Poll::Ready(Some(Ok(row)));
387 }
388 } else if self.rows.len() > 1 {
389 if let Some(row) = self.rows.pop_front() {
390 let row = T::from_row(row);
391 return Poll::Ready(Some(Ok(row)));
392 }
393 }
394
395 match Pin::new(&mut self.pages).poll_next(cx) {
396 Poll::Ready(Some(Ok(page))) => {
397 if self.schema.fields().is_empty() {
398 if !page.raw_schema.is_empty() {
399 self.schema = Arc::new(page.raw_schema.try_into()?);
400 } else if !page.batches.is_empty() {
401 self.schema = Arc::new(page.batches[0].schema().clone().try_into()?);
402 }
403 }
404 if page.batches.is_empty() {
405 let mut new_data = page.data.into();
406 self.data.append(&mut new_data);
407 } else {
408 for batch in page.batches.into_iter() {
409 let rows = Rows::try_from((batch, self.settings.clone()))?;
410 self.rows.extend(rows);
411 }
412 }
413 Poll::Ready(Some(Ok(T::from_stats(page.stats.into()))))
414 }
415 Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e.into()))),
416 Poll::Ready(None) => {
417 if let Some(row) = self.rows.pop_front() {
418 let row = T::from_row(row);
419 Poll::Ready(Some(Ok(row)))
420 } else if let Some(row) = self.data.pop_front() {
421 let row = T::try_from_raw_row(row, self.schema.clone(), &self.settings)?;
422 Poll::Ready(Some(Ok(row)))
423 } else {
424 Poll::Ready(None)
425 }
426 }
427 Poll::Pending => Poll::Pending,
428 }
429 }
430}
431
432trait FromRowStats: Send + Sync + Clone {
433 fn from_stats(stats: ServerStats) -> Self;
434 fn try_from_raw_row(
435 row: Vec<Option<String>>,
436 schema: SchemaRef,
437 tz: &ResultFormatSettings,
438 ) -> Result<Self>;
439 fn from_row(row: Row) -> Self;
440}
441
442impl FromRowStats for RowWithStats {
443 fn from_stats(stats: ServerStats) -> Self {
444 RowWithStats::Stats(stats)
445 }
446
447 fn try_from_raw_row(
448 row: Vec<Option<String>>,
449 schema: SchemaRef,
450 settings: &ResultFormatSettings,
451 ) -> Result<Self> {
452 Ok(RowWithStats::Row(Row::try_from((schema, row, settings))?))
453 }
454 fn from_row(row: Row) -> Self {
455 RowWithStats::Row(row)
456 }
457}
458
459impl FromRowStats for RawRowWithStats {
460 fn from_stats(stats: ServerStats) -> Self {
461 RawRowWithStats::Stats(stats)
462 }
463
464 fn try_from_raw_row(
465 row: Vec<Option<String>>,
466 schema: SchemaRef,
467 tz: &ResultFormatSettings,
468 ) -> Result<Self> {
469 let rows = Row::try_from((schema, row.clone(), tz))?;
470 Ok(RawRowWithStats::Row(RawRow::new(rows, row)))
471 }
472
473 fn from_row(row: Row) -> Self {
474 RawRowWithStats::Row(RawRow::from(row))
475 }
476}