1use reqwest::redirect::Policy;
21use reqwest::{Client, ClientBuilder, header};
22use tracing::debug;
23use url::Url;
24
25use crate::{Discovery, Error, NodeInfo, Version};
26
27pub const DEFAULT_MAX_BODY_BYTES: u64 = 64 * 1024;
30
31pub fn recommended_client() -> Result<Client, Error> {
44 Ok(ClientBuilder::new()
45 .redirect(Policy::custom(|attempt| {
46 const MAX_REDIRECTS: usize = 2;
47 if attempt.previous().len() >= MAX_REDIRECTS {
48 return attempt.error("too many redirects");
49 }
50 let origin = attempt.previous().first().unwrap_or_else(|| attempt.url());
51 if origin.host_str() == attempt.url().host_str()
52 && origin.scheme() == attempt.url().scheme()
53 {
54 attempt.follow()
55 } else {
56 attempt.error("cross-origin redirect forbidden for NodeInfo")
57 }
58 }))
59 .build()?)
60}
61
62pub async fn fetch_discovery(host: &Url, client: &Client) -> Result<Discovery, Error> {
78 fetch_discovery_with_limit(host, client, DEFAULT_MAX_BODY_BYTES).await
79}
80
81pub async fn fetch_discovery_with_limit(
87 host: &Url,
88 client: &Client,
89 max_body_bytes: u64,
90) -> Result<Discovery, Error> {
91 let url = host.join(crate::WELL_KNOWN_PATH)?;
92 debug!(%url, max_body_bytes, "fetching NodeInfo discovery");
93
94 let body = request_capped(client, url, max_body_bytes).await?;
95 Ok(serde_json::from_slice(&body)?)
96}
97
98pub async fn fetch(host: &Url, version: Version, client: &Client) -> Result<NodeInfo, Error> {
111 fetch_with_limit(host, version, client, DEFAULT_MAX_BODY_BYTES).await
112}
113
114pub async fn fetch_with_limit(
120 host: &Url,
121 version: Version,
122 client: &Client,
123 max_body_bytes: u64,
124) -> Result<NodeInfo, Error> {
125 let discovery = fetch_discovery_with_limit(host, client, max_body_bytes).await?;
126 let link = discovery
127 .find_link(version)
128 .ok_or(Error::VersionNotAdvertised {
129 requested: version.as_str(),
130 })?;
131
132 if !same_origin(host, &link.href) {
143 return Err(Error::CrossOriginHref {
144 discovery: host.clone(),
145 href: link.href.clone(),
146 });
147 }
148
149 debug!(url = %link.href, max_body_bytes, "fetching NodeInfo document");
150
151 let body = request_capped(client, link.href.clone(), max_body_bytes).await?;
152 Ok(serde_json::from_slice(&body)?)
153}
154
155fn same_origin(a: &Url, b: &Url) -> bool {
161 a.scheme().eq_ignore_ascii_case(b.scheme())
162 && match (a.host_str(), b.host_str()) {
163 (Some(ha), Some(hb)) => ha.eq_ignore_ascii_case(hb),
164 _ => false,
165 }
166 && a.port_or_known_default() == b.port_or_known_default()
167}
168
169async fn request_capped(client: &Client, url: Url, max_body_bytes: u64) -> Result<Vec<u8>, Error> {
170 let response = client
171 .get(url)
172 .header(header::ACCEPT, "application/json")
173 .send()
174 .await?;
175
176 let status = response.status();
177 if !status.is_success() {
178 return Err(Error::BadStatus(status.as_u16()));
179 }
180
181 read_capped(response, max_body_bytes).await
182}
183
184async fn read_capped(
185 mut response: reqwest::Response,
186 max_body_bytes: u64,
187) -> Result<Vec<u8>, Error> {
188 let mut acc: Vec<u8> = Vec::new();
189 while let Some(chunk) = response.chunk().await? {
190 if max_body_bytes > 0 && (acc.len() as u64 + chunk.len() as u64) > max_body_bytes {
191 return Err(Error::ResponseTooLarge(max_body_bytes));
192 }
193 acc.extend_from_slice(&chunk);
194 }
195 Ok(acc)
196}
197
198#[cfg(test)]
199mod tests {
200 use pretty_assertions::assert_eq;
201 use serde_json::json;
202 use wiremock::matchers::{method, path};
203 use wiremock::{Mock, MockServer, ResponseTemplate};
204
205 use super::*;
206
207 #[tokio::test]
208 async fn end_to_end_fetch_via_mock_server() {
209 let server = MockServer::start().await;
210 let base: Url = server.uri().parse().unwrap();
211
212 let nodeinfo_url = format!("{base}nodeinfo/2.1");
213
214 Mock::given(method("GET"))
215 .and(path("/.well-known/nodeinfo"))
216 .respond_with(ResponseTemplate::new(200).set_body_json(json!({
217 "links": [{
218 "rel": "http://nodeinfo.diaspora.software/ns/schema/2.1",
219 "href": nodeinfo_url
220 }]
221 })))
222 .mount(&server)
223 .await;
224
225 Mock::given(method("GET"))
226 .and(path("/nodeinfo/2.1"))
227 .respond_with(ResponseTemplate::new(200).set_body_json(json!({
228 "version": "2.1",
229 "software": { "name": "mock-server", "version": "0.1.0" },
230 "protocols": ["activitypub"],
231 "openRegistrations": false,
232 "usage": {}
233 })))
234 .mount(&server)
235 .await;
236
237 let client = Client::new();
238 let info = fetch(&base, Version::V2_1, &client).await.unwrap();
239
240 assert_eq!(info.version, Version::V2_1);
241 assert_eq!(info.software.name, "mock-server");
242 }
243
244 #[tokio::test]
245 async fn fetch_refuses_cross_origin_href_advertised_by_discovery() {
246 let primary = MockServer::start().await;
256 let attacker = MockServer::start().await;
257 let attacker_nodeinfo = format!("{}/nodeinfo/2.1", attacker.uri());
258 Mock::given(method("GET"))
259 .and(path("/.well-known/nodeinfo"))
260 .respond_with(ResponseTemplate::new(200).set_body_json(json!({
261 "links": [{
262 "rel": "http://nodeinfo.diaspora.software/ns/schema/2.1",
263 "href": attacker_nodeinfo
264 }]
265 })))
266 .mount(&primary)
267 .await;
268
269 let client = Client::new();
270 let base: Url = primary.uri().parse().unwrap();
271 let err = fetch(&base, Version::V2_1, &client)
272 .await
273 .expect_err("cross-origin href must be refused");
274 assert!(
275 matches!(err, Error::CrossOriginHref { .. }),
276 "expected CrossOriginHref, got {err:?}",
277 );
278 }
279
280 #[tokio::test]
281 async fn oversized_discovery_body_is_rejected() {
282 let server = MockServer::start().await;
283 let base: Url = server.uri().parse().unwrap();
284
285 let padding = "x".repeat(128 * 1024);
287 Mock::given(method("GET"))
288 .and(path("/.well-known/nodeinfo"))
289 .respond_with(ResponseTemplate::new(200).set_body_raw(
290 format!(r#"{{"links":[],"padding":"{padding}"}}"#).into_bytes(),
291 "application/json",
292 ))
293 .mount(&server)
294 .await;
295
296 let client = Client::new();
297 let err = fetch_discovery(&base, &client)
298 .await
299 .expect_err("oversized body must be rejected");
300 assert!(matches!(
301 err,
302 Error::ResponseTooLarge(DEFAULT_MAX_BODY_BYTES)
303 ));
304 }
305
306 #[tokio::test]
307 async fn recommended_client_rejects_cross_origin_redirect() {
308 let primary = MockServer::start().await;
309 let attacker = MockServer::start().await;
310
311 Mock::given(method("GET"))
315 .and(path("/.well-known/nodeinfo"))
316 .respond_with(ResponseTemplate::new(302).insert_header(
317 "Location",
318 format!("{}/.well-known/nodeinfo", attacker.uri()),
319 ))
320 .mount(&primary)
321 .await;
322
323 let client = recommended_client().expect("client builds");
324 let base: Url = primary.uri().parse().unwrap();
325 fetch_discovery(&base, &client)
329 .await
330 .expect_err("cross-origin redirect must fail");
331 }
332
333 #[tokio::test]
334 async fn missing_version_returns_specific_error() {
335 let server = MockServer::start().await;
336 let base: Url = server.uri().parse().unwrap();
337
338 Mock::given(method("GET"))
339 .and(path("/.well-known/nodeinfo"))
340 .respond_with(ResponseTemplate::new(200).set_body_json(json!({
341 "links": [{
342 "rel": "http://nodeinfo.diaspora.software/ns/schema/2.0",
343 "href": format!("{base}nodeinfo/2.0")
344 }]
345 })))
346 .mount(&server)
347 .await;
348
349 let client = Client::new();
350 let err = fetch(&base, Version::V2_1, &client).await.unwrap_err();
351 assert!(matches!(
352 err,
353 Error::VersionNotAdvertised { requested: "2.1" }
354 ));
355 }
356}