Skip to main content

innate_core/
llm.rs

1use serde_json::{json, Value};
2use std::time::Duration;
3
4use crate::embedding::EmbeddingProvider;
5use crate::errors::{InnateError, Result};
6use crate::refine::{DistillProvenance, DistilledChunk, Distiller};
7use crate::settings::{EmbeddingConfig, LlmConfig};
8
9// ---------------------------------------------------------------------------
10// Prompt for distillation
11// ---------------------------------------------------------------------------
12
13const DISTILL_PROMPT_VERSION: &str = "2";
14
15fn safe_prompt_field(value: Option<&str>) -> String {
16    let value = value.unwrap_or("");
17    let (cleaned, action) = crate::utils::sanitize(value);
18    match action {
19        crate::utils::SanitizeAction::Discard => "[removed unsafe content]".to_string(),
20        _ => cleaned,
21    }
22}
23
24fn build_distill_prompt(log: &Value) -> String {
25    let query = safe_prompt_field(log.get("query").and_then(Value::as_str));
26    let output = safe_prompt_field(log.get("output").and_then(Value::as_str));
27    let output_summary = safe_prompt_field(
28        log
29        .get("output_summary")
30        .and_then(Value::as_str)
31    );
32    let nomination = safe_prompt_field(log.get("nomination").and_then(Value::as_str));
33    let outcome = safe_prompt_field(log.get("outcome").and_then(Value::as_str));
34
35    let mut context_parts = vec![];
36    if !query.is_empty() {
37        context_parts.push(format!("Query: {query}"));
38    }
39    if !nomination.is_empty() {
40        context_parts.push(format!("Nominated insight: {nomination}"));
41    }
42    if !output_summary.is_empty() {
43        context_parts.push(format!("Summary: {output_summary}"));
44    }
45    if !output.is_empty() {
46        let truncated: String = output.chars().take(1500).collect();
47        context_parts.push(format!("Output (truncated): {truncated}"));
48    }
49    if !outcome.is_empty() {
50        context_parts.push(format!("Outcome: {outcome}"));
51    }
52
53    let context = context_parts.join("\n");
54
55    format!(
56        r#"You are a knowledge distillation assistant. Given an agent interaction log, \
57extract zero or more independent reusable procedural principles.
58
59Agent interaction:
60{context}
61
62Output a JSON array. Each item has:
63{{
64  "content": "<principle; when it applies; what to avoid>",
65  "trigger_desc": "<2-6 word canonical phrase>",
66  "anti_trigger_desc": "<when NOT to apply this, or null>"
67}}
68Return [] if nothing is worth keeping.
69
70Rules:
71- content must be self-contained and actionable for a future agent reading cold
72- trigger_desc must match the vocabulary a future agent would use in a search query
73- Never store conversation text verbatim; always distil to reusable principle form
74- If outcome is "fail", focus on what to avoid
75- Keep principles independent; do not combine unrelated lessons"#
76    )
77}
78
79fn build_distill_prompt_with_related(log: &Value, logs: &[Value]) -> String {
80    let mut prompt = build_distill_prompt(log);
81    let log_id = log.get("id").and_then(Value::as_str).unwrap_or("");
82    let related: Vec<String> = logs
83        .iter()
84        .filter(|other| other.get("id").and_then(Value::as_str).unwrap_or("") != log_id)
85        .take(4)
86        .map(|other| {
87            let query = safe_prompt_field(other.get("query").and_then(Value::as_str));
88            let summary =
89                safe_prompt_field(other.get("output_summary").and_then(Value::as_str));
90            let outcome = safe_prompt_field(other.get("outcome").and_then(Value::as_str));
91            format!("- Query: {query}; outcome: {outcome}; summary: {summary}")
92        })
93        .collect();
94    if !related.is_empty() {
95        prompt.push_str(
96            "\n\nRelated recent interactions (use only to identify repeated patterns or conflicts):\n",
97        );
98        prompt.push_str(&related.join("\n"));
99    }
100    prompt
101}
102
103// ---------------------------------------------------------------------------
104// OpenAI-compatible distiller  (works for GPT, DeepSeek, local Ollama, etc.)
105// ---------------------------------------------------------------------------
106
107pub struct OpenAiDistiller {
108    config: LlmConfig,
109}
110
111impl OpenAiDistiller {
112    pub fn new(config: LlmConfig) -> Self {
113        Self { config }
114    }
115
116    fn call(&self, prompt: &str) -> Result<String> {
117        let api_key = self
118            .config
119            .resolved_api_key()
120            .ok_or_else(|| InnateError::Other("LLM API key not configured".into()))?;
121
122        let base = self.config.resolved_base_url();
123        let url = format!("{base}/chat/completions");
124
125        let body = json!({
126            "model": self.config.model_id,
127            "messages": [{"role": "user", "content": prompt}],
128            "max_tokens": 800,
129            "temperature": 0.2,
130        });
131
132        let response = ureq::post(&url)
133            .timeout(Duration::from_secs(30))
134            .set("Authorization", &format!("Bearer {api_key}"))
135            .set("Content-Type", "application/json")
136            .send_json(&body)
137            .map_err(|e| InnateError::Other(format!("LLM HTTP error: {e}")))?;
138
139        let resp_json: Value = response
140            .into_json()
141            .map_err(|e| InnateError::Other(format!("LLM response parse error: {e}")))?;
142
143        resp_json
144            .pointer("/choices/0/message/content")
145            .and_then(Value::as_str)
146            .map(str::to_string)
147            .ok_or_else(|| InnateError::Other("unexpected LLM response shape".into()))
148    }
149}
150
151impl Distiller for OpenAiDistiller {
152    fn distill(&self, log_entries: &[Value]) -> crate::errors::Result<Vec<DistilledChunk>> {
153        distill_with(log_entries, |prompt| self.call(prompt))
154    }
155
156    fn distill_with_context(
157        &self,
158        primary: &Value,
159        related_logs: &[Value],
160    ) -> crate::errors::Result<Vec<DistilledChunk>> {
161        distill_entry_with(primary, related_logs, |prompt| self.call(prompt))
162    }
163
164    fn provenance(&self) -> DistillProvenance {
165        DistillProvenance {
166            provider: Some(self.config.provider.clone()),
167            model: Some(self.config.model_id.clone()),
168            prompt_version: Some(DISTILL_PROMPT_VERSION.to_string()),
169        }
170    }
171}
172
173// ---------------------------------------------------------------------------
174// Anthropic Messages API distiller
175// ---------------------------------------------------------------------------
176
177pub struct AnthropicDistiller {
178    config: LlmConfig,
179}
180
181impl AnthropicDistiller {
182    pub fn new(config: LlmConfig) -> Self {
183        Self { config }
184    }
185
186    fn call(&self, prompt: &str) -> Result<String> {
187        let api_key = self
188            .config
189            .resolved_api_key()
190            .ok_or_else(|| InnateError::Other("Anthropic API key not configured".into()))?;
191
192        let base = self.config.resolved_base_url();
193        let url = format!("{base}/v1/messages");
194
195        let body = json!({
196            "model": self.config.model_id,
197            "max_tokens": 800,
198            "messages": [{"role": "user", "content": prompt}],
199        });
200
201        let response = ureq::post(&url)
202            .timeout(Duration::from_secs(30))
203            .set("x-api-key", &api_key)
204            .set("anthropic-version", "2023-06-01")
205            .set("Content-Type", "application/json")
206            .send_json(&body)
207            .map_err(|e| InnateError::Other(format!("Anthropic HTTP error: {e}")))?;
208
209        let resp_json: Value = response
210            .into_json()
211            .map_err(|e| InnateError::Other(format!("Anthropic response parse error: {e}")))?;
212
213        resp_json
214            .pointer("/content/0/text")
215            .and_then(Value::as_str)
216            .map(str::to_string)
217            .ok_or_else(|| InnateError::Other("unexpected Anthropic response shape".into()))
218    }
219}
220
221impl Distiller for AnthropicDistiller {
222    fn distill(&self, log_entries: &[Value]) -> crate::errors::Result<Vec<DistilledChunk>> {
223        distill_with(log_entries, |prompt| self.call(prompt))
224    }
225
226    fn distill_with_context(
227        &self,
228        primary: &Value,
229        related_logs: &[Value],
230    ) -> crate::errors::Result<Vec<DistilledChunk>> {
231        distill_entry_with(primary, related_logs, |prompt| self.call(prompt))
232    }
233
234    fn provenance(&self) -> DistillProvenance {
235        DistillProvenance {
236            provider: Some(self.config.provider.clone()),
237            model: Some(self.config.model_id.clone()),
238            prompt_version: Some(DISTILL_PROMPT_VERSION.to_string()),
239        }
240    }
241}
242
243// ---------------------------------------------------------------------------
244// Shared parse logic
245// ---------------------------------------------------------------------------
246
247fn distill_with(
248    log_entries: &[Value],
249    call: impl Fn(&str) -> Result<String> + Copy,
250) -> Result<Vec<DistilledChunk>> {
251    let mut out = Vec::new();
252    for entry in log_entries {
253        out.extend(distill_entry_with(entry, log_entries, call)?);
254    }
255    Ok(out)
256}
257
258fn distill_entry_with(
259    entry: &Value,
260    related_logs: &[Value],
261    call: impl Fn(&str) -> Result<String>,
262) -> Result<Vec<DistilledChunk>> {
263    let log_id = entry["id"].as_str().unwrap_or("").to_string();
264    let prompt = build_distill_prompt_with_related(entry, related_logs);
265    let mut raw = call(&prompt)?;
266    let mut parsed = parse_distill_response(&raw);
267    if parsed.is_err() {
268        raw = call(&format!(
269            "{prompt}\n\nYour previous response was invalid. Return only a valid JSON array."
270        ))?;
271        parsed = parse_distill_response(&raw);
272    }
273    let items = parsed
274        .map_err(|error| InnateError::Other(format!("LLM distillation response invalid: {error}")))?;
275    let mut out = Vec::new();
276    for parsed in items {
277        let content = parsed
278            .get("content")
279            .and_then(Value::as_str)
280            .map(str::trim)
281            .filter(|s| !s.is_empty());
282        let Some(content) = content else { continue };
283        let trigger_desc = parsed
284            .get("trigger_desc")
285            .and_then(Value::as_str)
286            .map(str::to_string)
287            .filter(|s| !s.is_empty());
288        let anti_trigger_desc = parsed
289            .get("anti_trigger_desc")
290            .and_then(Value::as_str)
291            .map(str::to_string)
292            .filter(|s| !s.is_empty() && s.to_lowercase() != "null");
293        out.push(DistilledChunk {
294            content: content.to_string(),
295            trigger_desc,
296            anti_trigger_desc,
297            source_log_id: log_id.clone(),
298            nomination: entry
299                .get("nomination")
300                .and_then(Value::as_str)
301                .map(str::to_string),
302        });
303    }
304    Ok(out)
305}
306
307fn parse_distill_response(raw: &str) -> std::result::Result<Vec<Value>, String> {
308    let json_str = extract_json(raw);
309    let parsed: Value = serde_json::from_str(json_str.trim()).map_err(|e| e.to_string())?;
310    if parsed.get("skip").and_then(Value::as_bool) == Some(true) {
311        return Ok(vec![]);
312    }
313    match parsed {
314        Value::Array(items) => Ok(items),
315        Value::Object(_) => Ok(vec![parsed]),
316        _ => Err("expected a JSON object or array".to_string()),
317    }
318}
319
320fn extract_json(text: &str) -> &str {
321    // Strip markdown code fences if present: ```json ... ``` or ``` ... ```
322    let stripped = text.trim();
323    if let Some(inner) = stripped
324        .strip_prefix("```json")
325        .or_else(|| stripped.strip_prefix("```"))
326    {
327        if let Some(end) = inner.rfind("```") {
328            return inner[..end].trim();
329        }
330    }
331    if let (Some(start), Some(end)) = (stripped.find('['), stripped.rfind(']')) {
332        return &stripped[start..=end];
333    }
334    // Backward-compatible object response.
335    if let (Some(start), Some(end)) = (stripped.find('{'), stripped.rfind('}')) {
336        return &stripped[start..=end];
337    }
338    stripped
339}
340
341// ---------------------------------------------------------------------------
342// Build distiller from settings
343// ---------------------------------------------------------------------------
344
345pub fn build_distiller(
346    config: &LlmConfig,
347) -> std::sync::Arc<dyn Distiller + Send + Sync> {
348    match config.provider.as_str() {
349        "anthropic" => std::sync::Arc::new(AnthropicDistiller::new(config.clone())),
350        _ => std::sync::Arc::new(OpenAiDistiller::new(config.clone())),
351    }
352}
353
354// ---------------------------------------------------------------------------
355// LLM embedding provider (OpenAI-compatible /v1/embeddings)
356// ---------------------------------------------------------------------------
357
358pub struct LlmEmbeddingProvider {
359    config: EmbeddingConfig,
360}
361
362#[cfg(test)]
363#[allow(clippy::items_after_test_module)]
364mod tests {
365    use std::cell::Cell;
366
367    use serde_json::json;
368
369    use super::{
370        build_distill_prompt, distill_entry_with, distill_with, parse_distill_response,
371    };
372
373    #[test]
374    fn prompt_redacts_secrets_before_external_llm_call() {
375        let prompt = build_distill_prompt(&json!({
376            "query": "debug sk-12345678901234567890",
377            "output_summary": "Authorization: Bearer secret-token-value"
378        }));
379        assert!(!prompt.contains("sk-12345678901234567890"));
380        assert!(!prompt.contains("secret-token-value"));
381        assert!(prompt.contains("[REDACTED]"));
382    }
383
384    #[test]
385    fn malformed_response_is_retried_instead_of_silently_skipped() {
386        let calls = Cell::new(0);
387        let chunks = distill_with(&[json!({"id": "log-1", "query": "q"})], |_| {
388            calls.set(calls.get() + 1);
389            if calls.get() == 1 {
390                Ok("not json".to_string())
391            } else {
392                Ok(r#"[{"content":"retry worked","trigger_desc":"retry"}]"#.to_string())
393            }
394        })
395        .unwrap();
396        assert_eq!(calls.get(), 2);
397        assert_eq!(chunks.len(), 1);
398        assert_eq!(chunks[0].content, "retry worked");
399    }
400
401    #[test]
402    fn parser_accepts_multiple_distilled_chunks() {
403        let parsed = parse_distill_response(
404            r#"[{"content":"one"},{"content":"two","anti_trigger_desc":"never"}]"#,
405        )
406        .unwrap();
407        assert_eq!(parsed.len(), 2);
408    }
409
410    #[test]
411    fn nomination_is_distilled_instead_of_bypassing_the_model() {
412        let prompt_seen = Cell::new(false);
413        let entry = json!({
414            "id": "log-1",
415            "query": "original query",
416            "nomination": "raw agent nomination",
417            "output_summary": "summary",
418            "outcome": "ok"
419        });
420        let chunks = distill_entry_with(&entry, std::slice::from_ref(&entry), |prompt| {
421            prompt_seen.set(prompt.contains("raw agent nomination"));
422            Ok(
423                r#"[{"content":"generalized principle","trigger_desc":"generalize","anti_trigger_desc":null}]"#
424                    .to_string(),
425            )
426        })
427        .unwrap();
428
429        assert!(prompt_seen.get());
430        assert_eq!(chunks[0].content, "generalized principle");
431        assert_eq!(
432            chunks[0].nomination.as_deref(),
433            Some("raw agent nomination")
434        );
435    }
436}
437
438impl LlmEmbeddingProvider {
439    pub fn new(config: EmbeddingConfig) -> Self {
440        Self { config }
441    }
442
443    fn embed(&self, text: &str) -> Result<Vec<f32>> {
444        let api_key = self
445            .config
446            .resolved_api_key()
447            .ok_or_else(|| InnateError::Other("Embedding API key not configured".into()))?;
448
449        let base = self.config.resolved_base_url();
450        let url = format!("{base}/embeddings");
451
452        let body = json!({
453            "input": text,
454            "model": self.config.model_id,
455        });
456
457        let response = ureq::post(&url)
458            .set("Authorization", &format!("Bearer {api_key}"))
459            .set("Content-Type", "application/json")
460            .send_json(&body)
461            .map_err(|e| InnateError::Other(format!("Embedding HTTP error: {e}")))?;
462
463        let resp_json: Value = response
464            .into_json()
465            .map_err(|e| InnateError::Other(format!("Embedding response parse: {e}")))?;
466
467        let embedding = resp_json
468            .pointer("/data/0/embedding")
469            .and_then(Value::as_array)
470            .ok_or_else(|| InnateError::Other("unexpected embedding response shape".into()))?;
471
472        Ok(embedding
473            .iter()
474            .filter_map(Value::as_f64)
475            .map(|x| x as f32)
476            .collect())
477    }
478}
479
480impl EmbeddingProvider for LlmEmbeddingProvider {
481    fn model_name(&self) -> &'static str {
482        "llm-embedding"
483    }
484
485    fn content_dim(&self) -> usize {
486        self.config.dim
487    }
488
489    fn trigger_dim(&self) -> usize {
490        self.config.dim
491    }
492
493    fn embed_content(&self, text: &str) -> Result<Vec<f32>> {
494        self.embed(text)
495    }
496
497    fn embed_trigger(&self, text: &str) -> Result<Vec<f32>> {
498        self.embed(text)
499    }
500}
501
502// ---------------------------------------------------------------------------
503// Connection test helpers (used by install)
504// ---------------------------------------------------------------------------
505
506/// Test the LLM config by sending a minimal request. Returns Ok(model_response) or Err.
507pub fn test_llm(config: &LlmConfig) -> Result<String> {
508    let distiller = build_distiller(config);
509    let dummy_log = json!({
510        "id": "test",
511        "query": "connection test",
512        "output_summary": "test",
513        "outcome": "ok"
514    });
515    // We don't care about the result — just that the call succeeds.
516    distiller.distill(&[dummy_log])?;
517    Ok(format!("OK — model: {}", config.model_id))
518}
519
520/// Test the embedding config. Returns Ok(dim) or Err.
521pub fn test_embedding(config: &EmbeddingConfig) -> Result<usize> {
522    let provider = LlmEmbeddingProvider::new(config.clone());
523    let vec = provider.embed("connection test")?;
524    Ok(vec.len())
525}