use super::CodeReference;
use car_inference::{GenerateParams, GenerateRequest, InferenceEngine};
use serde::Deserialize;
use std::sync::Arc;
pub async fn score_with_llm(
engine: &Arc<InferenceEngine>,
query: &str,
mut refs: Vec<CodeReference>,
model: Option<&str>,
) -> Vec<CodeReference> {
if refs.is_empty() {
return refs;
}
let prompt = build_prompt(query, &refs);
let req = GenerateRequest {
prompt,
model: model.map(String::from),
params: GenerateParams {
temperature: 0.1,
max_tokens: 1024,
..Default::default()
},
context: None,
tools: None,
images: None,
messages: None,
cache_control: false,
response_format: None,
intent: None,
};
let text = match engine.generate(req).await {
Ok(t) => t,
Err(e) => {
tracing::warn!(error = %e, "score_with_llm: inference failed, keeping input order");
return refs;
}
};
let scores = match parse_scores(&text) {
Some(s) => s,
None => {
tracing::warn!(
raw = %truncate(&text, 200),
"score_with_llm: could not parse scores, keeping input order"
);
return refs;
}
};
for entry in scores {
if let Some(r) = refs.get_mut(entry.index) {
let llm = entry.score.clamp(0.0, 1.0);
r.score = (0.7 * llm + 0.3 * r.score).clamp(0.0, 1.0);
if let Some(reason) = entry.why_relevant {
if !reason.is_empty() {
r.why_relevant = reason;
}
}
}
}
refs.sort_by(|a, b| {
b.score
.partial_cmp(&a.score)
.unwrap_or(std::cmp::Ordering::Equal)
});
refs
}
fn build_prompt(query: &str, refs: &[CodeReference]) -> String {
let mut buf = String::new();
buf.push_str(
"You are ranking code references by how relevant they are to a user's query.\n\
For each reference, return a score in [0.0, 1.0] (1.0 = perfectly relevant) and a one-sentence reason.\n\
Reply with ONLY a JSON array, no prose, no markdown fences.\n\
Each element: {\"index\": <0-based>, \"score\": <float>, \"why_relevant\": \"...\"}.\n\n",
);
buf.push_str("Query: ");
buf.push_str(query);
buf.push_str("\n\nReferences:\n");
for (i, r) in refs.iter().enumerate() {
buf.push_str(&format!(
"[{i}] repo={} path={}\n{}\n\n",
r.repo,
r.path,
truncate(&r.snippet, 600)
));
}
buf.push_str("Return the JSON array now:");
buf
}
#[derive(Deserialize)]
struct ScoreEntry {
index: usize,
score: f32,
#[serde(default)]
why_relevant: Option<String>,
}
fn parse_scores(text: &str) -> Option<Vec<ScoreEntry>> {
let trimmed = text.trim();
if let Ok(v) = serde_json::from_str::<Vec<ScoreEntry>>(trimmed) {
return Some(v);
}
let stripped = trimmed
.trim_start_matches("```json")
.trim_start_matches("```")
.trim_end_matches("```")
.trim();
if let Ok(v) = serde_json::from_str::<Vec<ScoreEntry>>(stripped) {
return Some(v);
}
let start = trimmed.find('[')?;
let end = trimmed.rfind(']')?;
if end <= start {
return None;
}
serde_json::from_str::<Vec<ScoreEntry>>(&trimmed[start..=end]).ok()
}
fn truncate(s: &str, n: usize) -> String {
if s.len() <= n {
s.to_string()
} else {
let mut end = n;
while end > 0 && !s.is_char_boundary(end) {
end -= 1;
}
format!("{}…", &s[..end])
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn parse_plain_array() {
let v = parse_scores(r#"[{"index":0,"score":0.8,"why_relevant":"hits"}]"#).unwrap();
assert_eq!(v.len(), 1);
assert_eq!(v[0].index, 0);
}
#[test]
fn parse_json_fence() {
let v = parse_scores("```json\n[{\"index\":1,\"score\":0.3}]\n```").unwrap();
assert_eq!(v[0].index, 1);
assert!((v[0].score - 0.3).abs() < 1e-6);
}
#[test]
fn parse_with_preamble() {
let v = parse_scores("Sure! [{\"index\":0,\"score\":0.5}]. done").unwrap();
assert_eq!(v.len(), 1);
}
#[test]
fn parse_garbage_returns_none() {
assert!(parse_scores("not json at all").is_none());
}
}