Skip to main content

innate_core/
llm.rs

1use serde_json::{json, Value};
2use std::time::Duration;
3
4use crate::embedding::EmbeddingProvider;
5use crate::errors::{InnateError, Result};
6use crate::refine::{DistillProvenance, DistilledChunk, Distiller};
7use crate::settings::{EmbeddingConfig, LlmConfig};
8
9// ---------------------------------------------------------------------------
10// Prompt for distillation
11// ---------------------------------------------------------------------------
12
13const DISTILL_PROMPT_VERSION: &str = "2";
14
15fn safe_prompt_field(value: Option<&str>) -> String {
16    let value = value.unwrap_or("");
17    let (cleaned, action) = crate::utils::sanitize(value);
18    match action {
19        crate::utils::SanitizeAction::Discard => "[removed unsafe content]".to_string(),
20        _ => cleaned,
21    }
22}
23
24fn build_distill_prompt(log: &Value) -> String {
25    let query = safe_prompt_field(log.get("query").and_then(Value::as_str));
26    let output = safe_prompt_field(log.get("output").and_then(Value::as_str));
27    let output_summary = safe_prompt_field(
28        log
29        .get("output_summary")
30        .and_then(Value::as_str)
31    );
32    let nomination = safe_prompt_field(log.get("nomination").and_then(Value::as_str));
33    let outcome = safe_prompt_field(log.get("outcome").and_then(Value::as_str));
34
35    let mut context_parts = vec![];
36    if !query.is_empty() {
37        context_parts.push(format!("Query: {query}"));
38    }
39    if !nomination.is_empty() {
40        context_parts.push(format!("Nominated insight: {nomination}"));
41    }
42    if !output_summary.is_empty() {
43        context_parts.push(format!("Summary: {output_summary}"));
44    }
45    if !output.is_empty() {
46        let truncated: String = output.chars().take(1500).collect();
47        context_parts.push(format!("Output (truncated): {truncated}"));
48    }
49    if !outcome.is_empty() {
50        context_parts.push(format!("Outcome: {outcome}"));
51    }
52
53    let context = context_parts.join("\n");
54
55    format!(
56        r#"You are a knowledge distillation assistant. Given an agent interaction log, \
57extract zero or more independent reusable procedural principles.
58
59Agent interaction:
60{context}
61
62Output a JSON array. Each item has:
63{{
64  "content": "<principle; when it applies; what to avoid>",
65  "trigger_desc": "<2-6 word canonical phrase>",
66  "anti_trigger_desc": "<when NOT to apply this, or null>"
67}}
68Return [] if nothing is worth keeping.
69
70Rules:
71- content must be self-contained and actionable for a future agent reading cold
72- trigger_desc must match the vocabulary a future agent would use in a search query
73- Never store conversation text verbatim; always distil to reusable principle form
74- If outcome is "fail", focus on what to avoid
75- Keep principles independent; do not combine unrelated lessons"#
76    )
77}
78
79fn build_distill_prompt_with_related(log: &Value, logs: &[Value]) -> String {
80    let mut prompt = build_distill_prompt(log);
81    let log_id = log.get("id").and_then(Value::as_str).unwrap_or("");
82    let context_key = log.get("context_key").and_then(Value::as_str);
83    let related: Vec<String> = logs
84        .iter()
85        .filter(|other| other.get("id").and_then(Value::as_str).unwrap_or("") != log_id)
86        .filter(|other| {
87            context_key.is_some()
88                && other.get("context_key").and_then(Value::as_str) == context_key
89        })
90        .take(4)
91        .map(|other| {
92            let query = safe_prompt_field(other.get("query").and_then(Value::as_str));
93            let summary =
94                safe_prompt_field(other.get("output_summary").and_then(Value::as_str));
95            let outcome = safe_prompt_field(other.get("outcome").and_then(Value::as_str));
96            format!("- Query: {query}; outcome: {outcome}; summary: {summary}")
97        })
98        .collect();
99    if !related.is_empty() {
100        prompt.push_str(
101            "\n\nRelated recent interactions (use only to identify repeated patterns or conflicts):\n",
102        );
103        prompt.push_str(&related.join("\n"));
104    }
105    prompt
106}
107
108// ---------------------------------------------------------------------------
109// OpenAI-compatible distiller  (works for GPT, DeepSeek, local Ollama, etc.)
110// ---------------------------------------------------------------------------
111
112pub struct OpenAiDistiller {
113    config: LlmConfig,
114}
115
116impl OpenAiDistiller {
117    pub fn new(config: LlmConfig) -> Self {
118        Self { config }
119    }
120
121    fn call(&self, prompt: &str) -> Result<String> {
122        let api_key = self
123            .config
124            .resolved_api_key()
125            .ok_or_else(|| InnateError::Other("LLM API key not configured".into()))?;
126
127        let base = self.config.resolved_base_url();
128        let url = format!("{base}/chat/completions");
129
130        let body = json!({
131            "model": self.config.model_id,
132            "messages": [{"role": "user", "content": prompt}],
133            "max_tokens": 800,
134            "temperature": 0.2,
135        });
136
137        let response = ureq::post(&url)
138            .timeout(Duration::from_secs(30))
139            .set("Authorization", &format!("Bearer {api_key}"))
140            .set("Content-Type", "application/json")
141            .send_json(&body)
142            .map_err(|e| InnateError::Other(format!("LLM HTTP error: {e}")))?;
143
144        let resp_json: Value = response
145            .into_json()
146            .map_err(|e| InnateError::Other(format!("LLM response parse error: {e}")))?;
147
148        resp_json
149            .pointer("/choices/0/message/content")
150            .and_then(Value::as_str)
151            .map(str::to_string)
152            .ok_or_else(|| InnateError::Other("unexpected LLM response shape".into()))
153    }
154}
155
156impl Distiller for OpenAiDistiller {
157    fn distill(&self, log_entries: &[Value]) -> crate::errors::Result<Vec<DistilledChunk>> {
158        distill_with(log_entries, |prompt| self.call(prompt))
159    }
160
161    fn distill_with_context(
162        &self,
163        primary: &Value,
164        related_logs: &[Value],
165    ) -> crate::errors::Result<Vec<DistilledChunk>> {
166        distill_entry_with(primary, related_logs, |prompt| self.call(prompt))
167    }
168
169    fn provenance(&self) -> DistillProvenance {
170        DistillProvenance {
171            provider: Some(self.config.provider.clone()),
172            model: Some(self.config.model_id.clone()),
173            prompt_version: Some(DISTILL_PROMPT_VERSION.to_string()),
174        }
175    }
176}
177
178// ---------------------------------------------------------------------------
179// Anthropic Messages API distiller
180// ---------------------------------------------------------------------------
181
182pub struct AnthropicDistiller {
183    config: LlmConfig,
184}
185
186impl AnthropicDistiller {
187    pub fn new(config: LlmConfig) -> Self {
188        Self { config }
189    }
190
191    fn call(&self, prompt: &str) -> Result<String> {
192        let api_key = self
193            .config
194            .resolved_api_key()
195            .ok_or_else(|| InnateError::Other("Anthropic API key not configured".into()))?;
196
197        let base = self.config.resolved_base_url();
198        let url = format!("{base}/v1/messages");
199
200        let body = json!({
201            "model": self.config.model_id,
202            "max_tokens": 800,
203            "messages": [{"role": "user", "content": prompt}],
204        });
205
206        let response = ureq::post(&url)
207            .timeout(Duration::from_secs(30))
208            .set("x-api-key", &api_key)
209            .set("anthropic-version", "2023-06-01")
210            .set("Content-Type", "application/json")
211            .send_json(&body)
212            .map_err(|e| InnateError::Other(format!("Anthropic HTTP error: {e}")))?;
213
214        let resp_json: Value = response
215            .into_json()
216            .map_err(|e| InnateError::Other(format!("Anthropic response parse error: {e}")))?;
217
218        resp_json
219            .pointer("/content/0/text")
220            .and_then(Value::as_str)
221            .map(str::to_string)
222            .ok_or_else(|| InnateError::Other("unexpected Anthropic response shape".into()))
223    }
224}
225
226impl Distiller for AnthropicDistiller {
227    fn distill(&self, log_entries: &[Value]) -> crate::errors::Result<Vec<DistilledChunk>> {
228        distill_with(log_entries, |prompt| self.call(prompt))
229    }
230
231    fn distill_with_context(
232        &self,
233        primary: &Value,
234        related_logs: &[Value],
235    ) -> crate::errors::Result<Vec<DistilledChunk>> {
236        distill_entry_with(primary, related_logs, |prompt| self.call(prompt))
237    }
238
239    fn provenance(&self) -> DistillProvenance {
240        DistillProvenance {
241            provider: Some(self.config.provider.clone()),
242            model: Some(self.config.model_id.clone()),
243            prompt_version: Some(DISTILL_PROMPT_VERSION.to_string()),
244        }
245    }
246}
247
248// ---------------------------------------------------------------------------
249// Shared parse logic
250// ---------------------------------------------------------------------------
251
252fn distill_with(
253    log_entries: &[Value],
254    call: impl Fn(&str) -> Result<String> + Copy,
255) -> Result<Vec<DistilledChunk>> {
256    let mut out = Vec::new();
257    for entry in log_entries {
258        out.extend(distill_entry_with(entry, log_entries, call)?);
259    }
260    Ok(out)
261}
262
263fn distill_entry_with(
264    entry: &Value,
265    related_logs: &[Value],
266    call: impl Fn(&str) -> Result<String>,
267) -> Result<Vec<DistilledChunk>> {
268    let log_id = entry["id"].as_str().unwrap_or("").to_string();
269    let prompt = build_distill_prompt_with_related(entry, related_logs);
270    let mut raw = call(&prompt)?;
271    let mut parsed = parse_distill_response(&raw);
272    if parsed.is_err() {
273        raw = call(&format!(
274            "{prompt}\n\nYour previous response was invalid. Return only a valid JSON array."
275        ))?;
276        parsed = parse_distill_response(&raw);
277    }
278    let items = parsed
279        .map_err(|error| InnateError::Other(format!("LLM distillation response invalid: {error}")))?;
280    let mut out = Vec::new();
281    for parsed in items {
282        let content = parsed
283            .get("content")
284            .and_then(Value::as_str)
285            .map(str::trim)
286            .filter(|s| !s.is_empty());
287        let Some(content) = content else { continue };
288        let trigger_desc = parsed
289            .get("trigger_desc")
290            .and_then(Value::as_str)
291            .map(str::to_string)
292            .filter(|s| !s.is_empty());
293        let anti_trigger_desc = parsed
294            .get("anti_trigger_desc")
295            .and_then(Value::as_str)
296            .map(str::to_string)
297            .filter(|s| !s.is_empty() && s.to_lowercase() != "null");
298        out.push(DistilledChunk {
299            content: content.to_string(),
300            trigger_desc,
301            anti_trigger_desc,
302            source_log_id: log_id.clone(),
303            nomination: entry
304                .get("nomination")
305                .and_then(Value::as_str)
306                .map(str::to_string),
307        });
308    }
309    Ok(out)
310}
311
312fn parse_distill_response(raw: &str) -> std::result::Result<Vec<Value>, String> {
313    let json_str = extract_json(raw);
314    let parsed: Value = serde_json::from_str(json_str.trim()).map_err(|e| e.to_string())?;
315    if parsed.get("skip").and_then(Value::as_bool) == Some(true) {
316        return Ok(vec![]);
317    }
318    match parsed {
319        Value::Array(items) => Ok(items),
320        Value::Object(_) => Ok(vec![parsed]),
321        _ => Err("expected a JSON object or array".to_string()),
322    }
323}
324
325fn extract_json(text: &str) -> &str {
326    // Strip markdown code fences if present: ```json ... ``` or ``` ... ```
327    let stripped = text.trim();
328    if let Some(inner) = stripped
329        .strip_prefix("```json")
330        .or_else(|| stripped.strip_prefix("```"))
331    {
332        if let Some(end) = inner.rfind("```") {
333            return inner[..end].trim();
334        }
335    }
336    if let (Some(start), Some(end)) = (stripped.find('['), stripped.rfind(']')) {
337        return &stripped[start..=end];
338    }
339    // Backward-compatible object response.
340    if let (Some(start), Some(end)) = (stripped.find('{'), stripped.rfind('}')) {
341        return &stripped[start..=end];
342    }
343    stripped
344}
345
346// ---------------------------------------------------------------------------
347// Build distiller from settings
348// ---------------------------------------------------------------------------
349
350pub fn build_distiller(
351    config: &LlmConfig,
352) -> std::sync::Arc<dyn Distiller + Send + Sync> {
353    match config.provider.as_str() {
354        "anthropic" => std::sync::Arc::new(AnthropicDistiller::new(config.clone())),
355        _ => std::sync::Arc::new(OpenAiDistiller::new(config.clone())),
356    }
357}
358
359// ---------------------------------------------------------------------------
360// LLM embedding provider (OpenAI-compatible /v1/embeddings)
361// ---------------------------------------------------------------------------
362
363pub struct LlmEmbeddingProvider {
364    config: EmbeddingConfig,
365}
366
367#[cfg(test)]
368#[allow(clippy::items_after_test_module)]
369mod tests {
370    use std::cell::Cell;
371
372    use serde_json::json;
373
374    use super::{
375        build_distill_prompt, distill_entry_with, distill_with, parse_distill_response,
376    };
377
378    #[test]
379    fn prompt_redacts_secrets_before_external_llm_call() {
380        let prompt = build_distill_prompt(&json!({
381            "query": "debug sk-12345678901234567890",
382            "output_summary": "Authorization: Bearer secret-token-value"
383        }));
384        assert!(!prompt.contains("sk-12345678901234567890"));
385        assert!(!prompt.contains("secret-token-value"));
386        assert!(prompt.contains("[REDACTED]"));
387    }
388
389    #[test]
390    fn malformed_response_is_retried_instead_of_silently_skipped() {
391        let calls = Cell::new(0);
392        let chunks = distill_with(&[json!({"id": "log-1", "query": "q"})], |_| {
393            calls.set(calls.get() + 1);
394            if calls.get() == 1 {
395                Ok("not json".to_string())
396            } else {
397                Ok(r#"[{"content":"retry worked","trigger_desc":"retry"}]"#.to_string())
398            }
399        })
400        .unwrap();
401        assert_eq!(calls.get(), 2);
402        assert_eq!(chunks.len(), 1);
403        assert_eq!(chunks[0].content, "retry worked");
404    }
405
406    #[test]
407    fn parser_accepts_multiple_distilled_chunks() {
408        let parsed = parse_distill_response(
409            r#"[{"content":"one"},{"content":"two","anti_trigger_desc":"never"}]"#,
410        )
411        .unwrap();
412        assert_eq!(parsed.len(), 2);
413    }
414
415    #[test]
416    fn nomination_is_distilled_instead_of_bypassing_the_model() {
417        let prompt_seen = Cell::new(false);
418        let entry = json!({
419            "id": "log-1",
420            "query": "original query",
421            "nomination": "raw agent nomination",
422            "output_summary": "summary",
423            "outcome": "ok"
424        });
425        let chunks = distill_entry_with(&entry, std::slice::from_ref(&entry), |prompt| {
426            prompt_seen.set(prompt.contains("raw agent nomination"));
427            Ok(
428                r#"[{"content":"generalized principle","trigger_desc":"generalize","anti_trigger_desc":null}]"#
429                    .to_string(),
430            )
431        })
432        .unwrap();
433
434        assert!(prompt_seen.get());
435        assert_eq!(chunks[0].content, "generalized principle");
436        assert_eq!(
437            chunks[0].nomination.as_deref(),
438            Some("raw agent nomination")
439        );
440    }
441}
442
443impl LlmEmbeddingProvider {
444    pub fn new(config: EmbeddingConfig) -> Self {
445        Self { config }
446    }
447
448    fn embed(&self, text: &str) -> Result<Vec<f32>> {
449        let api_key = self
450            .config
451            .resolved_api_key()
452            .ok_or_else(|| InnateError::Other("Embedding API key not configured".into()))?;
453
454        let base = self.config.resolved_base_url();
455        let url = format!("{base}/embeddings");
456
457        let body = json!({
458            "input": text,
459            "model": self.config.model_id,
460        });
461
462        let response = ureq::post(&url)
463            .set("Authorization", &format!("Bearer {api_key}"))
464            .set("Content-Type", "application/json")
465            .send_json(&body)
466            .map_err(|e| InnateError::Other(format!("Embedding HTTP error: {e}")))?;
467
468        let resp_json: Value = response
469            .into_json()
470            .map_err(|e| InnateError::Other(format!("Embedding response parse: {e}")))?;
471
472        let embedding = resp_json
473            .pointer("/data/0/embedding")
474            .and_then(Value::as_array)
475            .ok_or_else(|| InnateError::Other("unexpected embedding response shape".into()))?;
476
477        Ok(embedding
478            .iter()
479            .filter_map(Value::as_f64)
480            .map(|x| x as f32)
481            .collect())
482    }
483}
484
485impl EmbeddingProvider for LlmEmbeddingProvider {
486    fn model_name(&self) -> &'static str {
487        "llm-embedding"
488    }
489
490    fn content_dim(&self) -> usize {
491        self.config.dim
492    }
493
494    fn trigger_dim(&self) -> usize {
495        self.config.dim
496    }
497
498    fn embed_content(&self, text: &str) -> Result<Vec<f32>> {
499        self.embed(text)
500    }
501
502    fn embed_trigger(&self, text: &str) -> Result<Vec<f32>> {
503        self.embed(text)
504    }
505}
506
507// ---------------------------------------------------------------------------
508// Connection test helpers (used by install)
509// ---------------------------------------------------------------------------
510
511/// Test the LLM config by sending a minimal request. Returns Ok(model_response) or Err.
512pub fn test_llm(config: &LlmConfig) -> Result<String> {
513    let distiller = build_distiller(config);
514    let dummy_log = json!({
515        "id": "test",
516        "query": "connection test",
517        "output_summary": "test",
518        "outcome": "ok"
519    });
520    // We don't care about the result — just that the call succeeds.
521    distiller.distill(&[dummy_log])?;
522    Ok(format!("OK — model: {}", config.model_id))
523}
524
525/// Test the embedding config. Returns Ok(dim) or Err.
526pub fn test_embedding(config: &EmbeddingConfig) -> Result<usize> {
527    let provider = LlmEmbeddingProvider::new(config.clone());
528    let vec = provider.embed("connection test")?;
529    Ok(vec.len())
530}