use crate::classify::tiers::ClassificationResult;
#[allow(unused_imports)]
use crate::classify::tiers::llm::{LlmVerdict, SYSTEM_PROMPT};
pub struct BedrockClassifier {
#[allow(dead_code)] pub(crate) model: String,
#[cfg(feature = "bedrock")]
client: aws_sdk_bedrockruntime::Client,
}
pub const DEFAULT_BEDROCK_MODEL: &str = "anthropic.claude-3-haiku-20240307-v1:0";
impl BedrockClassifier {
#[cfg(feature = "bedrock")]
pub async fn new(model: &str) -> Result<Self, String> {
let config = aws_config::defaults(aws_config::BehaviorVersion::latest())
.load()
.await;
let client = aws_sdk_bedrockruntime::Client::new(&config);
Ok(Self {
model: model.to_string(),
client,
})
}
#[cfg(feature = "bedrock")]
pub async fn with_region(model: &str, region: Option<&str>) -> Result<Self, String> {
let mut builder = aws_config::defaults(aws_config::BehaviorVersion::latest());
if let Some(r) = region {
builder = builder.region(aws_config::Region::new(r.to_string()));
}
let config = builder.load().await;
let client = aws_sdk_bedrockruntime::Client::new(&config);
Ok(Self {
model: model.to_string(),
client,
})
}
#[cfg(not(feature = "bedrock"))]
pub async fn new(_model: &str) -> Result<Self, String> {
Err("bedrock feature not compiled in — rebuild with --features bedrock".to_string())
}
#[cfg(not(feature = "bedrock"))]
pub async fn with_region(_model: &str, _region: Option<&str>) -> Result<Self, String> {
Err("bedrock feature not compiled in — rebuild with --features bedrock".to_string())
}
#[cfg(feature = "bedrock")]
pub async fn classify_batch_bedrock(
&self,
messages: &[&str],
) -> Vec<Option<ClassificationResult>> {
let mut out = Vec::with_capacity(messages.len());
for msg in messages {
out.push(self.classify_one(msg).await);
}
out
}
#[cfg(not(feature = "bedrock"))]
pub async fn classify_batch_bedrock(
&self,
messages: &[&str],
) -> Vec<Option<ClassificationResult>> {
vec![None; messages.len()]
}
#[cfg(feature = "bedrock")]
async fn classify_one(&self, message: &str) -> Option<ClassificationResult> {
use crate::core::models::ClassificationMethod;
use aws_sdk_bedrockruntime::primitives::Blob;
use tracing::warn;
let body = serde_json::json!({
"anthropic_version": "bedrock-2023-05-31",
"max_tokens": 256,
"temperature": 0.0,
"system": SYSTEM_PROMPT,
"messages": [
{"role": "user", "content": format!("Classify this commit message:\n\n{message}")}
]
});
let body_bytes = match serde_json::to_vec(&body) {
Ok(b) => b,
Err(e) => {
warn!(error = %e, "bedrock body serialize failed");
return None;
}
};
let resp = match self
.client
.invoke_model()
.model_id(&self.model)
.content_type("application/json")
.accept("application/json")
.body(Blob::new(body_bytes))
.send()
.await
{
Ok(r) => r,
Err(e) => {
warn!(error = %e, "bedrock invoke_model failed");
return None;
}
};
let raw = resp.body.into_inner();
#[derive(serde::Deserialize)]
struct BedrockResponse {
content: Vec<ContentBlock>,
}
#[derive(serde::Deserialize)]
struct ContentBlock {
#[serde(default)]
text: Option<String>,
}
let parsed: BedrockResponse = match serde_json::from_slice(&raw) {
Ok(p) => p,
Err(e) => {
warn!(error = %e, "bedrock response decode failed");
return None;
}
};
let text = parsed
.content
.into_iter()
.find_map(|b| b.text)
.unwrap_or_default();
let verdict: LlmVerdict = match serde_json::from_str(text.trim()) {
Ok(v) => v,
Err(e) => {
warn!(error = %e, raw = %text, "bedrock verdict parse failed");
return None;
}
};
Some(ClassificationResult {
category: verdict.category,
subcategory: verdict.subcategory,
top_level: None,
confidence: verdict.confidence.clamp(0.0, 1.0),
method: ClassificationMethod::LlmFallback,
ticket_id: None,
complexity: verdict.complexity.map(|v| v.clamp(1, 5)),
})
}
}
#[cfg(all(test, not(feature = "bedrock")))]
mod tests {
use super::*;
#[tokio::test]
async fn bedrock_stub_returns_error_without_feature() {
let result = BedrockClassifier::new("anthropic.claude-3-haiku-20240307-v1:0").await;
let err = match result {
Err(e) => e,
Ok(_) => panic!("must error without feature"),
};
assert!(err.contains("bedrock feature not compiled in"));
}
#[test]
fn shared_system_prompt_contains_complexity_instruction() {
assert!(
SYSTEM_PROMPT.contains("complexity"),
"shared SYSTEM_PROMPT must instruct the model to return a complexity score"
);
}
}