interpolize 5.0.1

a rust program that scrapes discord, learns how your friends talk, and generates new messages in their collective voice. yes, this is what we're doing with our lives.
use std::io::{self, Write};
use crossterm::{
    execute,
    terminal::{Clear, ClearType},
    style::{Color, Print, ResetColor, SetForegroundColor},
    cursor,
    event::{self, Event, KeyCode, KeyModifiers, KeyEventKind},
    terminal::{enable_raw_mode, disable_raw_mode},
};
use crate::config::Config;
use crate::embed::Embeddings;
use crate::interpolate::Interpolator;
use crate::markov::Markov;
use crate::storage::Message;

pub struct Shell<'a> {
    config: &'a Config,
    interpolator: Interpolator<'a>,
    markov: Markov,
}

impl<'a> Shell<'a> {
    pub fn new(
        config: &'a Config,
        embeddings: &'a Embeddings,
        messages: &'a [Message],
        markov: Markov,
    ) -> Self {
        let interpolator = Interpolator::new(config, embeddings, messages);
        Self { config, interpolator, markov }
    }

    pub fn run(&self) -> anyhow::Result<()> {
        let mut stdout = io::stdout();
        enable_raw_mode()?;

        self.print_header(&mut stdout)?;

        let mut input = String::new();

        loop {
            self.print_prompt(&mut stdout)?;
            input.clear();

            loop {
                if let Event::Key(key) = event::read()? {
                    if key.kind != KeyEventKind::Press {
                        continue;
                    }
                    match (key.code, key.modifiers) {
                        (KeyCode::Char('c'), KeyModifiers::CONTROL) => {
                            disable_raw_mode()?;
                            execute!(stdout, Print("\n"))?;
                            return Ok(());
                        }
                        (KeyCode::Enter, _) => {
                            execute!(stdout, Print("\n"))?;
                            break;
                        }
                        (KeyCode::Backspace, _) => {
                            if !input.is_empty() {
                                input.pop();
                                execute!(
                                    stdout,
                                    cursor::MoveLeft(1),
                                    Print(" "),
                                    cursor::MoveLeft(1)
                                )?;
                            }
                        }
                        (KeyCode::Char(c), _) => {
                            input.push(c);
                            execute!(
                                stdout,
                                SetForegroundColor(Color::White),
                                Print(c),
                                ResetColor
                            )?;
                        }
                        _ => {}
                    }
                    stdout.flush()?;
                }
            }

            let query = input.trim().to_string();
            if query.is_empty() { continue; }
            if query == "exit" || query == "quit" {
                disable_raw_mode()?;
                execute!(stdout, Print("\n"))?;
                break;
            }

            self.respond(&query, &mut stdout)?;
        }

        Ok(())
    }

    pub fn answer(&self, query: &str) -> anyhow::Result<String> {
        let (seed, style_vec) = self.interpolator.build_seed(query);

        let response = self.markov.generate(
            &seed,
            64,
            Some(&style_vec),
            None,
            0.8,
        );

        Ok(response.join(" "))
    }

    fn respond(&self, query: &str, stdout: &mut io::Stdout) -> anyhow::Result<()> {
        let (seed, style_vec) = self.interpolator.build_seed(query);

        let response = self.markov.generate(
            &seed,
            64,
            Some(&style_vec),
            None,
            0.8,
        );

        execute!(stdout, SetForegroundColor(Color::DarkGrey), Print(""), ResetColor)?;

        for (i, tok) in response.iter().enumerate() {
            let color = self.token_color(tok);
            execute!(stdout, SetForegroundColor(color), Print(tok), ResetColor)?;
            if i < response.len() - 1 {
                execute!(stdout, Print(" "))?;
            }
            stdout.flush()?;
            std::thread::sleep(std::time::Duration::from_millis(30));
        }

        execute!(stdout, Print("\n\n"))?;
        Ok(())
    }

    fn token_color(&self, tok: &str) -> Color {
        let bytes = tok.bytes().fold(0u32, |acc, b| acc.wrapping_add(b as u32));
        match bytes % 4 {
            0 => Color::Cyan,
            1 => Color::Green,
            2 => Color::Magenta,
            _ => Color::White,
        }
    }

    fn print_header(&self, stdout: &mut io::Stdout) -> anyhow::Result<()> {
        execute!(stdout, Clear(ClearType::All), cursor::MoveTo(0, 0))?;

        execute!(
            stdout,
            SetForegroundColor(Color::Magenta),
            Print(" ░░ interpolize ░░\n"),
            ResetColor,
        )?;

        execute!(stdout, SetForegroundColor(Color::DarkGrey), Print(" channels: "), ResetColor)?;

        let weights = self.config.normalized_weights();
        for (ch, w) in self.config.channels.iter().zip(weights.iter()) {
            execute!(
                stdout,
                SetForegroundColor(Color::Cyan),
                Print(format!("{}({:.0}%) ", ch.name, w * 100.0)),
                ResetColor,
            )?;
        }

        execute!(stdout, Print("\n\n"))?;
        stdout.flush()?;
        Ok(())
    }

    fn print_prompt(&self, stdout: &mut io::Stdout) -> anyhow::Result<()> {
        execute!(
            stdout,
            SetForegroundColor(Color::Magenta),
            Print(""),
            ResetColor,
        )?;
        stdout.flush()?;
        Ok(())
    }
}