Skip to main content

innate_core/
llm.rs

1use serde_json::{json, Value};
2use std::time::Duration;
3
4use crate::embedding::EmbeddingProvider;
5use crate::errors::{InnateError, Result};
6use crate::refine::{DistillProvenance, DistilledChunk, Distiller};
7use crate::settings::{EmbeddingConfig, LlmConfig};
8
9// ---------------------------------------------------------------------------
10// Prompt for distillation
11// ---------------------------------------------------------------------------
12
13const DISTILL_PROMPT_VERSION: &str = "2";
14
15fn safe_prompt_field(value: Option<&str>) -> String {
16    let value = value.unwrap_or("");
17    let (cleaned, action) = crate::utils::sanitize(value);
18    match action {
19        crate::utils::SanitizeAction::Discard => "[removed unsafe content]".to_string(),
20        _ => cleaned,
21    }
22}
23
24fn build_distill_prompt(log: &Value) -> String {
25    let query = safe_prompt_field(log.get("query").and_then(Value::as_str));
26    let output = safe_prompt_field(log.get("output").and_then(Value::as_str));
27    let output_summary = safe_prompt_field(log.get("output_summary").and_then(Value::as_str));
28    let nomination = safe_prompt_field(log.get("nomination").and_then(Value::as_str));
29    let outcome = safe_prompt_field(log.get("outcome").and_then(Value::as_str));
30
31    let mut context_parts = vec![];
32    if !query.is_empty() {
33        context_parts.push(format!("Query: {query}"));
34    }
35    if !nomination.is_empty() {
36        context_parts.push(format!("Nominated insight: {nomination}"));
37    }
38    if !output_summary.is_empty() {
39        context_parts.push(format!("Summary: {output_summary}"));
40    }
41    if !output.is_empty() {
42        let truncated: String = output.chars().take(1500).collect();
43        context_parts.push(format!("Output (truncated): {truncated}"));
44    }
45    if !outcome.is_empty() {
46        context_parts.push(format!("Outcome: {outcome}"));
47    }
48
49    let context = context_parts.join("\n");
50
51    format!(
52        r#"You are a knowledge distillation assistant. Given an agent interaction log, \
53extract zero or more independent reusable procedural principles.
54
55Agent interaction:
56{context}
57
58Output a JSON array. Each item has:
59{{
60  "content": "<principle; when it applies; what to avoid>",
61  "trigger_desc": "<2-6 word canonical phrase>",
62  "anti_trigger_desc": "<when NOT to apply this, or null>"
63}}
64Return [] if nothing is worth keeping.
65
66Rules:
67- content must be self-contained and actionable for a future agent reading cold
68- trigger_desc must match the vocabulary a future agent would use in a search query
69- Never store conversation text verbatim; always distil to reusable principle form
70- If outcome is "fail", focus on what to avoid
71- Keep principles independent; do not combine unrelated lessons"#
72    )
73}
74
75fn build_distill_prompt_with_related(log: &Value, logs: &[Value]) -> String {
76    let mut prompt = build_distill_prompt(log);
77    let log_id = log.get("id").and_then(Value::as_str).unwrap_or("");
78    let context_key = log.get("context_key").and_then(Value::as_str);
79    let related: Vec<String> = logs
80        .iter()
81        .filter(|other| other.get("id").and_then(Value::as_str).unwrap_or("") != log_id)
82        .filter(|other| {
83            context_key.is_some() && other.get("context_key").and_then(Value::as_str) == context_key
84        })
85        .take(4)
86        .map(|other| {
87            let query = safe_prompt_field(other.get("query").and_then(Value::as_str));
88            let summary = safe_prompt_field(other.get("output_summary").and_then(Value::as_str));
89            let outcome = safe_prompt_field(other.get("outcome").and_then(Value::as_str));
90            format!("- Query: {query}; outcome: {outcome}; summary: {summary}")
91        })
92        .collect();
93    if !related.is_empty() {
94        prompt.push_str(
95            "\n\nRelated recent interactions (use only to identify repeated patterns or conflicts):\n",
96        );
97        prompt.push_str(&related.join("\n"));
98    }
99    prompt
100}
101
102// ---------------------------------------------------------------------------
103// Shared HTTP transport with retry/backoff
104// ---------------------------------------------------------------------------
105
106/// Max total attempts (initial try + retries) for a single LLM/embedding call.
107const HTTP_MAX_ATTEMPTS: u32 = 3;
108/// Per-request socket timeout. Each retry gets a fresh timeout window.
109const HTTP_TIMEOUT: Duration = Duration::from_secs(30);
110
111/// POST `body` as JSON to `url` with the given extra headers, retrying transient
112/// failures (network/timeout errors, HTTP 429, and 5xx) with exponential backoff.
113/// `Content-Type: application/json` is set automatically. `label` names the call
114/// site in error messages (e.g. "LLM", "Anthropic", "Embedding").
115fn post_json_retry(
116    url: &str,
117    headers: &[(&str, &str)],
118    body: &Value,
119    label: &str,
120) -> Result<Value> {
121    let mut attempt = 0;
122    loop {
123        attempt += 1;
124        let mut req = ureq::post(url)
125            .timeout(HTTP_TIMEOUT)
126            .set("Content-Type", "application/json");
127        for (k, v) in headers {
128            req = req.set(k, v);
129        }
130        match req.send_json(body) {
131            Ok(response) => {
132                return response
133                    .into_json()
134                    .map_err(|e| InnateError::Other(format!("{label} response parse error: {e}")));
135            }
136            Err(err) => {
137                let (retryable, retry_after) = match &err {
138                    ureq::Error::Status(code, resp) => (
139                        status_is_retryable(*code),
140                        resp.header("retry-after")
141                            .and_then(|s| s.trim().parse::<u64>().ok()),
142                    ),
143                    ureq::Error::Transport(_) => (true, None),
144                };
145                if retryable && attempt < HTTP_MAX_ATTEMPTS {
146                    std::thread::sleep(backoff_delay(attempt, retry_after));
147                    continue;
148                }
149                return Err(InnateError::Other(format!("{label} HTTP error: {err}")));
150            }
151        }
152    }
153}
154
155/// Transient HTTP statuses worth retrying: rate limits and server-side errors.
156fn status_is_retryable(code: u16) -> bool {
157    code == 429 || (500..=599).contains(&code)
158}
159
160/// Backoff before the next attempt. Honors a server `Retry-After` (seconds, capped
161/// at 30s) when present, otherwise exponential: 250ms, 500ms, 1s, ...
162fn backoff_delay(attempt: u32, retry_after_secs: Option<u64>) -> Duration {
163    if let Some(secs) = retry_after_secs {
164        return Duration::from_secs(secs.min(30));
165    }
166    let shift = (attempt - 1).min(6);
167    Duration::from_millis(250u64.saturating_mul(1 << shift))
168}
169
170// ---------------------------------------------------------------------------
171// HTTP distiller — one type for both OpenAI-compatible endpoints (GPT, DeepSeek,
172// local Ollama, ...) and the Anthropic Messages API. The request/response shape
173// is selected per call from `config.provider`; everything else (distill loop,
174// provenance, retry transport) is shared.
175// ---------------------------------------------------------------------------
176
177pub struct HttpDistiller {
178    config: LlmConfig,
179}
180
181impl HttpDistiller {
182    pub fn new(config: LlmConfig) -> Self {
183        Self { config }
184    }
185
186    fn call(&self, prompt: &str) -> Result<String> {
187        if self.config.provider == "anthropic" {
188            self.call_anthropic(prompt)
189        } else {
190            self.call_openai(prompt)
191        }
192    }
193
194    fn call_openai(&self, prompt: &str) -> Result<String> {
195        let api_key = self
196            .config
197            .resolved_api_key()
198            .ok_or_else(|| InnateError::Other("LLM API key not configured".into()))?;
199
200        let base = self.config.resolved_base_url();
201        let url = format!("{base}/chat/completions");
202
203        let body = json!({
204            "model": self.config.model_id,
205            "messages": [{"role": "user", "content": prompt}],
206            "max_tokens": 800,
207            "temperature": 0.2,
208        });
209
210        let auth = format!("Bearer {api_key}");
211        let resp_json = post_json_retry(&url, &[("Authorization", &auth)], &body, "LLM")?;
212
213        resp_json
214            .pointer("/choices/0/message/content")
215            .and_then(Value::as_str)
216            .map(str::to_string)
217            .ok_or_else(|| InnateError::Other("unexpected LLM response shape".into()))
218    }
219
220    fn call_anthropic(&self, prompt: &str) -> Result<String> {
221        let api_key = self
222            .config
223            .resolved_api_key()
224            .ok_or_else(|| InnateError::Other("Anthropic API key not configured".into()))?;
225
226        let base = self.config.resolved_base_url();
227        let url = format!("{base}/v1/messages");
228
229        let body = json!({
230            "model": self.config.model_id,
231            "max_tokens": 800,
232            "messages": [{"role": "user", "content": prompt}],
233        });
234
235        let resp_json = post_json_retry(
236            &url,
237            &[("x-api-key", &api_key), ("anthropic-version", "2023-06-01")],
238            &body,
239            "Anthropic",
240        )?;
241
242        resp_json
243            .pointer("/content/0/text")
244            .and_then(Value::as_str)
245            .map(str::to_string)
246            .ok_or_else(|| InnateError::Other("unexpected Anthropic response shape".into()))
247    }
248}
249
250impl Distiller for HttpDistiller {
251    fn distill(&self, log_entries: &[Value]) -> crate::errors::Result<Vec<DistilledChunk>> {
252        distill_with(log_entries, |prompt| self.call(prompt))
253    }
254
255    fn distill_with_context(
256        &self,
257        primary: &Value,
258        related_logs: &[Value],
259    ) -> crate::errors::Result<Vec<DistilledChunk>> {
260        distill_entry_with(primary, related_logs, |prompt| self.call(prompt))
261    }
262
263    fn provenance(&self) -> DistillProvenance {
264        DistillProvenance {
265            provider: Some(self.config.provider.clone()),
266            model: Some(self.config.model_id.clone()),
267            prompt_version: Some(DISTILL_PROMPT_VERSION.to_string()),
268        }
269    }
270}
271
272// ---------------------------------------------------------------------------
273// Shared parse logic
274// ---------------------------------------------------------------------------
275
276fn distill_with(
277    log_entries: &[Value],
278    call: impl Fn(&str) -> Result<String> + Copy,
279) -> Result<Vec<DistilledChunk>> {
280    let mut out = Vec::new();
281    for entry in log_entries {
282        out.extend(distill_entry_with(entry, log_entries, call)?);
283    }
284    Ok(out)
285}
286
287fn distill_entry_with(
288    entry: &Value,
289    related_logs: &[Value],
290    call: impl Fn(&str) -> Result<String>,
291) -> Result<Vec<DistilledChunk>> {
292    let log_id = entry["id"].as_str().unwrap_or("").to_string();
293    let prompt = build_distill_prompt_with_related(entry, related_logs);
294    let mut raw = call(&prompt)?;
295    let mut parsed = parse_distill_response(&raw);
296    if parsed.is_err() {
297        raw = call(&format!(
298            "{prompt}\n\nYour previous response was invalid. Return only a valid JSON array."
299        ))?;
300        parsed = parse_distill_response(&raw);
301    }
302    let items = parsed.map_err(|error| {
303        InnateError::Other(format!("LLM distillation response invalid: {error}"))
304    })?;
305    let mut out = Vec::new();
306    for parsed in items {
307        let content = parsed
308            .get("content")
309            .and_then(Value::as_str)
310            .map(str::trim)
311            .filter(|s| !s.is_empty());
312        let Some(content) = content else { continue };
313        let trigger_desc = parsed
314            .get("trigger_desc")
315            .and_then(Value::as_str)
316            .map(str::to_string)
317            .filter(|s| !s.is_empty());
318        let anti_trigger_desc = parsed
319            .get("anti_trigger_desc")
320            .and_then(Value::as_str)
321            .map(str::to_string)
322            .filter(|s| !s.is_empty() && s.to_lowercase() != "null");
323        out.push(DistilledChunk {
324            content: content.to_string(),
325            trigger_desc,
326            anti_trigger_desc,
327            source_log_id: log_id.clone(),
328            nomination: entry
329                .get("nomination")
330                .and_then(Value::as_str)
331                .map(str::to_string),
332        });
333    }
334    Ok(out)
335}
336
337fn parse_distill_response(raw: &str) -> std::result::Result<Vec<Value>, String> {
338    let json_str = extract_json(raw);
339    let parsed: Value = serde_json::from_str(json_str.trim()).map_err(|e| e.to_string())?;
340    if parsed.get("skip").and_then(Value::as_bool) == Some(true) {
341        return Ok(vec![]);
342    }
343    match parsed {
344        Value::Array(items) => Ok(items),
345        Value::Object(_) => Ok(vec![parsed]),
346        _ => Err("expected a JSON object or array".to_string()),
347    }
348}
349
350fn extract_json(text: &str) -> &str {
351    // Strip markdown code fences if present: ```json ... ``` or ``` ... ```
352    let stripped = text.trim();
353    if let Some(inner) = stripped
354        .strip_prefix("```json")
355        .or_else(|| stripped.strip_prefix("```"))
356    {
357        if let Some(end) = inner.rfind("```") {
358            return inner[..end].trim();
359        }
360    }
361    if let (Some(start), Some(end)) = (stripped.find('['), stripped.rfind(']')) {
362        return &stripped[start..=end];
363    }
364    // Backward-compatible object response.
365    if let (Some(start), Some(end)) = (stripped.find('{'), stripped.rfind('}')) {
366        return &stripped[start..=end];
367    }
368    stripped
369}
370
371// ---------------------------------------------------------------------------
372// Build distiller from settings
373// ---------------------------------------------------------------------------
374
375pub fn build_distiller(config: &LlmConfig) -> std::sync::Arc<dyn Distiller + Send + Sync> {
376    std::sync::Arc::new(HttpDistiller::new(config.clone()))
377}
378
379// ---------------------------------------------------------------------------
380// LLM embedding provider (OpenAI-compatible /v1/embeddings)
381// ---------------------------------------------------------------------------
382
383pub struct LlmEmbeddingProvider {
384    config: EmbeddingConfig,
385}
386
387#[cfg(test)]
388#[allow(clippy::items_after_test_module)]
389mod tests {
390    use std::cell::Cell;
391
392    use serde_json::json;
393
394    use std::time::Duration;
395
396    use super::{
397        backoff_delay, build_distill_prompt, distill_entry_with, distill_with,
398        parse_distill_response, status_is_retryable,
399    };
400
401    #[test]
402    fn only_rate_limit_and_5xx_are_retryable() {
403        assert!(status_is_retryable(429));
404        assert!(status_is_retryable(500));
405        assert!(status_is_retryable(503));
406        assert!(status_is_retryable(599));
407        assert!(!status_is_retryable(400));
408        assert!(!status_is_retryable(401));
409        assert!(!status_is_retryable(404));
410        assert!(!status_is_retryable(200));
411    }
412
413    #[test]
414    fn backoff_is_exponential_and_honors_retry_after() {
415        // Exponential schedule: 250ms, 500ms, 1s for attempts 1..3.
416        assert_eq!(backoff_delay(1, None), Duration::from_millis(250));
417        assert_eq!(backoff_delay(2, None), Duration::from_millis(500));
418        assert_eq!(backoff_delay(3, None), Duration::from_millis(1000));
419        // Retry-After overrides the schedule and is capped at 30s.
420        assert_eq!(backoff_delay(1, Some(5)), Duration::from_secs(5));
421        assert_eq!(backoff_delay(1, Some(120)), Duration::from_secs(30));
422    }
423
424    #[test]
425    fn prompt_redacts_secrets_before_external_llm_call() {
426        let prompt = build_distill_prompt(&json!({
427            "query": "debug sk-12345678901234567890",
428            "output_summary": "Authorization: Bearer secret-token-value"
429        }));
430        assert!(!prompt.contains("sk-12345678901234567890"));
431        assert!(!prompt.contains("secret-token-value"));
432        assert!(prompt.contains("[REDACTED]"));
433    }
434
435    #[test]
436    fn malformed_response_is_retried_instead_of_silently_skipped() {
437        let calls = Cell::new(0);
438        let chunks = distill_with(&[json!({"id": "log-1", "query": "q"})], |_| {
439            calls.set(calls.get() + 1);
440            if calls.get() == 1 {
441                Ok("not json".to_string())
442            } else {
443                Ok(r#"[{"content":"retry worked","trigger_desc":"retry"}]"#.to_string())
444            }
445        })
446        .unwrap();
447        assert_eq!(calls.get(), 2);
448        assert_eq!(chunks.len(), 1);
449        assert_eq!(chunks[0].content, "retry worked");
450    }
451
452    #[test]
453    fn parser_accepts_multiple_distilled_chunks() {
454        let parsed = parse_distill_response(
455            r#"[{"content":"one"},{"content":"two","anti_trigger_desc":"never"}]"#,
456        )
457        .unwrap();
458        assert_eq!(parsed.len(), 2);
459    }
460
461    #[test]
462    fn nomination_is_distilled_instead_of_bypassing_the_model() {
463        let prompt_seen = Cell::new(false);
464        let entry = json!({
465            "id": "log-1",
466            "query": "original query",
467            "nomination": "raw agent nomination",
468            "output_summary": "summary",
469            "outcome": "ok"
470        });
471        let chunks = distill_entry_with(&entry, std::slice::from_ref(&entry), |prompt| {
472            prompt_seen.set(prompt.contains("raw agent nomination"));
473            Ok(
474                r#"[{"content":"generalized principle","trigger_desc":"generalize","anti_trigger_desc":null}]"#
475                    .to_string(),
476            )
477        })
478        .unwrap();
479
480        assert!(prompt_seen.get());
481        assert_eq!(chunks[0].content, "generalized principle");
482        assert_eq!(
483            chunks[0].nomination.as_deref(),
484            Some("raw agent nomination")
485        );
486    }
487}
488
489impl LlmEmbeddingProvider {
490    pub fn new(config: EmbeddingConfig) -> Self {
491        Self { config }
492    }
493
494    fn embed(&self, text: &str) -> Result<Vec<f32>> {
495        let api_key = self
496            .config
497            .resolved_api_key()
498            .ok_or_else(|| InnateError::Other("Embedding API key not configured".into()))?;
499
500        let base = self.config.resolved_base_url();
501        let url = format!("{base}/embeddings");
502
503        let body = json!({
504            "input": text,
505            "model": self.config.model_id,
506        });
507
508        let auth = format!("Bearer {api_key}");
509        let resp_json = post_json_retry(
510            &url,
511            &[("Authorization", &auth)],
512            &body,
513            "Embedding",
514        )?;
515
516        let embedding = resp_json
517            .pointer("/data/0/embedding")
518            .and_then(Value::as_array)
519            .ok_or_else(|| InnateError::Other("unexpected embedding response shape".into()))?;
520
521        Ok(embedding
522            .iter()
523            .filter_map(Value::as_f64)
524            .map(|x| x as f32)
525            .collect())
526    }
527}
528
529impl EmbeddingProvider for LlmEmbeddingProvider {
530    fn model_name(&self) -> &'static str {
531        "llm-embedding"
532    }
533
534    fn content_dim(&self) -> usize {
535        self.config.dim
536    }
537
538    fn trigger_dim(&self) -> usize {
539        self.config.dim
540    }
541
542    fn embed_content(&self, text: &str) -> Result<Vec<f32>> {
543        self.embed(text)
544    }
545
546    fn embed_trigger(&self, text: &str) -> Result<Vec<f32>> {
547        self.embed(text)
548    }
549
550    /// Content and trigger share the same model and dimension here, so a single
551    /// HTTP request serves both spaces — half the round trips per recall.
552    fn embed_both(&self, text: &str) -> Result<(Vec<f32>, Vec<f32>)> {
553        let v = self.embed(text)?;
554        Ok((v.clone(), v))
555    }
556}
557
558// ---------------------------------------------------------------------------
559// Connection test helpers (used by install)
560// ---------------------------------------------------------------------------
561
562/// Test the LLM config by sending a minimal request. Returns Ok(model_response) or Err.
563pub fn test_llm(config: &LlmConfig) -> Result<String> {
564    let distiller = build_distiller(config);
565    let dummy_log = json!({
566        "id": "test",
567        "query": "connection test",
568        "output_summary": "test",
569        "outcome": "ok"
570    });
571    // We don't care about the result — just that the call succeeds.
572    distiller.distill(&[dummy_log])?;
573    Ok(format!("OK — model: {}", config.model_id))
574}
575
576/// Test the embedding config. Returns Ok(dim) or Err.
577pub fn test_embedding(config: &EmbeddingConfig) -> Result<usize> {
578    let provider = LlmEmbeddingProvider::new(config.clone());
579    let vec = provider.embed("connection test")?;
580    Ok(vec.len())
581}