Skip to main content

kg/
output.rs

1use std::cmp::Reverse;
2
3use nucleo_matcher::pattern::{CaseMatching, Normalization, Pattern};
4use nucleo_matcher::{Config, Matcher, Utf32Str};
5
6use crate::graph::{Edge, GraphFile, Node};
7use crate::index::Bm25Index;
8
9const BM25_K1: f64 = 1.5;
10const BM25_B: f64 = 0.75;
11
12#[derive(Debug, Clone, Copy)]
13pub enum FindMode {
14    Fuzzy,
15    Bm25,
16}
17
18pub fn render_find(
19    graph: &GraphFile,
20    queries: &[String],
21    limit: usize,
22    include_features: bool,
23    mode: FindMode,
24    full: bool,
25) -> String {
26    render_find_with_index(graph, queries, limit, include_features, mode, full, None)
27}
28
29pub fn render_find_with_index(
30    graph: &GraphFile,
31    queries: &[String],
32    limit: usize,
33    include_features: bool,
34    mode: FindMode,
35    full: bool,
36    index: Option<&Bm25Index>,
37) -> String {
38    let mut sections = Vec::new();
39    for query in queries {
40        let matches = find_matches_with_index(graph, query, limit, include_features, mode, index);
41        let mut lines = vec![format!("? {query} ({})", matches.len())];
42        for node in matches {
43            lines.push(render_node_block(graph, node, full));
44        }
45        sections.push(lines.join("\n"));
46    }
47    format!("{}\n", sections.join("\n\n"))
48}
49
50pub fn find_nodes(
51    graph: &GraphFile,
52    query: &str,
53    limit: usize,
54    include_features: bool,
55    mode: FindMode,
56) -> Vec<Node> {
57    find_matches_with_index(graph, query, limit, include_features, mode, None)
58        .into_iter()
59        .cloned()
60        .collect()
61}
62
63pub fn find_nodes_with_index(
64    graph: &GraphFile,
65    query: &str,
66    limit: usize,
67    include_features: bool,
68    mode: FindMode,
69    index: Option<&Bm25Index>,
70) -> Vec<Node> {
71    find_matches_with_index(graph, query, limit, include_features, mode, index)
72        .into_iter()
73        .cloned()
74        .collect()
75}
76
77pub fn count_find_results(
78    graph: &GraphFile,
79    queries: &[String],
80    limit: usize,
81    include_features: bool,
82    mode: FindMode,
83) -> usize {
84    count_find_results_with_index(graph, queries, limit, include_features, mode, None)
85}
86
87pub fn count_find_results_with_index(
88    graph: &GraphFile,
89    queries: &[String],
90    limit: usize,
91    include_features: bool,
92    mode: FindMode,
93    index: Option<&Bm25Index>,
94) -> usize {
95    let mut total = 0;
96    for query in queries {
97        let matches = find_matches_with_index(graph, query, limit, include_features, mode, index);
98        total += matches.len();
99    }
100    total
101}
102
103pub fn render_node(graph: &GraphFile, node: &Node, full: bool) -> String {
104    format!("{}\n", render_node_block(graph, node, full))
105}
106
107fn render_node_block(graph: &GraphFile, node: &Node, full: bool) -> String {
108    let mut lines = Vec::new();
109    lines.push(format!("# {} | {}", node.id, node.name));
110
111    if !node.properties.alias.is_empty() {
112        lines.push(format!("aka: {}", node.properties.alias.join(", ")));
113    }
114    if full {
115        if !node.properties.domain_area.is_empty() {
116            lines.push(format!("domain_area: {}", node.properties.domain_area));
117        }
118        if !node.properties.provenance.is_empty() {
119            lines.push(format!("provenance: {}", node.properties.provenance));
120        }
121        if let Some(confidence) = node.properties.confidence {
122            lines.push(format!("confidence: {confidence}"));
123        }
124        if !node.properties.created_at.is_empty() {
125            lines.push(format!("created_at: {}", node.properties.created_at));
126        }
127    }
128
129    let facts_to_show = if full {
130        node.properties.key_facts.len()
131    } else {
132        node.properties.key_facts.len().min(2)
133    };
134    for fact in node.properties.key_facts.iter().take(facts_to_show) {
135        lines.push(format!("- {fact}"));
136    }
137    if node.properties.key_facts.len() > facts_to_show || full {
138        lines.push(format!("({} facts total)", node.properties.key_facts.len()));
139    }
140
141    let note_count = graph
142        .notes
143        .iter()
144        .filter(|note| note.node_id == node.id)
145        .count();
146    if full && note_count > 0 {
147        lines.push(format!("notes: {note_count}"));
148    }
149
150    for edge in outgoing_edges(graph, &node.id, full) {
151        if let Some(target) = graph.node_by_id(&edge.target_id) {
152            lines.push(format_edge("->", edge, target));
153        }
154    }
155    for edge in incoming_edges(graph, &node.id, full) {
156        if let Some(source) = graph.node_by_id(&edge.source_id) {
157            lines.push(format_edge("<-", edge, source));
158        }
159    }
160
161    lines.join("\n")
162}
163
164fn outgoing_edges<'a>(graph: &'a GraphFile, node_id: &str, full: bool) -> Vec<&'a Edge> {
165    let mut edges: Vec<&Edge> = graph
166        .edges
167        .iter()
168        .filter(|edge| edge.source_id == node_id)
169        .collect();
170    edges.sort_by_key(|edge| (&edge.relation, &edge.target_id));
171    if !full {
172        edges.truncate(3);
173    }
174    edges
175}
176
177fn incoming_edges<'a>(graph: &'a GraphFile, node_id: &str, full: bool) -> Vec<&'a Edge> {
178    let mut edges: Vec<&Edge> = graph
179        .edges
180        .iter()
181        .filter(|edge| edge.target_id == node_id)
182        .collect();
183    edges.sort_by_key(|edge| (&edge.relation, &edge.source_id));
184    if !full {
185        edges.truncate(3);
186    }
187    edges
188}
189
190fn format_edge(prefix: &str, edge: &Edge, related: &Node) -> String {
191    let (arrow, relation) = if edge.relation.starts_with("NOT_") {
192        (
193            format!("{prefix}!"),
194            edge.relation.trim_start_matches("NOT_"),
195        )
196    } else {
197        (prefix.to_owned(), edge.relation.as_str())
198    };
199
200    let mut line = format!("{arrow} {relation} | {} | {}", related.id, related.name);
201    if !edge.properties.detail.is_empty() {
202        line.push_str(" | ");
203        line.push_str(&truncate(&edge.properties.detail, 80));
204    }
205    line
206}
207
208fn truncate(value: &str, max_len: usize) -> String {
209    let char_count = value.chars().count();
210    if char_count <= max_len {
211        return value.to_owned();
212    }
213    let truncated: String = value.chars().take(max_len.saturating_sub(3)).collect();
214    format!("{truncated}...")
215}
216
217fn find_matches_with_index<'a>(
218    graph: &'a GraphFile,
219    query: &str,
220    limit: usize,
221    include_features: bool,
222    mode: FindMode,
223    index: Option<&Bm25Index>,
224) -> Vec<&'a Node> {
225    let mut scored: Vec<(i64, Reverse<&str>, &'a Node)> = match mode {
226        FindMode::Fuzzy => {
227            let pattern = Pattern::parse(query, CaseMatching::Ignore, Normalization::Smart);
228            let mut matcher = Matcher::new(Config::DEFAULT);
229            graph
230                .nodes
231                .iter()
232                .filter(|node| include_features || node.r#type != "Feature")
233                .filter_map(|node| {
234                    score_node(node, query, &pattern, &mut matcher).map(|score| {
235                        let base = score as i64;
236                        let boost = feedback_boost(node);
237                        (base + boost, Reverse(node.id.as_str()), node)
238                    })
239                })
240                .collect()
241        }
242        FindMode::Bm25 => score_bm25(graph, query, include_features, index),
243    };
244
245    scored.sort_by(|left, right| right.0.cmp(&left.0).then_with(|| left.1.cmp(&right.1)));
246    scored
247        .into_iter()
248        .take(limit)
249        .map(|(_, _, node)| node)
250        .collect()
251}
252
253fn feedback_boost(node: &Node) -> i64 {
254    let count = node.properties.feedback_count as f64;
255    if count <= 0.0 {
256        return 0;
257    }
258    let avg = node.properties.feedback_score / count;
259    let confidence = (count.ln_1p() / 3.0).min(1.0);
260    let scaled = avg * 200.0 * confidence;
261    scaled.clamp(-300.0, 300.0).round() as i64
262}
263
264fn score_bm25<'a>(
265    graph: &'a GraphFile,
266    query: &str,
267    include_features: bool,
268    index: Option<&Bm25Index>,
269) -> Vec<(i64, Reverse<&'a str>, &'a Node)> {
270    let terms = tokenize(query);
271    if terms.is_empty() {
272        return Vec::new();
273    }
274
275    if let Some(idx) = index {
276        let results = idx.search(&terms, graph);
277        return results
278            .into_iter()
279            .filter_map(|(node_id, score)| {
280                let node = graph.node_by_id(&node_id)?;
281                if !include_features && node.r#type == "Feature" {
282                    return None;
283                }
284                let boost = feedback_boost(node) as f64;
285                let combined = (score as f64 * 100.0 + boost).round() as i64;
286                Some((combined, Reverse(node.id.as_str()), node))
287            })
288            .collect();
289    }
290
291    let mut docs: Vec<(&'a Node, Vec<String>)> = graph
292        .nodes
293        .iter()
294        .filter(|node| include_features || node.r#type != "Feature")
295        .map(|node| (node, tokenize(&node_document_text(graph, node))))
296        .collect();
297
298    if docs.is_empty() {
299        return Vec::new();
300    }
301
302    let mut df: std::collections::HashMap<&str, usize> = std::collections::HashMap::new();
303    for term in &terms {
304        let mut count = 0usize;
305        for (_, tokens) in &docs {
306            if tokens.iter().any(|t| t == term) {
307                count += 1;
308            }
309        }
310        df.insert(term.as_str(), count);
311    }
312
313    let total_docs = docs.len() as f64;
314    let avgdl = docs
315        .iter()
316        .map(|(_, tokens)| tokens.len() as f64)
317        .sum::<f64>()
318        / total_docs;
319
320    let mut scored = Vec::new();
321
322    for (node, tokens) in docs.drain(..) {
323        let dl = tokens.len() as f64;
324        if dl == 0.0 {
325            continue;
326        }
327        let mut score = 0.0f64;
328        for term in &terms {
329            let tf = tokens.iter().filter(|t| *t == term).count() as f64;
330            if tf == 0.0 {
331                continue;
332            }
333            let df_t = *df.get(term.as_str()).unwrap_or(&0) as f64;
334            let idf = (1.0 + (total_docs - df_t + 0.5) / (df_t + 0.5)).ln();
335            let denom = tf + BM25_K1 * (1.0 - BM25_B + BM25_B * (dl / avgdl));
336            score += idf * (tf * (BM25_K1 + 1.0) / denom);
337        }
338        if score > 0.0 {
339            let boost = feedback_boost(node) as f64;
340            let combined = score * 100.0 + boost;
341            scored.push((combined.round() as i64, Reverse(node.id.as_str()), node));
342        }
343    }
344
345    scored
346}
347
348fn node_document_text(graph: &GraphFile, node: &Node) -> String {
349    let mut out = String::new();
350    push_field(&mut out, &node.id);
351    push_field(&mut out, &node.name);
352    push_field(&mut out, &node.properties.description);
353    for alias in &node.properties.alias {
354        push_field(&mut out, alias);
355    }
356    for fact in &node.properties.key_facts {
357        push_field(&mut out, fact);
358    }
359    for note in graph.notes.iter().filter(|note| note.node_id == node.id) {
360        push_field(&mut out, &note.body);
361        for tag in &note.tags {
362            push_field(&mut out, tag);
363        }
364    }
365    out
366}
367
368fn push_field(target: &mut String, value: &str) {
369    if value.is_empty() {
370        return;
371    }
372    if !target.is_empty() {
373        target.push(' ');
374    }
375    target.push_str(value);
376}
377
378fn tokenize(text: &str) -> Vec<String> {
379    let mut tokens = Vec::new();
380    let mut current = String::new();
381    for ch in text.chars() {
382        if ch.is_alphanumeric() {
383            current.push(ch.to_ascii_lowercase());
384        } else if !current.is_empty() {
385            tokens.push(current.clone());
386            current.clear();
387        }
388    }
389    if !current.is_empty() {
390        tokens.push(current);
391    }
392    tokens
393}
394
395fn score_node(node: &Node, query: &str, pattern: &Pattern, matcher: &mut Matcher) -> Option<u32> {
396    let mut total = 0;
397    let mut primary_hits = 0;
398
399    let id_score = score_primary_field(query, pattern, matcher, &node.id, 4);
400    if id_score > 0 {
401        primary_hits += 1;
402    }
403    total += id_score;
404
405    let name_score = score_primary_field(query, pattern, matcher, &node.name, 3);
406    if name_score > 0 {
407        primary_hits += 1;
408    }
409    total += name_score;
410
411    for alias in &node.properties.alias {
412        let alias_score = score_primary_field(query, pattern, matcher, alias, 3);
413        if alias_score > 0 {
414            primary_hits += 1;
415        }
416        total += alias_score;
417    }
418
419    if primary_hits > 0 {
420        total += score_secondary_field(query, pattern, matcher, &node.properties.description, 1);
421    }
422
423    (total > 0).then_some(total)
424}
425
426fn score_field(pattern: &Pattern, matcher: &mut Matcher, value: &str) -> Option<u32> {
427    if value.is_empty() {
428        return None;
429    }
430    let mut buf = Vec::new();
431    let haystack = Utf32Str::new(value, &mut buf);
432    pattern.score(haystack, matcher)
433}
434
435fn score_primary_field(
436    query: &str,
437    pattern: &Pattern,
438    matcher: &mut Matcher,
439    value: &str,
440    weight: u32,
441) -> u32 {
442    let bonus = textual_bonus(query, value);
443    if bonus == 0 {
444        return 0;
445    }
446    let fuzzy = score_field(pattern, matcher, value).unwrap_or(0);
447    (fuzzy + bonus) * weight
448}
449
450fn score_secondary_field(
451    query: &str,
452    pattern: &Pattern,
453    matcher: &mut Matcher,
454    value: &str,
455    weight: u32,
456) -> u32 {
457    let bonus = textual_bonus(query, value);
458    if bonus == 0 {
459        return 0;
460    }
461    let fuzzy = score_field(pattern, matcher, value).unwrap_or(0);
462    (fuzzy + bonus / 2) * weight
463}
464
465fn textual_bonus(query: &str, value: &str) -> u32 {
466    let query = query.trim().to_lowercase();
467    let value = value.to_lowercase();
468
469    if value == query {
470        return 400;
471    }
472    if value.contains(&query) {
473        return 200;
474    }
475
476    query
477        .split_whitespace()
478        .map(|token| {
479            if value.contains(token) {
480                80
481            } else if is_subsequence(token, &value) {
482                40
483            } else {
484                0
485            }
486        })
487        .sum()
488}
489
490fn is_subsequence(needle: &str, haystack: &str) -> bool {
491    if needle.is_empty() {
492        return false;
493    }
494
495    let mut chars = needle.chars();
496    let mut current = match chars.next() {
497        Some(ch) => ch,
498        None => return false,
499    };
500
501    for ch in haystack.chars() {
502        if ch == current {
503            match chars.next() {
504                Some(next) => current = next,
505                None => return true,
506            }
507        }
508    }
509
510    false
511}