actpub_nodeinfo/
client.rs1use reqwest::redirect::Policy;
21use reqwest::{Client, ClientBuilder, header};
22use tracing::debug;
23use url::Url;
24
25use crate::{Discovery, Error, NodeInfo, Version};
26
27pub const DEFAULT_MAX_BODY_BYTES: u64 = 64 * 1024;
30
31pub fn recommended_client() -> Result<Client, Error> {
44 Ok(ClientBuilder::new()
45 .redirect(Policy::custom(|attempt| {
46 const MAX_REDIRECTS: usize = 2;
47 if attempt.previous().len() >= MAX_REDIRECTS {
48 return attempt.error("too many redirects");
49 }
50 let origin = attempt.previous().first().unwrap_or_else(|| attempt.url());
51 if origin.host_str() == attempt.url().host_str()
52 && origin.scheme() == attempt.url().scheme()
53 {
54 attempt.follow()
55 } else {
56 attempt.error("cross-origin redirect forbidden for NodeInfo")
57 }
58 }))
59 .build()?)
60}
61
62pub async fn fetch_discovery(host: &Url, client: &Client) -> Result<Discovery, Error> {
78 fetch_discovery_with_limit(host, client, DEFAULT_MAX_BODY_BYTES).await
79}
80
81pub async fn fetch_discovery_with_limit(
87 host: &Url,
88 client: &Client,
89 max_body_bytes: u64,
90) -> Result<Discovery, Error> {
91 let url = host.join(crate::WELL_KNOWN_PATH)?;
92 debug!(%url, max_body_bytes, "fetching NodeInfo discovery");
93
94 let body = request_capped(client, url, max_body_bytes).await?;
95 Ok(serde_json::from_slice(&body)?)
96}
97
98pub async fn fetch(host: &Url, version: Version, client: &Client) -> Result<NodeInfo, Error> {
111 fetch_with_limit(host, version, client, DEFAULT_MAX_BODY_BYTES).await
112}
113
114pub async fn fetch_with_limit(
120 host: &Url,
121 version: Version,
122 client: &Client,
123 max_body_bytes: u64,
124) -> Result<NodeInfo, Error> {
125 let discovery = fetch_discovery_with_limit(host, client, max_body_bytes).await?;
126 let link = discovery
127 .find_link(version)
128 .ok_or(Error::VersionNotAdvertised {
129 requested: version.as_str(),
130 })?;
131
132 debug!(url = %link.href, max_body_bytes, "fetching NodeInfo document");
133
134 let body = request_capped(client, link.href.clone(), max_body_bytes).await?;
135 Ok(serde_json::from_slice(&body)?)
136}
137
138async fn request_capped(client: &Client, url: Url, max_body_bytes: u64) -> Result<Vec<u8>, Error> {
139 let response = client
140 .get(url)
141 .header(header::ACCEPT, "application/json")
142 .send()
143 .await?;
144
145 let status = response.status();
146 if !status.is_success() {
147 return Err(Error::BadStatus(status.as_u16()));
148 }
149
150 read_capped(response, max_body_bytes).await
151}
152
153async fn read_capped(
154 mut response: reqwest::Response,
155 max_body_bytes: u64,
156) -> Result<Vec<u8>, Error> {
157 let mut acc: Vec<u8> = Vec::new();
158 while let Some(chunk) = response.chunk().await? {
159 if max_body_bytes > 0 && (acc.len() as u64 + chunk.len() as u64) > max_body_bytes {
160 return Err(Error::ResponseTooLarge(max_body_bytes));
161 }
162 acc.extend_from_slice(&chunk);
163 }
164 Ok(acc)
165}
166
167#[cfg(test)]
168mod tests {
169 use pretty_assertions::assert_eq;
170 use serde_json::json;
171 use wiremock::matchers::{method, path};
172 use wiremock::{Mock, MockServer, ResponseTemplate};
173
174 use super::*;
175
176 #[tokio::test]
177 async fn end_to_end_fetch_via_mock_server() {
178 let server = MockServer::start().await;
179 let base: Url = server.uri().parse().unwrap();
180
181 let nodeinfo_url = format!("{base}nodeinfo/2.1");
182
183 Mock::given(method("GET"))
184 .and(path("/.well-known/nodeinfo"))
185 .respond_with(ResponseTemplate::new(200).set_body_json(json!({
186 "links": [{
187 "rel": "http://nodeinfo.diaspora.software/ns/schema/2.1",
188 "href": nodeinfo_url
189 }]
190 })))
191 .mount(&server)
192 .await;
193
194 Mock::given(method("GET"))
195 .and(path("/nodeinfo/2.1"))
196 .respond_with(ResponseTemplate::new(200).set_body_json(json!({
197 "version": "2.1",
198 "software": { "name": "mock-server", "version": "0.1.0" },
199 "protocols": ["activitypub"],
200 "openRegistrations": false,
201 "usage": {}
202 })))
203 .mount(&server)
204 .await;
205
206 let client = Client::new();
207 let info = fetch(&base, Version::V2_1, &client).await.unwrap();
208
209 assert_eq!(info.version, Version::V2_1);
210 assert_eq!(info.software.name, "mock-server");
211 }
212
213 #[tokio::test]
214 async fn oversized_discovery_body_is_rejected() {
215 let server = MockServer::start().await;
216 let base: Url = server.uri().parse().unwrap();
217
218 let padding = "x".repeat(128 * 1024);
220 Mock::given(method("GET"))
221 .and(path("/.well-known/nodeinfo"))
222 .respond_with(ResponseTemplate::new(200).set_body_raw(
223 format!(r#"{{"links":[],"padding":"{padding}"}}"#).into_bytes(),
224 "application/json",
225 ))
226 .mount(&server)
227 .await;
228
229 let client = Client::new();
230 let err = fetch_discovery(&base, &client)
231 .await
232 .expect_err("oversized body must be rejected");
233 assert!(matches!(
234 err,
235 Error::ResponseTooLarge(DEFAULT_MAX_BODY_BYTES)
236 ));
237 }
238
239 #[tokio::test]
240 async fn recommended_client_rejects_cross_origin_redirect() {
241 let primary = MockServer::start().await;
242 let attacker = MockServer::start().await;
243
244 Mock::given(method("GET"))
248 .and(path("/.well-known/nodeinfo"))
249 .respond_with(ResponseTemplate::new(302).insert_header(
250 "Location",
251 format!("{}/.well-known/nodeinfo", attacker.uri()),
252 ))
253 .mount(&primary)
254 .await;
255
256 let client = recommended_client().expect("client builds");
257 let base: Url = primary.uri().parse().unwrap();
258 fetch_discovery(&base, &client)
262 .await
263 .expect_err("cross-origin redirect must fail");
264 }
265
266 #[tokio::test]
267 async fn missing_version_returns_specific_error() {
268 let server = MockServer::start().await;
269 let base: Url = server.uri().parse().unwrap();
270
271 Mock::given(method("GET"))
272 .and(path("/.well-known/nodeinfo"))
273 .respond_with(ResponseTemplate::new(200).set_body_json(json!({
274 "links": [{
275 "rel": "http://nodeinfo.diaspora.software/ns/schema/2.0",
276 "href": format!("{base}nodeinfo/2.0")
277 }]
278 })))
279 .mount(&server)
280 .await;
281
282 let client = Client::new();
283 let err = fetch(&base, Version::V2_1, &client).await.unwrap_err();
284 assert!(matches!(
285 err,
286 Error::VersionNotAdvertised { requested: "2.1" }
287 ));
288 }
289}