1use serde_json::{json, Value};
2use std::time::Duration;
3
4use crate::embedding::EmbeddingProvider;
5use crate::errors::{InnateError, Result};
6use crate::refine::{DistillProvenance, DistilledChunk, Distiller};
7use crate::settings::{EmbeddingConfig, LlmConfig};
8
9const DISTILL_PROMPT_VERSION: &str = "2";
14
15fn safe_prompt_field(value: Option<&str>) -> String {
16 let value = value.unwrap_or("");
17 let (cleaned, action) = crate::utils::sanitize(value);
18 match action {
19 crate::utils::SanitizeAction::Discard => "[removed unsafe content]".to_string(),
20 _ => cleaned,
21 }
22}
23
24fn build_distill_prompt(log: &Value) -> String {
25 let query = safe_prompt_field(log.get("query").and_then(Value::as_str));
26 let output = safe_prompt_field(log.get("output").and_then(Value::as_str));
27 let output_summary = safe_prompt_field(log.get("output_summary").and_then(Value::as_str));
28 let nomination = safe_prompt_field(log.get("nomination").and_then(Value::as_str));
29 let outcome = safe_prompt_field(log.get("outcome").and_then(Value::as_str));
30
31 let mut context_parts = vec![];
32 if !query.is_empty() {
33 context_parts.push(format!("Query: {query}"));
34 }
35 if !nomination.is_empty() {
36 context_parts.push(format!("Nominated insight: {nomination}"));
37 }
38 if !output_summary.is_empty() {
39 context_parts.push(format!("Summary: {output_summary}"));
40 }
41 if !output.is_empty() {
42 let truncated: String = output.chars().take(1500).collect();
43 context_parts.push(format!("Output (truncated): {truncated}"));
44 }
45 if !outcome.is_empty() {
46 context_parts.push(format!("Outcome: {outcome}"));
47 }
48
49 let context = context_parts.join("\n");
50
51 format!(
52 r#"You are a knowledge distillation assistant. Given an agent interaction log, \
53extract zero or more independent reusable procedural principles.
54
55Agent interaction:
56{context}
57
58Output a JSON array. Each item has:
59{{
60 "content": "<principle; when it applies; what to avoid>",
61 "trigger_desc": "<2-6 word canonical phrase>",
62 "anti_trigger_desc": "<when NOT to apply this, or null>"
63}}
64Return [] if nothing is worth keeping.
65
66Rules:
67- content must be self-contained and actionable for a future agent reading cold
68- trigger_desc must match the vocabulary a future agent would use in a search query
69- Never store conversation text verbatim; always distil to reusable principle form
70- If outcome is "fail", focus on what to avoid
71- Keep principles independent; do not combine unrelated lessons"#
72 )
73}
74
75fn build_distill_prompt_with_related(log: &Value, logs: &[Value]) -> String {
76 let mut prompt = build_distill_prompt(log);
77 let log_id = log.get("id").and_then(Value::as_str).unwrap_or("");
78 let context_key = log.get("context_key").and_then(Value::as_str);
79 let related: Vec<String> = logs
80 .iter()
81 .filter(|other| other.get("id").and_then(Value::as_str).unwrap_or("") != log_id)
82 .filter(|other| {
83 context_key.is_some() && other.get("context_key").and_then(Value::as_str) == context_key
84 })
85 .take(4)
86 .map(|other| {
87 let query = safe_prompt_field(other.get("query").and_then(Value::as_str));
88 let summary = safe_prompt_field(other.get("output_summary").and_then(Value::as_str));
89 let outcome = safe_prompt_field(other.get("outcome").and_then(Value::as_str));
90 format!("- Query: {query}; outcome: {outcome}; summary: {summary}")
91 })
92 .collect();
93 if !related.is_empty() {
94 prompt.push_str(
95 "\n\nRelated recent interactions (use only to identify repeated patterns or conflicts):\n",
96 );
97 prompt.push_str(&related.join("\n"));
98 }
99 prompt
100}
101
102const HTTP_MAX_ATTEMPTS: u32 = 3;
108const HTTP_TIMEOUT: Duration = Duration::from_secs(30);
110
111fn post_json_retry(
116 url: &str,
117 headers: &[(&str, &str)],
118 body: &Value,
119 label: &str,
120) -> Result<Value> {
121 let mut attempt = 0;
122 loop {
123 attempt += 1;
124 let mut req = ureq::post(url)
125 .timeout(HTTP_TIMEOUT)
126 .set("Content-Type", "application/json");
127 for (k, v) in headers {
128 req = req.set(k, v);
129 }
130 match req.send_json(body) {
131 Ok(response) => {
132 return response
133 .into_json()
134 .map_err(|e| InnateError::Other(format!("{label} response parse error: {e}")));
135 }
136 Err(err) => {
137 let (retryable, retry_after) = match &err {
138 ureq::Error::Status(code, resp) => (
139 status_is_retryable(*code),
140 resp.header("retry-after")
141 .and_then(|s| s.trim().parse::<u64>().ok()),
142 ),
143 ureq::Error::Transport(_) => (true, None),
144 };
145 if retryable && attempt < HTTP_MAX_ATTEMPTS {
146 std::thread::sleep(backoff_delay(attempt, retry_after));
147 continue;
148 }
149 return Err(InnateError::Other(format!("{label} HTTP error: {err}")));
150 }
151 }
152 }
153}
154
155fn status_is_retryable(code: u16) -> bool {
157 code == 429 || (500..=599).contains(&code)
158}
159
160fn backoff_delay(attempt: u32, retry_after_secs: Option<u64>) -> Duration {
163 if let Some(secs) = retry_after_secs {
164 return Duration::from_secs(secs.min(30));
165 }
166 let shift = (attempt - 1).min(6);
167 Duration::from_millis(250u64.saturating_mul(1 << shift))
168}
169
170pub struct HttpDistiller {
178 config: LlmConfig,
179}
180
181impl HttpDistiller {
182 pub fn new(config: LlmConfig) -> Self {
183 Self { config }
184 }
185
186 fn call(&self, prompt: &str) -> Result<String> {
187 if self.config.provider == "anthropic" {
188 self.call_anthropic(prompt)
189 } else {
190 self.call_openai(prompt)
191 }
192 }
193
194 fn call_openai(&self, prompt: &str) -> Result<String> {
195 let api_key = self
196 .config
197 .resolved_api_key()
198 .ok_or_else(|| InnateError::Other("LLM API key not configured".into()))?;
199
200 let base = self.config.resolved_base_url();
201 let url = format!("{base}/chat/completions");
202
203 let body = json!({
204 "model": self.config.model_id,
205 "messages": [{"role": "user", "content": prompt}],
206 "max_tokens": 800,
207 "temperature": 0.2,
208 });
209
210 let auth = format!("Bearer {api_key}");
211 let resp_json = post_json_retry(&url, &[("Authorization", &auth)], &body, "LLM")?;
212
213 resp_json
214 .pointer("/choices/0/message/content")
215 .and_then(Value::as_str)
216 .map(str::to_string)
217 .ok_or_else(|| InnateError::Other("unexpected LLM response shape".into()))
218 }
219
220 fn call_anthropic(&self, prompt: &str) -> Result<String> {
221 let api_key = self
222 .config
223 .resolved_api_key()
224 .ok_or_else(|| InnateError::Other("Anthropic API key not configured".into()))?;
225
226 let base = self.config.resolved_base_url();
227 let url = format!("{base}/v1/messages");
228
229 let body = json!({
230 "model": self.config.model_id,
231 "max_tokens": 800,
232 "messages": [{"role": "user", "content": prompt}],
233 });
234
235 let resp_json = post_json_retry(
236 &url,
237 &[("x-api-key", &api_key), ("anthropic-version", "2023-06-01")],
238 &body,
239 "Anthropic",
240 )?;
241
242 resp_json
243 .pointer("/content/0/text")
244 .and_then(Value::as_str)
245 .map(str::to_string)
246 .ok_or_else(|| InnateError::Other("unexpected Anthropic response shape".into()))
247 }
248}
249
250impl Distiller for HttpDistiller {
251 fn distill(&self, log_entries: &[Value]) -> crate::errors::Result<Vec<DistilledChunk>> {
252 distill_with(log_entries, |prompt| self.call(prompt))
253 }
254
255 fn distill_with_context(
256 &self,
257 primary: &Value,
258 related_logs: &[Value],
259 ) -> crate::errors::Result<Vec<DistilledChunk>> {
260 distill_entry_with(primary, related_logs, |prompt| self.call(prompt))
261 }
262
263 fn provenance(&self) -> DistillProvenance {
264 DistillProvenance {
265 provider: Some(self.config.provider.clone()),
266 model: Some(self.config.model_id.clone()),
267 prompt_version: Some(DISTILL_PROMPT_VERSION.to_string()),
268 }
269 }
270}
271
272fn distill_with(
277 log_entries: &[Value],
278 call: impl Fn(&str) -> Result<String> + Copy,
279) -> Result<Vec<DistilledChunk>> {
280 let mut out = Vec::new();
281 for entry in log_entries {
282 out.extend(distill_entry_with(entry, log_entries, call)?);
283 }
284 Ok(out)
285}
286
287fn distill_entry_with(
288 entry: &Value,
289 related_logs: &[Value],
290 call: impl Fn(&str) -> Result<String>,
291) -> Result<Vec<DistilledChunk>> {
292 let log_id = entry["id"].as_str().unwrap_or("").to_string();
293 let prompt = build_distill_prompt_with_related(entry, related_logs);
294 let mut raw = call(&prompt)?;
295 let mut parsed = parse_distill_response(&raw);
296 if parsed.is_err() {
297 raw = call(&format!(
298 "{prompt}\n\nYour previous response was invalid. Return only a valid JSON array."
299 ))?;
300 parsed = parse_distill_response(&raw);
301 }
302 let items = parsed.map_err(|error| {
303 InnateError::Other(format!("LLM distillation response invalid: {error}"))
304 })?;
305 let mut out = Vec::new();
306 for parsed in items {
307 let content = parsed
308 .get("content")
309 .and_then(Value::as_str)
310 .map(str::trim)
311 .filter(|s| !s.is_empty());
312 let Some(content) = content else { continue };
313 let trigger_desc = parsed
314 .get("trigger_desc")
315 .and_then(Value::as_str)
316 .map(str::to_string)
317 .filter(|s| !s.is_empty());
318 let anti_trigger_desc = parsed
319 .get("anti_trigger_desc")
320 .and_then(Value::as_str)
321 .map(str::to_string)
322 .filter(|s| !s.is_empty() && s.to_lowercase() != "null");
323 out.push(DistilledChunk {
324 content: content.to_string(),
325 trigger_desc,
326 anti_trigger_desc,
327 source_log_id: log_id.clone(),
328 nomination: entry
329 .get("nomination")
330 .and_then(Value::as_str)
331 .map(str::to_string),
332 });
333 }
334 Ok(out)
335}
336
337fn parse_distill_response(raw: &str) -> std::result::Result<Vec<Value>, String> {
338 let json_str = extract_json(raw);
339 let parsed: Value = serde_json::from_str(json_str.trim()).map_err(|e| e.to_string())?;
340 if parsed.get("skip").and_then(Value::as_bool) == Some(true) {
341 return Ok(vec![]);
342 }
343 match parsed {
344 Value::Array(items) => Ok(items),
345 Value::Object(_) => Ok(vec![parsed]),
346 _ => Err("expected a JSON object or array".to_string()),
347 }
348}
349
350fn extract_json(text: &str) -> &str {
351 let stripped = text.trim();
353 if let Some(inner) = stripped
354 .strip_prefix("```json")
355 .or_else(|| stripped.strip_prefix("```"))
356 {
357 if let Some(end) = inner.rfind("```") {
358 return inner[..end].trim();
359 }
360 }
361 if let (Some(start), Some(end)) = (stripped.find('['), stripped.rfind(']')) {
362 return &stripped[start..=end];
363 }
364 if let (Some(start), Some(end)) = (stripped.find('{'), stripped.rfind('}')) {
366 return &stripped[start..=end];
367 }
368 stripped
369}
370
371pub fn build_distiller(config: &LlmConfig) -> std::sync::Arc<dyn Distiller + Send + Sync> {
376 std::sync::Arc::new(HttpDistiller::new(config.clone()))
377}
378
379pub struct LlmEmbeddingProvider {
384 config: EmbeddingConfig,
385}
386
387#[cfg(test)]
388#[allow(clippy::items_after_test_module)]
389mod tests {
390 use std::cell::Cell;
391
392 use serde_json::json;
393
394 use std::time::Duration;
395
396 use super::{
397 backoff_delay, build_distill_prompt, distill_entry_with, distill_with,
398 parse_distill_response, status_is_retryable,
399 };
400
401 #[test]
402 fn only_rate_limit_and_5xx_are_retryable() {
403 assert!(status_is_retryable(429));
404 assert!(status_is_retryable(500));
405 assert!(status_is_retryable(503));
406 assert!(status_is_retryable(599));
407 assert!(!status_is_retryable(400));
408 assert!(!status_is_retryable(401));
409 assert!(!status_is_retryable(404));
410 assert!(!status_is_retryable(200));
411 }
412
413 #[test]
414 fn backoff_is_exponential_and_honors_retry_after() {
415 assert_eq!(backoff_delay(1, None), Duration::from_millis(250));
417 assert_eq!(backoff_delay(2, None), Duration::from_millis(500));
418 assert_eq!(backoff_delay(3, None), Duration::from_millis(1000));
419 assert_eq!(backoff_delay(1, Some(5)), Duration::from_secs(5));
421 assert_eq!(backoff_delay(1, Some(120)), Duration::from_secs(30));
422 }
423
424 #[test]
425 fn prompt_redacts_secrets_before_external_llm_call() {
426 let prompt = build_distill_prompt(&json!({
427 "query": "debug sk-12345678901234567890",
428 "output_summary": "Authorization: Bearer secret-token-value"
429 }));
430 assert!(!prompt.contains("sk-12345678901234567890"));
431 assert!(!prompt.contains("secret-token-value"));
432 assert!(prompt.contains("[REDACTED]"));
433 }
434
435 #[test]
436 fn malformed_response_is_retried_instead_of_silently_skipped() {
437 let calls = Cell::new(0);
438 let chunks = distill_with(&[json!({"id": "log-1", "query": "q"})], |_| {
439 calls.set(calls.get() + 1);
440 if calls.get() == 1 {
441 Ok("not json".to_string())
442 } else {
443 Ok(r#"[{"content":"retry worked","trigger_desc":"retry"}]"#.to_string())
444 }
445 })
446 .unwrap();
447 assert_eq!(calls.get(), 2);
448 assert_eq!(chunks.len(), 1);
449 assert_eq!(chunks[0].content, "retry worked");
450 }
451
452 #[test]
453 fn parser_accepts_multiple_distilled_chunks() {
454 let parsed = parse_distill_response(
455 r#"[{"content":"one"},{"content":"two","anti_trigger_desc":"never"}]"#,
456 )
457 .unwrap();
458 assert_eq!(parsed.len(), 2);
459 }
460
461 #[test]
462 fn nomination_is_distilled_instead_of_bypassing_the_model() {
463 let prompt_seen = Cell::new(false);
464 let entry = json!({
465 "id": "log-1",
466 "query": "original query",
467 "nomination": "raw agent nomination",
468 "output_summary": "summary",
469 "outcome": "ok"
470 });
471 let chunks = distill_entry_with(&entry, std::slice::from_ref(&entry), |prompt| {
472 prompt_seen.set(prompt.contains("raw agent nomination"));
473 Ok(
474 r#"[{"content":"generalized principle","trigger_desc":"generalize","anti_trigger_desc":null}]"#
475 .to_string(),
476 )
477 })
478 .unwrap();
479
480 assert!(prompt_seen.get());
481 assert_eq!(chunks[0].content, "generalized principle");
482 assert_eq!(
483 chunks[0].nomination.as_deref(),
484 Some("raw agent nomination")
485 );
486 }
487}
488
489impl LlmEmbeddingProvider {
490 pub fn new(config: EmbeddingConfig) -> Self {
491 Self { config }
492 }
493
494 fn embed(&self, text: &str) -> Result<Vec<f32>> {
495 let api_key = self
496 .config
497 .resolved_api_key()
498 .ok_or_else(|| InnateError::Other("Embedding API key not configured".into()))?;
499
500 let base = self.config.resolved_base_url();
501 let url = format!("{base}/embeddings");
502
503 let body = json!({
504 "input": text,
505 "model": self.config.model_id,
506 });
507
508 let auth = format!("Bearer {api_key}");
509 let resp_json = post_json_retry(
510 &url,
511 &[("Authorization", &auth)],
512 &body,
513 "Embedding",
514 )?;
515
516 let embedding = resp_json
517 .pointer("/data/0/embedding")
518 .and_then(Value::as_array)
519 .ok_or_else(|| InnateError::Other("unexpected embedding response shape".into()))?;
520
521 Ok(embedding
522 .iter()
523 .filter_map(Value::as_f64)
524 .map(|x| x as f32)
525 .collect())
526 }
527}
528
529impl EmbeddingProvider for LlmEmbeddingProvider {
530 fn model_name(&self) -> &'static str {
531 "llm-embedding"
532 }
533
534 fn content_dim(&self) -> usize {
535 self.config.dim
536 }
537
538 fn trigger_dim(&self) -> usize {
539 self.config.dim
540 }
541
542 fn embed_content(&self, text: &str) -> Result<Vec<f32>> {
543 self.embed(text)
544 }
545
546 fn embed_trigger(&self, text: &str) -> Result<Vec<f32>> {
547 self.embed(text)
548 }
549
550 fn embed_both(&self, text: &str) -> Result<(Vec<f32>, Vec<f32>)> {
553 let v = self.embed(text)?;
554 Ok((v.clone(), v))
555 }
556}
557
558pub fn test_llm(config: &LlmConfig) -> Result<String> {
564 let distiller = build_distiller(config);
565 let dummy_log = json!({
566 "id": "test",
567 "query": "connection test",
568 "output_summary": "test",
569 "outcome": "ok"
570 });
571 distiller.distill(&[dummy_log])?;
573 Ok(format!("OK — model: {}", config.model_id))
574}
575
576pub fn test_embedding(config: &EmbeddingConfig) -> Result<usize> {
578 let provider = LlmEmbeddingProvider::new(config.clone());
579 let vec = provider.embed("connection test")?;
580 Ok(vec.len())
581}