1use tokio::sync::OnceCell;
17
18#[derive(Debug, Clone, Copy, PartialEq)]
21pub enum IntentClass {
22 Advisory,
25 Diagnostic,
27 Ambiguous,
29}
30
31pub async fn classify_intent(query: &str, api_url: &str) -> IntentClass {
37 let centroids = ensure_centroids(api_url).await;
38 let (adv_centroid, diag_centroid) = match centroids {
39 Some(c) => c,
40 None => return IntentClass::Ambiguous,
41 };
42
43 let query_vec = match embed_query(query, api_url).await {
44 Some(v) => v,
45 None => return IntentClass::Ambiguous,
46 };
47
48 let advisory_score = cosine_similarity(&query_vec, adv_centroid);
49 let diagnostic_score = cosine_similarity(&query_vec, diag_centroid);
50
51 classify_from_scores(advisory_score, diagnostic_score)
52}
53
54static CENTROIDS: OnceCell<Option<(Vec<f32>, Vec<f32>)>> = OnceCell::const_new();
58
59async fn ensure_centroids(api_url: &str) -> Option<&'static (Vec<f32>, Vec<f32>)> {
60 let url = api_url.to_string();
61 let opt = CENTROIDS
62 .get_or_init(|| async move { compute_centroids(&url).await })
63 .await;
64 opt.as_ref()
65}
66
67async fn compute_centroids(api_url: &str) -> Option<(Vec<f32>, Vec<f32>)> {
68 let adv_vecs = embed_batch(ADVISORY_EXAMPLES, api_url).await?;
69 let diag_vecs = embed_batch(DIAGNOSTIC_EXAMPLES, api_url).await?;
70
71 let adv_centroid = mean_centroid(&adv_vecs)?;
72 let diag_centroid = mean_centroid(&diag_vecs)?;
73
74 eprintln!(
75 "[intent_embed] centroids ready ({} advisory, {} diagnostic examples)",
76 adv_vecs.len(),
77 diag_vecs.len()
78 );
79 Some((adv_centroid, diag_centroid))
80}
81
82const ADVISORY_EXAMPLES: &[&str] = &[
87 "would more ram help with this",
88 "should I upgrade my GPU",
89 "is that worth buying",
90 "could I offload VRAM to system RAM",
91 "i think the cpu is fine",
92 "what if I had a faster SSD",
93 "makes sense so the network is slow",
94 "so the ram is the issue right",
95 "do you think I should restart",
96 "is it worth getting more storage",
97 "if i upgraded the gpu would that help",
98 "i believe the service is running",
99 "i see the memory is fine",
100 "everything looks good here",
101 "ok so the cpu is at 8 percent that seems fine",
102 "i think the service is already running",
103 "my vram situation seems to be improving",
104 "makes sense that the disk would be slow",
105 "yeah that all adds up",
106 "so the network was just congested",
107 "that explains why the gpu was hot",
108 "ah ok so it was the ram all along",
109 "i guess the service crashed overnight",
110 "would adding another monitor hurt gpu performance",
111 "so basically the ssd is the bottleneck right",
112];
113
114const DIAGNOSTIC_EXAMPLES: &[&str] = &[
117 "how much RAM do I have",
118 "show me running processes",
119 "what is my CPU usage right now",
120 "check my disk health",
121 "why is my PC slow",
122 "what services are running",
123 "list my network adapters",
124 "what GPU do I have",
125 "is my firewall on",
126 "show me recent errors",
127 "what is my IP address",
128 "check my wifi signal strength",
129 "how much free disk space do I have",
130 "what is taking up all my memory",
131 "show hardware specs",
132 "what processes are using the most RAM",
133 "is my bluetooth working",
134 "check my disk for errors",
135 "what network connections are active",
136 "show me the system logs",
137 "what is the cpu temperature",
138 "are there any pending windows updates",
139 "is the docker daemon running",
140 "check my battery status",
141 "what is my gpu driver version",
142];
143
144async fn embed_query(text: &str, api_url: &str) -> Option<Vec<f32>> {
147 let input = format!("search_query: {text}");
149 embed_single(&input, api_url).await
150}
151
152async fn embed_batch(texts: &[&str], api_url: &str) -> Option<Vec<Vec<f32>>> {
153 let inputs: Vec<String> = texts
155 .iter()
156 .map(|t| format!("search_document: {t}"))
157 .collect();
158
159 let embed_model = load_embed_model(api_url)?;
160 let trimmed = api_url.trim_end_matches('/');
161 let is_ollama = trimmed.contains("11434");
162 let body = serde_json::json!({
163 "model": embed_model,
164 "input": inputs
165 });
166
167 let url = if is_ollama {
168 format!("{}/api/embed", trimmed)
169 } else {
170 format!("{}/v1/embeddings", trimmed)
171 };
172
173 let client = reqwest::Client::builder()
174 .timeout(std::time::Duration::from_secs(20))
175 .build()
176 .ok()?;
177
178 let resp = client
179 .post(&url)
180 .header("Content-Type", "application/json")
181 .json(&body)
182 .send()
183 .await
184 .ok()?;
185
186 if !resp.status().is_success() {
187 return None;
188 }
189
190 let json: serde_json::Value = resp.json().await.ok()?;
191 let data = if is_ollama {
192 json["embeddings"].as_array()?
193 } else {
194 json["data"].as_array()?
195 };
196
197 let vecs: Vec<Vec<f32>> = data
198 .iter()
199 .filter_map(|item| {
200 let arr = if is_ollama {
201 item.as_array()
202 } else {
203 item["embedding"].as_array()
204 }?;
205 Some(
206 arr.iter()
207 .filter_map(|v| v.as_f64().map(|f| f as f32))
208 .collect(),
209 )
210 })
211 .collect();
212
213 if vecs.len() != texts.len() {
214 None
215 } else {
216 Some(vecs)
217 }
218}
219
220async fn embed_single(input: &str, api_url: &str) -> Option<Vec<f32>> {
221 let embed_model = load_embed_model(api_url)?;
222 let trimmed = api_url.trim_end_matches('/');
223 let is_ollama = trimmed.contains("11434");
224 let body = serde_json::json!({
225 "model": embed_model,
226 "input": input
227 });
228
229 let url = if is_ollama {
230 format!("{}/api/embed", trimmed)
231 } else {
232 format!("{}/v1/embeddings", trimmed)
233 };
234
235 let client = reqwest::Client::builder()
236 .timeout(std::time::Duration::from_secs(5))
237 .build()
238 .ok()?;
239
240 let resp = client
241 .post(&url)
242 .header("Content-Type", "application/json")
243 .json(&body)
244 .send()
245 .await
246 .ok()?;
247
248 if !resp.status().is_success() {
249 return None;
250 }
251
252 let json: serde_json::Value = resp.json().await.ok()?;
253 let arr = if is_ollama {
254 json["embeddings"][0].as_array()?
255 } else {
256 json["data"][0]["embedding"].as_array()?
257 };
258 let vec: Vec<f32> = arr
259 .iter()
260 .filter_map(|v| v.as_f64().map(|f| f as f32))
261 .collect();
262
263 if vec.is_empty() {
264 None
265 } else {
266 Some(vec)
267 }
268}
269
270fn load_embed_model(_api_url: &str) -> Option<String> {
271 let config = crate::agent::config::load_config();
272 let saved = config.embed_model?;
273 if saved.trim().is_empty() {
274 None
275 } else {
276 Some(saved)
277 }
278}
279
280fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
283 if a.len() != b.len() || a.is_empty() {
284 return 0.0;
285 }
286 let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
287 let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
288 let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
289 if norm_a == 0.0 || norm_b == 0.0 {
290 return 0.0;
291 }
292 dot / (norm_a * norm_b)
293}
294
295fn mean_centroid(vecs: &[Vec<f32>]) -> Option<Vec<f32>> {
296 if vecs.is_empty() {
297 return None;
298 }
299 let dim = vecs[0].len();
300 if dim == 0 {
301 return None;
302 }
303 let mut sum = vec![0.0f32; dim];
304 for v in vecs {
305 if v.len() != dim {
306 return None;
307 }
308 for (s, x) in sum.iter_mut().zip(v.iter()) {
309 *s += x;
310 }
311 }
312 let n = vecs.len() as f32;
313 Some(sum.into_iter().map(|x| x / n).collect())
314}
315
316fn classify_from_scores(advisory: f32, diagnostic: f32) -> IntentClass {
317 const ADVISORY_MIN: f32 = 0.72; const DIAGNOSTIC_MIN: f32 = 0.68; const MIN_GAP: f32 = 0.08; if advisory >= ADVISORY_MIN && advisory > diagnostic + MIN_GAP {
325 IntentClass::Advisory
326 } else if diagnostic >= DIAGNOSTIC_MIN && diagnostic > advisory + MIN_GAP {
327 IntentClass::Diagnostic
328 } else {
329 IntentClass::Ambiguous
330 }
331}
332
333#[cfg(test)]
336mod tests {
337 use super::*;
338
339 #[test]
340 fn cosine_identical_vectors() {
341 let v = vec![1.0, 2.0, 3.0];
342 assert!((cosine_similarity(&v, &v) - 1.0).abs() < 1e-5);
343 }
344
345 #[test]
346 fn cosine_orthogonal_vectors() {
347 let a = vec![1.0, 0.0];
348 let b = vec![0.0, 1.0];
349 assert!(cosine_similarity(&a, &b).abs() < 1e-5);
350 }
351
352 #[test]
353 fn centroid_of_two_identical() {
354 let vecs = vec![vec![1.0, 2.0], vec![1.0, 2.0]];
355 let c = mean_centroid(&vecs).unwrap();
356 assert!((c[0] - 1.0).abs() < 1e-5);
357 assert!((c[1] - 2.0).abs() < 1e-5);
358 }
359
360 #[test]
361 fn classify_from_scores_advisory() {
362 assert_eq!(classify_from_scores(0.80, 0.60), IntentClass::Advisory);
363 }
364
365 #[test]
366 fn classify_from_scores_diagnostic() {
367 assert_eq!(classify_from_scores(0.55, 0.78), IntentClass::Diagnostic);
368 }
369
370 #[test]
371 fn classify_from_scores_ambiguous_close_gap() {
372 assert_eq!(classify_from_scores(0.74, 0.70), IntentClass::Ambiguous);
373 }
374
375 #[test]
376 fn classify_from_scores_ambiguous_low_scores() {
377 assert_eq!(classify_from_scores(0.50, 0.40), IntentClass::Ambiguous);
378 }
379}