use reqwest::header::{HeaderMap, HeaderValue};
use reqwest::Client;
use serde::{Deserialize, Serialize};
use tracing::{debug, info, warn};
use crate::classify::tiers::bedrock::BedrockClassifier;
use crate::classify::tiers::ClassificationResult;
use crate::core::models::ClassificationMethod;
const DEFAULT_ENDPOINT: &str = "https://api.openai.com/v1/chat/completions";
const OPENROUTER_ENDPOINT: &str = "https://openrouter.ai/api/v1/chat/completions";
const OPENROUTER_REFERER: &str = "https://github.com/bobmatnyc/trusty-git-analytics";
const OPENROUTER_TITLE: &str = "trusty-git-analytics";
const SYSTEM_PROMPT: &str = "You are a git commit classifier. Respond with ONLY a JSON \
object: {\"category\": \"feature|bugfix|chore|documentation|refactor|test|ci|performance|style|build|revert|merge|breaking|uncategorized\", \
\"subcategory\": \"optional string or null\", \"confidence\": 0.0-1.0}. No prose, no markdown.";
pub struct LlmClassifier {
client: Client,
model: String,
api_key: Option<String>,
endpoint: String,
extra_headers: HeaderMap,
bedrock: Option<BedrockClassifier>,
}
impl LlmClassifier {
pub fn new(model: &str, api_key: Option<String>) -> Self {
Self {
client: Client::new(),
model: model.to_string(),
api_key,
endpoint: DEFAULT_ENDPOINT.to_string(),
extra_headers: HeaderMap::new(),
bedrock: None,
}
}
pub fn from_provider(
provider: &str,
model: &str,
openrouter_api_key: Option<String>,
) -> Result<Self, String> {
let normalized = provider.trim().to_ascii_lowercase();
match normalized.as_str() {
"openrouter" => Ok(Self::build_openrouter(model, openrouter_api_key)),
"openai" => Ok(Self::new(model, std::env::var("OPENAI_API_KEY").ok())),
"bedrock" => {
info!(model, "LLM provider: bedrock (requested via sync path)");
#[cfg(feature = "bedrock")]
{
Err("bedrock provider requires the async constructor; use \
LlmClassifier::from_provider_async"
.to_string())
}
#[cfg(not(feature = "bedrock"))]
{
let _ = model;
Err(
"bedrock feature not compiled in — rebuild with --features bedrock"
.to_string(),
)
}
}
"auto" | "" => {
let or_key =
openrouter_api_key.or_else(|| std::env::var("OPENROUTER_API_KEY").ok());
if or_key.is_some() {
info!("LLM provider auto-selected: openrouter");
Ok(Self::build_openrouter(model, or_key))
} else {
info!("LLM provider auto-selected: openai");
Ok(Self::new(model, std::env::var("OPENAI_API_KEY").ok()))
}
}
other => {
warn!(
provider = %other,
"unknown LLM provider; falling back to OpenAI endpoint"
);
Ok(Self::new(model, std::env::var("OPENAI_API_KEY").ok()))
}
}
}
pub async fn from_provider_async(
provider: &str,
model: &str,
openrouter_api_key: Option<String>,
) -> Result<Self, String> {
if provider.trim().eq_ignore_ascii_case("bedrock") {
info!(model, "LLM provider: bedrock (async init)");
let bedrock = BedrockClassifier::new(model).await?;
return Ok(Self {
client: Client::new(),
model: model.to_string(),
api_key: None,
endpoint: String::new(),
extra_headers: HeaderMap::new(),
bedrock: Some(bedrock),
});
}
Self::from_provider(provider, model, openrouter_api_key)
}
fn build_openrouter(model: &str, api_key: Option<String>) -> Self {
let key = api_key.or_else(|| std::env::var("OPENROUTER_API_KEY").ok());
let mut headers = HeaderMap::new();
headers.insert("HTTP-Referer", HeaderValue::from_static(OPENROUTER_REFERER));
headers.insert("X-Title", HeaderValue::from_static(OPENROUTER_TITLE));
Self {
client: Client::new(),
model: model.to_string(),
api_key: key,
endpoint: OPENROUTER_ENDPOINT.to_string(),
extra_headers: headers,
bedrock: None,
}
}
pub fn with_endpoint(mut self, endpoint: impl Into<String>) -> Self {
self.endpoint = endpoint.into();
self
}
pub async fn classify(&self, message: &str) -> Option<ClassificationResult> {
if let Some(bedrock) = &self.bedrock {
return bedrock
.classify_batch_bedrock(&[message])
.await
.into_iter()
.next()
.flatten();
}
let api_key = self.api_key.as_deref()?;
let body = ChatRequest {
model: &self.model,
messages: vec![
ChatMessage {
role: "system",
content: SYSTEM_PROMPT.to_string(),
},
ChatMessage {
role: "user",
content: format!("Classify this commit message:\n\n{message}"),
},
],
temperature: 0.0,
response_format: Some(ResponseFormat {
kind: "json_object".to_string(),
}),
};
let response = match self
.client
.post(&self.endpoint)
.bearer_auth(api_key)
.headers(self.extra_headers.clone())
.json(&body)
.send()
.await
{
Ok(r) => r,
Err(e) => {
warn!(error = %e, "LLM request failed");
return None;
}
};
if !response.status().is_success() {
warn!(status = %response.status(), "LLM returned non-success status");
return None;
}
let parsed: ChatResponse = match response.json().await {
Ok(j) => j,
Err(e) => {
warn!(error = %e, "LLM response JSON decode failed");
return None;
}
};
let content = parsed.choices.first()?.message.content.clone();
debug!(content = %content, "LLM raw response");
let verdict: LlmVerdict = serde_json::from_str(&content)
.map_err(|e| warn!(error = %e, "LLM JSON parse failed"))
.ok()?;
Some(ClassificationResult {
category: verdict.category,
subcategory: verdict.subcategory,
top_level: None, confidence: verdict.confidence.clamp(0.0, 1.0),
method: ClassificationMethod::LlmFallback,
ticket_id: None,
})
}
}
#[derive(Serialize)]
struct ChatRequest<'a> {
model: &'a str,
messages: Vec<ChatMessage>,
temperature: f64,
#[serde(skip_serializing_if = "Option::is_none")]
response_format: Option<ResponseFormat>,
}
#[derive(Serialize)]
struct ChatMessage {
role: &'static str,
content: String,
}
#[derive(Serialize)]
struct ResponseFormat {
#[serde(rename = "type")]
kind: String,
}
#[derive(Deserialize)]
struct ChatResponse {
choices: Vec<ChatChoice>,
}
#[derive(Deserialize)]
struct ChatChoice {
message: ChatChoiceMessage,
}
#[derive(Deserialize)]
struct ChatChoiceMessage {
content: String,
}
#[derive(Deserialize)]
struct LlmVerdict {
category: String,
#[serde(default)]
subcategory: Option<String>,
#[serde(default = "default_confidence")]
confidence: f64,
}
fn default_confidence() -> f64 {
0.5
}