use crate::error::RsGuardError;
use crate::llm::providers;
use reqwest::header::{self, HeaderMap, HeaderValue};
use url::Url;
const USER_AGENT: &str = concat!(env!("CARGO_PKG_NAME"), "/", env!("CARGO_PKG_VERSION"));
pub fn build_github_http_client(
timeout: std::time::Duration,
) -> Result<reqwest::Client, RsGuardError> {
reqwest::Client::builder()
.timeout(timeout)
.build()
.map_err(|e| RsGuardError::Config(format!("Failed to build HTTP client: {}", e)))
}
const ALLOWED_BASE_URLS: &[&str] = &["https://api.github.com"];
pub fn validate_github_base_url(base_url: &str) -> Result<(), RsGuardError> {
let trimmed = base_url.trim_end_matches('/');
if trimmed.starts_with("http://127.0.0.1") || trimmed.starts_with("http://localhost") {
return Ok(());
}
if !trimmed.starts_with("https://") {
return Err(RsGuardError::Config(format!(
"GitHub base URL must use HTTPS: '{}'. HTTP is not allowed.",
base_url
)));
}
if ALLOWED_BASE_URLS.contains(&trimmed) {
return Ok(());
}
if trimmed.ends_with("/api/v3") {
return Ok(());
}
Err(RsGuardError::Config(format!(
"GitHub base URL '{}' is not in the allowlist. \
Allowed: {} or https://<enterprise-host>/api/v3",
base_url,
ALLOWED_BASE_URLS.join(", ")
)))
}
pub fn validate_provider_base_url(base_url: &str) -> Result<(), RsGuardError> {
let parsed = Url::parse(base_url).map_err(|_| {
RsGuardError::Config(format!(
"Provider base URL is malformed: '{}'. Expected format: https://host/path",
base_url
))
})?;
if parsed.scheme() != "https" {
return Err(RsGuardError::Config(format!(
"Provider base URL must use HTTPS in CI mode: '{}'. HTTP is not allowed.",
base_url
)));
}
let host = parsed.host_str().ok_or_else(|| {
RsGuardError::Config(format!(
"Provider base URL is malformed: '{}'. No host found.",
base_url
))
})?;
if host == "127.0.0.1"
|| host == "localhost"
|| host == "[::1]"
|| host == "0.0.0.0"
|| host == "[::]"
{
return Err(RsGuardError::Config(format!(
"Provider base URL '{}' uses loopback address, which is not allowed in CI mode \
to prevent token exfiltration. Use a known provider endpoint or run in local mode.",
base_url
)));
}
let ci_hosts = providers::all_ci_allowed_hosts();
for &(allowed_scheme, allowed_host) in &ci_hosts {
if parsed.scheme() == allowed_scheme && host == allowed_host {
return Ok(());
}
}
let allowed_display: Vec<String> = ci_hosts
.iter()
.map(|(s, h)| format!("{}://{}", s, h))
.collect();
Err(RsGuardError::Config(format!(
"Provider base URL '{}' (host: {}) is not in the CI allowlist. \
Allowed hosts: {}. \
To use a custom endpoint, run in local mode (unset GITHUB_ACTIONS).",
base_url,
host,
allowed_display.join(", ")
)))
}
pub fn github_headers(token: &str) -> Result<HeaderMap, RsGuardError> {
let mut headers = HeaderMap::new();
headers.insert(
header::ACCEPT,
HeaderValue::from_static("application/vnd.github+json"),
);
headers.insert(
header::AUTHORIZATION,
HeaderValue::from_str(&format!("Bearer {}", token))
.map_err(|e| RsGuardError::Config(format!("Invalid GitHub token format: {}", e)))?,
);
headers.insert(
"X-GitHub-Api-Version",
HeaderValue::from_static("2022-11-28"),
);
headers.insert(header::USER_AGENT, HeaderValue::from_static(USER_AGENT));
Ok(headers)
}
pub fn github_diff_headers(token: &str) -> Result<HeaderMap, RsGuardError> {
let mut headers = github_headers(token)?;
headers.insert(
header::ACCEPT,
HeaderValue::from_static("application/vnd.github.v3.diff"),
);
Ok(headers)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_validate_allowed_url() {
assert!(validate_github_base_url("https://api.github.com").is_ok());
}
#[test]
fn test_validate_allowed_url_trailing_slash() {
assert!(validate_github_base_url("https://api.github.com/").is_ok());
}
#[test]
fn test_validate_enterprise_url() {
assert!(validate_github_base_url("https://github.mycompany.com/api/v3").is_ok());
}
#[test]
fn test_reject_http() {
let result = validate_github_base_url("http://api.github.com");
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("HTTPS"));
}
#[test]
fn test_allow_loopback_http() {
assert!(validate_github_base_url("http://127.0.0.1:8080").is_ok());
assert!(validate_github_base_url("http://localhost:3000").is_ok());
}
#[test]
fn test_reject_unknown_host() {
let result = validate_github_base_url("https://evil.example.com");
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("allowlist"));
}
#[test]
fn test_reject_partial_match() {
let result = validate_github_base_url("https://not-api.github.com");
assert!(result.is_err());
}
#[test]
fn test_github_headers_valid_token() {
let headers = github_headers("valid-token-123").unwrap();
assert_eq!(
headers.get(header::AUTHORIZATION).unwrap(),
"Bearer valid-token-123"
);
assert_eq!(headers.get(header::USER_AGENT).unwrap(), USER_AGENT);
}
#[test]
fn test_github_headers_invalid_token() {
let result = github_headers("token\x00with\x01control");
assert!(result.is_err());
}
#[test]
fn test_github_diff_headers_accept() {
let headers = github_diff_headers("tok").unwrap();
assert_eq!(
headers.get(header::ACCEPT).unwrap(),
"application/vnd.github.v3.diff"
);
}
#[test]
fn test_provider_base_url_allows_known_hosts() {
assert!(validate_provider_base_url("https://api.deepseek.com").is_ok());
assert!(validate_provider_base_url("https://api.deepseek.com/v1").is_ok());
assert!(validate_provider_base_url("https://api.moonshot.ai/v1").is_ok());
assert!(validate_provider_base_url(
"https://dashscope-intl.aliyuncs.com/compatible-mode/v1"
)
.is_ok());
assert!(validate_provider_base_url("https://openrouter.ai/api/v1").is_ok());
assert!(validate_provider_base_url("https://api.openai.com/v1").is_ok());
}
#[test]
fn test_provider_base_url_rejects_loopback() {
let result = validate_provider_base_url("http://127.0.0.1:11434/v1");
assert!(result.is_err());
let err = result.unwrap_err().to_string();
assert!(err.contains("loopback") || err.contains("HTTPS"));
let result = validate_provider_base_url("https://localhost:8080");
assert!(result.is_err());
let err = result.unwrap_err().to_string();
assert!(err.contains("loopback"));
}
#[test]
fn test_provider_base_url_rejects_subdomain_spoof() {
let result = validate_provider_base_url("https://api.deepseek.com.evil.com/v1");
assert!(result.is_err());
let err = result.unwrap_err().to_string();
assert!(err.contains("not in the CI allowlist"));
}
#[test]
fn test_provider_base_url_rejects_unknown_host() {
let result = validate_provider_base_url("https://evil.example.com/v1");
assert!(result.is_err());
let err = result.unwrap_err().to_string();
assert!(err.contains("not in the CI allowlist"));
}
#[test]
fn test_provider_base_url_rejects_http() {
let result = validate_provider_base_url("http://api.deepseek.com");
assert!(result.is_err());
let err = result.unwrap_err().to_string();
assert!(err.contains("HTTPS"));
}
#[test]
fn test_provider_base_url_rejects_malformed() {
let result = validate_provider_base_url("not-a-url");
assert!(result.is_err());
let err = result.unwrap_err().to_string();
assert!(err.contains("malformed"));
}
#[test]
fn test_provider_base_url_rejects_ipv6_loopback() {
let result = validate_provider_base_url("https://[::1]:11434/v1");
assert!(result.is_err());
let err = result.unwrap_err().to_string();
assert!(err.contains("loopback"));
}
#[test]
fn test_provider_base_url_rejects_bind_all() {
let result = validate_provider_base_url("https://0.0.0.0:8080/v1");
assert!(result.is_err());
let err = result.unwrap_err().to_string();
assert!(err.contains("loopback"));
let result = validate_provider_base_url("https://[::]:8080/v1");
assert!(result.is_err());
let err = result.unwrap_err().to_string();
assert!(err.contains("loopback"));
}
}