eld_llm 0.0.1

An LLM built from scratch in Rust
use crate::tokenizers::benchmark::dev::traits::TokenizerTrainingMetrics;
use crate::tokenizers::special_tokens::UNKNOWN;
use crate::tokenizers::traits::{Tokenizer, TokenizerError, TokenIdType};
use crate::utils::corpus_tracker::CorpusTracker; 



use std::collections::HashMap;
use std::io::BufRead;
use regex::Regex;
use serde_json;


//import time and memory measurement tools 
use std::time::Instant;

// NaiveTokenizer: splits tokens into words   
// save as json file with serde  
#[derive(Debug, serde_derive::Serialize, serde_derive::Deserialize)]
#[serde(bound(serialize = "T: serde::Serialize", deserialize = "T: serde::de::DeserializeOwned"))]
pub struct NaiveTokenizer<T: TokenIdType = u64> {
    pub string_to_id: HashMap<String, T>,
    pub id_to_string: HashMap<T, String>,
    pub corpus: CorpusTracker,
    pub unknown_token_id: Option<T>,
    pub tokenizer_path: Option<String>,
}

impl<T: TokenIdType> NaiveTokenizer<T> {
    pub fn new(tokenizer_path: Option<String>) -> Self {
        Self {
            string_to_id: HashMap::new(),
            id_to_string: HashMap::new(),
            corpus: CorpusTracker::new(),
            unknown_token_id: None,
            tokenizer_path,
        }
    }

    fn init(&mut self) {
        let id = T::from_usize(0);
        self.string_to_id.insert(" ".to_string(), id);
        self.id_to_string.insert(id, " ".to_string());
    }
}

impl<T> Tokenizer for NaiveTokenizer<T>
where
    T: TokenIdType + serde::Serialize + serde::de::DeserializeOwned + std::fmt::Debug + Eq + std::hash::Hash + Clone,
{
    type TokenId = T;

    fn name(&self) -> &'static str {
        "NaiveTokenizer"
    }

    fn corpus(&self) -> &CorpusTracker {
        &self.corpus
    }

    fn corpus_mut(&mut self) -> &mut CorpusTracker {
        &mut self.corpus
    }

    fn vocab_size(&self) -> usize {
        self.string_to_id.len()
    }

    fn train(&mut self, data_path: &str, corpus_name: &str) -> Result<TokenizerTrainingMetrics, TokenizerError> {
        let start_time = Instant::now();

        // Add default tokens 
        self.init();

        if !std::path::Path::new(data_path).is_file() {
            return Err(format!("'{}' is not a valid file", data_path));
        }

        let file = std::fs::File::open(data_path)
            .map_err(|e| format!("Failed to open '{}': {}", data_path, e))?;
        let reader = std::io::BufReader::new(file);


        // Update corpus tracker with the size of the training data
        let size = std::fs::metadata(data_path)
            .map_err(|e| format!("Failed to get metadata for '{}': {}", data_path, e))?
            .len();
        self.corpus.add(corpus_name.to_string(), size);

        // Read the file line by line and split into tokens by whitespace and add new tokens to the vocab
        for line in reader.lines() {
            let line = line.map_err(|e| format!("Failed to read line: {}", e))?;
            let tokens: Vec<&str> = line.split_whitespace().collect();
            for token in tokens {  
                if self.string_to_id.contains_key(token) {
                    continue;
                }
                let id = T::from_usize(self.string_to_id.len());
                self.string_to_id.insert(token.to_string(), id);
                self.id_to_string.insert(id, token.to_string());
            }
        }

        // Add the special tokens at the end
        let unkown_token_id = T::from_usize(self.string_to_id.len());
        self.id_to_string.insert(unkown_token_id, UNKNOWN.to_owned());
        self.string_to_id.insert(UNKNOWN.to_owned(), unkown_token_id);
        self.unknown_token_id = Some(unkown_token_id);

        // Return training metrics
        Ok(TokenizerTrainingMetrics {
            training_time_sec: start_time.elapsed().as_secs(),
        })
    }
  
    fn encode(&self, text: &str) -> Vec<Self::TokenId> { 
        // must have an unknown token id to encode. 
        assert!(self.unknown_token_id.is_some(), "Tokenizer must be trained before encoding");

        let mut res: Vec<Self::TokenId> = vec![]; 

        // split text by withespace and include the whitespaces in the vector using regex
        let re = Regex::new(r"\S+|\s+").unwrap();
        let text_tokens: Vec<&str> = re.find_iter(text).map(|mat| mat.as_str()).collect();

        // tokenize the text and convert to token ids, using the UNKNOWN token id for any token not in the vocab
        for token in text_tokens {
            if let Some(&id) = self.string_to_id.get(token.to_lowercase().as_str()) {
                res.push(id);
            } else {
                // If token not found, push the UNKNOWN token id
                res.push(self.unknown_token_id.unwrap());
            }
        }
        res
    }

    fn decode(&self, tokens: &[Self::TokenId]) -> String {
        let mut res = String::new();
        for &token_id in tokens {
            if let Some(token) = self.id_to_string.get(&token_id) {
                res.push_str(token);
            } else {
                // If token id not found, push the UNKNOWN token
                res.push_str(UNKNOWN);
            }
        }
        res
    }

    fn save(&self, path: &str) -> Result<(), TokenizerError> {
        let json = serde_json::to_string(self)
            .map_err(|e| format!("Failed to serialize tokenizer: {}", e))?;
        std::fs::write(path, json)
            .map_err(|e| format!("Failed to write tokenizer to '{}': {}", path, e))?;
        Ok(())
    }

    fn load(&mut self, path: &str) -> Result<(), TokenizerError> {
        let json = std::fs::read_to_string(path)
            .map_err(|e| format!("Failed to read tokenizer from '{}': {}", path, e))?;
        *self = serde_json::from_str(&json)
            .map_err(|e| format!("Failed to deserialize tokenizer: {}", e))?;
        Ok(())
    }

    fn vocab_tokens(&self) -> Vec<String> {
        self.string_to_id.keys().cloned().collect()
    }

    fn size(&self) -> usize {
        // calculate the size of the saved tokenizer in bytes
        if let Some(ref path) = self.tokenizer_path {
            std::fs::metadata(path)
                .map(|meta| meta.len() as usize)
                .unwrap_or(0)
        } else {
            0
        }
    }
}