Skip to main content

innate_core/kb/
recall.rs

1use super::*;
2
3/// Parameters for [`KnowledgeBase::recall`].
4///
5/// Borrowed, `Default`-able: construct with `RecallParams { query, budget, source, ..Default::default() }`.
6/// Empty-string defaults are normalized inside `recall`: `expand_deps` empty → `"false"`,
7/// `refine_mode` empty → `"off"`.
8#[derive(Debug, Clone, Default)]
9pub struct RecallParams<'a> {
10    pub query: &'a str,
11    pub budget: usize,
12    pub trace: bool,
13    pub include_sparks: bool,
14    pub top: Option<usize>,
15    pub source: &'a str,
16    pub expand_deps: &'a str, // "false" | "direct" | "closure"
17    pub allow_trim: bool,     // if true, invoke Refiner::trim when block doesn't fit
18    pub refine_mode: &'a str, // "off" | "trim" | "adapt" — recorded in trace
19    /// Relevance gate: drop candidates whose fused score is below this value
20    /// **before** packing/trace, so the trace only records knowledge that was
21    /// actually surfaced. `None` disables the gate. Used by always-on hooks
22    /// (UserPromptSubmit / SessionStart) to stay high-frequency without noise.
23    pub min_score: Option<f64>,
24}
25
26impl KnowledgeBase {
27    pub fn recall(&self, params: RecallParams<'_>) -> Result<RecallResult> {
28        let RecallParams {
29            query,
30            budget,
31            trace,
32            include_sparks,
33            top,
34            source,
35            expand_deps,
36            allow_trim,
37            refine_mode,
38            min_score,
39        } = params;
40        let expand_deps = if expand_deps.is_empty() {
41            "false"
42        } else {
43            expand_deps
44        };
45        let refine_mode = if refine_mode.is_empty() {
46            "off"
47        } else {
48            refine_mode
49        };
50        validate_source(source)?;
51        let trace_id = gen_uuid();
52        let now = utc_now_iso();
53
54        let (q_content, q_trigger) = self
55            .embedding
56            .embed_both(query)
57            .map_err(|e| InnateError::EmbeddingUnavailable(e.to_string()))?;
58
59        // ANN candidates (non-spark)
60        let mut candidates = self.ann_candidates(&q_content, &q_trigger)?;
61        self.apply_soft_dep_bonus(&mut candidates)?;
62
63        // Score + anti-trigger penalty
64        let mut scored = self.score_candidates(candidates, query)?;
65
66        // Relevance gate — drop sub-threshold candidates before packing/trace so the
67        // trace records only what was actually surfaced (keeps selected→used stats clean).
68        if let Some(min) = min_score {
69            scored.retain(|(fused, _)| *fused >= min);
70        }
71
72        // First-fit pack with dep expansion
73        let (selected, skipped, skipped_reasons) =
74            self.pack(&scored, budget, expand_deps, allow_trim, query)?;
75
76        let depth_skipped: Vec<String> = skipped_reasons
77            .iter()
78            .filter(|(_, r)| r.as_str() == "dep_depth_limit")
79            .map(|(id, _)| id.clone())
80            .collect();
81
82        // Density refill
83        let mut selected = selected;
84        if self.density_refill {
85            selected = self.density_refill(selected, &skipped, budget);
86        }
87
88        let limited = limit_knowledge(selected, top);
89        let visible = if refine_mode == "adapt" {
90            self.refiner
91                .refine(limited.clone(), Some(budget))
92                .unwrap_or(limited)
93        } else {
94            limited
95        };
96
97        // Sparks
98        let sparks = if include_sparks {
99            self.recall_sparks(&q_content, &q_trigger)?
100        } else {
101            vec![]
102        };
103
104        if trace {
105            self.write_recall_trace(
106                &trace_id,
107                query,
108                &scored,
109                &visible,
110                &sparks,
111                &depth_skipped,
112                &skipped_reasons,
113                refine_mode,
114                source,
115                &now,
116            )?;
117        }
118
119        let empty = visible.is_empty() && sparks.is_empty();
120        Ok(RecallResult {
121            knowledge: visible,
122            sparks,
123            trace_id,
124            empty,
125            depth_skipped,
126            skipped_reasons,
127        })
128    }
129
130    fn ann_candidates(
131        &self,
132        q_content: &[f32],
133        q_trigger: &[f32],
134    ) -> Result<HashMap<String, CandidateInfo>> {
135        let embed_version = self
136            .storage
137            .get_meta("embed_version")?
138            .and_then(|v| v.parse::<i64>().ok())
139            .unwrap_or(1);
140
141        let content_res = self
142            .storage
143            .search_vec_content(q_content, self.top_k_candidates * 2)?;
144        let trigger_res = self
145            .storage
146            .search_vec_trigger(q_trigger, self.top_k_candidates * 2)?;
147
148        // Collect unique ids and batch-fetch all chunks in two queries instead of N individual ones.
149        let all_ids: Vec<&str> = {
150            let mut seen = HashSet::new();
151            content_res
152                .iter()
153                .chain(trigger_res.iter())
154                .map(|(id, _)| id.as_str())
155                .filter(|id| seen.insert(*id))
156                .collect()
157        };
158        let chunks = self.storage.get_chunks_by_ids(&all_ids)?;
159
160        let mut candidates: HashMap<String, CandidateInfo> = HashMap::new();
161
162        for (cid, sim) in &content_res {
163            if let Some(chunk) = chunks.get(cid) {
164                if chunk_is_valid_for_recall(chunk, embed_version) {
165                    let e = candidates
166                        .entry(cid.clone())
167                        .or_insert_with(|| CandidateInfo {
168                            chunk: chunk.clone(),
169                            sim_content: 0.0,
170                            sim_trigger: 0.0,
171                        });
172                    e.sim_content = e.sim_content.max(*sim);
173                }
174            }
175        }
176        for (cid, sim) in &trigger_res {
177            if let Some(chunk) = chunks.get(cid) {
178                if chunk_is_valid_for_recall(chunk, embed_version) {
179                    let e = candidates
180                        .entry(cid.clone())
181                        .or_insert_with(|| CandidateInfo {
182                            chunk: chunk.clone(),
183                            sim_content: 0.0,
184                            sim_trigger: 0.0,
185                        });
186                    e.sim_trigger = e.sim_trigger.max(*sim);
187                }
188            }
189        }
190        Ok(candidates)
191    }
192
193    fn apply_soft_dep_bonus(&self, candidates: &mut HashMap<String, CandidateInfo>) -> Result<()> {
194        // Collect non-spark candidate ids and batch-fetch their outgoing deps
195        // in a single query (was one get_deps per candidate).
196        let src_ids: Vec<String> = candidates
197            .iter()
198            .filter(|(_, info)| {
199                info.chunk.get("origin").and_then(Value::as_str) != Some("spark")
200            })
201            .map(|(cid, _)| cid.clone())
202            .collect();
203        if src_ids.is_empty() {
204            return Ok(());
205        }
206        let src_refs: Vec<&str> = src_ids.iter().map(String::as_str).collect();
207        let deps_map = self.storage.get_deps_batch(&src_refs)?;
208
209        // Gather distinct soft-dep targets and batch-fetch them in one query
210        // (was one get_chunk per soft edge).
211        let mut target_ids: Vec<String> = Vec::new();
212        let mut seen: HashSet<String> = HashSet::new();
213        for deps in deps_map.values() {
214            for (dst, kind, _) in deps {
215                if kind == "soft" && seen.insert(dst.clone()) {
216                    target_ids.push(dst.clone());
217                }
218            }
219        }
220        if target_ids.is_empty() {
221            return Ok(());
222        }
223        let target_refs: Vec<&str> = target_ids.iter().map(String::as_str).collect();
224        let targets = self.storage.get_chunks_by_ids(&target_refs)?;
225
226        for src in &src_ids {
227            let Some(deps) = deps_map.get(src) else {
228                continue;
229            };
230            for (dst, kind, _) in deps {
231                if kind != "soft" {
232                    continue;
233                }
234                let Some(target) = targets.get(dst) else {
235                    continue;
236                };
237                if target.get("state").and_then(Value::as_str) == Some("archived") {
238                    continue;
239                }
240                if target.get("origin").and_then(Value::as_str) == Some("spark") {
241                    continue;
242                }
243                let e = candidates
244                    .entry(dst.clone())
245                    .or_insert_with(|| CandidateInfo {
246                        chunk: target.clone(),
247                        sim_content: 0.0,
248                        sim_trigger: 0.0,
249                    });
250                e.sim_content = (e.sim_content + 0.05).min(1.0);
251            }
252        }
253        Ok(())
254    }
255
256    fn score_candidates(
257        &self,
258        candidates: HashMap<String, CandidateInfo>,
259        query: &str,
260    ) -> Result<Vec<(f64, Value)>> {
261        let context_key = content_hash(&normalize_query(query));
262        // Batch-fetch context scores for all candidates in one query
263        // (was one context_score lookup per candidate).
264        let cand_ids: Vec<String> = candidates
265            .values()
266            .filter_map(|info| info.chunk.get("id").and_then(Value::as_str).map(str::to_string))
267            .collect();
268        let cand_refs: Vec<&str> = cand_ids.iter().map(String::as_str).collect();
269        let ctx_scores = self.storage.context_scores_batch(&cand_refs, &context_key)?;
270
271        let mut scored: Vec<(f64, Value)> = Vec::with_capacity(candidates.len());
272        for info in candidates.into_values() {
273            let conf = info
274                .chunk
275                .get("confidence")
276                .and_then(Value::as_f64)
277                .unwrap_or(0.5);
278            let chunk_id = info.chunk.get("id").and_then(Value::as_str).unwrap_or("");
279            let context_score = ctx_scores.get(chunk_id).copied().unwrap_or(0.0);
280            let mut fused = self.w_content * info.sim_content as f64
281                + self.w_trigger * info.sim_trigger as f64
282                + self.w_confidence * conf
283                + self.w_context * context_score;
284            if info.chunk.get("state").and_then(Value::as_str) == Some("pending") {
285                fused *= PENDING_RECALL_PENALTY;
286            }
287            let anti = info
288                .chunk
289                .get("anti_trigger_desc")
290                .and_then(Value::as_str)
291                .unwrap_or("");
292            if !anti.is_empty() && anti_trigger_hit(query, anti) {
293                fused *= self.anti_trigger_penalty;
294            }
295            let mut chunk = info.chunk;
296            chunk["_context_score"] = json!(context_score);
297            chunk["_fused_score"] = json!(fused);
298            scored.push((fused, chunk));
299        }
300        scored.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
301        scored.truncate(self.top_k_candidates);
302        Ok(scored)
303    }
304
305    fn pack(
306        &self,
307        scored: &[(f64, Value)],
308        budget: usize,
309        expand_deps: &str,
310        allow_trim: bool,
311        query: &str,
312    ) -> Result<PackResult> {
313        let mut selected: Vec<Value> = vec![];
314        let mut skipped: Vec<(Vec<Value>, f64, usize)> = vec![];
315        let mut skipped_reasons: HashMap<String, String> = HashMap::new();
316        let mut used_ids: HashSet<String> = HashSet::new();
317        let mut used_tokens: usize = 0;
318
319        for (fused, chunk) in scored {
320            let cid = chunk["id"].as_str().unwrap_or("").to_string();
321            if used_ids.contains(&cid) {
322                continue;
323            }
324
325            // Build block with dep expansion; fail-closed on dep issues.
326            let (block, dep_skip_reason) = self.build_dep_block(chunk, expand_deps)?;
327            if let Some(reason) = dep_skip_reason {
328                skipped_reasons.insert(cid, reason);
329                continue;
330            }
331
332            let new_block: Vec<Value> = block
333                .iter()
334                .filter(|b| !used_ids.contains(b["id"].as_str().unwrap_or("")))
335                .cloned()
336                .collect();
337            let cost = block_cost(&new_block);
338
339            if used_tokens + cost <= budget {
340                for b in &block {
341                    let bid = b["id"].as_str().unwrap_or("").to_string();
342                    if !used_ids.contains(&bid) {
343                        let mut b = b.clone();
344                        b["_fused_score"] = json!(fused);
345                        selected.push(b);
346                        used_ids.insert(bid);
347                    }
348                }
349                used_tokens += cost;
350            } else if allow_trim {
351                // Attempt refiner trim — NullRefiner returns None (no-op).
352                if let Some(trimmed) =
353                    self.refiner
354                        .trim(&block, query, budget.saturating_sub(used_tokens))
355                {
356                    let trim_cost = block_cost(&trimmed);
357                    if used_tokens + trim_cost <= budget {
358                        for b in &trimmed {
359                            let bid = b["id"].as_str().unwrap_or("").to_string();
360                            if !used_ids.contains(&bid) {
361                                let mut b = b.clone();
362                                b["_fused_score"] = json!(fused);
363                                b["_trimmed"] = json!(true);
364                                selected.push(b);
365                                used_ids.insert(bid);
366                            }
367                        }
368                        used_tokens += trim_cost;
369                        continue;
370                    }
371                }
372                skipped.push((block, *fused, cost));
373            } else {
374                skipped.push((block, *fused, cost));
375            }
376        }
377        Ok((selected, skipped, skipped_reasons))
378    }
379
380    /// Expand a seed chunk into a block according to `expand_deps`.
381    /// Returns `(block, Some(skip_reason))` if the block should be discarded (fail-closed).
382    fn build_dep_block(
383        &self,
384        seed: &Value,
385        expand_deps: &str,
386    ) -> Result<(Vec<Value>, Option<String>)> {
387        if expand_deps == "false" || expand_deps.is_empty() {
388            return Ok((vec![seed.clone()], None));
389        }
390        let seed_id = seed["id"].as_str().unwrap_or("");
391        match expand_deps {
392            "direct" => {
393                let deps = self.storage.get_deps(seed_id)?;
394                let mut block = vec![seed.clone()];
395                for (dep_id, kind, _) in &deps {
396                    if kind != "hard" {
397                        continue;
398                    }
399                    match self.validate_hard_dep(dep_id)? {
400                        Some(chunk) => block.push(chunk),
401                        None => return Ok((vec![], Some("hard_dep_unavailable".to_string()))),
402                    }
403                }
404                Ok((block, None))
405            }
406            "closure" => {
407                let mut block = vec![seed.clone()];
408                let mut visited: HashSet<String> = [seed_id.to_string()].into();
409                match self.expand_hard_closure(seed_id, &mut visited, &mut block, 0, 3)? {
410                    Some(reason) => Ok((vec![], Some(reason))),
411                    None => Ok((block, None)),
412                }
413            }
414            _ => Ok((vec![seed.clone()], None)),
415        }
416    }
417
418    /// Returns the chunk if the hard dep is usable, None if it should cause fail-closed.
419    fn validate_hard_dep(&self, dep_id: &str) -> Result<Option<Value>> {
420        match self.storage.get_chunk(dep_id)? {
421            None => Ok(None),
422            Some(chunk) => {
423                let state = chunk.get("state").and_then(Value::as_str).unwrap_or("");
424                let origin = chunk.get("origin").and_then(Value::as_str).unwrap_or("");
425                let embed_v = chunk
426                    .get("embed_version")
427                    .and_then(Value::as_i64)
428                    .unwrap_or(0);
429                if state == "archived" || origin == "spark" || embed_v == 0 {
430                    Ok(None)
431                } else {
432                    Ok(Some(chunk))
433                }
434            }
435        }
436    }
437
438    /// BFS hard-dep expansion up to `max_depth`. Returns Some(reason) on fail-closed.
439    fn expand_hard_closure(
440        &self,
441        id: &str,
442        visited: &mut HashSet<String>,
443        block: &mut Vec<Value>,
444        depth: usize,
445        max_depth: usize,
446    ) -> Result<Option<String>> {
447        if depth >= max_depth {
448            return Ok(Some("dep_depth_limit".to_string()));
449        }
450        let deps = self.storage.get_deps(id)?;
451        for (dep_id, kind, _) in &deps {
452            if kind != "hard" {
453                continue;
454            }
455            if visited.contains(dep_id) {
456                continue;
457            } // cycle guard
458            visited.insert(dep_id.clone());
459            match self.validate_hard_dep(dep_id)? {
460                None => return Ok(Some("hard_dep_unavailable".to_string())),
461                Some(chunk) => {
462                    block.push(chunk);
463                    if let Some(reason) =
464                        self.expand_hard_closure(dep_id, visited, block, depth + 1, max_depth)?
465                    {
466                        return Ok(Some(reason));
467                    }
468                }
469            }
470        }
471        Ok(None)
472    }
473
474    fn density_refill(
475        &self,
476        mut selected: Vec<Value>,
477        skipped: &[(Vec<Value>, f64, usize)],
478        budget: usize,
479    ) -> Vec<Value> {
480        let used_tokens = block_cost(&selected);
481        if used_tokens >= budget {
482            return selected;
483        }
484
485        let selected_ids: HashSet<String> = selected
486            .iter()
487            .filter_map(|c| c["id"].as_str().map(str::to_string))
488            .collect();
489
490        let mut density_items: Vec<(f64, Vec<Value>, usize)> = skipped
491            .iter()
492            .filter_map(|(block, fscore, _)| {
493                let block: Vec<Value> = block
494                    .iter()
495                    .filter(|b| !selected_ids.contains(b["id"].as_str().unwrap_or("")))
496                    .cloned()
497                    .collect();
498                if block.is_empty() {
499                    return None;
500                }
501                let cost = block_cost(&block);
502                let density = fscore / cost.max(1) as f64;
503                Some((density, block, cost))
504            })
505            .collect();
506        density_items.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
507
508        let mut used_tokens = block_cost(&selected);
509        let mut added_ids: HashSet<String> = selected_ids;
510        for (_, block, cost) in density_items {
511            if used_tokens + cost <= budget {
512                for b in block {
513                    let bid = b["id"].as_str().unwrap_or("").to_string();
514                    if !added_ids.contains(&bid) {
515                        selected.push(b);
516                        added_ids.insert(bid);
517                    }
518                }
519                used_tokens += cost;
520            }
521        }
522        selected
523    }
524
525    fn recall_sparks(&self, q_content: &[f32], q_trigger: &[f32]) -> Result<Vec<Value>> {
526        let embed_version = self
527            .storage
528            .get_meta("embed_version")?
529            .and_then(|v| v.parse::<i64>().ok())
530            .unwrap_or(1);
531
532        let content_res = self
533            .storage
534            .search_vec_content(q_content, self.top_k_candidates)?;
535        let trigger_res = self
536            .storage
537            .search_vec_trigger(q_trigger, self.top_k_candidates)?;
538
539        // Batch-fetch all candidate chunk IDs (mirrors the pattern in ann_candidates).
540        let all_ids: Vec<&str> = {
541            let mut seen = HashSet::new();
542            content_res
543                .iter()
544                .chain(trigger_res.iter())
545                .map(|(id, _)| id.as_str())
546                .filter(|id| seen.insert(*id))
547                .collect()
548        };
549        let chunks = self.storage.get_chunks_by_ids(&all_ids)?;
550
551        let mut spark_scores: HashMap<String, (f32, Value)> = HashMap::new();
552        for (cid, sim) in content_res.iter().chain(trigger_res.iter()) {
553            if let Some(chunk) = chunks.get(cid) {
554                if chunk.get("origin").and_then(Value::as_str) != Some("spark") {
555                    continue;
556                }
557                if chunk.get("state").and_then(Value::as_str) == Some("archived") {
558                    continue;
559                }
560                let maturity = chunk.get("maturity").and_then(Value::as_str).unwrap_or("");
561                if maturity == "promoted" || maturity == "dropped" {
562                    continue;
563                }
564                let ev = chunk
565                    .get("embed_version")
566                    .and_then(Value::as_i64)
567                    .unwrap_or(1);
568                if ev < embed_version {
569                    continue;
570                }
571                let entry = spark_scores
572                    .entry(cid.clone())
573                    .or_insert_with(|| (*sim, chunk.clone()));
574                if *sim > entry.0 {
575                    *entry = (*sim, chunk.clone());
576                }
577            }
578        }
579        let mut sparks: Vec<(f32, Value)> = spark_scores.into_values().collect();
580        sparks.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
581        Ok(sparks
582            .into_iter()
583            .take(self.top_k_candidates)
584            .map(|(_, c)| c)
585            .collect())
586    }
587
588    #[allow(clippy::too_many_arguments)]
589    fn write_recall_trace(
590        &self,
591        trace_id: &str,
592        query: &str,
593        scored: &[(f64, Value)],
594        visible: &[Value],
595        sparks: &[Value],
596        depth_skipped: &[String],
597        skipped_reasons: &HashMap<String, String>,
598        refine_mode: &str,
599        source: &str,
600        now: &str,
601    ) -> Result<()> {
602        let lib_id = self.storage.lib_id()?;
603        self.storage.begin_immediate()?;
604        let result = (|| -> Result<()> {
605            for (rank, (_, chunk)) in scored.iter().enumerate() {
606                let cid = chunk["id"].as_str().unwrap_or("");
607                let sim = chunk.get("_fused_score").and_then(Value::as_f64);
608                // For dep-skipped seeds, record their skip reason as refine_mode.
609                let rm = skipped_reasons
610                    .get(cid)
611                    .map(|r| format!("skipped:{r}"))
612                    .or_else(|| {
613                        if refine_mode != "off" && !refine_mode.is_empty() {
614                            Some(refine_mode.to_string())
615                        } else {
616                            None
617                        }
618                    });
619                self.storage.insert_usage_trace(
620                    trace_id,
621                    Some(cid),
622                    "retrieved",
623                    1.0,
624                    sim,
625                    rm.as_deref(),
626                    None,
627                    Some((rank + 1) as i64),
628                    None,
629                    source,
630                    now,
631                )?;
632            }
633            for (rank, chunk) in visible.iter().enumerate() {
634                let cid = chunk["id"].as_str().unwrap_or("");
635                self.storage.insert_usage_trace(
636                    trace_id,
637                    Some(cid),
638                    "selected",
639                    1.0,
640                    None,
641                    None,
642                    None,
643                    Some((rank + 1) as i64),
644                    None,
645                    source,
646                    now,
647                )?;
648                // Write 'refined' event for chunks that came through the trim path.
649                if chunk
650                    .get("_trimmed")
651                    .and_then(Value::as_bool)
652                    .unwrap_or(false)
653                {
654                    self.storage.insert_usage_trace(
655                        trace_id,
656                        Some(cid),
657                        "refined",
658                        1.0,
659                        None,
660                        Some("trim"),
661                        None,
662                        Some((rank + 1) as i64),
663                        None,
664                        source,
665                        now,
666                    )?;
667                }
668            }
669            // Write 'retrieved' events for sparks (for recurring-spark count tracking).
670            for (rank, chunk) in sparks.iter().enumerate() {
671                let cid = chunk["id"].as_str().unwrap_or("");
672                self.storage.insert_usage_trace(
673                    trace_id,
674                    Some(cid),
675                    "retrieved",
676                    1.0,
677                    None,
678                    Some("spark"),
679                    None,
680                    Some((rank + 1) as i64),
681                    None,
682                    source,
683                    now,
684                )?;
685            }
686            let snapshot = json!({
687                "retrieved": scored.iter().map(|(_, c)| c["id"].as_str().unwrap_or("")).collect::<Vec<_>>(),
688                "selected": visible.iter().map(|c| c["id"].as_str().unwrap_or("")).collect::<Vec<_>>(),
689                "sparks": sparks.iter().map(|c| c["id"].as_str().unwrap_or("")).collect::<Vec<_>>(),
690                "depth_skipped": depth_skipped,
691                "skipped_reasons": skipped_reasons,
692            });
693            let log = EpisodicLogRow {
694                id: gen_uuid(),
695                trace_id: trace_id.to_string(),
696                lib_id,
697                ts: now.to_string(),
698                query: Some(query.to_string()),
699                recall_snapshot: Some(snapshot.to_string()),
700                event_source: source.to_string(),
701                task_state: "recalled".to_string(),
702                usage_state: "unknown".to_string(),
703                context_key: Some(content_hash(&normalize_query(query))),
704                distill_state: "open".to_string(),
705                ..Default::default()
706            };
707            self.storage.upsert_episodic_log(&log)?;
708            self.storage.commit()
709        })();
710        if result.is_err() {
711            let _ = self.storage.rollback();
712        }
713        result
714    }
715}