dbfs_client/
lib.rs

1use std::{
2    cmp::min,
3    fmt::Display,
4    path::{Path, PathBuf},
5    pin::Pin,
6    sync::Arc,
7    task::Poll,
8};
9
10use async_trait::async_trait;
11use futures::{AsyncBufRead, AsyncRead, Future, FutureExt};
12use log::{debug, trace};
13use pin_project::pin_project;
14use reqwest::multipart::Part;
15use serde::{Deserialize, Serialize};
16use thiserror::Error;
17use tokio::io::AsyncWriteExt;
18
19const CHUNK_SIZE: usize = 1024 * 1024;
20
21/// Log if `Result` is an error
22trait Logged {
23    fn log(self) -> Self;
24}
25
26impl<T, E> Logged for std::result::Result<T, E>
27where
28    E: std::fmt::Debug,
29{
30    fn log(self) -> Self {
31        if let Err(e) = &self {
32            log::debug!("---TraceError--- {:#?}", e)
33        }
34        self
35    }
36}
37
38#[async_trait]
39trait LoggedResponse {
40    async fn detailed_error_for_status(self) -> Result<Self>
41    where
42        Self: Sized;
43}
44
45#[async_trait]
46impl LoggedResponse for reqwest::Response {
47    async fn detailed_error_for_status(self) -> Result<Self> {
48        #[derive(Debug, Deserialize)]
49        struct DbfsErrorResponse {
50            error_code: DbfsErrorCode,
51            message: String,
52        }
53        
54        if self.status().is_client_error() || self.status().is_server_error() {
55            let url = self.url().to_string();
56            let status = self.status().to_string();
57            let text = self.text().await?;
58            Err(match serde_json::from_str::<DbfsErrorResponse>(&text) {
59                Ok(resp) => DbfsError::DbfsApiError(resp.error_code, resp.message),
60                Err(_) => DbfsError::HttpError(url, status, text),
61            })
62        } else {
63            Ok(self)
64        }
65    }
66}
67
68/**
69 * DBFS Error Code
70 */
71#[derive(Debug, Serialize, Deserialize)]
72#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
73pub enum DbfsErrorCode {
74    ResourceAlreadyExists,
75    MaxBlockSizeExceeded,
76    InvalidParameterValue,
77    MaxReadSizeExceeded,
78    ResourceDoesNotExist,
79}
80
81impl Display for DbfsErrorCode {
82    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
83        f.write_str(
84            &serde_json::to_string(&self)
85                .unwrap()
86                .strip_prefix("\"")
87                .unwrap()
88                .strip_suffix("\"")
89                .unwrap(),
90        )
91    }
92}
93
94/**
95 * Error Type
96 */
97#[derive(Debug, Error)]
98pub enum DbfsError {
99    #[error(transparent)]
100    ReqwestError(#[from] reqwest::Error),
101
102    #[error("HTTP Error, URL: '{0}', Status: {1}, Response: '{2}' ")]
103    HttpError(String, String, String),
104
105    #[error(transparent)]
106    DecodeError(#[from] base64::DecodeError),
107
108    #[error(transparent)]
109    IoError(#[from] std::io::Error),
110
111    #[error(transparent)]
112    VarError(#[from] std::env::VarError),
113
114    #[error(transparent)]
115    JsonError(#[from] serde_json::Error),
116
117    #[error("DBFS Error, Code: {0}, message: {0}")]
118    DbfsApiError(DbfsErrorCode, String),
119
120    #[error("Invalid DBFS Path {0}")]
121    InvalidDbfsPath(String),
122}
123
124/**
125 * DBFS API Version
126 */
127#[allow(non_camel_case_types)]
128#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord)]
129pub enum DbfsApiVersions {
130    API_2_0,
131}
132
133impl Default for DbfsApiVersions {
134    fn default() -> Self {
135        Self::API_2_0
136    }
137}
138
139impl Display for DbfsApiVersions {
140    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
141        write!(
142            f,
143            "{}",
144            match &self {
145                DbfsApiVersions::API_2_0 => "api/2.0",
146            }
147        )
148    }
149}
150
151/**
152 * The Result type
153 */
154pub type Result<T> = std::result::Result<T, DbfsError>;
155
156/**
157 * DBFS File Status
158 */
159#[derive(Debug, Deserialize)]
160pub struct FileStatus {
161    pub path: String,
162    pub is_dir: bool,
163    pub file_size: usize,
164    pub modification_time: u64,
165}
166
167/**
168 * DBFS Client
169 */
170#[derive(Clone, Debug)]
171pub struct DbfsClient {
172    inner: Arc<DbfsClientInner>,
173}
174
175impl DbfsClient {
176    /**
177     * Create New DBFS Client
178     * 
179     * @param url_base: the base part of the DBFS endpoint, e.g. "https://adb-xxx.azuredatabricks.net"
180     * @param token: The Databricks API token
181     */
182    pub fn new(url_base: &str, token: &str) -> Self {
183        Self {
184            inner: Arc::new(DbfsClientInner::new(url_base, token)),
185        }
186    }
187
188    /** 
189     * Read DBFS file, returns AsyncRead + AsyncBufRead
190     */
191    pub fn read(&self, path: &str) -> Result<DbfsReadStream> {
192        let path = strip_dbfs_prefix(path)?.to_string();
193        let inner = self.inner.clone();
194        Ok(DbfsReadStream {
195            reader: inner.clone(),
196            path: path.clone(),
197            step: ReadStreamSteps::Len,
198            file_size: 0,
199            file_offset: 0,
200            current_buf: vec![],
201            current_buf_offset: 0,
202            len_future: Box::pin(async move {
203                inner
204                    .get_status(&path)
205                    .map(|r| {
206                        r.map(|s| s.file_size)
207                            .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))
208                    })
209                    .await
210            }),
211            current_future: None,
212        })
213    }
214
215    /**
216     * Read the whole file into a Vec<u8>
217     */
218    pub async fn read_file(&self, path: &str) -> Result<Vec<u8>> {
219        let path = strip_dbfs_prefix(path)?;
220        debug!("Reading DBFS file {}", path);
221        let file_size = self.inner.get_status(path).await?.file_size;
222        debug!("File size is {}", file_size);
223        let mut ret = Vec::with_capacity(file_size);
224        let mut offset = 0;
225        loop {
226            let data = self.inner.read_block(path, offset, CHUNK_SIZE).await?;
227            offset += data.len();
228            ret.extend(data.into_iter());
229            if offset >= file_size {
230                break;
231            }
232        }
233        Ok(ret)
234    }
235
236    /**
237     * Write data to file, the existing content will be overwritten
238     */
239    pub async fn write_file<T>(&self, path: &str, data: T) -> Result<()>
240    where
241        T: AsRef<[u8]>,
242    {
243        debug!(
244            "Writing {} bytes to DBFS file {}",
245            data.as_ref().len(),
246            path
247        );
248        let path = strip_dbfs_prefix(path)?;
249        if data.as_ref().len() < CHUNK_SIZE {
250            return self.inner.put(path, data, true).await;
251        }
252
253        let handle = self.inner.create(path, true).await?;
254        for chunk in data.as_ref().chunks(CHUNK_SIZE) {
255            self.inner.add_block(handle, chunk).await?;
256        }
257        self.inner.close(handle).await?;
258        Ok(())
259    }
260
261    /**
262     * Upload a local file to DBFS
263     */
264    pub async fn upload_file<T>(&self, local_path: T, remote_path: &str) -> Result<String>
265    where
266        T: AsRef<Path>,
267    {
268        debug!(
269            "Uploading local file {} to DBFS file {}",
270            local_path.as_ref().to_string_lossy(),
271            remote_path
272        );
273        let remote_path = strip_dbfs_prefix(remote_path)?;
274        let filename = local_path.as_ref().to_owned().to_string_lossy().to_string();
275        let file = tokio::fs::File::open(local_path).await?;
276        let length = file.metadata().await?.len();
277        let stream = tokio_util::codec::FramedRead::new(file, tokio_util::codec::BytesCodec::new());
278        let body = reqwest::Body::wrap_stream(stream);
279        self.inner
280            .put_stream(remote_path, &filename, body, length, true)
281            .await?;
282        Ok(remote_path.to_string())
283    }
284
285    /**
286     * Download DBFS file to local path
287     */
288    pub async fn download_file<T>(&self, remote_path: &str, local_path: T) -> Result<PathBuf>
289    where
290        T: AsRef<Path>,
291    {
292        debug!(
293            "Downloading DBFS file {} to local file {}",
294            remote_path,
295            local_path.as_ref().to_string_lossy()
296        );
297        let remote_path = strip_dbfs_prefix(remote_path)?;
298        let file_size = self.inner.get_status(remote_path).await?.file_size;
299        let mut offset = 0;
300        let mut file = tokio::fs::File::create(local_path.as_ref()).await?;
301        loop {
302            let data = self
303                .inner
304                .read_block(remote_path, offset, CHUNK_SIZE)
305                .await?;
306            offset += data.len();
307            file.write_all(&data).await?;
308            if offset >= file_size {
309                break;
310            }
311        }
312        file.flush().await?;
313        file.sync_all().await?;
314        Ok(PathBuf::from(local_path.as_ref()))
315    }
316
317    /**
318     * Get DBFS file status
319     */
320    pub async fn get_file_status(&self, path: &str) -> Result<FileStatus> {
321        debug!("Getting status of DBFS file {}", path);
322        self.inner.get_status(path).await
323    }
324
325    /**
326     * Delete DBFS file
327     */
328    pub async fn delete_file(&self, path: &str) -> Result<()> {
329        debug!("Deleting DBFS file {}", path);
330        self.inner.delete(strip_dbfs_prefix(path)?).await
331    }
332
333    /**
334     * Get all status of the files under the directory
335     */
336    pub async fn list(&self, path: &str) -> Result<Vec<FileStatus>> {
337        debug!("Listing DBFS directory {}", path);
338        self.inner.list(strip_dbfs_prefix(path)?).await
339    }
340
341    /**
342     * Create directory recursively
343     */
344    pub async fn mkdir(&self, path: &str) -> Result<()> {
345        debug!("Creating DBFS directory {}", path);
346        self.inner.mkdirs(strip_dbfs_prefix(path)?).await
347    }
348
349    /**
350     * Move a DBFS file from one place to another
351     */
352    pub async fn move_file(&self, src_path: &str, dest_path: &str) -> Result<()> {
353        debug!("Moving DBFS file from {} to {}", src_path, dest_path);
354        self.inner
355            .move_(strip_dbfs_prefix(src_path)?, strip_dbfs_prefix(dest_path)?)
356            .await
357    }
358}
359
360#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
361struct Handle(u64);
362
363#[derive(Debug)]
364struct DbfsClientInner {
365    url_base: String,
366    api_version: DbfsApiVersions,
367    client: reqwest::Client,
368}
369
370impl DbfsClientInner {
371    fn new(url_base: &str, token: &str) -> Self {
372        let mut headers = reqwest::header::HeaderMap::new();
373        if !token.is_empty() {
374            headers.insert(
375                "Authorization",
376                reqwest::header::HeaderValue::from_str(&format!("Bearer {}", token)).unwrap(),
377            );
378        }
379
380        Self {
381            url_base: url_base
382                .trim()
383                .strip_suffix("/")
384                .unwrap_or(url_base)
385                .trim()
386                .to_string(),
387            api_version: DbfsApiVersions::API_2_0,
388            client: reqwest::ClientBuilder::new()
389                .default_headers(headers)
390                .build()
391                .unwrap(),
392        }
393    }
394
395    fn get_url(&self, api: &str) -> String {
396        format!("{}/{}/dbfs/{}", self.url_base, self.api_version, api)
397    }
398
399    /// DBFS API
400
401    async fn add_block<T>(&self, handle: Handle, data: T) -> Result<()>
402    where
403        T: AsRef<[u8]>,
404    {
405        trace!("Add block to handle {}", handle.0);
406        #[derive(Debug, Serialize)]
407        struct Request {
408            handle: Handle,
409            data: String,
410        }
411        self.client
412            .post(self.get_url("add-block"))
413            .json(&Request {
414                handle,
415                data: base64::encode(data),
416            })
417            .send()
418            .await?
419            .detailed_error_for_status()
420            .await
421            .log()?
422            .text()
423            .await?;
424        Ok(())
425    }
426
427    async fn close(&self, handle: Handle) -> Result<()> {
428        trace!("Close handle {}", handle.0);
429        #[derive(Debug, Serialize)]
430        struct Request {
431            handle: Handle,
432        }
433        self.client
434            .post(self.get_url("close"))
435            .json(&Request { handle })
436            .send()
437            .await?
438            .detailed_error_for_status()
439            .await
440            .log()?
441            .text()
442            .await?;
443        Ok(())
444    }
445
446    async fn create(&self, path: &str, overwrite: bool) -> Result<Handle> {
447        trace!("Create file {}", path);
448        #[derive(Debug, Serialize)]
449        struct Request {
450            path: String,
451            overwrite: bool,
452        }
453        #[derive(Debug, Deserialize)]
454        struct Response {
455            handle: Handle,
456        }
457        let resp: Response = self
458            .client
459            .post(self.get_url("create"))
460            .json(&Request {
461                path: path.to_string(),
462                overwrite,
463            })
464            .send()
465            .await?
466            .detailed_error_for_status()
467            .await
468            .log()?
469            .json()
470            .await?;
471        Ok(resp.handle)
472    }
473
474    async fn delete(&self, path: &str) -> Result<()> {
475        trace!("Delete file {}", path);
476        #[derive(Debug, Serialize)]
477        struct Request {
478            path: String,
479        }
480        self.client
481            .post(self.get_url("delete"))
482            .json(&Request {
483                path: path.to_string(),
484            })
485            .send()
486            .await?
487            .detailed_error_for_status()
488            .await
489            .log()?
490            .text()
491            .await?;
492        Ok(())
493    }
494
495    async fn get_status(&self, path: &str) -> Result<FileStatus> {
496        trace!("Get status of file {}", path);
497        #[derive(Debug, Serialize)]
498        struct Request {
499            path: String,
500        }
501        Ok(self
502            .client
503            .get(self.get_url("get-status"))
504            .json(&Request {
505                path: path.to_string(),
506            })
507            .send()
508            .await?
509            .detailed_error_for_status()
510            .await
511            .log()?
512            .json()
513            .await?)
514    }
515
516    async fn list(&self, path: &str) -> Result<Vec<FileStatus>> {
517        trace!("List directory {}", path);
518        #[derive(Debug, Serialize)]
519        struct Request {
520            path: String,
521        }
522        #[derive(Debug, Deserialize)]
523        struct Response {
524            files: Vec<FileStatus>,
525        }
526        let resp: Response = self
527            .client
528            .get(self.get_url("list"))
529            .json(&Request {
530                path: path.to_string(),
531            })
532            .send()
533            .await?
534            .detailed_error_for_status()
535            .await
536            .log()?
537            .json()
538            .await?;
539        Ok(resp.files)
540    }
541
542    async fn mkdirs(&self, path: &str) -> Result<()> {
543        trace!("Make directory {}", path);
544        #[derive(Debug, Serialize)]
545        struct Request {
546            path: String,
547        }
548        self.client
549            .post(self.get_url("mkdirs"))
550            .json(&Request {
551                path: path.to_string(),
552            })
553            .send()
554            .await?
555            .detailed_error_for_status()
556            .await
557            .log()?
558            .text()
559            .await?;
560        Ok(())
561    }
562
563    async fn move_(&self, source_path: &str, destination_path: &str) -> Result<()> {
564        trace!("Move file from {} to {}", source_path, destination_path);
565        #[derive(Debug, Serialize)]
566        struct Request {
567            source_path: String,
568            destination_path: String,
569        }
570        self.client
571            .post(self.get_url("move"))
572            .json(&Request {
573                source_path: source_path.to_string(),
574                destination_path: destination_path.to_string(),
575            })
576            .send()
577            .await?
578            .detailed_error_for_status()
579            .await
580            .log()?
581            .text()
582            .await?;
583        Ok(())
584    }
585
586    async fn put<T>(&self, path: &str, content: T, overwrite: bool) -> Result<()>
587    where
588        T: AsRef<[u8]>,
589    {
590        trace!(
591            "Upload buffer to file {}, length is {}",
592            path,
593            content.as_ref().len()
594        );
595        #[derive(Debug, Serialize)]
596        struct Request {
597            path: String,
598            contents: String,
599            overwrite: bool,
600        }
601        self.client
602            .post(self.get_url("put"))
603            .json(&Request {
604                path: path.to_string(),
605                contents: base64::encode(content),
606                overwrite,
607            })
608            .send()
609            .await?
610            .detailed_error_for_status()
611            .await
612            .log()?
613            .text()
614            .await?;
615        Ok(())
616    }
617
618    async fn put_stream<S>(
619        &self,
620        path: &str,
621        filename: &str,
622        stream: S,
623        length: u64,
624        overwrite: bool,
625    ) -> Result<()>
626    where
627        S: Into<reqwest::Body>,
628    {
629        trace!("Upload stream to file {}, length is {}", path, length);
630        let path = path.to_string();
631        let form = reqwest::multipart::Form::new()
632            .part(
633                "contents",
634                Part::stream_with_length(stream, length).file_name(filename.to_owned()),
635            )
636            .text("path", path)
637            .text("overwrite", if overwrite { "true" } else { "false" });
638        self.client
639            .post(self.get_url("put"))
640            .multipart(form)
641            .send()
642            .await?
643            .detailed_error_for_status()
644            .await
645            .log()?
646            .text()
647            .await?;
648        Ok(())
649    }
650
651    async fn read_block(&self, path: &str, offset: usize, length: usize) -> Result<Vec<u8>> {
652        trace!("Read file {}", path);
653        #[derive(Debug, Serialize)]
654        struct Request {
655            path: String,
656            offset: usize,
657            length: usize,
658        }
659        #[allow(dead_code)]
660        #[derive(Debug, Deserialize)]
661        struct Response {
662            bytes_read: usize,
663            data: String,
664        }
665        let resp: Response = self
666            .client
667            .get(self.get_url("read"))
668            .json(&Request {
669                path: path.to_string(),
670                offset,
671                length,
672            })
673            .send()
674            .await?
675            .detailed_error_for_status()
676            .await
677            .log()?
678            .json()
679            .await?;
680        Ok(base64::decode(resp.data)?)
681    }
682}
683
684#[pin_project]
685pub struct DbfsReadStream {
686    reader: Arc<DbfsClientInner>,
687    path: String,
688    step: ReadStreamSteps,
689    file_size: usize,
690    file_offset: usize,
691    current_buf: Vec<u8>,
692    current_buf_offset: usize,
693    len_future: Pin<Box<dyn Future<Output = std::result::Result<usize, std::io::Error>>>>,
694    current_future:
695        Option<Pin<Box<dyn Future<Output = std::result::Result<Vec<u8>, std::io::Error>>>>>,
696}
697
698#[derive(Clone, Copy, Debug)]
699enum ReadStreamSteps {
700    Len,
701    Read,
702    End,
703    Eof,
704}
705
706impl AsyncBufRead for DbfsReadStream {
707    fn poll_fill_buf(
708        self: Pin<&mut Self>,
709        cx: &mut std::task::Context<'_>,
710    ) -> Poll<std::io::Result<&[u8]>> {
711        let mut this = self.project();
712        let current_buf = &mut this.current_buf;
713        match *this.step {
714            ReadStreamSteps::Len => {
715                match this.len_future.poll_unpin(cx) {
716                    Poll::Ready(r) => {
717                        match r {
718                            Ok(sz) => {
719                                // Got file length, start reading
720                                if sz == 0 {
721                                    // File is empty
722                                    debug!("File is empty");
723                                    return Poll::Ready(Ok(&[]));
724                                }
725                                trace!("File length is {}", sz);
726                                *this.file_size = sz;
727                                *this.file_offset = 0;
728                                *this.current_buf_offset = 0;
729                                // this.current_buf.clear();
730                                *this.step = ReadStreamSteps::Read;
731                                cx.waker().wake_by_ref();
732                                Poll::Pending
733                            }
734                            Err(e) => {
735                                // Failed to get file length
736                                Poll::Ready(Err(e))
737                            }
738                        }
739                    }
740                    Poll::Pending => {
741                        // Pending on getting file length
742                        Poll::Pending
743                    }
744                }
745            }
746            ReadStreamSteps::Read => {
747                if *this.file_offset >= *this.file_size {
748                    // Reach EOF
749                    *this.step = ReadStreamSteps::End;
750                    trace!("Reach EOF");
751                    Poll::Ready(std::io::Result::Ok(&this.current_buf[0..0]))
752                } else if current_buf.len() > *this.current_buf_offset {
753                    // There are some data left in the current buffer
754                    let end_pos = current_buf.len();
755                    Poll::Ready(std::io::Result::Ok(
756                        &this.current_buf[*this.current_buf_offset..end_pos],
757                    ))
758                } else if let Some(f) = this.current_future {
759                    // Reading operation in progress
760                    let p = f.poll_unpin(cx);
761                    match p {
762                        Poll::Ready(r) => {
763                            // Current future completed
764                            *this.current_future = None;
765                            match r {
766                                Ok(b) => {
767                                    // Got a buffer
768                                    // Reset current buffer and pos
769                                    *this.current_buf_offset = 0;
770                                    *this.current_buf = b;
771                                    *this.step = ReadStreamSteps::Read;
772                                    cx.waker().wake_by_ref();
773                                    Poll::Pending
774                                }
775                                Err(e) => {
776                                    // Read error
777                                    *this.step = ReadStreamSteps::End;
778                                    Poll::Ready(Err(e))
779                                }
780                            }
781                        }
782                        Poll::Pending => Poll::Pending,
783                    }
784                } else {
785                    // Nothing to provide, start reading
786                    let path = this.path.clone();
787                    let reader = this.reader.clone();
788                    let offset = *this.file_offset;
789                    let f = async move {
790                        reader
791                            .read_block(&path, offset, 4096)
792                            .await
793                            .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))
794                    };
795                    *this.current_future = Some(Box::pin(f));
796                    cx.waker().wake_by_ref();
797                    Poll::Pending
798                }
799            }
800            ReadStreamSteps::End => {
801                *this.step = ReadStreamSteps::Eof;
802                trace!("Reach EOF Again");
803                Poll::Ready(std::io::Result::Ok(&this.current_buf[0..0]))
804            }
805            ReadStreamSteps::Eof => {
806                panic!(
807                    "ReadStreamState must not be polled after it returned `Poll::Ready(Ok(&[]))`"
808                )
809            }
810        }
811    }
812
813    fn consume(self: Pin<&mut Self>, amt: usize) {
814        let this = self.project();
815        *this.current_buf_offset += amt;
816        *this.file_offset += amt;
817    }
818}
819
820impl AsyncRead for DbfsReadStream {
821    fn poll_read(
822        self: std::pin::Pin<&mut Self>,
823        cx: &mut std::task::Context<'_>,
824        buf: &mut [u8],
825    ) -> std::task::Poll<std::io::Result<usize>> {
826        let mut this = self.project();
827        let current_buf = &mut this.current_buf;
828        match *this.step {
829            ReadStreamSteps::Len => {
830                match this.len_future.poll_unpin(cx) {
831                    Poll::Ready(r) => {
832                        match r {
833                            Ok(sz) => {
834                                if sz == 0 {
835                                    // File is empty
836                                    debug!("File is empty");
837                                    return Poll::Ready(Ok(0));
838                                }
839                                // Got file length, start reading
840                                debug!("File length is {}", sz);
841                                *this.file_size = sz;
842                                *this.file_offset = 0;
843                                *this.current_buf_offset = 0;
844                                this.current_buf.clear();
845                                *this.step = ReadStreamSteps::Read;
846                                cx.waker().wake_by_ref();
847                                Poll::Pending
848                            }
849                            Err(e) => {
850                                // Failed to get file length
851                                Poll::Ready(Err(e))
852                            }
853                        }
854                    }
855                    Poll::Pending => {
856                        // Pending on getting file length
857                        Poll::Pending
858                    }
859                }
860            }
861            ReadStreamSteps::Read => {
862                if *this.file_offset >= *this.file_size {
863                    // Reach EOF
864                    *this.step = ReadStreamSteps::End;
865                    Poll::Ready(Ok(0))
866                } else if current_buf.len() > *this.current_buf_offset {
867                    // There are some data left in the current buffer
868                    let existing_sz = current_buf.len() - *this.current_buf_offset;
869                    let required_sz = buf.len();
870                    let sz = min(existing_sz, required_sz);
871                    let end_pos = *this.current_buf_offset + sz;
872                    buf[0..sz].copy_from_slice(&current_buf[*this.current_buf_offset..end_pos]);
873                    if end_pos >= this.current_buf.len() {
874                        // Current buffer exhausted
875                        *this.current_buf_offset = 0;
876                        this.current_buf.clear();
877                    } else {
878                        // Current buffer still has data
879                        *this.current_buf_offset = end_pos;
880                    }
881                    *this.file_offset += sz;
882                    *this.step = ReadStreamSteps::Read;
883                    Poll::Ready(std::io::Result::Ok(sz))
884                } else if let Some(f) = this.current_future {
885                    // Reading operation in progress
886                    let p = f.poll_unpin(cx);
887                    match p {
888                        Poll::Ready(r) => {
889                            // Current future completed
890                            *this.current_future = None;
891                            match r {
892                                Ok(b) => {
893                                    // Got a buffer
894                                    *this.current_buf_offset = 0;
895                                    *this.current_buf = b;
896                                    *this.step = ReadStreamSteps::Read;
897                                    cx.waker().wake_by_ref();
898                                    Poll::Pending
899                                }
900                                Err(e) => {
901                                    // Read error
902                                    *this.step = ReadStreamSteps::Eof;
903                                    Poll::Ready(Err(e))
904                                }
905                            }
906                        }
907                        Poll::Pending => Poll::Pending,
908                    }
909                } else {
910                    // Nothing to provide, start reading
911                    let path = this.path.clone();
912                    let reader = this.reader.clone();
913                    let offset = *this.file_offset;
914                    let f = async move {
915                        reader
916                            .read_block(&path, offset, 4096)
917                            .await
918                            .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))
919                    };
920                    *this.current_future = Some(Box::pin(f));
921                    cx.waker().wake_by_ref();
922                    Poll::Pending
923                }
924            }
925            ReadStreamSteps::End => {
926                *this.step = ReadStreamSteps::Eof;
927                trace!("Reach EOF Again");
928                Poll::Ready(std::io::Result::Ok(0))
929            }
930            ReadStreamSteps::Eof => {
931                panic!("ReadStreamState must not be polled after it returned `Poll::Ready(Ok(0))`")
932            }
933        }
934    }
935}
936
937fn strip_dbfs_prefix(path: &str) -> Result<&str> {
938    let ret = path.strip_prefix("dbfs:").unwrap_or(path);
939    if ret.starts_with("/") {
940        Ok(ret)
941    } else {
942        Err(DbfsError::InvalidDbfsPath(path.to_string()))
943    }
944}
945
946#[cfg(test)]
947mod tests {
948    use std::sync::Once;
949
950    use dotenv;
951    use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt};
952    use rand::Rng;
953
954    static INIT_ENV_LOGGER: Once = Once::new();
955
956    pub fn init_logger() {
957        dotenv::dotenv().ok();
958        INIT_ENV_LOGGER.call_once(|| env_logger::init());
959    }
960    use super::*;
961
962    fn init() -> DbfsClient {
963        crate::tests::init_logger();
964        DbfsClient::new(
965            &std::env::var("DATABRICKS_URL").unwrap(),
966            &std::env::var("DATABRICKS_API_TOKEN").unwrap(),
967        )
968    }
969
970    #[test]
971    fn test_strip_prefix() {
972        assert_eq!(strip_dbfs_prefix("/abc").unwrap(), "/abc");
973        assert_ne!(strip_dbfs_prefix("/abc").unwrap(), "/abcd");
974        assert_eq!(strip_dbfs_prefix("dbfs:/abc").unwrap(), "/abc");
975        assert_ne!(strip_dbfs_prefix("dbfs:/abc").unwrap(), "/abcd");
976        assert!(matches!(
977            strip_dbfs_prefix("abc"),
978            Err(DbfsError::InvalidDbfsPath(..))
979        ));
980        assert!(matches!(
981            strip_dbfs_prefix("dbfs:abc"),
982            Err(DbfsError::InvalidDbfsPath(..))
983        ));
984    }
985
986    #[tokio::test]
987    async fn read_write_delete() {
988        let client = init();
989        let expected = "foo\nbar\nbaz\nspam\n".as_bytes();
990        client
991            .write_file("/test_read_write_delete", expected)
992            .await
993            .unwrap();
994        let data = client.read_file("/test_read_write_delete").await.unwrap();
995        assert_eq!(data, expected);
996        assert_eq!(
997            client
998                .get_file_status("/test_read_write_delete")
999                .await
1000                .unwrap()
1001                .file_size,
1002            expected.len()
1003        );
1004        client.delete_file("/test_read_write_delete").await.unwrap();
1005        let ret = client.read_file("/test_read_write_delete").await;
1006        assert!(matches!(
1007            ret,
1008            Err(DbfsError::DbfsApiError(
1009                DbfsErrorCode::ResourceDoesNotExist,
1010                ..
1011            ))
1012        ));
1013    }
1014
1015    #[tokio::test]
1016    async fn upload_file() {
1017        let client = init();
1018        let expected = "foo\nbar\nbaz\nspam\n".as_bytes();
1019        let mut f = tokio::fs::File::create("/tmp/test_upload_file")
1020            .await
1021            .unwrap();
1022        f.write_all(expected).await.unwrap();
1023        f.flush().await.unwrap();
1024        f.sync_all().await.unwrap();
1025        client
1026            .upload_file("/tmp/test_upload_file", "/test_upload_file")
1027            .await
1028            .unwrap();
1029        let data = client.read_file("/test_upload_file").await.unwrap();
1030        assert_eq!(data, expected);
1031    }
1032
1033    #[tokio::test]
1034    async fn large_file() {
1035        let mut rng = rand::thread_rng();
1036
1037        // Exceeds CHUNK_SIZE
1038        let expected: Vec<u8> = (0..1024 * 1024 * 2).map(|_| rng.gen()).collect();
1039
1040        let client = init();
1041        client
1042            .write_file("dbfs:/large_file", &expected)
1043            .await
1044            .unwrap();
1045
1046        let buf = client.read_file("/large_file").await.unwrap();
1047        assert_eq!(buf, expected);
1048    }
1049
1050    #[tokio::test]
1051    async fn test_read() {
1052        let client = init();
1053        let mut rng = rand::thread_rng();
1054        const TOTAL: usize = 100000;
1055        let expected: Vec<u8> = (0..TOTAL).map(|_| rng.gen()).collect();
1056        client
1057            .write_file("dbfs:/test_read", &expected)
1058            .await
1059            .unwrap();
1060
1061        let mut offset = 0;
1062        let mut buf = [0; 5000];
1063        let mut s = client.read("dbfs:/test_read").unwrap();
1064        while let Ok(sz) = s.read(&mut buf).await {
1065            debug!("Read {} bytes at {}", buf.len(), offset);
1066            debug!("Got {} bytes", sz);
1067            if sz == 0 {
1068                break;
1069            }
1070            assert_eq!(&buf[0..sz], &expected[offset..offset + sz]);
1071            offset += sz;
1072        }
1073        assert_eq!(offset, TOTAL);
1074    }
1075
1076    #[tokio::test]
1077    async fn test_read_line() {
1078        let client = init();
1079        const TOTAL: usize = 1000;
1080        let expected: Vec<String> = (0..TOTAL).map(|n| format!("Line {}", n)).collect();
1081        client
1082            .write_file("dbfs:/test_read_line", expected.join("\n").as_bytes())
1083            .await
1084            .unwrap();
1085
1086        let mut counter = 0;
1087        let mut lines = client.read("dbfs:/test_read_line").unwrap().lines();
1088        while let Some(Ok(line)) = lines.next().await {
1089            debug!("Line is `{}`", line);
1090            assert_eq!(line, format!("Line {}", counter));
1091            counter += 1;
1092        }
1093        assert_eq!(counter, TOTAL)
1094    }
1095}