1use serde_json::{json, Value};
2use std::time::Duration;
3
4use crate::embedding::EmbeddingProvider;
5use crate::errors::{InnateError, Result};
6use crate::refine::{DistillProvenance, DistilledChunk, Distiller, Reranker};
7use crate::settings::{EmbeddingConfig, LlmConfig};
8
9const DISTILL_PROMPT_VERSION: &str = "4";
14
15fn safe_prompt_field(value: Option<&str>) -> String {
16 let value = value.unwrap_or("");
17 let (cleaned, action) = crate::utils::sanitize(value);
18 match action {
19 crate::utils::SanitizeAction::Discard => "[removed unsafe content]".to_string(),
20 _ => cleaned,
21 }
22}
23
24fn build_distill_prompt(log: &Value) -> String {
25 let query = safe_prompt_field(log.get("query").and_then(Value::as_str));
26 let output = safe_prompt_field(log.get("output").and_then(Value::as_str));
27 let output_summary = safe_prompt_field(log.get("output_summary").and_then(Value::as_str));
28 let nomination = safe_prompt_field(log.get("nomination").and_then(Value::as_str));
29 let outcome = safe_prompt_field(log.get("outcome").and_then(Value::as_str));
30
31 let mut context_parts = vec![];
32 if !query.is_empty() {
33 context_parts.push(format!("Query: {query}"));
34 }
35 if !nomination.is_empty() {
36 context_parts.push(format!("Nominated insight: {nomination}"));
37 }
38 if !output_summary.is_empty() {
39 context_parts.push(format!("Summary: {output_summary}"));
40 }
41 if !output.is_empty() {
42 let truncated: String = output.chars().take(1500).collect();
43 context_parts.push(format!("Output (truncated): {truncated}"));
44 }
45 if !outcome.is_empty() {
46 context_parts.push(format!("Outcome: {outcome}"));
47 }
48
49 let context = context_parts.join("\n");
50
51 format!(
52 r#"You are a knowledge distillation assistant. Given an agent interaction log, \
53extract zero or more independent reusable procedural principles. Favor GENERAL, \
54transferable skills, methods, and techniques over project-specific facts.
55
56Agent interaction:
57{context}
58
59Output a JSON array. Each item has:
60{{
61 "skill_name": "<1-3 word skill/topic label for this principle>",
62 "content": "<principle; when it applies; what to avoid>",
63 "trigger_desc": "<2-6 word canonical phrase>",
64 "anti_trigger_desc": "<when NOT to apply this, or null>"
65}}
66Return [] if nothing is worth keeping.
67
68Rules:
69- skill_name is a short human label (1-3 words) naming the skill/topic, e.g.
70 "error handling", "git rebase", "async retries"; not a sentence
71- content must be self-contained and actionable for a future agent reading cold
72- Prefer transferable methods and techniques; a principle that helps across many
73 projects is worth far more than one tied to this codebase
74- Abstract away project-specific detail: strip repo/file/function/path/variable names
75 and one-off identifiers, and rephrase the lesson as a general principle whoever the
76 next project is. Keep concrete project-specific detail ONLY when the lesson genuinely
77 cannot be generalized without losing its meaning
78- trigger_desc must match the vocabulary a future agent would use in a search query;
79 prefer general, technology- or domain-level phrasing over project-name phrasing
80- Never store conversation text verbatim; always distil to reusable principle form
81- If outcome is "fail", focus on what to avoid
82- Keep principles independent; do not combine unrelated lessons"#
83 )
84}
85
86fn build_distill_prompt_with_related(log: &Value, logs: &[Value]) -> String {
87 let mut prompt = build_distill_prompt(log);
88 let log_id = log.get("id").and_then(Value::as_str).unwrap_or("");
89 let context_key = log.get("context_key").and_then(Value::as_str);
90 let related: Vec<String> = logs
91 .iter()
92 .filter(|other| other.get("id").and_then(Value::as_str).unwrap_or("") != log_id)
93 .filter(|other| {
94 context_key.is_some() && other.get("context_key").and_then(Value::as_str) == context_key
95 })
96 .take(4)
97 .map(|other| {
98 let query = safe_prompt_field(other.get("query").and_then(Value::as_str));
99 let summary = safe_prompt_field(other.get("output_summary").and_then(Value::as_str));
100 let outcome = safe_prompt_field(other.get("outcome").and_then(Value::as_str));
101 format!("- Query: {query}; outcome: {outcome}; summary: {summary}")
102 })
103 .collect();
104 if !related.is_empty() {
105 prompt.push_str(
106 "\n\nRelated recent interactions (use only to identify repeated patterns or conflicts):\n",
107 );
108 prompt.push_str(&related.join("\n"));
109 }
110 prompt
111}
112
113const HTTP_MAX_ATTEMPTS: u32 = 3;
119const HTTP_TIMEOUT: Duration = Duration::from_secs(30);
121
122fn post_json_retry(
127 url: &str,
128 headers: &[(&str, &str)],
129 body: &Value,
130 label: &str,
131) -> Result<Value> {
132 let start = std::time::Instant::now();
136 let mut attempt = 0;
137 let outcome: Result<Value> = loop {
138 attempt += 1;
139 let mut req = ureq::post(url)
144 .config()
145 .timeout_global(Some(HTTP_TIMEOUT))
146 .http_status_as_error(false)
147 .build()
148 .header("Content-Type", "application/json");
149 for (k, v) in headers {
150 req = req.header(*k, *v);
151 }
152 match req.send_json(body) {
153 Ok(mut response) => {
154 let code = response.status().as_u16();
155 if (200..300).contains(&code) {
156 break response.body_mut().read_json::<Value>().map_err(|e| {
157 InnateError::Other(format!("{label} response parse error: {e}"))
158 });
159 }
160 let retry_after = response
161 .headers()
162 .get("retry-after")
163 .and_then(|h| h.to_str().ok())
164 .and_then(|s| s.trim().parse::<u64>().ok());
165 if status_is_retryable(code) && attempt < HTTP_MAX_ATTEMPTS {
166 std::thread::sleep(backoff_delay(attempt, retry_after));
167 continue;
168 }
169 let detail = response.body_mut().read_to_string().unwrap_or_default();
172 break Err(InnateError::Other(format!(
173 "{label} HTTP error: status: {code} {detail}"
174 )));
175 }
176 Err(err) => {
177 if attempt < HTTP_MAX_ATTEMPTS {
178 std::thread::sleep(backoff_delay(attempt, None));
179 continue;
180 }
181 break Err(InnateError::Other(format!(
183 "{label} HTTP error: transport: {err}"
184 )));
185 }
186 }
187 };
188 crate::llm_trace::record(label, url, body, &outcome, attempt, start.elapsed());
189 outcome
190}
191
192fn status_is_retryable(code: u16) -> bool {
194 code == 429 || (500..=599).contains(&code)
195}
196
197fn backoff_delay(attempt: u32, retry_after_secs: Option<u64>) -> Duration {
200 if let Some(secs) = retry_after_secs {
201 return Duration::from_secs(secs.min(30));
202 }
203 let shift = (attempt - 1).min(6);
204 Duration::from_millis(250u64.saturating_mul(1 << shift))
205}
206
207pub struct HttpDistiller {
215 config: LlmConfig,
216}
217
218impl HttpDistiller {
219 pub fn new(config: LlmConfig) -> Self {
220 Self { config }
221 }
222
223 pub fn call(&self, prompt: &str) -> Result<String> {
226 if self.config.provider == "anthropic" {
227 self.call_anthropic(prompt)
228 } else {
229 self.call_openai(prompt)
230 }
231 }
232
233 fn call_openai(&self, prompt: &str) -> Result<String> {
234 let api_key = self
235 .config
236 .resolved_api_key()
237 .ok_or_else(|| InnateError::Other("LLM API key not configured".into()))?;
238
239 let base = self.config.resolved_base_url();
240 let url = format!("{base}/chat/completions");
241
242 let body = json!({
243 "model": self.config.model_id,
244 "messages": [{"role": "user", "content": prompt}],
245 "max_tokens": 800,
246 "temperature": 0.2,
247 });
248
249 let auth = format!("Bearer {api_key}");
250 let resp_json = post_json_retry(&url, &[("Authorization", &auth)], &body, "LLM")?;
251
252 resp_json
253 .pointer("/choices/0/message/content")
254 .and_then(Value::as_str)
255 .map(str::to_string)
256 .ok_or_else(|| InnateError::Other("unexpected LLM response shape".into()))
257 }
258
259 fn call_anthropic(&self, prompt: &str) -> Result<String> {
260 let api_key = self
261 .config
262 .resolved_api_key()
263 .ok_or_else(|| InnateError::Other("Anthropic API key not configured".into()))?;
264
265 let base = self.config.resolved_base_url();
266 let url = format!("{base}/v1/messages");
267
268 let body = json!({
269 "model": self.config.model_id,
270 "max_tokens": 800,
271 "messages": [{"role": "user", "content": prompt}],
272 });
273
274 let resp_json = post_json_retry(
275 &url,
276 &[("x-api-key", &api_key), ("anthropic-version", "2023-06-01")],
277 &body,
278 "Anthropic",
279 )?;
280
281 resp_json
282 .pointer("/content/0/text")
283 .and_then(Value::as_str)
284 .map(str::to_string)
285 .ok_or_else(|| InnateError::Other("unexpected Anthropic response shape".into()))
286 }
287}
288
289impl Distiller for HttpDistiller {
290 fn distill(&self, log_entries: &[Value]) -> crate::errors::Result<Vec<DistilledChunk>> {
291 distill_with(log_entries, |prompt| self.call(prompt))
292 }
293
294 fn distill_with_context(
295 &self,
296 primary: &Value,
297 related_logs: &[Value],
298 ) -> crate::errors::Result<Vec<DistilledChunk>> {
299 distill_entry_with(primary, related_logs, |prompt| self.call(prompt))
300 }
301
302 fn provenance(&self) -> DistillProvenance {
303 DistillProvenance {
304 provider: Some(self.config.provider.clone()),
305 model: Some(self.config.model_id.clone()),
306 prompt_version: Some(DISTILL_PROMPT_VERSION.to_string()),
307 }
308 }
309}
310
311fn distill_with(
316 log_entries: &[Value],
317 call: impl Fn(&str) -> Result<String> + Copy,
318) -> Result<Vec<DistilledChunk>> {
319 let mut out = Vec::new();
320 for entry in log_entries {
321 out.extend(distill_entry_with(entry, log_entries, call)?);
322 }
323 Ok(out)
324}
325
326fn distill_entry_with(
327 entry: &Value,
328 related_logs: &[Value],
329 call: impl Fn(&str) -> Result<String>,
330) -> Result<Vec<DistilledChunk>> {
331 let log_id = entry["id"].as_str().unwrap_or("").to_string();
332 let prompt = build_distill_prompt_with_related(entry, related_logs);
333 let mut raw = call(&prompt)?;
334 let mut parsed = parse_distill_response(&raw);
335 if parsed.is_err() {
336 raw = call(&format!(
337 "{prompt}\n\nYour previous response was invalid. Return only a valid JSON array."
338 ))?;
339 parsed = parse_distill_response(&raw);
340 }
341 let items = parsed.map_err(|error| {
342 InnateError::Other(format!("LLM distillation response invalid: {error}"))
343 })?;
344 let mut out = Vec::new();
345 for parsed in items {
346 let content = parsed
347 .get("content")
348 .and_then(Value::as_str)
349 .map(str::trim)
350 .filter(|s| !s.is_empty());
351 let Some(content) = content else { continue };
352 let skill_name = parsed
353 .get("skill_name")
354 .and_then(Value::as_str)
355 .map(|s| {
356 s.trim()
357 .split_whitespace()
358 .take(3)
359 .collect::<Vec<_>>()
360 .join(" ")
361 })
362 .filter(|s| !s.is_empty() && s.to_lowercase() != "null");
363 let trigger_desc = parsed
364 .get("trigger_desc")
365 .and_then(Value::as_str)
366 .map(str::to_string)
367 .filter(|s| !s.is_empty());
368 let anti_trigger_desc = parsed
369 .get("anti_trigger_desc")
370 .and_then(Value::as_str)
371 .map(str::to_string)
372 .filter(|s| !s.is_empty() && s.to_lowercase() != "null");
373 out.push(DistilledChunk {
374 content: content.to_string(),
375 skill_name,
376 trigger_desc,
377 anti_trigger_desc,
378 source_log_id: log_id.clone(),
379 nomination: entry
380 .get("nomination")
381 .and_then(Value::as_str)
382 .map(str::to_string),
383 provider_override: None,
384 });
385 }
386 Ok(out)
387}
388
389fn parse_distill_response(raw: &str) -> std::result::Result<Vec<Value>, String> {
390 let json_str = extract_json(raw);
391 let parsed: Value = serde_json::from_str(json_str.trim()).map_err(|e| e.to_string())?;
392 if parsed.get("skip").and_then(Value::as_bool) == Some(true) {
393 return Ok(vec![]);
394 }
395 match parsed {
396 Value::Array(items) => Ok(items),
397 Value::Object(_) => Ok(vec![parsed]),
398 _ => Err("expected a JSON object or array".to_string()),
399 }
400}
401
402fn extract_json(text: &str) -> &str {
403 let stripped = text.trim();
405 if let Some(inner) = stripped
406 .strip_prefix("```json")
407 .or_else(|| stripped.strip_prefix("```"))
408 {
409 if let Some(end) = inner.rfind("```") {
410 return inner[..end].trim();
411 }
412 }
413 if let (Some(start), Some(end)) = (stripped.find('['), stripped.rfind(']')) {
414 return &stripped[start..=end];
415 }
416 if let (Some(start), Some(end)) = (stripped.find('{'), stripped.rfind('}')) {
418 return &stripped[start..=end];
419 }
420 stripped
421}
422
423pub fn build_distiller(config: &LlmConfig) -> std::sync::Arc<dyn Distiller + Send + Sync> {
428 std::sync::Arc::new(HttpDistiller::new(config.clone()))
429}
430
431pub struct LlmReranker {
438 inner: HttpDistiller,
439}
440
441impl LlmReranker {
442 pub fn new(config: LlmConfig) -> Self {
443 Self {
444 inner: HttpDistiller::new(config),
445 }
446 }
447}
448
449impl Reranker for LlmReranker {
450 fn rerank(&self, query: &str, candidates: &[Value]) -> Result<Vec<String>> {
451 if candidates.is_empty() {
452 return Ok(Vec::new());
453 }
454 let mut list = String::new();
455 for c in candidates {
456 let id = c.get("id").and_then(Value::as_str).unwrap_or("");
457 let trig = c.get("trigger_desc").and_then(Value::as_str).unwrap_or("");
458 let content: String = c
459 .get("content")
460 .and_then(Value::as_str)
461 .unwrap_or("")
462 .chars()
463 .take(280)
464 .collect();
465 list.push_str(&format!("- id={id} | when={trig} | {content}\n"));
466 }
467 let prompt = format!(
468 "You are reranking knowledge snippets by how directly each one helps with the QUERY. \
469 Consider the snippet's `when` (trigger) and content. Return ONLY a JSON array of the \
470 ids, most relevant first, no prose, no ids that are not listed.\n\n\
471 QUERY: {query}\n\nCANDIDATES:\n{list}"
472 );
473 let resp = self.inner.call(&prompt)?;
474 parse_id_array(&resp)
475 .ok_or_else(|| InnateError::Other("reranker: no id array in LLM response".into()))
476 }
477}
478
479fn parse_id_array(resp: &str) -> Option<Vec<String>> {
481 let start = resp.find('[')?;
482 let end = resp.rfind(']')?;
483 if end <= start {
484 return None;
485 }
486 let arr: Value = serde_json::from_str(&resp[start..=end]).ok()?;
487 let ids: Vec<String> = arr
488 .as_array()?
489 .iter()
490 .filter_map(|v| v.as_str().map(str::to_string))
491 .collect();
492 if ids.is_empty() {
493 None
494 } else {
495 Some(ids)
496 }
497}
498
499pub struct LlmEmbeddingProvider {
504 config: EmbeddingConfig,
505}
506
507#[cfg(test)]
508#[allow(clippy::items_after_test_module)]
509mod tests {
510 use std::cell::Cell;
511
512 use serde_json::json;
513
514 use std::time::Duration;
515
516 use super::{
517 backoff_delay, build_distill_prompt, distill_entry_with, distill_with,
518 parse_distill_response, parse_embedding_response, status_is_retryable,
519 };
520
521 #[test]
522 fn embedding_response_is_parsed_fail_closed() {
523 let resp = json!({"data": [{"embedding": [0.1, 0.2, 0.3]}]});
525 assert_eq!(
526 parse_embedding_response(&resp, 3).unwrap(),
527 vec![0.1f32, 0.2, 0.3]
528 );
529
530 assert!(parse_embedding_response(&resp, 4).is_err());
532
533 let bad = json!({"data": [{"embedding": [0.1, "oops", 0.3]}]});
535 assert!(parse_embedding_response(&bad, 3).is_err());
536
537 let shape = json!({"data": []});
539 assert!(parse_embedding_response(&shape, 3).is_err());
540 }
541
542 #[test]
543 fn only_rate_limit_and_5xx_are_retryable() {
544 assert!(status_is_retryable(429));
545 assert!(status_is_retryable(500));
546 assert!(status_is_retryable(503));
547 assert!(status_is_retryable(599));
548 assert!(!status_is_retryable(400));
549 assert!(!status_is_retryable(401));
550 assert!(!status_is_retryable(404));
551 assert!(!status_is_retryable(200));
552 }
553
554 #[test]
555 fn backoff_is_exponential_and_honors_retry_after() {
556 assert_eq!(backoff_delay(1, None), Duration::from_millis(250));
558 assert_eq!(backoff_delay(2, None), Duration::from_millis(500));
559 assert_eq!(backoff_delay(3, None), Duration::from_millis(1000));
560 assert_eq!(backoff_delay(1, Some(5)), Duration::from_secs(5));
562 assert_eq!(backoff_delay(1, Some(120)), Duration::from_secs(30));
563 }
564
565 #[test]
566 fn prompt_redacts_secrets_before_external_llm_call() {
567 let prompt = build_distill_prompt(&json!({
568 "query": "debug sk-12345678901234567890",
569 "output_summary": "Authorization: Bearer secret-token-value"
570 }));
571 assert!(!prompt.contains("sk-12345678901234567890"));
572 assert!(!prompt.contains("secret-token-value"));
573 assert!(prompt.contains("[REDACTED]"));
574 }
575
576 #[test]
577 fn malformed_response_is_retried_instead_of_silently_skipped() {
578 let calls = Cell::new(0);
579 let chunks = distill_with(&[json!({"id": "log-1", "query": "q"})], |_| {
580 calls.set(calls.get() + 1);
581 if calls.get() == 1 {
582 Ok("not json".to_string())
583 } else {
584 Ok(r#"[{"content":"retry worked","trigger_desc":"retry"}]"#.to_string())
585 }
586 })
587 .unwrap();
588 assert_eq!(calls.get(), 2);
589 assert_eq!(chunks.len(), 1);
590 assert_eq!(chunks[0].content, "retry worked");
591 }
592
593 #[test]
594 fn parser_accepts_multiple_distilled_chunks() {
595 let parsed = parse_distill_response(
596 r#"[{"content":"one"},{"content":"two","anti_trigger_desc":"never"}]"#,
597 )
598 .unwrap();
599 assert_eq!(parsed.len(), 2);
600 }
601
602 #[test]
603 fn nomination_is_distilled_instead_of_bypassing_the_model() {
604 let prompt_seen = Cell::new(false);
605 let entry = json!({
606 "id": "log-1",
607 "query": "original query",
608 "nomination": "raw agent nomination",
609 "output_summary": "summary",
610 "outcome": "ok"
611 });
612 let chunks = distill_entry_with(&entry, std::slice::from_ref(&entry), |prompt| {
613 prompt_seen.set(prompt.contains("raw agent nomination"));
614 Ok(
615 r#"[{"content":"generalized principle","trigger_desc":"generalize","anti_trigger_desc":null}]"#
616 .to_string(),
617 )
618 })
619 .unwrap();
620
621 assert!(prompt_seen.get());
622 assert_eq!(chunks[0].content, "generalized principle");
623 assert_eq!(
624 chunks[0].nomination.as_deref(),
625 Some("raw agent nomination")
626 );
627 }
628}
629
630impl LlmEmbeddingProvider {
631 pub fn new(config: EmbeddingConfig) -> Self {
632 Self { config }
633 }
634
635 fn embed(&self, text: &str) -> Result<Vec<f32>> {
636 let api_key = self
637 .config
638 .resolved_api_key()
639 .ok_or_else(|| InnateError::Other("Embedding API key not configured".into()))?;
640
641 let base = self.config.resolved_base_url();
642 let url = format!("{base}/embeddings");
643
644 let body = json!({
645 "input": text,
646 "model": self.config.model_id,
647 });
648
649 let auth = format!("Bearer {api_key}");
650 let resp_json = post_json_retry(&url, &[("Authorization", &auth)], &body, "Embedding")?;
651
652 parse_embedding_response(&resp_json, self.config.dim)
653 }
654}
655
656fn parse_embedding_response(resp_json: &Value, expected_dim: usize) -> Result<Vec<f32>> {
662 let embedding = resp_json
663 .pointer("/data/0/embedding")
664 .and_then(Value::as_array)
665 .ok_or_else(|| InnateError::Other("unexpected embedding response shape".into()))?;
666 let vec: Vec<f32> = embedding
667 .iter()
668 .map(|v| {
669 v.as_f64().map(|x| x as f32).ok_or_else(|| {
670 InnateError::Other("embedding response contains a non-numeric element".into())
671 })
672 })
673 .collect::<Result<Vec<f32>>>()?;
674 if vec.len() != expected_dim {
675 return Err(InnateError::Other(format!(
676 "embedding dimension mismatch: provider returned {}, expected {expected_dim} (check embedding.dim)",
677 vec.len(),
678 )));
679 }
680 Ok(vec)
681}
682
683impl EmbeddingProvider for LlmEmbeddingProvider {
684 fn model_name(&self) -> &'static str {
685 "llm-embedding"
686 }
687
688 fn content_dim(&self) -> usize {
689 self.config.dim
690 }
691
692 fn trigger_dim(&self) -> usize {
693 self.config.dim
694 }
695
696 fn embed_content(&self, text: &str) -> Result<Vec<f32>> {
697 self.embed(text)
698 }
699
700 fn embed_trigger(&self, text: &str) -> Result<Vec<f32>> {
701 self.embed(text)
702 }
703
704 fn embed_both(&self, text: &str) -> Result<(Vec<f32>, Vec<f32>)> {
707 let v = self.embed(text)?;
708 Ok((v.clone(), v))
709 }
710}
711
712pub fn test_llm(config: &LlmConfig) -> Result<String> {
718 let distiller = build_distiller(config);
719 let dummy_log = json!({
720 "id": "test",
721 "query": "connection test",
722 "output_summary": "test",
723 "outcome": "ok"
724 });
725 distiller.distill(&[dummy_log])?;
727 Ok(format!("OK — model: {}", config.model_id))
728}
729
730pub fn test_embedding(config: &EmbeddingConfig) -> Result<usize> {
732 let provider = LlmEmbeddingProvider::new(config.clone());
733 let vec = provider.embed("connection test")?;
734 Ok(vec.len())
735}