1use rand::Rng;
2use std::collections::HashMap;
3use std::fs::File;
4use std::io::{Read, Write};
5
6pub const ALLOWED_CHARS: &str =
7 "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789 .,!?:;-";
8const VOCAB_SIZE: usize = ALLOWED_CHARS.len(); lazy_static::lazy_static! {
11 static ref CHAR_TO_INDEX: HashMap<char, usize> = {
12 let mut map = HashMap::new();
13 for (i, c) in ALLOWED_CHARS.chars().enumerate() {
14 map.insert(c, i);
15 }
16 map
17 };
18
19 static ref INDEX_TO_CHAR: Vec<char> = {
20 ALLOWED_CHARS.chars().collect()
21 };
22}
23
24const HIDDEN: usize = 128; const LR: f32 = 0.005; pub const CONTEXT_SIZE: usize = 4; #[derive(Clone)]
29pub struct NGramModel {
30 w1: [[f32; HIDDEN]; VOCAB_SIZE * CONTEXT_SIZE],
32 w2: [[f32; VOCAB_SIZE]; HIDDEN],
34 context_to_index: HashMap<Vec<usize>, usize>,
36 next_context_index: usize,
37}
38
39impl NGramModel {
40 pub fn new() -> Self {
41 let mut rng = rand::thread_rng();
42
43 let mut w1 = [[0.0; HIDDEN]; VOCAB_SIZE * CONTEXT_SIZE];
45 let mut w2 = [[0.0; VOCAB_SIZE]; HIDDEN];
46
47 for i in 0..VOCAB_SIZE * CONTEXT_SIZE {
49 for j in 0..HIDDEN {
50 w1[i][j] = rng.gen_range(-0.1..0.1); }
52 }
53
54 for i in 0..HIDDEN {
55 for j in 0..VOCAB_SIZE {
56 w2[i][j] = rng.gen_range(-0.1..0.1);
57 }
58 }
59
60 Self {
61 w1,
62 w2,
63 context_to_index: HashMap::new(),
64 next_context_index: 0,
65 }
66 }
67
68 fn get_or_create_context_index(&mut self, context: &[usize]) -> usize {
70 let context_vec = context.to_vec();
71
72 if let Some(&idx) = self.context_to_index.get(&context_vec) {
73 return idx;
74 }
75
76 if self.next_context_index >= VOCAB_SIZE * CONTEXT_SIZE {
78 let mut hash = 0;
80 for &c in context {
81 hash = (hash * 31 + c) % (VOCAB_SIZE * CONTEXT_SIZE);
82 }
83 return hash;
84 }
85
86 let idx = self.next_context_index;
87 self.context_to_index.insert(context_vec, idx);
88 self.next_context_index += 1;
89 idx
90 }
91
92 pub fn forward(&mut self, context: &[usize]) -> Vec<f32> {
93 if context.len() != CONTEXT_SIZE {
94 panic!("Context must contain exactly {} characters", CONTEXT_SIZE);
95 }
96
97 let input_idx = self.get_or_create_context_index(context);
98
99 let mut hidden = [0.0; HIDDEN];
101 for j in 0..HIDDEN {
102 hidden[j] = self.w1[input_idx][j].tanh();
103 }
104
105 let mut output = vec![0.0; VOCAB_SIZE];
107 for i in 0..VOCAB_SIZE {
108 for j in 0..HIDDEN {
109 output[i] += hidden[j] * self.w2[j][i];
110 }
111 }
112
113 softmax(&mut output);
115 output
116 }
117
118 pub fn train(&mut self, context: &[usize], target: usize) {
119 if context.len() != CONTEXT_SIZE {
120 panic!("Context must contain exactly {} characters", CONTEXT_SIZE);
121 }
122
123 let target_idx = target;
125
126 let input_idx = self.get_or_create_context_index(context);
127
128 let mut hidden = [0.0; HIDDEN];
130 for j in 0..HIDDEN {
131 hidden[j] = self.w1[input_idx][j].tanh();
132 }
133
134 let mut logits = [0.0; VOCAB_SIZE];
135 for i in 0..VOCAB_SIZE {
136 for j in 0..HIDDEN {
137 logits[i] += hidden[j] * self.w2[j][i];
138 }
139 }
140
141 let mut probs = logits;
143 softmax(&mut probs);
144 probs[target_idx] -= 1.0; for j in 0..HIDDEN {
148 for i in 0..VOCAB_SIZE {
149 self.w2[j][i] -= LR * probs[i] * hidden[j];
150 }
151 }
152
153 for j in 0..HIDDEN {
155 let grad = (1.0 - hidden[j] * hidden[j]) * self.w2[j][target_idx];
157 self.w1[input_idx][j] -= LR * grad;
158 }
159 }
160
161 pub fn save(&self, filename: &str) -> std::io::Result<()> {
163 let mut file = File::create(filename)?;
164
165 file.write_all(&(VOCAB_SIZE as u32).to_le_bytes())?;
167 file.write_all(&(HIDDEN as u32).to_le_bytes())?;
168 file.write_all(&(CONTEXT_SIZE as u32).to_le_bytes())?;
169
170 for row in &self.w1 {
172 for &value in row {
173 file.write_all(&value.to_le_bytes())?;
174 }
175 }
176
177 for row in &self.w2 {
179 for &value in row {
180 file.write_all(&value.to_le_bytes())?;
181 }
182 }
183
184 file.write_all(&(self.context_to_index.len() as u32).to_le_bytes())?;
186
187 for (context, &index) in &self.context_to_index {
189 file.write_all(&(context.len() as u32).to_le_bytes())?;
191
192 for &c in context {
194 file.write_all(&(c as u32).to_le_bytes())?;
195 }
196
197 file.write_all(&(index as u32).to_le_bytes())?;
199 }
200
201 file.write_all(&(self.next_context_index as u32).to_le_bytes())?;
203
204 Ok(())
205 }
206
207 pub fn load(filename: &str) -> std::io::Result<Self> {
209 let mut file = File::open(filename)?;
210
211 let mut buffer = [0; 4];
213
214 file.read_exact(&mut buffer)?;
215 let vocab_size = u32::from_le_bytes(buffer) as usize;
216
217 file.read_exact(&mut buffer)?;
218 let hidden_size = u32::from_le_bytes(buffer) as usize;
219
220 file.read_exact(&mut buffer)?;
221 let context_size = u32::from_le_bytes(buffer) as usize;
222
223 if vocab_size != VOCAB_SIZE || hidden_size != HIDDEN || context_size != CONTEXT_SIZE {
225 return Err(std::io::Error::new(
226 std::io::ErrorKind::InvalidData,
227 format!(
228 "Dimensions do not match: expected {}x{}x{}, found {}x{}x{}",
229 VOCAB_SIZE, HIDDEN, CONTEXT_SIZE, vocab_size, hidden_size, context_size
230 ),
231 ));
232 }
233
234 let mut model = Self::new();
236
237 for i in 0..VOCAB_SIZE * CONTEXT_SIZE {
239 for j in 0..HIDDEN {
240 file.read_exact(&mut buffer)?;
241 model.w1[i][j] = f32::from_le_bytes(buffer);
242 }
243 }
244
245 for i in 0..HIDDEN {
247 for j in 0..VOCAB_SIZE {
248 file.read_exact(&mut buffer)?;
249 model.w2[i][j] = f32::from_le_bytes(buffer);
250 }
251 }
252
253 file.read_exact(&mut buffer)?;
255 let num_contexts = u32::from_le_bytes(buffer) as usize;
256
257 for _ in 0..num_contexts {
259 file.read_exact(&mut buffer)?;
261 let context_len = u32::from_le_bytes(buffer) as usize;
262
263 let mut context = Vec::with_capacity(context_len);
265 for _ in 0..context_len {
266 file.read_exact(&mut buffer)?;
267 context.push(u32::from_le_bytes(buffer) as usize);
268 }
269
270 file.read_exact(&mut buffer)?;
272 let index = u32::from_le_bytes(buffer) as usize;
273
274 model.context_to_index.insert(context, index);
276 }
277
278 file.read_exact(&mut buffer)?;
280 model.next_context_index = u32::from_le_bytes(buffer) as usize;
281
282 Ok(model)
283 }
284}
285
286pub fn softmax(x: &mut [f32]) {
288 let max = x.iter().copied().fold(f32::NEG_INFINITY, f32::max);
289 let sum: f32 = x
290 .iter_mut()
291 .map(|v| {
292 *v = (*v - max).exp();
293 *v
294 })
295 .sum();
296 for v in x.iter_mut() {
297 *v /= sum;
298 }
299}
300
301pub fn sample(probs: &[f32]) -> usize {
303 const TEMPERATURE: f32 = 0.3; let mut adjusted_probs = Vec::with_capacity(probs.len());
309 for &p in probs {
310 adjusted_probs.push(p.powf(1.0 / TEMPERATURE));
311 }
312
313 let sum: f32 = adjusted_probs.iter().sum();
315 for p in &mut adjusted_probs {
316 *p /= sum;
317 }
318
319 let mut cumulative_sum = 0.0;
321 let r: f32 = rand::random();
322 for (i, &p) in adjusted_probs.iter().enumerate() {
323 cumulative_sum += p;
324 if r < cumulative_sum {
325 return i;
326 }
327 }
328
329 VOCAB_SIZE - 1
330}
331
332pub fn prepare_training_data(text: &str) -> Vec<(Vec<usize>, usize)> {
334 let allowed_chars: Vec<char> = ALLOWED_CHARS.chars().collect();
336
337 let filtered_text: String = text.chars().filter(|c| allowed_chars.contains(c)).collect();
339
340 let chars: Vec<usize> = filtered_text.chars().map(|c| char_to_index(c)).collect();
342
343 let mut data = Vec::new();
344
345 for i in CONTEXT_SIZE..chars.len() {
347 let context = chars[i - CONTEXT_SIZE..i].to_vec();
348 let target = chars[i];
349 data.push((context, target));
350 }
351
352 data
353}
354
355pub fn char_to_index(c: char) -> usize {
357 *CHAR_TO_INDEX.get(&c).unwrap_or(&0)
358}
359
360pub fn index_to_char(idx: usize) -> char {
361 INDEX_TO_CHAR[idx % VOCAB_SIZE]
362}