remote_file/
lib.rs

1#![warn(missing_docs)]
2#![doc = include_str!("../README.md")]
3
4use futures_util::{FutureExt, StreamExt, future::BoxFuture, stream::BoxStream};
5use std::{num::NonZeroU64, task::ready};
6use tokio::io::{AsyncRead, AsyncSeek};
7
8type RequestFuture = BoxFuture<'static, reqwest::Result<ResponseStream>>;
9type ResponseStream = BoxStream<'static, reqwest::Result<bytes::Bytes>>;
10
11fn new_request(client: &reqwest::Client, url: reqwest::Url, pos: u64) -> RequestFuture {
12    client
13        .get(url)
14        .header(reqwest::header::RANGE, format!("bytes={}-", pos))
15        .send()
16        .map(|resp| match resp {
17            Ok(resp) => match resp.error_for_status() {
18                Ok(resp) => Ok(resp.bytes_stream().boxed()),
19                Err(e) => Err(e),
20            },
21            Err(e) => Err(e),
22        })
23        .boxed()
24}
25
26/// An remote file accessed over HTTP.
27/// Implements `AsyncRead` and `AsyncSeek` traits.
28///
29/// * Supports seeking and reading at arbitrary positions.
30/// * Uses HTTP Range requests to fetch data.
31/// * Handles transient network errors with retries.
32/// * `stream_position()` is cheap, as it is tracked locally.
33///
34pub struct HttpFile {
35    client: reqwest::Client,
36
37    // info
38    url: reqwest::Url,
39    content_length: Option<NonZeroU64>,
40    etag: Option<String>,
41    mime: Option<String>,
42
43    // inner states
44    pos: u64,
45    request: Option<(u64, RequestFuture)>,
46    response: Option<ResponseStream>,
47    last_chunk: Option<bytes::Bytes>,
48    seek: Option<u64>,
49    retry_attempt: u8,
50}
51
52impl std::fmt::Debug for HttpFile {
53    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
54        f.debug_struct("HttpFile")
55            .field("client", &self.client)
56            .field("url", &self.url)
57            .field("content_length", &self.content_length)
58            .field("etag", &self.etag)
59            .field("pos", &self.pos)
60            .field(
61                "request",
62                &self
63                    .request
64                    .as_ref()
65                    .map(|(pos, _)| format!("request at {}", pos)),
66            )
67            .field("response", &"[response stream]")
68            .field("last_chunk", &self.last_chunk)
69            .field("seek", &self.seek)
70            .finish()
71    }
72}
73
74impl HttpFile {
75    /// url of the file
76    pub fn url(&self) -> &reqwest::Url {
77        &self.url
78    }
79    /// content length of the file(in bytes), if present
80    pub fn content_length(&self) -> Option<u64> {
81        self.content_length.map(|v| v.get())
82    }
83    /// etag of the file, if present
84    pub fn etag(&self) -> Option<&str> {
85        self.etag.as_deref()
86    }
87    /// Mime type of the file, if present
88    pub fn mime(&self) -> Option<&str> {
89        self.mime.as_deref()
90    }
91}
92
93impl HttpFile {
94    /// Create a new `HttpFile` from a `reqwest::Client` and a file URL.
95    ///
96    /// Arguments:
97    /// * `client`: A `reqwest::Client` instance to make HTTP requests.
98    /// * `url`: The URL of the file to access.
99    ///
100    pub async fn new(client: reqwest::Client, url: &str) -> reqwest::Result<Self> {
101        log::debug!("HEAD {}", url);
102        let resp = client.head(url).send().await?.error_for_status()?;
103        let etag = resp
104            .headers()
105            .get(reqwest::header::ETAG)
106            .and_then(|v| v.to_str().ok())
107            .map(|s| s.to_string());
108
109        let content_length = resp
110            .headers()
111            .get(reqwest::header::CONTENT_LENGTH)
112            .and_then(|v| v.to_str().ok())
113            .and_then(|s| s.parse::<NonZeroU64>().ok());
114
115        let mime = resp
116            .headers()
117            .get(reqwest::header::CONTENT_TYPE)
118            .and_then(|v| v.to_str().ok())
119            .map(|s| s.to_string());
120
121        let url = resp.url().clone();
122        let pos = 0;
123
124        Ok(Self {
125            client,
126            content_length,
127            url,
128            pos,
129            request: None,
130            response: None,
131            last_chunk: None,
132            seek: None,
133            etag,
134            retry_attempt: 3,
135            mime,
136        })
137    }
138
139    fn reset_retry(&mut self) {
140        self.retry_attempt = 3;
141    }
142}
143
144impl AsyncRead for HttpFile {
145    fn poll_read(
146        mut self: std::pin::Pin<&mut Self>,
147        cx: &mut std::task::Context<'_>,
148        buf: &mut tokio::io::ReadBuf<'_>,
149    ) -> std::task::Poll<std::io::Result<()>> {
150        // Check if we're at or beyond the end of file
151        if let Some(content_length) = self.content_length {
152            if self.pos >= content_length.get() {
153                return std::task::Poll::Ready(Ok(()));
154            }
155        }
156
157        if let Some(last_chunk) = self.last_chunk.take() {
158            let size = last_chunk.len().min(buf.remaining());
159            buf.put_slice(&last_chunk[..size]);
160            self.pos += size as u64;
161            if size < last_chunk.len() {
162                self.last_chunk = Some(last_chunk.slice(size..));
163            }
164            return std::task::Poll::Ready(Ok(()));
165        }
166
167        let no_response = self.response.is_none();
168        let no_request = self.request.is_none();
169
170        if no_response && no_request {
171            log::debug!(bytes_from = self.pos ; "GET {}", self.url);
172            let request = new_request(&self.client, self.url.clone(), self.pos);
173            self.request = Some((self.pos, request));
174        }
175
176        if let Some((_pos, request)) = self.request.as_mut() {
177            match ready!(request.poll_unpin(cx)) {
178                Ok(stream) => {
179                    // put response stream
180                    self.response = Some(stream);
181                    self.request = None;
182                }
183                Err(err) => {
184                    self.request = None;
185                    return std::task::Poll::Ready(Err(std::io::Error::other(Box::new(err))));
186                }
187            }
188        }
189
190        let Some(response) = self.response.as_mut() else {
191            panic!("response should be Some after polled")
192        };
193
194        let Some(stream_chunks) = ready!(response.poll_next_unpin(cx)) else {
195            return std::task::Poll::Ready(Ok(()));
196        };
197
198        match stream_chunks {
199            Ok(chunk) => {
200                let size = chunk.len().min(buf.remaining());
201                buf.put_slice(&chunk[..size]);
202                self.pos += size as u64;
203                if size < chunk.len() {
204                    self.last_chunk = Some(chunk.slice(size..));
205                }
206                self.reset_retry();
207                std::task::Poll::Ready(Ok(()))
208            }
209            Err(e) => {
210                if self.retry_attempt == 0 {
211                    return std::task::Poll::Ready(Err(std::io::Error::other(Box::new(e))));
212                }
213
214                if e.is_timeout() || e.status().is_some_and(|s| s.is_server_error()) {
215                    log::warn!("timeout, retrying... attempts left: {}", self.retry_attempt);
216                    self.retry_attempt -= 1;
217                    self.response = None;
218                    return self.poll_read(cx, buf);
219                }
220
221                std::task::Poll::Ready(Err(std::io::Error::other(Box::new(e))))
222            }
223        }
224    }
225}
226
227impl AsyncSeek for HttpFile {
228    fn start_seek(
229        mut self: std::pin::Pin<&mut Self>,
230        position: std::io::SeekFrom,
231    ) -> std::io::Result<()> {
232        if let Some(content_length) = self.content_length {
233            let content_length = content_length.get();
234            let effective_pos = match position {
235                std::io::SeekFrom::Start(n) => n,
236                std::io::SeekFrom::End(n) => {
237                    content_length.checked_add_signed(n).ok_or_else(|| {
238                        std::io::Error::new(std::io::ErrorKind::InvalidInput, "invalid seek to end")
239                    })?
240                }
241                std::io::SeekFrom::Current(n) => {
242                    if n == 0 {
243                        self.seek = Some(self.pos);
244                        return Ok(());
245                    }
246                    self.pos.checked_add_signed(n).ok_or_else(|| {
247                        std::io::Error::new(
248                            std::io::ErrorKind::InvalidInput,
249                            "invalid seek to current",
250                        )
251                    })?
252                }
253            };
254            if effective_pos > content_length {
255                return Err(std::io::Error::new(
256                    std::io::ErrorKind::InvalidInput,
257                    "invalid seek beyond end",
258                ));
259            }
260            self.seek = Some(effective_pos);
261            Ok(())
262        } else {
263            if matches!(position, std::io::SeekFrom::End(_)) {
264                return Err(std::io::Error::new(
265                    std::io::ErrorKind::InvalidInput,
266                    "cannot seek from end without known content length",
267                ));
268            }
269
270            let effective_pos = match position {
271                std::io::SeekFrom::Start(n) => n,
272                std::io::SeekFrom::End(_) => {
273                    return Err(std::io::Error::new(
274                        std::io::ErrorKind::InvalidInput,
275                        "cannot seek from end without known content length",
276                    ));
277                }
278                std::io::SeekFrom::Current(n) => {
279                    if n == 0 {
280                        self.seek = Some(self.pos);
281                        return Ok(());
282                    }
283                    self.pos.checked_add_signed(n).ok_or_else(|| {
284                        std::io::Error::new(
285                            std::io::ErrorKind::InvalidInput,
286                            "invalid seek to current",
287                        )
288                    })?
289                }
290            };
291            self.seek = Some(effective_pos);
292            Ok(())
293        }
294    }
295    fn poll_complete(
296        mut self: std::pin::Pin<&mut Self>,
297        cx: &mut std::task::Context<'_>,
298    ) -> std::task::Poll<std::io::Result<u64>> {
299        if self.seek == Some(self.pos) {
300            self.seek = None;
301            return std::task::Poll::Ready(Ok(self.pos));
302        }
303
304        let Some(seek_pos) = self.seek else {
305            return std::task::Poll::Ready(Ok(self.pos));
306        };
307
308        // If seeking to or beyond EOF, just update position without making a request
309        if let Some(content_length) = self.content_length {
310            if seek_pos >= content_length.get() {
311                self.pos = seek_pos;
312                self.seek = None;
313                self.request = None;
314                self.response = None;
315                self.last_chunk = None;
316                return std::task::Poll::Ready(Ok(self.pos));
317            }
318        }
319
320        if self.request.is_none() || self.request.as_ref().unwrap().0 != seek_pos {
321            log::debug!(bytes_from = self.pos ; "GET {}", self.url);
322            let request = new_request(&self.client, self.url.clone(), seek_pos);
323            self.request = Some((seek_pos, request));
324        }
325
326        match ready!(self.request.as_mut().unwrap().1.poll_unpin(cx)) {
327            Ok(stream) => {
328                self.response = Some(stream);
329                self.pos = seek_pos;
330                self.seek = None;
331                self.request = None;
332                self.last_chunk = None;
333                std::task::Poll::Ready(Ok(self.pos))
334            }
335            Err(err) => {
336                self.request = None;
337                std::task::Poll::Ready(Err(std::io::Error::other(Box::new(err))))
338            }
339        }
340    }
341}