use tokio::sync::OnceCell;
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum IntentClass {
Advisory,
Diagnostic,
Ambiguous,
}
pub async fn classify_intent(query: &str, api_url: &str) -> IntentClass {
let centroids = ensure_centroids(api_url).await;
let (adv_centroid, diag_centroid) = match centroids {
Some(c) => c,
None => return IntentClass::Ambiguous,
};
let query_vec = match embed_query(query, api_url).await {
Some(v) => v,
None => return IntentClass::Ambiguous,
};
let advisory_score = cosine_similarity(&query_vec, adv_centroid);
let diagnostic_score = cosine_similarity(&query_vec, diag_centroid);
classify_from_scores(advisory_score, diagnostic_score)
}
static CENTROIDS: OnceCell<Option<(Vec<f32>, Vec<f32>)>> = OnceCell::const_new();
async fn ensure_centroids(api_url: &str) -> Option<&'static (Vec<f32>, Vec<f32>)> {
let url = api_url.to_string();
let opt = CENTROIDS
.get_or_init(|| async move { compute_centroids(&url).await })
.await;
opt.as_ref()
}
async fn compute_centroids(api_url: &str) -> Option<(Vec<f32>, Vec<f32>)> {
let adv_vecs = embed_batch(ADVISORY_EXAMPLES, api_url).await?;
let diag_vecs = embed_batch(DIAGNOSTIC_EXAMPLES, api_url).await?;
let adv_centroid = mean_centroid(&adv_vecs)?;
let diag_centroid = mean_centroid(&diag_vecs)?;
eprintln!(
"[intent_embed] centroids ready ({} advisory, {} diagnostic examples)",
adv_vecs.len(),
diag_vecs.len()
);
Some((adv_centroid, diag_centroid))
}
const ADVISORY_EXAMPLES: &[&str] = &[
"would more ram help with this",
"should I upgrade my GPU",
"is that worth buying",
"could I offload VRAM to system RAM",
"i think the cpu is fine",
"what if I had a faster SSD",
"makes sense so the network is slow",
"so the ram is the issue right",
"do you think I should restart",
"is it worth getting more storage",
"if i upgraded the gpu would that help",
"i believe the service is running",
"i see the memory is fine",
"everything looks good here",
"ok so the cpu is at 8 percent that seems fine",
"i think the service is already running",
"my vram situation seems to be improving",
"makes sense that the disk would be slow",
"yeah that all adds up",
"so the network was just congested",
"that explains why the gpu was hot",
"ah ok so it was the ram all along",
"i guess the service crashed overnight",
"would adding another monitor hurt gpu performance",
"so basically the ssd is the bottleneck right",
];
const DIAGNOSTIC_EXAMPLES: &[&str] = &[
"how much RAM do I have",
"show me running processes",
"what is my CPU usage right now",
"check my disk health",
"why is my PC slow",
"what services are running",
"list my network adapters",
"what GPU do I have",
"is my firewall on",
"show me recent errors",
"what is my IP address",
"check my wifi signal strength",
"how much free disk space do I have",
"what is taking up all my memory",
"show hardware specs",
"what processes are using the most RAM",
"is my bluetooth working",
"check my disk for errors",
"what network connections are active",
"show me the system logs",
"what is the cpu temperature",
"are there any pending windows updates",
"is the docker daemon running",
"check my battery status",
"what is my gpu driver version",
];
async fn embed_query(text: &str, api_url: &str) -> Option<Vec<f32>> {
let input = format!("search_query: {text}");
embed_single(&input, api_url).await
}
async fn embed_batch(texts: &[&str], api_url: &str) -> Option<Vec<Vec<f32>>> {
let inputs: Vec<String> = texts
.iter()
.map(|t| format!("search_document: {t}"))
.collect();
let embed_model = load_embed_model(api_url)?;
let trimmed = api_url.trim_end_matches('/');
let is_ollama = trimmed.contains("11434");
let body = serde_json::json!({
"model": embed_model,
"input": inputs
});
let url = if is_ollama {
format!("{}/api/embed", trimmed)
} else {
format!("{}/v1/embeddings", trimmed)
};
let client = reqwest::Client::builder()
.timeout(std::time::Duration::from_secs(20))
.build()
.ok()?;
let resp = client
.post(&url)
.header("Content-Type", "application/json")
.json(&body)
.send()
.await
.ok()?;
if !resp.status().is_success() {
return None;
}
let json: serde_json::Value = resp.json().await.ok()?;
let data = if is_ollama {
json["embeddings"].as_array()?
} else {
json["data"].as_array()?
};
let vecs: Vec<Vec<f32>> = data
.iter()
.filter_map(|item| {
let arr = if is_ollama {
item.as_array()
} else {
item["embedding"].as_array()
}?;
Some(
arr.iter()
.filter_map(|v| v.as_f64().map(|f| f as f32))
.collect(),
)
})
.collect();
if vecs.len() != texts.len() {
None
} else {
Some(vecs)
}
}
async fn embed_single(input: &str, api_url: &str) -> Option<Vec<f32>> {
let embed_model = load_embed_model(api_url)?;
let trimmed = api_url.trim_end_matches('/');
let is_ollama = trimmed.contains("11434");
let body = serde_json::json!({
"model": embed_model,
"input": input
});
let url = if is_ollama {
format!("{}/api/embed", trimmed)
} else {
format!("{}/v1/embeddings", trimmed)
};
let client = reqwest::Client::builder()
.timeout(std::time::Duration::from_secs(5))
.build()
.ok()?;
let resp = client
.post(&url)
.header("Content-Type", "application/json")
.json(&body)
.send()
.await
.ok()?;
if !resp.status().is_success() {
return None;
}
let json: serde_json::Value = resp.json().await.ok()?;
let arr = if is_ollama {
json["embeddings"][0].as_array()?
} else {
json["data"][0]["embedding"].as_array()?
};
let vec: Vec<f32> = arr
.iter()
.filter_map(|v| v.as_f64().map(|f| f as f32))
.collect();
if vec.is_empty() {
None
} else {
Some(vec)
}
}
fn load_embed_model(_api_url: &str) -> Option<String> {
let config = crate::agent::config::load_config();
let saved = config.embed_model?;
if saved.trim().is_empty() {
None
} else {
Some(saved)
}
}
fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
if a.len() != b.len() || a.is_empty() {
return 0.0;
}
let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm_a == 0.0 || norm_b == 0.0 {
return 0.0;
}
dot / (norm_a * norm_b)
}
fn mean_centroid(vecs: &[Vec<f32>]) -> Option<Vec<f32>> {
if vecs.is_empty() {
return None;
}
let dim = vecs[0].len();
if dim == 0 {
return None;
}
let mut sum = vec![0.0f32; dim];
for v in vecs {
if v.len() != dim {
return None;
}
for (s, x) in sum.iter_mut().zip(v.iter()) {
*s += x;
}
}
let n = vecs.len() as f32;
Some(sum.into_iter().map(|x| x / n).collect())
}
fn classify_from_scores(advisory: f32, diagnostic: f32) -> IntentClass {
const ADVISORY_MIN: f32 = 0.72; const DIAGNOSTIC_MIN: f32 = 0.68; const MIN_GAP: f32 = 0.08;
if advisory >= ADVISORY_MIN && advisory > diagnostic + MIN_GAP {
IntentClass::Advisory
} else if diagnostic >= DIAGNOSTIC_MIN && diagnostic > advisory + MIN_GAP {
IntentClass::Diagnostic
} else {
IntentClass::Ambiguous
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn cosine_identical_vectors() {
let v = vec![1.0, 2.0, 3.0];
assert!((cosine_similarity(&v, &v) - 1.0).abs() < 1e-5);
}
#[test]
fn cosine_orthogonal_vectors() {
let a = vec![1.0, 0.0];
let b = vec![0.0, 1.0];
assert!(cosine_similarity(&a, &b).abs() < 1e-5);
}
#[test]
fn centroid_of_two_identical() {
let vecs = vec![vec![1.0, 2.0], vec![1.0, 2.0]];
let c = mean_centroid(&vecs).unwrap();
assert!((c[0] - 1.0).abs() < 1e-5);
assert!((c[1] - 2.0).abs() < 1e-5);
}
#[test]
fn classify_from_scores_advisory() {
assert_eq!(classify_from_scores(0.80, 0.60), IntentClass::Advisory);
}
#[test]
fn classify_from_scores_diagnostic() {
assert_eq!(classify_from_scores(0.55, 0.78), IntentClass::Diagnostic);
}
#[test]
fn classify_from_scores_ambiguous_close_gap() {
assert_eq!(classify_from_scores(0.74, 0.70), IntentClass::Ambiguous);
}
#[test]
fn classify_from_scores_ambiguous_low_scores() {
assert_eq!(classify_from_scores(0.50, 0.40), IntentClass::Ambiguous);
}
}