1use super::*;
2
3#[derive(Debug, Clone, Default)]
9pub struct RecallParams<'a> {
10 pub query: &'a str,
11 pub budget: usize,
12 pub trace: bool,
13 pub include_sparks: bool,
14 pub top: Option<usize>,
15 pub source: &'a str,
16 pub expand_deps: &'a str, pub allow_trim: bool, pub refine_mode: &'a str, pub min_score: Option<f64>,
24}
25
26impl KnowledgeBase {
27 pub fn recall(&self, params: RecallParams<'_>) -> Result<RecallResult> {
28 let RecallParams {
29 query,
30 budget,
31 trace,
32 include_sparks,
33 top,
34 source,
35 expand_deps,
36 allow_trim,
37 refine_mode,
38 min_score,
39 } = params;
40 let expand_deps = if expand_deps.is_empty() {
41 "false"
42 } else {
43 expand_deps
44 };
45 let refine_mode = if refine_mode.is_empty() {
46 "off"
47 } else {
48 refine_mode
49 };
50 validate_source(source)?;
51 let trace_id = gen_uuid();
52 let now = utc_now_iso();
53
54 let (q_content, q_trigger) = self
55 .embedding
56 .embed_both(query)
57 .map_err(|e| InnateError::EmbeddingUnavailable(e.to_string()))?;
58
59 let mut candidates = self.ann_candidates(&q_content, &q_trigger)?;
61 self.apply_soft_dep_bonus(&mut candidates)?;
62
63 let mut scored = self.score_candidates(candidates, query)?;
65
66 if let Some(min) = min_score {
69 scored.retain(|(fused, _)| *fused >= min);
70 }
71
72 let (selected, skipped, skipped_reasons) =
74 self.pack(&scored, budget, expand_deps, allow_trim, query)?;
75
76 let depth_skipped: Vec<String> = skipped_reasons
77 .iter()
78 .filter(|(_, r)| r.as_str() == "dep_depth_limit")
79 .map(|(id, _)| id.clone())
80 .collect();
81
82 let mut selected = selected;
84 if self.density_refill {
85 selected = self.density_refill(selected, &skipped, budget);
86 }
87
88 let limited = limit_knowledge(selected, top);
89 let visible = if refine_mode == "adapt" {
90 self.refiner
91 .refine(limited.clone(), Some(budget))
92 .unwrap_or(limited)
93 } else {
94 limited
95 };
96
97 let sparks = if include_sparks {
99 self.recall_sparks(&q_content, &q_trigger)?
100 } else {
101 vec![]
102 };
103
104 if trace {
105 self.write_recall_trace(
106 &trace_id,
107 query,
108 &scored,
109 &visible,
110 &sparks,
111 &depth_skipped,
112 &skipped_reasons,
113 refine_mode,
114 source,
115 &now,
116 )?;
117 }
118
119 let empty = visible.is_empty() && sparks.is_empty();
120 Ok(RecallResult {
121 knowledge: visible,
122 sparks,
123 trace_id,
124 empty,
125 depth_skipped,
126 skipped_reasons,
127 })
128 }
129
130 fn ann_candidates(
131 &self,
132 q_content: &[f32],
133 q_trigger: &[f32],
134 ) -> Result<HashMap<String, CandidateInfo>> {
135 let embed_version = self
136 .storage
137 .get_meta("embed_version")?
138 .and_then(|v| v.parse::<i64>().ok())
139 .unwrap_or(1);
140
141 let content_res = self
142 .storage
143 .search_vec_content(q_content, self.top_k_candidates * 2)?;
144 let trigger_res = self
145 .storage
146 .search_vec_trigger(q_trigger, self.top_k_candidates * 2)?;
147
148 let all_ids: Vec<&str> = {
150 let mut seen = HashSet::new();
151 content_res
152 .iter()
153 .chain(trigger_res.iter())
154 .map(|(id, _)| id.as_str())
155 .filter(|id| seen.insert(*id))
156 .collect()
157 };
158 let chunks = self.storage.get_chunks_by_ids(&all_ids)?;
159
160 let mut candidates: HashMap<String, CandidateInfo> = HashMap::new();
161
162 for (cid, sim) in &content_res {
163 if let Some(chunk) = chunks.get(cid) {
164 if chunk_is_valid_for_recall(chunk, embed_version) {
165 let e = candidates
166 .entry(cid.clone())
167 .or_insert_with(|| CandidateInfo {
168 chunk: chunk.clone(),
169 sim_content: 0.0,
170 sim_trigger: 0.0,
171 });
172 e.sim_content = e.sim_content.max(*sim);
173 }
174 }
175 }
176 for (cid, sim) in &trigger_res {
177 if let Some(chunk) = chunks.get(cid) {
178 if chunk_is_valid_for_recall(chunk, embed_version) {
179 let e = candidates
180 .entry(cid.clone())
181 .or_insert_with(|| CandidateInfo {
182 chunk: chunk.clone(),
183 sim_content: 0.0,
184 sim_trigger: 0.0,
185 });
186 e.sim_trigger = e.sim_trigger.max(*sim);
187 }
188 }
189 }
190 Ok(candidates)
191 }
192
193 fn apply_soft_dep_bonus(&self, candidates: &mut HashMap<String, CandidateInfo>) -> Result<()> {
194 let src_ids: Vec<String> = candidates
197 .iter()
198 .filter(|(_, info)| {
199 info.chunk.get("origin").and_then(Value::as_str) != Some("spark")
200 })
201 .map(|(cid, _)| cid.clone())
202 .collect();
203 if src_ids.is_empty() {
204 return Ok(());
205 }
206 let src_refs: Vec<&str> = src_ids.iter().map(String::as_str).collect();
207 let deps_map = self.storage.get_deps_batch(&src_refs)?;
208
209 let mut target_ids: Vec<String> = Vec::new();
212 let mut seen: HashSet<String> = HashSet::new();
213 for deps in deps_map.values() {
214 for (dst, kind, _) in deps {
215 if kind == "soft" && seen.insert(dst.clone()) {
216 target_ids.push(dst.clone());
217 }
218 }
219 }
220 if target_ids.is_empty() {
221 return Ok(());
222 }
223 let target_refs: Vec<&str> = target_ids.iter().map(String::as_str).collect();
224 let targets = self.storage.get_chunks_by_ids(&target_refs)?;
225
226 for src in &src_ids {
227 let Some(deps) = deps_map.get(src) else {
228 continue;
229 };
230 for (dst, kind, _) in deps {
231 if kind != "soft" {
232 continue;
233 }
234 let Some(target) = targets.get(dst) else {
235 continue;
236 };
237 if target.get("state").and_then(Value::as_str) == Some("archived") {
238 continue;
239 }
240 if target.get("origin").and_then(Value::as_str) == Some("spark") {
241 continue;
242 }
243 let e = candidates
244 .entry(dst.clone())
245 .or_insert_with(|| CandidateInfo {
246 chunk: target.clone(),
247 sim_content: 0.0,
248 sim_trigger: 0.0,
249 });
250 e.sim_content = (e.sim_content + 0.05).min(1.0);
251 }
252 }
253 Ok(())
254 }
255
256 fn score_candidates(
257 &self,
258 candidates: HashMap<String, CandidateInfo>,
259 query: &str,
260 ) -> Result<Vec<(f64, Value)>> {
261 let context_key = content_hash(&normalize_query(query));
262 let cand_ids: Vec<String> = candidates
265 .values()
266 .filter_map(|info| info.chunk.get("id").and_then(Value::as_str).map(str::to_string))
267 .collect();
268 let cand_refs: Vec<&str> = cand_ids.iter().map(String::as_str).collect();
269 let ctx_scores = self.storage.context_scores_batch(&cand_refs, &context_key)?;
270
271 let mut scored: Vec<(f64, Value)> = Vec::with_capacity(candidates.len());
272 for info in candidates.into_values() {
273 let conf = info
274 .chunk
275 .get("confidence")
276 .and_then(Value::as_f64)
277 .unwrap_or(0.5);
278 let chunk_id = info.chunk.get("id").and_then(Value::as_str).unwrap_or("");
279 let context_score = ctx_scores.get(chunk_id).copied().unwrap_or(0.0);
280 let mut fused = self.w_content * info.sim_content as f64
281 + self.w_trigger * info.sim_trigger as f64
282 + self.w_confidence * conf
283 + self.w_context * context_score;
284 if info.chunk.get("state").and_then(Value::as_str) == Some("pending") {
285 fused *= PENDING_RECALL_PENALTY;
286 }
287 let anti = info
288 .chunk
289 .get("anti_trigger_desc")
290 .and_then(Value::as_str)
291 .unwrap_or("");
292 if !anti.is_empty() && anti_trigger_hit(query, anti) {
293 fused *= self.anti_trigger_penalty;
294 }
295 let mut chunk = info.chunk;
296 chunk["_context_score"] = json!(context_score);
297 chunk["_fused_score"] = json!(fused);
298 scored.push((fused, chunk));
299 }
300 scored.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
301 scored.truncate(self.top_k_candidates);
302 Ok(scored)
303 }
304
305 fn pack(
306 &self,
307 scored: &[(f64, Value)],
308 budget: usize,
309 expand_deps: &str,
310 allow_trim: bool,
311 query: &str,
312 ) -> Result<PackResult> {
313 let mut selected: Vec<Value> = vec![];
314 let mut skipped: Vec<(Vec<Value>, f64, usize)> = vec![];
315 let mut skipped_reasons: HashMap<String, String> = HashMap::new();
316 let mut used_ids: HashSet<String> = HashSet::new();
317 let mut used_tokens: usize = 0;
318
319 for (fused, chunk) in scored {
320 let cid = chunk["id"].as_str().unwrap_or("").to_string();
321 if used_ids.contains(&cid) {
322 continue;
323 }
324
325 let (block, dep_skip_reason) = self.build_dep_block(chunk, expand_deps)?;
327 if let Some(reason) = dep_skip_reason {
328 skipped_reasons.insert(cid, reason);
329 continue;
330 }
331
332 let new_block: Vec<Value> = block
333 .iter()
334 .filter(|b| !used_ids.contains(b["id"].as_str().unwrap_or("")))
335 .cloned()
336 .collect();
337 let cost = block_cost(&new_block);
338
339 if used_tokens + cost <= budget {
340 for b in &block {
341 let bid = b["id"].as_str().unwrap_or("").to_string();
342 if !used_ids.contains(&bid) {
343 let mut b = b.clone();
344 b["_fused_score"] = json!(fused);
345 selected.push(b);
346 used_ids.insert(bid);
347 }
348 }
349 used_tokens += cost;
350 } else if allow_trim {
351 if let Some(trimmed) =
353 self.refiner
354 .trim(&block, query, budget.saturating_sub(used_tokens))
355 {
356 let trim_cost = block_cost(&trimmed);
357 if used_tokens + trim_cost <= budget {
358 for b in &trimmed {
359 let bid = b["id"].as_str().unwrap_or("").to_string();
360 if !used_ids.contains(&bid) {
361 let mut b = b.clone();
362 b["_fused_score"] = json!(fused);
363 b["_trimmed"] = json!(true);
364 selected.push(b);
365 used_ids.insert(bid);
366 }
367 }
368 used_tokens += trim_cost;
369 continue;
370 }
371 }
372 skipped.push((block, *fused, cost));
373 } else {
374 skipped.push((block, *fused, cost));
375 }
376 }
377 Ok((selected, skipped, skipped_reasons))
378 }
379
380 fn build_dep_block(
383 &self,
384 seed: &Value,
385 expand_deps: &str,
386 ) -> Result<(Vec<Value>, Option<String>)> {
387 if expand_deps == "false" || expand_deps.is_empty() {
388 return Ok((vec![seed.clone()], None));
389 }
390 let seed_id = seed["id"].as_str().unwrap_or("");
391 match expand_deps {
392 "direct" => {
393 let deps = self.storage.get_deps(seed_id)?;
394 let mut block = vec![seed.clone()];
395 for (dep_id, kind, _) in &deps {
396 if kind != "hard" {
397 continue;
398 }
399 match self.validate_hard_dep(dep_id)? {
400 Some(chunk) => block.push(chunk),
401 None => return Ok((vec![], Some("hard_dep_unavailable".to_string()))),
402 }
403 }
404 Ok((block, None))
405 }
406 "closure" => {
407 let mut block = vec![seed.clone()];
408 let mut visited: HashSet<String> = [seed_id.to_string()].into();
409 match self.expand_hard_closure(seed_id, &mut visited, &mut block, 0, 3)? {
410 Some(reason) => Ok((vec![], Some(reason))),
411 None => Ok((block, None)),
412 }
413 }
414 _ => Ok((vec![seed.clone()], None)),
415 }
416 }
417
418 fn validate_hard_dep(&self, dep_id: &str) -> Result<Option<Value>> {
420 match self.storage.get_chunk(dep_id)? {
421 None => Ok(None),
422 Some(chunk) => {
423 let state = chunk.get("state").and_then(Value::as_str).unwrap_or("");
424 let origin = chunk.get("origin").and_then(Value::as_str).unwrap_or("");
425 let embed_v = chunk
426 .get("embed_version")
427 .and_then(Value::as_i64)
428 .unwrap_or(0);
429 if state == "archived" || origin == "spark" || embed_v == 0 {
430 Ok(None)
431 } else {
432 Ok(Some(chunk))
433 }
434 }
435 }
436 }
437
438 fn expand_hard_closure(
440 &self,
441 id: &str,
442 visited: &mut HashSet<String>,
443 block: &mut Vec<Value>,
444 depth: usize,
445 max_depth: usize,
446 ) -> Result<Option<String>> {
447 if depth >= max_depth {
448 return Ok(Some("dep_depth_limit".to_string()));
449 }
450 let deps = self.storage.get_deps(id)?;
451 for (dep_id, kind, _) in &deps {
452 if kind != "hard" {
453 continue;
454 }
455 if visited.contains(dep_id) {
456 continue;
457 } visited.insert(dep_id.clone());
459 match self.validate_hard_dep(dep_id)? {
460 None => return Ok(Some("hard_dep_unavailable".to_string())),
461 Some(chunk) => {
462 block.push(chunk);
463 if let Some(reason) =
464 self.expand_hard_closure(dep_id, visited, block, depth + 1, max_depth)?
465 {
466 return Ok(Some(reason));
467 }
468 }
469 }
470 }
471 Ok(None)
472 }
473
474 fn density_refill(
475 &self,
476 mut selected: Vec<Value>,
477 skipped: &[(Vec<Value>, f64, usize)],
478 budget: usize,
479 ) -> Vec<Value> {
480 let used_tokens = block_cost(&selected);
481 if used_tokens >= budget {
482 return selected;
483 }
484
485 let selected_ids: HashSet<String> = selected
486 .iter()
487 .filter_map(|c| c["id"].as_str().map(str::to_string))
488 .collect();
489
490 let mut density_items: Vec<(f64, Vec<Value>, usize)> = skipped
491 .iter()
492 .filter_map(|(block, fscore, _)| {
493 let block: Vec<Value> = block
494 .iter()
495 .filter(|b| !selected_ids.contains(b["id"].as_str().unwrap_or("")))
496 .cloned()
497 .collect();
498 if block.is_empty() {
499 return None;
500 }
501 let cost = block_cost(&block);
502 let density = fscore / cost.max(1) as f64;
503 Some((density, block, cost))
504 })
505 .collect();
506 density_items.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
507
508 let mut used_tokens = block_cost(&selected);
509 let mut added_ids: HashSet<String> = selected_ids;
510 for (_, block, cost) in density_items {
511 if used_tokens + cost <= budget {
512 for b in block {
513 let bid = b["id"].as_str().unwrap_or("").to_string();
514 if !added_ids.contains(&bid) {
515 selected.push(b);
516 added_ids.insert(bid);
517 }
518 }
519 used_tokens += cost;
520 }
521 }
522 selected
523 }
524
525 fn recall_sparks(&self, q_content: &[f32], q_trigger: &[f32]) -> Result<Vec<Value>> {
526 let embed_version = self
527 .storage
528 .get_meta("embed_version")?
529 .and_then(|v| v.parse::<i64>().ok())
530 .unwrap_or(1);
531
532 let content_res = self
533 .storage
534 .search_vec_content(q_content, self.top_k_candidates)?;
535 let trigger_res = self
536 .storage
537 .search_vec_trigger(q_trigger, self.top_k_candidates)?;
538
539 let all_ids: Vec<&str> = {
541 let mut seen = HashSet::new();
542 content_res
543 .iter()
544 .chain(trigger_res.iter())
545 .map(|(id, _)| id.as_str())
546 .filter(|id| seen.insert(*id))
547 .collect()
548 };
549 let chunks = self.storage.get_chunks_by_ids(&all_ids)?;
550
551 let mut spark_scores: HashMap<String, (f32, Value)> = HashMap::new();
552 for (cid, sim) in content_res.iter().chain(trigger_res.iter()) {
553 if let Some(chunk) = chunks.get(cid) {
554 if chunk.get("origin").and_then(Value::as_str) != Some("spark") {
555 continue;
556 }
557 if chunk.get("state").and_then(Value::as_str) == Some("archived") {
558 continue;
559 }
560 let maturity = chunk.get("maturity").and_then(Value::as_str).unwrap_or("");
561 if maturity == "promoted" || maturity == "dropped" {
562 continue;
563 }
564 let ev = chunk
565 .get("embed_version")
566 .and_then(Value::as_i64)
567 .unwrap_or(1);
568 if ev < embed_version {
569 continue;
570 }
571 let entry = spark_scores
572 .entry(cid.clone())
573 .or_insert_with(|| (*sim, chunk.clone()));
574 if *sim > entry.0 {
575 *entry = (*sim, chunk.clone());
576 }
577 }
578 }
579 let mut sparks: Vec<(f32, Value)> = spark_scores.into_values().collect();
580 sparks.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
581 Ok(sparks
582 .into_iter()
583 .take(self.top_k_candidates)
584 .map(|(_, c)| c)
585 .collect())
586 }
587
588 #[allow(clippy::too_many_arguments)]
589 fn write_recall_trace(
590 &self,
591 trace_id: &str,
592 query: &str,
593 scored: &[(f64, Value)],
594 visible: &[Value],
595 sparks: &[Value],
596 depth_skipped: &[String],
597 skipped_reasons: &HashMap<String, String>,
598 refine_mode: &str,
599 source: &str,
600 now: &str,
601 ) -> Result<()> {
602 let lib_id = self.storage.lib_id()?;
603 self.storage.begin_immediate()?;
604 let result = (|| -> Result<()> {
605 for (rank, (_, chunk)) in scored.iter().enumerate() {
606 let cid = chunk["id"].as_str().unwrap_or("");
607 let sim = chunk.get("_fused_score").and_then(Value::as_f64);
608 let rm = skipped_reasons
610 .get(cid)
611 .map(|r| format!("skipped:{r}"))
612 .or_else(|| {
613 if refine_mode != "off" && !refine_mode.is_empty() {
614 Some(refine_mode.to_string())
615 } else {
616 None
617 }
618 });
619 self.storage.insert_usage_trace(
620 trace_id,
621 Some(cid),
622 "retrieved",
623 1.0,
624 sim,
625 rm.as_deref(),
626 None,
627 Some((rank + 1) as i64),
628 None,
629 source,
630 now,
631 )?;
632 }
633 for (rank, chunk) in visible.iter().enumerate() {
634 let cid = chunk["id"].as_str().unwrap_or("");
635 self.storage.insert_usage_trace(
636 trace_id,
637 Some(cid),
638 "selected",
639 1.0,
640 None,
641 None,
642 None,
643 Some((rank + 1) as i64),
644 None,
645 source,
646 now,
647 )?;
648 if chunk
650 .get("_trimmed")
651 .and_then(Value::as_bool)
652 .unwrap_or(false)
653 {
654 self.storage.insert_usage_trace(
655 trace_id,
656 Some(cid),
657 "refined",
658 1.0,
659 None,
660 Some("trim"),
661 None,
662 Some((rank + 1) as i64),
663 None,
664 source,
665 now,
666 )?;
667 }
668 }
669 for (rank, chunk) in sparks.iter().enumerate() {
671 let cid = chunk["id"].as_str().unwrap_or("");
672 self.storage.insert_usage_trace(
673 trace_id,
674 Some(cid),
675 "retrieved",
676 1.0,
677 None,
678 Some("spark"),
679 None,
680 Some((rank + 1) as i64),
681 None,
682 source,
683 now,
684 )?;
685 }
686 let snapshot = json!({
687 "retrieved": scored.iter().map(|(_, c)| c["id"].as_str().unwrap_or("")).collect::<Vec<_>>(),
688 "selected": visible.iter().map(|c| c["id"].as_str().unwrap_or("")).collect::<Vec<_>>(),
689 "sparks": sparks.iter().map(|c| c["id"].as_str().unwrap_or("")).collect::<Vec<_>>(),
690 "depth_skipped": depth_skipped,
691 "skipped_reasons": skipped_reasons,
692 });
693 let log = EpisodicLogRow {
694 id: gen_uuid(),
695 trace_id: trace_id.to_string(),
696 lib_id,
697 ts: now.to_string(),
698 query: Some(query.to_string()),
699 recall_snapshot: Some(snapshot.to_string()),
700 event_source: source.to_string(),
701 task_state: "recalled".to_string(),
702 usage_state: "unknown".to_string(),
703 context_key: Some(content_hash(&normalize_query(query))),
704 distill_state: "open".to_string(),
705 ..Default::default()
706 };
707 self.storage.upsert_episodic_log(&log)?;
708 self.storage.commit()
709 })();
710 if result.is_err() {
711 let _ = self.storage.rollback();
712 }
713 result
714 }
715}