use crate::error::RsGuardError;
use crate::http::{build_github_http_client, github_diff_headers, validate_github_base_url};
use crate::retry::with_retry_simple;
use std::borrow::Cow;
const MAX_DIFF_BYTES: usize = 100 * 1024;
const MAX_DIFF_LINES: usize = 1500;
const REQUEST_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(30);
pub const DEFAULT_CHUNK_HEAD_LINES: usize = 400;
pub const DEFAULT_CHUNK_TAIL_LINES: usize = 400;
const CHUNK_PLACEHOLDER: &str = "\n# ... [diff truncated: {removed} lines omitted] ...\n";
#[derive(Debug, Clone)]
#[must_use = "DiffResult should be used for review processing"]
pub struct DiffResult {
pub content: String,
pub size_bytes: usize,
pub line_count: usize,
}
fn validate_diff_content(content: &str) -> Result<(), RsGuardError> {
let trimmed = content.trim_start();
if trimmed.starts_with('{') || trimmed.starts_with('[') {
return Err(RsGuardError::InvalidDiffContent);
}
let has_diff_markers = content.contains("diff --git")
|| content.contains("@@ ")
|| content.contains("--- a/")
|| content.contains("+++ b/")
|| content.starts_with("diff ")
|| content.starts_with("index ");
if !has_diff_markers {
return Err(RsGuardError::InvalidDiffContent);
}
Ok(())
}
pub fn chunk_diff(content: &str) -> (Cow<'_, str>, bool, usize) {
chunk_diff_with_params(content, DEFAULT_CHUNK_HEAD_LINES, DEFAULT_CHUNK_TAIL_LINES)
}
pub fn chunk_diff_with_params(
content: &str,
head_lines: usize,
tail_lines: usize,
) -> (Cow<'_, str>, bool, usize) {
let has_crlf = content.contains("\r\n");
let line_ending = if has_crlf { "\r\n" } else { "\n" };
let ends_with_newline = content.ends_with('\n') || content.ends_with("\r\n");
let lines: Vec<&str> = content.lines().collect();
let total = lines.len();
let threshold = head_lines + tail_lines;
if total <= threshold {
return (Cow::Borrowed(content), false, 0);
}
let head = &lines[..head_lines];
let tail = &lines[total - tail_lines..];
let removed = total - head_lines - tail_lines;
let placeholder = CHUNK_PLACEHOLDER.replace("{removed}", &removed.to_string());
let mut result = String::with_capacity(content.len() / 2);
for line in head {
result.push_str(line);
result.push_str(line_ending);
}
result.push_str(&placeholder);
for (i, line) in tail.iter().enumerate() {
result.push_str(line);
if i < tail.len() - 1 || ends_with_newline {
result.push_str(line_ending);
}
}
(Cow::Owned(result), true, removed)
}
pub async fn fetch_pr_diff(
base_url: &str,
owner: &str,
repo: &str,
pr_number: u64,
token: &str,
) -> Result<DiffResult, RsGuardError> {
validate_github_base_url(base_url)?;
let client = build_github_http_client(REQUEST_TIMEOUT)?;
let url = format!(
"{}/repos/{}/{}/pulls/{}",
base_url.trim_end_matches('/'),
owner,
repo,
pr_number
);
let headers = github_diff_headers(token)?;
let response = with_retry_simple(|| async {
let resp = client
.get(&url)
.headers(headers.clone())
.send()
.await
.map_err(|e| {
let status = e.status().map(|s| s.as_u16()).unwrap_or(0);
RsGuardError::GitHubApi {
status,
message: e.to_string(),
}
})?;
let status = resp.status();
if !status.is_success() {
let body = resp
.text()
.await
.unwrap_or_else(|e| format!("[failed to read response body: {}]", e));
return Err(RsGuardError::GitHubApi {
status: status.as_u16(),
message: body,
});
}
let body = resp.text().await.map_err(|e| RsGuardError::GitHubApi {
status: 0,
message: e.to_string(),
})?;
Ok(body)
})
.await?;
if response.is_empty() {
return Err(RsGuardError::EmptyDiff);
}
validate_diff_content(&response)?;
let size_bytes = response.len();
let line_count = response.lines().count();
if size_bytes > MAX_DIFF_BYTES || line_count > MAX_DIFF_LINES {
return Err(RsGuardError::DiffTooLarge {
size_bytes,
line_count,
});
}
Ok(DiffResult {
content: response,
size_bytes,
line_count,
})
}
pub fn fetch_file_diff(path: &str) -> Result<DiffResult, RsGuardError> {
let content = std::fs::read_to_string(path)
.map_err(|e| RsGuardError::Config(format!("Failed to read diff file '{}': {}", path, e)))?;
if content.is_empty() {
return Err(RsGuardError::EmptyDiff);
}
validate_diff_content(&content)?;
let size_bytes = content.len();
let line_count = content.lines().count();
if size_bytes > MAX_DIFF_BYTES || line_count > MAX_DIFF_LINES {
return Err(RsGuardError::DiffTooLarge {
size_bytes,
line_count,
});
}
Ok(DiffResult {
content,
size_bytes,
line_count,
})
}
pub fn fetch_local_diff() -> Result<DiffResult, RsGuardError> {
let output = std::process::Command::new("git")
.args(["diff", "--cached"])
.output()
.map_err(RsGuardError::Io)?;
if !output.status.success() {
let stderr = String::from_utf8_lossy(&output.stderr);
return Err(RsGuardError::Config(format!(
"git diff --cached failed: {}",
stderr
)));
}
let content = String::from_utf8_lossy(&output.stdout).to_string();
build_local_diff_result(content)
}
pub(crate) fn build_local_diff_result(content: String) -> Result<DiffResult, RsGuardError> {
if content.is_empty() {
return Err(RsGuardError::EmptyDiff);
}
validate_diff_content(&content)?;
let size_bytes = content.len();
let line_count = content.lines().count();
if size_bytes > MAX_DIFF_BYTES || line_count > MAX_DIFF_LINES {
return Err(RsGuardError::DiffTooLarge {
size_bytes,
line_count,
});
}
Ok(DiffResult {
content,
size_bytes,
line_count,
})
}
#[cfg(test)]
mod tests {
use super::*;
use wiremock::matchers::{header, method, path};
use wiremock::{Mock, MockServer, ResponseTemplate};
#[tokio::test]
async fn test_fetch_pr_diff_success() {
let mock_server = MockServer::start().await;
Mock::given(method("GET"))
.and(path("/repos/test-owner/test-repo/pulls/42"))
.and(header("Accept", "application/vnd.github.v3.diff"))
.respond_with(ResponseTemplate::new(200).set_body_string(
"diff --git a/file.rs b/file.rs\n--- a/file.rs\n+++ b/file.rs\n@@ -1,2 +1,3 @@\n+line",
))
.mount(&mock_server)
.await;
let result = fetch_pr_diff(
&mock_server.uri(),
"test-owner",
"test-repo",
42,
"test-token",
)
.await;
assert!(result.is_ok());
let diff = result.unwrap();
assert!(diff.content.contains("diff --git"));
assert!(diff.line_count > 0);
}
#[tokio::test]
async fn test_fetch_pr_diff_not_found() {
let mock_server = MockServer::start().await;
Mock::given(method("GET"))
.and(path("/repos/test-owner/test-repo/pulls/999"))
.respond_with(ResponseTemplate::new(404).set_body_string("Not Found"))
.mount(&mock_server)
.await;
let result = fetch_pr_diff(
&mock_server.uri(),
"test-owner",
"test-repo",
999,
"test-token",
)
.await;
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("404"));
}
#[tokio::test]
async fn test_fetch_pr_diff_rejects_json_response() {
let mock_server = MockServer::start().await;
Mock::given(method("GET"))
.and(path("/repos/test-owner/test-repo/pulls/42"))
.respond_with(
ResponseTemplate::new(200)
.set_body_string(r#"{"message": "Not Found", "documentation_url": "..." }"#),
)
.mount(&mock_server)
.await;
let result = fetch_pr_diff(
&mock_server.uri(),
"test-owner",
"test-repo",
42,
"test-token",
)
.await;
assert!(result.is_err());
assert!(result
.unwrap_err()
.to_string()
.contains("not appear to be a diff"));
}
#[tokio::test]
async fn test_fetch_pr_diff_exactly_100kb_passes() {
let mock_server = MockServer::start().await;
let diff_header =
"diff --git a/file.rs b/file.rs\n--- a/file.rs\n+++ b/file.rs\n@@ -1,2 +1,3 @@\n";
let header_bytes = diff_header.len();
let content_bytes = 100 * 1024 - header_bytes;
let diff_content = format!("{}{}", diff_header, "+".repeat(content_bytes));
Mock::given(method("GET"))
.and(path("/repos/test-owner/test-repo/pulls/42"))
.and(header("Accept", "application/vnd.github.v3.diff"))
.respond_with(ResponseTemplate::new(200).set_body_string(diff_content))
.mount(&mock_server)
.await;
let result = fetch_pr_diff(
&mock_server.uri(),
"test-owner",
"test-repo",
42,
"test-token",
)
.await;
assert!(result.is_ok(), "Exactly 100KB diff should pass");
let diff = result.unwrap();
assert_eq!(diff.size_bytes, 100 * 1024);
}
#[tokio::test]
async fn test_fetch_pr_diff_100kb_plus_1_fails() {
let mock_server = MockServer::start().await;
let diff_header =
"diff --git a/file.rs b/file.rs\n--- a/file.rs\n+++ b/file.rs\n@@ -1,2 +1,3 @@\n";
let header_bytes = diff_header.len();
let content_bytes = 100 * 1024 - header_bytes + 1;
let diff_content = format!("{}{}", diff_header, "+".repeat(content_bytes));
Mock::given(method("GET"))
.and(path("/repos/test-owner/test-repo/pulls/42"))
.and(header("Accept", "application/vnd.github.v3.diff"))
.respond_with(ResponseTemplate::new(200).set_body_string(diff_content))
.mount(&mock_server)
.await;
let result = fetch_pr_diff(
&mock_server.uri(),
"test-owner",
"test-repo",
42,
"test-token",
)
.await;
assert!(result.is_err(), "100KB + 1 byte diff should fail");
assert!(matches!(result, Err(RsGuardError::DiffTooLarge { .. })));
}
#[tokio::test]
async fn test_fetch_pr_diff_1501_lines_fails() {
let mock_server = MockServer::start().await;
let diff_header =
"diff --git a/file.rs b/file.rs\n--- a/file.rs\n+++ b/file.rs\n@@ -1,2 +1,3 @@\n";
let lines: Vec<String> = (0..1497).map(|i| format!("+line {}", i)).collect();
let diff_content = format!("{}{}", diff_header, lines.join("\n"));
Mock::given(method("GET"))
.and(path("/repos/test-owner/test-repo/pulls/42"))
.and(header("Accept", "application/vnd.github.v3.diff"))
.respond_with(ResponseTemplate::new(200).set_body_string(diff_content))
.mount(&mock_server)
.await;
let result = fetch_pr_diff(
&mock_server.uri(),
"test-owner",
"test-repo",
42,
"test-token",
)
.await;
assert!(result.is_err(), "1501 lines should fail");
assert!(matches!(result, Err(RsGuardError::DiffTooLarge { .. })));
}
#[test]
fn test_validate_diff_content_valid() {
assert!(validate_diff_content("diff --git a/f.rs b/f.rs\n").is_ok());
assert!(validate_diff_content("@@ -1,3 +1,4 @@\n").is_ok());
assert!(validate_diff_content("--- a/f.rs\n+++ b/f.rs\n").is_ok());
assert!(validate_diff_content("index abc123..def456 100644\n").is_ok());
}
#[test]
fn test_validate_diff_content_json() {
assert!(validate_diff_content(r#"{"message": "error"}"#).is_err());
assert!(validate_diff_content(r#"[{"error": true}]"#).is_err());
}
#[test]
fn test_validate_diff_content_no_markers() {
assert!(validate_diff_content("just some random text\nwith no diff markers").is_err());
}
#[test]
fn test_chunk_diff_small_diff_unchanged() {
let content = "line1\nline2\nline3";
let (result, truncated, _) = chunk_diff(content);
assert!(!truncated);
assert_eq!(result.as_ref(), content);
}
#[test]
fn test_chunk_diff_truncates_large_diff() {
let lines: Vec<String> = (0..200).map(|i| format!("line {}", i)).collect();
let content = lines.join("\n");
let (result, truncated, removed) = chunk_diff_with_params(&content, 50, 50);
assert!(truncated);
assert_eq!(removed, 100);
assert!(result.contains("line 0"));
assert!(result.contains("line 49"));
assert!(result.contains("line 150"));
assert!(result.contains("line 199"));
assert!(result.contains("100 lines omitted"));
assert!(!result.contains("line 100"));
}
#[test]
fn test_chunk_diff_exactly_at_threshold_unchanged() {
let lines: Vec<String> = (0..100).map(|i| format!("line {}", i)).collect();
let content = lines.join("\n");
let (result, truncated, _) = chunk_diff_with_params(&content, 50, 50);
assert!(!truncated);
assert_eq!(result.as_ref(), content);
}
#[test]
fn test_chunk_diff_preserves_head_and_tail_order() {
let lines: Vec<String> = (0..150).map(|i| format!("line {}", i)).collect();
let content = lines.join("\n");
let (result, truncated, _) = chunk_diff_with_params(&content, 50, 50);
assert!(truncated);
let head_pos = result.find("line 0").unwrap();
let placeholder_pos = result.find("lines omitted").unwrap();
let tail_pos = result.find("line 100").unwrap();
assert!(head_pos < placeholder_pos);
assert!(placeholder_pos < tail_pos);
}
#[test]
fn test_chunk_diff_preserves_line_endings() {
let lines: Vec<String> = (0..150).map(|i| format!("line {}", i)).collect();
let content = lines.join("\n") + "\n";
let (result, truncated, _) = chunk_diff_with_params(&content, 50, 50);
assert!(truncated);
assert!(result.ends_with('\n'));
}
#[test]
fn test_chunk_diff_preserves_crlf_line_endings() {
let lines: Vec<String> = (0..150).map(|i| format!("line {}", i)).collect();
let content = lines.join("\r\n") + "\r\n";
let (result, truncated, removed) = chunk_diff_with_params(&content, 50, 50);
assert!(truncated);
assert_eq!(removed, 50); assert!(result.contains("\r\n"));
assert!(result.ends_with("\r\n"));
}
#[test]
fn test_chunk_diff_small_crlf_unchanged() {
let content = "line1\r\nline2\r\nline3\r\n";
let (result, truncated, _) = chunk_diff(content);
assert!(!truncated);
assert_eq!(result.as_ref(), content);
}
#[test]
fn test_chunk_diff_no_allocation_when_small() {
let content = "line1\nline2\nline3";
let (result, truncated, _) = chunk_diff(content);
assert!(!truncated);
assert!(matches!(result, Cow::Borrowed(_)));
}
#[test]
fn test_chunk_diff_default_does_not_truncate_200_lines() {
let lines: Vec<String> = (0..200).map(|i| format!("line {}", i)).collect();
let content = lines.join("\n");
let (result, truncated, removed) = chunk_diff(&content);
assert!(
!truncated,
"200-line diff should not be truncated at new 800-line default"
);
assert_eq!(removed, 0);
assert!(matches!(result, Cow::Borrowed(_)));
}
#[test]
fn test_chunk_diff_default_truncates_at_1000_lines() {
let lines: Vec<String> = (0..1000).map(|i| format!("line {}", i)).collect();
let content = lines.join("\n");
let (result, truncated, removed) = chunk_diff(&content);
assert!(
truncated,
"1000-line diff should be truncated at 800-line default"
);
assert_eq!(removed, 200);
assert!(result.contains("200 lines omitted"));
}
#[test]
fn test_chunk_diff_default_exactly_at_threshold() {
let lines: Vec<String> = (0..800).map(|i| format!("line {}", i)).collect();
let content = lines.join("\n");
let (result, truncated, _) = chunk_diff(&content);
assert!(
!truncated,
"800-line diff at threshold should not be truncated"
);
assert!(matches!(result, Cow::Borrowed(_)));
}
#[test]
fn test_chunk_diff_with_params_custom_thresholds() {
let lines: Vec<String> = (0..100).map(|i| format!("line {}", i)).collect();
let content = lines.join("\n");
let (result, truncated, removed) = chunk_diff_with_params(&content, 20, 20);
assert!(truncated);
assert_eq!(removed, 60); assert!(result.contains("line 0"));
assert!(result.contains("line 19"));
assert!(result.contains("line 80"));
assert!(result.contains("line 99"));
assert!(!result.contains("line 50")); }
#[test]
fn test_fetch_file_diff_valid() {
let dir = tempfile::tempdir().unwrap();
let diff_path = dir.path().join("test.diff");
let diff_content =
"diff --git a/f.rs b/f.rs\n--- a/f.rs\n+++ b/f.rs\n@@ -1 +1,2 @@\n+line1\n line0";
std::fs::write(&diff_path, diff_content).unwrap();
let result = fetch_file_diff(diff_path.to_str().unwrap()).unwrap();
assert_eq!(result.content, diff_content);
assert!(result.size_bytes > 0);
assert!(result.line_count > 0);
}
#[test]
fn test_fetch_file_diff_empty() {
let dir = tempfile::tempdir().unwrap();
let diff_path = dir.path().join("empty.diff");
std::fs::write(&diff_path, "").unwrap();
let result = fetch_file_diff(diff_path.to_str().unwrap());
assert!(matches!(result, Err(RsGuardError::EmptyDiff)));
}
#[test]
fn test_fetch_file_diff_invalid_content() {
let dir = tempfile::tempdir().unwrap();
let diff_path = dir.path().join("invalid.diff");
std::fs::write(&diff_path, "not a diff").unwrap();
let result = fetch_file_diff(diff_path.to_str().unwrap());
assert!(matches!(result, Err(RsGuardError::InvalidDiffContent)));
}
#[test]
fn test_fetch_file_diff_too_large() {
let dir = tempfile::tempdir().unwrap();
let diff_path = dir.path().join("large.diff");
let diff_header = "diff --git a/f.rs b/f.rs\n--- a/f.rs\n+++ b/f.rs\n@@ -1 +1,2 @@\n";
let large_content = format!("{}{}", diff_header, "+line\n".repeat(200 * 1024));
std::fs::write(&diff_path, &large_content).unwrap();
let result = fetch_file_diff(diff_path.to_str().unwrap());
assert!(matches!(result, Err(RsGuardError::DiffTooLarge { .. })));
}
#[test]
fn test_fetch_file_diff_not_found() {
let result = fetch_file_diff("/nonexistent/path.diff");
assert!(matches!(result, Err(RsGuardError::Config(_))));
}
#[test]
#[serial_test::serial]
fn test_fetch_local_diff_requires_git_repo() {
let dir = tempfile::tempdir().unwrap();
let original_dir = std::env::current_dir().unwrap();
std::env::set_current_dir(dir.path()).unwrap();
let result = fetch_local_diff();
assert!(result.is_err(), "expected error, got Ok");
let _ = std::env::set_current_dir(&original_dir);
}
#[test]
fn test_build_local_diff_result_rejects_invalid_content() {
let result = build_local_diff_result("this is not a diff at all".to_string());
assert!(
matches!(result, Err(RsGuardError::InvalidDiffContent)),
"expected InvalidDiffContent, got {:?}",
result
);
}
#[test]
fn test_build_local_diff_result_rejects_json_content() {
let result = build_local_diff_result(r#"{"error": "something went wrong"}"#.to_string());
assert!(
matches!(result, Err(RsGuardError::InvalidDiffContent)),
"expected InvalidDiffContent, got {:?}",
result
);
}
#[test]
fn test_build_local_diff_result_rejects_empty() {
let result = build_local_diff_result(String::new());
assert!(matches!(result, Err(RsGuardError::EmptyDiff)));
}
#[test]
fn test_build_local_diff_result_accepts_valid_diff() {
let content = "diff --git a/src/main.rs b/src/main.rs\n--- a/src/main.rs\n+++ b/src/main.rs\n@@ -1 +1,2 @@\n+new line\n old line".to_string();
let result = build_local_diff_result(content.clone());
assert!(result.is_ok(), "expected Ok, got {:?}", result);
let diff = result.unwrap();
assert_eq!(diff.content, content);
assert!(diff.size_bytes > 0);
assert!(diff.line_count > 0);
}
#[test]
fn test_build_local_diff_result_rejects_too_large() {
let header = "diff --git a/f.rs b/f.rs\n--- a/f.rs\n+++ b/f.rs\n@@ -1 +1,2 @@\n";
let huge = format!("{}{}", header, "+line\n".repeat(200 * 1024));
let result = build_local_diff_result(huge);
assert!(matches!(result, Err(RsGuardError::DiffTooLarge { .. })));
}
#[test]
fn test_chunk_diff_101_lines_truncates() {
let lines: Vec<String> = (0..101).map(|i| format!("line {}", i)).collect();
let content = lines.join("\n");
let (result, truncated, removed) = chunk_diff_with_params(&content, 50, 50);
assert!(truncated, "101 lines should truncate with 50/50 params");
assert_eq!(removed, 1); assert!(result.contains("1 lines omitted"));
assert!(result.contains("line 0"));
assert!(result.contains("line 49"));
assert!(result.contains("line 51"));
assert!(result.contains("line 100"));
}
#[test]
fn test_chunk_diff_100_lines_no_truncate() {
let lines: Vec<String> = (0..100).map(|i| format!("line {}", i)).collect();
let content = lines.join("\n");
let (result, truncated, removed) = chunk_diff_with_params(&content, 50, 50);
assert!(
!truncated,
"100 lines should not truncate with 50/50 params"
);
assert_eq!(removed, 0);
assert!(!result.contains("lines omitted"));
assert_eq!(result.as_ref(), content);
}
#[test]
#[serial_test::serial]
fn test_build_local_diff_result_handles_non_utf8_lossy() {
let mut content = "diff --git a/binary.bin b/binary.bin\n--- a/binary.bin\n+++ b/binary.bin\n@@ -1 +1,2 @@\n".as_bytes().to_vec();
content.extend_from_slice(&[0xFF, 0xFE, 0xFD]);
content.extend_from_slice(b"+some content\n");
let lossy_string = String::from_utf8_lossy(&content).to_string();
let result = build_local_diff_result(lossy_string);
assert!(
result.is_ok(),
"non-UTF8 diff with valid markers should be accepted"
);
}
}