remote_file/
lib.rs

1#![warn(missing_docs)]
2#![doc = include_str!("../README.md")]
3
4use futures_util::{FutureExt, StreamExt, future::BoxFuture, stream::BoxStream};
5use std::{num::NonZeroU64, task::ready};
6use tokio::io::{AsyncRead, AsyncSeek};
7
8type RequestFuture = BoxFuture<'static, reqwest::Result<ResponseStream>>;
9type ResponseStream = BoxStream<'static, reqwest::Result<bytes::Bytes>>;
10
11fn new_request(client: &reqwest::Client, url: reqwest::Url, pos: u64) -> RequestFuture {
12    client
13        .get(url)
14        .header(reqwest::header::RANGE, format!("bytes={}-", pos))
15        .send()
16        .map(|resp| match resp {
17            Ok(resp) => match resp.error_for_status() {
18                Ok(resp) => Ok(resp.bytes_stream().boxed()),
19                Err(e) => Err(e),
20            },
21            Err(e) => Err(e),
22        })
23        .boxed()
24}
25
26/// An remote file accessed over HTTP.
27/// Implements `AsyncRead` and `AsyncSeek` traits.
28/// 
29/// * Supports seeking and reading at arbitrary positions.
30/// * Uses HTTP Range requests to fetch data.
31/// * Handles transient network errors with retries.
32/// * `stream_position()` is cheap, as it is tracked locally.
33/// 
34pub struct HttpFile {
35    client: reqwest::Client,
36
37    // info
38    url: reqwest::Url,
39    content_length: Option<NonZeroU64>,
40    etag: Option<String>,
41    mime: Option<String>,
42
43    // inner states
44    pos: u64,
45    request: Option<(u64, RequestFuture)>,
46    response: Option<ResponseStream>,
47    last_chunk: Option<bytes::Bytes>,
48    seek: Option<u64>,
49    retry_attempt: u8,
50}
51
52impl std::fmt::Debug for HttpFile {
53    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
54        f.debug_struct("HttpFile")
55            .field("client", &self.client)
56            .field("url", &self.url)
57            .field("content_length", &self.content_length)
58            .field("etag", &self.etag)
59            .field("pos", &self.pos)
60            .field(
61                "request",
62                &self
63                    .request
64                    .as_ref()
65                    .map(|(pos, _)| format!("request at {}", pos)),
66            )
67            .field("response", &"[response stream]")
68            .field("last_chunk", &self.last_chunk)
69            .field("seek", &self.seek)
70            .finish()
71    }
72}
73
74impl HttpFile {
75    /// url of the file
76    pub fn url(&self) -> &reqwest::Url {
77        &self.url
78    }
79    /// content length of the file(in bytes), if present
80    pub fn content_length(&self) -> Option<u64> {
81        self.content_length.map(|v| v.get())
82    }
83    /// etag of the file, if present
84    pub fn etag(&self) -> Option<&str> {
85        self.etag.as_deref()
86    }
87    /// Mime type of the file, if present
88    pub fn mime(&self) -> Option<&str> {
89        self.mime.as_deref() 
90    }
91}
92
93impl HttpFile {
94    /// Create a new `HttpFile` from a `reqwest::Client` and a file URL.
95    /// 
96    /// Arguments:
97    /// * `client`: A `reqwest::Client` instance to make HTTP requests.
98    /// * `url`: The URL of the file to access.
99    ///
100    pub async fn new(client: reqwest::Client, url: &str) -> reqwest::Result<Self> {
101        log::debug!("HEAD {}", url);
102        let resp = client.head(url).send().await?.error_for_status()?;
103        let etag = resp
104            .headers()
105            .get(reqwest::header::ETAG)
106            .and_then(|v| v.to_str().ok())
107            .map(|s| s.to_string());
108
109        let content_length = resp
110            .headers()
111            .get(reqwest::header::CONTENT_LENGTH)
112            .and_then(|v| v.to_str().ok())
113            .and_then(|s| s.parse::<NonZeroU64>().ok());
114
115        let mime = resp
116            .headers()
117            .get(reqwest::header::CONTENT_TYPE)
118            .and_then(|v| v.to_str().ok())
119            .map(|s| s.to_string());
120
121        let url = resp.url().clone();
122        let pos = 0;
123
124        Ok(Self {
125            client,
126            content_length,
127            url,
128            pos,
129            request: None,
130            response: None,
131            last_chunk: None,
132            seek: None,
133            etag,
134            retry_attempt: 3,
135            mime,
136        })
137    }
138
139    fn reset_retry(&mut self) {
140        self.retry_attempt = 3;
141    }
142}
143
144impl AsyncRead for HttpFile {
145    fn poll_read(
146        mut self: std::pin::Pin<&mut Self>,
147        cx: &mut std::task::Context<'_>,
148        buf: &mut tokio::io::ReadBuf<'_>,
149    ) -> std::task::Poll<std::io::Result<()>> {
150        if let Some(last_chunk) = self.last_chunk.take() {
151            let size = last_chunk.len().min(buf.remaining());
152            buf.put_slice(&last_chunk[..size]);
153            self.pos += size as u64;
154            if size < last_chunk.len() {
155                self.last_chunk = Some(last_chunk.slice(size..));
156            }
157            return std::task::Poll::Ready(Ok(()));
158        }
159
160        let no_response = self.response.is_none();
161        let no_request = self.request.is_none();
162
163        if no_response && no_request {
164            log::debug!(bytes_from = self.pos ; "GET {}", self.url);
165            let request = new_request(&self.client, self.url.clone(), self.pos);
166            self.request = Some((self.pos, request));
167        }
168
169        if let Some((_pos, request)) = self.request.as_mut() {
170            match ready!(request.poll_unpin(cx)) {
171                Ok(stream) => {
172                    // put response stream
173                    self.response = Some(stream);
174                    self.request = None;
175                }
176                Err(err) => {
177                    self.request = None;
178                    return std::task::Poll::Ready(Err(std::io::Error::other(Box::new(err))));
179                }
180            }
181        }
182
183        let Some(response) = self.response.as_mut() else {
184            panic!("response should be Some after polled")
185        };
186
187        let Some(stream_chunks) = ready!(response.poll_next_unpin(cx)) else {
188            return std::task::Poll::Ready(Ok(()));
189        };
190
191        match stream_chunks {
192            Ok(chunk) => {
193                let size = chunk.len().min(buf.remaining());
194                buf.put_slice(&chunk[..size]);
195                self.pos += size as u64;
196                if size < chunk.len() {
197                    self.last_chunk = Some(chunk.slice(size..));
198                }
199                self.reset_retry();
200                std::task::Poll::Ready(Ok(()))
201            }
202            Err(e) => {
203                if self.retry_attempt == 0 {
204                    return std::task::Poll::Ready(Err(std::io::Error::other(Box::new(e))));
205                }
206
207                if e.is_timeout() || e.status().is_some_and(|s| s.is_server_error()) {
208                    log::warn!("timeout, retrying... attempts left: {}", self.retry_attempt);
209                    self.retry_attempt -= 1;
210                    self.response = None;
211                    return self.poll_read(cx, buf);
212                }
213
214                std::task::Poll::Ready(Err(std::io::Error::other(Box::new(e))))
215            }
216        }
217    }
218}
219
220impl AsyncSeek for HttpFile {
221    fn start_seek(
222        mut self: std::pin::Pin<&mut Self>,
223        position: std::io::SeekFrom,
224    ) -> std::io::Result<()> {
225        if let Some(content_length) = self.content_length {
226            let content_length = content_length.get();
227            let effective_pos = match position {
228                std::io::SeekFrom::Start(n) => n,
229                std::io::SeekFrom::End(n) => {
230                    content_length.checked_add_signed(n).ok_or_else(|| {
231                        std::io::Error::new(std::io::ErrorKind::InvalidInput, "invalid seek to end")
232                    })?
233                }
234                std::io::SeekFrom::Current(n) => {
235                    if n == 0 {
236                        self.seek = Some(self.pos);
237                        return Ok(());
238                    }
239                    self.pos.checked_add_signed(n).ok_or_else(|| {
240                        std::io::Error::new(
241                            std::io::ErrorKind::InvalidInput,
242                            "invalid seek to current",
243                        )
244                    })?
245                }
246            };
247            if effective_pos > content_length {
248                return Err(std::io::Error::new(
249                    std::io::ErrorKind::InvalidInput,
250                    "invalid seek beyond end",
251                ));
252            }
253            self.seek = Some(effective_pos);
254            Ok(())
255        } else {
256            if matches!(position, std::io::SeekFrom::End(_)) {
257                return Err(std::io::Error::new(
258                    std::io::ErrorKind::InvalidInput,
259                    "cannot seek from end without known content length",
260                ));
261            }
262
263            let effective_pos = match position {
264                std::io::SeekFrom::Start(n) => n,
265                std::io::SeekFrom::End(_) => {
266                    return Err(std::io::Error::new(
267                        std::io::ErrorKind::InvalidInput,
268                        "cannot seek from end without known content length",
269                    ));
270                }
271                std::io::SeekFrom::Current(n) => {
272                    if n == 0 {
273                        self.seek = Some(self.pos);
274                        return Ok(());
275                    }
276                    self.pos.checked_add_signed(n).ok_or_else(|| {
277                        std::io::Error::new(
278                            std::io::ErrorKind::InvalidInput,
279                            "invalid seek to current",
280                        )
281                    })?
282                }
283            };
284            self.seek = Some(effective_pos);
285            Ok(())
286        }
287    }
288    fn poll_complete(
289        mut self: std::pin::Pin<&mut Self>,
290        cx: &mut std::task::Context<'_>,
291    ) -> std::task::Poll<std::io::Result<u64>> {
292        if self.seek == Some(self.pos) {
293            self.seek = None;
294            return std::task::Poll::Ready(Ok(self.pos));
295        }
296
297        let Some(seek_pos) = self.seek else {
298            return std::task::Poll::Ready(Ok(self.pos));
299        };
300
301        if self.request.is_none() || self.request.as_ref().unwrap().0 != seek_pos {
302            log::debug!(bytes_from = self.pos ; "GET {}", self.url);
303            let request = new_request(&self.client, self.url.clone(), seek_pos);
304            self.request = Some((seek_pos, request));
305        }
306
307        match ready!(self.request.as_mut().unwrap().1.poll_unpin(cx)) {
308            Ok(stream) => {
309                self.response = Some(stream);
310                self.pos = seek_pos;
311                self.seek = None;
312                self.request = None;
313                self.last_chunk = None;
314                std::task::Poll::Ready(Ok(self.pos))
315            }
316            Err(err) => {
317                self.request = None;
318                std::task::Poll::Ready(Err(std::io::Error::other(Box::new(err))))
319            }
320        }
321    }
322}