1use super::*;
2
3#[derive(Debug, Clone, Default)]
9pub struct RecallParams<'a> {
10 pub query: &'a str,
11 pub budget: usize,
12 pub trace: bool,
13 pub include_sparks: bool,
14 pub top: Option<usize>,
15 pub source: &'a str,
16 pub expand_deps: &'a str, pub allow_trim: bool, pub refine_mode: &'a str, }
20
21impl KnowledgeBase {
22 pub fn recall(&self, params: RecallParams<'_>) -> Result<RecallResult> {
23 let RecallParams {
24 query,
25 budget,
26 trace,
27 include_sparks,
28 top,
29 source,
30 expand_deps,
31 allow_trim,
32 refine_mode,
33 } = params;
34 let expand_deps = if expand_deps.is_empty() {
35 "false"
36 } else {
37 expand_deps
38 };
39 let refine_mode = if refine_mode.is_empty() {
40 "off"
41 } else {
42 refine_mode
43 };
44 validate_source(source)?;
45 let trace_id = gen_uuid();
46 let now = utc_now_iso();
47
48 let (q_content, q_trigger) = self
49 .embedding
50 .embed_both(query)
51 .map_err(|e| InnateError::EmbeddingUnavailable(e.to_string()))?;
52
53 let mut candidates = self.ann_candidates(&q_content, &q_trigger)?;
55 self.apply_soft_dep_bonus(&mut candidates)?;
56
57 let scored = self.score_candidates(candidates, query)?;
59
60 let (selected, skipped, skipped_reasons) =
62 self.pack(&scored, budget, expand_deps, allow_trim, query)?;
63
64 let depth_skipped: Vec<String> = skipped_reasons
65 .iter()
66 .filter(|(_, r)| r.as_str() == "dep_depth_limit")
67 .map(|(id, _)| id.clone())
68 .collect();
69
70 let mut selected = selected;
72 if self.density_refill {
73 selected = self.density_refill(selected, &skipped, budget);
74 }
75
76 let limited = limit_knowledge(selected, top);
77 let visible = if refine_mode == "adapt" {
78 self.refiner
79 .refine(limited.clone(), Some(budget))
80 .unwrap_or(limited)
81 } else {
82 limited
83 };
84
85 let sparks = if include_sparks {
87 self.recall_sparks(&q_content, &q_trigger)?
88 } else {
89 vec![]
90 };
91
92 if trace {
93 self.write_recall_trace(
94 &trace_id,
95 query,
96 &scored,
97 &visible,
98 &sparks,
99 &depth_skipped,
100 &skipped_reasons,
101 refine_mode,
102 source,
103 &now,
104 )?;
105 }
106
107 let empty = visible.is_empty() && sparks.is_empty();
108 Ok(RecallResult {
109 knowledge: visible,
110 sparks,
111 trace_id,
112 empty,
113 depth_skipped,
114 skipped_reasons,
115 })
116 }
117
118 fn ann_candidates(
119 &self,
120 q_content: &[f32],
121 q_trigger: &[f32],
122 ) -> Result<HashMap<String, CandidateInfo>> {
123 let embed_version = self
124 .storage
125 .get_meta("embed_version")?
126 .and_then(|v| v.parse::<i64>().ok())
127 .unwrap_or(1);
128
129 let content_res = self
130 .storage
131 .search_vec_content(q_content, self.top_k_candidates * 2)?;
132 let trigger_res = self
133 .storage
134 .search_vec_trigger(q_trigger, self.top_k_candidates * 2)?;
135
136 let all_ids: Vec<&str> = {
138 let mut seen = HashSet::new();
139 content_res
140 .iter()
141 .chain(trigger_res.iter())
142 .map(|(id, _)| id.as_str())
143 .filter(|id| seen.insert(*id))
144 .collect()
145 };
146 let chunks = self.storage.get_chunks_by_ids(&all_ids)?;
147
148 let mut candidates: HashMap<String, CandidateInfo> = HashMap::new();
149
150 for (cid, sim) in &content_res {
151 if let Some(chunk) = chunks.get(cid) {
152 if chunk_is_valid_for_recall(chunk, embed_version) {
153 let e = candidates
154 .entry(cid.clone())
155 .or_insert_with(|| CandidateInfo {
156 chunk: chunk.clone(),
157 sim_content: 0.0,
158 sim_trigger: 0.0,
159 });
160 e.sim_content = e.sim_content.max(*sim);
161 }
162 }
163 }
164 for (cid, sim) in &trigger_res {
165 if let Some(chunk) = chunks.get(cid) {
166 if chunk_is_valid_for_recall(chunk, embed_version) {
167 let e = candidates
168 .entry(cid.clone())
169 .or_insert_with(|| CandidateInfo {
170 chunk: chunk.clone(),
171 sim_content: 0.0,
172 sim_trigger: 0.0,
173 });
174 e.sim_trigger = e.sim_trigger.max(*sim);
175 }
176 }
177 }
178 Ok(candidates)
179 }
180
181 fn apply_soft_dep_bonus(&self, candidates: &mut HashMap<String, CandidateInfo>) -> Result<()> {
182 let src_ids: Vec<String> = candidates
185 .iter()
186 .filter(|(_, info)| {
187 info.chunk.get("origin").and_then(Value::as_str) != Some("spark")
188 })
189 .map(|(cid, _)| cid.clone())
190 .collect();
191 if src_ids.is_empty() {
192 return Ok(());
193 }
194 let src_refs: Vec<&str> = src_ids.iter().map(String::as_str).collect();
195 let deps_map = self.storage.get_deps_batch(&src_refs)?;
196
197 let mut target_ids: Vec<String> = Vec::new();
200 let mut seen: HashSet<String> = HashSet::new();
201 for deps in deps_map.values() {
202 for (dst, kind, _) in deps {
203 if kind == "soft" && seen.insert(dst.clone()) {
204 target_ids.push(dst.clone());
205 }
206 }
207 }
208 if target_ids.is_empty() {
209 return Ok(());
210 }
211 let target_refs: Vec<&str> = target_ids.iter().map(String::as_str).collect();
212 let targets = self.storage.get_chunks_by_ids(&target_refs)?;
213
214 for src in &src_ids {
215 let Some(deps) = deps_map.get(src) else {
216 continue;
217 };
218 for (dst, kind, _) in deps {
219 if kind != "soft" {
220 continue;
221 }
222 let Some(target) = targets.get(dst) else {
223 continue;
224 };
225 if target.get("state").and_then(Value::as_str) == Some("archived") {
226 continue;
227 }
228 if target.get("origin").and_then(Value::as_str) == Some("spark") {
229 continue;
230 }
231 let e = candidates
232 .entry(dst.clone())
233 .or_insert_with(|| CandidateInfo {
234 chunk: target.clone(),
235 sim_content: 0.0,
236 sim_trigger: 0.0,
237 });
238 e.sim_content = (e.sim_content + 0.05).min(1.0);
239 }
240 }
241 Ok(())
242 }
243
244 fn score_candidates(
245 &self,
246 candidates: HashMap<String, CandidateInfo>,
247 query: &str,
248 ) -> Result<Vec<(f64, Value)>> {
249 let context_key = content_hash(&normalize_query(query));
250 let cand_ids: Vec<String> = candidates
253 .values()
254 .filter_map(|info| info.chunk.get("id").and_then(Value::as_str).map(str::to_string))
255 .collect();
256 let cand_refs: Vec<&str> = cand_ids.iter().map(String::as_str).collect();
257 let ctx_scores = self.storage.context_scores_batch(&cand_refs, &context_key)?;
258
259 let mut scored: Vec<(f64, Value)> = Vec::with_capacity(candidates.len());
260 for info in candidates.into_values() {
261 let conf = info
262 .chunk
263 .get("confidence")
264 .and_then(Value::as_f64)
265 .unwrap_or(0.5);
266 let chunk_id = info.chunk.get("id").and_then(Value::as_str).unwrap_or("");
267 let context_score = ctx_scores.get(chunk_id).copied().unwrap_or(0.0);
268 let mut fused = self.w_content * info.sim_content as f64
269 + self.w_trigger * info.sim_trigger as f64
270 + self.w_confidence * conf
271 + self.w_context * context_score;
272 if info.chunk.get("state").and_then(Value::as_str) == Some("pending") {
273 fused *= PENDING_RECALL_PENALTY;
274 }
275 let anti = info
276 .chunk
277 .get("anti_trigger_desc")
278 .and_then(Value::as_str)
279 .unwrap_or("");
280 if !anti.is_empty() && anti_trigger_hit(query, anti) {
281 fused *= self.anti_trigger_penalty;
282 }
283 let mut chunk = info.chunk;
284 chunk["_context_score"] = json!(context_score);
285 chunk["_fused_score"] = json!(fused);
286 scored.push((fused, chunk));
287 }
288 scored.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
289 scored.truncate(self.top_k_candidates);
290 Ok(scored)
291 }
292
293 fn pack(
294 &self,
295 scored: &[(f64, Value)],
296 budget: usize,
297 expand_deps: &str,
298 allow_trim: bool,
299 query: &str,
300 ) -> Result<PackResult> {
301 let mut selected: Vec<Value> = vec![];
302 let mut skipped: Vec<(Vec<Value>, f64, usize)> = vec![];
303 let mut skipped_reasons: HashMap<String, String> = HashMap::new();
304 let mut used_ids: HashSet<String> = HashSet::new();
305 let mut used_tokens: usize = 0;
306
307 for (fused, chunk) in scored {
308 let cid = chunk["id"].as_str().unwrap_or("").to_string();
309 if used_ids.contains(&cid) {
310 continue;
311 }
312
313 let (block, dep_skip_reason) = self.build_dep_block(chunk, expand_deps)?;
315 if let Some(reason) = dep_skip_reason {
316 skipped_reasons.insert(cid, reason);
317 continue;
318 }
319
320 let new_block: Vec<Value> = block
321 .iter()
322 .filter(|b| !used_ids.contains(b["id"].as_str().unwrap_or("")))
323 .cloned()
324 .collect();
325 let cost = block_cost(&new_block);
326
327 if used_tokens + cost <= budget {
328 for b in &block {
329 let bid = b["id"].as_str().unwrap_or("").to_string();
330 if !used_ids.contains(&bid) {
331 let mut b = b.clone();
332 b["_fused_score"] = json!(fused);
333 selected.push(b);
334 used_ids.insert(bid);
335 }
336 }
337 used_tokens += cost;
338 } else if allow_trim {
339 if let Some(trimmed) =
341 self.refiner
342 .trim(&block, query, budget.saturating_sub(used_tokens))
343 {
344 let trim_cost = block_cost(&trimmed);
345 if used_tokens + trim_cost <= budget {
346 for b in &trimmed {
347 let bid = b["id"].as_str().unwrap_or("").to_string();
348 if !used_ids.contains(&bid) {
349 let mut b = b.clone();
350 b["_fused_score"] = json!(fused);
351 b["_trimmed"] = json!(true);
352 selected.push(b);
353 used_ids.insert(bid);
354 }
355 }
356 used_tokens += trim_cost;
357 continue;
358 }
359 }
360 skipped.push((block, *fused, cost));
361 } else {
362 skipped.push((block, *fused, cost));
363 }
364 }
365 Ok((selected, skipped, skipped_reasons))
366 }
367
368 fn build_dep_block(
371 &self,
372 seed: &Value,
373 expand_deps: &str,
374 ) -> Result<(Vec<Value>, Option<String>)> {
375 if expand_deps == "false" || expand_deps.is_empty() {
376 return Ok((vec![seed.clone()], None));
377 }
378 let seed_id = seed["id"].as_str().unwrap_or("");
379 match expand_deps {
380 "direct" => {
381 let deps = self.storage.get_deps(seed_id)?;
382 let mut block = vec![seed.clone()];
383 for (dep_id, kind, _) in &deps {
384 if kind != "hard" {
385 continue;
386 }
387 match self.validate_hard_dep(dep_id)? {
388 Some(chunk) => block.push(chunk),
389 None => return Ok((vec![], Some("hard_dep_unavailable".to_string()))),
390 }
391 }
392 Ok((block, None))
393 }
394 "closure" => {
395 let mut block = vec![seed.clone()];
396 let mut visited: HashSet<String> = [seed_id.to_string()].into();
397 match self.expand_hard_closure(seed_id, &mut visited, &mut block, 0, 3)? {
398 Some(reason) => Ok((vec![], Some(reason))),
399 None => Ok((block, None)),
400 }
401 }
402 _ => Ok((vec![seed.clone()], None)),
403 }
404 }
405
406 fn validate_hard_dep(&self, dep_id: &str) -> Result<Option<Value>> {
408 match self.storage.get_chunk(dep_id)? {
409 None => Ok(None),
410 Some(chunk) => {
411 let state = chunk.get("state").and_then(Value::as_str).unwrap_or("");
412 let origin = chunk.get("origin").and_then(Value::as_str).unwrap_or("");
413 let embed_v = chunk
414 .get("embed_version")
415 .and_then(Value::as_i64)
416 .unwrap_or(0);
417 if state == "archived" || origin == "spark" || embed_v == 0 {
418 Ok(None)
419 } else {
420 Ok(Some(chunk))
421 }
422 }
423 }
424 }
425
426 fn expand_hard_closure(
428 &self,
429 id: &str,
430 visited: &mut HashSet<String>,
431 block: &mut Vec<Value>,
432 depth: usize,
433 max_depth: usize,
434 ) -> Result<Option<String>> {
435 if depth >= max_depth {
436 return Ok(Some("dep_depth_limit".to_string()));
437 }
438 let deps = self.storage.get_deps(id)?;
439 for (dep_id, kind, _) in &deps {
440 if kind != "hard" {
441 continue;
442 }
443 if visited.contains(dep_id) {
444 continue;
445 } visited.insert(dep_id.clone());
447 match self.validate_hard_dep(dep_id)? {
448 None => return Ok(Some("hard_dep_unavailable".to_string())),
449 Some(chunk) => {
450 block.push(chunk);
451 if let Some(reason) =
452 self.expand_hard_closure(dep_id, visited, block, depth + 1, max_depth)?
453 {
454 return Ok(Some(reason));
455 }
456 }
457 }
458 }
459 Ok(None)
460 }
461
462 fn density_refill(
463 &self,
464 mut selected: Vec<Value>,
465 skipped: &[(Vec<Value>, f64, usize)],
466 budget: usize,
467 ) -> Vec<Value> {
468 let used_tokens = block_cost(&selected);
469 if used_tokens >= budget {
470 return selected;
471 }
472
473 let selected_ids: HashSet<String> = selected
474 .iter()
475 .filter_map(|c| c["id"].as_str().map(str::to_string))
476 .collect();
477
478 let mut density_items: Vec<(f64, Vec<Value>, usize)> = skipped
479 .iter()
480 .filter_map(|(block, fscore, _)| {
481 let block: Vec<Value> = block
482 .iter()
483 .filter(|b| !selected_ids.contains(b["id"].as_str().unwrap_or("")))
484 .cloned()
485 .collect();
486 if block.is_empty() {
487 return None;
488 }
489 let cost = block_cost(&block);
490 let density = fscore / cost.max(1) as f64;
491 Some((density, block, cost))
492 })
493 .collect();
494 density_items.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
495
496 let mut used_tokens = block_cost(&selected);
497 let mut added_ids: HashSet<String> = selected_ids;
498 for (_, block, cost) in density_items {
499 if used_tokens + cost <= budget {
500 for b in block {
501 let bid = b["id"].as_str().unwrap_or("").to_string();
502 if !added_ids.contains(&bid) {
503 selected.push(b);
504 added_ids.insert(bid);
505 }
506 }
507 used_tokens += cost;
508 }
509 }
510 selected
511 }
512
513 fn recall_sparks(&self, q_content: &[f32], q_trigger: &[f32]) -> Result<Vec<Value>> {
514 let embed_version = self
515 .storage
516 .get_meta("embed_version")?
517 .and_then(|v| v.parse::<i64>().ok())
518 .unwrap_or(1);
519
520 let content_res = self
521 .storage
522 .search_vec_content(q_content, self.top_k_candidates)?;
523 let trigger_res = self
524 .storage
525 .search_vec_trigger(q_trigger, self.top_k_candidates)?;
526
527 let all_ids: Vec<&str> = {
529 let mut seen = HashSet::new();
530 content_res
531 .iter()
532 .chain(trigger_res.iter())
533 .map(|(id, _)| id.as_str())
534 .filter(|id| seen.insert(*id))
535 .collect()
536 };
537 let chunks = self.storage.get_chunks_by_ids(&all_ids)?;
538
539 let mut spark_scores: HashMap<String, (f32, Value)> = HashMap::new();
540 for (cid, sim) in content_res.iter().chain(trigger_res.iter()) {
541 if let Some(chunk) = chunks.get(cid) {
542 if chunk.get("origin").and_then(Value::as_str) != Some("spark") {
543 continue;
544 }
545 if chunk.get("state").and_then(Value::as_str) == Some("archived") {
546 continue;
547 }
548 let maturity = chunk.get("maturity").and_then(Value::as_str).unwrap_or("");
549 if maturity == "promoted" || maturity == "dropped" {
550 continue;
551 }
552 let ev = chunk
553 .get("embed_version")
554 .and_then(Value::as_i64)
555 .unwrap_or(1);
556 if ev < embed_version {
557 continue;
558 }
559 let entry = spark_scores
560 .entry(cid.clone())
561 .or_insert_with(|| (*sim, chunk.clone()));
562 if *sim > entry.0 {
563 *entry = (*sim, chunk.clone());
564 }
565 }
566 }
567 let mut sparks: Vec<(f32, Value)> = spark_scores.into_values().collect();
568 sparks.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
569 Ok(sparks
570 .into_iter()
571 .take(self.top_k_candidates)
572 .map(|(_, c)| c)
573 .collect())
574 }
575
576 #[allow(clippy::too_many_arguments)]
577 fn write_recall_trace(
578 &self,
579 trace_id: &str,
580 query: &str,
581 scored: &[(f64, Value)],
582 visible: &[Value],
583 sparks: &[Value],
584 depth_skipped: &[String],
585 skipped_reasons: &HashMap<String, String>,
586 refine_mode: &str,
587 source: &str,
588 now: &str,
589 ) -> Result<()> {
590 let lib_id = self.storage.lib_id()?;
591 self.storage.begin_immediate()?;
592 let result = (|| -> Result<()> {
593 for (rank, (_, chunk)) in scored.iter().enumerate() {
594 let cid = chunk["id"].as_str().unwrap_or("");
595 let sim = chunk.get("_fused_score").and_then(Value::as_f64);
596 let rm = skipped_reasons
598 .get(cid)
599 .map(|r| format!("skipped:{r}"))
600 .or_else(|| {
601 if refine_mode != "off" && !refine_mode.is_empty() {
602 Some(refine_mode.to_string())
603 } else {
604 None
605 }
606 });
607 self.storage.insert_usage_trace(
608 trace_id,
609 Some(cid),
610 "retrieved",
611 1.0,
612 sim,
613 rm.as_deref(),
614 None,
615 Some((rank + 1) as i64),
616 None,
617 source,
618 now,
619 )?;
620 }
621 for (rank, chunk) in visible.iter().enumerate() {
622 let cid = chunk["id"].as_str().unwrap_or("");
623 self.storage.insert_usage_trace(
624 trace_id,
625 Some(cid),
626 "selected",
627 1.0,
628 None,
629 None,
630 None,
631 Some((rank + 1) as i64),
632 None,
633 source,
634 now,
635 )?;
636 if chunk
638 .get("_trimmed")
639 .and_then(Value::as_bool)
640 .unwrap_or(false)
641 {
642 self.storage.insert_usage_trace(
643 trace_id,
644 Some(cid),
645 "refined",
646 1.0,
647 None,
648 Some("trim"),
649 None,
650 Some((rank + 1) as i64),
651 None,
652 source,
653 now,
654 )?;
655 }
656 }
657 for (rank, chunk) in sparks.iter().enumerate() {
659 let cid = chunk["id"].as_str().unwrap_or("");
660 self.storage.insert_usage_trace(
661 trace_id,
662 Some(cid),
663 "retrieved",
664 1.0,
665 None,
666 Some("spark"),
667 None,
668 Some((rank + 1) as i64),
669 None,
670 source,
671 now,
672 )?;
673 }
674 let snapshot = json!({
675 "retrieved": scored.iter().map(|(_, c)| c["id"].as_str().unwrap_or("")).collect::<Vec<_>>(),
676 "selected": visible.iter().map(|c| c["id"].as_str().unwrap_or("")).collect::<Vec<_>>(),
677 "sparks": sparks.iter().map(|c| c["id"].as_str().unwrap_or("")).collect::<Vec<_>>(),
678 "depth_skipped": depth_skipped,
679 "skipped_reasons": skipped_reasons,
680 });
681 let log = EpisodicLogRow {
682 id: gen_uuid(),
683 trace_id: trace_id.to_string(),
684 lib_id,
685 ts: now.to_string(),
686 query: Some(query.to_string()),
687 recall_snapshot: Some(snapshot.to_string()),
688 event_source: source.to_string(),
689 task_state: "recalled".to_string(),
690 usage_state: "unknown".to_string(),
691 context_key: Some(content_hash(&normalize_query(query))),
692 distill_state: "open".to_string(),
693 ..Default::default()
694 };
695 self.storage.upsert_episodic_log(&log)?;
696 self.storage.commit()
697 })();
698 if result.is_err() {
699 let _ = self.storage.rollback();
700 }
701 result
702 }
703}