Skip to main content

adk_memory/
text.rs

1//! Shared text extraction utilities for memory backends.
2
3use adk_core::Part;
4use std::collections::HashSet;
5
6/// Extract all text parts from a [`Content`](adk_core::Content) into a single string.
7///
8/// Parts are joined with a single space. Non-text parts (images, function calls,
9/// etc.) are silently skipped.
10pub fn extract_text(content: &adk_core::Content) -> String {
11    content
12        .parts
13        .iter()
14        .filter_map(|part| match part {
15            Part::Text { text } => Some(text.as_str()),
16            _ => None,
17        })
18        .collect::<Vec<_>>()
19        .join(" ")
20}
21
22/// Returns `true` if the character is in a CJK Unified Ideographs block.
23fn is_cjk_char(c: char) -> bool {
24    matches!(c,
25        '\u{4e00}'..='\u{9fff}'   // CJK Unified Ideographs
26        | '\u{3400}'..='\u{4dbf}' // CJK Unified Ideographs Extension A
27        | '\u{f900}'..='\u{faff}' // CJK Compatibility Ideographs
28        | '\u{2e80}'..='\u{2eff}' // CJK Radicals Supplement
29        | '\u{3000}'..='\u{303f}' // CJK Symbols and Punctuation
30        | '\u{3040}'..='\u{309f}' // Hiragana
31        | '\u{30a0}'..='\u{30ff}' // Katakana
32        | '\u{ac00}'..='\u{d7af}' // Hangul Syllables
33    )
34}
35
36/// Tokenize text into a set of lowercase words for keyword matching.
37///
38/// For whitespace-separated languages (English, etc.), splits on whitespace.
39/// For CJK text (Chinese, Japanese, Korean) which has no word-separating
40/// whitespace, generates character-level unigrams and bigrams to enable
41/// substring matching.
42pub fn extract_words(text: &str) -> HashSet<String> {
43    let mut words = HashSet::new();
44
45    for token in text.split_whitespace() {
46        if token.is_empty() {
47            continue;
48        }
49        let lower = token.to_lowercase();
50
51        // Check if this token contains CJK characters
52        let has_cjk = lower.chars().any(is_cjk_char);
53
54        if has_cjk {
55            // For CJK tokens, generate character unigrams and bigrams
56            // This enables partial matching: "编程" matches within "用户喜欢用Rust编程"
57            let chars: Vec<char> = lower.chars().collect();
58            for c in &chars {
59                if is_cjk_char(*c) {
60                    words.insert(c.to_string());
61                }
62            }
63            for window in chars.windows(2) {
64                if window.iter().any(|c| is_cjk_char(*c)) {
65                    let bigram: String = window.iter().collect();
66                    words.insert(bigram);
67                }
68            }
69            // Also insert the full token for exact matches
70            words.insert(lower);
71        } else {
72            words.insert(lower);
73        }
74    }
75
76    // Handle text with no whitespace at all (pure CJK string)
77    if !text.contains(char::is_whitespace) && text.chars().any(is_cjk_char) {
78        let lower = text.to_lowercase();
79        let chars: Vec<char> = lower.chars().collect();
80        for c in &chars {
81            if is_cjk_char(*c) {
82                words.insert(c.to_string());
83            }
84        }
85        for window in chars.windows(2) {
86            if window.iter().any(|c| is_cjk_char(*c)) {
87                let bigram: String = window.iter().collect();
88                words.insert(bigram);
89            }
90        }
91        // Also add the full string
92        words.insert(lower);
93    }
94
95    words
96}
97
98/// Extract and tokenize all text from a [`Content`](adk_core::Content) into word set.
99pub fn extract_words_from_content(content: &adk_core::Content) -> HashSet<String> {
100    let mut words = HashSet::new();
101    for part in &content.parts {
102        if let Part::Text { text } = part {
103            words.extend(extract_words(text));
104        }
105    }
106    words
107}
108
109#[cfg(test)]
110mod tests {
111    use super::*;
112
113    #[test]
114    fn test_extract_words_english() {
115        let words = extract_words("Hello World foo bar");
116        assert!(words.contains("hello"));
117        assert!(words.contains("world"));
118        assert!(words.contains("foo"));
119        assert!(words.contains("bar"));
120    }
121
122    #[test]
123    fn test_extract_words_cjk_bigram_matching() {
124        // "用户喜欢用Rust编程" should produce bigrams that include "编程"
125        let stored = extract_words("用户喜欢用Rust编程");
126        let query = extract_words("编程");
127
128        // The query "编程" should match because it's a bigram in the stored text
129        let matches: HashSet<_> = stored.intersection(&query).collect();
130        assert!(
131            !matches.is_empty(),
132            "CJK search should find matches. Stored: {stored:?}, Query: {query:?}"
133        );
134    }
135
136    #[test]
137    fn test_extract_words_cjk_single_char() {
138        let stored = extract_words("今天天气很好");
139        let query = extract_words("天气");
140
141        let matches: HashSet<_> = stored.intersection(&query).collect();
142        assert!(
143            !matches.is_empty(),
144            "CJK bigram '天气' should match. Stored: {stored:?}, Query: {query:?}"
145        );
146    }
147
148    #[test]
149    fn test_extract_words_mixed_cjk_english() {
150        let words = extract_words("Hello 你好 World");
151        assert!(words.contains("hello"));
152        assert!(words.contains("world"));
153        assert!(words.contains("你"));
154        assert!(words.contains("好"));
155        assert!(words.contains("你好"));
156    }
157
158    #[test]
159    fn test_extract_words_japanese() {
160        let stored = extract_words("東京タワー");
161        let query = extract_words("東京");
162
163        let matches: HashSet<_> = stored.intersection(&query).collect();
164        assert!(!matches.is_empty(), "Japanese bigram should match");
165    }
166}