codebook/
lib.rs

1pub mod dictionaries;
2mod logging;
3pub mod parser;
4pub mod queries;
5mod splitter;
6
7use std::sync::Arc;
8
9use codebook_config::CodebookConfig;
10use dictionaries::{dictionary, manager::DictionaryManager};
11use dictionary::Dictionary;
12use log::debug;
13use parser::WordLocation;
14
15pub struct Codebook {
16    config: Arc<CodebookConfig>,
17    manager: DictionaryManager,
18}
19
20// Custom 'codebook' dictionary could be removed later for a more general solution.
21static DEFAULT_DICTIONARIES: &[&str; 3] = &["codebook", "software_terms", "computing_acronyms"];
22
23impl Codebook {
24    pub fn new(config: Arc<CodebookConfig>) -> Result<Self, Box<dyn std::error::Error>> {
25        let manager = DictionaryManager::new(&config.cache_dir);
26        Ok(Self { config, manager })
27    }
28
29    /// Get WordLocations for a block of text.
30    /// Supply LanguageType, file path or both to use the correct code parser.
31    pub fn spell_check(
32        &self,
33        text: &str,
34        language: Option<queries::LanguageType>,
35        file_path: Option<&str>,
36    ) -> Vec<parser::WordLocation> {
37        if file_path.is_some() && self.config.should_ignore_path(file_path.unwrap()) {
38            return Vec::new();
39        }
40        // get needed dictionary names
41        // get needed dictionaries
42        // call spell check on each dictionary
43        let language = self.resolve_language(language, file_path);
44        let dictionaries = self.get_dictionaries(Some(language));
45        parser::find_locations(text, language, |word| {
46            if self.config.should_flag_word(word) {
47                return false;
48            }
49            if word.len() < 3 {
50                return true;
51            }
52            if self.config.is_allowed_word(word) {
53                return true;
54            }
55            for dictionary in &dictionaries {
56                if dictionary.check(word) {
57                    return true;
58                }
59            }
60            false
61        })
62    }
63
64    fn resolve_language(
65        &self,
66        language_type: Option<queries::LanguageType>,
67        path: Option<&str>,
68    ) -> queries::LanguageType {
69        // Check if we have a language_id first, fallback to path, fall back to text
70        match language_type {
71            Some(lang) => lang,
72            None => match path {
73                Some(path) => queries::get_language_name_from_filename(path),
74                None => queries::LanguageType::Text,
75            },
76        }
77    }
78
79    fn get_dictionaries(
80        &self,
81        language: Option<queries::LanguageType>,
82    ) -> Vec<Arc<dyn Dictionary>> {
83        let mut dictionary_ids = self.config.get_dictionary_ids();
84        if let Some(lang) = language {
85            let language_dictionary_ids = lang.dictionary_ids();
86            dictionary_ids.extend(language_dictionary_ids);
87        };
88        dictionary_ids.extend(DEFAULT_DICTIONARIES.iter().map(|f| f.to_string()));
89        let mut dictionaries = Vec::with_capacity(dictionary_ids.len());
90        debug!("Checking text with dictionaries: {:?}", dictionary_ids);
91        for dictionary_id in dictionary_ids {
92            let dictionary = self.manager.get_dictionary(&dictionary_id);
93            if let Some(d) = dictionary {
94                dictionaries.push(d);
95            }
96        }
97        dictionaries
98    }
99
100    pub fn spell_check_file(&self, path: &str) -> Vec<WordLocation> {
101        let lang_type = queries::get_language_name_from_filename(path);
102        let file_text = std::fs::read_to_string(path).unwrap();
103        self.spell_check(&file_text, Some(lang_type), Some(path))
104    }
105
106    pub fn get_suggestions(&self, word: &str) -> Option<Vec<String>> {
107        // Get top suggestions and return the first 5 suggestions in round robin order
108        let max_results = 5;
109        let dictionaries = self.get_dictionaries(None);
110        let mut is_misspelled = false;
111        let suggestions: Vec<Vec<String>> = dictionaries
112            .iter()
113            .filter_map(|dict| {
114                if !dict.check(word) {
115                    is_misspelled = true;
116                    Some(dict.suggest(word))
117                } else {
118                    None
119                }
120            })
121            .collect();
122        if !is_misspelled {
123            return None;
124        }
125        Some(collect_round_robin(&suggestions, max_results))
126    }
127}
128
129fn collect_round_robin<T: Clone + PartialEq + Ord>(sources: &[Vec<T>], max_count: usize) -> Vec<T> {
130    let mut result = Vec::with_capacity(max_count);
131    for i in 0..max_count {
132        for source in sources {
133            if let Some(item) = source.get(i) {
134                if !result.contains(item) {
135                    result.push(item.clone());
136                    if result.len() >= max_count {
137                        return result;
138                    }
139                }
140            }
141        }
142    }
143    result.sort();
144    result
145}
146
147#[cfg(test)]
148mod tests {
149    use super::*;
150
151    #[test]
152    fn test_collect_round_robin_basic() {
153        let sources = vec![
154            vec!["apple", "banana", "cherry"],
155            vec!["date", "elderberry", "fig"],
156            vec!["grape", "honeydew", "kiwi"],
157        ];
158
159        let result = collect_round_robin(&sources, 5);
160        // Round-robin order: first from each source, then second from each source
161        assert_eq!(
162            result,
163            vec!["apple", "date", "grape", "banana", "elderberry"]
164        );
165    }
166
167    #[test]
168    fn test_collect_round_robin_with_duplicates() {
169        let sources = vec![
170            vec!["apple", "banana", "cherry"],
171            vec!["banana", "cherry", "date"],
172            vec!["cherry", "date", "elderberry"],
173        ];
174
175        // In round-robin, we get:
176        // 1. apple (1st from 1st source)
177        // 2. banana (1st from 2nd source) - cherry already taken
178        // 3. cherry (1st from 3rd source)
179        // 4. banana (2nd from 1st source)
180        // 5. date (3rd from 2nd source) - cherry already taken
181        let result = collect_round_robin(&sources, 5);
182        assert_eq!(
183            result,
184            vec!["apple", "banana", "cherry", "date", "elderberry"]
185        );
186    }
187
188    #[test]
189    fn test_collect_round_robin_uneven_sources() {
190        let sources = vec![
191            vec!["apple", "banana", "cherry", "date"],
192            vec!["elderberry"],
193            vec!["fig", "grape"],
194        ];
195
196        // Round-robin order with uneven sources
197        let result = collect_round_robin(&sources, 7);
198        assert_eq!(
199            result,
200            vec![
201                "apple",
202                "elderberry",
203                "fig",
204                "banana",
205                "grape",
206                "cherry",
207                "date"
208            ]
209        );
210    }
211
212    #[test]
213    fn test_collect_round_robin_empty_sources() {
214        let sources: Vec<Vec<&str>> = vec![];
215        let result = collect_round_robin(&sources, 5);
216        assert_eq!(result, Vec::<&str>::new());
217    }
218
219    #[test]
220    fn test_collect_round_robin_some_empty_sources() {
221        let sources = vec![vec!["apple", "banana"], vec![], vec!["cherry", "date"]];
222
223        // Round-robin order, skipping empty source
224        let result = collect_round_robin(&sources, 4);
225        assert_eq!(result, vec!["apple", "cherry", "banana", "date"]);
226    }
227
228    #[test]
229    fn test_collect_round_robin_with_numbers() {
230        let sources = vec![vec![1, 3, 5], vec![2, 4, 6]];
231
232        // Round-robin order with numbers
233        let result = collect_round_robin(&sources, 6);
234        assert_eq!(result, vec![1, 2, 3, 4, 5, 6]);
235    }
236
237    #[test]
238    fn test_collect_round_robin_max_count_exceeded() {
239        let sources = vec![
240            vec!["apple", "banana", "cherry"],
241            vec!["date", "elderberry", "fig"],
242            vec!["grape", "honeydew", "kiwi"],
243        ];
244
245        // First round of round-robin (first from each source)
246        let result = collect_round_robin(&sources, 3);
247        assert_eq!(result, vec!["apple", "date", "grape"]);
248    }
249
250    #[test]
251    fn test_collect_round_robin_max_count_higher_than_available() {
252        let sources = vec![vec!["apple", "banana"], vec!["cherry", "date"]];
253
254        // Round-robin order for all available elements
255        let result = collect_round_robin(&sources, 10);
256        assert_eq!(result, vec!["apple", "banana", "cherry", "date"]);
257    }
258}