1#![warn(missing_docs)]
25#![warn(clippy::all)]
26#![warn(clippy::pedantic)]
27#![allow(clippy::cast_possible_truncation)]
29#![allow(clippy::cast_sign_loss)]
30#![allow(clippy::cast_precision_loss)]
31#![allow(clippy::cast_possible_wrap)]
32#![allow(clippy::missing_errors_doc)]
33#![allow(clippy::missing_panics_doc)]
34#![allow(clippy::must_use_candidate)]
35#![allow(clippy::module_name_repetitions)]
36#![allow(clippy::similar_names)]
37#![allow(clippy::many_single_char_names)]
38#![allow(clippy::too_many_arguments)]
39#![allow(clippy::doc_markdown)]
40#![allow(clippy::cast_lossless)]
41#![allow(clippy::needless_pass_by_value)]
42#![allow(clippy::redundant_closure_for_method_calls)]
43#![allow(clippy::uninlined_format_args)]
44#![allow(clippy::ptr_arg)]
45#![allow(clippy::return_self_not_must_use)]
46#![allow(clippy::not_unsafe_ptr_arg_deref)]
47#![allow(clippy::items_after_statements)]
48#![allow(clippy::unreadable_literal)]
49#![allow(clippy::if_same_then_else)]
50#![allow(clippy::needless_range_loop)]
51#![allow(clippy::trivially_copy_pass_by_ref)]
52#![allow(clippy::unnecessary_wraps)]
53#![allow(clippy::match_same_arms)]
54#![allow(clippy::unused_self)]
55#![allow(clippy::too_many_lines)]
56#![allow(clippy::single_match_else)]
57#![allow(clippy::fn_params_excessive_bools)]
58#![allow(clippy::struct_excessive_bools)]
59#![allow(clippy::format_push_string)]
60#![allow(clippy::erasing_op)]
61#![allow(clippy::type_repetition_in_bounds)]
62#![allow(clippy::iter_without_into_iter)]
63#![allow(clippy::should_implement_trait)]
64#![allow(clippy::use_debug)]
65#![allow(clippy::case_sensitive_file_extension_comparisons)]
66#![allow(clippy::large_enum_variant)]
67#![allow(clippy::panic)]
68#![allow(clippy::struct_field_names)]
69#![allow(clippy::missing_fields_in_debug)]
70#![allow(clippy::upper_case_acronyms)]
71#![allow(clippy::assigning_clones)]
72#![allow(clippy::option_if_let_else)]
73#![allow(clippy::manual_let_else)]
74#![allow(clippy::explicit_iter_loop)]
75#![allow(clippy::default_trait_access)]
76#![allow(clippy::only_used_in_recursion)]
77#![allow(clippy::manual_clamp)]
78#![allow(clippy::ref_option)]
79#![allow(clippy::multiple_bound_locations)]
80#![allow(clippy::comparison_chain)]
81#![allow(clippy::manual_assert)]
82#![allow(clippy::unnecessary_debug_formatting)]
83
84pub mod datasets;
85pub mod tokenizer;
86pub mod vocab;
87
88pub use vocab::{BOS_TOKEN, EOS_TOKEN, MASK_TOKEN, PAD_TOKEN, UNK_TOKEN, Vocab};
93
94pub use tokenizer::{
95 BasicBPETokenizer, CharTokenizer, NGramTokenizer, Tokenizer, UnigramTokenizer,
96 WhitespaceTokenizer, WordPunctTokenizer,
97};
98
99pub use datasets::{
100 LanguageModelDataset, SyntheticSentimentDataset, SyntheticSeq2SeqDataset, TextDataset,
101};
102
103pub mod prelude {
109 pub use crate::{
110 BOS_TOKEN,
111 BasicBPETokenizer,
112 CharTokenizer,
113 EOS_TOKEN,
114 LanguageModelDataset,
115 MASK_TOKEN,
116 NGramTokenizer,
117 PAD_TOKEN,
118 SyntheticSentimentDataset,
119 SyntheticSeq2SeqDataset,
120 TextDataset,
122 Tokenizer,
124 UNK_TOKEN,
125 UnigramTokenizer,
126 Vocab,
128 WhitespaceTokenizer,
129 WordPunctTokenizer,
130 };
131
132 pub use axonml_data::{DataLoader, Dataset};
133 pub use axonml_tensor::Tensor;
134}
135
136#[cfg(test)]
141mod tests {
142 use super::*;
143 use axonml_data::Dataset;
144
145 #[test]
146 fn test_vocab_and_tokenizer_integration() {
147 let text = "the quick brown fox jumps over the lazy dog";
148 let vocab = Vocab::from_text(text, 1);
149 let tokenizer = WhitespaceTokenizer::new();
150
151 let tokens = tokenizer.tokenize("the fox");
152 let indices = tokenizer.encode("the fox", &vocab);
153
154 assert_eq!(tokens.len(), 2);
155 assert_eq!(indices.len(), 2);
156 }
157
158 #[test]
159 fn test_text_dataset_with_tokenizer() {
160 let samples = vec![
161 ("good movie".to_string(), 1),
162 ("bad movie".to_string(), 0),
163 ("great film".to_string(), 1),
164 ("terrible movie".to_string(), 0),
165 ];
166
167 let tokenizer = WhitespaceTokenizer::new();
168 let dataset = TextDataset::from_samples(&samples, &tokenizer, 1, 10);
169
170 assert_eq!(dataset.len(), 4);
171 assert_eq!(dataset.num_classes(), 2);
172 }
173
174 #[test]
175 fn test_language_model_pipeline() {
176 let text = "one two three four five six seven eight nine ten";
177 let dataset = LanguageModelDataset::from_text(text, 3, 1);
178
179 assert!(dataset.len() > 0);
180
181 let (input, target) = dataset.get(0).unwrap();
183 assert_eq!(input.shape(), &[3]);
184 assert_eq!(target.shape(), &[3]);
185 }
186
187 #[test]
188 fn test_bpe_tokenizer_training() {
189 let mut tokenizer = BasicBPETokenizer::new();
190 let text = "low lower lowest newer newest";
191 tokenizer.train(text, 10);
192
193 let vocab = tokenizer.get_vocab();
194 assert!(!vocab.is_empty());
195
196 let tokens = tokenizer.tokenize("low");
197 assert!(!tokens.is_empty());
198 }
199
200 #[test]
201 fn test_char_tokenizer_with_vocab() {
202 let tokenizer = CharTokenizer::new();
203 let mut vocab = Vocab::with_special_tokens();
204
205 for c in "abcdefghijklmnopqrstuvwxyz ".chars() {
207 vocab.add_token(&c.to_string());
208 }
209
210 let indices = tokenizer.encode("hello", &vocab);
211 assert_eq!(indices.len(), 5);
212 }
213
214 #[test]
215 fn test_synthetic_datasets_with_dataloader() {
216 use axonml_data::DataLoader;
217
218 let dataset = SyntheticSentimentDataset::small();
219 let loader = DataLoader::new(dataset, 16);
220
221 let mut batch_count = 0;
222 for batch in loader.iter().take(3) {
223 assert_eq!(batch.data.shape()[0], 16);
224 batch_count += 1;
225 }
226 assert_eq!(batch_count, 3);
227 }
228
229 #[test]
230 fn test_ngram_tokenizer() {
231 let word_bigrams = NGramTokenizer::word_ngrams(2);
232 let tokens = word_bigrams.tokenize("one two three four");
233
234 assert_eq!(tokens.len(), 3);
235 assert!(tokens.contains(&"one two".to_string()));
236
237 let char_trigrams = NGramTokenizer::char_ngrams(3);
238 let tokens = char_trigrams.tokenize("hello");
239
240 assert_eq!(tokens.len(), 3);
241 }
242
243 #[test]
244 fn test_seq2seq_reverse_task() {
245 let dataset = SyntheticSeq2SeqDataset::copy_task(10, 5, 100);
246
247 let (src, tgt) = dataset.get(0).unwrap();
248
249 let src_vec = src.to_vec();
251 let tgt_vec = tgt.to_vec();
252
253 for (i, &val) in src_vec.iter().enumerate() {
254 assert_eq!(val, tgt_vec[src_vec.len() - 1 - i]);
255 }
256 }
257}