Skip to main content

innate_core/kb/
recall.rs

1use super::*;
2
3/// Parameters for [`KnowledgeBase::recall`].
4///
5/// Borrowed, `Default`-able: construct with `RecallParams { query, budget, source, ..Default::default() }`.
6/// Empty-string defaults are normalized inside `recall`: `expand_deps` empty → `"false"`,
7/// `refine_mode` empty → `"off"`.
8#[derive(Debug, Clone, Default)]
9pub struct RecallParams<'a> {
10    pub query: &'a str,
11    pub budget: usize,
12    pub trace: bool,
13    pub include_sparks: bool,
14    pub top: Option<usize>,
15    pub source: &'a str,
16    pub expand_deps: &'a str, // "false" | "direct" | "closure"
17    pub allow_trim: bool,     // if true, invoke Refiner::trim when block doesn't fit
18    pub refine_mode: &'a str, // "off" | "trim" | "adapt" — recorded in trace
19}
20
21impl KnowledgeBase {
22    pub fn recall(&self, params: RecallParams<'_>) -> Result<RecallResult> {
23        let RecallParams {
24            query,
25            budget,
26            trace,
27            include_sparks,
28            top,
29            source,
30            expand_deps,
31            allow_trim,
32            refine_mode,
33        } = params;
34        let expand_deps = if expand_deps.is_empty() {
35            "false"
36        } else {
37            expand_deps
38        };
39        let refine_mode = if refine_mode.is_empty() {
40            "off"
41        } else {
42            refine_mode
43        };
44        validate_source(source)?;
45        let trace_id = gen_uuid();
46        let now = utc_now_iso();
47
48        let (q_content, q_trigger) = self
49            .embedding
50            .embed_both(query)
51            .map_err(|e| InnateError::EmbeddingUnavailable(e.to_string()))?;
52
53        // ANN candidates (non-spark)
54        let mut candidates = self.ann_candidates(&q_content, &q_trigger)?;
55        self.apply_soft_dep_bonus(&mut candidates)?;
56
57        // Score + anti-trigger penalty
58        let scored = self.score_candidates(candidates, query)?;
59
60        // First-fit pack with dep expansion
61        let (selected, skipped, skipped_reasons) =
62            self.pack(&scored, budget, expand_deps, allow_trim, query)?;
63
64        let depth_skipped: Vec<String> = skipped_reasons
65            .iter()
66            .filter(|(_, r)| r.as_str() == "dep_depth_limit")
67            .map(|(id, _)| id.clone())
68            .collect();
69
70        // Density refill
71        let mut selected = selected;
72        if self.density_refill {
73            selected = self.density_refill(selected, &skipped, budget);
74        }
75
76        let limited = limit_knowledge(selected, top);
77        let visible = if refine_mode == "adapt" {
78            self.refiner
79                .refine(limited.clone(), Some(budget))
80                .unwrap_or(limited)
81        } else {
82            limited
83        };
84
85        // Sparks
86        let sparks = if include_sparks {
87            self.recall_sparks(&q_content, &q_trigger)?
88        } else {
89            vec![]
90        };
91
92        if trace {
93            self.write_recall_trace(
94                &trace_id,
95                query,
96                &scored,
97                &visible,
98                &sparks,
99                &depth_skipped,
100                &skipped_reasons,
101                refine_mode,
102                source,
103                &now,
104            )?;
105        }
106
107        let empty = visible.is_empty() && sparks.is_empty();
108        Ok(RecallResult {
109            knowledge: visible,
110            sparks,
111            trace_id,
112            empty,
113            depth_skipped,
114            skipped_reasons,
115        })
116    }
117
118    fn ann_candidates(
119        &self,
120        q_content: &[f32],
121        q_trigger: &[f32],
122    ) -> Result<HashMap<String, CandidateInfo>> {
123        let embed_version = self
124            .storage
125            .get_meta("embed_version")?
126            .and_then(|v| v.parse::<i64>().ok())
127            .unwrap_or(1);
128
129        let content_res = self
130            .storage
131            .search_vec_content(q_content, self.top_k_candidates * 2)?;
132        let trigger_res = self
133            .storage
134            .search_vec_trigger(q_trigger, self.top_k_candidates * 2)?;
135
136        // Collect unique ids and batch-fetch all chunks in two queries instead of N individual ones.
137        let all_ids: Vec<&str> = {
138            let mut seen = HashSet::new();
139            content_res
140                .iter()
141                .chain(trigger_res.iter())
142                .map(|(id, _)| id.as_str())
143                .filter(|id| seen.insert(*id))
144                .collect()
145        };
146        let chunks = self.storage.get_chunks_by_ids(&all_ids)?;
147
148        let mut candidates: HashMap<String, CandidateInfo> = HashMap::new();
149
150        for (cid, sim) in &content_res {
151            if let Some(chunk) = chunks.get(cid) {
152                if chunk_is_valid_for_recall(chunk, embed_version) {
153                    let e = candidates
154                        .entry(cid.clone())
155                        .or_insert_with(|| CandidateInfo {
156                            chunk: chunk.clone(),
157                            sim_content: 0.0,
158                            sim_trigger: 0.0,
159                        });
160                    e.sim_content = e.sim_content.max(*sim);
161                }
162            }
163        }
164        for (cid, sim) in &trigger_res {
165            if let Some(chunk) = chunks.get(cid) {
166                if chunk_is_valid_for_recall(chunk, embed_version) {
167                    let e = candidates
168                        .entry(cid.clone())
169                        .or_insert_with(|| CandidateInfo {
170                            chunk: chunk.clone(),
171                            sim_content: 0.0,
172                            sim_trigger: 0.0,
173                        });
174                    e.sim_trigger = e.sim_trigger.max(*sim);
175                }
176            }
177        }
178        Ok(candidates)
179    }
180
181    fn apply_soft_dep_bonus(&self, candidates: &mut HashMap<String, CandidateInfo>) -> Result<()> {
182        // Collect non-spark candidate ids and batch-fetch their outgoing deps
183        // in a single query (was one get_deps per candidate).
184        let src_ids: Vec<String> = candidates
185            .iter()
186            .filter(|(_, info)| {
187                info.chunk.get("origin").and_then(Value::as_str) != Some("spark")
188            })
189            .map(|(cid, _)| cid.clone())
190            .collect();
191        if src_ids.is_empty() {
192            return Ok(());
193        }
194        let src_refs: Vec<&str> = src_ids.iter().map(String::as_str).collect();
195        let deps_map = self.storage.get_deps_batch(&src_refs)?;
196
197        // Gather distinct soft-dep targets and batch-fetch them in one query
198        // (was one get_chunk per soft edge).
199        let mut target_ids: Vec<String> = Vec::new();
200        let mut seen: HashSet<String> = HashSet::new();
201        for deps in deps_map.values() {
202            for (dst, kind, _) in deps {
203                if kind == "soft" && seen.insert(dst.clone()) {
204                    target_ids.push(dst.clone());
205                }
206            }
207        }
208        if target_ids.is_empty() {
209            return Ok(());
210        }
211        let target_refs: Vec<&str> = target_ids.iter().map(String::as_str).collect();
212        let targets = self.storage.get_chunks_by_ids(&target_refs)?;
213
214        for src in &src_ids {
215            let Some(deps) = deps_map.get(src) else {
216                continue;
217            };
218            for (dst, kind, _) in deps {
219                if kind != "soft" {
220                    continue;
221                }
222                let Some(target) = targets.get(dst) else {
223                    continue;
224                };
225                if target.get("state").and_then(Value::as_str) == Some("archived") {
226                    continue;
227                }
228                if target.get("origin").and_then(Value::as_str) == Some("spark") {
229                    continue;
230                }
231                let e = candidates
232                    .entry(dst.clone())
233                    .or_insert_with(|| CandidateInfo {
234                        chunk: target.clone(),
235                        sim_content: 0.0,
236                        sim_trigger: 0.0,
237                    });
238                e.sim_content = (e.sim_content + 0.05).min(1.0);
239            }
240        }
241        Ok(())
242    }
243
244    fn score_candidates(
245        &self,
246        candidates: HashMap<String, CandidateInfo>,
247        query: &str,
248    ) -> Result<Vec<(f64, Value)>> {
249        let context_key = content_hash(&normalize_query(query));
250        // Batch-fetch context scores for all candidates in one query
251        // (was one context_score lookup per candidate).
252        let cand_ids: Vec<String> = candidates
253            .values()
254            .filter_map(|info| info.chunk.get("id").and_then(Value::as_str).map(str::to_string))
255            .collect();
256        let cand_refs: Vec<&str> = cand_ids.iter().map(String::as_str).collect();
257        let ctx_scores = self.storage.context_scores_batch(&cand_refs, &context_key)?;
258
259        let mut scored: Vec<(f64, Value)> = Vec::with_capacity(candidates.len());
260        for info in candidates.into_values() {
261            let conf = info
262                .chunk
263                .get("confidence")
264                .and_then(Value::as_f64)
265                .unwrap_or(0.5);
266            let chunk_id = info.chunk.get("id").and_then(Value::as_str).unwrap_or("");
267            let context_score = ctx_scores.get(chunk_id).copied().unwrap_or(0.0);
268            let mut fused = self.w_content * info.sim_content as f64
269                + self.w_trigger * info.sim_trigger as f64
270                + self.w_confidence * conf
271                + self.w_context * context_score;
272            if info.chunk.get("state").and_then(Value::as_str) == Some("pending") {
273                fused *= PENDING_RECALL_PENALTY;
274            }
275            let anti = info
276                .chunk
277                .get("anti_trigger_desc")
278                .and_then(Value::as_str)
279                .unwrap_or("");
280            if !anti.is_empty() && anti_trigger_hit(query, anti) {
281                fused *= self.anti_trigger_penalty;
282            }
283            let mut chunk = info.chunk;
284            chunk["_context_score"] = json!(context_score);
285            chunk["_fused_score"] = json!(fused);
286            scored.push((fused, chunk));
287        }
288        scored.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
289        scored.truncate(self.top_k_candidates);
290        Ok(scored)
291    }
292
293    fn pack(
294        &self,
295        scored: &[(f64, Value)],
296        budget: usize,
297        expand_deps: &str,
298        allow_trim: bool,
299        query: &str,
300    ) -> Result<PackResult> {
301        let mut selected: Vec<Value> = vec![];
302        let mut skipped: Vec<(Vec<Value>, f64, usize)> = vec![];
303        let mut skipped_reasons: HashMap<String, String> = HashMap::new();
304        let mut used_ids: HashSet<String> = HashSet::new();
305        let mut used_tokens: usize = 0;
306
307        for (fused, chunk) in scored {
308            let cid = chunk["id"].as_str().unwrap_or("").to_string();
309            if used_ids.contains(&cid) {
310                continue;
311            }
312
313            // Build block with dep expansion; fail-closed on dep issues.
314            let (block, dep_skip_reason) = self.build_dep_block(chunk, expand_deps)?;
315            if let Some(reason) = dep_skip_reason {
316                skipped_reasons.insert(cid, reason);
317                continue;
318            }
319
320            let new_block: Vec<Value> = block
321                .iter()
322                .filter(|b| !used_ids.contains(b["id"].as_str().unwrap_or("")))
323                .cloned()
324                .collect();
325            let cost = block_cost(&new_block);
326
327            if used_tokens + cost <= budget {
328                for b in &block {
329                    let bid = b["id"].as_str().unwrap_or("").to_string();
330                    if !used_ids.contains(&bid) {
331                        let mut b = b.clone();
332                        b["_fused_score"] = json!(fused);
333                        selected.push(b);
334                        used_ids.insert(bid);
335                    }
336                }
337                used_tokens += cost;
338            } else if allow_trim {
339                // Attempt refiner trim — NullRefiner returns None (no-op).
340                if let Some(trimmed) =
341                    self.refiner
342                        .trim(&block, query, budget.saturating_sub(used_tokens))
343                {
344                    let trim_cost = block_cost(&trimmed);
345                    if used_tokens + trim_cost <= budget {
346                        for b in &trimmed {
347                            let bid = b["id"].as_str().unwrap_or("").to_string();
348                            if !used_ids.contains(&bid) {
349                                let mut b = b.clone();
350                                b["_fused_score"] = json!(fused);
351                                b["_trimmed"] = json!(true);
352                                selected.push(b);
353                                used_ids.insert(bid);
354                            }
355                        }
356                        used_tokens += trim_cost;
357                        continue;
358                    }
359                }
360                skipped.push((block, *fused, cost));
361            } else {
362                skipped.push((block, *fused, cost));
363            }
364        }
365        Ok((selected, skipped, skipped_reasons))
366    }
367
368    /// Expand a seed chunk into a block according to `expand_deps`.
369    /// Returns `(block, Some(skip_reason))` if the block should be discarded (fail-closed).
370    fn build_dep_block(
371        &self,
372        seed: &Value,
373        expand_deps: &str,
374    ) -> Result<(Vec<Value>, Option<String>)> {
375        if expand_deps == "false" || expand_deps.is_empty() {
376            return Ok((vec![seed.clone()], None));
377        }
378        let seed_id = seed["id"].as_str().unwrap_or("");
379        match expand_deps {
380            "direct" => {
381                let deps = self.storage.get_deps(seed_id)?;
382                let mut block = vec![seed.clone()];
383                for (dep_id, kind, _) in &deps {
384                    if kind != "hard" {
385                        continue;
386                    }
387                    match self.validate_hard_dep(dep_id)? {
388                        Some(chunk) => block.push(chunk),
389                        None => return Ok((vec![], Some("hard_dep_unavailable".to_string()))),
390                    }
391                }
392                Ok((block, None))
393            }
394            "closure" => {
395                let mut block = vec![seed.clone()];
396                let mut visited: HashSet<String> = [seed_id.to_string()].into();
397                match self.expand_hard_closure(seed_id, &mut visited, &mut block, 0, 3)? {
398                    Some(reason) => Ok((vec![], Some(reason))),
399                    None => Ok((block, None)),
400                }
401            }
402            _ => Ok((vec![seed.clone()], None)),
403        }
404    }
405
406    /// Returns the chunk if the hard dep is usable, None if it should cause fail-closed.
407    fn validate_hard_dep(&self, dep_id: &str) -> Result<Option<Value>> {
408        match self.storage.get_chunk(dep_id)? {
409            None => Ok(None),
410            Some(chunk) => {
411                let state = chunk.get("state").and_then(Value::as_str).unwrap_or("");
412                let origin = chunk.get("origin").and_then(Value::as_str).unwrap_or("");
413                let embed_v = chunk
414                    .get("embed_version")
415                    .and_then(Value::as_i64)
416                    .unwrap_or(0);
417                if state == "archived" || origin == "spark" || embed_v == 0 {
418                    Ok(None)
419                } else {
420                    Ok(Some(chunk))
421                }
422            }
423        }
424    }
425
426    /// BFS hard-dep expansion up to `max_depth`. Returns Some(reason) on fail-closed.
427    fn expand_hard_closure(
428        &self,
429        id: &str,
430        visited: &mut HashSet<String>,
431        block: &mut Vec<Value>,
432        depth: usize,
433        max_depth: usize,
434    ) -> Result<Option<String>> {
435        if depth >= max_depth {
436            return Ok(Some("dep_depth_limit".to_string()));
437        }
438        let deps = self.storage.get_deps(id)?;
439        for (dep_id, kind, _) in &deps {
440            if kind != "hard" {
441                continue;
442            }
443            if visited.contains(dep_id) {
444                continue;
445            } // cycle guard
446            visited.insert(dep_id.clone());
447            match self.validate_hard_dep(dep_id)? {
448                None => return Ok(Some("hard_dep_unavailable".to_string())),
449                Some(chunk) => {
450                    block.push(chunk);
451                    if let Some(reason) =
452                        self.expand_hard_closure(dep_id, visited, block, depth + 1, max_depth)?
453                    {
454                        return Ok(Some(reason));
455                    }
456                }
457            }
458        }
459        Ok(None)
460    }
461
462    fn density_refill(
463        &self,
464        mut selected: Vec<Value>,
465        skipped: &[(Vec<Value>, f64, usize)],
466        budget: usize,
467    ) -> Vec<Value> {
468        let used_tokens = block_cost(&selected);
469        if used_tokens >= budget {
470            return selected;
471        }
472
473        let selected_ids: HashSet<String> = selected
474            .iter()
475            .filter_map(|c| c["id"].as_str().map(str::to_string))
476            .collect();
477
478        let mut density_items: Vec<(f64, Vec<Value>, usize)> = skipped
479            .iter()
480            .filter_map(|(block, fscore, _)| {
481                let block: Vec<Value> = block
482                    .iter()
483                    .filter(|b| !selected_ids.contains(b["id"].as_str().unwrap_or("")))
484                    .cloned()
485                    .collect();
486                if block.is_empty() {
487                    return None;
488                }
489                let cost = block_cost(&block);
490                let density = fscore / cost.max(1) as f64;
491                Some((density, block, cost))
492            })
493            .collect();
494        density_items.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
495
496        let mut used_tokens = block_cost(&selected);
497        let mut added_ids: HashSet<String> = selected_ids;
498        for (_, block, cost) in density_items {
499            if used_tokens + cost <= budget {
500                for b in block {
501                    let bid = b["id"].as_str().unwrap_or("").to_string();
502                    if !added_ids.contains(&bid) {
503                        selected.push(b);
504                        added_ids.insert(bid);
505                    }
506                }
507                used_tokens += cost;
508            }
509        }
510        selected
511    }
512
513    fn recall_sparks(&self, q_content: &[f32], q_trigger: &[f32]) -> Result<Vec<Value>> {
514        let embed_version = self
515            .storage
516            .get_meta("embed_version")?
517            .and_then(|v| v.parse::<i64>().ok())
518            .unwrap_or(1);
519
520        let content_res = self
521            .storage
522            .search_vec_content(q_content, self.top_k_candidates)?;
523        let trigger_res = self
524            .storage
525            .search_vec_trigger(q_trigger, self.top_k_candidates)?;
526
527        // Batch-fetch all candidate chunk IDs (mirrors the pattern in ann_candidates).
528        let all_ids: Vec<&str> = {
529            let mut seen = HashSet::new();
530            content_res
531                .iter()
532                .chain(trigger_res.iter())
533                .map(|(id, _)| id.as_str())
534                .filter(|id| seen.insert(*id))
535                .collect()
536        };
537        let chunks = self.storage.get_chunks_by_ids(&all_ids)?;
538
539        let mut spark_scores: HashMap<String, (f32, Value)> = HashMap::new();
540        for (cid, sim) in content_res.iter().chain(trigger_res.iter()) {
541            if let Some(chunk) = chunks.get(cid) {
542                if chunk.get("origin").and_then(Value::as_str) != Some("spark") {
543                    continue;
544                }
545                if chunk.get("state").and_then(Value::as_str) == Some("archived") {
546                    continue;
547                }
548                let maturity = chunk.get("maturity").and_then(Value::as_str).unwrap_or("");
549                if maturity == "promoted" || maturity == "dropped" {
550                    continue;
551                }
552                let ev = chunk
553                    .get("embed_version")
554                    .and_then(Value::as_i64)
555                    .unwrap_or(1);
556                if ev < embed_version {
557                    continue;
558                }
559                let entry = spark_scores
560                    .entry(cid.clone())
561                    .or_insert_with(|| (*sim, chunk.clone()));
562                if *sim > entry.0 {
563                    *entry = (*sim, chunk.clone());
564                }
565            }
566        }
567        let mut sparks: Vec<(f32, Value)> = spark_scores.into_values().collect();
568        sparks.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
569        Ok(sparks
570            .into_iter()
571            .take(self.top_k_candidates)
572            .map(|(_, c)| c)
573            .collect())
574    }
575
576    #[allow(clippy::too_many_arguments)]
577    fn write_recall_trace(
578        &self,
579        trace_id: &str,
580        query: &str,
581        scored: &[(f64, Value)],
582        visible: &[Value],
583        sparks: &[Value],
584        depth_skipped: &[String],
585        skipped_reasons: &HashMap<String, String>,
586        refine_mode: &str,
587        source: &str,
588        now: &str,
589    ) -> Result<()> {
590        let lib_id = self.storage.lib_id()?;
591        self.storage.begin_immediate()?;
592        let result = (|| -> Result<()> {
593            for (rank, (_, chunk)) in scored.iter().enumerate() {
594                let cid = chunk["id"].as_str().unwrap_or("");
595                let sim = chunk.get("_fused_score").and_then(Value::as_f64);
596                // For dep-skipped seeds, record their skip reason as refine_mode.
597                let rm = skipped_reasons
598                    .get(cid)
599                    .map(|r| format!("skipped:{r}"))
600                    .or_else(|| {
601                        if refine_mode != "off" && !refine_mode.is_empty() {
602                            Some(refine_mode.to_string())
603                        } else {
604                            None
605                        }
606                    });
607                self.storage.insert_usage_trace(
608                    trace_id,
609                    Some(cid),
610                    "retrieved",
611                    1.0,
612                    sim,
613                    rm.as_deref(),
614                    None,
615                    Some((rank + 1) as i64),
616                    None,
617                    source,
618                    now,
619                )?;
620            }
621            for (rank, chunk) in visible.iter().enumerate() {
622                let cid = chunk["id"].as_str().unwrap_or("");
623                self.storage.insert_usage_trace(
624                    trace_id,
625                    Some(cid),
626                    "selected",
627                    1.0,
628                    None,
629                    None,
630                    None,
631                    Some((rank + 1) as i64),
632                    None,
633                    source,
634                    now,
635                )?;
636                // Write 'refined' event for chunks that came through the trim path.
637                if chunk
638                    .get("_trimmed")
639                    .and_then(Value::as_bool)
640                    .unwrap_or(false)
641                {
642                    self.storage.insert_usage_trace(
643                        trace_id,
644                        Some(cid),
645                        "refined",
646                        1.0,
647                        None,
648                        Some("trim"),
649                        None,
650                        Some((rank + 1) as i64),
651                        None,
652                        source,
653                        now,
654                    )?;
655                }
656            }
657            // Write 'retrieved' events for sparks (for recurring-spark count tracking).
658            for (rank, chunk) in sparks.iter().enumerate() {
659                let cid = chunk["id"].as_str().unwrap_or("");
660                self.storage.insert_usage_trace(
661                    trace_id,
662                    Some(cid),
663                    "retrieved",
664                    1.0,
665                    None,
666                    Some("spark"),
667                    None,
668                    Some((rank + 1) as i64),
669                    None,
670                    source,
671                    now,
672                )?;
673            }
674            let snapshot = json!({
675                "retrieved": scored.iter().map(|(_, c)| c["id"].as_str().unwrap_or("")).collect::<Vec<_>>(),
676                "selected": visible.iter().map(|c| c["id"].as_str().unwrap_or("")).collect::<Vec<_>>(),
677                "sparks": sparks.iter().map(|c| c["id"].as_str().unwrap_or("")).collect::<Vec<_>>(),
678                "depth_skipped": depth_skipped,
679                "skipped_reasons": skipped_reasons,
680            });
681            let log = EpisodicLogRow {
682                id: gen_uuid(),
683                trace_id: trace_id.to_string(),
684                lib_id,
685                ts: now.to_string(),
686                query: Some(query.to_string()),
687                recall_snapshot: Some(snapshot.to_string()),
688                event_source: source.to_string(),
689                task_state: "recalled".to_string(),
690                usage_state: "unknown".to_string(),
691                context_key: Some(content_hash(&normalize_query(query))),
692                distill_state: "open".to_string(),
693                ..Default::default()
694            };
695            self.storage.upsert_episodic_log(&log)?;
696            self.storage.commit()
697        })();
698        if result.is_err() {
699            let _ = self.storage.rollback();
700        }
701        result
702    }
703}