Skip to main content

innate_core/
llm.rs

1use serde_json::{json, Value};
2use std::time::Duration;
3
4use crate::embedding::EmbeddingProvider;
5use crate::errors::{InnateError, Result};
6use crate::refine::{DistillProvenance, DistilledChunk, Distiller, Reranker};
7use crate::settings::{EmbeddingConfig, LlmConfig};
8
9// ---------------------------------------------------------------------------
10// Prompt for distillation
11// ---------------------------------------------------------------------------
12
13const DISTILL_PROMPT_VERSION: &str = "4";
14
15fn safe_prompt_field(value: Option<&str>) -> String {
16    let value = value.unwrap_or("");
17    let (cleaned, action) = crate::utils::sanitize(value);
18    match action {
19        crate::utils::SanitizeAction::Discard => "[removed unsafe content]".to_string(),
20        _ => cleaned,
21    }
22}
23
24fn build_distill_prompt(log: &Value) -> String {
25    let query = safe_prompt_field(log.get("query").and_then(Value::as_str));
26    let output = safe_prompt_field(log.get("output").and_then(Value::as_str));
27    let output_summary = safe_prompt_field(log.get("output_summary").and_then(Value::as_str));
28    let nomination = safe_prompt_field(log.get("nomination").and_then(Value::as_str));
29    let outcome = safe_prompt_field(log.get("outcome").and_then(Value::as_str));
30
31    let mut context_parts = vec![];
32    if !query.is_empty() {
33        context_parts.push(format!("Query: {query}"));
34    }
35    if !nomination.is_empty() {
36        context_parts.push(format!("Nominated insight: {nomination}"));
37    }
38    if !output_summary.is_empty() {
39        context_parts.push(format!("Summary: {output_summary}"));
40    }
41    if !output.is_empty() {
42        let truncated: String = output.chars().take(1500).collect();
43        context_parts.push(format!("Output (truncated): {truncated}"));
44    }
45    if !outcome.is_empty() {
46        context_parts.push(format!("Outcome: {outcome}"));
47    }
48
49    let context = context_parts.join("\n");
50
51    format!(
52        r#"You are a knowledge distillation assistant. Given an agent interaction log, \
53extract zero or more independent reusable procedural principles. Favor GENERAL, \
54transferable skills, methods, and techniques over project-specific facts.
55
56Agent interaction:
57{context}
58
59Output a JSON array. Each item has:
60{{
61  "skill_name": "<1-3 word skill/topic label for this principle>",
62  "content": "<principle; when it applies; what to avoid>",
63  "trigger_desc": "<2-6 word canonical phrase>",
64  "anti_trigger_desc": "<when NOT to apply this, or null>"
65}}
66Return [] if nothing is worth keeping.
67
68Rules:
69- skill_name is a short human label (1-3 words) naming the skill/topic, e.g.
70  "error handling", "git rebase", "async retries"; not a sentence
71- content must be self-contained and actionable for a future agent reading cold
72- Prefer transferable methods and techniques; a principle that helps across many
73  projects is worth far more than one tied to this codebase
74- Abstract away project-specific detail: strip repo/file/function/path/variable names
75  and one-off identifiers, and rephrase the lesson as a general principle whoever the
76  next project is. Keep concrete project-specific detail ONLY when the lesson genuinely
77  cannot be generalized without losing its meaning
78- trigger_desc must match the vocabulary a future agent would use in a search query;
79  prefer general, technology- or domain-level phrasing over project-name phrasing
80- Never store conversation text verbatim; always distil to reusable principle form
81- If outcome is "fail", focus on what to avoid
82- Keep principles independent; do not combine unrelated lessons"#
83    )
84}
85
86fn build_distill_prompt_with_related(log: &Value, logs: &[Value]) -> String {
87    let mut prompt = build_distill_prompt(log);
88    let log_id = log.get("id").and_then(Value::as_str).unwrap_or("");
89    let context_key = log.get("context_key").and_then(Value::as_str);
90    let related: Vec<String> = logs
91        .iter()
92        .filter(|other| other.get("id").and_then(Value::as_str).unwrap_or("") != log_id)
93        .filter(|other| {
94            context_key.is_some() && other.get("context_key").and_then(Value::as_str) == context_key
95        })
96        .take(4)
97        .map(|other| {
98            let query = safe_prompt_field(other.get("query").and_then(Value::as_str));
99            let summary = safe_prompt_field(other.get("output_summary").and_then(Value::as_str));
100            let outcome = safe_prompt_field(other.get("outcome").and_then(Value::as_str));
101            format!("- Query: {query}; outcome: {outcome}; summary: {summary}")
102        })
103        .collect();
104    if !related.is_empty() {
105        prompt.push_str(
106            "\n\nRelated recent interactions (use only to identify repeated patterns or conflicts):\n",
107        );
108        prompt.push_str(&related.join("\n"));
109    }
110    prompt
111}
112
113// ---------------------------------------------------------------------------
114// Shared HTTP transport with retry/backoff
115// ---------------------------------------------------------------------------
116
117/// Max total attempts (initial try + retries) for a single LLM/embedding call.
118const HTTP_MAX_ATTEMPTS: u32 = 3;
119/// Per-request socket timeout. Each retry gets a fresh timeout window.
120const HTTP_TIMEOUT: Duration = Duration::from_secs(30);
121
122/// POST `body` as JSON to `url` with the given extra headers, retrying transient
123/// failures (network/timeout errors, HTTP 429, and 5xx) with exponential backoff.
124/// `Content-Type: application/json` is set automatically. `label` names the call
125/// site in error messages (e.g. "LLM", "Anthropic", "Embedding").
126fn post_json_retry(
127    url: &str,
128    headers: &[(&str, &str)],
129    body: &Value,
130    label: &str,
131) -> Result<Value> {
132    // Single instrumentation point for all LLM/embedding calls: time the whole
133    // call (across retries) and emit one trace with the final outcome. The
134    // `Authorization` header is never handed to the tracer — only the body.
135    let start = std::time::Instant::now();
136    let mut attempt = 0;
137    let outcome: Result<Value> = loop {
138        attempt += 1;
139        // ureq 3 no longer carries the response inside the status error, so we opt
140        // out of `http_status_as_error`: non-2xx comes back as `Ok(response)` and we
141        // read its code + headers + body ourselves. A genuine `Err` is therefore a
142        // transport-level failure (timeout / connection / I/O) — always retryable.
143        let mut req = ureq::post(url)
144            .config()
145            .timeout_global(Some(HTTP_TIMEOUT))
146            .http_status_as_error(false)
147            .build()
148            .header("Content-Type", "application/json");
149        for (k, v) in headers {
150            req = req.header(*k, *v);
151        }
152        match req.send_json(body) {
153            Ok(mut response) => {
154                let code = response.status().as_u16();
155                if (200..300).contains(&code) {
156                    break response.body_mut().read_json::<Value>().map_err(|e| {
157                        InnateError::Other(format!("{label} response parse error: {e}"))
158                    });
159                }
160                let retry_after = response
161                    .headers()
162                    .get("retry-after")
163                    .and_then(|h| h.to_str().ok())
164                    .and_then(|s| s.trim().parse::<u64>().ok());
165                if status_is_retryable(code) && attempt < HTTP_MAX_ATTEMPTS {
166                    std::thread::sleep(backoff_delay(attempt, retry_after));
167                    continue;
168                }
169                // `status: {code}` is the substring llm_trace classifies into
170                // http_4xx / http_5xx / rate_limited (429); keep it ahead of the body.
171                let detail = response.body_mut().read_to_string().unwrap_or_default();
172                break Err(InnateError::Other(format!(
173                    "{label} HTTP error: status: {code} {detail}"
174                )));
175            }
176            Err(err) => {
177                if attempt < HTTP_MAX_ATTEMPTS {
178                    std::thread::sleep(backoff_delay(attempt, None));
179                    continue;
180                }
181                // `transport:` tag preserves the transport bucket in llm_trace.
182                break Err(InnateError::Other(format!(
183                    "{label} HTTP error: transport: {err}"
184                )));
185            }
186        }
187    };
188    crate::llm_trace::record(label, url, body, &outcome, attempt, start.elapsed());
189    outcome
190}
191
192/// Transient HTTP statuses worth retrying: rate limits and server-side errors.
193fn status_is_retryable(code: u16) -> bool {
194    code == 429 || (500..=599).contains(&code)
195}
196
197/// Backoff before the next attempt. Honors a server `Retry-After` (seconds, capped
198/// at 30s) when present, otherwise exponential: 250ms, 500ms, 1s, ...
199fn backoff_delay(attempt: u32, retry_after_secs: Option<u64>) -> Duration {
200    if let Some(secs) = retry_after_secs {
201        return Duration::from_secs(secs.min(30));
202    }
203    let shift = (attempt - 1).min(6);
204    Duration::from_millis(250u64.saturating_mul(1 << shift))
205}
206
207// ---------------------------------------------------------------------------
208// HTTP distiller — one type for both OpenAI-compatible endpoints (GPT, DeepSeek,
209// local Ollama, ...) and the Anthropic Messages API. The request/response shape
210// is selected per call from `config.provider`; everything else (distill loop,
211// provenance, retry transport) is shared.
212// ---------------------------------------------------------------------------
213
214pub struct HttpDistiller {
215    config: LlmConfig,
216}
217
218impl HttpDistiller {
219    pub fn new(config: LlmConfig) -> Self {
220        Self { config }
221    }
222
223    /// One-shot completion. Public so other LLM-backed features (e.g. the opt-in
224    /// recall reranker) can reuse the same retrying transport and provider switch.
225    pub fn call(&self, prompt: &str) -> Result<String> {
226        if self.config.provider == "anthropic" {
227            self.call_anthropic(prompt)
228        } else {
229            self.call_openai(prompt)
230        }
231    }
232
233    fn call_openai(&self, prompt: &str) -> Result<String> {
234        let api_key = self
235            .config
236            .resolved_api_key()
237            .ok_or_else(|| InnateError::Other("LLM API key not configured".into()))?;
238
239        let base = self.config.resolved_base_url();
240        let url = format!("{base}/chat/completions");
241
242        let body = json!({
243            "model": self.config.model_id,
244            "messages": [{"role": "user", "content": prompt}],
245            "max_tokens": 800,
246            "temperature": 0.2,
247        });
248
249        let auth = format!("Bearer {api_key}");
250        let resp_json = post_json_retry(&url, &[("Authorization", &auth)], &body, "LLM")?;
251
252        resp_json
253            .pointer("/choices/0/message/content")
254            .and_then(Value::as_str)
255            .map(str::to_string)
256            .ok_or_else(|| InnateError::Other("unexpected LLM response shape".into()))
257    }
258
259    fn call_anthropic(&self, prompt: &str) -> Result<String> {
260        let api_key = self
261            .config
262            .resolved_api_key()
263            .ok_or_else(|| InnateError::Other("Anthropic API key not configured".into()))?;
264
265        let base = self.config.resolved_base_url();
266        let url = format!("{base}/v1/messages");
267
268        let body = json!({
269            "model": self.config.model_id,
270            "max_tokens": 800,
271            "messages": [{"role": "user", "content": prompt}],
272        });
273
274        let resp_json = post_json_retry(
275            &url,
276            &[("x-api-key", &api_key), ("anthropic-version", "2023-06-01")],
277            &body,
278            "Anthropic",
279        )?;
280
281        resp_json
282            .pointer("/content/0/text")
283            .and_then(Value::as_str)
284            .map(str::to_string)
285            .ok_or_else(|| InnateError::Other("unexpected Anthropic response shape".into()))
286    }
287}
288
289impl Distiller for HttpDistiller {
290    fn distill(&self, log_entries: &[Value]) -> crate::errors::Result<Vec<DistilledChunk>> {
291        distill_with(log_entries, |prompt| self.call(prompt))
292    }
293
294    fn distill_with_context(
295        &self,
296        primary: &Value,
297        related_logs: &[Value],
298    ) -> crate::errors::Result<Vec<DistilledChunk>> {
299        distill_entry_with(primary, related_logs, |prompt| self.call(prompt))
300    }
301
302    fn provenance(&self) -> DistillProvenance {
303        DistillProvenance {
304            provider: Some(self.config.provider.clone()),
305            model: Some(self.config.model_id.clone()),
306            prompt_version: Some(DISTILL_PROMPT_VERSION.to_string()),
307        }
308    }
309}
310
311// ---------------------------------------------------------------------------
312// Shared parse logic
313// ---------------------------------------------------------------------------
314
315fn distill_with(
316    log_entries: &[Value],
317    call: impl Fn(&str) -> Result<String> + Copy,
318) -> Result<Vec<DistilledChunk>> {
319    let mut out = Vec::new();
320    for entry in log_entries {
321        out.extend(distill_entry_with(entry, log_entries, call)?);
322    }
323    Ok(out)
324}
325
326fn distill_entry_with(
327    entry: &Value,
328    related_logs: &[Value],
329    call: impl Fn(&str) -> Result<String>,
330) -> Result<Vec<DistilledChunk>> {
331    let log_id = entry["id"].as_str().unwrap_or("").to_string();
332    let prompt = build_distill_prompt_with_related(entry, related_logs);
333    let mut raw = call(&prompt)?;
334    let mut parsed = parse_distill_response(&raw);
335    if parsed.is_err() {
336        raw = call(&format!(
337            "{prompt}\n\nYour previous response was invalid. Return only a valid JSON array."
338        ))?;
339        parsed = parse_distill_response(&raw);
340    }
341    let items = parsed.map_err(|error| {
342        InnateError::Other(format!("LLM distillation response invalid: {error}"))
343    })?;
344    let mut out = Vec::new();
345    for parsed in items {
346        let content = parsed
347            .get("content")
348            .and_then(Value::as_str)
349            .map(str::trim)
350            .filter(|s| !s.is_empty());
351        let Some(content) = content else { continue };
352        let skill_name = parsed
353            .get("skill_name")
354            .and_then(Value::as_str)
355            .map(|s| {
356                s.trim()
357                    .split_whitespace()
358                    .take(3)
359                    .collect::<Vec<_>>()
360                    .join(" ")
361            })
362            .filter(|s| !s.is_empty() && s.to_lowercase() != "null");
363        let trigger_desc = parsed
364            .get("trigger_desc")
365            .and_then(Value::as_str)
366            .map(str::to_string)
367            .filter(|s| !s.is_empty());
368        let anti_trigger_desc = parsed
369            .get("anti_trigger_desc")
370            .and_then(Value::as_str)
371            .map(str::to_string)
372            .filter(|s| !s.is_empty() && s.to_lowercase() != "null");
373        out.push(DistilledChunk {
374            content: content.to_string(),
375            skill_name,
376            trigger_desc,
377            anti_trigger_desc,
378            source_log_id: log_id.clone(),
379            nomination: entry
380                .get("nomination")
381                .and_then(Value::as_str)
382                .map(str::to_string),
383            provider_override: None,
384        });
385    }
386    Ok(out)
387}
388
389fn parse_distill_response(raw: &str) -> std::result::Result<Vec<Value>, String> {
390    let json_str = extract_json(raw);
391    let parsed: Value = serde_json::from_str(json_str.trim()).map_err(|e| e.to_string())?;
392    if parsed.get("skip").and_then(Value::as_bool) == Some(true) {
393        return Ok(vec![]);
394    }
395    match parsed {
396        Value::Array(items) => Ok(items),
397        Value::Object(_) => Ok(vec![parsed]),
398        _ => Err("expected a JSON object or array".to_string()),
399    }
400}
401
402fn extract_json(text: &str) -> &str {
403    // Strip markdown code fences if present: ```json ... ``` or ``` ... ```
404    let stripped = text.trim();
405    if let Some(inner) = stripped
406        .strip_prefix("```json")
407        .or_else(|| stripped.strip_prefix("```"))
408    {
409        if let Some(end) = inner.rfind("```") {
410            return inner[..end].trim();
411        }
412    }
413    if let (Some(start), Some(end)) = (stripped.find('['), stripped.rfind(']')) {
414        return &stripped[start..=end];
415    }
416    // Backward-compatible object response.
417    if let (Some(start), Some(end)) = (stripped.find('{'), stripped.rfind('}')) {
418        return &stripped[start..=end];
419    }
420    stripped
421}
422
423// ---------------------------------------------------------------------------
424// Build distiller from settings
425// ---------------------------------------------------------------------------
426
427pub fn build_distiller(config: &LlmConfig) -> std::sync::Arc<dyn Distiller + Send + Sync> {
428    std::sync::Arc::new(HttpDistiller::new(config.clone()))
429}
430
431// ---------------------------------------------------------------------------
432// Opt-in LLM reranker (part d) — reasoning relevance over a small shortlist.
433// Reuses HttpDistiller's retrying transport. Errors propagate so recall can fall
434// back to the fused order (a flaky LLM must never break retrieval).
435// ---------------------------------------------------------------------------
436
437pub struct LlmReranker {
438    inner: HttpDistiller,
439}
440
441impl LlmReranker {
442    pub fn new(config: LlmConfig) -> Self {
443        Self {
444            inner: HttpDistiller::new(config),
445        }
446    }
447}
448
449impl Reranker for LlmReranker {
450    fn rerank(&self, query: &str, candidates: &[Value]) -> Result<Vec<String>> {
451        if candidates.is_empty() {
452            return Ok(Vec::new());
453        }
454        let mut list = String::new();
455        for c in candidates {
456            let id = c.get("id").and_then(Value::as_str).unwrap_or("");
457            let trig = c.get("trigger_desc").and_then(Value::as_str).unwrap_or("");
458            let content: String = c
459                .get("content")
460                .and_then(Value::as_str)
461                .unwrap_or("")
462                .chars()
463                .take(280)
464                .collect();
465            list.push_str(&format!("- id={id} | when={trig} | {content}\n"));
466        }
467        let prompt = format!(
468            "You are reranking knowledge snippets by how directly each one helps with the QUERY. \
469             Consider the snippet's `when` (trigger) and content. Return ONLY a JSON array of the \
470             ids, most relevant first, no prose, no ids that are not listed.\n\n\
471             QUERY: {query}\n\nCANDIDATES:\n{list}"
472        );
473        let resp = self.inner.call(&prompt)?;
474        parse_id_array(&resp)
475            .ok_or_else(|| InnateError::Other("reranker: no id array in LLM response".into()))
476    }
477}
478
479/// Extract a JSON array of string ids from a (possibly chatty) LLM response.
480fn parse_id_array(resp: &str) -> Option<Vec<String>> {
481    let start = resp.find('[')?;
482    let end = resp.rfind(']')?;
483    if end <= start {
484        return None;
485    }
486    let arr: Value = serde_json::from_str(&resp[start..=end]).ok()?;
487    let ids: Vec<String> = arr
488        .as_array()?
489        .iter()
490        .filter_map(|v| v.as_str().map(str::to_string))
491        .collect();
492    if ids.is_empty() {
493        None
494    } else {
495        Some(ids)
496    }
497}
498
499// ---------------------------------------------------------------------------
500// LLM embedding provider (OpenAI-compatible /v1/embeddings)
501// ---------------------------------------------------------------------------
502
503pub struct LlmEmbeddingProvider {
504    config: EmbeddingConfig,
505}
506
507#[cfg(test)]
508#[allow(clippy::items_after_test_module)]
509mod tests {
510    use std::cell::Cell;
511
512    use serde_json::json;
513
514    use std::time::Duration;
515
516    use super::{
517        backoff_delay, build_distill_prompt, distill_entry_with, distill_with,
518        parse_distill_response, parse_embedding_response, status_is_retryable,
519    };
520
521    #[test]
522    fn embedding_response_is_parsed_fail_closed() {
523        // Happy path: correct dimension parses.
524        let resp = json!({"data": [{"embedding": [0.1, 0.2, 0.3]}]});
525        assert_eq!(
526            parse_embedding_response(&resp, 3).unwrap(),
527            vec![0.1f32, 0.2, 0.3]
528        );
529
530        // Wrong dimension is rejected, not silently accepted.
531        assert!(parse_embedding_response(&resp, 4).is_err());
532
533        // A non-numeric element fails the whole parse (no silent drop).
534        let bad = json!({"data": [{"embedding": [0.1, "oops", 0.3]}]});
535        assert!(parse_embedding_response(&bad, 3).is_err());
536
537        // Missing embedding field is rejected.
538        let shape = json!({"data": []});
539        assert!(parse_embedding_response(&shape, 3).is_err());
540    }
541
542    #[test]
543    fn only_rate_limit_and_5xx_are_retryable() {
544        assert!(status_is_retryable(429));
545        assert!(status_is_retryable(500));
546        assert!(status_is_retryable(503));
547        assert!(status_is_retryable(599));
548        assert!(!status_is_retryable(400));
549        assert!(!status_is_retryable(401));
550        assert!(!status_is_retryable(404));
551        assert!(!status_is_retryable(200));
552    }
553
554    #[test]
555    fn backoff_is_exponential_and_honors_retry_after() {
556        // Exponential schedule: 250ms, 500ms, 1s for attempts 1..3.
557        assert_eq!(backoff_delay(1, None), Duration::from_millis(250));
558        assert_eq!(backoff_delay(2, None), Duration::from_millis(500));
559        assert_eq!(backoff_delay(3, None), Duration::from_millis(1000));
560        // Retry-After overrides the schedule and is capped at 30s.
561        assert_eq!(backoff_delay(1, Some(5)), Duration::from_secs(5));
562        assert_eq!(backoff_delay(1, Some(120)), Duration::from_secs(30));
563    }
564
565    #[test]
566    fn prompt_redacts_secrets_before_external_llm_call() {
567        let prompt = build_distill_prompt(&json!({
568            "query": "debug sk-12345678901234567890",
569            "output_summary": "Authorization: Bearer secret-token-value"
570        }));
571        assert!(!prompt.contains("sk-12345678901234567890"));
572        assert!(!prompt.contains("secret-token-value"));
573        assert!(prompt.contains("[REDACTED]"));
574    }
575
576    #[test]
577    fn malformed_response_is_retried_instead_of_silently_skipped() {
578        let calls = Cell::new(0);
579        let chunks = distill_with(&[json!({"id": "log-1", "query": "q"})], |_| {
580            calls.set(calls.get() + 1);
581            if calls.get() == 1 {
582                Ok("not json".to_string())
583            } else {
584                Ok(r#"[{"content":"retry worked","trigger_desc":"retry"}]"#.to_string())
585            }
586        })
587        .unwrap();
588        assert_eq!(calls.get(), 2);
589        assert_eq!(chunks.len(), 1);
590        assert_eq!(chunks[0].content, "retry worked");
591    }
592
593    #[test]
594    fn parser_accepts_multiple_distilled_chunks() {
595        let parsed = parse_distill_response(
596            r#"[{"content":"one"},{"content":"two","anti_trigger_desc":"never"}]"#,
597        )
598        .unwrap();
599        assert_eq!(parsed.len(), 2);
600    }
601
602    #[test]
603    fn nomination_is_distilled_instead_of_bypassing_the_model() {
604        let prompt_seen = Cell::new(false);
605        let entry = json!({
606            "id": "log-1",
607            "query": "original query",
608            "nomination": "raw agent nomination",
609            "output_summary": "summary",
610            "outcome": "ok"
611        });
612        let chunks = distill_entry_with(&entry, std::slice::from_ref(&entry), |prompt| {
613            prompt_seen.set(prompt.contains("raw agent nomination"));
614            Ok(
615                r#"[{"content":"generalized principle","trigger_desc":"generalize","anti_trigger_desc":null}]"#
616                    .to_string(),
617            )
618        })
619        .unwrap();
620
621        assert!(prompt_seen.get());
622        assert_eq!(chunks[0].content, "generalized principle");
623        assert_eq!(
624            chunks[0].nomination.as_deref(),
625            Some("raw agent nomination")
626        );
627    }
628}
629
630impl LlmEmbeddingProvider {
631    pub fn new(config: EmbeddingConfig) -> Self {
632        Self { config }
633    }
634
635    fn embed(&self, text: &str) -> Result<Vec<f32>> {
636        let api_key = self
637            .config
638            .resolved_api_key()
639            .ok_or_else(|| InnateError::Other("Embedding API key not configured".into()))?;
640
641        let base = self.config.resolved_base_url();
642        let url = format!("{base}/embeddings");
643
644        let body = json!({
645            "input": text,
646            "model": self.config.model_id,
647        });
648
649        let auth = format!("Bearer {api_key}");
650        let resp_json = post_json_retry(&url, &[("Authorization", &auth)], &body, "Embedding")?;
651
652        parse_embedding_response(&resp_json, self.config.dim)
653    }
654}
655
656/// Parse an OpenAI-compatible embedding response, fail-closed.
657///
658/// Every element must be numeric (bad entries are not silently dropped) and the
659/// resulting length must equal `expected_dim`, so a malformed or wrong-dimension
660/// vector never reaches cosine similarity.
661fn parse_embedding_response(resp_json: &Value, expected_dim: usize) -> Result<Vec<f32>> {
662    let embedding = resp_json
663        .pointer("/data/0/embedding")
664        .and_then(Value::as_array)
665        .ok_or_else(|| InnateError::Other("unexpected embedding response shape".into()))?;
666    let vec: Vec<f32> = embedding
667        .iter()
668        .map(|v| {
669            v.as_f64().map(|x| x as f32).ok_or_else(|| {
670                InnateError::Other("embedding response contains a non-numeric element".into())
671            })
672        })
673        .collect::<Result<Vec<f32>>>()?;
674    if vec.len() != expected_dim {
675        return Err(InnateError::Other(format!(
676            "embedding dimension mismatch: provider returned {}, expected {expected_dim} (check embedding.dim)",
677            vec.len(),
678        )));
679    }
680    Ok(vec)
681}
682
683impl EmbeddingProvider for LlmEmbeddingProvider {
684    fn model_name(&self) -> &'static str {
685        "llm-embedding"
686    }
687
688    fn content_dim(&self) -> usize {
689        self.config.dim
690    }
691
692    fn trigger_dim(&self) -> usize {
693        self.config.dim
694    }
695
696    fn embed_content(&self, text: &str) -> Result<Vec<f32>> {
697        self.embed(text)
698    }
699
700    fn embed_trigger(&self, text: &str) -> Result<Vec<f32>> {
701        self.embed(text)
702    }
703
704    /// Content and trigger share the same model and dimension here, so a single
705    /// HTTP request serves both spaces — half the round trips per recall.
706    fn embed_both(&self, text: &str) -> Result<(Vec<f32>, Vec<f32>)> {
707        let v = self.embed(text)?;
708        Ok((v.clone(), v))
709    }
710}
711
712// ---------------------------------------------------------------------------
713// Connection test helpers (used by install)
714// ---------------------------------------------------------------------------
715
716/// Test the LLM config by sending a minimal request. Returns Ok(model_response) or Err.
717pub fn test_llm(config: &LlmConfig) -> Result<String> {
718    let distiller = build_distiller(config);
719    let dummy_log = json!({
720        "id": "test",
721        "query": "connection test",
722        "output_summary": "test",
723        "outcome": "ok"
724    });
725    // We don't care about the result — just that the call succeeds.
726    distiller.distill(&[dummy_log])?;
727    Ok(format!("OK — model: {}", config.model_id))
728}
729
730/// Test the embedding config. Returns Ok(dim) or Err.
731pub fn test_embedding(config: &EmbeddingConfig) -> Result<usize> {
732    let provider = LlmEmbeddingProvider::new(config.clone());
733    let vec = provider.embed("connection test")?;
734    Ok(vec.len())
735}