use crate::cli::Args;
use crate::error::RsGuardError;
use crate::http::{validate_github_base_url, validate_provider_base_url};
use crate::llm::providers::{self, find_provider};
use crate::llm::ProviderConfig;
use serde::Deserialize;
use std::collections::HashMap;
use std::path::Path;
pub const DEFAULT_MAX_TOKENS: u32 = 4096;
pub const DEFAULT_PROMPT: &str = r#"You are a Staff Engineer conducting a thorough code review. Your role is to evaluate
the proposed changes and provide actionable, categorized feedback across five dimensions.
## Approval Standard
Approve a change when it definitely improves overall code health, even if it is not perfect.
The goal is continuous improvement — do not block a change because it is not exactly how
you would have written it. If it improves the codebase and follows project conventions, approve it.
## Five Review Axes (evaluate every change across all five)
### 1. Correctness
- Does the code do what it claims to do? Does it match the spec or task requirements?
- Are edge cases handled (null, empty, boundary values, off-by-one)?
- Are error paths handled (not just the happy path)?
- Are there race conditions, state inconsistencies, or incorrect control flow?
### 2. Security
- Is user input validated and sanitized at system boundaries?
- Are secrets kept out of code, logs, and version control?
- Is authentication/authorization checked where needed?
- Are queries parameterized? Is output encoded to prevent injection?
- Are dependencies from trusted sources with no known vulnerabilities?
- Is data from external sources treated as untrusted?
### 3. Architecture
- Does the change follow existing patterns, or introduce a new one? If new, is it justified?
- Are module boundaries maintained? Any circular dependencies or unwanted coupling?
- Is there code duplication that should be shared?
- Is the abstraction level appropriate — not over-engineered, not too coupled?
### 4. Readability & Simplicity
- Can another engineer understand this code without the author explaining it?
- Are names descriptive and consistent with project conventions?
- Is the control flow straightforward (avoid deeply nested logic)?
- Is there dead code, no-op variables, or over-complicated logic that could be simplified?
- Are abstractions earning their complexity?
### 5. Performance
- Any N+1 query patterns or unbounded loops?
- Any synchronous operations that should be async?
- Any unconstrained data fetching or missing pagination?
- Any large objects created in hot paths?
## Severity Taxonomy
Label every finding with its severity:
- `[Critical]` — Must fix before merge: data loss risk, broken functionality, incorrect behavior in production
- `[Security]` — Must fix before merge: vulnerability, unauthorized access, injection risk, exposed secret
- `[Important]` — Should fix before merge: missing test, wrong abstraction, poor error handling, significant tech debt
- `[Suggestion]` — Optional improvement: naming, style, minor optimization (author may ignore)
## Output Format
### Critical Issues
List each `[Critical]` finding with file/location, description, and a concrete fix recommendation.
### Security Issues
List each `[Security]` finding with file/location, description, and a concrete fix recommendation.
### Important Issues
List each `[Important]` finding with file/location and description.
### Suggestions
List each `[Suggestion]` briefly.
### What's Done Well
Always include at least one specific positive observation. Specific praise motivates good practices.
## Verdict Guidelines
- **POSITIVE** if the diff improves overall code health and is ready to merge
- **NEGATIVE** if there are `[Critical]` or `[Security]` findings that must block merging
At the end of your response, include exactly this metadata block (do not modify the format):
[RS_GUARD_VERDICT_METADATA]
Verdict: POSITIVE or NEGATIVE
CriticalIssues: <count>
SecurityIssues: <count>
ImportantIssues: <count>
Suggestions: <count>
"#;
#[derive(Debug, Deserialize, Default, Clone)]
pub struct ProviderTomlConfig {
pub api_key_env: Option<String>,
pub base_url: Option<String>,
pub http_referer: Option<String>,
}
#[derive(Debug, Deserialize, Default, Clone)]
pub struct CircuitBreakerTomlConfig {
pub enabled: bool,
pub threshold: Option<u32>,
pub cooldown_secs: Option<u64>,
}
#[derive(Debug, Deserialize, Default, Clone)]
pub struct PricingTomlConfig {
pub input_per_million: u64,
pub output_per_million: u64,
}
#[derive(Debug, Deserialize, Default, Clone)]
pub struct TomlConfig {
pub provider: Option<String>,
pub model: Option<String>,
pub temperature: Option<f32>,
pub max_tokens: Option<u32>,
pub chunk_head_lines: Option<usize>,
pub chunk_tail_lines: Option<usize>,
pub providers: Option<HashMap<String, ProviderTomlConfig>>,
pub cache_dir: Option<String>,
pub circuit_breaker: Option<CircuitBreakerTomlConfig>,
pub pricing: Option<HashMap<String, PricingTomlConfig>>,
pub auto_gitignore: Option<bool>,
}
pub fn load_toml_config(path: &Path) -> Result<Option<TomlConfig>, RsGuardError> {
if !path.exists() {
return Ok(None);
}
let content = std::fs::read_to_string(path).map_err(|e| {
RsGuardError::Config(format!(
"Failed to read config file '{}': {}",
path.display(),
e
))
})?;
let config: TomlConfig = toml::from_str(&content).map_err(|e| {
RsGuardError::Config(format!(
"Failed to parse config file '{}': {}",
path.display(),
e
))
})?;
Ok(Some(config))
}
fn standard_api_key_env_var(provider: &str) -> Result<&'static str, RsGuardError> {
find_provider(provider)
.map(|m| m.api_key_env)
.ok_or_else(|| {
let names: Vec<&str> = crate::llm::providers::known_provider_names();
RsGuardError::Config(format!(
"Unknown provider: '{}'. Supported: {}",
provider,
names.join(", ")
))
})
}
fn default_model(provider: &str) -> Result<&'static str, RsGuardError> {
find_provider(provider)
.map(|m| m.default_model)
.ok_or_else(|| {
let names: Vec<&str> = crate::llm::providers::known_provider_names();
RsGuardError::Config(format!(
"Unknown provider: '{}'. Supported: {}",
provider,
names.join(", ")
))
})
}
fn validate_local_provider_base_url(base_url: &str) -> Result<(), RsGuardError> {
let parsed = url::Url::parse(base_url).map_err(|_| {
RsGuardError::Config(format!(
"Provider base URL is malformed: '{}'. Expected format: https://host/path",
base_url
))
})?;
let host = parsed.host_str().unwrap_or("");
if parsed.scheme() != "https" {
log::warn!(
"Provider base URL '{}' uses {} (not HTTPS). API keys will be transmitted in plaintext. \
This is risky if the traffic leaves your machine.",
base_url,
parsed.scheme()
);
} else if host == "127.0.0.1"
|| host == "localhost"
|| host == "[::1]"
|| host == "0.0.0.0"
|| host == "[::]"
{
log::warn!(
"Provider base URL '{}' points to a loopback address. \
Your API key will be sent to a local server. \
Ensure this is intentional (e.g. Ollama, LM Studio).",
base_url
);
} else if !providers::all_ci_allowed_hosts()
.iter()
.any(|&(s, h)| parsed.scheme() == s && host == h)
{
log::warn!(
"Provider base URL '{}' (host: {}) is not a recognized LLM provider endpoint. \
Your API key will be sent to a third-party server. \
Verify this is intentional.",
base_url,
host
);
}
Ok(())
}
fn resolve_api_key_env_var(
provider: &str,
toml_providers: Option<&HashMap<String, ProviderTomlConfig>>,
) -> Result<String, RsGuardError> {
if let Some(providers) = toml_providers {
if let Some(toml_provider) = providers.get(provider) {
if let Some(ref env_var) = toml_provider.api_key_env {
return Ok(env_var.clone());
}
}
}
standard_api_key_env_var(provider).map(|s| s.to_string())
}
#[derive(Debug, Clone)]
pub struct CiConfig {
pub github_token: String,
pub pr_number: u64,
pub repo_owner: String,
pub repo_name: String,
pub github_base_url: String,
}
#[derive(Debug, Clone)]
pub struct Config {
pub provider: String,
pub model: String,
pub temperature: f32,
pub api_key: String,
pub github_token: Option<String>,
pub pr_number: Option<u64>,
pub repo_owner: Option<String>,
pub repo_name: Option<String>,
pub prompt: String,
pub is_ci: bool,
pub github_base_url: String,
pub provider_config: ProviderConfig,
toml_providers: HashMap<String, ProviderTomlConfig>,
model_set_via_cli: bool,
pub no_cache: bool,
pub dry_run: bool,
pub cache_dir: Option<String>,
pub circuit_breaker: Option<crate::retry::CircuitBreaker>,
pub pricing: Option<HashMap<String, PricingTomlConfig>>,
pub auto_gitignore: bool,
pub chunk_head_lines: usize,
pub chunk_tail_lines: usize,
}
impl Config {
#[doc(hidden)]
pub fn empty() -> Self {
Self {
provider: String::new(),
model: String::new(),
temperature: 0.1,
api_key: String::new(),
github_token: None,
pr_number: None,
repo_owner: None,
repo_name: None,
prompt: String::new(),
is_ci: false,
github_base_url: String::new(),
provider_config: ProviderConfig::default(),
toml_providers: HashMap::new(),
model_set_via_cli: false,
no_cache: false,
dry_run: false,
cache_dir: None,
circuit_breaker: None,
pricing: None,
auto_gitignore: true,
chunk_head_lines: crate::diff::DEFAULT_CHUNK_HEAD_LINES,
chunk_tail_lines: crate::diff::DEFAULT_CHUNK_TAIL_LINES,
}
}
pub fn from_env(toml: Option<TomlConfig>) -> Result<Self, RsGuardError> {
let is_ci = std::env::var("GITHUB_ACTIONS").is_ok();
let toml_providers = toml
.as_ref()
.and_then(|t| t.providers.clone())
.unwrap_or_default();
let provider = std::env::var("RS_GUARD_PROVIDER")
.ok()
.or_else(|| toml.as_ref().and_then(|t| t.provider.clone()))
.unwrap_or_else(|| "deepseek".to_string());
standard_api_key_env_var(&provider)?;
let api_key_env = resolve_api_key_env_var(&provider, Some(&toml_providers))?;
let api_key = std::env::var(&api_key_env).map_err(|_| {
RsGuardError::Config(format!(
"API key not found. Set {} for provider '{}'",
api_key_env, provider
))
})?;
let github_token = std::env::var("GITHUB_TOKEN").ok();
let pr_number = std::env::var("PR_NUMBER").ok().and_then(|s| s.parse().ok());
let repo_full_name = std::env::var("REPO_FULL_NAME").ok();
let (repo_owner, repo_name) = match repo_full_name {
Some(full) => {
let parts: Vec<&str> = full.splitn(2, '/').collect();
if parts.len() != 2 {
return Err(RsGuardError::Config(format!(
"REPO_FULL_NAME must be in 'owner/repo' format, got: '{}'",
full
)));
}
let owner = parts[0];
let repo = parts[1];
if owner.is_empty() || repo.is_empty() {
return Err(RsGuardError::Config(format!(
"REPO_FULL_NAME owner and repo cannot be empty, got: '{}'",
full
)));
}
if owner.contains('/') || repo.contains('/') {
return Err(RsGuardError::Config(format!(
"REPO_FULL_NAME must be in 'owner/repo' format (no additional slashes), got: '{}'",
full
)));
}
(Some(owner.to_string()), Some(repo.to_string()))
}
None => (None, None),
};
let github_base_url = std::env::var("GITHUB_API_URL")
.unwrap_or_else(|_| "https://api.github.com".to_string());
let env_model = std::env::var("RS_GUARD_MODEL").ok();
let toml_model = toml.as_ref().and_then(|t| t.model.clone());
let model = env_model.or(toml_model).unwrap_or_else(|| {
default_model(&provider)
.expect("provider already validated above")
.to_string()
});
let temperature = match std::env::var("RS_GUARD_TEMPERATURE") {
Ok(val) => val.parse::<f32>().map_err(|_| {
RsGuardError::Config(format!(
"Invalid RS_GUARD_TEMPERATURE '{}': must be a number between 0.0 and 2.0",
val
))
})?,
Err(_) => toml.as_ref().and_then(|t| t.temperature).unwrap_or(0.1),
};
if !(0.0..=2.0).contains(&temperature) {
return Err(RsGuardError::Config(format!(
"Temperature must be between 0.0 and 2.0, got: {}",
temperature
)));
}
let max_tokens: Option<u32> = std::env::var("RS_GUARD_MAX_TOKENS")
.ok()
.and_then(|s| s.parse().ok())
.or(toml.as_ref().and_then(|t| t.max_tokens))
.or(Some(DEFAULT_MAX_TOKENS));
let chunk_head_lines = toml
.as_ref()
.and_then(|t| t.chunk_head_lines)
.unwrap_or(crate::diff::DEFAULT_CHUNK_HEAD_LINES);
let chunk_tail_lines = toml
.as_ref()
.and_then(|t| t.chunk_tail_lines)
.unwrap_or(crate::diff::DEFAULT_CHUNK_TAIL_LINES);
let toml_provider = toml_providers.get(&provider);
let base_url = toml_provider.and_then(|p| p.base_url.clone());
if is_ci {
if let Some(ref url) = base_url {
validate_provider_base_url(url)?;
}
} else if let Some(ref url) = base_url {
validate_local_provider_base_url(url)?;
}
let provider_config = ProviderConfig {
base_url,
http_referer: toml_provider.and_then(|p| p.http_referer.clone()),
max_tokens,
model: model.clone(),
};
let cache_dir = toml.as_ref().and_then(|t| t.cache_dir.clone());
let circuit_breaker = toml
.as_ref()
.and_then(|t| t.circuit_breaker.as_ref())
.and_then(|cb| {
if cb.enabled {
Some(crate::retry::CircuitBreaker::new(
cb.threshold.unwrap_or(3),
cb.cooldown_secs.unwrap_or(60),
))
} else {
None
}
});
let pricing = toml.as_ref().and_then(|t| t.pricing.clone());
let auto_gitignore = toml.as_ref().and_then(|t| t.auto_gitignore).unwrap_or(true);
Ok(Config {
provider,
model,
temperature,
api_key,
github_token,
pr_number,
repo_owner,
repo_name,
prompt: DEFAULT_PROMPT.to_string(),
is_ci,
github_base_url,
provider_config,
toml_providers,
model_set_via_cli: false,
no_cache: false,
dry_run: false,
cache_dir,
circuit_breaker,
pricing,
auto_gitignore,
chunk_head_lines,
chunk_tail_lines,
})
}
pub fn apply_args(&mut self, args: &Args) -> Result<(), RsGuardError> {
if let Some(ref provider) = args.provider {
if *provider != self.provider {
let new_env = resolve_api_key_env_var(provider, Some(&self.toml_providers))?;
let new_key = std::env::var(&new_env).map_err(|_| {
RsGuardError::Config(format!(
"API key not found. Set {} for provider '{}'",
new_env, provider
))
})?;
let old_provider = self.provider.clone();
self.api_key = new_key;
self.provider = provider.clone();
let toml_provider = self.toml_providers.get(provider);
let new_base_url = toml_provider.and_then(|p| p.base_url.clone());
if self.is_ci {
if let Some(ref url) = new_base_url {
validate_provider_base_url(url)?;
}
} else if let Some(ref url) = new_base_url {
validate_local_provider_base_url(url)?;
}
log::debug!(
"Provider switch '{} -> '{}': base_url={:?}, http_referer={:?}",
old_provider,
provider,
new_base_url,
toml_provider.and_then(|p| p.http_referer.as_deref())
);
self.provider_config.base_url = new_base_url;
self.provider_config.http_referer =
toml_provider.and_then(|p| p.http_referer.clone());
if !self.model_set_via_cli && args.model.is_none() {
self.model = default_model(provider)
.expect("provider already validated above")
.to_string();
}
self.provider_config.model = self.model.clone();
}
}
if let Some(ref model) = args.model {
self.model = model.clone();
self.provider_config.model = model.clone();
self.model_set_via_cli = true;
}
if let Some(temp) = args.temperature {
self.temperature = temp;
}
if let Some(max_tokens) = args.max_tokens {
self.provider_config.max_tokens = Some(max_tokens);
}
if args.no_cache {
self.no_cache = true;
}
if args.dry_run {
self.dry_run = true;
}
Ok(())
}
pub fn load_prompt_file(&mut self, path: &Path) -> Result<(), RsGuardError> {
if path.exists() {
let content = std::fs::read_to_string(path)
.map_err(|e| RsGuardError::Config(format!("Failed to read prompt file: {}", e)))?;
self.prompt = content;
}
Ok(())
}
pub fn validate_for_ci(&self) -> Result<CiConfig, RsGuardError> {
validate_github_base_url(&self.github_base_url)?;
if !self.is_ci {
return Err(RsGuardError::Config(
"validate_for_ci() called but not in CI mode".to_string(),
));
}
let github_token = self.github_token.clone().ok_or_else(|| {
RsGuardError::Config("GITHUB_TOKEN is required in CI mode".to_string())
})?;
let pr_number = self
.pr_number
.ok_or_else(|| RsGuardError::Config("PR_NUMBER is required in CI mode".to_string()))?;
let repo_owner = self.repo_owner.clone().ok_or_else(|| {
RsGuardError::Config(
"REPO_FULL_NAME is required in CI mode (format: owner/repo)".to_string(),
)
})?;
let repo_name = self.repo_name.clone().ok_or_else(|| {
RsGuardError::Config(
"REPO_FULL_NAME is required in CI mode (format: owner/repo)".to_string(),
)
})?;
Ok(CiConfig {
github_token,
pr_number,
repo_owner,
repo_name,
github_base_url: self.github_base_url.clone(),
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Mutex;
static ENV_MUTEX: Mutex<()> = Mutex::new(());
#[test]
fn test_standard_api_key_env_var_mapping() {
assert_eq!(
standard_api_key_env_var("deepseek").unwrap(),
"DEEPSEEK_API_KEY"
);
assert_eq!(standard_api_key_env_var("kimi").unwrap(), "KIMI_API_KEY");
assert_eq!(
standard_api_key_env_var("qwen").unwrap(),
"DASHSCOPE_API_KEY"
);
assert_eq!(
standard_api_key_env_var("openrouter").unwrap(),
"OPENROUTER_API_KEY"
);
assert_eq!(
standard_api_key_env_var("openai").unwrap(),
"OPENAI_API_KEY"
);
}
#[test]
fn test_unknown_provider_returns_error() {
let result = standard_api_key_env_var("unknown");
assert!(result.is_err());
let err = result.unwrap_err().to_string();
assert!(err.contains("Unknown provider"));
assert!(err.contains("unknown"));
}
#[test]
fn test_default_model_mapping() {
assert_eq!(default_model("deepseek").unwrap(), "deepseek-v4-flash");
assert_eq!(default_model("kimi").unwrap(), "kimi-k2.5");
assert_eq!(default_model("qwen").unwrap(), "qwen-plus");
assert_eq!(default_model("openrouter").unwrap(), "openai/gpt-4o-mini");
assert_eq!(default_model("openai").unwrap(), "gpt-4o-mini");
}
#[test]
fn test_default_model_unknown_provider_returns_error() {
let result = default_model("unknown");
assert!(result.is_err());
let err = result.unwrap_err().to_string();
assert!(err.contains("Unknown provider"));
}
#[test]
fn test_resolve_api_key_env_var_toml_override() {
let mut providers = HashMap::new();
providers.insert(
"openai".to_string(),
ProviderTomlConfig {
api_key_env: Some("MY_CUSTOM_KEY".to_string()),
base_url: None,
http_referer: None,
},
);
let result = resolve_api_key_env_var("openai", Some(&providers)).unwrap();
assert_eq!(result, "MY_CUSTOM_KEY");
}
#[test]
fn test_resolve_api_key_env_var_standard_fallback() {
let providers = HashMap::new();
let result = resolve_api_key_env_var("deepseek", Some(&providers)).unwrap();
assert_eq!(result, "DEEPSEEK_API_KEY");
}
#[test]
fn test_validate_local_provider_base_url_http_warns() {
let result = validate_local_provider_base_url("http://api.example.com/v1");
assert!(result.is_ok());
}
#[test]
fn test_validate_local_provider_base_url_loopback_warns() {
let result = validate_local_provider_base_url("http://127.0.0.1:11434/v1");
assert!(result.is_ok());
}
#[test]
fn test_validate_local_provider_base_url_unknown_host_warns() {
let result = validate_local_provider_base_url("https://custom-llm.example.com/v1");
assert!(result.is_ok());
}
#[test]
fn test_validate_local_provider_base_url_malformed_errors() {
let result = validate_local_provider_base_url("not-a-url");
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("malformed"));
}
#[test]
fn test_validate_local_provider_base_url_known_host_ok() {
let result = validate_local_provider_base_url("https://api.deepseek.com/v1");
assert!(result.is_ok());
}
#[test]
fn test_validate_for_ci_local_mode_valid() {
let mut config = Config::empty();
config.is_ci = false;
config.github_base_url = "https://api.github.com".to_string();
assert!(config.validate_for_ci().is_err());
}
#[test]
fn test_validate_for_ci_missing_github_token() {
let mut config = Config::empty();
config.is_ci = true;
config.github_base_url = "https://api.github.com".to_string();
config.github_token = None;
config.pr_number = Some(1);
config.repo_owner = Some("owner".to_string());
config.repo_name = Some("repo".to_string());
let result = config.validate_for_ci();
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("GITHUB_TOKEN"));
}
#[test]
fn test_validate_for_ci_missing_pr_number() {
let mut config = Config::empty();
config.is_ci = true;
config.github_base_url = "https://api.github.com".to_string();
config.github_token = Some("token".to_string());
config.pr_number = None;
config.repo_owner = Some("owner".to_string());
config.repo_name = Some("repo".to_string());
let result = config.validate_for_ci();
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("PR_NUMBER"));
}
#[test]
fn test_validate_for_ci_missing_repo_owner() {
let mut config = Config::empty();
config.is_ci = true;
config.github_base_url = "https://api.github.com".to_string();
config.github_token = Some("token".to_string());
config.pr_number = Some(1);
config.repo_owner = None;
config.repo_name = Some("repo".to_string());
let result = config.validate_for_ci();
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("REPO_FULL_NAME"));
}
#[test]
fn test_validate_for_ci_missing_repo_name() {
let mut config = Config::empty();
config.is_ci = true;
config.github_base_url = "https://api.github.com".to_string();
config.github_token = Some("token".to_string());
config.pr_number = Some(1);
config.repo_owner = Some("owner".to_string());
config.repo_name = None;
let result = config.validate_for_ci();
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("REPO_FULL_NAME"));
}
#[test]
fn test_validate_for_ci_all_fields_present() {
let mut config = Config::empty();
config.is_ci = true;
config.github_base_url = "https://api.github.com".to_string();
config.github_token = Some("token".to_string());
config.pr_number = Some(42);
config.repo_owner = Some("owner".to_string());
config.repo_name = Some("repo".to_string());
assert!(config.validate_for_ci().is_ok());
}
#[test]
fn test_validate_for_ci_invalid_base_url() {
let mut config = Config::empty();
config.is_ci = true;
config.github_base_url = "http://evil.com".to_string();
config.github_token = Some("token".to_string());
config.pr_number = Some(1);
config.repo_owner = Some("owner".to_string());
config.repo_name = Some("repo".to_string());
let result = config.validate_for_ci();
assert!(result.is_err());
}
#[test]
fn test_validate_for_ci_returns_ci_config() {
let mut config = Config::empty();
config.is_ci = true;
config.github_token = Some("test-token".to_string());
config.pr_number = Some(42);
config.repo_owner = Some("owner".to_string());
config.repo_name = Some("repo".to_string());
config.github_base_url = "https://api.github.com".to_string();
let ci_config = config.validate_for_ci().expect("should validate");
assert_eq!(ci_config.github_token, "test-token");
assert_eq!(ci_config.pr_number, 42);
assert_eq!(ci_config.repo_owner, "owner");
assert_eq!(ci_config.repo_name, "repo");
assert_eq!(ci_config.github_base_url, "https://api.github.com");
}
#[test]
fn test_validate_for_ci_not_in_ci_mode_returns_error() {
let mut config = Config::empty();
config.is_ci = false;
config.github_base_url = "https://api.github.com".to_string();
let result = config.validate_for_ci();
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("not in CI mode"));
}
#[test]
fn test_load_prompt_file_existing_file() {
let dir = tempfile::tempdir().unwrap();
let prompt_path = dir.path().join("prompt.md");
std::fs::write(&prompt_path, "Custom review prompt").unwrap();
let mut config = Config::empty();
config.load_prompt_file(&prompt_path).unwrap();
assert_eq!(config.prompt, "Custom review prompt");
}
#[test]
fn test_load_prompt_file_missing_file_keeps_default() {
let mut config = Config::empty();
config.prompt = "default prompt".to_string();
let result = config.load_prompt_file(std::path::Path::new("/nonexistent/prompt.md"));
assert!(result.is_ok());
assert_eq!(config.prompt, "default prompt");
}
#[test]
fn test_load_prompt_file_unreadable_file() {
let dir = tempfile::tempdir().unwrap();
let prompt_path = dir.path().join("unreadable.md");
std::fs::write(&prompt_path, "content").unwrap();
#[cfg(unix)]
std::fs::set_permissions(
&prompt_path,
std::os::unix::fs::PermissionsExt::from_mode(0o000),
)
.unwrap();
let mut config = Config::empty();
let result = config.load_prompt_file(&prompt_path);
#[cfg(unix)]
{
use std::os::unix::fs::PermissionsExt;
std::fs::set_permissions(&prompt_path, PermissionsExt::from_mode(0o644)).ok();
}
#[cfg(unix)]
assert!(result.is_err());
}
#[test]
fn test_config_empty_has_dry_run_false() {
let config = Config::empty();
assert!(!config.dry_run);
}
#[test]
fn test_config_empty_has_cache_dir_none() {
let config = Config::empty();
assert!(config.cache_dir.is_none());
}
#[test]
fn test_apply_args_sets_dry_run() {
use clap::Parser;
let mut config = Config::empty();
assert!(!config.dry_run);
let args = crate::cli::Args::parse_from(["rs-guard", "--dry-run"]);
config.apply_args(&args).unwrap();
assert!(config.dry_run);
}
#[test]
fn test_circuit_breaker_disabled_produces_none() {
let toml = TomlConfig {
circuit_breaker: Some(CircuitBreakerTomlConfig {
enabled: false,
threshold: Some(5),
cooldown_secs: Some(120),
}),
..Default::default()
};
let circuit_breaker = toml.circuit_breaker.as_ref().and_then(|cb| {
if cb.enabled {
Some(crate::retry::CircuitBreaker::new(
cb.threshold.unwrap_or(3),
cb.cooldown_secs.unwrap_or(60),
))
} else {
None
}
});
assert!(
circuit_breaker.is_none(),
"circuit_breaker should be None when enabled=false"
);
}
#[test]
fn test_circuit_breaker_enabled_produces_some() {
let toml = TomlConfig {
circuit_breaker: Some(CircuitBreakerTomlConfig {
enabled: true,
threshold: Some(5),
cooldown_secs: Some(120),
}),
..Default::default()
};
let circuit_breaker = toml.circuit_breaker.as_ref().and_then(|cb| {
if cb.enabled {
Some(crate::retry::CircuitBreaker::new(
cb.threshold.unwrap_or(3),
cb.cooldown_secs.unwrap_or(60),
))
} else {
None
}
});
assert!(
circuit_breaker.is_some(),
"circuit_breaker should be Some when enabled=true"
);
}
#[test]
fn test_repo_full_name_with_multiple_slashes() {
let _guard = ENV_MUTEX.lock().unwrap();
std::env::set_var("DEEPSEEK_API_KEY", "test-key");
std::env::set_var("REPO_FULL_NAME", "owner/repo/subpath");
let result = Config::from_env(None);
std::env::remove_var("DEEPSEEK_API_KEY");
std::env::remove_var("REPO_FULL_NAME");
assert!(result.is_err());
assert!(result
.unwrap_err()
.to_string()
.contains("no additional slashes"));
}
#[test]
fn test_repo_full_name_empty_owner() {
let _guard = ENV_MUTEX.lock().unwrap();
std::env::set_var("DEEPSEEK_API_KEY", "test-key");
std::env::set_var("REPO_FULL_NAME", "/repo");
let result = Config::from_env(None);
std::env::remove_var("DEEPSEEK_API_KEY");
std::env::remove_var("REPO_FULL_NAME");
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("cannot be empty"));
}
#[test]
fn test_repo_full_name_empty_repo() {
let _guard = ENV_MUTEX.lock().unwrap();
std::env::set_var("DEEPSEEK_API_KEY", "test-key");
std::env::set_var("REPO_FULL_NAME", "owner/");
let result = Config::from_env(None);
std::env::remove_var("DEEPSEEK_API_KEY");
std::env::remove_var("REPO_FULL_NAME");
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("cannot be empty"));
}
#[test]
fn test_repo_full_name_valid_format() {
let _guard = ENV_MUTEX.lock().unwrap();
std::env::set_var("DEEPSEEK_API_KEY", "test-key");
std::env::set_var("REPO_FULL_NAME", "owner/repo");
let result = Config::from_env(None);
std::env::remove_var("DEEPSEEK_API_KEY");
std::env::remove_var("REPO_FULL_NAME");
assert!(result.is_ok());
let config = result.unwrap();
assert_eq!(config.repo_owner, Some("owner".to_string()));
assert_eq!(config.repo_name, Some("repo".to_string()));
}
}