1#![warn(missing_docs)]
2#![doc = include_str!("../README.md")]
3
4use futures_util::{FutureExt, StreamExt, future::BoxFuture, stream::BoxStream};
5use std::{num::NonZeroU64, task::ready};
6use tokio::io::{AsyncRead, AsyncSeek};
7
8type RequestFuture = BoxFuture<'static, reqwest::Result<ResponseStream>>;
9type ResponseStream = BoxStream<'static, reqwest::Result<bytes::Bytes>>;
10
11fn new_request(client: &reqwest::Client, url: reqwest::Url, pos: u64) -> RequestFuture {
12 client
13 .get(url)
14 .header(reqwest::header::RANGE, format!("bytes={}-", pos))
15 .send()
16 .map(|resp| match resp {
17 Ok(resp) => match resp.error_for_status() {
18 Ok(resp) => Ok(resp.bytes_stream().boxed()),
19 Err(e) => Err(e),
20 },
21 Err(e) => Err(e),
22 })
23 .boxed()
24}
25
26pub struct HttpFile {
35 client: reqwest::Client,
36
37 url: reqwest::Url,
39 content_length: Option<NonZeroU64>,
40 etag: Option<String>,
41 mime: Option<String>,
42
43 pos: u64,
45 request: Option<(u64, RequestFuture)>,
46 response: Option<ResponseStream>,
47 last_chunk: Option<bytes::Bytes>,
48 seek: Option<u64>,
49 retry_attempt: u8,
50}
51
52impl std::fmt::Debug for HttpFile {
53 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
54 f.debug_struct("HttpFile")
55 .field("client", &self.client)
56 .field("url", &self.url)
57 .field("content_length", &self.content_length)
58 .field("etag", &self.etag)
59 .field("pos", &self.pos)
60 .field(
61 "request",
62 &self
63 .request
64 .as_ref()
65 .map(|(pos, _)| format!("request at {}", pos)),
66 )
67 .field("response", &"[response stream]")
68 .field("last_chunk", &self.last_chunk)
69 .field("seek", &self.seek)
70 .finish()
71 }
72}
73
74impl HttpFile {
75 pub fn url(&self) -> &reqwest::Url {
77 &self.url
78 }
79 pub fn content_length(&self) -> Option<u64> {
81 self.content_length.map(|v| v.get())
82 }
83 pub fn etag(&self) -> Option<&str> {
85 self.etag.as_deref()
86 }
87 pub fn mime(&self) -> Option<&str> {
89 self.mime.as_deref()
90 }
91}
92
93impl HttpFile {
94 pub async fn new(client: reqwest::Client, url: &str) -> reqwest::Result<Self> {
101 log::debug!("HEAD {}", url);
102 let resp = client.head(url).send().await?.error_for_status()?;
103 let etag = resp
104 .headers()
105 .get(reqwest::header::ETAG)
106 .and_then(|v| v.to_str().ok())
107 .map(|s| s.to_string());
108
109 let content_length = resp
110 .headers()
111 .get(reqwest::header::CONTENT_LENGTH)
112 .and_then(|v| v.to_str().ok())
113 .and_then(|s| s.parse::<NonZeroU64>().ok());
114
115 let mime = resp
116 .headers()
117 .get(reqwest::header::CONTENT_TYPE)
118 .and_then(|v| v.to_str().ok())
119 .map(|s| s.to_string());
120
121 let url = resp.url().clone();
122 let pos = 0;
123
124 Ok(Self {
125 client,
126 content_length,
127 url,
128 pos,
129 request: None,
130 response: None,
131 last_chunk: None,
132 seek: None,
133 etag,
134 retry_attempt: 3,
135 mime,
136 })
137 }
138
139 fn reset_retry(&mut self) {
140 self.retry_attempt = 3;
141 }
142}
143
144impl AsyncRead for HttpFile {
145 fn poll_read(
146 mut self: std::pin::Pin<&mut Self>,
147 cx: &mut std::task::Context<'_>,
148 buf: &mut tokio::io::ReadBuf<'_>,
149 ) -> std::task::Poll<std::io::Result<()>> {
150 if let Some(content_length) = self.content_length {
152 if self.pos >= content_length.get() {
153 return std::task::Poll::Ready(Ok(()));
154 }
155 }
156
157 if let Some(last_chunk) = self.last_chunk.take() {
158 let size = last_chunk.len().min(buf.remaining());
159 buf.put_slice(&last_chunk[..size]);
160 self.pos += size as u64;
161 if size < last_chunk.len() {
162 self.last_chunk = Some(last_chunk.slice(size..));
163 }
164 return std::task::Poll::Ready(Ok(()));
165 }
166
167 let no_response = self.response.is_none();
168 let no_request = self.request.is_none();
169
170 if no_response && no_request {
171 log::debug!(bytes_from = self.pos ; "GET {}", self.url);
172 let request = new_request(&self.client, self.url.clone(), self.pos);
173 self.request = Some((self.pos, request));
174 }
175
176 if let Some((_pos, request)) = self.request.as_mut() {
177 match ready!(request.poll_unpin(cx)) {
178 Ok(stream) => {
179 self.response = Some(stream);
181 self.request = None;
182 }
183 Err(err) => {
184 self.request = None;
185 return std::task::Poll::Ready(Err(std::io::Error::other(Box::new(err))));
186 }
187 }
188 }
189
190 let Some(response) = self.response.as_mut() else {
191 panic!("response should be Some after polled")
192 };
193
194 let Some(stream_chunks) = ready!(response.poll_next_unpin(cx)) else {
195 return std::task::Poll::Ready(Ok(()));
196 };
197
198 match stream_chunks {
199 Ok(chunk) => {
200 let size = chunk.len().min(buf.remaining());
201 buf.put_slice(&chunk[..size]);
202 self.pos += size as u64;
203 if size < chunk.len() {
204 self.last_chunk = Some(chunk.slice(size..));
205 }
206 self.reset_retry();
207 std::task::Poll::Ready(Ok(()))
208 }
209 Err(e) => {
210 if self.retry_attempt == 0 {
211 return std::task::Poll::Ready(Err(std::io::Error::other(Box::new(e))));
212 }
213
214 if e.is_timeout() || e.status().is_some_and(|s| s.is_server_error()) {
215 log::warn!("timeout, retrying... attempts left: {}", self.retry_attempt);
216 self.retry_attempt -= 1;
217 self.response = None;
218 return self.poll_read(cx, buf);
219 }
220
221 std::task::Poll::Ready(Err(std::io::Error::other(Box::new(e))))
222 }
223 }
224 }
225}
226
227impl AsyncSeek for HttpFile {
228 fn start_seek(
229 mut self: std::pin::Pin<&mut Self>,
230 position: std::io::SeekFrom,
231 ) -> std::io::Result<()> {
232 if let Some(content_length) = self.content_length {
233 let content_length = content_length.get();
234 let effective_pos = match position {
235 std::io::SeekFrom::Start(n) => n,
236 std::io::SeekFrom::End(n) => {
237 content_length.checked_add_signed(n).ok_or_else(|| {
238 std::io::Error::new(std::io::ErrorKind::InvalidInput, "invalid seek to end")
239 })?
240 }
241 std::io::SeekFrom::Current(n) => {
242 if n == 0 {
243 self.seek = Some(self.pos);
244 return Ok(());
245 }
246 self.pos.checked_add_signed(n).ok_or_else(|| {
247 std::io::Error::new(
248 std::io::ErrorKind::InvalidInput,
249 "invalid seek to current",
250 )
251 })?
252 }
253 };
254 if effective_pos > content_length {
255 return Err(std::io::Error::new(
256 std::io::ErrorKind::InvalidInput,
257 "invalid seek beyond end",
258 ));
259 }
260 self.seek = Some(effective_pos);
261 Ok(())
262 } else {
263 if matches!(position, std::io::SeekFrom::End(_)) {
264 return Err(std::io::Error::new(
265 std::io::ErrorKind::InvalidInput,
266 "cannot seek from end without known content length",
267 ));
268 }
269
270 let effective_pos = match position {
271 std::io::SeekFrom::Start(n) => n,
272 std::io::SeekFrom::End(_) => {
273 return Err(std::io::Error::new(
274 std::io::ErrorKind::InvalidInput,
275 "cannot seek from end without known content length",
276 ));
277 }
278 std::io::SeekFrom::Current(n) => {
279 if n == 0 {
280 self.seek = Some(self.pos);
281 return Ok(());
282 }
283 self.pos.checked_add_signed(n).ok_or_else(|| {
284 std::io::Error::new(
285 std::io::ErrorKind::InvalidInput,
286 "invalid seek to current",
287 )
288 })?
289 }
290 };
291 self.seek = Some(effective_pos);
292 Ok(())
293 }
294 }
295 fn poll_complete(
296 mut self: std::pin::Pin<&mut Self>,
297 cx: &mut std::task::Context<'_>,
298 ) -> std::task::Poll<std::io::Result<u64>> {
299 if self.seek == Some(self.pos) {
300 self.seek = None;
301 return std::task::Poll::Ready(Ok(self.pos));
302 }
303
304 let Some(seek_pos) = self.seek else {
305 return std::task::Poll::Ready(Ok(self.pos));
306 };
307
308 if let Some(content_length) = self.content_length {
310 if seek_pos >= content_length.get() {
311 self.pos = seek_pos;
312 self.seek = None;
313 self.request = None;
314 self.response = None;
315 self.last_chunk = None;
316 return std::task::Poll::Ready(Ok(self.pos));
317 }
318 }
319
320 if self.request.is_none() || self.request.as_ref().unwrap().0 != seek_pos {
321 log::debug!(bytes_from = self.pos ; "GET {}", self.url);
322 let request = new_request(&self.client, self.url.clone(), seek_pos);
323 self.request = Some((seek_pos, request));
324 }
325
326 match ready!(self.request.as_mut().unwrap().1.poll_unpin(cx)) {
327 Ok(stream) => {
328 self.response = Some(stream);
329 self.pos = seek_pos;
330 self.seek = None;
331 self.request = None;
332 self.last_chunk = None;
333 std::task::Poll::Ready(Ok(self.pos))
334 }
335 Err(err) => {
336 self.request = None;
337 std::task::Poll::Ready(Err(std::io::Error::other(Box::new(err))))
338 }
339 }
340 }
341}