Skip to main content

hexz_core/store/http/
async_client.rs

1//! HTTP storage backend with embedded Tokio runtime.
2//!
3//! This module provides an HTTP storage backend that wraps the `reqwest` async
4//! client in an embedded Tokio runtime. This allows the backend to present a
5//! synchronous `StorageBackend` interface while leveraging async I/O internally
6//! for efficient concurrent operations.
7//!
8//! # Architecture
9//!
10//! The [`HttpBackend`] embeds a Tokio runtime (`Arc<Runtime>`) and uses
11//! `runtime.block_on()` to execute async operations synchronously. This design:
12//! - Maintains compatibility with the synchronous `StorageBackend` trait
13//! - Enables efficient connection pooling and concurrent requests via `reqwest`
14//! - Provides async benefits (low memory overhead per connection) without requiring
15//!   callers to use async/await
16//!
17//! # Redirect Handling
18//!
19//! reqwest's built-in redirect policy drops certain headers (including `Range`)
20//! on cross-origin redirects. Since CDNs like GitHub Releases → Azure Blob
21//! require cross-origin redirects with range headers, this backend disables
22//! auto-redirects and follows them manually, re-sending all headers on each hop.
23//!
24//! # Thread Safety
25//!
26//! The backend is fully thread-safe (`Send + Sync`):
27//! - The `reqwest::Client` is designed for concurrent use
28//! - The `Arc<Runtime>` is shared safely across threads
29//! - Multiple threads can call `read_exact()` concurrently without coordination
30//!
31//! # Security
32//!
33//! This backend validates URLs to prevent SSRF attacks:
34//! - Blocks access to localhost and private networks by default
35//! - Set `allow_restricted: true` only in trusted environments
36//! - See [`validate_url`](crate::store::utils::validate_url) for details
37//!
38//! # Examples
39//!
40//! ```no_run
41//! use hexz_core::store::http::HttpBackend;
42//! use hexz_core::store::StorageBackend;
43//!
44//! # fn main() -> Result<(), Box<dyn std::error::Error>> {
45//! let backend = HttpBackend::new(
46//!     "https://cdn.example.com/snapshots/data.hxz".to_string(),
47//!     false // block restricted IPs
48//! )?;
49//!
50//! println!("Snapshot size: {} bytes", backend.len());
51//!
52//! let header = backend.read_exact(0, 512)?;
53//! assert_eq!(header.len(), 512);
54//! # Ok(())
55//! # }
56//! ```
57
58use crate::store::StorageBackend;
59use crate::store::runtime::global_handle;
60use crate::store::utils::validate_url;
61use bytes::Bytes;
62use hexz_common::{Error, Result};
63use reqwest::Client;
64use reqwest::header::HeaderMap;
65use reqwest::redirect::Policy;
66use std::io::{Error as IoError, ErrorKind};
67use tokio::runtime::Handle;
68
69/// Maximum number of redirects to follow before giving up.
70const MAX_REDIRECTS: usize = 10;
71
72/// Send an HTTP request, manually following redirects while preserving headers.
73///
74/// reqwest's built-in redirect policy strips headers like `Range` on cross-origin
75/// redirects. This function follows redirects manually and re-sends the provided
76/// extra headers on every hop, ensuring range requests work through CDNs.
77async fn send_with_redirects(
78    client: &Client,
79    method: reqwest::Method,
80    url: &str,
81    extra_headers: HeaderMap,
82) -> Result<reqwest::Response> {
83    let mut current_url = url.to_string();
84    let mut current_method = method;
85
86    for _ in 0..=MAX_REDIRECTS {
87        let resp = client
88            .request(current_method.clone(), &current_url)
89            .headers(extra_headers.clone())
90            .send()
91            .await
92            .map_err(|e| Error::Io(IoError::other(e)))?;
93
94        if !resp.status().is_redirection() {
95            if !resp.status().is_success() {
96                return Err(Error::Io(IoError::other(format!(
97                    "HTTP {} for {}",
98                    resp.status(),
99                    current_url
100                ))));
101            }
102            return Ok(resp);
103        }
104
105        let location = resp
106            .headers()
107            .get(reqwest::header::LOCATION)
108            .and_then(|v| v.to_str().ok())
109            .ok_or_else(|| {
110                Error::Io(IoError::new(
111                    ErrorKind::InvalidData,
112                    "Redirect without Location header",
113                ))
114            })?
115            .to_string();
116
117        // 303 See Other: switch to GET; all others preserve method
118        if resp.status().as_u16() == 303 {
119            current_method = reqwest::Method::GET;
120        }
121
122        current_url = location;
123    }
124
125    Err(Error::Io(IoError::other(format!(
126        "Too many redirects (>{MAX_REDIRECTS})"
127    ))))
128}
129
130/// HTTP storage backend with embedded Tokio runtime.
131///
132/// This backend wraps an async `reqwest::Client` and Tokio `Runtime` to provide
133/// synchronous `StorageBackend` operations while leveraging async I/O internally.
134/// It validates URLs for security, maintains a connection pool, and performs
135/// range requests to fetch specific byte ranges.
136///
137/// # Examples
138///
139/// ```no_run
140/// use hexz_core::store::http::HttpBackend;
141/// use hexz_core::store::StorageBackend;
142///
143/// # fn main() -> Result<(), Box<dyn std::error::Error>> {
144/// let backend = HttpBackend::new(
145///     "https://example.com/snapshot.hxz".to_string(),
146///     false
147/// )?;
148///
149/// let data = backend.read_exact(8192, 4096)?;
150/// assert_eq!(data.len(), 4096);
151/// # Ok(())
152/// # }
153/// ```
154#[derive(Debug)]
155pub struct HttpBackend {
156    url: String,
157    client: Client,
158    len: u64,
159    handle: Handle,
160}
161
162impl HttpBackend {
163    /// Creates a new HTTP backend by validating the URL and fetching file metadata.
164    ///
165    /// This constructor:
166    /// 1. Validates the URL for security (blocks restricted IPs unless allowed)
167    /// 2. Creates a Tokio runtime for executing async operations
168    /// 3. Sends an async HEAD request to verify the server and fetch file size
169    /// 4. Extracts the `Content-Length` header to determine snapshot size
170    ///
171    /// # Parameters
172    ///
173    /// - `url`: The HTTP/HTTPS URL of the snapshot file
174    /// - `allow_restricted`: If `false`, blocks access to localhost and private networks
175    pub fn new(url: String, allow_restricted: bool) -> Result<Self> {
176        let safe_url = validate_url(&url, allow_restricted)?;
177
178        let handle = global_handle();
179
180        // Disable auto-redirects so we can manually follow them while
181        // preserving headers (e.g. Range) across cross-origin redirects.
182        let client = Client::builder()
183            .redirect(Policy::none())
184            .build()
185            .map_err(|e| Error::Io(IoError::other(e)))?;
186
187        let len = handle.block_on(async {
188            let resp =
189                send_with_redirects(&client, reqwest::Method::HEAD, &safe_url, HeaderMap::new())
190                    .await?;
191
192            resp.headers()
193                .get(reqwest::header::CONTENT_LENGTH)
194                .and_then(|val| val.to_str().ok())
195                .and_then(|s| s.parse::<u64>().ok())
196                .ok_or_else(|| {
197                    Error::Io(IoError::new(
198                        ErrorKind::InvalidData,
199                        "Missing Content-Length header",
200                    ))
201                })
202        })?;
203
204        Ok(Self {
205            url: safe_url,
206            client,
207            len,
208            handle,
209        })
210    }
211}
212
213impl StorageBackend for HttpBackend {
214    fn read_exact(&self, offset: u64, len: usize) -> Result<Bytes> {
215        if len == 0 {
216            return Ok(Bytes::new());
217        }
218        let end = offset + len as u64 - 1;
219
220        let mut headers = HeaderMap::new();
221        headers.insert(
222            reqwest::header::RANGE,
223            format!("bytes={offset}-{end}").parse().unwrap(),
224        );
225
226        self.handle.block_on(async {
227            let resp =
228                send_with_redirects(&self.client, reqwest::Method::GET, &self.url, headers).await?;
229
230            let bytes = resp
231                .bytes()
232                .await
233                .map_err(|e| Error::Io(IoError::other(e)))?;
234
235            if bytes.len() != len {
236                return Err(Error::Io(IoError::new(
237                    ErrorKind::UnexpectedEof,
238                    format!("Expected {} bytes, got {}", len, bytes.len()),
239                )));
240            }
241
242            Ok(bytes)
243        })
244    }
245
246    fn len(&self) -> u64 {
247        self.len
248    }
249}