Skip to main content

hematite/agent/
intent_embed.rs

1// Embedding-based intent classifier — semantic pre-filter for routing decisions.
2//
3// Uses nomic-embed-text-v2 (already loaded in LM Studio alongside the main model)
4// to verify whether a user query is genuinely diagnostic or conversational.
5//
6// When the keyword router would inject HOST INSPECTION MODE, this classifier runs
7// as a second-opinion pass. If it returns Advisory with high confidence, the
8// injection is suppressed and the model answers from context instead of fetching
9// fresh machine data.
10//
11// Centroids are bootstrapped lazily on first use by batch-embedding a small set
12// of labeled example phrases (~100ms total, one API call). Subsequent calls embed
13// only the query (~50ms). Falls back to Ambiguous silently if the embed model is
14// unavailable or slow — keyword routing continues as before.
15
16use tokio::sync::OnceCell;
17
18// ── Public API ────────────────────────────────────────────────────────────────
19
20#[derive(Debug, Clone, Copy, PartialEq)]
21pub enum IntentClass {
22    /// Conversational, advisory, or declarative. Suppress HOST INSPECTION MODE —
23    /// the model should answer from existing context, not fetch new data.
24    Advisory,
25    /// Clearly diagnostic. The keyword router's topic choice is correct.
26    Diagnostic,
27    /// Uncertain. Defer to keyword router as before.
28    Ambiguous,
29}
30
31/// Classify user query intent using embedding similarity against labeled centroids.
32///
33/// Only called when the keyword router has already returned `host_inspection_mode = true`,
34/// so this is a veto path, not the primary routing path. Returning Ambiguous is always
35/// safe — it just falls through to existing behavior.
36pub async fn classify_intent(query: &str, api_url: &str) -> IntentClass {
37    let centroids = ensure_centroids(api_url).await;
38    let (adv_centroid, diag_centroid) = match centroids {
39        Some(c) => c,
40        None => return IntentClass::Ambiguous,
41    };
42
43    let query_vec = match embed_query(query, api_url).await {
44        Some(v) => v,
45        None => return IntentClass::Ambiguous,
46    };
47
48    let advisory_score = cosine_similarity(&query_vec, adv_centroid);
49    let diagnostic_score = cosine_similarity(&query_vec, diag_centroid);
50
51    classify_from_scores(advisory_score, diagnostic_score)
52}
53
54// ── Centroid bootstrap ────────────────────────────────────────────────────────
55
56// Stored as (advisory_centroid, diagnostic_centroid). None = embed model unavailable.
57static CENTROIDS: OnceCell<Option<(Vec<f32>, Vec<f32>)>> = OnceCell::const_new();
58
59async fn ensure_centroids(api_url: &str) -> Option<&'static (Vec<f32>, Vec<f32>)> {
60    let url = api_url.to_string();
61    let opt = CENTROIDS
62        .get_or_init(|| async move { compute_centroids(&url).await })
63        .await;
64    opt.as_ref()
65}
66
67async fn compute_centroids(api_url: &str) -> Option<(Vec<f32>, Vec<f32>)> {
68    let adv_vecs = embed_batch(ADVISORY_EXAMPLES, api_url).await?;
69    let diag_vecs = embed_batch(DIAGNOSTIC_EXAMPLES, api_url).await?;
70
71    let adv_centroid = mean_centroid(&adv_vecs)?;
72    let diag_centroid = mean_centroid(&diag_vecs)?;
73
74    eprintln!(
75        "[intent_embed] centroids ready ({} advisory, {} diagnostic examples)",
76        adv_vecs.len(),
77        diag_vecs.len()
78    );
79    Some((adv_centroid, diag_centroid))
80}
81
82// ── Example phrases ───────────────────────────────────────────────────────────
83
84// Advisory examples — the model should NOT call inspect_host for these.
85// Covers: opinion questions, hypotheticals, declarative statements, acknowledgments.
86const ADVISORY_EXAMPLES: &[&str] = &[
87    "would more ram help with this",
88    "should I upgrade my GPU",
89    "is that worth buying",
90    "could I offload VRAM to system RAM",
91    "i think the cpu is fine",
92    "what if I had a faster SSD",
93    "makes sense so the network is slow",
94    "so the ram is the issue right",
95    "do you think I should restart",
96    "is it worth getting more storage",
97    "if i upgraded the gpu would that help",
98    "i believe the service is running",
99    "i see the memory is fine",
100    "everything looks good here",
101    "ok so the cpu is at 8 percent that seems fine",
102    "i think the service is already running",
103    "my vram situation seems to be improving",
104    "makes sense that the disk would be slow",
105    "yeah that all adds up",
106    "so the network was just congested",
107    "that explains why the gpu was hot",
108    "ah ok so it was the ram all along",
109    "i guess the service crashed overnight",
110    "would adding another monitor hurt gpu performance",
111    "so basically the ssd is the bottleneck right",
112];
113
114// Diagnostic examples — the model SHOULD call inspect_host for these.
115// Covers: data requests, status checks, show/list/check commands.
116const DIAGNOSTIC_EXAMPLES: &[&str] = &[
117    "how much RAM do I have",
118    "show me running processes",
119    "what is my CPU usage right now",
120    "check my disk health",
121    "why is my PC slow",
122    "what services are running",
123    "list my network adapters",
124    "what GPU do I have",
125    "is my firewall on",
126    "show me recent errors",
127    "what is my IP address",
128    "check my wifi signal strength",
129    "how much free disk space do I have",
130    "what is taking up all my memory",
131    "show hardware specs",
132    "what processes are using the most RAM",
133    "is my bluetooth working",
134    "check my disk for errors",
135    "what network connections are active",
136    "show me the system logs",
137    "what is the cpu temperature",
138    "are there any pending windows updates",
139    "is the docker daemon running",
140    "check my battery status",
141    "what is my gpu driver version",
142];
143
144// ── Embedding helpers ─────────────────────────────────────────────────────────
145
146async fn embed_query(text: &str, api_url: &str) -> Option<Vec<f32>> {
147    // nomic-embed-text-v2 uses task instruction prefixes
148    let input = format!("search_query: {text}");
149    embed_single(&input, api_url).await
150}
151
152async fn embed_batch(texts: &[&str], api_url: &str) -> Option<Vec<Vec<f32>>> {
153    // Batch embed with document prefix — one API call for all examples
154    let inputs: Vec<String> = texts
155        .iter()
156        .map(|t| format!("search_document: {t}"))
157        .collect();
158
159    let body = serde_json::json!({
160        "model": "nomic-embed-text-v2",
161        "input": inputs
162    });
163
164    let url = format!("{}/v1/embeddings", api_url.trim_end_matches('/'));
165
166    let client = reqwest::Client::builder()
167        .timeout(std::time::Duration::from_secs(20))
168        .build()
169        .ok()?;
170
171    let resp = client
172        .post(&url)
173        .header("Content-Type", "application/json")
174        .json(&body)
175        .send()
176        .await
177        .ok()?;
178
179    if !resp.status().is_success() {
180        return None;
181    }
182
183    let json: serde_json::Value = resp.json().await.ok()?;
184    let data = json["data"].as_array()?;
185
186    let vecs: Vec<Vec<f32>> = data
187        .iter()
188        .filter_map(|item| {
189            item["embedding"].as_array().map(|arr| {
190                arr.iter()
191                    .filter_map(|v| v.as_f64().map(|f| f as f32))
192                    .collect()
193            })
194        })
195        .collect();
196
197    if vecs.len() != texts.len() {
198        None
199    } else {
200        Some(vecs)
201    }
202}
203
204async fn embed_single(input: &str, api_url: &str) -> Option<Vec<f32>> {
205    let body = serde_json::json!({
206        "model": "nomic-embed-text-v2",
207        "input": input
208    });
209
210    let url = format!("{}/v1/embeddings", api_url.trim_end_matches('/'));
211
212    let client = reqwest::Client::builder()
213        .timeout(std::time::Duration::from_secs(5))
214        .build()
215        .ok()?;
216
217    let resp = client
218        .post(&url)
219        .header("Content-Type", "application/json")
220        .json(&body)
221        .send()
222        .await
223        .ok()?;
224
225    if !resp.status().is_success() {
226        return None;
227    }
228
229    let json: serde_json::Value = resp.json().await.ok()?;
230    let arr = json["data"][0]["embedding"].as_array()?;
231    let vec: Vec<f32> = arr
232        .iter()
233        .filter_map(|v| v.as_f64().map(|f| f as f32))
234        .collect();
235
236    if vec.is_empty() {
237        None
238    } else {
239        Some(vec)
240    }
241}
242
243// ── Vector math ───────────────────────────────────────────────────────────────
244
245fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
246    if a.len() != b.len() || a.is_empty() {
247        return 0.0;
248    }
249    let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
250    let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
251    let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
252    if norm_a == 0.0 || norm_b == 0.0 {
253        return 0.0;
254    }
255    dot / (norm_a * norm_b)
256}
257
258fn mean_centroid(vecs: &[Vec<f32>]) -> Option<Vec<f32>> {
259    if vecs.is_empty() {
260        return None;
261    }
262    let dim = vecs[0].len();
263    if dim == 0 {
264        return None;
265    }
266    let mut sum = vec![0.0f32; dim];
267    for v in vecs {
268        if v.len() != dim {
269            return None;
270        }
271        for (s, x) in sum.iter_mut().zip(v.iter()) {
272            *s += x;
273        }
274    }
275    let n = vecs.len() as f32;
276    Some(sum.into_iter().map(|x| x / n).collect())
277}
278
279fn classify_from_scores(advisory: f32, diagnostic: f32) -> IntentClass {
280    // Require meaningful separation — if they're close, stay ambiguous.
281    // Tuned conservatively: suppressing a real diagnostic query is worse than
282    // failing to suppress a conversational one (keyword guard handles most of those).
283    const ADVISORY_MIN: f32 = 0.72; // minimum score to declare advisory
284    const DIAGNOSTIC_MIN: f32 = 0.68; // minimum score to declare diagnostic
285    const MIN_GAP: f32 = 0.08; // required margin over the other class
286
287    if advisory >= ADVISORY_MIN && advisory > diagnostic + MIN_GAP {
288        IntentClass::Advisory
289    } else if diagnostic >= DIAGNOSTIC_MIN && diagnostic > advisory + MIN_GAP {
290        IntentClass::Diagnostic
291    } else {
292        IntentClass::Ambiguous
293    }
294}
295
296// ── Tests ─────────────────────────────────────────────────────────────────────
297
298#[cfg(test)]
299mod tests {
300    use super::*;
301
302    #[test]
303    fn cosine_identical_vectors() {
304        let v = vec![1.0, 2.0, 3.0];
305        assert!((cosine_similarity(&v, &v) - 1.0).abs() < 1e-5);
306    }
307
308    #[test]
309    fn cosine_orthogonal_vectors() {
310        let a = vec![1.0, 0.0];
311        let b = vec![0.0, 1.0];
312        assert!(cosine_similarity(&a, &b).abs() < 1e-5);
313    }
314
315    #[test]
316    fn centroid_of_two_identical() {
317        let vecs = vec![vec![1.0, 2.0], vec![1.0, 2.0]];
318        let c = mean_centroid(&vecs).unwrap();
319        assert!((c[0] - 1.0).abs() < 1e-5);
320        assert!((c[1] - 2.0).abs() < 1e-5);
321    }
322
323    #[test]
324    fn classify_from_scores_advisory() {
325        assert_eq!(classify_from_scores(0.80, 0.60), IntentClass::Advisory);
326    }
327
328    #[test]
329    fn classify_from_scores_diagnostic() {
330        assert_eq!(classify_from_scores(0.55, 0.78), IntentClass::Diagnostic);
331    }
332
333    #[test]
334    fn classify_from_scores_ambiguous_close_gap() {
335        assert_eq!(classify_from_scores(0.74, 0.70), IntentClass::Ambiguous);
336    }
337
338    #[test]
339    fn classify_from_scores_ambiguous_low_scores() {
340        assert_eq!(classify_from_scores(0.50, 0.40), IntentClass::Ambiguous);
341    }
342}