1use super::*;
2
3#[derive(Debug, Clone, Default)]
9pub struct RecallParams<'a> {
10 pub query: &'a str,
11 pub budget: usize,
12 pub trace: bool,
13 pub include_sparks: bool,
14 pub top: Option<usize>,
15 pub source: &'a str,
16 pub expand_deps: &'a str, pub allow_trim: bool, pub refine_mode: &'a str, pub min_score: Option<f64>,
24}
25
26impl KnowledgeBase {
27 pub fn recall(&self, params: RecallParams<'_>) -> Result<RecallResult> {
28 let RecallParams {
29 query,
30 budget,
31 trace,
32 include_sparks,
33 top,
34 source,
35 expand_deps,
36 allow_trim,
37 refine_mode,
38 min_score,
39 } = params;
40 let expand_deps = if expand_deps.is_empty() {
41 "false"
42 } else {
43 expand_deps
44 };
45 let refine_mode = if refine_mode.is_empty() {
46 "off"
47 } else {
48 refine_mode
49 };
50 validate_source(source)?;
51 let trace_id = gen_uuid();
52 let now = utc_now_iso();
53
54 let situation = Situation::from_query(query);
58 let context_key = situation.context_key(&self.situation_coarse_keys);
59
60 let (q_content, q_trigger) = self
61 .embedding
62 .embed_both(query)
63 .map_err(|e| InnateError::EmbeddingUnavailable(e.to_string()))?;
64
65 let mut candidates = self.ann_candidates(&q_content, &q_trigger)?;
67 self.apply_soft_dep_bonus(&mut candidates)?;
68
69 let mut scored = self.score_candidates(candidates, query, &context_key, &now)?;
71
72 if let Some(min) = min_score {
75 scored.retain(|(fused, _)| *fused >= min);
76 }
77
78 let (selected, skipped, skipped_reasons) =
80 self.pack(&scored, budget, expand_deps, allow_trim, query)?;
81
82 let depth_skipped: Vec<String> = skipped_reasons
83 .iter()
84 .filter(|(_, r)| r.as_str() == "dep_depth_limit")
85 .map(|(id, _)| id.clone())
86 .collect();
87
88 let mut selected = selected;
90 if self.density_refill {
91 selected = self.density_refill(selected, &skipped, budget);
92 }
93
94 let limited = limit_knowledge(selected, top);
95 let visible = if refine_mode == "adapt" {
96 self.refiner
97 .refine(limited.clone(), Some(budget))
98 .unwrap_or(limited)
99 } else {
100 limited
101 };
102
103 let sparks = if include_sparks {
105 self.recall_sparks(&q_content, &q_trigger)?
106 } else {
107 vec![]
108 };
109
110 if trace {
111 self.write_recall_trace(
112 &trace_id,
113 query,
114 &context_key,
115 &scored,
116 &visible,
117 &sparks,
118 &depth_skipped,
119 &skipped_reasons,
120 refine_mode,
121 source,
122 &now,
123 )?;
124 }
125
126 let empty = visible.is_empty() && sparks.is_empty();
127 Ok(RecallResult {
128 knowledge: visible,
129 sparks,
130 trace_id,
131 empty,
132 depth_skipped,
133 skipped_reasons,
134 })
135 }
136
137 pub(super) fn ann_candidates(
138 &self,
139 q_content: &[f32],
140 q_trigger: &[f32],
141 ) -> Result<HashMap<String, CandidateInfo>> {
142 let embed_version = self
143 .storage
144 .get_meta("embed_version")?
145 .and_then(|v| v.parse::<i64>().ok())
146 .unwrap_or(1);
147
148 let content_res = self
149 .storage
150 .search_vec_content(q_content, self.top_k_candidates * 2)?;
151 let trigger_res = self
152 .storage
153 .search_vec_trigger(q_trigger, self.top_k_candidates * 2)?;
154
155 let all_ids: Vec<&str> = {
157 let mut seen = HashSet::new();
158 content_res
159 .iter()
160 .chain(trigger_res.iter())
161 .map(|(id, _)| id.as_str())
162 .filter(|id| seen.insert(*id))
163 .collect()
164 };
165 let chunks = self.storage.get_chunks_by_ids(&all_ids)?;
166
167 let mut candidates: HashMap<String, CandidateInfo> = HashMap::new();
168
169 for (cid, sim) in &content_res {
170 if let Some(chunk) = chunks.get(cid) {
171 if chunk_is_valid_for_recall(chunk, embed_version) {
172 let e = candidates
173 .entry(cid.clone())
174 .or_insert_with(|| CandidateInfo {
175 chunk: chunk.clone(),
176 sim_content: 0.0,
177 sim_trigger: 0.0,
178 });
179 e.sim_content = e.sim_content.max(*sim);
180 }
181 }
182 }
183 for (cid, sim) in &trigger_res {
184 if let Some(chunk) = chunks.get(cid) {
185 if chunk_is_valid_for_recall(chunk, embed_version) {
186 let e = candidates
187 .entry(cid.clone())
188 .or_insert_with(|| CandidateInfo {
189 chunk: chunk.clone(),
190 sim_content: 0.0,
191 sim_trigger: 0.0,
192 });
193 e.sim_trigger = e.sim_trigger.max(*sim);
194 }
195 }
196 }
197 Ok(candidates)
198 }
199
200 pub(super) fn apply_soft_dep_bonus(
201 &self,
202 candidates: &mut HashMap<String, CandidateInfo>,
203 ) -> Result<()> {
204 let src_ids: Vec<String> = candidates
207 .iter()
208 .filter(|(_, info)| {
209 info.chunk.get("origin").and_then(Value::as_str) != Some("spark")
210 })
211 .map(|(cid, _)| cid.clone())
212 .collect();
213 if src_ids.is_empty() {
214 return Ok(());
215 }
216 let src_refs: Vec<&str> = src_ids.iter().map(String::as_str).collect();
217 let deps_map = self.storage.get_deps_batch(&src_refs)?;
218
219 let mut target_ids: Vec<String> = Vec::new();
222 let mut seen: HashSet<String> = HashSet::new();
223 for deps in deps_map.values() {
224 for (dst, kind, _) in deps {
225 if kind == "soft" && seen.insert(dst.clone()) {
226 target_ids.push(dst.clone());
227 }
228 }
229 }
230 if target_ids.is_empty() {
231 return Ok(());
232 }
233 let target_refs: Vec<&str> = target_ids.iter().map(String::as_str).collect();
234 let targets = self.storage.get_chunks_by_ids(&target_refs)?;
235
236 for src in &src_ids {
237 let Some(deps) = deps_map.get(src) else {
238 continue;
239 };
240 for (dst, kind, _) in deps {
241 if kind != "soft" {
242 continue;
243 }
244 let Some(target) = targets.get(dst) else {
245 continue;
246 };
247 if target.get("state").and_then(Value::as_str) == Some("archived") {
248 continue;
249 }
250 if target.get("origin").and_then(Value::as_str) == Some("spark") {
251 continue;
252 }
253 let e = candidates
254 .entry(dst.clone())
255 .or_insert_with(|| CandidateInfo {
256 chunk: target.clone(),
257 sim_content: 0.0,
258 sim_trigger: 0.0,
259 });
260 e.sim_content = (e.sim_content + 0.05).min(1.0);
261 }
262 }
263 Ok(())
264 }
265
266 fn score_candidates(
267 &self,
268 candidates: HashMap<String, CandidateInfo>,
269 query: &str,
270 context_key: &str,
271 now: &str,
272 ) -> Result<Vec<(f64, Value)>> {
273 let cand_ids: Vec<String> = candidates
276 .values()
277 .filter_map(|info| info.chunk.get("id").and_then(Value::as_str).map(str::to_string))
278 .collect();
279 let cand_refs: Vec<&str> = cand_ids.iter().map(String::as_str).collect();
280 let ctx_scores = self.storage.context_scores_batch(&cand_refs, context_key)?;
281
282 let mut scored: Vec<(f64, Value)> = Vec::with_capacity(candidates.len());
283 for info in candidates.into_values() {
284 let conf = info
285 .chunk
286 .get("confidence")
287 .and_then(Value::as_f64)
288 .unwrap_or(0.5);
289 let chunk_id = info.chunk.get("id").and_then(Value::as_str).unwrap_or("");
290 let context_score = ctx_scores.get(chunk_id).copied().unwrap_or(0.0);
291 let used_count = info
294 .chunk
295 .get("used_count")
296 .and_then(Value::as_i64)
297 .unwrap_or(0);
298 let last_used_at = info.chunk.get("last_used_at").and_then(Value::as_str);
299 let activation = actr_activation(used_count, last_used_at, now);
300 let mut fused = self.w_content * info.sim_content as f64
301 + self.w_trigger * info.sim_trigger as f64
302 + self.w_confidence * conf
303 + self.w_context * context_score
304 + self.w_activation * activation;
305 if info.chunk.get("state").and_then(Value::as_str) == Some("pending") {
306 fused *= PENDING_RECALL_PENALTY;
307 }
308 let anti = info
309 .chunk
310 .get("anti_trigger_desc")
311 .and_then(Value::as_str)
312 .unwrap_or("");
313 if !anti.is_empty() && anti_trigger_hit(query, anti) {
314 fused *= self.anti_trigger_penalty;
315 }
316 let mut chunk = info.chunk;
317 chunk["_context_score"] = json!(context_score);
318 chunk["_activation"] = json!(activation);
319 chunk["_fused_score"] = json!(fused);
320 scored.push((fused, chunk));
321 }
322 scored.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
323 scored.truncate(self.top_k_candidates);
324 Ok(scored)
325 }
326
327 fn pack(
328 &self,
329 scored: &[(f64, Value)],
330 budget: usize,
331 expand_deps: &str,
332 allow_trim: bool,
333 query: &str,
334 ) -> Result<PackResult> {
335 let mut selected: Vec<Value> = vec![];
336 let mut skipped: Vec<(Vec<Value>, f64, usize)> = vec![];
337 let mut skipped_reasons: HashMap<String, String> = HashMap::new();
338 let mut used_ids: HashSet<String> = HashSet::new();
339 let mut used_tokens: usize = 0;
340
341 for (fused, chunk) in scored {
342 let cid = chunk["id"].as_str().unwrap_or("").to_string();
343 if used_ids.contains(&cid) {
344 continue;
345 }
346
347 let (block, dep_skip_reason) = self.build_dep_block(chunk, expand_deps)?;
349 if let Some(reason) = dep_skip_reason {
350 skipped_reasons.insert(cid, reason);
351 continue;
352 }
353
354 let new_block: Vec<Value> = block
355 .iter()
356 .filter(|b| !used_ids.contains(b["id"].as_str().unwrap_or("")))
357 .cloned()
358 .collect();
359 let cost = block_cost(&new_block);
360
361 if used_tokens + cost <= budget {
362 for b in &block {
363 let bid = b["id"].as_str().unwrap_or("").to_string();
364 if !used_ids.contains(&bid) {
365 let mut b = b.clone();
366 b["_fused_score"] = json!(fused);
367 selected.push(b);
368 used_ids.insert(bid);
369 }
370 }
371 used_tokens += cost;
372 } else if allow_trim {
373 if let Some(trimmed) =
375 self.refiner
376 .trim(&block, query, budget.saturating_sub(used_tokens))
377 {
378 let trim_cost = block_cost(&trimmed);
379 if used_tokens + trim_cost <= budget {
380 for b in &trimmed {
381 let bid = b["id"].as_str().unwrap_or("").to_string();
382 if !used_ids.contains(&bid) {
383 let mut b = b.clone();
384 b["_fused_score"] = json!(fused);
385 b["_trimmed"] = json!(true);
386 selected.push(b);
387 used_ids.insert(bid);
388 }
389 }
390 used_tokens += trim_cost;
391 continue;
392 }
393 }
394 skipped.push((block, *fused, cost));
395 } else {
396 skipped.push((block, *fused, cost));
397 }
398 }
399 Ok((selected, skipped, skipped_reasons))
400 }
401
402 fn build_dep_block(
405 &self,
406 seed: &Value,
407 expand_deps: &str,
408 ) -> Result<(Vec<Value>, Option<String>)> {
409 if expand_deps == "false" || expand_deps.is_empty() {
410 return Ok((vec![seed.clone()], None));
411 }
412 let seed_id = seed["id"].as_str().unwrap_or("");
413 match expand_deps {
414 "direct" => {
415 let deps = self.storage.get_deps(seed_id)?;
416 let mut block = vec![seed.clone()];
417 for (dep_id, kind, _) in &deps {
418 if kind != "hard" {
419 continue;
420 }
421 match self.validate_hard_dep(dep_id)? {
422 Some(chunk) => block.push(chunk),
423 None => return Ok((vec![], Some("hard_dep_unavailable".to_string()))),
424 }
425 }
426 Ok((block, None))
427 }
428 "closure" => {
429 let mut block = vec![seed.clone()];
430 let mut visited: HashSet<String> = [seed_id.to_string()].into();
431 match self.expand_hard_closure(seed_id, &mut visited, &mut block, 0, 3)? {
432 Some(reason) => Ok((vec![], Some(reason))),
433 None => Ok((block, None)),
434 }
435 }
436 _ => Ok((vec![seed.clone()], None)),
437 }
438 }
439
440 fn validate_hard_dep(&self, dep_id: &str) -> Result<Option<Value>> {
442 match self.storage.get_chunk(dep_id)? {
443 None => Ok(None),
444 Some(chunk) => {
445 let state = chunk.get("state").and_then(Value::as_str).unwrap_or("");
446 let origin = chunk.get("origin").and_then(Value::as_str).unwrap_or("");
447 let embed_v = chunk
448 .get("embed_version")
449 .and_then(Value::as_i64)
450 .unwrap_or(0);
451 if state == "archived" || origin == "spark" || embed_v == 0 {
452 Ok(None)
453 } else {
454 Ok(Some(chunk))
455 }
456 }
457 }
458 }
459
460 fn expand_hard_closure(
462 &self,
463 id: &str,
464 visited: &mut HashSet<String>,
465 block: &mut Vec<Value>,
466 depth: usize,
467 max_depth: usize,
468 ) -> Result<Option<String>> {
469 if depth >= max_depth {
470 return Ok(Some("dep_depth_limit".to_string()));
471 }
472 let deps = self.storage.get_deps(id)?;
473 for (dep_id, kind, _) in &deps {
474 if kind != "hard" {
475 continue;
476 }
477 if visited.contains(dep_id) {
478 continue;
479 } visited.insert(dep_id.clone());
481 match self.validate_hard_dep(dep_id)? {
482 None => return Ok(Some("hard_dep_unavailable".to_string())),
483 Some(chunk) => {
484 block.push(chunk);
485 if let Some(reason) =
486 self.expand_hard_closure(dep_id, visited, block, depth + 1, max_depth)?
487 {
488 return Ok(Some(reason));
489 }
490 }
491 }
492 }
493 Ok(None)
494 }
495
496 fn density_refill(
497 &self,
498 mut selected: Vec<Value>,
499 skipped: &[(Vec<Value>, f64, usize)],
500 budget: usize,
501 ) -> Vec<Value> {
502 let used_tokens = block_cost(&selected);
503 if used_tokens >= budget {
504 return selected;
505 }
506
507 let selected_ids: HashSet<String> = selected
508 .iter()
509 .filter_map(|c| c["id"].as_str().map(str::to_string))
510 .collect();
511
512 let mut density_items: Vec<(f64, Vec<Value>, usize)> = skipped
513 .iter()
514 .filter_map(|(block, fscore, _)| {
515 let block: Vec<Value> = block
516 .iter()
517 .filter(|b| !selected_ids.contains(b["id"].as_str().unwrap_or("")))
518 .cloned()
519 .collect();
520 if block.is_empty() {
521 return None;
522 }
523 let cost = block_cost(&block);
524 let density = fscore / cost.max(1) as f64;
525 Some((density, block, cost))
526 })
527 .collect();
528 density_items.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
529
530 let mut used_tokens = block_cost(&selected);
531 let mut added_ids: HashSet<String> = selected_ids;
532 for (_, block, cost) in density_items {
533 if used_tokens + cost <= budget {
534 for b in block {
535 let bid = b["id"].as_str().unwrap_or("").to_string();
536 if !added_ids.contains(&bid) {
537 selected.push(b);
538 added_ids.insert(bid);
539 }
540 }
541 used_tokens += cost;
542 }
543 }
544 selected
545 }
546
547 fn recall_sparks(&self, q_content: &[f32], q_trigger: &[f32]) -> Result<Vec<Value>> {
548 let embed_version = self
549 .storage
550 .get_meta("embed_version")?
551 .and_then(|v| v.parse::<i64>().ok())
552 .unwrap_or(1);
553
554 let content_res = self
555 .storage
556 .search_vec_content(q_content, self.top_k_candidates)?;
557 let trigger_res = self
558 .storage
559 .search_vec_trigger(q_trigger, self.top_k_candidates)?;
560
561 let all_ids: Vec<&str> = {
563 let mut seen = HashSet::new();
564 content_res
565 .iter()
566 .chain(trigger_res.iter())
567 .map(|(id, _)| id.as_str())
568 .filter(|id| seen.insert(*id))
569 .collect()
570 };
571 let chunks = self.storage.get_chunks_by_ids(&all_ids)?;
572
573 let mut spark_scores: HashMap<String, (f32, Value)> = HashMap::new();
574 for (cid, sim) in content_res.iter().chain(trigger_res.iter()) {
575 if let Some(chunk) = chunks.get(cid) {
576 if chunk.get("origin").and_then(Value::as_str) != Some("spark") {
577 continue;
578 }
579 if chunk.get("state").and_then(Value::as_str) == Some("archived") {
580 continue;
581 }
582 let maturity = chunk.get("maturity").and_then(Value::as_str).unwrap_or("");
583 if maturity == "promoted" || maturity == "dropped" {
584 continue;
585 }
586 let ev = chunk
587 .get("embed_version")
588 .and_then(Value::as_i64)
589 .unwrap_or(1);
590 if ev < embed_version {
591 continue;
592 }
593 let entry = spark_scores
594 .entry(cid.clone())
595 .or_insert_with(|| (*sim, chunk.clone()));
596 if *sim > entry.0 {
597 *entry = (*sim, chunk.clone());
598 }
599 }
600 }
601 let mut sparks: Vec<(f32, Value)> = spark_scores.into_values().collect();
602 sparks.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
603 Ok(sparks
604 .into_iter()
605 .take(self.top_k_candidates)
606 .map(|(_, c)| c)
607 .collect())
608 }
609
610 #[allow(clippy::too_many_arguments)]
611 fn write_recall_trace(
612 &self,
613 trace_id: &str,
614 query: &str,
615 context_key: &str,
616 scored: &[(f64, Value)],
617 visible: &[Value],
618 sparks: &[Value],
619 depth_skipped: &[String],
620 skipped_reasons: &HashMap<String, String>,
621 refine_mode: &str,
622 source: &str,
623 now: &str,
624 ) -> Result<()> {
625 let lib_id = self.storage.lib_id()?;
626 self.storage.begin_immediate()?;
627 let result = (|| -> Result<()> {
628 for (rank, (_, chunk)) in scored.iter().enumerate() {
629 let cid = chunk["id"].as_str().unwrap_or("");
630 let sim = chunk.get("_fused_score").and_then(Value::as_f64);
631 let rm = skipped_reasons
633 .get(cid)
634 .map(|r| format!("skipped:{r}"))
635 .or_else(|| {
636 if refine_mode != "off" && !refine_mode.is_empty() {
637 Some(refine_mode.to_string())
638 } else {
639 None
640 }
641 });
642 self.storage.insert_usage_trace(
643 trace_id,
644 Some(cid),
645 "retrieved",
646 1.0,
647 sim,
648 rm.as_deref(),
649 None,
650 Some((rank + 1) as i64),
651 None,
652 source,
653 now,
654 )?;
655 }
656 for (rank, chunk) in visible.iter().enumerate() {
657 let cid = chunk["id"].as_str().unwrap_or("");
658 self.storage.insert_usage_trace(
659 trace_id,
660 Some(cid),
661 "selected",
662 1.0,
663 None,
664 None,
665 None,
666 Some((rank + 1) as i64),
667 None,
668 source,
669 now,
670 )?;
671 if chunk
673 .get("_trimmed")
674 .and_then(Value::as_bool)
675 .unwrap_or(false)
676 {
677 self.storage.insert_usage_trace(
678 trace_id,
679 Some(cid),
680 "refined",
681 1.0,
682 None,
683 Some("trim"),
684 None,
685 Some((rank + 1) as i64),
686 None,
687 source,
688 now,
689 )?;
690 }
691 }
692 for (rank, chunk) in sparks.iter().enumerate() {
694 let cid = chunk["id"].as_str().unwrap_or("");
695 self.storage.insert_usage_trace(
696 trace_id,
697 Some(cid),
698 "retrieved",
699 1.0,
700 None,
701 Some("spark"),
702 None,
703 Some((rank + 1) as i64),
704 None,
705 source,
706 now,
707 )?;
708 }
709 let snapshot = json!({
710 "retrieved": scored.iter().map(|(_, c)| c["id"].as_str().unwrap_or("")).collect::<Vec<_>>(),
711 "selected": visible.iter().map(|c| c["id"].as_str().unwrap_or("")).collect::<Vec<_>>(),
712 "sparks": sparks.iter().map(|c| c["id"].as_str().unwrap_or("")).collect::<Vec<_>>(),
713 "depth_skipped": depth_skipped,
714 "skipped_reasons": skipped_reasons,
715 });
716 let log = EpisodicLogRow {
717 id: gen_uuid(),
718 trace_id: trace_id.to_string(),
719 lib_id,
720 ts: now.to_string(),
721 query: Some(query.to_string()),
722 recall_snapshot: Some(snapshot.to_string()),
723 event_source: source.to_string(),
724 task_state: "recalled".to_string(),
725 usage_state: "unknown".to_string(),
726 context_key: Some(context_key.to_string()),
727 distill_state: "open".to_string(),
728 ..Default::default()
729 };
730 self.storage.upsert_episodic_log(&log)?;
731 self.storage.commit()
732 })();
733 if result.is_err() {
734 let _ = self.storage.rollback();
735 }
736 result
737 }
738}