Skip to main content

kg/
output.rs

1use std::cmp::Reverse;
2
3use nucleo_matcher::pattern::{CaseMatching, Normalization, Pattern};
4use nucleo_matcher::{Config, Matcher, Utf32Str};
5
6use crate::graph::{Edge, GraphFile, Node};
7use crate::index::Bm25Index;
8
9const BM25_K1: f64 = 1.5;
10const BM25_B: f64 = 0.75;
11
12#[derive(Debug, Clone, Copy)]
13pub enum FindMode {
14    Fuzzy,
15    Bm25,
16}
17
18pub fn render_find(
19    graph: &GraphFile,
20    queries: &[String],
21    limit: usize,
22    include_features: bool,
23    mode: FindMode,
24    full: bool,
25) -> String {
26    render_find_with_index(graph, queries, limit, include_features, mode, full, None)
27}
28
29pub fn render_find_with_index(
30    graph: &GraphFile,
31    queries: &[String],
32    limit: usize,
33    include_features: bool,
34    mode: FindMode,
35    full: bool,
36    index: Option<&Bm25Index>,
37) -> String {
38    let mut sections = Vec::new();
39    for query in queries {
40        let matches = find_matches_with_index(graph, query, limit, include_features, mode, index);
41        let mut lines = vec![format!("? {query} ({})", matches.len())];
42        for node in matches {
43            lines.push(render_node_block(graph, node, full));
44        }
45        sections.push(lines.join("\n"));
46    }
47    format!("{}\n", sections.join("\n\n"))
48}
49
50pub fn find_nodes(
51    graph: &GraphFile,
52    query: &str,
53    limit: usize,
54    include_features: bool,
55    mode: FindMode,
56) -> Vec<Node> {
57    find_matches_with_index(graph, query, limit, include_features, mode, None)
58        .into_iter()
59        .cloned()
60        .collect()
61}
62
63pub fn find_nodes_with_index(
64    graph: &GraphFile,
65    query: &str,
66    limit: usize,
67    include_features: bool,
68    mode: FindMode,
69    index: Option<&Bm25Index>,
70) -> Vec<Node> {
71    find_matches_with_index(graph, query, limit, include_features, mode, index)
72        .into_iter()
73        .cloned()
74        .collect()
75}
76
77pub fn count_find_results(
78    graph: &GraphFile,
79    queries: &[String],
80    limit: usize,
81    include_features: bool,
82    mode: FindMode,
83) -> usize {
84    count_find_results_with_index(graph, queries, limit, include_features, mode, None)
85}
86
87pub fn count_find_results_with_index(
88    graph: &GraphFile,
89    queries: &[String],
90    limit: usize,
91    include_features: bool,
92    mode: FindMode,
93    index: Option<&Bm25Index>,
94) -> usize {
95    let mut total = 0;
96    for query in queries {
97        let matches = find_matches_with_index(graph, query, limit, include_features, mode, index);
98        total += matches.len();
99    }
100    total
101}
102
103pub fn render_node(graph: &GraphFile, node: &Node, full: bool) -> String {
104    format!("{}\n", render_node_block(graph, node, full))
105}
106
107fn render_node_block(graph: &GraphFile, node: &Node, full: bool) -> String {
108    let mut lines = Vec::new();
109    lines.push(format!("# {} | {}", node.id, node.name));
110
111    if !node.properties.alias.is_empty() {
112        lines.push(format!("aka: {}", node.properties.alias.join(", ")));
113    }
114    if full {
115        if !node.properties.domain_area.is_empty() {
116            lines.push(format!("domain_area: {}", node.properties.domain_area));
117        }
118        if !node.properties.provenance.is_empty() {
119            lines.push(format!("provenance: {}", node.properties.provenance));
120        }
121        if let Some(confidence) = node.properties.confidence {
122            lines.push(format!("confidence: {confidence}"));
123        }
124        lines.push(format!("importance: {}", node.properties.importance));
125        if !node.properties.created_at.is_empty() {
126            lines.push(format!("created_at: {}", node.properties.created_at));
127        }
128    }
129
130    let facts_to_show = if full {
131        node.properties.key_facts.len()
132    } else {
133        node.properties.key_facts.len().min(2)
134    };
135    for fact in node.properties.key_facts.iter().take(facts_to_show) {
136        lines.push(format!("- {fact}"));
137    }
138    if node.properties.key_facts.len() > facts_to_show || full {
139        lines.push(format!("({} facts total)", node.properties.key_facts.len()));
140    }
141
142    let note_count = graph
143        .notes
144        .iter()
145        .filter(|note| note.node_id == node.id)
146        .count();
147    if full && note_count > 0 {
148        lines.push(format!("notes: {note_count}"));
149    }
150
151    for edge in outgoing_edges(graph, &node.id, full) {
152        if let Some(target) = graph.node_by_id(&edge.target_id) {
153            lines.push(format_edge("->", edge, target));
154        }
155    }
156    for edge in incoming_edges(graph, &node.id, full) {
157        if let Some(source) = graph.node_by_id(&edge.source_id) {
158            lines.push(format_edge("<-", edge, source));
159        }
160    }
161
162    lines.join("\n")
163}
164
165fn outgoing_edges<'a>(graph: &'a GraphFile, node_id: &str, full: bool) -> Vec<&'a Edge> {
166    let mut edges: Vec<&Edge> = graph
167        .edges
168        .iter()
169        .filter(|edge| edge.source_id == node_id)
170        .collect();
171    edges.sort_by_key(|edge| (&edge.relation, &edge.target_id));
172    if !full {
173        edges.truncate(3);
174    }
175    edges
176}
177
178fn incoming_edges<'a>(graph: &'a GraphFile, node_id: &str, full: bool) -> Vec<&'a Edge> {
179    let mut edges: Vec<&Edge> = graph
180        .edges
181        .iter()
182        .filter(|edge| edge.target_id == node_id)
183        .collect();
184    edges.sort_by_key(|edge| (&edge.relation, &edge.source_id));
185    if !full {
186        edges.truncate(3);
187    }
188    edges
189}
190
191fn format_edge(prefix: &str, edge: &Edge, related: &Node) -> String {
192    let (arrow, relation) = if edge.relation.starts_with("NOT_") {
193        (
194            format!("{prefix}!"),
195            edge.relation.trim_start_matches("NOT_"),
196        )
197    } else {
198        (prefix.to_owned(), edge.relation.as_str())
199    };
200
201    let mut line = format!("{arrow} {relation} | {} | {}", related.id, related.name);
202    if !edge.properties.detail.is_empty() {
203        line.push_str(" | ");
204        line.push_str(&truncate(&edge.properties.detail, 80));
205    }
206    line
207}
208
209fn truncate(value: &str, max_len: usize) -> String {
210    let char_count = value.chars().count();
211    if char_count <= max_len {
212        return value.to_owned();
213    }
214    let truncated: String = value.chars().take(max_len.saturating_sub(3)).collect();
215    format!("{truncated}...")
216}
217
218fn find_matches_with_index<'a>(
219    graph: &'a GraphFile,
220    query: &str,
221    limit: usize,
222    include_features: bool,
223    mode: FindMode,
224    index: Option<&Bm25Index>,
225) -> Vec<&'a Node> {
226    let mut scored: Vec<(i64, Reverse<&str>, &'a Node)> = match mode {
227        FindMode::Fuzzy => {
228            let pattern = Pattern::parse(query, CaseMatching::Ignore, Normalization::Smart);
229            let mut matcher = Matcher::new(Config::DEFAULT);
230            graph
231                .nodes
232                .iter()
233                .filter(|node| include_features || node.r#type != "Feature")
234                .filter_map(|node| {
235                    score_node(node, query, &pattern, &mut matcher).map(|score| {
236                        let base = score as i64;
237                        let boost = feedback_boost(node);
238                        (base + boost, Reverse(node.id.as_str()), node)
239                    })
240                })
241                .collect()
242        }
243        FindMode::Bm25 => score_bm25(graph, query, include_features, index),
244    };
245
246    scored.sort_by(|left, right| right.0.cmp(&left.0).then_with(|| left.1.cmp(&right.1)));
247    scored
248        .into_iter()
249        .take(limit)
250        .map(|(_, _, node)| node)
251        .collect()
252}
253
254fn feedback_boost(node: &Node) -> i64 {
255    let count = node.properties.feedback_count as f64;
256    if count <= 0.0 {
257        return 0;
258    }
259    let avg = node.properties.feedback_score / count;
260    let confidence = (count.ln_1p() / 3.0).min(1.0);
261    let scaled = avg * 200.0 * confidence;
262    scaled.clamp(-300.0, 300.0).round() as i64
263}
264
265fn score_bm25<'a>(
266    graph: &'a GraphFile,
267    query: &str,
268    include_features: bool,
269    index: Option<&Bm25Index>,
270) -> Vec<(i64, Reverse<&'a str>, &'a Node)> {
271    let terms = tokenize(query);
272    if terms.is_empty() {
273        return Vec::new();
274    }
275
276    if let Some(idx) = index {
277        let results = idx.search(&terms, graph);
278        return results
279            .into_iter()
280            .filter_map(|(node_id, score)| {
281                let node = graph.node_by_id(&node_id)?;
282                if !include_features && node.r#type == "Feature" {
283                    return None;
284                }
285                let boost = feedback_boost(node) as f64;
286                let combined = (score as f64 * 100.0 + boost).round() as i64;
287                Some((combined, Reverse(node.id.as_str()), node))
288            })
289            .collect();
290    }
291
292    let mut docs: Vec<(&'a Node, Vec<String>)> = graph
293        .nodes
294        .iter()
295        .filter(|node| include_features || node.r#type != "Feature")
296        .map(|node| (node, tokenize(&node_document_text(graph, node))))
297        .collect();
298
299    if docs.is_empty() {
300        return Vec::new();
301    }
302
303    let mut df: std::collections::HashMap<&str, usize> = std::collections::HashMap::new();
304    for term in &terms {
305        let mut count = 0usize;
306        for (_, tokens) in &docs {
307            if tokens.iter().any(|t| t == term) {
308                count += 1;
309            }
310        }
311        df.insert(term.as_str(), count);
312    }
313
314    let total_docs = docs.len() as f64;
315    let avgdl = docs
316        .iter()
317        .map(|(_, tokens)| tokens.len() as f64)
318        .sum::<f64>()
319        / total_docs;
320
321    let mut scored = Vec::new();
322
323    for (node, tokens) in docs.drain(..) {
324        let dl = tokens.len() as f64;
325        if dl == 0.0 {
326            continue;
327        }
328        let mut score = 0.0f64;
329        for term in &terms {
330            let tf = tokens.iter().filter(|t| *t == term).count() as f64;
331            if tf == 0.0 {
332                continue;
333            }
334            let df_t = *df.get(term.as_str()).unwrap_or(&0) as f64;
335            let idf = (1.0 + (total_docs - df_t + 0.5) / (df_t + 0.5)).ln();
336            let denom = tf + BM25_K1 * (1.0 - BM25_B + BM25_B * (dl / avgdl));
337            score += idf * (tf * (BM25_K1 + 1.0) / denom);
338        }
339        if score > 0.0 {
340            let boost = feedback_boost(node) as f64;
341            let combined = score * 100.0 + boost;
342            scored.push((combined.round() as i64, Reverse(node.id.as_str()), node));
343        }
344    }
345
346    scored
347}
348
349fn node_document_text(graph: &GraphFile, node: &Node) -> String {
350    let mut out = String::new();
351    push_field(&mut out, &node.id);
352    push_field(&mut out, &node.name);
353    push_field(&mut out, &node.properties.description);
354    for alias in &node.properties.alias {
355        push_field(&mut out, alias);
356    }
357    for fact in &node.properties.key_facts {
358        push_field(&mut out, fact);
359    }
360    for note in graph.notes.iter().filter(|note| note.node_id == node.id) {
361        push_field(&mut out, &note.body);
362        for tag in &note.tags {
363            push_field(&mut out, tag);
364        }
365    }
366    out
367}
368
369fn push_field(target: &mut String, value: &str) {
370    if value.is_empty() {
371        return;
372    }
373    if !target.is_empty() {
374        target.push(' ');
375    }
376    target.push_str(value);
377}
378
379fn tokenize(text: &str) -> Vec<String> {
380    let mut tokens = Vec::new();
381    let mut current = String::new();
382    for ch in text.chars() {
383        if ch.is_alphanumeric() {
384            current.push(ch.to_ascii_lowercase());
385        } else if !current.is_empty() {
386            tokens.push(current.clone());
387            current.clear();
388        }
389    }
390    if !current.is_empty() {
391        tokens.push(current);
392    }
393    tokens
394}
395
396fn score_node(node: &Node, query: &str, pattern: &Pattern, matcher: &mut Matcher) -> Option<u32> {
397    let mut total = 0;
398    let mut primary_hits = 0;
399
400    let id_score = score_primary_field(query, pattern, matcher, &node.id, 4);
401    if id_score > 0 {
402        primary_hits += 1;
403    }
404    total += id_score;
405
406    let name_score = score_primary_field(query, pattern, matcher, &node.name, 3);
407    if name_score > 0 {
408        primary_hits += 1;
409    }
410    total += name_score;
411
412    for alias in &node.properties.alias {
413        let alias_score = score_primary_field(query, pattern, matcher, alias, 3);
414        if alias_score > 0 {
415            primary_hits += 1;
416        }
417        total += alias_score;
418    }
419
420    if primary_hits > 0 {
421        total += score_secondary_field(query, pattern, matcher, &node.properties.description, 1);
422    }
423
424    (total > 0).then_some(total)
425}
426
427fn score_field(pattern: &Pattern, matcher: &mut Matcher, value: &str) -> Option<u32> {
428    if value.is_empty() {
429        return None;
430    }
431    let mut buf = Vec::new();
432    let haystack = Utf32Str::new(value, &mut buf);
433    pattern.score(haystack, matcher)
434}
435
436fn score_primary_field(
437    query: &str,
438    pattern: &Pattern,
439    matcher: &mut Matcher,
440    value: &str,
441    weight: u32,
442) -> u32 {
443    let bonus = textual_bonus(query, value);
444    if bonus == 0 {
445        return 0;
446    }
447    let fuzzy = score_field(pattern, matcher, value).unwrap_or(0);
448    (fuzzy + bonus) * weight
449}
450
451fn score_secondary_field(
452    query: &str,
453    pattern: &Pattern,
454    matcher: &mut Matcher,
455    value: &str,
456    weight: u32,
457) -> u32 {
458    let bonus = textual_bonus(query, value);
459    if bonus == 0 {
460        return 0;
461    }
462    let fuzzy = score_field(pattern, matcher, value).unwrap_or(0);
463    (fuzzy + bonus / 2) * weight
464}
465
466fn textual_bonus(query: &str, value: &str) -> u32 {
467    let query = query.trim().to_lowercase();
468    let value = value.to_lowercase();
469
470    if value == query {
471        return 400;
472    }
473    if value.contains(&query) {
474        return 200;
475    }
476
477    query
478        .split_whitespace()
479        .map(|token| {
480            if value.contains(token) {
481                80
482            } else if is_subsequence(token, &value) {
483                40
484            } else {
485                0
486            }
487        })
488        .sum()
489}
490
491fn is_subsequence(needle: &str, haystack: &str) -> bool {
492    if needle.is_empty() {
493        return false;
494    }
495
496    let mut chars = needle.chars();
497    let mut current = match chars.next() {
498        Some(ch) => ch,
499        None => return false,
500    };
501
502    for ch in haystack.chars() {
503        if ch == current {
504            match chars.next() {
505                Some(next) => current = next,
506                None => return true,
507            }
508        }
509    }
510
511    false
512}