Skip to main content

boundless_market/prover_utils/
storage.rs

1// Copyright 2026 Boundless Foundation, Inc.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! URI handling for fetching data from HTTP, S3, and file URLs.
16
17use alloy::primitives::bytes::Buf;
18#[cfg(feature = "s3")]
19use aws_config::retry::RetryConfig;
20#[cfg(feature = "s3")]
21use aws_sdk_s3::{
22    config::{ProvideCredentials, SharedCredentialsProvider},
23    error::ProvideErrorMetadata,
24    Client as S3Client,
25};
26use futures_util::StreamExt;
27use std::env;
28use thiserror::Error;
29
30use super::config::MarketConfig;
31
32#[cfg(feature = "s3")]
33const ENV_VAR_ROLE_ARN: &str = "AWS_ROLE_ARN";
34
35/// Returns `true` if the dev mode environment variable is enabled.
36fn is_dev_mode() -> bool {
37    env::var("RISC0_DEV_MODE")
38        .ok()
39        .map(|x| x.to_lowercase())
40        .filter(|x| x == "1" || x == "true" || x == "yes")
41        .is_some()
42}
43
44/// Returns `true` if the `ALLOW_LOCAL_FILE_STORAGE` environment variable is enabled.
45fn allow_local_file_storage() -> bool {
46    env::var("ALLOW_LOCAL_FILE_STORAGE")
47        .ok()
48        .map(|x| x.to_lowercase())
49        .filter(|x| x == "1" || x == "true" || x == "yes")
50        .is_some()
51}
52
53/// Returns `true` if file:// URLs are allowed based on environment variables.
54fn allow_file_urls() -> bool {
55    is_dev_mode() || allow_local_file_storage()
56}
57
58/// Errors that can occur during URI fetching.
59#[derive(Error, Debug)]
60#[non_exhaustive]
61pub enum StorageError {
62    /// Unsupported URI scheme.
63    #[error("unsupported URI scheme: {0}")]
64    UnsupportedScheme(String),
65
66    /// Failed to parse URL.
67    #[error("failed to parse URL: {0}")]
68    UriParse(#[from] url::ParseError),
69
70    /// Invalid URL.
71    #[error("invalid URL: {0}")]
72    InvalidUrl(&'static str),
73
74    /// Resource size exceeds maximum allowed.
75    #[error("resource size exceeds maximum allowed size ({0} bytes)")]
76    SizeLimitExceeded(usize),
77
78    /// File I/O error.
79    #[error("file error: {0}")]
80    File(#[from] std::io::Error),
81
82    /// HTTP error.
83    #[error("HTTP error: {0}")]
84    Http(String),
85
86    /// AWS S3 error.
87    #[cfg(feature = "s3")]
88    #[error("AWS S3 error: {0}")]
89    S3(String),
90}
91
92/// Fetch data from a URI with default config. Supports HTTP/HTTPS and S3 schemes.
93///
94/// For more control over fetching behavior (size limits, retries, caching, file:// support),
95/// use [`fetch_uri_with_config`] instead.
96#[allow(unused)]
97pub async fn fetch_uri(uri: &str) -> Result<Vec<u8>, StorageError> {
98    fetch_uri_with_config(uri, &MarketConfig::default()).await
99}
100
101/// Fetch data from a URI with the given market configuration.
102///
103/// Supports:
104/// - `http://` and `https://` URLs
105/// - `s3://bucket/key` URLs (requires AWS credentials)
106/// - `file://` URLs (only if `RISC0_DEV_MODE` or `ALLOW_LOCAL_FILE_STORAGE` env var is set)
107pub async fn fetch_uri_with_config(
108    uri: &str,
109    config: &MarketConfig,
110) -> Result<Vec<u8>, StorageError> {
111    let parsed = url::Url::parse(uri)?;
112
113    match parsed.scheme() {
114        "file" => {
115            if !allow_file_urls() {
116                return Err(StorageError::UnsupportedScheme(
117                    "file (not allowed in this context)".to_string(),
118                ));
119            }
120            fetch_file(parsed.path(), Some(config.max_file_size)).await
121        }
122        "http" | "https" => fetch_http(parsed, config).await,
123        #[cfg(feature = "s3")]
124        "s3" => fetch_s3(parsed, config).await,
125        scheme => Err(StorageError::UnsupportedScheme(scheme.to_string())),
126    }
127}
128
129/// Fetch data from a local file.
130async fn fetch_file(path: &str, max_size: Option<usize>) -> Result<Vec<u8>, StorageError> {
131    let metadata = tokio::fs::metadata(path).await?;
132    let size = metadata.len() as usize;
133
134    if let Some(max) = max_size {
135        if size > max {
136            return Err(StorageError::SizeLimitExceeded(size));
137        }
138    }
139
140    Ok(tokio::fs::read(path).await?)
141}
142
143/// Fetch data from an HTTP/HTTPS URL.
144async fn fetch_http(url: url::Url, config: &MarketConfig) -> Result<Vec<u8>, StorageError> {
145    use reqwest_middleware::ClientBuilder;
146    use reqwest_retry::{policies::ExponentialBackoff, RetryTransientMiddleware};
147
148    if !url.has_host() {
149        return Err(StorageError::InvalidUrl("missing host"));
150    }
151
152    let mut builder = ClientBuilder::new(reqwest::Client::new());
153
154    // Add retry middleware if configured
155    if let Some(max_retries) = config.max_fetch_retries {
156        let retry_policy = ExponentialBackoff::builder().build_with_max_retries(max_retries as u32);
157        let retry_middleware = RetryTransientMiddleware::new_with_policy(retry_policy);
158        builder = builder.with(retry_middleware);
159    }
160
161    let client = builder.build();
162    let response =
163        client.get(url.clone()).send().await.map_err(|e| StorageError::Http(e.to_string()))?;
164    let response = response.error_for_status().map_err(|e| StorageError::Http(e.to_string()))?;
165
166    let max_size = config.max_file_size;
167
168    // Check content-length header first
169    let capacity = response.content_length().unwrap_or_default() as usize;
170    if capacity > max_size {
171        return Err(StorageError::SizeLimitExceeded(capacity));
172    }
173
174    // Stream the response and check size incrementally
175    let mut buffer = Vec::with_capacity(capacity);
176    let mut stream = response.bytes_stream();
177
178    while let Some(chunk) = stream.next().await {
179        let chunk = chunk.map_err(|e| StorageError::Http(e.to_string()))?;
180        buffer.extend_from_slice(chunk.chunk());
181        if buffer.len() > max_size {
182            return Err(StorageError::SizeLimitExceeded(buffer.len()));
183        }
184    }
185
186    Ok(buffer)
187}
188
189/// Fetch data from an S3 URL.
190///
191/// Authenticates using the default AWS credential chain (environment variables,
192/// `~/.aws/credentials`, `~/.aws/config`, etc.).
193///
194/// If the `AWS_ROLE_ARN` environment variable is set, it will attempt to assume that
195/// IAM role before accessing S3.
196#[cfg(feature = "s3")]
197async fn fetch_s3(url: url::Url, config: &MarketConfig) -> Result<Vec<u8>, StorageError> {
198    let retry_config = if let Some(max_retries) = config.max_fetch_retries {
199        RetryConfig::standard().with_max_attempts(max_retries as u32 + 1)
200    } else {
201        RetryConfig::disabled()
202    };
203
204    let mut aws_config = aws_config::from_env().retry_config(retry_config).load().await;
205
206    // Verify credentials are available
207    if let Some(provider) = aws_config.credentials_provider() {
208        if let Err(e) = provider.provide_credentials().await {
209            tracing::debug!(error=%e, "Could not load AWS credentials for S3");
210            return Err(StorageError::UnsupportedScheme(format!(
211                "s3 (no credentials available: {})",
212                e
213            )));
214        }
215    } else {
216        return Err(StorageError::UnsupportedScheme("s3 (no credentials provider)".to_string()));
217    }
218
219    // Handle role assumption if AWS_ROLE_ARN is set
220    if let Ok(role_arn) = env::var(ENV_VAR_ROLE_ARN) {
221        let role_provider = aws_config::sts::AssumeRoleProvider::builder(role_arn)
222            .configure(&aws_config)
223            .build()
224            .await;
225        aws_config = aws_config
226            .into_builder()
227            .credentials_provider(SharedCredentialsProvider::new(role_provider))
228            .build();
229    }
230
231    let bucket = url.host_str().ok_or(StorageError::InvalidUrl("missing bucket"))?;
232    let key = url.path().trim_start_matches('/');
233    if key.is_empty() {
234        return Err(StorageError::InvalidUrl("empty key"));
235    }
236
237    let client = S3Client::new(&aws_config);
238    let resp = client.get_object().bucket(bucket).key(key).send().await.map_err(|e| {
239        let code = e.code().unwrap_or("unknown");
240        tracing::debug!(error = %e, code = ?code, "S3 GetObject failed");
241        StorageError::S3(format!("{}: {}", code, e))
242    })?;
243
244    let max_size = config.max_file_size;
245
246    // Check content-length first
247    let capacity = resp.content_length.unwrap_or_default() as usize;
248    if capacity > max_size {
249        return Err(StorageError::SizeLimitExceeded(capacity));
250    }
251
252    // Stream and check size incrementally
253    let mut buffer = Vec::with_capacity(capacity);
254    let mut stream = resp.body;
255
256    while let Some(chunk) = stream.next().await {
257        let chunk = chunk.map_err(|e| StorageError::S3(e.to_string()))?;
258        buffer.extend_from_slice(chunk.chunk());
259        if buffer.len() > max_size {
260            return Err(StorageError::SizeLimitExceeded(buffer.len()));
261        }
262    }
263
264    Ok(buffer)
265}
266
267#[cfg(test)]
268mod tests {
269    use super::*;
270
271    #[tokio::test]
272    async fn test_unsupported_scheme() {
273        let result = fetch_uri("ftp://example.com/file").await;
274        assert!(matches!(result, Err(StorageError::UnsupportedScheme(_))));
275    }
276
277    #[tokio::test]
278    async fn test_invalid_url() {
279        let result = fetch_uri("not a url").await;
280        assert!(matches!(result, Err(StorageError::UriParse(_))));
281    }
282
283    #[tokio::test]
284    async fn test_file_url_disabled_by_default() {
285        // Skip test if dev mode or local file storage is enabled
286        if allow_file_urls() {
287            return;
288        }
289        let result = fetch_uri("file:///tmp/test").await;
290        assert!(matches!(result, Err(StorageError::UnsupportedScheme(_))));
291    }
292}