rml_core/
lib.rs

1use rand::Rng;
2use std::collections::HashMap;
3use std::fs::File;
4use std::io::{Read, Write};
5
6pub const ALLOWED_CHARS: &str =
7    "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789 .,!?:;-";
8const VOCAB_SIZE: usize = ALLOWED_CHARS.len(); // e.g. 70
9
10lazy_static::lazy_static! {
11    static ref CHAR_TO_INDEX: HashMap<char, usize> = {
12        let mut map = HashMap::new();
13        for (i, c) in ALLOWED_CHARS.chars().enumerate() {
14            map.insert(c, i);
15        }
16        map
17    };
18
19    static ref INDEX_TO_CHAR: Vec<char> = {
20        ALLOWED_CHARS.chars().collect()
21    };
22}
23
24const HIDDEN: usize = 128; // Significantly larger hidden layer for more complex patterns
25const LR: f32 = 0.005; // Lower learning rate for more stable convergence
26pub const CONTEXT_SIZE: usize = 4; // Larger context for better text coherence
27
28#[derive(Clone)]
29pub struct NGramModel {
30    // Weights for the input -> hidden layer
31    w1: [[f32; HIDDEN]; VOCAB_SIZE * CONTEXT_SIZE],
32    // Weights for the hidden -> output layer
33    w2: [[f32; VOCAB_SIZE]; HIDDEN],
34    // Mapping from n-grams to indices
35    context_to_index: HashMap<Vec<usize>, usize>,
36    next_context_index: usize,
37}
38
39impl NGramModel {
40    pub fn new() -> Self {
41        let mut rng = rand::thread_rng();
42
43        // Initialization of weight matrices
44        let mut w1 = [[0.0; HIDDEN]; VOCAB_SIZE * CONTEXT_SIZE];
45        let mut w2 = [[0.0; VOCAB_SIZE]; HIDDEN];
46
47        // Random initialization of weights
48        for i in 0..VOCAB_SIZE * CONTEXT_SIZE {
49            for j in 0..HIDDEN {
50                w1[i][j] = rng.gen_range(-0.1..0.1); // Smaller initialization for more stable training
51            }
52        }
53
54        for i in 0..HIDDEN {
55            for j in 0..VOCAB_SIZE {
56                w2[i][j] = rng.gen_range(-0.1..0.1);
57            }
58        }
59
60        Self {
61            w1,
62            w2,
63            context_to_index: HashMap::new(),
64            next_context_index: 0,
65        }
66    }
67
68    // Helper function to determine or create an index for a context
69    fn get_or_create_context_index(&mut self, context: &[usize]) -> usize {
70        let context_vec = context.to_vec();
71
72        if let Some(&idx) = self.context_to_index.get(&context_vec) {
73            return idx;
74        }
75
76        // If we have reached the maximum number of contexts, we use a fallback
77        if self.next_context_index >= VOCAB_SIZE * CONTEXT_SIZE {
78            // Simple hash-based fallback
79            let mut hash = 0;
80            for &c in context {
81                hash = (hash * 31 + c) % (VOCAB_SIZE * CONTEXT_SIZE);
82            }
83            return hash;
84        }
85
86        let idx = self.next_context_index;
87        self.context_to_index.insert(context_vec, idx);
88        self.next_context_index += 1;
89        idx
90    }
91
92    pub fn forward(&mut self, context: &[usize]) -> Vec<f32> {
93        if context.len() != CONTEXT_SIZE {
94            panic!("Context must contain exactly {} characters", CONTEXT_SIZE);
95        }
96
97        let input_idx = self.get_or_create_context_index(context);
98
99        // Forward pass to hidden layer
100        let mut hidden = [0.0; HIDDEN];
101        for j in 0..HIDDEN {
102            hidden[j] = self.w1[input_idx][j].tanh();
103        }
104
105        // Forward pass to output layer
106        let mut output = vec![0.0; VOCAB_SIZE];
107        for i in 0..VOCAB_SIZE {
108            for j in 0..HIDDEN {
109                output[i] += hidden[j] * self.w2[j][i];
110            }
111        }
112
113        // Softmax normalization for probabilities
114        softmax(&mut output);
115        output
116    }
117
118    pub fn train(&mut self, context: &[usize], target: usize) {
119        if context.len() != CONTEXT_SIZE {
120            panic!("Context must contain exactly {} characters", CONTEXT_SIZE);
121        }
122
123        // Ensure the target is within the vocabulary
124        let target_idx = target;
125
126        let input_idx = self.get_or_create_context_index(context);
127
128        // Forward pass, as in forward()
129        let mut hidden = [0.0; HIDDEN];
130        for j in 0..HIDDEN {
131            hidden[j] = self.w1[input_idx][j].tanh();
132        }
133
134        let mut logits = [0.0; VOCAB_SIZE];
135        for i in 0..VOCAB_SIZE {
136            for j in 0..HIDDEN {
137                logits[i] += hidden[j] * self.w2[j][i];
138            }
139        }
140
141        // Calculation of error with softmax cross-entropy
142        let mut probs = logits;
143        softmax(&mut probs);
144        probs[target_idx] -= 1.0; // Simple form of cross-entropy derivative
145
146        // Backpropagation for W2 (Hidden -> Output)
147        for j in 0..HIDDEN {
148            for i in 0..VOCAB_SIZE {
149                self.w2[j][i] -= LR * probs[i] * hidden[j];
150            }
151        }
152
153        // Backpropagation for W1 (Input -> Hidden)
154        for j in 0..HIDDEN {
155            // Derivative of tanh activation: 1 - tanh²(x)
156            let grad = (1.0 - hidden[j] * hidden[j]) * self.w2[j][target_idx];
157            self.w1[input_idx][j] -= LR * grad;
158        }
159    }
160
161    // Save the model to a file
162    pub fn save(&self, filename: &str) -> std::io::Result<()> {
163        let mut file = File::create(filename)?;
164
165        // Save the dimensions
166        file.write_all(&(VOCAB_SIZE as u32).to_le_bytes())?;
167        file.write_all(&(HIDDEN as u32).to_le_bytes())?;
168        file.write_all(&(CONTEXT_SIZE as u32).to_le_bytes())?;
169
170        // Save W1
171        for row in &self.w1 {
172            for &value in row {
173                file.write_all(&value.to_le_bytes())?;
174            }
175        }
176
177        // Save W2
178        for row in &self.w2 {
179            for &value in row {
180                file.write_all(&value.to_le_bytes())?;
181            }
182        }
183
184        // Save the number of contexts
185        file.write_all(&(self.context_to_index.len() as u32).to_le_bytes())?;
186
187        // Save the context map
188        for (context, &index) in &self.context_to_index {
189            // Save the context length
190            file.write_all(&(context.len() as u32).to_le_bytes())?;
191
192            // Save the context elements
193            for &c in context {
194                file.write_all(&(c as u32).to_le_bytes())?;
195            }
196
197            // Save the index
198            file.write_all(&(index as u32).to_le_bytes())?;
199        }
200
201        // Save the next context index
202        file.write_all(&(self.next_context_index as u32).to_le_bytes())?;
203
204        Ok(())
205    }
206
207    // Load the model from a file
208    pub fn load(filename: &str) -> std::io::Result<Self> {
209        let mut file = File::open(filename)?;
210
211        // Read the dimensions
212        let mut buffer = [0; 4];
213
214        file.read_exact(&mut buffer)?;
215        let vocab_size = u32::from_le_bytes(buffer) as usize;
216
217        file.read_exact(&mut buffer)?;
218        let hidden_size = u32::from_le_bytes(buffer) as usize;
219
220        file.read_exact(&mut buffer)?;
221        let context_size = u32::from_le_bytes(buffer) as usize;
222
223        // Check if the dimensions match
224        if vocab_size != VOCAB_SIZE || hidden_size != HIDDEN || context_size != CONTEXT_SIZE {
225            return Err(std::io::Error::new(
226                std::io::ErrorKind::InvalidData,
227                format!(
228                    "Dimensions do not match: expected {}x{}x{}, found {}x{}x{}",
229                    VOCAB_SIZE, HIDDEN, CONTEXT_SIZE, vocab_size, hidden_size, context_size
230                ),
231            ));
232        }
233
234        // Create a new model
235        let mut model = Self::new();
236
237        // Read W1
238        for i in 0..VOCAB_SIZE * CONTEXT_SIZE {
239            for j in 0..HIDDEN {
240                file.read_exact(&mut buffer)?;
241                model.w1[i][j] = f32::from_le_bytes(buffer);
242            }
243        }
244
245        // Read W2
246        for i in 0..HIDDEN {
247            for j in 0..VOCAB_SIZE {
248                file.read_exact(&mut buffer)?;
249                model.w2[i][j] = f32::from_le_bytes(buffer);
250            }
251        }
252
253        // Read the number of contexts
254        file.read_exact(&mut buffer)?;
255        let num_contexts = u32::from_le_bytes(buffer) as usize;
256
257        // Read the context map
258        for _ in 0..num_contexts {
259            // Read the context length
260            file.read_exact(&mut buffer)?;
261            let context_len = u32::from_le_bytes(buffer) as usize;
262
263            // Read the context elements
264            let mut context = Vec::with_capacity(context_len);
265            for _ in 0..context_len {
266                file.read_exact(&mut buffer)?;
267                context.push(u32::from_le_bytes(buffer) as usize);
268            }
269
270            // Read the index
271            file.read_exact(&mut buffer)?;
272            let index = u32::from_le_bytes(buffer) as usize;
273
274            // Add the context to the map
275            model.context_to_index.insert(context, index);
276        }
277
278        // Read the next context index
279        file.read_exact(&mut buffer)?;
280        model.next_context_index = u32::from_le_bytes(buffer) as usize;
281
282        Ok(model)
283    }
284}
285
286// Helper function: Softmax for normalizing output probabilities
287pub fn softmax(x: &mut [f32]) {
288    let max = x.iter().copied().fold(f32::NEG_INFINITY, f32::max);
289    let sum: f32 = x
290        .iter_mut()
291        .map(|v| {
292            *v = (*v - max).exp();
293            *v
294        })
295        .sum();
296    for v in x.iter_mut() {
297        *v /= sum;
298    }
299}
300
301// Helper function: Sampling from probabilities with temperature
302pub fn sample(probs: &[f32]) -> usize {
303    // Lower temperature = more conservative decisions (more typical characters)
304    // Higher temperature = more randomness and creativity
305    const TEMPERATURE: f32 = 0.3; // Even lower temperature for more conservative selection
306
307    // Copy the probabilities and apply temperature
308    let mut adjusted_probs = Vec::with_capacity(probs.len());
309    for &p in probs {
310        adjusted_probs.push(p.powf(1.0 / TEMPERATURE));
311    }
312
313    // Renormalize
314    let sum: f32 = adjusted_probs.iter().sum();
315    for p in &mut adjusted_probs {
316        *p /= sum;
317    }
318
319    // Perform sampling
320    let mut cumulative_sum = 0.0;
321    let r: f32 = rand::random();
322    for (i, &p) in adjusted_probs.iter().enumerate() {
323        cumulative_sum += p;
324        if r < cumulative_sum {
325            return i;
326        }
327    }
328
329    VOCAB_SIZE - 1
330}
331
332// Helper function to prepare a text for training
333pub fn prepare_training_data(text: &str) -> Vec<(Vec<usize>, usize)> {
334    // Only use letters, numbers, spaces and some punctuation marks
335    let allowed_chars: Vec<char> = ALLOWED_CHARS.chars().collect();
336
337    // Filter the characters
338    let filtered_text: String = text.chars().filter(|c| allowed_chars.contains(c)).collect();
339
340    // Convert to indices
341    let chars: Vec<usize> = filtered_text.chars().map(|c| char_to_index(c)).collect();
342
343    let mut data = Vec::new();
344
345    // Fill our training data with contexts and target characters
346    for i in CONTEXT_SIZE..chars.len() {
347        let context = chars[i - CONTEXT_SIZE..i].to_vec();
348        let target = chars[i];
349        data.push((context, target));
350    }
351
352    data
353}
354
355// Helper function for converting text to indices and back
356pub fn char_to_index(c: char) -> usize {
357    *CHAR_TO_INDEX.get(&c).unwrap_or(&0)
358}
359
360pub fn index_to_char(idx: usize) -> char {
361    INDEX_TO_CHAR[idx % VOCAB_SIZE]
362}