Skip to main content

hematite/agent/
intent_embed.rs

1// Embedding-based intent classifier — semantic pre-filter for routing decisions.
2//
3// Uses nomic-embed-text-v2 (already loaded in LM Studio alongside the main model)
4// to verify whether a user query is genuinely diagnostic or conversational.
5//
6// When the keyword router would inject HOST INSPECTION MODE, this classifier runs
7// as a second-opinion pass. If it returns Advisory with high confidence, the
8// injection is suppressed and the model answers from context instead of fetching
9// fresh machine data.
10//
11// Centroids are bootstrapped lazily on first use by batch-embedding a small set
12// of labeled example phrases (~100ms total, one API call). Subsequent calls embed
13// only the query (~50ms). Falls back to Ambiguous silently if the embed model is
14// unavailable or slow — keyword routing continues as before.
15
16use tokio::sync::OnceCell;
17
18// ── Public API ────────────────────────────────────────────────────────────────
19
20#[derive(Debug, Clone, Copy, PartialEq)]
21pub enum IntentClass {
22    /// Conversational, advisory, or declarative. Suppress HOST INSPECTION MODE —
23    /// the model should answer from existing context, not fetch new data.
24    Advisory,
25    /// Clearly diagnostic. The keyword router's topic choice is correct.
26    Diagnostic,
27    /// Uncertain. Defer to keyword router as before.
28    Ambiguous,
29}
30
31/// Classify user query intent using embedding similarity against labeled centroids.
32///
33/// Only called when the keyword router has already returned `host_inspection_mode = true`,
34/// so this is a veto path, not the primary routing path. Returning Ambiguous is always
35/// safe — it just falls through to existing behavior.
36pub async fn classify_intent(query: &str, api_url: &str) -> IntentClass {
37    let centroids = ensure_centroids(api_url).await;
38    let (adv_centroid, diag_centroid) = match centroids {
39        Some(c) => c,
40        None => return IntentClass::Ambiguous,
41    };
42
43    let query_vec = match embed_query(query, api_url).await {
44        Some(v) => v,
45        None => return IntentClass::Ambiguous,
46    };
47
48    let advisory_score = cosine_similarity(&query_vec, adv_centroid);
49    let diagnostic_score = cosine_similarity(&query_vec, diag_centroid);
50
51    classify_from_scores(advisory_score, diagnostic_score)
52}
53
54// ── Centroid bootstrap ────────────────────────────────────────────────────────
55
56// Stored as (advisory_centroid, diagnostic_centroid). None = embed model unavailable.
57static CENTROIDS: OnceCell<Option<(Vec<f32>, Vec<f32>)>> = OnceCell::const_new();
58
59async fn ensure_centroids(api_url: &str) -> Option<&'static (Vec<f32>, Vec<f32>)> {
60    let url = api_url.to_string();
61    let opt = CENTROIDS
62        .get_or_init(|| async move { compute_centroids(&url).await })
63        .await;
64    opt.as_ref()
65}
66
67async fn compute_centroids(api_url: &str) -> Option<(Vec<f32>, Vec<f32>)> {
68    let adv_vecs = embed_batch(ADVISORY_EXAMPLES, api_url).await?;
69    let diag_vecs = embed_batch(DIAGNOSTIC_EXAMPLES, api_url).await?;
70
71    let adv_centroid = mean_centroid(&adv_vecs)?;
72    let diag_centroid = mean_centroid(&diag_vecs)?;
73
74    eprintln!(
75        "[intent_embed] centroids ready ({} advisory, {} diagnostic examples)",
76        adv_vecs.len(),
77        diag_vecs.len()
78    );
79    Some((adv_centroid, diag_centroid))
80}
81
82// ── Example phrases ───────────────────────────────────────────────────────────
83
84// Advisory examples — the model should NOT call inspect_host for these.
85// Covers: opinion questions, hypotheticals, declarative statements, acknowledgments.
86const ADVISORY_EXAMPLES: &[&str] = &[
87    "would more ram help with this",
88    "should I upgrade my GPU",
89    "is that worth buying",
90    "could I offload VRAM to system RAM",
91    "i think the cpu is fine",
92    "what if I had a faster SSD",
93    "makes sense so the network is slow",
94    "so the ram is the issue right",
95    "do you think I should restart",
96    "is it worth getting more storage",
97    "if i upgraded the gpu would that help",
98    "i believe the service is running",
99    "i see the memory is fine",
100    "everything looks good here",
101    "ok so the cpu is at 8 percent that seems fine",
102    "i think the service is already running",
103    "my vram situation seems to be improving",
104    "makes sense that the disk would be slow",
105    "yeah that all adds up",
106    "so the network was just congested",
107    "that explains why the gpu was hot",
108    "ah ok so it was the ram all along",
109    "i guess the service crashed overnight",
110    "would adding another monitor hurt gpu performance",
111    "so basically the ssd is the bottleneck right",
112];
113
114// Diagnostic examples — the model SHOULD call inspect_host for these.
115// Covers: data requests, status checks, show/list/check commands.
116const DIAGNOSTIC_EXAMPLES: &[&str] = &[
117    "how much RAM do I have",
118    "show me running processes",
119    "what is my CPU usage right now",
120    "check my disk health",
121    "why is my PC slow",
122    "what services are running",
123    "list my network adapters",
124    "what GPU do I have",
125    "is my firewall on",
126    "show me recent errors",
127    "what is my IP address",
128    "check my wifi signal strength",
129    "how much free disk space do I have",
130    "what is taking up all my memory",
131    "show hardware specs",
132    "what processes are using the most RAM",
133    "is my bluetooth working",
134    "check my disk for errors",
135    "what network connections are active",
136    "show me the system logs",
137    "what is the cpu temperature",
138    "are there any pending windows updates",
139    "is the docker daemon running",
140    "check my battery status",
141    "what is my gpu driver version",
142];
143
144// ── Embedding helpers ─────────────────────────────────────────────────────────
145
146async fn embed_query(text: &str, api_url: &str) -> Option<Vec<f32>> {
147    // nomic-embed-text-v2 uses task instruction prefixes
148    let input = format!("search_query: {text}");
149    embed_single(&input, api_url).await
150}
151
152async fn embed_batch(texts: &[&str], api_url: &str) -> Option<Vec<Vec<f32>>> {
153    // Batch embed with document prefix — one API call for all examples
154    let inputs: Vec<String> = texts
155        .iter()
156        .map(|t| format!("search_document: {t}"))
157        .collect();
158
159    let embed_model = load_embed_model(api_url)?;
160    let trimmed = api_url.trim_end_matches('/');
161    let is_ollama = trimmed.contains("11434");
162    let body = serde_json::json!({
163        "model": embed_model,
164        "input": inputs
165    });
166
167    let url = if is_ollama {
168        format!("{}/api/embed", trimmed)
169    } else {
170        format!("{}/v1/embeddings", trimmed)
171    };
172
173    let client = reqwest::Client::builder()
174        .timeout(std::time::Duration::from_secs(20))
175        .build()
176        .ok()?;
177
178    let resp = client
179        .post(&url)
180        .header("Content-Type", "application/json")
181        .json(&body)
182        .send()
183        .await
184        .ok()?;
185
186    if !resp.status().is_success() {
187        return None;
188    }
189
190    let json: serde_json::Value = resp.json().await.ok()?;
191    let data = if is_ollama {
192        json["embeddings"].as_array()?
193    } else {
194        json["data"].as_array()?
195    };
196
197    let vecs: Vec<Vec<f32>> = data
198        .iter()
199        .filter_map(|item| {
200            let arr = if is_ollama {
201                item.as_array()
202            } else {
203                item["embedding"].as_array()
204            }?;
205            Some(
206                arr.iter()
207                    .filter_map(|v| v.as_f64().map(|f| f as f32))
208                    .collect(),
209            )
210        })
211        .collect();
212
213    if vecs.len() != texts.len() {
214        None
215    } else {
216        Some(vecs)
217    }
218}
219
220async fn embed_single(input: &str, api_url: &str) -> Option<Vec<f32>> {
221    let embed_model = load_embed_model(api_url)?;
222    let trimmed = api_url.trim_end_matches('/');
223    let is_ollama = trimmed.contains("11434");
224    let body = serde_json::json!({
225        "model": embed_model,
226        "input": input
227    });
228
229    let url = if is_ollama {
230        format!("{}/api/embed", trimmed)
231    } else {
232        format!("{}/v1/embeddings", trimmed)
233    };
234
235    let client = reqwest::Client::builder()
236        .timeout(std::time::Duration::from_secs(5))
237        .build()
238        .ok()?;
239
240    let resp = client
241        .post(&url)
242        .header("Content-Type", "application/json")
243        .json(&body)
244        .send()
245        .await
246        .ok()?;
247
248    if !resp.status().is_success() {
249        return None;
250    }
251
252    let json: serde_json::Value = resp.json().await.ok()?;
253    let arr = if is_ollama {
254        json["embeddings"][0].as_array()?
255    } else {
256        json["data"][0]["embedding"].as_array()?
257    };
258    let vec: Vec<f32> = arr
259        .iter()
260        .filter_map(|v| v.as_f64().map(|f| f as f32))
261        .collect();
262
263    if vec.is_empty() {
264        None
265    } else {
266        Some(vec)
267    }
268}
269
270fn load_embed_model(_api_url: &str) -> Option<String> {
271    let config = crate::agent::config::load_config();
272    let saved = config.embed_model?;
273    if saved.trim().is_empty() {
274        None
275    } else {
276        Some(saved)
277    }
278}
279
280// ── Vector math ───────────────────────────────────────────────────────────────
281
282fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
283    if a.len() != b.len() || a.is_empty() {
284        return 0.0;
285    }
286    let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
287    let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
288    let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
289    if norm_a == 0.0 || norm_b == 0.0 {
290        return 0.0;
291    }
292    dot / (norm_a * norm_b)
293}
294
295fn mean_centroid(vecs: &[Vec<f32>]) -> Option<Vec<f32>> {
296    if vecs.is_empty() {
297        return None;
298    }
299    let dim = vecs[0].len();
300    if dim == 0 {
301        return None;
302    }
303    let mut sum = vec![0.0f32; dim];
304    for v in vecs {
305        if v.len() != dim {
306            return None;
307        }
308        for (s, x) in sum.iter_mut().zip(v.iter()) {
309            *s += x;
310        }
311    }
312    let n = vecs.len() as f32;
313    Some(sum.into_iter().map(|x| x / n).collect())
314}
315
316fn classify_from_scores(advisory: f32, diagnostic: f32) -> IntentClass {
317    // Require meaningful separation — if they're close, stay ambiguous.
318    // Tuned conservatively: suppressing a real diagnostic query is worse than
319    // failing to suppress a conversational one (keyword guard handles most of those).
320    const ADVISORY_MIN: f32 = 0.72; // minimum score to declare advisory
321    const DIAGNOSTIC_MIN: f32 = 0.68; // minimum score to declare diagnostic
322    const MIN_GAP: f32 = 0.08; // required margin over the other class
323
324    if advisory >= ADVISORY_MIN && advisory > diagnostic + MIN_GAP {
325        IntentClass::Advisory
326    } else if diagnostic >= DIAGNOSTIC_MIN && diagnostic > advisory + MIN_GAP {
327        IntentClass::Diagnostic
328    } else {
329        IntentClass::Ambiguous
330    }
331}
332
333// ── Tests ─────────────────────────────────────────────────────────────────────
334
335#[cfg(test)]
336mod tests {
337    use super::*;
338
339    #[test]
340    fn cosine_identical_vectors() {
341        let v = vec![1.0, 2.0, 3.0];
342        assert!((cosine_similarity(&v, &v) - 1.0).abs() < 1e-5);
343    }
344
345    #[test]
346    fn cosine_orthogonal_vectors() {
347        let a = vec![1.0, 0.0];
348        let b = vec![0.0, 1.0];
349        assert!(cosine_similarity(&a, &b).abs() < 1e-5);
350    }
351
352    #[test]
353    fn centroid_of_two_identical() {
354        let vecs = vec![vec![1.0, 2.0], vec![1.0, 2.0]];
355        let c = mean_centroid(&vecs).unwrap();
356        assert!((c[0] - 1.0).abs() < 1e-5);
357        assert!((c[1] - 2.0).abs() < 1e-5);
358    }
359
360    #[test]
361    fn classify_from_scores_advisory() {
362        assert_eq!(classify_from_scores(0.80, 0.60), IntentClass::Advisory);
363    }
364
365    #[test]
366    fn classify_from_scores_diagnostic() {
367        assert_eq!(classify_from_scores(0.55, 0.78), IntentClass::Diagnostic);
368    }
369
370    #[test]
371    fn classify_from_scores_ambiguous_close_gap() {
372        assert_eq!(classify_from_scores(0.74, 0.70), IntentClass::Ambiguous);
373    }
374
375    #[test]
376    fn classify_from_scores_ambiguous_low_scores() {
377        assert_eq!(classify_from_scores(0.50, 0.40), IntentClass::Ambiguous);
378    }
379}