1use crate::error::RsGuardError;
8use crate::llm::providers;
9use reqwest::header::{self, HeaderMap, HeaderValue};
10use url::Url;
11
12const USER_AGENT: &str = concat!(env!("CARGO_PKG_NAME"), "/", env!("CARGO_PKG_VERSION"));
14
15pub fn build_github_http_client(
26 timeout: std::time::Duration,
27) -> Result<reqwest::Client, RsGuardError> {
28 reqwest::Client::builder()
29 .timeout(timeout)
30 .build()
31 .map_err(|e| RsGuardError::Config(format!("Failed to build HTTP client: {}", e)))
32}
33
34const ALLOWED_BASE_URLS: &[&str] = &["https://api.github.com"];
39
40pub fn validate_github_base_url(base_url: &str) -> Result<(), RsGuardError> {
54 let trimmed = base_url.trim_end_matches('/');
55
56 if trimmed.starts_with("http://127.0.0.1") || trimmed.starts_with("http://localhost") {
57 return Ok(());
58 }
59
60 if !trimmed.starts_with("https://") {
61 return Err(RsGuardError::Config(format!(
62 "GitHub base URL must use HTTPS: '{}'. HTTP is not allowed.",
63 base_url
64 )));
65 }
66
67 if ALLOWED_BASE_URLS.contains(&trimmed) {
68 return Ok(());
69 }
70
71 if trimmed.ends_with("/api/v3") {
72 return Ok(());
73 }
74
75 Err(RsGuardError::Config(format!(
76 "GitHub base URL '{}' is not in the allowlist. \
77 Allowed: {} or https://<enterprise-host>/api/v3",
78 base_url,
79 ALLOWED_BASE_URLS.join(", ")
80 )))
81}
82
83pub fn validate_provider_base_url(base_url: &str) -> Result<(), RsGuardError> {
99 let parsed = Url::parse(base_url).map_err(|_| {
100 RsGuardError::Config(format!(
101 "Provider base URL is malformed: '{}'. Expected format: https://host/path",
102 base_url
103 ))
104 })?;
105
106 if parsed.scheme() != "https" {
107 return Err(RsGuardError::Config(format!(
108 "Provider base URL must use HTTPS in CI mode: '{}'. HTTP is not allowed.",
109 base_url
110 )));
111 }
112
113 let host = parsed.host_str().ok_or_else(|| {
114 RsGuardError::Config(format!(
115 "Provider base URL is malformed: '{}'. No host found.",
116 base_url
117 ))
118 })?;
119
120 if host == "127.0.0.1"
121 || host == "localhost"
122 || host == "[::1]"
123 || host == "0.0.0.0"
124 || host == "[::]"
125 {
126 return Err(RsGuardError::Config(format!(
127 "Provider base URL '{}' uses loopback address, which is not allowed in CI mode \
128 to prevent token exfiltration. Use a known provider endpoint or run in local mode.",
129 base_url
130 )));
131 }
132
133 let ci_hosts = providers::all_ci_allowed_hosts();
134 for &(allowed_scheme, allowed_host) in &ci_hosts {
135 if parsed.scheme() == allowed_scheme && host == allowed_host {
136 return Ok(());
137 }
138 }
139
140 let allowed_display: Vec<String> = ci_hosts
141 .iter()
142 .map(|(s, h)| format!("{}://{}", s, h))
143 .collect();
144
145 Err(RsGuardError::Config(format!(
146 "Provider base URL '{}' (host: {}) is not in the CI allowlist. \
147 Allowed hosts: {}. \
148 To use a custom endpoint, run in local mode (unset GITHUB_ACTIONS).",
149 base_url,
150 host,
151 allowed_display.join(", ")
152 )))
153}
154
155pub fn github_headers(token: &str) -> Result<HeaderMap, RsGuardError> {
166 let mut headers = HeaderMap::new();
167 headers.insert(
168 header::ACCEPT,
169 HeaderValue::from_static("application/vnd.github+json"),
170 );
171 headers.insert(
172 header::AUTHORIZATION,
173 HeaderValue::from_str(&format!("Bearer {}", token))
174 .map_err(|e| RsGuardError::Config(format!("Invalid GitHub token format: {}", e)))?,
175 );
176 headers.insert(
177 "X-GitHub-Api-Version",
178 HeaderValue::from_static("2022-11-28"),
179 );
180 headers.insert(header::USER_AGENT, HeaderValue::from_static(USER_AGENT));
181 Ok(headers)
182}
183
184pub fn github_diff_headers(token: &str) -> Result<HeaderMap, RsGuardError> {
194 let mut headers = github_headers(token)?;
195 headers.insert(
196 header::ACCEPT,
197 HeaderValue::from_static("application/vnd.github.v3.diff"),
198 );
199 Ok(headers)
200}
201
202#[cfg(test)]
203mod tests {
204 use super::*;
205
206 #[test]
207 fn test_validate_allowed_url() {
208 assert!(validate_github_base_url("https://api.github.com").is_ok());
209 }
210
211 #[test]
212 fn test_validate_allowed_url_trailing_slash() {
213 assert!(validate_github_base_url("https://api.github.com/").is_ok());
214 }
215
216 #[test]
217 fn test_validate_enterprise_url() {
218 assert!(validate_github_base_url("https://github.mycompany.com/api/v3").is_ok());
219 }
220
221 #[test]
222 fn test_reject_http() {
223 let result = validate_github_base_url("http://api.github.com");
224 assert!(result.is_err());
225 assert!(result.unwrap_err().to_string().contains("HTTPS"));
226 }
227
228 #[test]
229 fn test_allow_loopback_http() {
230 assert!(validate_github_base_url("http://127.0.0.1:8080").is_ok());
231 assert!(validate_github_base_url("http://localhost:3000").is_ok());
232 }
233
234 #[test]
235 fn test_reject_unknown_host() {
236 let result = validate_github_base_url("https://evil.example.com");
237 assert!(result.is_err());
238 assert!(result.unwrap_err().to_string().contains("allowlist"));
239 }
240
241 #[test]
242 fn test_reject_partial_match() {
243 let result = validate_github_base_url("https://not-api.github.com");
244 assert!(result.is_err());
245 }
246
247 #[test]
248 fn test_github_headers_valid_token() {
249 let headers = github_headers("valid-token-123").unwrap();
250 assert_eq!(
251 headers.get(header::AUTHORIZATION).unwrap(),
252 "Bearer valid-token-123"
253 );
254 assert_eq!(headers.get(header::USER_AGENT).unwrap(), USER_AGENT);
255 }
256
257 #[test]
258 fn test_github_headers_invalid_token() {
259 let result = github_headers("token\x00with\x01control");
260 assert!(result.is_err());
261 }
262
263 #[test]
264 fn test_github_diff_headers_accept() {
265 let headers = github_diff_headers("tok").unwrap();
266 assert_eq!(
267 headers.get(header::ACCEPT).unwrap(),
268 "application/vnd.github.v3.diff"
269 );
270 }
271
272 #[test]
273 fn test_provider_base_url_allows_known_hosts() {
274 assert!(validate_provider_base_url("https://api.deepseek.com").is_ok());
275 assert!(validate_provider_base_url("https://api.deepseek.com/v1").is_ok());
276 assert!(validate_provider_base_url("https://api.moonshot.ai/v1").is_ok());
277 assert!(validate_provider_base_url(
278 "https://dashscope-intl.aliyuncs.com/compatible-mode/v1"
279 )
280 .is_ok());
281 assert!(validate_provider_base_url("https://openrouter.ai/api/v1").is_ok());
282 assert!(validate_provider_base_url("https://api.openai.com/v1").is_ok());
283 }
284
285 #[test]
286 fn test_provider_base_url_rejects_loopback() {
287 let result = validate_provider_base_url("http://127.0.0.1:11434/v1");
288 assert!(result.is_err());
289 let err = result.unwrap_err().to_string();
290 assert!(err.contains("loopback") || err.contains("HTTPS"));
291
292 let result = validate_provider_base_url("https://localhost:8080");
293 assert!(result.is_err());
294 let err = result.unwrap_err().to_string();
295 assert!(err.contains("loopback"));
296 }
297
298 #[test]
299 fn test_provider_base_url_rejects_subdomain_spoof() {
300 let result = validate_provider_base_url("https://api.deepseek.com.evil.com/v1");
301 assert!(result.is_err());
302 let err = result.unwrap_err().to_string();
303 assert!(err.contains("not in the CI allowlist"));
304 }
305
306 #[test]
307 fn test_provider_base_url_rejects_unknown_host() {
308 let result = validate_provider_base_url("https://evil.example.com/v1");
309 assert!(result.is_err());
310 let err = result.unwrap_err().to_string();
311 assert!(err.contains("not in the CI allowlist"));
312 }
313
314 #[test]
315 fn test_provider_base_url_rejects_http() {
316 let result = validate_provider_base_url("http://api.deepseek.com");
317 assert!(result.is_err());
318 let err = result.unwrap_err().to_string();
319 assert!(err.contains("HTTPS"));
320 }
321
322 #[test]
323 fn test_provider_base_url_rejects_malformed() {
324 let result = validate_provider_base_url("not-a-url");
325 assert!(result.is_err());
326 let err = result.unwrap_err().to_string();
327 assert!(err.contains("malformed"));
328 }
329
330 #[test]
331 fn test_provider_base_url_rejects_ipv6_loopback() {
332 let result = validate_provider_base_url("https://[::1]:11434/v1");
333 assert!(result.is_err());
334 let err = result.unwrap_err().to_string();
335 assert!(err.contains("loopback"));
336 }
337
338 #[test]
339 fn test_provider_base_url_rejects_bind_all() {
340 let result = validate_provider_base_url("https://0.0.0.0:8080/v1");
341 assert!(result.is_err());
342 let err = result.unwrap_err().to_string();
343 assert!(err.contains("loopback"));
344
345 let result = validate_provider_base_url("https://[::]:8080/v1");
346 assert!(result.is_err());
347 let err = result.unwrap_err().to_string();
348 assert!(err.contains("loopback"));
349 }
350}