boundless_market/prover_utils/
storage.rs1use alloy::primitives::bytes::Buf;
18#[cfg(feature = "s3")]
19use aws_config::retry::RetryConfig;
20#[cfg(feature = "s3")]
21use aws_sdk_s3::{
22 config::{ProvideCredentials, SharedCredentialsProvider},
23 error::ProvideErrorMetadata,
24 Client as S3Client,
25};
26use futures_util::StreamExt;
27use std::env;
28use thiserror::Error;
29
30use super::config::MarketConfig;
31
32#[cfg(feature = "s3")]
33const ENV_VAR_ROLE_ARN: &str = "AWS_ROLE_ARN";
34
35fn is_dev_mode() -> bool {
37 env::var("RISC0_DEV_MODE")
38 .ok()
39 .map(|x| x.to_lowercase())
40 .filter(|x| x == "1" || x == "true" || x == "yes")
41 .is_some()
42}
43
44fn allow_local_file_storage() -> bool {
46 env::var("ALLOW_LOCAL_FILE_STORAGE")
47 .ok()
48 .map(|x| x.to_lowercase())
49 .filter(|x| x == "1" || x == "true" || x == "yes")
50 .is_some()
51}
52
53fn allow_file_urls() -> bool {
55 is_dev_mode() || allow_local_file_storage()
56}
57
58#[derive(Error, Debug)]
60#[non_exhaustive]
61pub enum StorageError {
62 #[error("unsupported URI scheme: {0}")]
64 UnsupportedScheme(String),
65
66 #[error("failed to parse URL: {0}")]
68 UriParse(#[from] url::ParseError),
69
70 #[error("invalid URL: {0}")]
72 InvalidUrl(&'static str),
73
74 #[error("resource size exceeds maximum allowed size ({0} bytes)")]
76 SizeLimitExceeded(usize),
77
78 #[error("file error: {0}")]
80 File(#[from] std::io::Error),
81
82 #[error("HTTP error: {0}")]
84 Http(String),
85
86 #[cfg(feature = "s3")]
88 #[error("AWS S3 error: {0}")]
89 S3(String),
90}
91
92#[allow(unused)]
97pub async fn fetch_uri(uri: &str) -> Result<Vec<u8>, StorageError> {
98 fetch_uri_with_config(uri, &MarketConfig::default()).await
99}
100
101pub async fn fetch_uri_with_config(
108 uri: &str,
109 config: &MarketConfig,
110) -> Result<Vec<u8>, StorageError> {
111 let parsed = url::Url::parse(uri)?;
112
113 match parsed.scheme() {
114 "file" => {
115 if !allow_file_urls() {
116 return Err(StorageError::UnsupportedScheme(
117 "file (not allowed in this context)".to_string(),
118 ));
119 }
120 fetch_file(parsed.path(), Some(config.max_file_size)).await
121 }
122 "http" | "https" => fetch_http(parsed, config).await,
123 #[cfg(feature = "s3")]
124 "s3" => fetch_s3(parsed, config).await,
125 scheme => Err(StorageError::UnsupportedScheme(scheme.to_string())),
126 }
127}
128
129async fn fetch_file(path: &str, max_size: Option<usize>) -> Result<Vec<u8>, StorageError> {
131 let metadata = tokio::fs::metadata(path).await?;
132 let size = metadata.len() as usize;
133
134 if let Some(max) = max_size {
135 if size > max {
136 return Err(StorageError::SizeLimitExceeded(size));
137 }
138 }
139
140 Ok(tokio::fs::read(path).await?)
141}
142
143async fn fetch_http(url: url::Url, config: &MarketConfig) -> Result<Vec<u8>, StorageError> {
145 use reqwest_middleware::ClientBuilder;
146 use reqwest_retry::{policies::ExponentialBackoff, RetryTransientMiddleware};
147
148 if !url.has_host() {
149 return Err(StorageError::InvalidUrl("missing host"));
150 }
151
152 let mut builder = ClientBuilder::new(reqwest::Client::new());
153
154 if let Some(max_retries) = config.max_fetch_retries {
156 let retry_policy = ExponentialBackoff::builder().build_with_max_retries(max_retries as u32);
157 let retry_middleware = RetryTransientMiddleware::new_with_policy(retry_policy);
158 builder = builder.with(retry_middleware);
159 }
160
161 let client = builder.build();
162 let response =
163 client.get(url.clone()).send().await.map_err(|e| StorageError::Http(e.to_string()))?;
164 let response = response.error_for_status().map_err(|e| StorageError::Http(e.to_string()))?;
165
166 let max_size = config.max_file_size;
167
168 let capacity = response.content_length().unwrap_or_default() as usize;
170 if capacity > max_size {
171 return Err(StorageError::SizeLimitExceeded(capacity));
172 }
173
174 let mut buffer = Vec::with_capacity(capacity);
176 let mut stream = response.bytes_stream();
177
178 while let Some(chunk) = stream.next().await {
179 let chunk = chunk.map_err(|e| StorageError::Http(e.to_string()))?;
180 buffer.extend_from_slice(chunk.chunk());
181 if buffer.len() > max_size {
182 return Err(StorageError::SizeLimitExceeded(buffer.len()));
183 }
184 }
185
186 Ok(buffer)
187}
188
189#[cfg(feature = "s3")]
197async fn fetch_s3(url: url::Url, config: &MarketConfig) -> Result<Vec<u8>, StorageError> {
198 let retry_config = if let Some(max_retries) = config.max_fetch_retries {
199 RetryConfig::standard().with_max_attempts(max_retries as u32 + 1)
200 } else {
201 RetryConfig::disabled()
202 };
203
204 let mut aws_config = aws_config::from_env().retry_config(retry_config).load().await;
205
206 if let Some(provider) = aws_config.credentials_provider() {
208 if let Err(e) = provider.provide_credentials().await {
209 tracing::debug!(error=%e, "Could not load AWS credentials for S3");
210 return Err(StorageError::UnsupportedScheme(format!(
211 "s3 (no credentials available: {})",
212 e
213 )));
214 }
215 } else {
216 return Err(StorageError::UnsupportedScheme("s3 (no credentials provider)".to_string()));
217 }
218
219 if let Ok(role_arn) = env::var(ENV_VAR_ROLE_ARN) {
221 let role_provider = aws_config::sts::AssumeRoleProvider::builder(role_arn)
222 .configure(&aws_config)
223 .build()
224 .await;
225 aws_config = aws_config
226 .into_builder()
227 .credentials_provider(SharedCredentialsProvider::new(role_provider))
228 .build();
229 }
230
231 let bucket = url.host_str().ok_or(StorageError::InvalidUrl("missing bucket"))?;
232 let key = url.path().trim_start_matches('/');
233 if key.is_empty() {
234 return Err(StorageError::InvalidUrl("empty key"));
235 }
236
237 let client = S3Client::new(&aws_config);
238 let resp = client.get_object().bucket(bucket).key(key).send().await.map_err(|e| {
239 let code = e.code().unwrap_or("unknown");
240 tracing::debug!(error = %e, code = ?code, "S3 GetObject failed");
241 StorageError::S3(format!("{}: {}", code, e))
242 })?;
243
244 let max_size = config.max_file_size;
245
246 let capacity = resp.content_length.unwrap_or_default() as usize;
248 if capacity > max_size {
249 return Err(StorageError::SizeLimitExceeded(capacity));
250 }
251
252 let mut buffer = Vec::with_capacity(capacity);
254 let mut stream = resp.body;
255
256 while let Some(chunk) = stream.next().await {
257 let chunk = chunk.map_err(|e| StorageError::S3(e.to_string()))?;
258 buffer.extend_from_slice(chunk.chunk());
259 if buffer.len() > max_size {
260 return Err(StorageError::SizeLimitExceeded(buffer.len()));
261 }
262 }
263
264 Ok(buffer)
265}
266
267#[cfg(test)]
268mod tests {
269 use super::*;
270
271 #[tokio::test]
272 async fn test_unsupported_scheme() {
273 let result = fetch_uri("ftp://example.com/file").await;
274 assert!(matches!(result, Err(StorageError::UnsupportedScheme(_))));
275 }
276
277 #[tokio::test]
278 async fn test_invalid_url() {
279 let result = fetch_uri("not a url").await;
280 assert!(matches!(result, Err(StorageError::UriParse(_))));
281 }
282
283 #[tokio::test]
284 async fn test_file_url_disabled_by_default() {
285 if allow_file_urls() {
287 return;
288 }
289 let result = fetch_uri("file:///tmp/test").await;
290 assert!(matches!(result, Err(StorageError::UnsupportedScheme(_))));
291 }
292}