1#![warn(missing_docs)]
18#![warn(clippy::all)]
19#![warn(clippy::pedantic)]
20#![allow(clippy::cast_possible_truncation)]
22#![allow(clippy::cast_sign_loss)]
23#![allow(clippy::cast_precision_loss)]
24#![allow(clippy::cast_possible_wrap)]
25#![allow(clippy::missing_errors_doc)]
26#![allow(clippy::missing_panics_doc)]
27#![allow(clippy::must_use_candidate)]
28#![allow(clippy::module_name_repetitions)]
29#![allow(clippy::similar_names)]
30#![allow(clippy::many_single_char_names)]
31#![allow(clippy::too_many_arguments)]
32#![allow(clippy::doc_markdown)]
33#![allow(clippy::cast_lossless)]
34#![allow(clippy::needless_pass_by_value)]
35#![allow(clippy::redundant_closure_for_method_calls)]
36#![allow(clippy::uninlined_format_args)]
37#![allow(clippy::ptr_arg)]
38#![allow(clippy::return_self_not_must_use)]
39#![allow(clippy::not_unsafe_ptr_arg_deref)]
40#![allow(clippy::items_after_statements)]
41#![allow(clippy::unreadable_literal)]
42#![allow(clippy::if_same_then_else)]
43#![allow(clippy::needless_range_loop)]
44#![allow(clippy::trivially_copy_pass_by_ref)]
45#![allow(clippy::unnecessary_wraps)]
46#![allow(clippy::match_same_arms)]
47#![allow(clippy::unused_self)]
48#![allow(clippy::too_many_lines)]
49#![allow(clippy::single_match_else)]
50#![allow(clippy::fn_params_excessive_bools)]
51#![allow(clippy::struct_excessive_bools)]
52#![allow(clippy::format_push_string)]
53#![allow(clippy::erasing_op)]
54#![allow(clippy::type_repetition_in_bounds)]
55#![allow(clippy::iter_without_into_iter)]
56#![allow(clippy::should_implement_trait)]
57#![allow(clippy::use_debug)]
58#![allow(clippy::case_sensitive_file_extension_comparisons)]
59#![allow(clippy::large_enum_variant)]
60#![allow(clippy::panic)]
61#![allow(clippy::struct_field_names)]
62#![allow(clippy::missing_fields_in_debug)]
63#![allow(clippy::upper_case_acronyms)]
64#![allow(clippy::assigning_clones)]
65#![allow(clippy::option_if_let_else)]
66#![allow(clippy::manual_let_else)]
67#![allow(clippy::explicit_iter_loop)]
68#![allow(clippy::default_trait_access)]
69#![allow(clippy::only_used_in_recursion)]
70#![allow(clippy::manual_clamp)]
71#![allow(clippy::ref_option)]
72#![allow(clippy::multiple_bound_locations)]
73#![allow(clippy::comparison_chain)]
74#![allow(clippy::manual_assert)]
75#![allow(clippy::unnecessary_debug_formatting)]
76
77pub mod datasets;
78pub mod tokenizer;
79pub mod vocab;
80
81pub use vocab::{BOS_TOKEN, EOS_TOKEN, MASK_TOKEN, PAD_TOKEN, UNK_TOKEN, Vocab};
86
87pub use tokenizer::{
88 BasicBPETokenizer, CharTokenizer, NGramTokenizer, Tokenizer, UnigramTokenizer,
89 WhitespaceTokenizer, WordPunctTokenizer,
90};
91
92pub use datasets::{
93 LanguageModelDataset, SyntheticSentimentDataset, SyntheticSeq2SeqDataset, TextDataset,
94};
95
96pub mod prelude {
102 pub use crate::{
103 BOS_TOKEN,
104 BasicBPETokenizer,
105 CharTokenizer,
106 EOS_TOKEN,
107 LanguageModelDataset,
108 MASK_TOKEN,
109 NGramTokenizer,
110 PAD_TOKEN,
111 SyntheticSentimentDataset,
112 SyntheticSeq2SeqDataset,
113 TextDataset,
115 Tokenizer,
117 UNK_TOKEN,
118 UnigramTokenizer,
119 Vocab,
121 WhitespaceTokenizer,
122 WordPunctTokenizer,
123 };
124
125 pub use axonml_data::{DataLoader, Dataset};
126 pub use axonml_tensor::Tensor;
127}
128
129#[cfg(test)]
134mod tests {
135 use super::*;
136 use axonml_data::Dataset;
137
138 #[test]
139 fn test_vocab_and_tokenizer_integration() {
140 let text = "the quick brown fox jumps over the lazy dog";
141 let vocab = Vocab::from_text(text, 1);
142 let tokenizer = WhitespaceTokenizer::new();
143
144 let tokens = tokenizer.tokenize("the fox");
145 let indices = tokenizer.encode("the fox", &vocab);
146
147 assert_eq!(tokens.len(), 2);
148 assert_eq!(indices.len(), 2);
149 }
150
151 #[test]
152 fn test_text_dataset_with_tokenizer() {
153 let samples = vec![
154 ("good movie".to_string(), 1),
155 ("bad movie".to_string(), 0),
156 ("great film".to_string(), 1),
157 ("terrible movie".to_string(), 0),
158 ];
159
160 let tokenizer = WhitespaceTokenizer::new();
161 let dataset = TextDataset::from_samples(&samples, &tokenizer, 1, 10);
162
163 assert_eq!(dataset.len(), 4);
164 assert_eq!(dataset.num_classes(), 2);
165 }
166
167 #[test]
168 fn test_language_model_pipeline() {
169 let text = "one two three four five six seven eight nine ten";
170 let dataset = LanguageModelDataset::from_text(text, 3, 1);
171
172 assert!(dataset.len() > 0);
173
174 let (input, target) = dataset.get(0).unwrap();
176 assert_eq!(input.shape(), &[3]);
177 assert_eq!(target.shape(), &[3]);
178 }
179
180 #[test]
181 fn test_bpe_tokenizer_training() {
182 let mut tokenizer = BasicBPETokenizer::new();
183 let text = "low lower lowest newer newest";
184 tokenizer.train(text, 10);
185
186 let vocab = tokenizer.get_vocab();
187 assert!(!vocab.is_empty());
188
189 let tokens = tokenizer.tokenize("low");
190 assert!(!tokens.is_empty());
191 }
192
193 #[test]
194 fn test_char_tokenizer_with_vocab() {
195 let tokenizer = CharTokenizer::new();
196 let mut vocab = Vocab::with_special_tokens();
197
198 for c in "abcdefghijklmnopqrstuvwxyz ".chars() {
200 vocab.add_token(&c.to_string());
201 }
202
203 let indices = tokenizer.encode("hello", &vocab);
204 assert_eq!(indices.len(), 5);
205 }
206
207 #[test]
208 fn test_synthetic_datasets_with_dataloader() {
209 use axonml_data::DataLoader;
210
211 let dataset = SyntheticSentimentDataset::small();
212 let loader = DataLoader::new(dataset, 16);
213
214 let mut batch_count = 0;
215 for batch in loader.iter().take(3) {
216 assert_eq!(batch.data.shape()[0], 16);
217 batch_count += 1;
218 }
219 assert_eq!(batch_count, 3);
220 }
221
222 #[test]
223 fn test_ngram_tokenizer() {
224 let word_bigrams = NGramTokenizer::word_ngrams(2);
225 let tokens = word_bigrams.tokenize("one two three four");
226
227 assert_eq!(tokens.len(), 3);
228 assert!(tokens.contains(&"one two".to_string()));
229
230 let char_trigrams = NGramTokenizer::char_ngrams(3);
231 let tokens = char_trigrams.tokenize("hello");
232
233 assert_eq!(tokens.len(), 3);
234 }
235
236 #[test]
237 fn test_seq2seq_reverse_task() {
238 let dataset = SyntheticSeq2SeqDataset::copy_task(10, 5, 100);
239
240 let (src, tgt) = dataset.get(0).unwrap();
241
242 let src_vec = src.to_vec();
244 let tgt_vec = tgt.to_vec();
245
246 for (i, &val) in src_vec.iter().enumerate() {
247 assert_eq!(val, tgt_vec[src_vec.len() - 1 - i]);
248 }
249 }
250}