1#![warn(missing_docs)]
2#![doc = include_str!("../README.md")]
3
4use futures_util::{FutureExt, StreamExt, future::BoxFuture, stream::BoxStream};
5use std::{num::NonZeroU64, task::ready};
6use tokio::io::{AsyncRead, AsyncSeek};
7
8type RequestFuture = BoxFuture<'static, reqwest::Result<ResponseStream>>;
9type ResponseStream = BoxStream<'static, reqwest::Result<bytes::Bytes>>;
10
11fn new_request(client: &reqwest::Client, url: reqwest::Url, pos: u64) -> RequestFuture {
12 client
13 .get(url)
14 .header(reqwest::header::RANGE, format!("bytes={}-", pos))
15 .send()
16 .map(|resp| match resp {
17 Ok(resp) => match resp.error_for_status() {
18 Ok(resp) => Ok(resp.bytes_stream().boxed()),
19 Err(e) => Err(e),
20 },
21 Err(e) => Err(e),
22 })
23 .boxed()
24}
25
26pub struct HttpFile {
35 client: reqwest::Client,
36
37 url: reqwest::Url,
39 content_length: Option<NonZeroU64>,
40 etag: Option<String>,
41 mime: Option<String>,
42
43 pos: u64,
45 request: Option<(u64, RequestFuture)>,
46 response: Option<ResponseStream>,
47 last_chunk: Option<bytes::Bytes>,
48 seek: Option<u64>,
49 retry_attempt: u8,
50}
51
52impl std::fmt::Debug for HttpFile {
53 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
54 f.debug_struct("HttpFile")
55 .field("client", &self.client)
56 .field("url", &self.url)
57 .field("content_length", &self.content_length)
58 .field("etag", &self.etag)
59 .field("pos", &self.pos)
60 .field(
61 "request",
62 &self
63 .request
64 .as_ref()
65 .map(|(pos, _)| format!("request at {}", pos)),
66 )
67 .field("response", &"[response stream]")
68 .field("last_chunk", &self.last_chunk)
69 .field("seek", &self.seek)
70 .finish()
71 }
72}
73
74impl HttpFile {
75 pub fn url(&self) -> &reqwest::Url {
77 &self.url
78 }
79 pub fn content_length(&self) -> Option<u64> {
81 self.content_length.map(|v| v.get())
82 }
83 pub fn etag(&self) -> Option<&str> {
85 self.etag.as_deref()
86 }
87 pub fn mime(&self) -> Option<&str> {
89 self.mime.as_deref()
90 }
91}
92
93impl HttpFile {
94 pub async fn new(client: reqwest::Client, url: &str) -> reqwest::Result<Self> {
101 log::debug!("HEAD {}", url);
102 let resp = client.head(url).send().await?.error_for_status()?;
103 let etag = resp
104 .headers()
105 .get(reqwest::header::ETAG)
106 .and_then(|v| v.to_str().ok())
107 .map(|s| s.to_string());
108
109 let content_length = resp
110 .headers()
111 .get(reqwest::header::CONTENT_LENGTH)
112 .and_then(|v| v.to_str().ok())
113 .and_then(|s| s.parse::<NonZeroU64>().ok());
114
115 let mime = resp
116 .headers()
117 .get(reqwest::header::CONTENT_TYPE)
118 .and_then(|v| v.to_str().ok())
119 .map(|s| s.to_string());
120
121 let url = resp.url().clone();
122 let pos = 0;
123
124 Ok(Self {
125 client,
126 content_length,
127 url,
128 pos,
129 request: None,
130 response: None,
131 last_chunk: None,
132 seek: None,
133 etag,
134 retry_attempt: 3,
135 mime,
136 })
137 }
138
139 fn reset_retry(&mut self) {
140 self.retry_attempt = 3;
141 }
142}
143
144impl AsyncRead for HttpFile {
145 fn poll_read(
146 mut self: std::pin::Pin<&mut Self>,
147 cx: &mut std::task::Context<'_>,
148 buf: &mut tokio::io::ReadBuf<'_>,
149 ) -> std::task::Poll<std::io::Result<()>> {
150 if let Some(last_chunk) = self.last_chunk.take() {
151 let size = last_chunk.len().min(buf.remaining());
152 buf.put_slice(&last_chunk[..size]);
153 self.pos += size as u64;
154 if size < last_chunk.len() {
155 self.last_chunk = Some(last_chunk.slice(size..));
156 }
157 return std::task::Poll::Ready(Ok(()));
158 }
159
160 let no_response = self.response.is_none();
161 let no_request = self.request.is_none();
162
163 if no_response && no_request {
164 log::debug!(bytes_from = self.pos ; "GET {}", self.url);
165 let request = new_request(&self.client, self.url.clone(), self.pos);
166 self.request = Some((self.pos, request));
167 }
168
169 if let Some((_pos, request)) = self.request.as_mut() {
170 match ready!(request.poll_unpin(cx)) {
171 Ok(stream) => {
172 self.response = Some(stream);
174 self.request = None;
175 }
176 Err(err) => {
177 self.request = None;
178 return std::task::Poll::Ready(Err(std::io::Error::other(Box::new(err))));
179 }
180 }
181 }
182
183 let Some(response) = self.response.as_mut() else {
184 panic!("response should be Some after polled")
185 };
186
187 let Some(stream_chunks) = ready!(response.poll_next_unpin(cx)) else {
188 return std::task::Poll::Ready(Ok(()));
189 };
190
191 match stream_chunks {
192 Ok(chunk) => {
193 let size = chunk.len().min(buf.remaining());
194 buf.put_slice(&chunk[..size]);
195 self.pos += size as u64;
196 if size < chunk.len() {
197 self.last_chunk = Some(chunk.slice(size..));
198 }
199 self.reset_retry();
200 std::task::Poll::Ready(Ok(()))
201 }
202 Err(e) => {
203 if self.retry_attempt == 0 {
204 return std::task::Poll::Ready(Err(std::io::Error::other(Box::new(e))));
205 }
206
207 if e.is_timeout() || e.status().is_some_and(|s| s.is_server_error()) {
208 log::warn!("timeout, retrying... attempts left: {}", self.retry_attempt);
209 self.retry_attempt -= 1;
210 self.response = None;
211 return self.poll_read(cx, buf);
212 }
213
214 std::task::Poll::Ready(Err(std::io::Error::other(Box::new(e))))
215 }
216 }
217 }
218}
219
220impl AsyncSeek for HttpFile {
221 fn start_seek(
222 mut self: std::pin::Pin<&mut Self>,
223 position: std::io::SeekFrom,
224 ) -> std::io::Result<()> {
225 if let Some(content_length) = self.content_length {
226 let content_length = content_length.get();
227 let effective_pos = match position {
228 std::io::SeekFrom::Start(n) => n,
229 std::io::SeekFrom::End(n) => {
230 content_length.checked_add_signed(n).ok_or_else(|| {
231 std::io::Error::new(std::io::ErrorKind::InvalidInput, "invalid seek to end")
232 })?
233 }
234 std::io::SeekFrom::Current(n) => {
235 if n == 0 {
236 self.seek = Some(self.pos);
237 return Ok(());
238 }
239 self.pos.checked_add_signed(n).ok_or_else(|| {
240 std::io::Error::new(
241 std::io::ErrorKind::InvalidInput,
242 "invalid seek to current",
243 )
244 })?
245 }
246 };
247 if effective_pos > content_length {
248 return Err(std::io::Error::new(
249 std::io::ErrorKind::InvalidInput,
250 "invalid seek beyond end",
251 ));
252 }
253 self.seek = Some(effective_pos);
254 Ok(())
255 } else {
256 if matches!(position, std::io::SeekFrom::End(_)) {
257 return Err(std::io::Error::new(
258 std::io::ErrorKind::InvalidInput,
259 "cannot seek from end without known content length",
260 ));
261 }
262
263 let effective_pos = match position {
264 std::io::SeekFrom::Start(n) => n,
265 std::io::SeekFrom::End(_) => {
266 return Err(std::io::Error::new(
267 std::io::ErrorKind::InvalidInput,
268 "cannot seek from end without known content length",
269 ));
270 }
271 std::io::SeekFrom::Current(n) => {
272 if n == 0 {
273 self.seek = Some(self.pos);
274 return Ok(());
275 }
276 self.pos.checked_add_signed(n).ok_or_else(|| {
277 std::io::Error::new(
278 std::io::ErrorKind::InvalidInput,
279 "invalid seek to current",
280 )
281 })?
282 }
283 };
284 self.seek = Some(effective_pos);
285 Ok(())
286 }
287 }
288 fn poll_complete(
289 mut self: std::pin::Pin<&mut Self>,
290 cx: &mut std::task::Context<'_>,
291 ) -> std::task::Poll<std::io::Result<u64>> {
292 if self.seek == Some(self.pos) {
293 self.seek = None;
294 return std::task::Poll::Ready(Ok(self.pos));
295 }
296
297 let Some(seek_pos) = self.seek else {
298 return std::task::Poll::Ready(Ok(self.pos));
299 };
300
301 if self.request.is_none() || self.request.as_ref().unwrap().0 != seek_pos {
302 log::debug!(bytes_from = self.pos ; "GET {}", self.url);
303 let request = new_request(&self.client, self.url.clone(), seek_pos);
304 self.request = Some((seek_pos, request));
305 }
306
307 match ready!(self.request.as_mut().unwrap().1.poll_unpin(cx)) {
308 Ok(stream) => {
309 self.response = Some(stream);
310 self.pos = seek_pos;
311 self.seek = None;
312 self.request = None;
313 self.last_chunk = None;
314 std::task::Poll::Ready(Ok(self.pos))
315 }
316 Err(err) => {
317 self.request = None;
318 std::task::Poll::Ready(Err(std::io::Error::other(Box::new(err))))
319 }
320 }
321 }
322}