1#![warn(missing_docs)]
29#![warn(clippy::all)]
30#![warn(clippy::pedantic)]
31#![allow(clippy::cast_possible_truncation)]
33#![allow(clippy::cast_sign_loss)]
34#![allow(clippy::cast_precision_loss)]
35#![allow(clippy::cast_possible_wrap)]
36#![allow(clippy::missing_errors_doc)]
37#![allow(clippy::missing_panics_doc)]
38#![allow(clippy::must_use_candidate)]
39#![allow(clippy::module_name_repetitions)]
40#![allow(clippy::similar_names)]
41#![allow(clippy::many_single_char_names)]
42#![allow(clippy::too_many_arguments)]
43#![allow(clippy::doc_markdown)]
44#![allow(clippy::cast_lossless)]
45#![allow(clippy::needless_pass_by_value)]
46#![allow(clippy::redundant_closure_for_method_calls)]
47#![allow(clippy::uninlined_format_args)]
48#![allow(clippy::ptr_arg)]
49#![allow(clippy::return_self_not_must_use)]
50#![allow(clippy::not_unsafe_ptr_arg_deref)]
51#![allow(clippy::items_after_statements)]
52#![allow(clippy::unreadable_literal)]
53#![allow(clippy::if_same_then_else)]
54#![allow(clippy::needless_range_loop)]
55#![allow(clippy::trivially_copy_pass_by_ref)]
56#![allow(clippy::unnecessary_wraps)]
57#![allow(clippy::match_same_arms)]
58#![allow(clippy::unused_self)]
59#![allow(clippy::too_many_lines)]
60#![allow(clippy::single_match_else)]
61#![allow(clippy::fn_params_excessive_bools)]
62#![allow(clippy::struct_excessive_bools)]
63#![allow(clippy::format_push_string)]
64#![allow(clippy::erasing_op)]
65#![allow(clippy::type_repetition_in_bounds)]
66#![allow(clippy::iter_without_into_iter)]
67#![allow(clippy::should_implement_trait)]
68#![allow(clippy::use_debug)]
69#![allow(clippy::case_sensitive_file_extension_comparisons)]
70#![allow(clippy::large_enum_variant)]
71#![allow(clippy::panic)]
72#![allow(clippy::struct_field_names)]
73#![allow(clippy::missing_fields_in_debug)]
74#![allow(clippy::upper_case_acronyms)]
75#![allow(clippy::assigning_clones)]
76#![allow(clippy::option_if_let_else)]
77#![allow(clippy::manual_let_else)]
78#![allow(clippy::explicit_iter_loop)]
79#![allow(clippy::default_trait_access)]
80#![allow(clippy::only_used_in_recursion)]
81#![allow(clippy::manual_clamp)]
82#![allow(clippy::ref_option)]
83#![allow(clippy::multiple_bound_locations)]
84#![allow(clippy::comparison_chain)]
85#![allow(clippy::manual_assert)]
86#![allow(clippy::unnecessary_debug_formatting)]
87
88pub mod datasets;
89pub mod tokenizer;
90pub mod vocab;
91
92pub use vocab::{Vocab, BOS_TOKEN, EOS_TOKEN, MASK_TOKEN, PAD_TOKEN, UNK_TOKEN};
97
98pub use tokenizer::{
99 BasicBPETokenizer, CharTokenizer, NGramTokenizer, Tokenizer, UnigramTokenizer,
100 WhitespaceTokenizer, WordPunctTokenizer,
101};
102
103pub use datasets::{
104 LanguageModelDataset, SyntheticSentimentDataset, SyntheticSeq2SeqDataset, TextDataset,
105};
106
107pub mod prelude {
113 pub use crate::{
114 BasicBPETokenizer,
115 CharTokenizer,
116 LanguageModelDataset,
117 NGramTokenizer,
118 SyntheticSentimentDataset,
119 SyntheticSeq2SeqDataset,
120 TextDataset,
122 Tokenizer,
124 UnigramTokenizer,
125 Vocab,
127 WhitespaceTokenizer,
128 WordPunctTokenizer,
129 BOS_TOKEN,
130 EOS_TOKEN,
131 MASK_TOKEN,
132 PAD_TOKEN,
133 UNK_TOKEN,
134 };
135
136 pub use axonml_data::{DataLoader, Dataset};
137 pub use axonml_tensor::Tensor;
138}
139
140#[cfg(test)]
145mod tests {
146 use super::*;
147 use axonml_data::Dataset;
148
149 #[test]
150 fn test_vocab_and_tokenizer_integration() {
151 let text = "the quick brown fox jumps over the lazy dog";
152 let vocab = Vocab::from_text(text, 1);
153 let tokenizer = WhitespaceTokenizer::new();
154
155 let tokens = tokenizer.tokenize("the fox");
156 let indices = tokenizer.encode("the fox", &vocab);
157
158 assert_eq!(tokens.len(), 2);
159 assert_eq!(indices.len(), 2);
160 }
161
162 #[test]
163 fn test_text_dataset_with_tokenizer() {
164 let samples = vec![
165 ("good movie".to_string(), 1),
166 ("bad movie".to_string(), 0),
167 ("great film".to_string(), 1),
168 ("terrible movie".to_string(), 0),
169 ];
170
171 let tokenizer = WhitespaceTokenizer::new();
172 let dataset = TextDataset::from_samples(&samples, &tokenizer, 1, 10);
173
174 assert_eq!(dataset.len(), 4);
175 assert_eq!(dataset.num_classes(), 2);
176 }
177
178 #[test]
179 fn test_language_model_pipeline() {
180 let text = "one two three four five six seven eight nine ten";
181 let dataset = LanguageModelDataset::from_text(text, 3, 1);
182
183 assert!(dataset.len() > 0);
184
185 let (input, target) = dataset.get(0).unwrap();
187 assert_eq!(input.shape(), &[3]);
188 assert_eq!(target.shape(), &[3]);
189 }
190
191 #[test]
192 fn test_bpe_tokenizer_training() {
193 let mut tokenizer = BasicBPETokenizer::new();
194 let text = "low lower lowest newer newest";
195 tokenizer.train(text, 10);
196
197 let vocab = tokenizer.get_vocab();
198 assert!(!vocab.is_empty());
199
200 let tokens = tokenizer.tokenize("low");
201 assert!(!tokens.is_empty());
202 }
203
204 #[test]
205 fn test_char_tokenizer_with_vocab() {
206 let tokenizer = CharTokenizer::new();
207 let mut vocab = Vocab::with_special_tokens();
208
209 for c in "abcdefghijklmnopqrstuvwxyz ".chars() {
211 vocab.add_token(&c.to_string());
212 }
213
214 let indices = tokenizer.encode("hello", &vocab);
215 assert_eq!(indices.len(), 5);
216 }
217
218 #[test]
219 fn test_synthetic_datasets_with_dataloader() {
220 use axonml_data::DataLoader;
221
222 let dataset = SyntheticSentimentDataset::small();
223 let loader = DataLoader::new(dataset, 16);
224
225 let mut batch_count = 0;
226 for batch in loader.iter().take(3) {
227 assert_eq!(batch.data.shape()[0], 16);
228 batch_count += 1;
229 }
230 assert_eq!(batch_count, 3);
231 }
232
233 #[test]
234 fn test_ngram_tokenizer() {
235 let word_bigrams = NGramTokenizer::word_ngrams(2);
236 let tokens = word_bigrams.tokenize("one two three four");
237
238 assert_eq!(tokens.len(), 3);
239 assert!(tokens.contains(&"one two".to_string()));
240
241 let char_trigrams = NGramTokenizer::char_ngrams(3);
242 let tokens = char_trigrams.tokenize("hello");
243
244 assert_eq!(tokens.len(), 3);
245 }
246
247 #[test]
248 fn test_seq2seq_reverse_task() {
249 let dataset = SyntheticSeq2SeqDataset::copy_task(10, 5, 100);
250
251 let (src, tgt) = dataset.get(0).unwrap();
252
253 let src_vec = src.to_vec();
255 let tgt_vec = tgt.to_vec();
256
257 for (i, &val) in src_vec.iter().enumerate() {
258 assert_eq!(val, tgt_vec[src_vec.len() - 1 - i]);
259 }
260 }
261}