use std::collections::HashMap;
use ndarray::{Array1, Array2};
use ndarray_rand::RandomExt;
use ndarray_rand::rand_distr::Uniform;
use crate::optimizers::Optimizer;
#[derive(Clone, Debug)]
pub struct TextVocabulary {
char_to_idx: HashMap<char, usize>,
idx_to_char: HashMap<usize, char>,
vocab_size: usize,
}
impl TextVocabulary {
pub fn from_text(text: &str) -> Self {
let mut chars: Vec<char> = text.chars().collect::<std::collections::HashSet<_>>()
.into_iter().collect();
chars.sort();
let vocab_size = chars.len();
let char_to_idx: HashMap<char, usize> = chars.iter()
.enumerate()
.map(|(i, &c)| (c, i))
.collect();
let idx_to_char: HashMap<usize, char> = chars.iter()
.enumerate()
.map(|(i, &c)| (i, c))
.collect();
Self { char_to_idx, idx_to_char, vocab_size }
}
pub fn from_chars(chars: &[char]) -> Self {
let vocab_size = chars.len();
let char_to_idx: HashMap<char, usize> = chars.iter()
.enumerate()
.map(|(i, &c)| (c, i))
.collect();
let idx_to_char: HashMap<usize, char> = chars.iter()
.enumerate()
.map(|(i, &c)| (i, c))
.collect();
Self { char_to_idx, idx_to_char, vocab_size }
}
pub fn char_to_index(&self, ch: char) -> Option<usize> {
self.char_to_idx.get(&ch).copied()
}
pub fn index_to_char(&self, idx: usize) -> Option<char> {
self.idx_to_char.get(&idx).copied()
}
pub fn size(&self) -> usize {
self.vocab_size
}
pub fn contains(&self, ch: char) -> bool {
self.char_to_idx.contains_key(&ch)
}
pub fn chars(&self) -> Vec<char> {
let mut chars: Vec<_> = self.idx_to_char.iter().collect();
chars.sort_by_key(|(idx, _)| *idx);
chars.into_iter().map(|(_, &ch)| ch).collect()
}
pub fn encode(&self, text: &str) -> Vec<usize> {
text.chars()
.filter_map(|ch| self.char_to_index(ch))
.collect()
}
pub fn decode(&self, indices: &[usize]) -> String {
indices.iter()
.filter_map(|&idx| self.index_to_char(idx))
.collect()
}
}
#[derive(Clone, Debug)]
pub struct EmbeddingGradients {
pub weight: Array2<f64>,
}
#[derive(Clone, Debug)]
pub struct CharacterEmbedding {
pub weight: Array2<f64>, vocab_size: usize,
embed_dim: usize,
input_cache: Option<Vec<usize>>,
}
impl CharacterEmbedding {
pub fn new(vocab_size: usize, embed_dim: usize) -> Self {
let scale = (1.0 / embed_dim as f64).sqrt();
let weight = Array2::random((vocab_size, embed_dim), Uniform::new(-scale, scale));
Self {
weight,
vocab_size,
embed_dim,
input_cache: None,
}
}
pub fn new_zeros(vocab_size: usize, embed_dim: usize) -> Self {
Self {
weight: Array2::zeros((vocab_size, embed_dim)),
vocab_size,
embed_dim,
input_cache: None,
}
}
pub fn from_weights(weight: Array2<f64>) -> Self {
let (vocab_size, embed_dim) = weight.dim();
Self {
weight,
vocab_size,
embed_dim,
input_cache: None,
}
}
pub fn embed_dim(&self) -> usize {
self.embed_dim
}
pub fn vocab_size(&self) -> usize {
self.vocab_size
}
pub fn lookup(&self, char_idx: usize) -> Array1<f64> {
assert!(char_idx < self.vocab_size, "Index {} out of vocabulary size {}", char_idx, self.vocab_size);
self.weight.row(char_idx).to_owned()
}
pub fn forward(&mut self, char_indices: &[usize]) -> Array2<f64> {
self.input_cache = Some(char_indices.to_vec());
let seq_len = char_indices.len();
let mut output = Array2::zeros((seq_len, self.embed_dim));
for (i, &idx) in char_indices.iter().enumerate() {
assert!(idx < self.vocab_size, "Index {} out of vocabulary size {}", idx, self.vocab_size);
output.row_mut(i).assign(&self.weight.row(idx));
}
output
}
pub fn backward(&self, grad_output: &Array2<f64>) -> EmbeddingGradients {
let indices = self.input_cache.as_ref().expect("No cached input for backward pass");
let mut weight_grad = Array2::zeros((self.vocab_size, self.embed_dim));
for (i, &idx) in indices.iter().enumerate() {
for j in 0..self.embed_dim {
weight_grad[[idx, j]] += grad_output[[i, j]];
}
}
EmbeddingGradients { weight: weight_grad }
}
pub fn update_parameters<O: Optimizer>(&mut self, gradients: &EmbeddingGradients, optimizer: &mut O, prefix: &str) {
optimizer.update(&format!("{}_weight", prefix), &mut self.weight, &gradients.weight);
}
pub fn num_parameters(&self) -> usize {
self.weight.len()
}
}
pub fn sample_with_temperature(logits: &Array1<f64>, temperature: f64) -> usize {
assert!(temperature > 0.0, "Temperature must be positive");
let scaled: Vec<f64> = logits.iter().map(|&x| x / temperature).collect();
let max_val = scaled.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
let exp_vals: Vec<f64> = scaled.iter().map(|&x| (x - max_val).exp()).collect();
let sum: f64 = exp_vals.iter().sum();
let probs: Vec<f64> = exp_vals.iter().map(|&x| x / sum).collect();
let mut rng_val = rand::random::<f64>();
for (i, &prob) in probs.iter().enumerate() {
rng_val -= prob;
if rng_val <= 0.0 {
return i;
}
}
probs.len() - 1
}
pub fn sample_top_k(logits: &Array1<f64>, k: usize, temperature: f64) -> usize {
assert!(k > 0, "k must be positive");
assert!(temperature > 0.0, "Temperature must be positive");
let k = k.min(logits.len());
let mut indexed: Vec<(usize, f64)> = logits.iter().enumerate().map(|(i, &v)| (i, v)).collect();
indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
let top_k: Vec<(usize, f64)> = indexed.into_iter().take(k).collect();
let scaled: Vec<f64> = top_k.iter().map(|(_, v)| v / temperature).collect();
let max_val = scaled.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
let exp_vals: Vec<f64> = scaled.iter().map(|&x| (x - max_val).exp()).collect();
let sum: f64 = exp_vals.iter().sum();
let probs: Vec<f64> = exp_vals.iter().map(|&x| x / sum).collect();
let mut rng_val = rand::random::<f64>();
for (i, &prob) in probs.iter().enumerate() {
rng_val -= prob;
if rng_val <= 0.0 {
return top_k[i].0;
}
}
top_k[k - 1].0
}
pub fn sample_nucleus(logits: &Array1<f64>, p: f64, temperature: f64) -> usize {
assert!(p > 0.0 && p <= 1.0, "p must be in (0, 1]");
assert!(temperature > 0.0, "Temperature must be positive");
let scaled: Vec<f64> = logits.iter().map(|&x| x / temperature).collect();
let max_val = scaled.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
let exp_vals: Vec<f64> = scaled.iter().map(|&x| (x - max_val).exp()).collect();
let sum: f64 = exp_vals.iter().sum();
let probs: Vec<f64> = exp_vals.iter().map(|&x| x / sum).collect();
let mut indexed: Vec<(usize, f64)> = probs.iter().enumerate().map(|(i, &v)| (i, v)).collect();
indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
let mut cumulative = 0.0;
let mut nucleus: Vec<(usize, f64)> = Vec::new();
for (idx, prob) in indexed {
cumulative += prob;
nucleus.push((idx, prob));
if cumulative >= p {
break;
}
}
let nucleus_sum: f64 = nucleus.iter().map(|(_, prob)| prob).sum();
let nucleus_probs: Vec<f64> = nucleus.iter().map(|(_, prob)| prob / nucleus_sum).collect();
let mut rng_val = rand::random::<f64>();
for (i, &prob) in nucleus_probs.iter().enumerate() {
rng_val -= prob;
if rng_val <= 0.0 {
return nucleus[i].0;
}
}
nucleus.last().map(|(idx, _)| *idx).unwrap_or(0)
}
pub fn argmax(logits: &Array1<f64>) -> usize {
logits.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
.map(|(idx, _)| idx)
.unwrap_or(0)
}
pub fn softmax(logits: &Array1<f64>) -> Array1<f64> {
let max_val = logits.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
let exp_vals: Array1<f64> = logits.mapv(|x| (x - max_val).exp());
let sum: f64 = exp_vals.sum();
exp_vals / sum
}
#[cfg(test)]
mod tests {
use super::*;
use ndarray::arr1;
#[test]
fn test_vocabulary_from_text() {
let vocab = TextVocabulary::from_text("hello");
assert_eq!(vocab.size(), 4); assert!(vocab.contains('h'));
assert!(vocab.contains('l'));
assert!(!vocab.contains('x'));
}
#[test]
fn test_vocabulary_encode_decode() {
let vocab = TextVocabulary::from_text("abc");
let encoded = vocab.encode("cab");
let decoded = vocab.decode(&encoded);
assert_eq!(decoded, "cab");
}
#[test]
fn test_embedding_forward() {
let mut emb = CharacterEmbedding::new(10, 8);
let output = emb.forward(&[0, 3, 5]);
assert_eq!(output.shape(), &[3, 8]);
}
#[test]
fn test_embedding_lookup() {
let emb = CharacterEmbedding::new(10, 8);
let vec = emb.lookup(5);
assert_eq!(vec.len(), 8);
}
#[test]
fn test_sample_with_temperature() {
let logits = arr1(&[1.0, 2.0, 3.0]);
let idx = sample_with_temperature(&logits, 1.0);
assert!(idx < 3);
}
#[test]
fn test_sample_top_k() {
let logits = arr1(&[1.0, 5.0, 2.0, 0.5]);
let idx = sample_top_k(&logits, 2, 1.0);
assert!(idx == 1 || idx == 2);
}
#[test]
fn test_sample_nucleus() {
let logits = arr1(&[0.0, 10.0, 0.0]); let idx = sample_nucleus(&logits, 0.9, 1.0);
assert_eq!(idx, 1); }
#[test]
fn test_argmax() {
let logits = arr1(&[1.0, 5.0, 2.0]);
assert_eq!(argmax(&logits), 1);
}
#[test]
fn test_softmax() {
let logits = arr1(&[1.0, 2.0, 3.0]);
let probs = softmax(&logits);
assert!((probs.sum() - 1.0).abs() < 1e-6);
}
}