1use serde_json::{json, Value};
2use std::time::Duration;
3
4use crate::embedding::EmbeddingProvider;
5use crate::errors::{InnateError, Result};
6use crate::refine::{DistillProvenance, DistilledChunk, Distiller};
7use crate::settings::{EmbeddingConfig, LlmConfig};
8
9const DISTILL_PROMPT_VERSION: &str = "2";
14
15fn safe_prompt_field(value: Option<&str>) -> String {
16 let value = value.unwrap_or("");
17 let (cleaned, action) = crate::utils::sanitize(value);
18 match action {
19 crate::utils::SanitizeAction::Discard => "[removed unsafe content]".to_string(),
20 _ => cleaned,
21 }
22}
23
24fn build_distill_prompt(log: &Value) -> String {
25 let query = safe_prompt_field(log.get("query").and_then(Value::as_str));
26 let output = safe_prompt_field(log.get("output").and_then(Value::as_str));
27 let output_summary = safe_prompt_field(
28 log
29 .get("output_summary")
30 .and_then(Value::as_str)
31 );
32 let nomination = safe_prompt_field(log.get("nomination").and_then(Value::as_str));
33 let outcome = safe_prompt_field(log.get("outcome").and_then(Value::as_str));
34
35 let mut context_parts = vec![];
36 if !query.is_empty() {
37 context_parts.push(format!("Query: {query}"));
38 }
39 if !nomination.is_empty() {
40 context_parts.push(format!("Nominated insight: {nomination}"));
41 }
42 if !output_summary.is_empty() {
43 context_parts.push(format!("Summary: {output_summary}"));
44 }
45 if !output.is_empty() {
46 let truncated: String = output.chars().take(1500).collect();
47 context_parts.push(format!("Output (truncated): {truncated}"));
48 }
49 if !outcome.is_empty() {
50 context_parts.push(format!("Outcome: {outcome}"));
51 }
52
53 let context = context_parts.join("\n");
54
55 format!(
56 r#"You are a knowledge distillation assistant. Given an agent interaction log, \
57extract zero or more independent reusable procedural principles.
58
59Agent interaction:
60{context}
61
62Output a JSON array. Each item has:
63{{
64 "content": "<principle; when it applies; what to avoid>",
65 "trigger_desc": "<2-6 word canonical phrase>",
66 "anti_trigger_desc": "<when NOT to apply this, or null>"
67}}
68Return [] if nothing is worth keeping.
69
70Rules:
71- content must be self-contained and actionable for a future agent reading cold
72- trigger_desc must match the vocabulary a future agent would use in a search query
73- Never store conversation text verbatim; always distil to reusable principle form
74- If outcome is "fail", focus on what to avoid
75- Keep principles independent; do not combine unrelated lessons"#
76 )
77}
78
79fn build_distill_prompt_with_related(log: &Value, logs: &[Value]) -> String {
80 let mut prompt = build_distill_prompt(log);
81 let log_id = log.get("id").and_then(Value::as_str).unwrap_or("");
82 let context_key = log.get("context_key").and_then(Value::as_str);
83 let related: Vec<String> = logs
84 .iter()
85 .filter(|other| other.get("id").and_then(Value::as_str).unwrap_or("") != log_id)
86 .filter(|other| {
87 context_key.is_some()
88 && other.get("context_key").and_then(Value::as_str) == context_key
89 })
90 .take(4)
91 .map(|other| {
92 let query = safe_prompt_field(other.get("query").and_then(Value::as_str));
93 let summary =
94 safe_prompt_field(other.get("output_summary").and_then(Value::as_str));
95 let outcome = safe_prompt_field(other.get("outcome").and_then(Value::as_str));
96 format!("- Query: {query}; outcome: {outcome}; summary: {summary}")
97 })
98 .collect();
99 if !related.is_empty() {
100 prompt.push_str(
101 "\n\nRelated recent interactions (use only to identify repeated patterns or conflicts):\n",
102 );
103 prompt.push_str(&related.join("\n"));
104 }
105 prompt
106}
107
108pub struct OpenAiDistiller {
113 config: LlmConfig,
114}
115
116impl OpenAiDistiller {
117 pub fn new(config: LlmConfig) -> Self {
118 Self { config }
119 }
120
121 fn call(&self, prompt: &str) -> Result<String> {
122 let api_key = self
123 .config
124 .resolved_api_key()
125 .ok_or_else(|| InnateError::Other("LLM API key not configured".into()))?;
126
127 let base = self.config.resolved_base_url();
128 let url = format!("{base}/chat/completions");
129
130 let body = json!({
131 "model": self.config.model_id,
132 "messages": [{"role": "user", "content": prompt}],
133 "max_tokens": 800,
134 "temperature": 0.2,
135 });
136
137 let response = ureq::post(&url)
138 .timeout(Duration::from_secs(30))
139 .set("Authorization", &format!("Bearer {api_key}"))
140 .set("Content-Type", "application/json")
141 .send_json(&body)
142 .map_err(|e| InnateError::Other(format!("LLM HTTP error: {e}")))?;
143
144 let resp_json: Value = response
145 .into_json()
146 .map_err(|e| InnateError::Other(format!("LLM response parse error: {e}")))?;
147
148 resp_json
149 .pointer("/choices/0/message/content")
150 .and_then(Value::as_str)
151 .map(str::to_string)
152 .ok_or_else(|| InnateError::Other("unexpected LLM response shape".into()))
153 }
154}
155
156impl Distiller for OpenAiDistiller {
157 fn distill(&self, log_entries: &[Value]) -> crate::errors::Result<Vec<DistilledChunk>> {
158 distill_with(log_entries, |prompt| self.call(prompt))
159 }
160
161 fn distill_with_context(
162 &self,
163 primary: &Value,
164 related_logs: &[Value],
165 ) -> crate::errors::Result<Vec<DistilledChunk>> {
166 distill_entry_with(primary, related_logs, |prompt| self.call(prompt))
167 }
168
169 fn provenance(&self) -> DistillProvenance {
170 DistillProvenance {
171 provider: Some(self.config.provider.clone()),
172 model: Some(self.config.model_id.clone()),
173 prompt_version: Some(DISTILL_PROMPT_VERSION.to_string()),
174 }
175 }
176}
177
178pub struct AnthropicDistiller {
183 config: LlmConfig,
184}
185
186impl AnthropicDistiller {
187 pub fn new(config: LlmConfig) -> Self {
188 Self { config }
189 }
190
191 fn call(&self, prompt: &str) -> Result<String> {
192 let api_key = self
193 .config
194 .resolved_api_key()
195 .ok_or_else(|| InnateError::Other("Anthropic API key not configured".into()))?;
196
197 let base = self.config.resolved_base_url();
198 let url = format!("{base}/v1/messages");
199
200 let body = json!({
201 "model": self.config.model_id,
202 "max_tokens": 800,
203 "messages": [{"role": "user", "content": prompt}],
204 });
205
206 let response = ureq::post(&url)
207 .timeout(Duration::from_secs(30))
208 .set("x-api-key", &api_key)
209 .set("anthropic-version", "2023-06-01")
210 .set("Content-Type", "application/json")
211 .send_json(&body)
212 .map_err(|e| InnateError::Other(format!("Anthropic HTTP error: {e}")))?;
213
214 let resp_json: Value = response
215 .into_json()
216 .map_err(|e| InnateError::Other(format!("Anthropic response parse error: {e}")))?;
217
218 resp_json
219 .pointer("/content/0/text")
220 .and_then(Value::as_str)
221 .map(str::to_string)
222 .ok_or_else(|| InnateError::Other("unexpected Anthropic response shape".into()))
223 }
224}
225
226impl Distiller for AnthropicDistiller {
227 fn distill(&self, log_entries: &[Value]) -> crate::errors::Result<Vec<DistilledChunk>> {
228 distill_with(log_entries, |prompt| self.call(prompt))
229 }
230
231 fn distill_with_context(
232 &self,
233 primary: &Value,
234 related_logs: &[Value],
235 ) -> crate::errors::Result<Vec<DistilledChunk>> {
236 distill_entry_with(primary, related_logs, |prompt| self.call(prompt))
237 }
238
239 fn provenance(&self) -> DistillProvenance {
240 DistillProvenance {
241 provider: Some(self.config.provider.clone()),
242 model: Some(self.config.model_id.clone()),
243 prompt_version: Some(DISTILL_PROMPT_VERSION.to_string()),
244 }
245 }
246}
247
248fn distill_with(
253 log_entries: &[Value],
254 call: impl Fn(&str) -> Result<String> + Copy,
255) -> Result<Vec<DistilledChunk>> {
256 let mut out = Vec::new();
257 for entry in log_entries {
258 out.extend(distill_entry_with(entry, log_entries, call)?);
259 }
260 Ok(out)
261}
262
263fn distill_entry_with(
264 entry: &Value,
265 related_logs: &[Value],
266 call: impl Fn(&str) -> Result<String>,
267) -> Result<Vec<DistilledChunk>> {
268 let log_id = entry["id"].as_str().unwrap_or("").to_string();
269 let prompt = build_distill_prompt_with_related(entry, related_logs);
270 let mut raw = call(&prompt)?;
271 let mut parsed = parse_distill_response(&raw);
272 if parsed.is_err() {
273 raw = call(&format!(
274 "{prompt}\n\nYour previous response was invalid. Return only a valid JSON array."
275 ))?;
276 parsed = parse_distill_response(&raw);
277 }
278 let items = parsed
279 .map_err(|error| InnateError::Other(format!("LLM distillation response invalid: {error}")))?;
280 let mut out = Vec::new();
281 for parsed in items {
282 let content = parsed
283 .get("content")
284 .and_then(Value::as_str)
285 .map(str::trim)
286 .filter(|s| !s.is_empty());
287 let Some(content) = content else { continue };
288 let trigger_desc = parsed
289 .get("trigger_desc")
290 .and_then(Value::as_str)
291 .map(str::to_string)
292 .filter(|s| !s.is_empty());
293 let anti_trigger_desc = parsed
294 .get("anti_trigger_desc")
295 .and_then(Value::as_str)
296 .map(str::to_string)
297 .filter(|s| !s.is_empty() && s.to_lowercase() != "null");
298 out.push(DistilledChunk {
299 content: content.to_string(),
300 trigger_desc,
301 anti_trigger_desc,
302 source_log_id: log_id.clone(),
303 nomination: entry
304 .get("nomination")
305 .and_then(Value::as_str)
306 .map(str::to_string),
307 });
308 }
309 Ok(out)
310}
311
312fn parse_distill_response(raw: &str) -> std::result::Result<Vec<Value>, String> {
313 let json_str = extract_json(raw);
314 let parsed: Value = serde_json::from_str(json_str.trim()).map_err(|e| e.to_string())?;
315 if parsed.get("skip").and_then(Value::as_bool) == Some(true) {
316 return Ok(vec![]);
317 }
318 match parsed {
319 Value::Array(items) => Ok(items),
320 Value::Object(_) => Ok(vec![parsed]),
321 _ => Err("expected a JSON object or array".to_string()),
322 }
323}
324
325fn extract_json(text: &str) -> &str {
326 let stripped = text.trim();
328 if let Some(inner) = stripped
329 .strip_prefix("```json")
330 .or_else(|| stripped.strip_prefix("```"))
331 {
332 if let Some(end) = inner.rfind("```") {
333 return inner[..end].trim();
334 }
335 }
336 if let (Some(start), Some(end)) = (stripped.find('['), stripped.rfind(']')) {
337 return &stripped[start..=end];
338 }
339 if let (Some(start), Some(end)) = (stripped.find('{'), stripped.rfind('}')) {
341 return &stripped[start..=end];
342 }
343 stripped
344}
345
346pub fn build_distiller(
351 config: &LlmConfig,
352) -> std::sync::Arc<dyn Distiller + Send + Sync> {
353 match config.provider.as_str() {
354 "anthropic" => std::sync::Arc::new(AnthropicDistiller::new(config.clone())),
355 _ => std::sync::Arc::new(OpenAiDistiller::new(config.clone())),
356 }
357}
358
359pub struct LlmEmbeddingProvider {
364 config: EmbeddingConfig,
365}
366
367#[cfg(test)]
368#[allow(clippy::items_after_test_module)]
369mod tests {
370 use std::cell::Cell;
371
372 use serde_json::json;
373
374 use super::{
375 build_distill_prompt, distill_entry_with, distill_with, parse_distill_response,
376 };
377
378 #[test]
379 fn prompt_redacts_secrets_before_external_llm_call() {
380 let prompt = build_distill_prompt(&json!({
381 "query": "debug sk-12345678901234567890",
382 "output_summary": "Authorization: Bearer secret-token-value"
383 }));
384 assert!(!prompt.contains("sk-12345678901234567890"));
385 assert!(!prompt.contains("secret-token-value"));
386 assert!(prompt.contains("[REDACTED]"));
387 }
388
389 #[test]
390 fn malformed_response_is_retried_instead_of_silently_skipped() {
391 let calls = Cell::new(0);
392 let chunks = distill_with(&[json!({"id": "log-1", "query": "q"})], |_| {
393 calls.set(calls.get() + 1);
394 if calls.get() == 1 {
395 Ok("not json".to_string())
396 } else {
397 Ok(r#"[{"content":"retry worked","trigger_desc":"retry"}]"#.to_string())
398 }
399 })
400 .unwrap();
401 assert_eq!(calls.get(), 2);
402 assert_eq!(chunks.len(), 1);
403 assert_eq!(chunks[0].content, "retry worked");
404 }
405
406 #[test]
407 fn parser_accepts_multiple_distilled_chunks() {
408 let parsed = parse_distill_response(
409 r#"[{"content":"one"},{"content":"two","anti_trigger_desc":"never"}]"#,
410 )
411 .unwrap();
412 assert_eq!(parsed.len(), 2);
413 }
414
415 #[test]
416 fn nomination_is_distilled_instead_of_bypassing_the_model() {
417 let prompt_seen = Cell::new(false);
418 let entry = json!({
419 "id": "log-1",
420 "query": "original query",
421 "nomination": "raw agent nomination",
422 "output_summary": "summary",
423 "outcome": "ok"
424 });
425 let chunks = distill_entry_with(&entry, std::slice::from_ref(&entry), |prompt| {
426 prompt_seen.set(prompt.contains("raw agent nomination"));
427 Ok(
428 r#"[{"content":"generalized principle","trigger_desc":"generalize","anti_trigger_desc":null}]"#
429 .to_string(),
430 )
431 })
432 .unwrap();
433
434 assert!(prompt_seen.get());
435 assert_eq!(chunks[0].content, "generalized principle");
436 assert_eq!(
437 chunks[0].nomination.as_deref(),
438 Some("raw agent nomination")
439 );
440 }
441}
442
443impl LlmEmbeddingProvider {
444 pub fn new(config: EmbeddingConfig) -> Self {
445 Self { config }
446 }
447
448 fn embed(&self, text: &str) -> Result<Vec<f32>> {
449 let api_key = self
450 .config
451 .resolved_api_key()
452 .ok_or_else(|| InnateError::Other("Embedding API key not configured".into()))?;
453
454 let base = self.config.resolved_base_url();
455 let url = format!("{base}/embeddings");
456
457 let body = json!({
458 "input": text,
459 "model": self.config.model_id,
460 });
461
462 let response = ureq::post(&url)
463 .set("Authorization", &format!("Bearer {api_key}"))
464 .set("Content-Type", "application/json")
465 .send_json(&body)
466 .map_err(|e| InnateError::Other(format!("Embedding HTTP error: {e}")))?;
467
468 let resp_json: Value = response
469 .into_json()
470 .map_err(|e| InnateError::Other(format!("Embedding response parse: {e}")))?;
471
472 let embedding = resp_json
473 .pointer("/data/0/embedding")
474 .and_then(Value::as_array)
475 .ok_or_else(|| InnateError::Other("unexpected embedding response shape".into()))?;
476
477 Ok(embedding
478 .iter()
479 .filter_map(Value::as_f64)
480 .map(|x| x as f32)
481 .collect())
482 }
483}
484
485impl EmbeddingProvider for LlmEmbeddingProvider {
486 fn model_name(&self) -> &'static str {
487 "llm-embedding"
488 }
489
490 fn content_dim(&self) -> usize {
491 self.config.dim
492 }
493
494 fn trigger_dim(&self) -> usize {
495 self.config.dim
496 }
497
498 fn embed_content(&self, text: &str) -> Result<Vec<f32>> {
499 self.embed(text)
500 }
501
502 fn embed_trigger(&self, text: &str) -> Result<Vec<f32>> {
503 self.embed(text)
504 }
505}
506
507pub fn test_llm(config: &LlmConfig) -> Result<String> {
513 let distiller = build_distiller(config);
514 let dummy_log = json!({
515 "id": "test",
516 "query": "connection test",
517 "output_summary": "test",
518 "outcome": "ok"
519 });
520 distiller.distill(&[dummy_log])?;
522 Ok(format!("OK — model: {}", config.model_id))
523}
524
525pub fn test_embedding(config: &EmbeddingConfig) -> Result<usize> {
527 let provider = LlmEmbeddingProvider::new(config.clone());
528 let vec = provider.embed("connection test")?;
529 Ok(vec.len())
530}