use std::io::{self, Write};
use crossterm::{
execute,
terminal::{Clear, ClearType},
style::{Color, Print, ResetColor, SetForegroundColor},
cursor,
event::{self, Event, KeyCode, KeyModifiers, KeyEventKind},
terminal::{enable_raw_mode, disable_raw_mode},
};
use crate::config::Config;
use crate::embed::Embeddings;
use crate::interpolate::Interpolator;
use crate::markov::Markov;
use crate::storage::Message;
pub struct Shell<'a> {
config: &'a Config,
interpolator: Interpolator<'a>,
markov: Markov,
}
impl<'a> Shell<'a> {
pub fn new(
config: &'a Config,
embeddings: &'a Embeddings,
messages: &'a [Message],
markov: Markov,
) -> Self {
let interpolator = Interpolator::new(config, embeddings, messages);
Self { config, interpolator, markov }
}
pub fn run(&self) -> anyhow::Result<()> {
let mut stdout = io::stdout();
enable_raw_mode()?;
self.print_header(&mut stdout)?;
let mut input = String::new();
loop {
self.print_prompt(&mut stdout)?;
input.clear();
loop {
if let Event::Key(key) = event::read()? {
if key.kind != KeyEventKind::Press {
continue;
}
match (key.code, key.modifiers) {
(KeyCode::Char('c'), KeyModifiers::CONTROL) => {
disable_raw_mode()?;
execute!(stdout, Print("\n"))?;
return Ok(());
}
(KeyCode::Enter, _) => {
execute!(stdout, Print("\n"))?;
break;
}
(KeyCode::Backspace, _) => {
if !input.is_empty() {
input.pop();
execute!(
stdout,
cursor::MoveLeft(1),
Print(" "),
cursor::MoveLeft(1)
)?;
}
}
(KeyCode::Char(c), _) => {
input.push(c);
execute!(
stdout,
SetForegroundColor(Color::White),
Print(c),
ResetColor
)?;
}
_ => {}
}
stdout.flush()?;
}
}
let query = input.trim().to_string();
if query.is_empty() { continue; }
if query == "exit" || query == "quit" {
disable_raw_mode()?;
execute!(stdout, Print("\n"))?;
break;
}
self.respond(&query, &mut stdout)?;
}
Ok(())
}
pub fn answer(&self, query: &str) -> anyhow::Result<String> {
let (seed, style_vec) = self.interpolator.build_seed(query);
let response = self.markov.generate(
&seed,
64,
Some(&style_vec),
None,
0.8,
);
Ok(response.join(" "))
}
fn respond(&self, query: &str, stdout: &mut io::Stdout) -> anyhow::Result<()> {
let (seed, style_vec) = self.interpolator.build_seed(query);
let response = self.markov.generate(
&seed,
64,
Some(&style_vec),
None,
0.8,
);
execute!(stdout, SetForegroundColor(Color::DarkGrey), Print(" ► "), ResetColor)?;
for (i, tok) in response.iter().enumerate() {
let color = self.token_color(tok);
execute!(stdout, SetForegroundColor(color), Print(tok), ResetColor)?;
if i < response.len() - 1 {
execute!(stdout, Print(" "))?;
}
stdout.flush()?;
std::thread::sleep(std::time::Duration::from_millis(30));
}
execute!(stdout, Print("\n\n"))?;
Ok(())
}
fn token_color(&self, tok: &str) -> Color {
let bytes = tok.bytes().fold(0u32, |acc, b| acc.wrapping_add(b as u32));
match bytes % 4 {
0 => Color::Cyan,
1 => Color::Green,
2 => Color::Magenta,
_ => Color::White,
}
}
fn print_header(&self, stdout: &mut io::Stdout) -> anyhow::Result<()> {
execute!(stdout, Clear(ClearType::All), cursor::MoveTo(0, 0))?;
execute!(
stdout,
SetForegroundColor(Color::Magenta),
Print(" ░░ interpolize ░░\n"),
ResetColor,
)?;
execute!(stdout, SetForegroundColor(Color::DarkGrey), Print(" channels: "), ResetColor)?;
let weights = self.config.normalized_weights();
for (ch, w) in self.config.channels.iter().zip(weights.iter()) {
execute!(
stdout,
SetForegroundColor(Color::Cyan),
Print(format!("{}({:.0}%) ", ch.name, w * 100.0)),
ResetColor,
)?;
}
execute!(stdout, Print("\n\n"))?;
stdout.flush()?;
Ok(())
}
fn print_prompt(&self, stdout: &mut io::Stdout) -> anyhow::Result<()> {
execute!(
stdout,
SetForegroundColor(Color::Magenta),
Print("〉 "),
ResetColor,
)?;
stdout.flush()?;
Ok(())
}
}