use crate::classify::tiers::ClassificationResult;
pub struct BedrockClassifier {
#[allow(dead_code)] pub(crate) model: String,
#[cfg(feature = "bedrock")]
client: aws_sdk_bedrockruntime::Client,
}
pub const DEFAULT_BEDROCK_MODEL: &str = "anthropic.claude-3-haiku-20240307-v1:0";
impl BedrockClassifier {
#[cfg(feature = "bedrock")]
pub async fn new(model: &str) -> Result<Self, String> {
let config = aws_config::defaults(aws_config::BehaviorVersion::latest())
.load()
.await;
let client = aws_sdk_bedrockruntime::Client::new(&config);
Ok(Self {
model: model.to_string(),
client,
})
}
#[cfg(not(feature = "bedrock"))]
pub async fn new(_model: &str) -> Result<Self, String> {
Err("bedrock feature not compiled in — rebuild with --features bedrock".to_string())
}
#[cfg(feature = "bedrock")]
pub async fn classify_batch_bedrock(
&self,
messages: &[&str],
) -> Vec<Option<ClassificationResult>> {
let mut out = Vec::with_capacity(messages.len());
for msg in messages {
out.push(self.classify_one(msg).await);
}
out
}
#[cfg(not(feature = "bedrock"))]
pub async fn classify_batch_bedrock(
&self,
messages: &[&str],
) -> Vec<Option<ClassificationResult>> {
vec![None; messages.len()]
}
#[cfg(feature = "bedrock")]
async fn classify_one(&self, message: &str) -> Option<ClassificationResult> {
use crate::core::models::ClassificationMethod;
use aws_sdk_bedrockruntime::primitives::Blob;
use serde::Deserialize;
use tracing::warn;
let body = serde_json::json!({
"anthropic_version": "bedrock-2023-05-31",
"max_tokens": 256,
"temperature": 0.0,
"system": "You are a git commit classifier. Respond with ONLY a JSON \
object: {\"category\": \"feature|bugfix|chore|documentation|refactor|test|ci|performance|style|build|revert|merge|breaking|uncategorized\", \
\"subcategory\": \"optional string or null\", \"confidence\": 0.0-1.0}. \
No prose, no markdown.",
"messages": [
{"role": "user", "content": format!("Classify this commit message:\n\n{message}")}
]
});
let body_bytes = match serde_json::to_vec(&body) {
Ok(b) => b,
Err(e) => {
warn!(error = %e, "bedrock body serialize failed");
return None;
}
};
let resp = match self
.client
.invoke_model()
.model_id(&self.model)
.content_type("application/json")
.accept("application/json")
.body(Blob::new(body_bytes))
.send()
.await
{
Ok(r) => r,
Err(e) => {
warn!(error = %e, "bedrock invoke_model failed");
return None;
}
};
let raw = resp.body.into_inner();
#[derive(Deserialize)]
struct BedrockResponse {
content: Vec<ContentBlock>,
}
#[derive(Deserialize)]
struct ContentBlock {
#[serde(default)]
text: Option<String>,
}
let parsed: BedrockResponse = match serde_json::from_slice(&raw) {
Ok(p) => p,
Err(e) => {
warn!(error = %e, "bedrock response decode failed");
return None;
}
};
let text = parsed
.content
.into_iter()
.find_map(|b| b.text)
.unwrap_or_default();
#[derive(Deserialize)]
struct Verdict {
category: String,
#[serde(default)]
subcategory: Option<String>,
#[serde(default = "default_confidence")]
confidence: f64,
}
fn default_confidence() -> f64 {
0.5
}
let verdict: Verdict = match serde_json::from_str(text.trim()) {
Ok(v) => v,
Err(e) => {
warn!(error = %e, raw = %text, "bedrock verdict parse failed");
return None;
}
};
Some(ClassificationResult {
category: verdict.category,
subcategory: verdict.subcategory,
top_level: None,
confidence: verdict.confidence.clamp(0.0, 1.0),
method: ClassificationMethod::LlmFallback,
ticket_id: None,
})
}
}
#[cfg(all(test, not(feature = "bedrock")))]
mod tests {
use super::*;
#[tokio::test]
async fn bedrock_stub_returns_error_without_feature() {
let result = BedrockClassifier::new("anthropic.claude-3-haiku-20240307-v1:0").await;
let err = match result {
Err(e) => e,
Ok(_) => panic!("must error without feature"),
};
assert!(err.contains("bedrock feature not compiled in"));
}
}