1use std::{
2 cmp::min,
3 fmt::Display,
4 path::{Path, PathBuf},
5 pin::Pin,
6 sync::Arc,
7 task::Poll,
8};
9
10use async_trait::async_trait;
11use futures::{AsyncBufRead, AsyncRead, Future, FutureExt};
12use log::{debug, trace};
13use pin_project::pin_project;
14use reqwest::multipart::Part;
15use serde::{Deserialize, Serialize};
16use thiserror::Error;
17use tokio::io::AsyncWriteExt;
18
19const CHUNK_SIZE: usize = 1024 * 1024;
20
21trait Logged {
23 fn log(self) -> Self;
24}
25
26impl<T, E> Logged for std::result::Result<T, E>
27where
28 E: std::fmt::Debug,
29{
30 fn log(self) -> Self {
31 if let Err(e) = &self {
32 log::debug!("---TraceError--- {:#?}", e)
33 }
34 self
35 }
36}
37
38#[async_trait]
39trait LoggedResponse {
40 async fn detailed_error_for_status(self) -> Result<Self>
41 where
42 Self: Sized;
43}
44
45#[async_trait]
46impl LoggedResponse for reqwest::Response {
47 async fn detailed_error_for_status(self) -> Result<Self> {
48 #[derive(Debug, Deserialize)]
49 struct DbfsErrorResponse {
50 error_code: DbfsErrorCode,
51 message: String,
52 }
53
54 if self.status().is_client_error() || self.status().is_server_error() {
55 let url = self.url().to_string();
56 let status = self.status().to_string();
57 let text = self.text().await?;
58 Err(match serde_json::from_str::<DbfsErrorResponse>(&text) {
59 Ok(resp) => DbfsError::DbfsApiError(resp.error_code, resp.message),
60 Err(_) => DbfsError::HttpError(url, status, text),
61 })
62 } else {
63 Ok(self)
64 }
65 }
66}
67
68#[derive(Debug, Serialize, Deserialize)]
72#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
73pub enum DbfsErrorCode {
74 ResourceAlreadyExists,
75 MaxBlockSizeExceeded,
76 InvalidParameterValue,
77 MaxReadSizeExceeded,
78 ResourceDoesNotExist,
79}
80
81impl Display for DbfsErrorCode {
82 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
83 f.write_str(
84 &serde_json::to_string(&self)
85 .unwrap()
86 .strip_prefix("\"")
87 .unwrap()
88 .strip_suffix("\"")
89 .unwrap(),
90 )
91 }
92}
93
94#[derive(Debug, Error)]
98pub enum DbfsError {
99 #[error(transparent)]
100 ReqwestError(#[from] reqwest::Error),
101
102 #[error("HTTP Error, URL: '{0}', Status: {1}, Response: '{2}' ")]
103 HttpError(String, String, String),
104
105 #[error(transparent)]
106 DecodeError(#[from] base64::DecodeError),
107
108 #[error(transparent)]
109 IoError(#[from] std::io::Error),
110
111 #[error(transparent)]
112 VarError(#[from] std::env::VarError),
113
114 #[error(transparent)]
115 JsonError(#[from] serde_json::Error),
116
117 #[error("DBFS Error, Code: {0}, message: {0}")]
118 DbfsApiError(DbfsErrorCode, String),
119
120 #[error("Invalid DBFS Path {0}")]
121 InvalidDbfsPath(String),
122}
123
124#[allow(non_camel_case_types)]
128#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord)]
129pub enum DbfsApiVersions {
130 API_2_0,
131}
132
133impl Default for DbfsApiVersions {
134 fn default() -> Self {
135 Self::API_2_0
136 }
137}
138
139impl Display for DbfsApiVersions {
140 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
141 write!(
142 f,
143 "{}",
144 match &self {
145 DbfsApiVersions::API_2_0 => "api/2.0",
146 }
147 )
148 }
149}
150
151pub type Result<T> = std::result::Result<T, DbfsError>;
155
156#[derive(Debug, Deserialize)]
160pub struct FileStatus {
161 pub path: String,
162 pub is_dir: bool,
163 pub file_size: usize,
164 pub modification_time: u64,
165}
166
167#[derive(Clone, Debug)]
171pub struct DbfsClient {
172 inner: Arc<DbfsClientInner>,
173}
174
175impl DbfsClient {
176 pub fn new(url_base: &str, token: &str) -> Self {
183 Self {
184 inner: Arc::new(DbfsClientInner::new(url_base, token)),
185 }
186 }
187
188 pub fn read(&self, path: &str) -> Result<DbfsReadStream> {
192 let path = strip_dbfs_prefix(path)?.to_string();
193 let inner = self.inner.clone();
194 Ok(DbfsReadStream {
195 reader: inner.clone(),
196 path: path.clone(),
197 step: ReadStreamSteps::Len,
198 file_size: 0,
199 file_offset: 0,
200 current_buf: vec![],
201 current_buf_offset: 0,
202 len_future: Box::pin(async move {
203 inner
204 .get_status(&path)
205 .map(|r| {
206 r.map(|s| s.file_size)
207 .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))
208 })
209 .await
210 }),
211 current_future: None,
212 })
213 }
214
215 pub async fn read_file(&self, path: &str) -> Result<Vec<u8>> {
219 let path = strip_dbfs_prefix(path)?;
220 debug!("Reading DBFS file {}", path);
221 let file_size = self.inner.get_status(path).await?.file_size;
222 debug!("File size is {}", file_size);
223 let mut ret = Vec::with_capacity(file_size);
224 let mut offset = 0;
225 loop {
226 let data = self.inner.read_block(path, offset, CHUNK_SIZE).await?;
227 offset += data.len();
228 ret.extend(data.into_iter());
229 if offset >= file_size {
230 break;
231 }
232 }
233 Ok(ret)
234 }
235
236 pub async fn write_file<T>(&self, path: &str, data: T) -> Result<()>
240 where
241 T: AsRef<[u8]>,
242 {
243 debug!(
244 "Writing {} bytes to DBFS file {}",
245 data.as_ref().len(),
246 path
247 );
248 let path = strip_dbfs_prefix(path)?;
249 if data.as_ref().len() < CHUNK_SIZE {
250 return self.inner.put(path, data, true).await;
251 }
252
253 let handle = self.inner.create(path, true).await?;
254 for chunk in data.as_ref().chunks(CHUNK_SIZE) {
255 self.inner.add_block(handle, chunk).await?;
256 }
257 self.inner.close(handle).await?;
258 Ok(())
259 }
260
261 pub async fn upload_file<T>(&self, local_path: T, remote_path: &str) -> Result<String>
265 where
266 T: AsRef<Path>,
267 {
268 debug!(
269 "Uploading local file {} to DBFS file {}",
270 local_path.as_ref().to_string_lossy(),
271 remote_path
272 );
273 let remote_path = strip_dbfs_prefix(remote_path)?;
274 let filename = local_path.as_ref().to_owned().to_string_lossy().to_string();
275 let file = tokio::fs::File::open(local_path).await?;
276 let length = file.metadata().await?.len();
277 let stream = tokio_util::codec::FramedRead::new(file, tokio_util::codec::BytesCodec::new());
278 let body = reqwest::Body::wrap_stream(stream);
279 self.inner
280 .put_stream(remote_path, &filename, body, length, true)
281 .await?;
282 Ok(remote_path.to_string())
283 }
284
285 pub async fn download_file<T>(&self, remote_path: &str, local_path: T) -> Result<PathBuf>
289 where
290 T: AsRef<Path>,
291 {
292 debug!(
293 "Downloading DBFS file {} to local file {}",
294 remote_path,
295 local_path.as_ref().to_string_lossy()
296 );
297 let remote_path = strip_dbfs_prefix(remote_path)?;
298 let file_size = self.inner.get_status(remote_path).await?.file_size;
299 let mut offset = 0;
300 let mut file = tokio::fs::File::create(local_path.as_ref()).await?;
301 loop {
302 let data = self
303 .inner
304 .read_block(remote_path, offset, CHUNK_SIZE)
305 .await?;
306 offset += data.len();
307 file.write_all(&data).await?;
308 if offset >= file_size {
309 break;
310 }
311 }
312 file.flush().await?;
313 file.sync_all().await?;
314 Ok(PathBuf::from(local_path.as_ref()))
315 }
316
317 pub async fn get_file_status(&self, path: &str) -> Result<FileStatus> {
321 debug!("Getting status of DBFS file {}", path);
322 self.inner.get_status(path).await
323 }
324
325 pub async fn delete_file(&self, path: &str) -> Result<()> {
329 debug!("Deleting DBFS file {}", path);
330 self.inner.delete(strip_dbfs_prefix(path)?).await
331 }
332
333 pub async fn list(&self, path: &str) -> Result<Vec<FileStatus>> {
337 debug!("Listing DBFS directory {}", path);
338 self.inner.list(strip_dbfs_prefix(path)?).await
339 }
340
341 pub async fn mkdir(&self, path: &str) -> Result<()> {
345 debug!("Creating DBFS directory {}", path);
346 self.inner.mkdirs(strip_dbfs_prefix(path)?).await
347 }
348
349 pub async fn move_file(&self, src_path: &str, dest_path: &str) -> Result<()> {
353 debug!("Moving DBFS file from {} to {}", src_path, dest_path);
354 self.inner
355 .move_(strip_dbfs_prefix(src_path)?, strip_dbfs_prefix(dest_path)?)
356 .await
357 }
358}
359
360#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
361struct Handle(u64);
362
363#[derive(Debug)]
364struct DbfsClientInner {
365 url_base: String,
366 api_version: DbfsApiVersions,
367 client: reqwest::Client,
368}
369
370impl DbfsClientInner {
371 fn new(url_base: &str, token: &str) -> Self {
372 let mut headers = reqwest::header::HeaderMap::new();
373 if !token.is_empty() {
374 headers.insert(
375 "Authorization",
376 reqwest::header::HeaderValue::from_str(&format!("Bearer {}", token)).unwrap(),
377 );
378 }
379
380 Self {
381 url_base: url_base
382 .trim()
383 .strip_suffix("/")
384 .unwrap_or(url_base)
385 .trim()
386 .to_string(),
387 api_version: DbfsApiVersions::API_2_0,
388 client: reqwest::ClientBuilder::new()
389 .default_headers(headers)
390 .build()
391 .unwrap(),
392 }
393 }
394
395 fn get_url(&self, api: &str) -> String {
396 format!("{}/{}/dbfs/{}", self.url_base, self.api_version, api)
397 }
398
399 async fn add_block<T>(&self, handle: Handle, data: T) -> Result<()>
402 where
403 T: AsRef<[u8]>,
404 {
405 trace!("Add block to handle {}", handle.0);
406 #[derive(Debug, Serialize)]
407 struct Request {
408 handle: Handle,
409 data: String,
410 }
411 self.client
412 .post(self.get_url("add-block"))
413 .json(&Request {
414 handle,
415 data: base64::encode(data),
416 })
417 .send()
418 .await?
419 .detailed_error_for_status()
420 .await
421 .log()?
422 .text()
423 .await?;
424 Ok(())
425 }
426
427 async fn close(&self, handle: Handle) -> Result<()> {
428 trace!("Close handle {}", handle.0);
429 #[derive(Debug, Serialize)]
430 struct Request {
431 handle: Handle,
432 }
433 self.client
434 .post(self.get_url("close"))
435 .json(&Request { handle })
436 .send()
437 .await?
438 .detailed_error_for_status()
439 .await
440 .log()?
441 .text()
442 .await?;
443 Ok(())
444 }
445
446 async fn create(&self, path: &str, overwrite: bool) -> Result<Handle> {
447 trace!("Create file {}", path);
448 #[derive(Debug, Serialize)]
449 struct Request {
450 path: String,
451 overwrite: bool,
452 }
453 #[derive(Debug, Deserialize)]
454 struct Response {
455 handle: Handle,
456 }
457 let resp: Response = self
458 .client
459 .post(self.get_url("create"))
460 .json(&Request {
461 path: path.to_string(),
462 overwrite,
463 })
464 .send()
465 .await?
466 .detailed_error_for_status()
467 .await
468 .log()?
469 .json()
470 .await?;
471 Ok(resp.handle)
472 }
473
474 async fn delete(&self, path: &str) -> Result<()> {
475 trace!("Delete file {}", path);
476 #[derive(Debug, Serialize)]
477 struct Request {
478 path: String,
479 }
480 self.client
481 .post(self.get_url("delete"))
482 .json(&Request {
483 path: path.to_string(),
484 })
485 .send()
486 .await?
487 .detailed_error_for_status()
488 .await
489 .log()?
490 .text()
491 .await?;
492 Ok(())
493 }
494
495 async fn get_status(&self, path: &str) -> Result<FileStatus> {
496 trace!("Get status of file {}", path);
497 #[derive(Debug, Serialize)]
498 struct Request {
499 path: String,
500 }
501 Ok(self
502 .client
503 .get(self.get_url("get-status"))
504 .json(&Request {
505 path: path.to_string(),
506 })
507 .send()
508 .await?
509 .detailed_error_for_status()
510 .await
511 .log()?
512 .json()
513 .await?)
514 }
515
516 async fn list(&self, path: &str) -> Result<Vec<FileStatus>> {
517 trace!("List directory {}", path);
518 #[derive(Debug, Serialize)]
519 struct Request {
520 path: String,
521 }
522 #[derive(Debug, Deserialize)]
523 struct Response {
524 files: Vec<FileStatus>,
525 }
526 let resp: Response = self
527 .client
528 .get(self.get_url("list"))
529 .json(&Request {
530 path: path.to_string(),
531 })
532 .send()
533 .await?
534 .detailed_error_for_status()
535 .await
536 .log()?
537 .json()
538 .await?;
539 Ok(resp.files)
540 }
541
542 async fn mkdirs(&self, path: &str) -> Result<()> {
543 trace!("Make directory {}", path);
544 #[derive(Debug, Serialize)]
545 struct Request {
546 path: String,
547 }
548 self.client
549 .post(self.get_url("mkdirs"))
550 .json(&Request {
551 path: path.to_string(),
552 })
553 .send()
554 .await?
555 .detailed_error_for_status()
556 .await
557 .log()?
558 .text()
559 .await?;
560 Ok(())
561 }
562
563 async fn move_(&self, source_path: &str, destination_path: &str) -> Result<()> {
564 trace!("Move file from {} to {}", source_path, destination_path);
565 #[derive(Debug, Serialize)]
566 struct Request {
567 source_path: String,
568 destination_path: String,
569 }
570 self.client
571 .post(self.get_url("move"))
572 .json(&Request {
573 source_path: source_path.to_string(),
574 destination_path: destination_path.to_string(),
575 })
576 .send()
577 .await?
578 .detailed_error_for_status()
579 .await
580 .log()?
581 .text()
582 .await?;
583 Ok(())
584 }
585
586 async fn put<T>(&self, path: &str, content: T, overwrite: bool) -> Result<()>
587 where
588 T: AsRef<[u8]>,
589 {
590 trace!(
591 "Upload buffer to file {}, length is {}",
592 path,
593 content.as_ref().len()
594 );
595 #[derive(Debug, Serialize)]
596 struct Request {
597 path: String,
598 contents: String,
599 overwrite: bool,
600 }
601 self.client
602 .post(self.get_url("put"))
603 .json(&Request {
604 path: path.to_string(),
605 contents: base64::encode(content),
606 overwrite,
607 })
608 .send()
609 .await?
610 .detailed_error_for_status()
611 .await
612 .log()?
613 .text()
614 .await?;
615 Ok(())
616 }
617
618 async fn put_stream<S>(
619 &self,
620 path: &str,
621 filename: &str,
622 stream: S,
623 length: u64,
624 overwrite: bool,
625 ) -> Result<()>
626 where
627 S: Into<reqwest::Body>,
628 {
629 trace!("Upload stream to file {}, length is {}", path, length);
630 let path = path.to_string();
631 let form = reqwest::multipart::Form::new()
632 .part(
633 "contents",
634 Part::stream_with_length(stream, length).file_name(filename.to_owned()),
635 )
636 .text("path", path)
637 .text("overwrite", if overwrite { "true" } else { "false" });
638 self.client
639 .post(self.get_url("put"))
640 .multipart(form)
641 .send()
642 .await?
643 .detailed_error_for_status()
644 .await
645 .log()?
646 .text()
647 .await?;
648 Ok(())
649 }
650
651 async fn read_block(&self, path: &str, offset: usize, length: usize) -> Result<Vec<u8>> {
652 trace!("Read file {}", path);
653 #[derive(Debug, Serialize)]
654 struct Request {
655 path: String,
656 offset: usize,
657 length: usize,
658 }
659 #[allow(dead_code)]
660 #[derive(Debug, Deserialize)]
661 struct Response {
662 bytes_read: usize,
663 data: String,
664 }
665 let resp: Response = self
666 .client
667 .get(self.get_url("read"))
668 .json(&Request {
669 path: path.to_string(),
670 offset,
671 length,
672 })
673 .send()
674 .await?
675 .detailed_error_for_status()
676 .await
677 .log()?
678 .json()
679 .await?;
680 Ok(base64::decode(resp.data)?)
681 }
682}
683
684#[pin_project]
685pub struct DbfsReadStream {
686 reader: Arc<DbfsClientInner>,
687 path: String,
688 step: ReadStreamSteps,
689 file_size: usize,
690 file_offset: usize,
691 current_buf: Vec<u8>,
692 current_buf_offset: usize,
693 len_future: Pin<Box<dyn Future<Output = std::result::Result<usize, std::io::Error>>>>,
694 current_future:
695 Option<Pin<Box<dyn Future<Output = std::result::Result<Vec<u8>, std::io::Error>>>>>,
696}
697
698#[derive(Clone, Copy, Debug)]
699enum ReadStreamSteps {
700 Len,
701 Read,
702 End,
703 Eof,
704}
705
706impl AsyncBufRead for DbfsReadStream {
707 fn poll_fill_buf(
708 self: Pin<&mut Self>,
709 cx: &mut std::task::Context<'_>,
710 ) -> Poll<std::io::Result<&[u8]>> {
711 let mut this = self.project();
712 let current_buf = &mut this.current_buf;
713 match *this.step {
714 ReadStreamSteps::Len => {
715 match this.len_future.poll_unpin(cx) {
716 Poll::Ready(r) => {
717 match r {
718 Ok(sz) => {
719 if sz == 0 {
721 debug!("File is empty");
723 return Poll::Ready(Ok(&[]));
724 }
725 trace!("File length is {}", sz);
726 *this.file_size = sz;
727 *this.file_offset = 0;
728 *this.current_buf_offset = 0;
729 *this.step = ReadStreamSteps::Read;
731 cx.waker().wake_by_ref();
732 Poll::Pending
733 }
734 Err(e) => {
735 Poll::Ready(Err(e))
737 }
738 }
739 }
740 Poll::Pending => {
741 Poll::Pending
743 }
744 }
745 }
746 ReadStreamSteps::Read => {
747 if *this.file_offset >= *this.file_size {
748 *this.step = ReadStreamSteps::End;
750 trace!("Reach EOF");
751 Poll::Ready(std::io::Result::Ok(&this.current_buf[0..0]))
752 } else if current_buf.len() > *this.current_buf_offset {
753 let end_pos = current_buf.len();
755 Poll::Ready(std::io::Result::Ok(
756 &this.current_buf[*this.current_buf_offset..end_pos],
757 ))
758 } else if let Some(f) = this.current_future {
759 let p = f.poll_unpin(cx);
761 match p {
762 Poll::Ready(r) => {
763 *this.current_future = None;
765 match r {
766 Ok(b) => {
767 *this.current_buf_offset = 0;
770 *this.current_buf = b;
771 *this.step = ReadStreamSteps::Read;
772 cx.waker().wake_by_ref();
773 Poll::Pending
774 }
775 Err(e) => {
776 *this.step = ReadStreamSteps::End;
778 Poll::Ready(Err(e))
779 }
780 }
781 }
782 Poll::Pending => Poll::Pending,
783 }
784 } else {
785 let path = this.path.clone();
787 let reader = this.reader.clone();
788 let offset = *this.file_offset;
789 let f = async move {
790 reader
791 .read_block(&path, offset, 4096)
792 .await
793 .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))
794 };
795 *this.current_future = Some(Box::pin(f));
796 cx.waker().wake_by_ref();
797 Poll::Pending
798 }
799 }
800 ReadStreamSteps::End => {
801 *this.step = ReadStreamSteps::Eof;
802 trace!("Reach EOF Again");
803 Poll::Ready(std::io::Result::Ok(&this.current_buf[0..0]))
804 }
805 ReadStreamSteps::Eof => {
806 panic!(
807 "ReadStreamState must not be polled after it returned `Poll::Ready(Ok(&[]))`"
808 )
809 }
810 }
811 }
812
813 fn consume(self: Pin<&mut Self>, amt: usize) {
814 let this = self.project();
815 *this.current_buf_offset += amt;
816 *this.file_offset += amt;
817 }
818}
819
820impl AsyncRead for DbfsReadStream {
821 fn poll_read(
822 self: std::pin::Pin<&mut Self>,
823 cx: &mut std::task::Context<'_>,
824 buf: &mut [u8],
825 ) -> std::task::Poll<std::io::Result<usize>> {
826 let mut this = self.project();
827 let current_buf = &mut this.current_buf;
828 match *this.step {
829 ReadStreamSteps::Len => {
830 match this.len_future.poll_unpin(cx) {
831 Poll::Ready(r) => {
832 match r {
833 Ok(sz) => {
834 if sz == 0 {
835 debug!("File is empty");
837 return Poll::Ready(Ok(0));
838 }
839 debug!("File length is {}", sz);
841 *this.file_size = sz;
842 *this.file_offset = 0;
843 *this.current_buf_offset = 0;
844 this.current_buf.clear();
845 *this.step = ReadStreamSteps::Read;
846 cx.waker().wake_by_ref();
847 Poll::Pending
848 }
849 Err(e) => {
850 Poll::Ready(Err(e))
852 }
853 }
854 }
855 Poll::Pending => {
856 Poll::Pending
858 }
859 }
860 }
861 ReadStreamSteps::Read => {
862 if *this.file_offset >= *this.file_size {
863 *this.step = ReadStreamSteps::End;
865 Poll::Ready(Ok(0))
866 } else if current_buf.len() > *this.current_buf_offset {
867 let existing_sz = current_buf.len() - *this.current_buf_offset;
869 let required_sz = buf.len();
870 let sz = min(existing_sz, required_sz);
871 let end_pos = *this.current_buf_offset + sz;
872 buf[0..sz].copy_from_slice(¤t_buf[*this.current_buf_offset..end_pos]);
873 if end_pos >= this.current_buf.len() {
874 *this.current_buf_offset = 0;
876 this.current_buf.clear();
877 } else {
878 *this.current_buf_offset = end_pos;
880 }
881 *this.file_offset += sz;
882 *this.step = ReadStreamSteps::Read;
883 Poll::Ready(std::io::Result::Ok(sz))
884 } else if let Some(f) = this.current_future {
885 let p = f.poll_unpin(cx);
887 match p {
888 Poll::Ready(r) => {
889 *this.current_future = None;
891 match r {
892 Ok(b) => {
893 *this.current_buf_offset = 0;
895 *this.current_buf = b;
896 *this.step = ReadStreamSteps::Read;
897 cx.waker().wake_by_ref();
898 Poll::Pending
899 }
900 Err(e) => {
901 *this.step = ReadStreamSteps::Eof;
903 Poll::Ready(Err(e))
904 }
905 }
906 }
907 Poll::Pending => Poll::Pending,
908 }
909 } else {
910 let path = this.path.clone();
912 let reader = this.reader.clone();
913 let offset = *this.file_offset;
914 let f = async move {
915 reader
916 .read_block(&path, offset, 4096)
917 .await
918 .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))
919 };
920 *this.current_future = Some(Box::pin(f));
921 cx.waker().wake_by_ref();
922 Poll::Pending
923 }
924 }
925 ReadStreamSteps::End => {
926 *this.step = ReadStreamSteps::Eof;
927 trace!("Reach EOF Again");
928 Poll::Ready(std::io::Result::Ok(0))
929 }
930 ReadStreamSteps::Eof => {
931 panic!("ReadStreamState must not be polled after it returned `Poll::Ready(Ok(0))`")
932 }
933 }
934 }
935}
936
937fn strip_dbfs_prefix(path: &str) -> Result<&str> {
938 let ret = path.strip_prefix("dbfs:").unwrap_or(path);
939 if ret.starts_with("/") {
940 Ok(ret)
941 } else {
942 Err(DbfsError::InvalidDbfsPath(path.to_string()))
943 }
944}
945
946#[cfg(test)]
947mod tests {
948 use std::sync::Once;
949
950 use dotenv;
951 use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt};
952 use rand::Rng;
953
954 static INIT_ENV_LOGGER: Once = Once::new();
955
956 pub fn init_logger() {
957 dotenv::dotenv().ok();
958 INIT_ENV_LOGGER.call_once(|| env_logger::init());
959 }
960 use super::*;
961
962 fn init() -> DbfsClient {
963 crate::tests::init_logger();
964 DbfsClient::new(
965 &std::env::var("DATABRICKS_URL").unwrap(),
966 &std::env::var("DATABRICKS_API_TOKEN").unwrap(),
967 )
968 }
969
970 #[test]
971 fn test_strip_prefix() {
972 assert_eq!(strip_dbfs_prefix("/abc").unwrap(), "/abc");
973 assert_ne!(strip_dbfs_prefix("/abc").unwrap(), "/abcd");
974 assert_eq!(strip_dbfs_prefix("dbfs:/abc").unwrap(), "/abc");
975 assert_ne!(strip_dbfs_prefix("dbfs:/abc").unwrap(), "/abcd");
976 assert!(matches!(
977 strip_dbfs_prefix("abc"),
978 Err(DbfsError::InvalidDbfsPath(..))
979 ));
980 assert!(matches!(
981 strip_dbfs_prefix("dbfs:abc"),
982 Err(DbfsError::InvalidDbfsPath(..))
983 ));
984 }
985
986 #[tokio::test]
987 async fn read_write_delete() {
988 let client = init();
989 let expected = "foo\nbar\nbaz\nspam\n".as_bytes();
990 client
991 .write_file("/test_read_write_delete", expected)
992 .await
993 .unwrap();
994 let data = client.read_file("/test_read_write_delete").await.unwrap();
995 assert_eq!(data, expected);
996 assert_eq!(
997 client
998 .get_file_status("/test_read_write_delete")
999 .await
1000 .unwrap()
1001 .file_size,
1002 expected.len()
1003 );
1004 client.delete_file("/test_read_write_delete").await.unwrap();
1005 let ret = client.read_file("/test_read_write_delete").await;
1006 assert!(matches!(
1007 ret,
1008 Err(DbfsError::DbfsApiError(
1009 DbfsErrorCode::ResourceDoesNotExist,
1010 ..
1011 ))
1012 ));
1013 }
1014
1015 #[tokio::test]
1016 async fn upload_file() {
1017 let client = init();
1018 let expected = "foo\nbar\nbaz\nspam\n".as_bytes();
1019 let mut f = tokio::fs::File::create("/tmp/test_upload_file")
1020 .await
1021 .unwrap();
1022 f.write_all(expected).await.unwrap();
1023 f.flush().await.unwrap();
1024 f.sync_all().await.unwrap();
1025 client
1026 .upload_file("/tmp/test_upload_file", "/test_upload_file")
1027 .await
1028 .unwrap();
1029 let data = client.read_file("/test_upload_file").await.unwrap();
1030 assert_eq!(data, expected);
1031 }
1032
1033 #[tokio::test]
1034 async fn large_file() {
1035 let mut rng = rand::thread_rng();
1036
1037 let expected: Vec<u8> = (0..1024 * 1024 * 2).map(|_| rng.gen()).collect();
1039
1040 let client = init();
1041 client
1042 .write_file("dbfs:/large_file", &expected)
1043 .await
1044 .unwrap();
1045
1046 let buf = client.read_file("/large_file").await.unwrap();
1047 assert_eq!(buf, expected);
1048 }
1049
1050 #[tokio::test]
1051 async fn test_read() {
1052 let client = init();
1053 let mut rng = rand::thread_rng();
1054 const TOTAL: usize = 100000;
1055 let expected: Vec<u8> = (0..TOTAL).map(|_| rng.gen()).collect();
1056 client
1057 .write_file("dbfs:/test_read", &expected)
1058 .await
1059 .unwrap();
1060
1061 let mut offset = 0;
1062 let mut buf = [0; 5000];
1063 let mut s = client.read("dbfs:/test_read").unwrap();
1064 while let Ok(sz) = s.read(&mut buf).await {
1065 debug!("Read {} bytes at {}", buf.len(), offset);
1066 debug!("Got {} bytes", sz);
1067 if sz == 0 {
1068 break;
1069 }
1070 assert_eq!(&buf[0..sz], &expected[offset..offset + sz]);
1071 offset += sz;
1072 }
1073 assert_eq!(offset, TOTAL);
1074 }
1075
1076 #[tokio::test]
1077 async fn test_read_line() {
1078 let client = init();
1079 const TOTAL: usize = 1000;
1080 let expected: Vec<String> = (0..TOTAL).map(|n| format!("Line {}", n)).collect();
1081 client
1082 .write_file("dbfs:/test_read_line", expected.join("\n").as_bytes())
1083 .await
1084 .unwrap();
1085
1086 let mut counter = 0;
1087 let mut lines = client.read("dbfs:/test_read_line").unwrap().lines();
1088 while let Some(Ok(line)) = lines.next().await {
1089 debug!("Line is `{}`", line);
1090 assert_eq!(line, format!("Line {}", counter));
1091 counter += 1;
1092 }
1093 assert_eq!(counter, TOTAL)
1094 }
1095}