1pub mod dictionaries;
2mod logging;
3pub mod parser;
4pub mod queries;
5mod splitter;
6
7use std::sync::Arc;
8
9use codebook_config::CodebookConfig;
10use dictionaries::{dictionary, manager::DictionaryManager};
11use dictionary::Dictionary;
12use log::debug;
13use parser::WordLocation;
14
15pub struct Codebook {
16 config: Arc<CodebookConfig>,
17 manager: DictionaryManager,
18}
19
20static DEFAULT_DICTIONARIES: &[&str; 3] = &["codebook", "software_terms", "computing_acronyms"];
22
23impl Codebook {
24 pub fn new(config: Arc<CodebookConfig>) -> Result<Self, Box<dyn std::error::Error>> {
25 let manager = DictionaryManager::new(&config.cache_dir);
26 Ok(Self { config, manager })
27 }
28
29 pub fn spell_check(
32 &self,
33 text: &str,
34 language: Option<queries::LanguageType>,
35 file_path: Option<&str>,
36 ) -> Vec<parser::WordLocation> {
37 if file_path.is_some() && self.config.should_ignore_path(file_path.unwrap()) {
38 return Vec::new();
39 }
40 let language = self.resolve_language(language, file_path);
44 let dictionaries = self.get_dictionaries(Some(language));
45 parser::find_locations(text, language, |word| {
46 if self.config.should_flag_word(word) {
47 return false;
48 }
49 if word.len() < 3 {
50 return true;
51 }
52 if self.config.is_allowed_word(word) {
53 return true;
54 }
55 for dictionary in &dictionaries {
56 if dictionary.check(word) {
57 return true;
58 }
59 }
60 false
61 })
62 }
63
64 fn resolve_language(
65 &self,
66 language_type: Option<queries::LanguageType>,
67 path: Option<&str>,
68 ) -> queries::LanguageType {
69 match language_type {
71 Some(lang) => lang,
72 None => match path {
73 Some(path) => queries::get_language_name_from_filename(path),
74 None => queries::LanguageType::Text,
75 },
76 }
77 }
78
79 fn get_dictionaries(
80 &self,
81 language: Option<queries::LanguageType>,
82 ) -> Vec<Arc<dyn Dictionary>> {
83 let mut dictionary_ids = self.config.get_dictionary_ids();
84 if let Some(lang) = language {
85 let language_dictionary_ids = lang.dictionary_ids();
86 dictionary_ids.extend(language_dictionary_ids);
87 };
88 dictionary_ids.extend(DEFAULT_DICTIONARIES.iter().map(|f| f.to_string()));
89 let mut dictionaries = Vec::with_capacity(dictionary_ids.len());
90 debug!("Checking text with dictionaries: {:?}", dictionary_ids);
91 for dictionary_id in dictionary_ids {
92 let dictionary = self.manager.get_dictionary(&dictionary_id);
93 if let Some(d) = dictionary {
94 dictionaries.push(d);
95 }
96 }
97 dictionaries
98 }
99
100 pub fn spell_check_file(&self, path: &str) -> Vec<WordLocation> {
101 let lang_type = queries::get_language_name_from_filename(path);
102 let file_text = std::fs::read_to_string(path).unwrap();
103 self.spell_check(&file_text, Some(lang_type), Some(path))
104 }
105
106 pub fn get_suggestions(&self, word: &str) -> Option<Vec<String>> {
107 let max_results = 5;
109 let dictionaries = self.get_dictionaries(None);
110 let mut is_misspelled = false;
111 let suggestions: Vec<Vec<String>> = dictionaries
112 .iter()
113 .filter_map(|dict| {
114 if !dict.check(word) {
115 is_misspelled = true;
116 Some(dict.suggest(word))
117 } else {
118 None
119 }
120 })
121 .collect();
122 if !is_misspelled {
123 return None;
124 }
125 Some(collect_round_robin(&suggestions, max_results))
126 }
127}
128
129fn collect_round_robin<T: Clone + PartialEq + Ord>(sources: &[Vec<T>], max_count: usize) -> Vec<T> {
130 let mut result = Vec::with_capacity(max_count);
131 for i in 0..max_count {
132 for source in sources {
133 if let Some(item) = source.get(i) {
134 if !result.contains(item) {
135 result.push(item.clone());
136 if result.len() >= max_count {
137 return result;
138 }
139 }
140 }
141 }
142 }
143 result.sort();
144 result
145}
146
147#[cfg(test)]
148mod tests {
149 use super::*;
150
151 #[test]
152 fn test_collect_round_robin_basic() {
153 let sources = vec![
154 vec!["apple", "banana", "cherry"],
155 vec!["date", "elderberry", "fig"],
156 vec!["grape", "honeydew", "kiwi"],
157 ];
158
159 let result = collect_round_robin(&sources, 5);
160 assert_eq!(
162 result,
163 vec!["apple", "date", "grape", "banana", "elderberry"]
164 );
165 }
166
167 #[test]
168 fn test_collect_round_robin_with_duplicates() {
169 let sources = vec![
170 vec!["apple", "banana", "cherry"],
171 vec!["banana", "cherry", "date"],
172 vec!["cherry", "date", "elderberry"],
173 ];
174
175 let result = collect_round_robin(&sources, 5);
182 assert_eq!(
183 result,
184 vec!["apple", "banana", "cherry", "date", "elderberry"]
185 );
186 }
187
188 #[test]
189 fn test_collect_round_robin_uneven_sources() {
190 let sources = vec![
191 vec!["apple", "banana", "cherry", "date"],
192 vec!["elderberry"],
193 vec!["fig", "grape"],
194 ];
195
196 let result = collect_round_robin(&sources, 7);
198 assert_eq!(
199 result,
200 vec![
201 "apple",
202 "elderberry",
203 "fig",
204 "banana",
205 "grape",
206 "cherry",
207 "date"
208 ]
209 );
210 }
211
212 #[test]
213 fn test_collect_round_robin_empty_sources() {
214 let sources: Vec<Vec<&str>> = vec![];
215 let result = collect_round_robin(&sources, 5);
216 assert_eq!(result, Vec::<&str>::new());
217 }
218
219 #[test]
220 fn test_collect_round_robin_some_empty_sources() {
221 let sources = vec![vec!["apple", "banana"], vec![], vec!["cherry", "date"]];
222
223 let result = collect_round_robin(&sources, 4);
225 assert_eq!(result, vec!["apple", "cherry", "banana", "date"]);
226 }
227
228 #[test]
229 fn test_collect_round_robin_with_numbers() {
230 let sources = vec![vec![1, 3, 5], vec![2, 4, 6]];
231
232 let result = collect_round_robin(&sources, 6);
234 assert_eq!(result, vec![1, 2, 3, 4, 5, 6]);
235 }
236
237 #[test]
238 fn test_collect_round_robin_max_count_exceeded() {
239 let sources = vec![
240 vec!["apple", "banana", "cherry"],
241 vec!["date", "elderberry", "fig"],
242 vec!["grape", "honeydew", "kiwi"],
243 ];
244
245 let result = collect_round_robin(&sources, 3);
247 assert_eq!(result, vec!["apple", "date", "grape"]);
248 }
249
250 #[test]
251 fn test_collect_round_robin_max_count_higher_than_available() {
252 let sources = vec![vec!["apple", "banana"], vec!["cherry", "date"]];
253
254 let result = collect_round_robin(&sources, 10);
256 assert_eq!(result, vec!["apple", "banana", "cherry", "date"]);
257 }
258}