Skip to main content

rs_guard/
http.rs

1//! Shared HTTP utilities for GitHub API communication.
2//!
3//! Provides a single [`github_headers`] builder used by both diff fetching
4//! and review submission, along with [`validate_github_base_url`] for
5//! strict allowlisting of GitHub API endpoints.
6
7use crate::error::RsGuardError;
8use crate::llm::providers;
9use reqwest::header::{self, HeaderMap, HeaderValue};
10use url::Url;
11
12/// User-Agent string derived from package metadata at compile time.
13const USER_AGENT: &str = concat!(env!("CARGO_PKG_NAME"), "/", env!("CARGO_PKG_VERSION"));
14
15/// Builds a [`reqwest::Client`] with a standard timeout.
16///
17/// Shared helper to avoid duplicating client construction across modules
18/// that communicate with the GitHub API.
19///
20/// Note: User-Agent is added per-request via [`github_headers()`].
21///
22/// # Errors
23///
24/// Returns [`RsGuardError::Config`] if the TLS backend fails to initialise.
25pub fn build_github_http_client(
26    timeout: std::time::Duration,
27) -> Result<reqwest::Client, RsGuardError> {
28    reqwest::Client::builder()
29        .timeout(timeout)
30        .build()
31        .map_err(|e| RsGuardError::Config(format!("Failed to build HTTP client: {}", e)))
32}
33
34/// Allowed GitHub API base URLs.
35///
36/// Only HTTPS URLs matching these patterns are permitted. This prevents
37/// accidentally sending `Authorization` headers to arbitrary hosts.
38const ALLOWED_BASE_URLS: &[&str] = &["https://api.github.com"];
39
40/// Validates that a GitHub API base URL is on the allowlist.
41///
42/// Accepts:
43/// - Exact match against `ALLOWED_BASE_URLS` (e.g. `https://api.github.com`)
44/// - GitHub Enterprise pattern: `https://{host}/api/v3` where `{host}` is
45///   any valid hostname
46/// - Loopback addresses (`http://127.0.0.1`, `http://localhost`) for testing
47///
48/// All non-loopback URLs must use HTTPS. HTTP URLs to external hosts are rejected.
49///
50/// # Errors
51///
52/// Returns [`RsGuardError::Config`] if the URL is not allowed.
53pub fn validate_github_base_url(base_url: &str) -> Result<(), RsGuardError> {
54    let trimmed = base_url.trim_end_matches('/');
55
56    if trimmed.starts_with("http://127.0.0.1") || trimmed.starts_with("http://localhost") {
57        return Ok(());
58    }
59
60    if !trimmed.starts_with("https://") {
61        return Err(RsGuardError::Config(format!(
62            "GitHub base URL must use HTTPS: '{}'. HTTP is not allowed.",
63            base_url
64        )));
65    }
66
67    if ALLOWED_BASE_URLS.contains(&trimmed) {
68        return Ok(());
69    }
70
71    if trimmed.ends_with("/api/v3") {
72        return Ok(());
73    }
74
75    Err(RsGuardError::Config(format!(
76        "GitHub base URL '{}' is not in the allowlist. \
77         Allowed: {} or https://<enterprise-host>/api/v3",
78        base_url,
79        ALLOWED_BASE_URLS.join(", ")
80    )))
81}
82
83/// Validates that a provider API base URL is safe for use in CI mode.
84///
85/// In CI mode, TOML `base_url` overrides are restricted to an exact allowlist
86/// of known LLM provider (scheme, host) pairs to prevent SSRF attacks where a
87/// malicious `.reviewer.toml` could redirect API calls (and auth headers) to
88/// an attacker-controlled server.
89///
90/// Loopback addresses (`127.0.0.1`, `localhost`) are **rejected** in CI mode
91/// to prevent token exfiltration to attacker-controlled local servers.
92///
93/// All URLs must use HTTPS.
94///
95/// # Errors
96///
97/// Returns [`RsGuardError::Config`] if the URL is not allowed.
98pub fn validate_provider_base_url(base_url: &str) -> Result<(), RsGuardError> {
99    let parsed = Url::parse(base_url).map_err(|_| {
100        RsGuardError::Config(format!(
101            "Provider base URL is malformed: '{}'. Expected format: https://host/path",
102            base_url
103        ))
104    })?;
105
106    if parsed.scheme() != "https" {
107        return Err(RsGuardError::Config(format!(
108            "Provider base URL must use HTTPS in CI mode: '{}'. HTTP is not allowed.",
109            base_url
110        )));
111    }
112
113    let host = parsed.host_str().ok_or_else(|| {
114        RsGuardError::Config(format!(
115            "Provider base URL is malformed: '{}'. No host found.",
116            base_url
117        ))
118    })?;
119
120    if host == "127.0.0.1"
121        || host == "localhost"
122        || host == "[::1]"
123        || host == "0.0.0.0"
124        || host == "[::]"
125    {
126        return Err(RsGuardError::Config(format!(
127            "Provider base URL '{}' uses loopback address, which is not allowed in CI mode \
128             to prevent token exfiltration. Use a known provider endpoint or run in local mode.",
129            base_url
130        )));
131    }
132
133    let ci_hosts = providers::all_ci_allowed_hosts();
134    for &(allowed_scheme, allowed_host) in &ci_hosts {
135        if parsed.scheme() == allowed_scheme && host == allowed_host {
136            return Ok(());
137        }
138    }
139
140    let allowed_display: Vec<String> = ci_hosts
141        .iter()
142        .map(|(s, h)| format!("{}://{}", s, h))
143        .collect();
144
145    Err(RsGuardError::Config(format!(
146        "Provider base URL '{}' (host: {}) is not in the CI allowlist. \
147         Allowed hosts: {}. \
148         To use a custom endpoint, run in local mode (unset GITHUB_ACTIONS).",
149        base_url,
150        host,
151        allowed_display.join(", ")
152    )))
153}
154
155/// Builds default headers for GitHub API requests.
156///
157/// Includes `Authorization`, `Accept`, `X-GitHub-Api-Version`, and
158/// `User-Agent` headers. The `User-Agent` is derived from
159/// `CARGO_PKG_NAME`/`CARGO_PKG_VERSION` at compile time.
160///
161/// # Errors
162///
163/// Returns [`RsGuardError::Config`] if the token contains invalid
164/// header characters.
165pub fn github_headers(token: &str) -> Result<HeaderMap, RsGuardError> {
166    let mut headers = HeaderMap::new();
167    headers.insert(
168        header::ACCEPT,
169        HeaderValue::from_static("application/vnd.github+json"),
170    );
171    headers.insert(
172        header::AUTHORIZATION,
173        HeaderValue::from_str(&format!("Bearer {}", token))
174            .map_err(|e| RsGuardError::Config(format!("Invalid GitHub token format: {}", e)))?,
175    );
176    headers.insert(
177        "X-GitHub-Api-Version",
178        HeaderValue::from_static("2022-11-28"),
179    );
180    headers.insert(header::USER_AGENT, HeaderValue::from_static(USER_AGENT));
181    Ok(headers)
182}
183
184/// Builds headers specifically for fetching PR diffs.
185///
186/// Same as [`github_headers`] but uses the `application/vnd.github.v3.diff`
187/// accept header instead of `application/vnd.github+json`.
188///
189/// # Errors
190///
191/// Returns [`RsGuardError::Config`] if the token contains invalid
192/// header characters.
193pub fn github_diff_headers(token: &str) -> Result<HeaderMap, RsGuardError> {
194    let mut headers = github_headers(token)?;
195    headers.insert(
196        header::ACCEPT,
197        HeaderValue::from_static("application/vnd.github.v3.diff"),
198    );
199    Ok(headers)
200}
201
202#[cfg(test)]
203mod tests {
204    use super::*;
205
206    #[test]
207    fn test_validate_allowed_url() {
208        assert!(validate_github_base_url("https://api.github.com").is_ok());
209    }
210
211    #[test]
212    fn test_validate_allowed_url_trailing_slash() {
213        assert!(validate_github_base_url("https://api.github.com/").is_ok());
214    }
215
216    #[test]
217    fn test_validate_enterprise_url() {
218        assert!(validate_github_base_url("https://github.mycompany.com/api/v3").is_ok());
219    }
220
221    #[test]
222    fn test_reject_http() {
223        let result = validate_github_base_url("http://api.github.com");
224        assert!(result.is_err());
225        assert!(result.unwrap_err().to_string().contains("HTTPS"));
226    }
227
228    #[test]
229    fn test_allow_loopback_http() {
230        assert!(validate_github_base_url("http://127.0.0.1:8080").is_ok());
231        assert!(validate_github_base_url("http://localhost:3000").is_ok());
232    }
233
234    #[test]
235    fn test_reject_unknown_host() {
236        let result = validate_github_base_url("https://evil.example.com");
237        assert!(result.is_err());
238        assert!(result.unwrap_err().to_string().contains("allowlist"));
239    }
240
241    #[test]
242    fn test_reject_partial_match() {
243        let result = validate_github_base_url("https://not-api.github.com");
244        assert!(result.is_err());
245    }
246
247    #[test]
248    fn test_github_headers_valid_token() {
249        let headers = github_headers("valid-token-123").unwrap();
250        assert_eq!(
251            headers.get(header::AUTHORIZATION).unwrap(),
252            "Bearer valid-token-123"
253        );
254        assert_eq!(headers.get(header::USER_AGENT).unwrap(), USER_AGENT);
255    }
256
257    #[test]
258    fn test_github_headers_invalid_token() {
259        let result = github_headers("token\x00with\x01control");
260        assert!(result.is_err());
261    }
262
263    #[test]
264    fn test_github_diff_headers_accept() {
265        let headers = github_diff_headers("tok").unwrap();
266        assert_eq!(
267            headers.get(header::ACCEPT).unwrap(),
268            "application/vnd.github.v3.diff"
269        );
270    }
271
272    #[test]
273    fn test_provider_base_url_allows_known_hosts() {
274        assert!(validate_provider_base_url("https://api.deepseek.com").is_ok());
275        assert!(validate_provider_base_url("https://api.deepseek.com/v1").is_ok());
276        assert!(validate_provider_base_url("https://api.moonshot.ai/v1").is_ok());
277        assert!(validate_provider_base_url(
278            "https://dashscope-intl.aliyuncs.com/compatible-mode/v1"
279        )
280        .is_ok());
281        assert!(validate_provider_base_url("https://openrouter.ai/api/v1").is_ok());
282        assert!(validate_provider_base_url("https://api.openai.com/v1").is_ok());
283    }
284
285    #[test]
286    fn test_provider_base_url_rejects_loopback() {
287        let result = validate_provider_base_url("http://127.0.0.1:11434/v1");
288        assert!(result.is_err());
289        let err = result.unwrap_err().to_string();
290        assert!(err.contains("loopback") || err.contains("HTTPS"));
291
292        let result = validate_provider_base_url("https://localhost:8080");
293        assert!(result.is_err());
294        let err = result.unwrap_err().to_string();
295        assert!(err.contains("loopback"));
296    }
297
298    #[test]
299    fn test_provider_base_url_rejects_subdomain_spoof() {
300        let result = validate_provider_base_url("https://api.deepseek.com.evil.com/v1");
301        assert!(result.is_err());
302        let err = result.unwrap_err().to_string();
303        assert!(err.contains("not in the CI allowlist"));
304    }
305
306    #[test]
307    fn test_provider_base_url_rejects_unknown_host() {
308        let result = validate_provider_base_url("https://evil.example.com/v1");
309        assert!(result.is_err());
310        let err = result.unwrap_err().to_string();
311        assert!(err.contains("not in the CI allowlist"));
312    }
313
314    #[test]
315    fn test_provider_base_url_rejects_http() {
316        let result = validate_provider_base_url("http://api.deepseek.com");
317        assert!(result.is_err());
318        let err = result.unwrap_err().to_string();
319        assert!(err.contains("HTTPS"));
320    }
321
322    #[test]
323    fn test_provider_base_url_rejects_malformed() {
324        let result = validate_provider_base_url("not-a-url");
325        assert!(result.is_err());
326        let err = result.unwrap_err().to_string();
327        assert!(err.contains("malformed"));
328    }
329
330    #[test]
331    fn test_provider_base_url_rejects_ipv6_loopback() {
332        let result = validate_provider_base_url("https://[::1]:11434/v1");
333        assert!(result.is_err());
334        let err = result.unwrap_err().to_string();
335        assert!(err.contains("loopback"));
336    }
337
338    #[test]
339    fn test_provider_base_url_rejects_bind_all() {
340        let result = validate_provider_base_url("https://0.0.0.0:8080/v1");
341        assert!(result.is_err());
342        let err = result.unwrap_err().to_string();
343        assert!(err.contains("loopback"));
344
345        let result = validate_provider_base_url("https://[::]:8080/v1");
346        assert!(result.is_err());
347        let err = result.unwrap_err().to_string();
348        assert!(err.contains("loopback"));
349    }
350}