1use std::net::IpAddr;
4use tokio::net::lookup_host;
5use tracing::debug;
6use url::Url;
7
8use super::{
9 FetcherError,
10 canonical::extract_canonical_url,
11 charset::{Detected, decode_to_utf8},
12 dns::SSRF_LEVEL,
13 ssrf::{self, SsrfLevel},
14};
15
16#[derive(Debug, Clone)]
18pub struct FetchedPage {
19 pub final_url: Url,
21
22 pub canonical_url: Url,
24
25 pub status: u16,
27
28 pub content_type: Option<String>,
30
31 pub body: String,
33
34 pub charset: Detected,
36
37 pub link_header: Option<String>,
39
40 pub etag: Option<String>,
42
43 pub last_modified: Option<String>,
45
46 pub cache_control: Option<String>,
48
49 pub expires: Option<String>,
51
52 pub retry_after: Option<String>,
55}
56
57#[derive(Debug, Clone, Default)]
63pub struct ConditionalGet {
64 pub if_none_match: Option<String>,
65 pub if_modified_since: Option<String>,
66}
67
68pub async fn fetch_url(
70 client: &reqwest::Client,
71 url: &Url,
72 level: SsrfLevel,
73 project_root: Option<&std::path::Path>,
74 har_recorder: Option<&std::sync::Arc<super::har::HarRecorder>>,
75) -> Result<FetchedPage, FetcherError> {
76 fetch_url_conditional(
77 client,
78 url,
79 level,
80 project_root,
81 har_recorder,
82 &ConditionalGet::default(),
83 )
84 .await
85}
86
87pub async fn fetch_url_conditional(
98 client: &reqwest::Client,
99 url: &Url,
100 level: SsrfLevel,
101 project_root: Option<&std::path::Path>,
102 har_recorder: Option<&std::sync::Arc<super::har::HarRecorder>>,
103 cond: &ConditionalGet,
104) -> Result<FetchedPage, FetcherError> {
105 let start = std::time::Instant::now();
106 ssrf::validate_url_with_project_root(url, level, project_root)?;
107 let host = url
108 .host_str()
109 .ok_or(FetcherError::Ssrf(ssrf::SsrfError::NoHost))?;
110 let port = url.port_or_known_default().unwrap_or(0);
111
112 let addrs = resolve_host(host, port).await?;
118 ssrf::validate_addresses(&addrs, level)?;
119
120 let mut req = client.get(url.clone());
121 let mut request_headers_pairs: Vec<(String, String)> = Vec::new();
125 if let Some(etag) = &cond.if_none_match {
126 req = req.header(reqwest::header::IF_NONE_MATCH, etag);
127 request_headers_pairs.push(("if-none-match".into(), etag.clone()));
128 }
129 if let Some(lm) = &cond.if_modified_since {
130 req = req.header(reqwest::header::IF_MODIFIED_SINCE, lm);
131 request_headers_pairs.push(("if-modified-since".into(), lm.clone()));
132 }
133 let response = SSRF_LEVEL.scope(level, req.send()).await?;
139 let status = response.status().as_u16();
140 let final_url = Url::parse(response.url().as_str())?;
141
142 let response_headers_pairs: Vec<(String, String)> = response
144 .headers()
145 .iter()
146 .map(|(k, v)| (k.as_str().to_string(), v.to_str().unwrap_or("").to_string()))
147 .collect();
148
149 let content_type = response
150 .headers()
151 .get(reqwest::header::CONTENT_TYPE)
152 .and_then(|v| v.to_str().ok())
153 .map(str::to_string);
154 let link_header = response
155 .headers()
156 .get(reqwest::header::LINK)
157 .and_then(|v| v.to_str().ok())
158 .map(str::to_string);
159 let etag = response
160 .headers()
161 .get(reqwest::header::ETAG)
162 .and_then(|v| v.to_str().ok())
163 .map(str::to_string);
164 let last_modified = response
165 .headers()
166 .get(reqwest::header::LAST_MODIFIED)
167 .and_then(|v| v.to_str().ok())
168 .map(str::to_string);
169 let cache_control = response
170 .headers()
171 .get(reqwest::header::CACHE_CONTROL)
172 .and_then(|v| v.to_str().ok())
173 .map(str::to_string);
174 let expires = response
175 .headers()
176 .get(reqwest::header::EXPIRES)
177 .and_then(|v| v.to_str().ok())
178 .map(str::to_string);
179 let retry_after = response
180 .headers()
181 .get(reqwest::header::RETRY_AFTER)
182 .and_then(|v| v.to_str().ok())
183 .map(str::to_string);
184
185 let challenge = super::challenge::detect(status, &response_headers_pairs);
192
193 let bytes = response.bytes().await?;
194
195 if let Some(recorder) = har_recorder {
196 let ex = super::har::RecordedExchange {
197 url: final_url.to_string(),
198 method: "GET".to_string(),
199 request_headers: request_headers_pairs,
200 response_status: status,
201 response_headers: response_headers_pairs,
202 response_body: bytes.to_vec(),
203 duration: start.elapsed(),
204 };
205 if let Err(e) = recorder.record(ex).await {
206 tracing::warn!(target: "rover::fetcher", error = ?e, "failed to record har entry");
207 }
208 }
209
210 if let Some(kind) = challenge {
211 return Err(FetcherError::BotChallenge {
212 url: final_url.to_string(),
213 provider: kind.provider().to_string(),
214 });
215 }
216
217 let (body, charset) = decode_to_utf8(content_type.as_deref(), &bytes);
218
219 if let Some(ref ct) = content_type
220 && ct.to_ascii_lowercase().contains("charset=")
221 {
222 debug!(
223 target: "rover::fetcher::charset",
224 http_charset = ct.as_str(),
225 detected = %charset.encoding.name(),
226 "charset detection complete"
227 );
228 }
229
230 let canonical_url = extract_canonical_url(&body, &final_url, link_header.as_deref());
231
232 Ok(FetchedPage {
233 final_url,
234 canonical_url,
235 status,
236 content_type,
237 body,
238 charset,
239 link_header,
240 etag,
241 last_modified,
242 cache_control,
243 expires,
244 retry_after,
245 })
246}
247
248async fn resolve_host(host: &str, port: u16) -> Result<Vec<IpAddr>, FetcherError> {
249 let target = format!("{host}:{port}");
250 let iter = lookup_host(target.as_str())
251 .await
252 .map_err(|e| FetcherError::Dns {
253 host: host.to_string(),
254 source: e,
255 })?;
256 Ok(iter.map(|sa| sa.ip()).collect())
257}