1use tokio::sync::OnceCell;
17
18#[derive(Debug, Clone, Copy, PartialEq)]
21pub enum IntentClass {
22 Advisory,
25 Diagnostic,
27 Ambiguous,
29}
30
31pub async fn classify_intent(query: &str, api_url: &str) -> IntentClass {
37 let centroids = ensure_centroids(api_url).await;
38 let (adv_centroid, diag_centroid) = match centroids {
39 Some(c) => c,
40 None => return IntentClass::Ambiguous,
41 };
42
43 let query_vec = match embed_query(query, api_url).await {
44 Some(v) => v,
45 None => return IntentClass::Ambiguous,
46 };
47
48 let advisory_score = cosine_similarity(&query_vec, adv_centroid);
49 let diagnostic_score = cosine_similarity(&query_vec, diag_centroid);
50
51 classify_from_scores(advisory_score, diagnostic_score)
52}
53
54static CENTROIDS: OnceCell<Option<(Vec<f32>, Vec<f32>)>> = OnceCell::const_new();
58
59async fn ensure_centroids(api_url: &str) -> Option<&'static (Vec<f32>, Vec<f32>)> {
60 let url = api_url.to_string();
61 let opt = CENTROIDS
62 .get_or_init(|| async move { compute_centroids(&url).await })
63 .await;
64 opt.as_ref()
65}
66
67async fn compute_centroids(api_url: &str) -> Option<(Vec<f32>, Vec<f32>)> {
68 let adv_vecs = embed_batch(ADVISORY_EXAMPLES, api_url).await?;
69 let diag_vecs = embed_batch(DIAGNOSTIC_EXAMPLES, api_url).await?;
70
71 let adv_centroid = mean_centroid(&adv_vecs)?;
72 let diag_centroid = mean_centroid(&diag_vecs)?;
73
74 eprintln!(
75 "[intent_embed] centroids ready ({} advisory, {} diagnostic examples)",
76 adv_vecs.len(),
77 diag_vecs.len()
78 );
79 Some((adv_centroid, diag_centroid))
80}
81
82const ADVISORY_EXAMPLES: &[&str] = &[
87 "would more ram help with this",
88 "should I upgrade my GPU",
89 "is that worth buying",
90 "could I offload VRAM to system RAM",
91 "i think the cpu is fine",
92 "what if I had a faster SSD",
93 "makes sense so the network is slow",
94 "so the ram is the issue right",
95 "do you think I should restart",
96 "is it worth getting more storage",
97 "if i upgraded the gpu would that help",
98 "i believe the service is running",
99 "i see the memory is fine",
100 "everything looks good here",
101 "ok so the cpu is at 8 percent that seems fine",
102 "i think the service is already running",
103 "my vram situation seems to be improving",
104 "makes sense that the disk would be slow",
105 "yeah that all adds up",
106 "so the network was just congested",
107 "that explains why the gpu was hot",
108 "ah ok so it was the ram all along",
109 "i guess the service crashed overnight",
110 "would adding another monitor hurt gpu performance",
111 "so basically the ssd is the bottleneck right",
112];
113
114const DIAGNOSTIC_EXAMPLES: &[&str] = &[
117 "how much RAM do I have",
118 "show me running processes",
119 "what is my CPU usage right now",
120 "check my disk health",
121 "why is my PC slow",
122 "what services are running",
123 "list my network adapters",
124 "what GPU do I have",
125 "is my firewall on",
126 "show me recent errors",
127 "what is my IP address",
128 "check my wifi signal strength",
129 "how much free disk space do I have",
130 "what is taking up all my memory",
131 "show hardware specs",
132 "what processes are using the most RAM",
133 "is my bluetooth working",
134 "check my disk for errors",
135 "what network connections are active",
136 "show me the system logs",
137 "what is the cpu temperature",
138 "are there any pending windows updates",
139 "is the docker daemon running",
140 "check my battery status",
141 "what is my gpu driver version",
142];
143
144async fn embed_query(text: &str, api_url: &str) -> Option<Vec<f32>> {
147 let input = format!("search_query: {text}");
149 embed_single(&input, api_url).await
150}
151
152async fn embed_batch(texts: &[&str], api_url: &str) -> Option<Vec<Vec<f32>>> {
153 let inputs: Vec<String> = texts
155 .iter()
156 .map(|t| format!("search_document: {t}"))
157 .collect();
158
159 let body = serde_json::json!({
160 "model": "nomic-embed-text-v2",
161 "input": inputs
162 });
163
164 let url = format!("{}/v1/embeddings", api_url.trim_end_matches('/'));
165
166 let client = reqwest::Client::builder()
167 .timeout(std::time::Duration::from_secs(20))
168 .build()
169 .ok()?;
170
171 let resp = client
172 .post(&url)
173 .header("Content-Type", "application/json")
174 .json(&body)
175 .send()
176 .await
177 .ok()?;
178
179 if !resp.status().is_success() {
180 return None;
181 }
182
183 let json: serde_json::Value = resp.json().await.ok()?;
184 let data = json["data"].as_array()?;
185
186 let vecs: Vec<Vec<f32>> = data
187 .iter()
188 .filter_map(|item| {
189 item["embedding"].as_array().map(|arr| {
190 arr.iter()
191 .filter_map(|v| v.as_f64().map(|f| f as f32))
192 .collect()
193 })
194 })
195 .collect();
196
197 if vecs.len() != texts.len() {
198 None
199 } else {
200 Some(vecs)
201 }
202}
203
204async fn embed_single(input: &str, api_url: &str) -> Option<Vec<f32>> {
205 let body = serde_json::json!({
206 "model": "nomic-embed-text-v2",
207 "input": input
208 });
209
210 let url = format!("{}/v1/embeddings", api_url.trim_end_matches('/'));
211
212 let client = reqwest::Client::builder()
213 .timeout(std::time::Duration::from_secs(5))
214 .build()
215 .ok()?;
216
217 let resp = client
218 .post(&url)
219 .header("Content-Type", "application/json")
220 .json(&body)
221 .send()
222 .await
223 .ok()?;
224
225 if !resp.status().is_success() {
226 return None;
227 }
228
229 let json: serde_json::Value = resp.json().await.ok()?;
230 let arr = json["data"][0]["embedding"].as_array()?;
231 let vec: Vec<f32> = arr
232 .iter()
233 .filter_map(|v| v.as_f64().map(|f| f as f32))
234 .collect();
235
236 if vec.is_empty() {
237 None
238 } else {
239 Some(vec)
240 }
241}
242
243fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
246 if a.len() != b.len() || a.is_empty() {
247 return 0.0;
248 }
249 let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
250 let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
251 let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
252 if norm_a == 0.0 || norm_b == 0.0 {
253 return 0.0;
254 }
255 dot / (norm_a * norm_b)
256}
257
258fn mean_centroid(vecs: &[Vec<f32>]) -> Option<Vec<f32>> {
259 if vecs.is_empty() {
260 return None;
261 }
262 let dim = vecs[0].len();
263 if dim == 0 {
264 return None;
265 }
266 let mut sum = vec![0.0f32; dim];
267 for v in vecs {
268 if v.len() != dim {
269 return None;
270 }
271 for (s, x) in sum.iter_mut().zip(v.iter()) {
272 *s += x;
273 }
274 }
275 let n = vecs.len() as f32;
276 Some(sum.into_iter().map(|x| x / n).collect())
277}
278
279fn classify_from_scores(advisory: f32, diagnostic: f32) -> IntentClass {
280 const ADVISORY_MIN: f32 = 0.72; const DIAGNOSTIC_MIN: f32 = 0.68; const MIN_GAP: f32 = 0.08; if advisory >= ADVISORY_MIN && advisory > diagnostic + MIN_GAP {
288 IntentClass::Advisory
289 } else if diagnostic >= DIAGNOSTIC_MIN && diagnostic > advisory + MIN_GAP {
290 IntentClass::Diagnostic
291 } else {
292 IntentClass::Ambiguous
293 }
294}
295
296#[cfg(test)]
299mod tests {
300 use super::*;
301
302 #[test]
303 fn cosine_identical_vectors() {
304 let v = vec![1.0, 2.0, 3.0];
305 assert!((cosine_similarity(&v, &v) - 1.0).abs() < 1e-5);
306 }
307
308 #[test]
309 fn cosine_orthogonal_vectors() {
310 let a = vec![1.0, 0.0];
311 let b = vec![0.0, 1.0];
312 assert!(cosine_similarity(&a, &b).abs() < 1e-5);
313 }
314
315 #[test]
316 fn centroid_of_two_identical() {
317 let vecs = vec![vec![1.0, 2.0], vec![1.0, 2.0]];
318 let c = mean_centroid(&vecs).unwrap();
319 assert!((c[0] - 1.0).abs() < 1e-5);
320 assert!((c[1] - 2.0).abs() < 1e-5);
321 }
322
323 #[test]
324 fn classify_from_scores_advisory() {
325 assert_eq!(classify_from_scores(0.80, 0.60), IntentClass::Advisory);
326 }
327
328 #[test]
329 fn classify_from_scores_diagnostic() {
330 assert_eq!(classify_from_scores(0.55, 0.78), IntentClass::Diagnostic);
331 }
332
333 #[test]
334 fn classify_from_scores_ambiguous_close_gap() {
335 assert_eq!(classify_from_scores(0.74, 0.70), IntentClass::Ambiguous);
336 }
337
338 #[test]
339 fn classify_from_scores_ambiguous_low_scores() {
340 assert_eq!(classify_from_scores(0.50, 0.40), IntentClass::Ambiguous);
341 }
342}