Documentation
use std::{
    cmp,
    io::{Read, Seek},
};

use crate::{
    errors::PllmError, transformer::Transformer, util::FloatVec, Config, Tokenizer, Weights,
};

const DEFAULT_STEPS: u32 = 256;

pub struct LLM {
    config: Config,
    tokenizer: Tokenizer,
    weights: Weights,
    transformer: Transformer,
}

impl LLM {
    pub fn new(config: Config, tokenizer: Tokenizer, weights: Weights) -> Self {
        let transformer = Transformer::new(config.clone());
        rayon::ThreadPoolBuilder::new()
            .num_threads(4)
            .build_global()
            .unwrap();

        Self {
            config,
            tokenizer,
            weights,
            transformer,
        }
    }

    pub fn inference(
        self,
        prompt: String,
        temperature: f32,
    ) -> Result<InferenceIterator, PllmError> {
        let steps = cmp::min(self.config.seq_len, DEFAULT_STEPS);

        let prompt_tokens = self.tokenizer.bpe_encode(prompt)?;
        if prompt_tokens.is_empty() {
            return Err(PllmError::Other("empty prompt".to_string()));
        }

        let iterator = InferenceIterator::new(self, prompt_tokens, steps, temperature);

        Ok(iterator)
    }
}

pub struct InferenceIterator {
    llm: LLM,
    prompt_tokens: Vec<u32>,
    steps: u32,
    temperature: f32,

    next_token: u32,
    pos: u32,
}

impl InferenceIterator {
    pub fn new(
        llm: LLM,
        prompt_tokens: Vec<u32>,
        steps: u32,
        temperature: f32,
    ) -> InferenceIterator {
        let next_token = prompt_tokens[0];
        InferenceIterator {
            llm,
            prompt_tokens,
            steps,
            temperature,
            next_token,
            pos: 0,
        }
    }
}

impl Iterator for InferenceIterator {
    type Item = Result<String, PllmError>;

    fn next(&mut self) -> Option<Self::Item> {
        let eos_token = self.llm.tokenizer.eos_token;

        if (self.pos != 0 && self.next_token == eos_token) || self.pos >= self.steps {
            return None;
        }

        let logits = match self
            .llm
            .transformer
            .run(self.next_token, self.pos, &self.llm.weights)
        {
            Ok(l) => l,
            Err(e) => {
                return Some(Err(e));
            }
        };

        let next_token = if self.pos as usize + 1 < self.prompt_tokens.len() {
            self.prompt_tokens[self.pos as usize + 1]
        } else {
            if self.temperature == 0.0 {
                logits.arg_max()
            } else {
                logits.iter_mut().for_each(|x| *x = *x / self.temperature);
                logits.soft_max();
                logits.sample()
            }
        };

        let token_str = match self.llm.tokenizer.get_token(next_token as usize) {
            Some(t) => t.replace('', " ").replace("<0x0A>", "\n"),
            None => {
                return Some(Err(PllmError::Other(format!(
                    "token not found, idx={}",
                    self.next_token
                ))))
            }
        };

        self.pos += 1;
        self.next_token = next_token;

        if next_token == eos_token {
            None
        } else {
            Some(Ok(token_str))
        }
    }
}

#[cfg(test)]
mod tests {
    use std::{
        fs::File,
        io::{self, BufReader, Write},
        time::Instant,
    };

    use crate::{
        gguf::GgufFile,
        llm::{Weights, LLM},
    };

    use super::{Config, Tokenizer};

    #[test]
    fn test_reader() {
        let f = File::open("testdata/stories15M.bin").unwrap();
        let mut reader = BufReader::new(f);
        let config = Config::from_reader(&mut reader).unwrap();
        println!("{:?}", config);

        let mut weights = Weights::new(config.clone());
        weights.load_data(&mut reader).unwrap();

        let tokenizer_file = File::open("testdata/tokenizer.bin").unwrap();
        let tokenizer_reader = BufReader::new(tokenizer_file);

        let tokenizer =
            Tokenizer::from_reader(config.vocab_size as usize, tokenizer_reader).unwrap();
        println!("{}", tokenizer.max_token_length);

        let iterator = LLM::new(config, tokenizer, weights)
            .inference("a dog".to_string(), 0.8)
            .unwrap();
        let steps = iterator.steps;
        let start = Instant::now();
        for i in iterator {
            print!("{}", i.unwrap());
            io::stdout().flush().unwrap();
        }
        println!(
            "\ntoken/s: {}\n",
            (steps as u64 - 1) / start.elapsed().as_secs()
        );
    }

    #[test]
    fn test_gguf() {
        let f = File::open("testdata/gemma2b").unwrap();
        let reader = BufReader::new(f);
        let mut gf = GgufFile::from_reader(reader).unwrap();

        let config = Config::from_gguf(&gf).unwrap();
        println!("{:?}", config.clone());

        let tokenizer = Tokenizer::from_gguf(&gf).unwrap();
        println!("{}", tokenizer.max_token_length);

        let mut weights = Weights::new(config.clone());
        weights.load_from_gguf(&mut gf, config.clone()).unwrap();

        let iterator = LLM::new(config, tokenizer, weights)
            .inference("a dog".to_string(), 0.8)
            .unwrap();
        let steps = iterator.steps;
        let start = Instant::now();
        for i in iterator {
            print!("{}", i.unwrap());
            io::stdout().flush().unwrap();
        }
        println!(
            "\ntoken/s: {}\n",
            (steps as u64 - 1) / start.elapsed().as_secs()
        );
    }
}