Skip to main content

innate_core/kb/
recall.rs

1use super::*;
2
3impl KnowledgeBase {
4    #[allow(clippy::too_many_arguments)]
5    pub fn recall(
6        &self,
7        query: &str,
8        budget: usize,
9        trace: bool,
10        include_sparks: bool,
11        top: Option<usize>,
12        source: &str,
13        expand_deps: &str, // "false" | "direct" | "closure"
14        allow_trim: bool,  // if true, invoke Refiner::trim when block doesn't fit
15        refine_mode: &str, // "off" | "trim" | "adapt" — recorded in trace
16    ) -> Result<RecallResult> {
17        validate_source(source)?;
18        let trace_id = gen_uuid();
19        let now = utc_now_iso();
20
21        let (q_content, q_trigger) = self
22            .embedding
23            .embed_both(query)
24            .map_err(|e| InnateError::EmbeddingUnavailable(e.to_string()))?;
25
26        // ANN candidates (non-spark)
27        let mut candidates = self.ann_candidates(&q_content, &q_trigger)?;
28        self.apply_soft_dep_bonus(&mut candidates)?;
29
30        // Score + anti-trigger penalty
31        let scored = self.score_candidates(candidates, query)?;
32
33        // First-fit pack with dep expansion
34        let (selected, skipped, skipped_reasons) =
35            self.pack(&scored, budget, expand_deps, allow_trim, query)?;
36
37        let depth_skipped: Vec<String> = skipped_reasons
38            .iter()
39            .filter(|(_, r)| r.as_str() == "dep_depth_limit")
40            .map(|(id, _)| id.clone())
41            .collect();
42
43        // Density refill
44        let mut selected = selected;
45        if self.density_refill {
46            selected = self.density_refill(selected, &skipped, budget);
47        }
48
49        let limited = limit_knowledge(selected, top);
50        let visible = if refine_mode == "adapt" {
51            self.refiner
52                .refine(limited.clone(), Some(budget))
53                .unwrap_or(limited)
54        } else {
55            limited
56        };
57
58        // Sparks
59        let sparks = if include_sparks {
60            self.recall_sparks(&q_content, &q_trigger)?
61        } else {
62            vec![]
63        };
64
65        if trace {
66            self.write_recall_trace(
67                &trace_id,
68                query,
69                &scored,
70                &visible,
71                &sparks,
72                &depth_skipped,
73                &skipped_reasons,
74                refine_mode,
75                source,
76                &now,
77            )?;
78        }
79
80        let empty = visible.is_empty() && sparks.is_empty();
81        Ok(RecallResult {
82            knowledge: visible,
83            sparks,
84            trace_id,
85            empty,
86            depth_skipped,
87            skipped_reasons,
88        })
89    }
90
91    fn ann_candidates(
92        &self,
93        q_content: &[f32],
94        q_trigger: &[f32],
95    ) -> Result<HashMap<String, CandidateInfo>> {
96        let embed_version = self
97            .storage
98            .get_meta("embed_version")?
99            .and_then(|v| v.parse::<i64>().ok())
100            .unwrap_or(1);
101
102        let content_res = self
103            .storage
104            .search_vec_content(q_content, self.top_k_candidates * 2)?;
105        let trigger_res = self
106            .storage
107            .search_vec_trigger(q_trigger, self.top_k_candidates * 2)?;
108
109        // Collect unique ids and batch-fetch all chunks in two queries instead of N individual ones.
110        let all_ids: Vec<&str> = {
111            let mut seen = HashSet::new();
112            content_res
113                .iter()
114                .chain(trigger_res.iter())
115                .map(|(id, _)| id.as_str())
116                .filter(|id| seen.insert(*id))
117                .collect()
118        };
119        let chunks = self.storage.get_chunks_by_ids(&all_ids)?;
120
121        let mut candidates: HashMap<String, CandidateInfo> = HashMap::new();
122
123        for (cid, sim) in &content_res {
124            if let Some(chunk) = chunks.get(cid) {
125                if chunk_is_valid_for_recall(chunk, embed_version) {
126                    let e = candidates
127                        .entry(cid.clone())
128                        .or_insert_with(|| CandidateInfo {
129                            chunk: chunk.clone(),
130                            sim_content: 0.0,
131                            sim_trigger: 0.0,
132                        });
133                    e.sim_content = e.sim_content.max(*sim);
134                }
135            }
136        }
137        for (cid, sim) in &trigger_res {
138            if let Some(chunk) = chunks.get(cid) {
139                if chunk_is_valid_for_recall(chunk, embed_version) {
140                    let e = candidates
141                        .entry(cid.clone())
142                        .or_insert_with(|| CandidateInfo {
143                            chunk: chunk.clone(),
144                            sim_content: 0.0,
145                            sim_trigger: 0.0,
146                        });
147                    e.sim_trigger = e.sim_trigger.max(*sim);
148                }
149            }
150        }
151        Ok(candidates)
152    }
153
154    fn apply_soft_dep_bonus(&self, candidates: &mut HashMap<String, CandidateInfo>) -> Result<()> {
155        // Collect non-spark candidate ids and batch-fetch their outgoing deps
156        // in a single query (was one get_deps per candidate).
157        let src_ids: Vec<String> = candidates
158            .iter()
159            .filter(|(_, info)| {
160                info.chunk.get("origin").and_then(Value::as_str) != Some("spark")
161            })
162            .map(|(cid, _)| cid.clone())
163            .collect();
164        if src_ids.is_empty() {
165            return Ok(());
166        }
167        let src_refs: Vec<&str> = src_ids.iter().map(String::as_str).collect();
168        let deps_map = self.storage.get_deps_batch(&src_refs)?;
169
170        // Gather distinct soft-dep targets and batch-fetch them in one query
171        // (was one get_chunk per soft edge).
172        let mut target_ids: Vec<String> = Vec::new();
173        let mut seen: HashSet<String> = HashSet::new();
174        for deps in deps_map.values() {
175            for (dst, kind, _) in deps {
176                if kind == "soft" && seen.insert(dst.clone()) {
177                    target_ids.push(dst.clone());
178                }
179            }
180        }
181        if target_ids.is_empty() {
182            return Ok(());
183        }
184        let target_refs: Vec<&str> = target_ids.iter().map(String::as_str).collect();
185        let targets = self.storage.get_chunks_by_ids(&target_refs)?;
186
187        for src in &src_ids {
188            let Some(deps) = deps_map.get(src) else {
189                continue;
190            };
191            for (dst, kind, _) in deps {
192                if kind != "soft" {
193                    continue;
194                }
195                let Some(target) = targets.get(dst) else {
196                    continue;
197                };
198                if target.get("state").and_then(Value::as_str) == Some("archived") {
199                    continue;
200                }
201                if target.get("origin").and_then(Value::as_str) == Some("spark") {
202                    continue;
203                }
204                let e = candidates
205                    .entry(dst.clone())
206                    .or_insert_with(|| CandidateInfo {
207                        chunk: target.clone(),
208                        sim_content: 0.0,
209                        sim_trigger: 0.0,
210                    });
211                e.sim_content = (e.sim_content + 0.05).min(1.0);
212            }
213        }
214        Ok(())
215    }
216
217    fn score_candidates(
218        &self,
219        candidates: HashMap<String, CandidateInfo>,
220        query: &str,
221    ) -> Result<Vec<(f64, Value)>> {
222        let context_key = content_hash(&normalize_query(query));
223        // Batch-fetch context scores for all candidates in one query
224        // (was one context_score lookup per candidate).
225        let cand_ids: Vec<String> = candidates
226            .values()
227            .filter_map(|info| info.chunk.get("id").and_then(Value::as_str).map(str::to_string))
228            .collect();
229        let cand_refs: Vec<&str> = cand_ids.iter().map(String::as_str).collect();
230        let ctx_scores = self.storage.context_scores_batch(&cand_refs, &context_key)?;
231
232        let mut scored: Vec<(f64, Value)> = Vec::with_capacity(candidates.len());
233        for info in candidates.into_values() {
234            let conf = info
235                .chunk
236                .get("confidence")
237                .and_then(Value::as_f64)
238                .unwrap_or(0.5);
239            let chunk_id = info.chunk.get("id").and_then(Value::as_str).unwrap_or("");
240            let context_score = ctx_scores.get(chunk_id).copied().unwrap_or(0.0);
241            let mut fused = self.w_content * info.sim_content as f64
242                + self.w_trigger * info.sim_trigger as f64
243                + self.w_confidence * conf
244                + self.w_context * context_score;
245            if info.chunk.get("state").and_then(Value::as_str) == Some("pending") {
246                fused *= PENDING_RECALL_PENALTY;
247            }
248            let anti = info
249                .chunk
250                .get("anti_trigger_desc")
251                .and_then(Value::as_str)
252                .unwrap_or("");
253            if !anti.is_empty() && anti_trigger_hit(query, anti) {
254                fused *= self.anti_trigger_penalty;
255            }
256            let mut chunk = info.chunk;
257            chunk["_context_score"] = json!(context_score);
258            chunk["_fused_score"] = json!(fused);
259            scored.push((fused, chunk));
260        }
261        scored.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
262        scored.truncate(self.top_k_candidates);
263        Ok(scored)
264    }
265
266    fn pack(
267        &self,
268        scored: &[(f64, Value)],
269        budget: usize,
270        expand_deps: &str,
271        allow_trim: bool,
272        query: &str,
273    ) -> Result<PackResult> {
274        let mut selected: Vec<Value> = vec![];
275        let mut skipped: Vec<(Vec<Value>, f64, usize)> = vec![];
276        let mut skipped_reasons: HashMap<String, String> = HashMap::new();
277        let mut used_ids: HashSet<String> = HashSet::new();
278        let mut used_tokens: usize = 0;
279
280        for (fused, chunk) in scored {
281            let cid = chunk["id"].as_str().unwrap_or("").to_string();
282            if used_ids.contains(&cid) {
283                continue;
284            }
285
286            // Build block with dep expansion; fail-closed on dep issues.
287            let (block, dep_skip_reason) = self.build_dep_block(chunk, expand_deps)?;
288            if let Some(reason) = dep_skip_reason {
289                skipped_reasons.insert(cid, reason);
290                continue;
291            }
292
293            let new_block: Vec<Value> = block
294                .iter()
295                .filter(|b| !used_ids.contains(b["id"].as_str().unwrap_or("")))
296                .cloned()
297                .collect();
298            let cost = block_cost(&new_block);
299
300            if used_tokens + cost <= budget {
301                for b in &block {
302                    let bid = b["id"].as_str().unwrap_or("").to_string();
303                    if !used_ids.contains(&bid) {
304                        let mut b = b.clone();
305                        b["_fused_score"] = json!(fused);
306                        selected.push(b);
307                        used_ids.insert(bid);
308                    }
309                }
310                used_tokens += cost;
311            } else if allow_trim {
312                // Attempt refiner trim — NullRefiner returns None (no-op).
313                if let Some(trimmed) =
314                    self.refiner
315                        .trim(&block, query, budget.saturating_sub(used_tokens))
316                {
317                    let trim_cost = block_cost(&trimmed);
318                    if used_tokens + trim_cost <= budget {
319                        for b in &trimmed {
320                            let bid = b["id"].as_str().unwrap_or("").to_string();
321                            if !used_ids.contains(&bid) {
322                                let mut b = b.clone();
323                                b["_fused_score"] = json!(fused);
324                                b["_trimmed"] = json!(true);
325                                selected.push(b);
326                                used_ids.insert(bid);
327                            }
328                        }
329                        used_tokens += trim_cost;
330                        continue;
331                    }
332                }
333                skipped.push((block, *fused, cost));
334            } else {
335                skipped.push((block, *fused, cost));
336            }
337        }
338        Ok((selected, skipped, skipped_reasons))
339    }
340
341    /// Expand a seed chunk into a block according to `expand_deps`.
342    /// Returns `(block, Some(skip_reason))` if the block should be discarded (fail-closed).
343    fn build_dep_block(
344        &self,
345        seed: &Value,
346        expand_deps: &str,
347    ) -> Result<(Vec<Value>, Option<String>)> {
348        if expand_deps == "false" || expand_deps.is_empty() {
349            return Ok((vec![seed.clone()], None));
350        }
351        let seed_id = seed["id"].as_str().unwrap_or("");
352        match expand_deps {
353            "direct" => {
354                let deps = self.storage.get_deps(seed_id)?;
355                let mut block = vec![seed.clone()];
356                for (dep_id, kind, _) in &deps {
357                    if kind != "hard" {
358                        continue;
359                    }
360                    match self.validate_hard_dep(dep_id)? {
361                        Some(chunk) => block.push(chunk),
362                        None => return Ok((vec![], Some("hard_dep_unavailable".to_string()))),
363                    }
364                }
365                Ok((block, None))
366            }
367            "closure" => {
368                let mut block = vec![seed.clone()];
369                let mut visited: HashSet<String> = [seed_id.to_string()].into();
370                match self.expand_hard_closure(seed_id, &mut visited, &mut block, 0, 3)? {
371                    Some(reason) => Ok((vec![], Some(reason))),
372                    None => Ok((block, None)),
373                }
374            }
375            _ => Ok((vec![seed.clone()], None)),
376        }
377    }
378
379    /// Returns the chunk if the hard dep is usable, None if it should cause fail-closed.
380    fn validate_hard_dep(&self, dep_id: &str) -> Result<Option<Value>> {
381        match self.storage.get_chunk(dep_id)? {
382            None => Ok(None),
383            Some(chunk) => {
384                let state = chunk.get("state").and_then(Value::as_str).unwrap_or("");
385                let origin = chunk.get("origin").and_then(Value::as_str).unwrap_or("");
386                let embed_v = chunk
387                    .get("embed_version")
388                    .and_then(Value::as_i64)
389                    .unwrap_or(0);
390                if state == "archived" || origin == "spark" || embed_v == 0 {
391                    Ok(None)
392                } else {
393                    Ok(Some(chunk))
394                }
395            }
396        }
397    }
398
399    /// BFS hard-dep expansion up to `max_depth`. Returns Some(reason) on fail-closed.
400    fn expand_hard_closure(
401        &self,
402        id: &str,
403        visited: &mut HashSet<String>,
404        block: &mut Vec<Value>,
405        depth: usize,
406        max_depth: usize,
407    ) -> Result<Option<String>> {
408        if depth >= max_depth {
409            return Ok(Some("dep_depth_limit".to_string()));
410        }
411        let deps = self.storage.get_deps(id)?;
412        for (dep_id, kind, _) in &deps {
413            if kind != "hard" {
414                continue;
415            }
416            if visited.contains(dep_id) {
417                continue;
418            } // cycle guard
419            visited.insert(dep_id.clone());
420            match self.validate_hard_dep(dep_id)? {
421                None => return Ok(Some("hard_dep_unavailable".to_string())),
422                Some(chunk) => {
423                    block.push(chunk);
424                    if let Some(reason) =
425                        self.expand_hard_closure(dep_id, visited, block, depth + 1, max_depth)?
426                    {
427                        return Ok(Some(reason));
428                    }
429                }
430            }
431        }
432        Ok(None)
433    }
434
435    fn density_refill(
436        &self,
437        mut selected: Vec<Value>,
438        skipped: &[(Vec<Value>, f64, usize)],
439        budget: usize,
440    ) -> Vec<Value> {
441        let used_tokens = block_cost(&selected);
442        if used_tokens >= budget {
443            return selected;
444        }
445
446        let selected_ids: HashSet<String> = selected
447            .iter()
448            .filter_map(|c| c["id"].as_str().map(str::to_string))
449            .collect();
450
451        let mut density_items: Vec<(f64, Vec<Value>, usize)> = skipped
452            .iter()
453            .filter_map(|(block, fscore, _)| {
454                let block: Vec<Value> = block
455                    .iter()
456                    .filter(|b| !selected_ids.contains(b["id"].as_str().unwrap_or("")))
457                    .cloned()
458                    .collect();
459                if block.is_empty() {
460                    return None;
461                }
462                let cost = block_cost(&block);
463                let density = fscore / cost.max(1) as f64;
464                Some((density, block, cost))
465            })
466            .collect();
467        density_items.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
468
469        let mut used_tokens = block_cost(&selected);
470        let mut added_ids: HashSet<String> = selected_ids;
471        for (_, block, cost) in density_items {
472            if used_tokens + cost <= budget {
473                for b in block {
474                    let bid = b["id"].as_str().unwrap_or("").to_string();
475                    if !added_ids.contains(&bid) {
476                        selected.push(b);
477                        added_ids.insert(bid);
478                    }
479                }
480                used_tokens += cost;
481            }
482        }
483        selected
484    }
485
486    fn recall_sparks(&self, q_content: &[f32], q_trigger: &[f32]) -> Result<Vec<Value>> {
487        let embed_version = self
488            .storage
489            .get_meta("embed_version")?
490            .and_then(|v| v.parse::<i64>().ok())
491            .unwrap_or(1);
492
493        let content_res = self
494            .storage
495            .search_vec_content(q_content, self.top_k_candidates)?;
496        let trigger_res = self
497            .storage
498            .search_vec_trigger(q_trigger, self.top_k_candidates)?;
499
500        // Batch-fetch all candidate chunk IDs (mirrors the pattern in ann_candidates).
501        let all_ids: Vec<&str> = {
502            let mut seen = HashSet::new();
503            content_res
504                .iter()
505                .chain(trigger_res.iter())
506                .map(|(id, _)| id.as_str())
507                .filter(|id| seen.insert(*id))
508                .collect()
509        };
510        let chunks = self.storage.get_chunks_by_ids(&all_ids)?;
511
512        let mut spark_scores: HashMap<String, (f32, Value)> = HashMap::new();
513        for (cid, sim) in content_res.iter().chain(trigger_res.iter()) {
514            if let Some(chunk) = chunks.get(cid) {
515                if chunk.get("origin").and_then(Value::as_str) != Some("spark") {
516                    continue;
517                }
518                if chunk.get("state").and_then(Value::as_str) == Some("archived") {
519                    continue;
520                }
521                let maturity = chunk.get("maturity").and_then(Value::as_str).unwrap_or("");
522                if maturity == "promoted" || maturity == "dropped" {
523                    continue;
524                }
525                let ev = chunk
526                    .get("embed_version")
527                    .and_then(Value::as_i64)
528                    .unwrap_or(1);
529                if ev < embed_version {
530                    continue;
531                }
532                let entry = spark_scores
533                    .entry(cid.clone())
534                    .or_insert_with(|| (*sim, chunk.clone()));
535                if *sim > entry.0 {
536                    *entry = (*sim, chunk.clone());
537                }
538            }
539        }
540        let mut sparks: Vec<(f32, Value)> = spark_scores.into_values().collect();
541        sparks.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
542        Ok(sparks
543            .into_iter()
544            .take(self.top_k_candidates)
545            .map(|(_, c)| c)
546            .collect())
547    }
548
549    #[allow(clippy::too_many_arguments)]
550    fn write_recall_trace(
551        &self,
552        trace_id: &str,
553        query: &str,
554        scored: &[(f64, Value)],
555        visible: &[Value],
556        sparks: &[Value],
557        depth_skipped: &[String],
558        skipped_reasons: &HashMap<String, String>,
559        refine_mode: &str,
560        source: &str,
561        now: &str,
562    ) -> Result<()> {
563        let lib_id = self.storage.lib_id()?;
564        self.storage.begin_immediate()?;
565        let result = (|| -> Result<()> {
566            for (rank, (_, chunk)) in scored.iter().enumerate() {
567                let cid = chunk["id"].as_str().unwrap_or("");
568                let sim = chunk.get("_fused_score").and_then(Value::as_f64);
569                // For dep-skipped seeds, record their skip reason as refine_mode.
570                let rm = skipped_reasons
571                    .get(cid)
572                    .map(|r| format!("skipped:{r}"))
573                    .or_else(|| {
574                        if refine_mode != "off" && !refine_mode.is_empty() {
575                            Some(refine_mode.to_string())
576                        } else {
577                            None
578                        }
579                    });
580                self.storage.insert_usage_trace(
581                    trace_id,
582                    Some(cid),
583                    "retrieved",
584                    1.0,
585                    sim,
586                    rm.as_deref(),
587                    None,
588                    Some((rank + 1) as i64),
589                    None,
590                    source,
591                    now,
592                )?;
593            }
594            for (rank, chunk) in visible.iter().enumerate() {
595                let cid = chunk["id"].as_str().unwrap_or("");
596                self.storage.insert_usage_trace(
597                    trace_id,
598                    Some(cid),
599                    "selected",
600                    1.0,
601                    None,
602                    None,
603                    None,
604                    Some((rank + 1) as i64),
605                    None,
606                    source,
607                    now,
608                )?;
609                // Write 'refined' event for chunks that came through the trim path.
610                if chunk
611                    .get("_trimmed")
612                    .and_then(Value::as_bool)
613                    .unwrap_or(false)
614                {
615                    self.storage.insert_usage_trace(
616                        trace_id,
617                        Some(cid),
618                        "refined",
619                        1.0,
620                        None,
621                        Some("trim"),
622                        None,
623                        Some((rank + 1) as i64),
624                        None,
625                        source,
626                        now,
627                    )?;
628                }
629            }
630            // Write 'retrieved' events for sparks (for recurring-spark count tracking).
631            for (rank, chunk) in sparks.iter().enumerate() {
632                let cid = chunk["id"].as_str().unwrap_or("");
633                self.storage.insert_usage_trace(
634                    trace_id,
635                    Some(cid),
636                    "retrieved",
637                    1.0,
638                    None,
639                    Some("spark"),
640                    None,
641                    Some((rank + 1) as i64),
642                    None,
643                    source,
644                    now,
645                )?;
646            }
647            let snapshot = json!({
648                "retrieved": scored.iter().map(|(_, c)| c["id"].as_str().unwrap_or("")).collect::<Vec<_>>(),
649                "selected": visible.iter().map(|c| c["id"].as_str().unwrap_or("")).collect::<Vec<_>>(),
650                "sparks": sparks.iter().map(|c| c["id"].as_str().unwrap_or("")).collect::<Vec<_>>(),
651                "depth_skipped": depth_skipped,
652                "skipped_reasons": skipped_reasons,
653            });
654            let log = EpisodicLogRow {
655                id: gen_uuid(),
656                trace_id: trace_id.to_string(),
657                lib_id,
658                ts: now.to_string(),
659                query: Some(query.to_string()),
660                recall_snapshot: Some(snapshot.to_string()),
661                event_source: source.to_string(),
662                task_state: "recalled".to_string(),
663                usage_state: "unknown".to_string(),
664                context_key: Some(content_hash(&normalize_query(query))),
665                distill_state: "open".to_string(),
666                ..Default::default()
667            };
668            self.storage.upsert_episodic_log(&log)?;
669            self.storage.commit()
670        })();
671        if result.is_err() {
672            let _ = self.storage.rollback();
673        }
674        result
675    }
676}