Skip to main content

hexz_store/http/
async_client.rs

1//! HTTP storage backend with embedded Tokio runtime.
2
3use crate::runtime::global_handle;
4use crate::utils::validate_url;
5use bytes::Bytes;
6use hexz_common::{Error, Result};
7use hexz_core::store::StorageBackend;
8use reqwest::Client;
9use reqwest::header::HeaderMap;
10use reqwest::redirect::Policy;
11use std::io::{Error as IoError, ErrorKind};
12use tokio::runtime::Handle;
13
14const MAX_REDIRECTS: usize = 10;
15
16async fn send_with_redirects(
17    client: &Client,
18    method: reqwest::Method,
19    url: &str,
20    extra_headers: HeaderMap,
21) -> Result<reqwest::Response> {
22    let mut current_url = url.to_string();
23    let mut current_method = method;
24
25    for _ in 0..=MAX_REDIRECTS {
26        let resp = client
27            .request(current_method.clone(), &current_url)
28            .headers(extra_headers.clone())
29            .send()
30            .await
31            .map_err(|e| Error::Io(IoError::other(e)))?;
32
33        if !resp.status().is_redirection() {
34            if !resp.status().is_success() {
35                return Err(Error::Io(IoError::other(format!(
36                    "HTTP {} for {}",
37                    resp.status(),
38                    current_url
39                ))));
40            }
41            return Ok(resp);
42        }
43
44        let location = resp
45            .headers()
46            .get(reqwest::header::LOCATION)
47            .and_then(|v| v.to_str().ok())
48            .ok_or_else(|| {
49                Error::Io(IoError::new(
50                    ErrorKind::InvalidData,
51                    "Redirect without Location header",
52                ))
53            })?
54            .to_string();
55
56        if resp.status().as_u16() == 303 {
57            current_method = reqwest::Method::GET;
58        }
59        current_url = location;
60    }
61
62    Err(Error::Io(IoError::other(format!(
63        "Too many redirects (>{MAX_REDIRECTS})"
64    ))))
65}
66
67/// HTTP storage backend with embedded Tokio runtime.
68#[derive(Debug)]
69pub struct HttpBackend {
70    url: String,
71    client: Client,
72    len: u64,
73    handle: Handle,
74}
75
76impl HttpBackend {
77    /// Creates an HTTP backend, validates the URL, and fetches file length via HEAD.
78    pub fn new(url: &str, allow_restricted: bool) -> Result<Self> {
79        let safe_url = validate_url(url, allow_restricted)?;
80        let handle = global_handle().map_err(Error::Io)?;
81        let client = Client::builder()
82            .redirect(Policy::none())
83            .build()
84            .map_err(|e| Error::Io(IoError::other(e)))?;
85
86        let len = tokio::task::block_in_place(|| {
87            handle.block_on(async {
88                let resp = send_with_redirects(
89                    &client,
90                    reqwest::Method::HEAD,
91                    &safe_url,
92                    HeaderMap::new(),
93                )
94                .await?;
95                resp.headers()
96                    .get(reqwest::header::CONTENT_LENGTH)
97                    .and_then(|v| v.to_str().ok())
98                    .and_then(|s| s.parse::<u64>().ok())
99                    .ok_or_else(|| {
100                        Error::Io(IoError::new(
101                            ErrorKind::InvalidData,
102                            "Missing Content-Length header",
103                        ))
104                    })
105            })
106        })?;
107
108        Ok(Self {
109            url: safe_url,
110            client,
111            len,
112            handle,
113        })
114    }
115}
116
117impl StorageBackend for HttpBackend {
118    fn read_exact(&self, offset: u64, len: usize) -> Result<Bytes> {
119        if len == 0 {
120            return Ok(Bytes::new());
121        }
122        let end = offset + len as u64 - 1;
123        let mut headers = HeaderMap::new();
124        _ = headers.insert(
125            reqwest::header::RANGE,
126            format!("bytes={offset}-{end}")
127                .parse()
128                .map_err(|e: reqwest::header::InvalidHeaderValue| Error::Io(IoError::other(e)))?,
129        );
130
131        tokio::task::block_in_place(|| {
132            self.handle.block_on(async {
133                let resp =
134                    send_with_redirects(&self.client, reqwest::Method::GET, &self.url, headers)
135                        .await?;
136                let bytes = resp
137                    .bytes()
138                    .await
139                    .map_err(|e| Error::Io(IoError::other(e)))?;
140                if bytes.len() != len {
141                    return Err(Error::Io(IoError::new(
142                        ErrorKind::UnexpectedEof,
143                        format!("Expected {} bytes, got {}", len, bytes.len()),
144                    )));
145                }
146                Ok(bytes)
147            })
148        })
149    }
150
151    fn len(&self) -> u64 {
152        self.len
153    }
154}