use std::sync::Arc;
use serde::Deserialize;
use tracing::{debug, warn};
use crate::{
llm::{ChatMessage, LlmProvider, LlmRequest},
pipeline::diff_analyzer::models::{FilteredHunk, HunkDropReason},
};
pub const DEFAULT_BATCH_SIZE: usize = 10;
pub const MAX_HUNK_CHARS: usize = 4_000;
pub const DROP_CONFIDENCE_THRESHOLD: f64 = 0.7;
pub const DEFAULT_CLASSIFIER_MODEL: &str = "us.anthropic.claude-haiku-4-5-20251001-v1:0";
const SYSTEM_PROMPT: &str = "\
You are a code review assistant. Below are N code diff hunks from a pull request.\n\
For each hunk, classify it as exactly one of:\n\
- \"substantive\": the hunk contains a meaningful change — logic change, schema\n\
change, API surface change, security-relevant change, control-flow change,\n\
new function/method body, bug fix, or any change that a reviewer must evaluate.\n\
- \"mechanical\": the hunk is boilerplate or noise — pure formatting, whitespace\n\
changes, import reordering, license/copyright header, generated code, fixture\n\
data, JavaDoc-only additions with no logic, getter/setter stubs, pure rename.\n\
- \"uncertain\": cannot determine substantiveness without more context.\n\
\n\
Return a JSON array with exactly one entry per hunk in order, each with:\n\
{\"hunk_id\": \"<id>\", \"classification\": \"<substantive|mechanical|uncertain>\",\n\
\"confidence\": <0.0-1.0>, \"reason\": \"<one sentence>\"}\n\
\n\
Do not include any text outside the JSON array. Do not wrap in markdown fences.";
#[derive(Debug, Clone)]
pub struct HunkClassification {
pub hunk_index: usize,
pub classification: String,
pub confidence: f64,
pub reason: String,
}
impl HunkClassification {
pub fn should_drop(&self) -> bool {
self.classification == "mechanical" && self.confidence > DROP_CONFIDENCE_THRESHOLD
}
pub fn drop_reason(&self) -> HunkDropReason {
HunkDropReason::MechanicalHaiku
}
}
#[derive(Debug, Deserialize)]
struct ClassificationEntry {
#[allow(dead_code)] hunk_id: String,
classification: String,
confidence: f64,
#[serde(default)]
reason: String,
}
pub struct HunkClassifier {
provider: Arc<dyn LlmProvider>,
model: String,
batch_size: usize,
#[allow(dead_code)]
drop_threshold: f64,
}
impl HunkClassifier {
pub fn new(
provider: Arc<dyn LlmProvider>,
model: impl Into<String>,
batch_size: usize,
drop_threshold: f64,
) -> Self {
Self {
provider,
model: model.into(),
batch_size,
drop_threshold,
}
}
pub async fn classify(&self, hunks: &[FilteredHunk]) -> Vec<HunkClassification> {
let mut results = Vec::with_capacity(hunks.len());
for (base_idx, batch) in hunks.chunks(self.batch_size).enumerate() {
let batch_results = self.classify_batch_slice(batch, base_idx).await;
results.extend(batch_results);
}
results
}
async fn classify_batch_slice(
&self,
batch: &[FilteredHunk],
base_idx: usize,
) -> Vec<HunkClassification> {
let user_content = self.build_batch_prompt(batch, base_idx);
let req = LlmRequest {
model: self.model.clone(),
system: SYSTEM_PROMPT.to_string(),
messages: vec![ChatMessage {
role: "user".to_string(),
content: user_content,
}],
temperature: 0.0,
max_tokens: 1024,
response_schema: None,
};
match self.provider.complete(req).await {
Ok(resp) => match parse_classification_array(&resp.text, batch.len(), base_idx) {
Ok(classifications) => {
debug!(
batch_start = base_idx,
count = classifications.len(),
"Stage C classified batch"
);
classifications
}
Err(e) => {
warn!(
error = %e,
batch_start = base_idx,
batch_size = batch.len(),
"Stage C parse failed — keeping all hunks (fail-open)"
);
uncertain_batch(batch, base_idx)
}
},
Err(e) => {
warn!(
error = %e,
batch_start = base_idx,
batch_size = batch.len(),
"Stage C LLM error — keeping all hunks (fail-open, spec REV-208)"
);
uncertain_batch(batch, base_idx)
}
}
}
fn build_batch_prompt(&self, batch: &[FilteredHunk], base_idx: usize) -> String {
let mut out = format!("Classify the following {} hunks:\n\n", batch.len());
for (i, hunk) in batch.iter().enumerate() {
let hunk_id = base_idx + i;
let body = hunk.render();
let truncated = if body.len() > MAX_HUNK_CHARS {
&body[..MAX_HUNK_CHARS]
} else {
&body
};
out.push_str(&format!("--- HUNK {hunk_id} ---\n{truncated}\n\n"));
}
out
}
}
pub fn parse_classification_array(
text: &str,
expected_count: usize,
base_idx: usize,
) -> Result<Vec<HunkClassification>, String> {
let cleaned = text
.trim()
.trim_start_matches("```json")
.trim_start_matches("```")
.trim_end_matches("```")
.trim();
let entries: Vec<ClassificationEntry> =
serde_json::from_str(cleaned).map_err(|e| format!("JSON parse error: {e}"))?;
if entries.len() != expected_count {
return Err(format!(
"expected {expected_count} entries, got {}",
entries.len()
));
}
let mut results = Vec::with_capacity(entries.len());
for (i, entry) in entries.into_iter().enumerate() {
results.push(HunkClassification {
hunk_index: base_idx + i,
classification: entry.classification,
confidence: entry.confidence.clamp(0.0, 1.0),
reason: entry.reason,
});
}
Ok(results)
}
fn uncertain_batch(batch: &[FilteredHunk], base_idx: usize) -> Vec<HunkClassification> {
batch
.iter()
.enumerate()
.map(|(i, _)| HunkClassification {
hunk_index: base_idx + i,
classification: "uncertain".to_string(),
confidence: 0.0,
reason: "fail-open: LLM error or parse failure".to_string(),
})
.collect()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn hunk_classification_mechanical_high_confidence_droppable() {
let c = HunkClassification {
hunk_index: 0,
classification: "mechanical".to_string(),
confidence: 0.85,
reason: "pure import reorder".to_string(),
};
assert!(c.should_drop());
}
#[test]
fn hunk_classification_mechanical_low_confidence_kept() {
let c = HunkClassification {
hunk_index: 0,
classification: "mechanical".to_string(),
confidence: 0.5,
reason: "unsure".to_string(),
};
assert!(!c.should_drop());
}
#[test]
fn hunk_classification_substantive_not_dropped() {
let c = HunkClassification {
hunk_index: 0,
classification: "substantive".to_string(),
confidence: 0.99,
reason: "logic change".to_string(),
};
assert!(!c.should_drop());
}
#[test]
fn hunk_classification_uncertain_not_dropped() {
let c = HunkClassification {
hunk_index: 0,
classification: "uncertain".to_string(),
confidence: 0.9,
reason: "need context".to_string(),
};
assert!(!c.should_drop());
}
#[test]
fn parse_classification_array_valid() {
let json = r#"[
{"hunk_id": "0", "classification": "substantive", "confidence": 0.9, "reason": "logic change"},
{"hunk_id": "1", "classification": "mechanical", "confidence": 0.8, "reason": "import reorder"}
]"#;
let results = parse_classification_array(json, 2, 0).unwrap();
assert_eq!(results.len(), 2);
assert_eq!(results[0].classification, "substantive");
assert!(!results[0].should_drop());
assert_eq!(results[1].classification, "mechanical");
assert!(results[1].should_drop());
}
#[test]
fn parse_classification_array_strips_markdown_fence() {
let json = "```json\n[{\"hunk_id\": \"0\", \"classification\": \"uncertain\", \"confidence\": 0.5, \"reason\": \"x\"}]\n```";
let results = parse_classification_array(json, 1, 0).unwrap();
assert_eq!(results.len(), 1);
assert_eq!(results[0].classification, "uncertain");
}
#[test]
fn parse_classification_array_wrong_count_returns_error() {
let json = r#"[{"hunk_id": "0", "classification": "substantive", "confidence": 0.9, "reason": "x"}]"#;
let err = parse_classification_array(json, 2, 0).unwrap_err();
assert!(err.contains("expected 2"), "error: {err}");
}
#[test]
fn parse_classification_array_invalid_json_returns_error() {
let err = parse_classification_array("not json", 1, 0).unwrap_err();
assert!(err.contains("JSON parse error"));
}
#[test]
fn uncertain_batch_has_correct_length() {
let hunks: Vec<FilteredHunk> = (0..3)
.map(|i| FilteredHunk {
header: format!("@@ -{i},1 +{i},1 @@"),
lines: vec![format!("+line {i}")],
substantive_confidence: 1.0,
reason_kept: "test".to_string(),
})
.collect();
let results = uncertain_batch(&hunks, 5);
assert_eq!(results.len(), 3);
assert_eq!(results[0].hunk_index, 5);
assert_eq!(results[2].hunk_index, 7);
for r in &results {
assert!(!r.should_drop(), "uncertain results must not be dropped");
}
}
#[tokio::test]
#[ignore]
async fn stage_c_live_bedrock() {
use crate::llm::BedrockProvider;
let provider = BedrockProvider::new(DEFAULT_CLASSIFIER_MODEL.to_string(), None)
.await
.expect("BedrockProvider must build");
let classifier = HunkClassifier::new(
Arc::new(provider),
DEFAULT_CLASSIFIER_MODEL,
10,
DROP_CONFIDENCE_THRESHOLD,
);
let hunks = vec![
FilteredHunk {
header: "@@ -1,1 +1,1 @@".to_string(),
lines: vec![
"-use std::io;".to_string(),
"+use std::io::{Read, Write};".to_string(),
],
substantive_confidence: 1.0,
reason_kept: "test".to_string(),
},
FilteredHunk {
header: "@@ -10,3 +10,5 @@".to_string(),
lines: vec![
"-fn process(data: &[u8]) -> Result<(), Error> {".to_string(),
"+fn process(data: &[u8], config: &Config) -> Result<(), Error> {".to_string(),
"+ let timeout = config.timeout();".to_string(),
"+ validate_input(data, timeout)?;".to_string(),
],
substantive_confidence: 1.0,
reason_kept: "test".to_string(),
},
];
let results = classifier.classify(&hunks).await;
assert_eq!(results.len(), 2);
assert!(
!results[1].should_drop(),
"substantive hunk must not be dropped: {:?}",
results[1]
);
}
}