use std::fmt::Display;
use std::fs;
use std::io::{self, StdoutLock};
use std::time::Duration;
use clap::Parser;
use lazy_static::lazy_static;
use regex::Regex;
use crossterm::event::{Event, KeyCode, KeyEvent, KeyModifiers};
use crossterm::terminal::{
disable_raw_mode, enable_raw_mode, EnterAlternateScreen, LeaveAlternateScreen,
};
use futures::{Stream, StreamExt};
use ratatui::backend::CrosstermBackend;
use ratatui::layout::{Alignment, Constraint, Direction};
use ratatui::widgets::{Block, Borders, Paragraph, Wrap};
use ratatui::Terminal;
use tui_input::backend::crossterm::EventHandler;
use tui_input::Input;
use crate::anthropic::AnthropicModel;
use crate::context::Context;
use crate::model::Task;
use crate::openai::OpenAIGPTModel;
use crate::{model_stream_request, AskConfig, ConfigKind, ExplainConfig, ModelError, ModelKind};
#[derive(Parser, Clone)]
#[command(author, version, about, long_about = None)]
pub enum ShaiCLIArgs {
#[command(arg_required_else_help = true)]
Ask(AskArgs),
#[command(arg_required_else_help = true)]
Explain(ExplainArgs),
#[command(arg_required_else_help = true)]
GenerateScript(IntegrationScriptArgs),
}
#[derive(clap::Args, Clone)]
#[command(author, about, long_about = None)]
pub struct AskArgs {
#[arg(long, default_value = "Linux")]
operating_system: String,
#[arg(long, default_value = "Bash")]
shell: String,
#[arg(long, short, default_value = None)]
environment: Option<Vec<String>>,
#[arg(long, short, default_value = None)]
program: Option<Vec<String>>,
#[arg(long)]
cwd: bool,
#[arg(long, default_value=None)]
depth: Option<u32>,
#[arg(long, value_enum)]
model: ArgModelKind,
#[arg(long)]
write_stdout: bool,
#[arg(long)]
edit_file: Option<std::path::PathBuf>,
}
#[derive(clap::Args, Clone)]
#[command(author, about, long_about = None)]
pub struct ExplainArgs {
#[arg(long, default_value = "Linux")]
operating_system: String,
#[arg(long, default_value = "Bash")]
shell: String,
#[arg(long, default_value = None)]
environment: Option<Vec<String>>,
#[arg(long)]
cwd: bool,
#[arg(long, default_value=None)]
depth: Option<u32>,
#[arg(long, value_enum)]
model: ArgModelKind,
#[arg(long)]
write_stdout: bool,
#[arg(long)]
edit_file: Option<std::path::PathBuf>,
}
#[derive(clap::Args, Clone)]
#[command(author, about, long_about = None)]
pub struct IntegrationScriptArgs {
#[arg(long, value_enum)]
shell: Shell,
}
#[derive(Clone)]
pub enum ShaiArgs {
Ask(AskArgs),
Explain(ExplainArgs),
}
impl ShaiArgs {
const fn edit_file(&self) -> &Option<std::path::PathBuf> {
match self {
Self::Ask(args) => &args.edit_file,
Self::Explain(args) => &args.edit_file,
}
}
const fn write_stdout(&self) -> bool {
match self {
Self::Ask(args) => args.write_stdout,
Self::Explain(args) => args.write_stdout,
}
}
}
impl From<ShaiArgs> for ConfigKind {
fn from(value: ShaiArgs) -> Self {
match value {
ShaiArgs::Ask(args) => Self::Ask(AskConfig::from(args)),
ShaiArgs::Explain(args) => Self::Explain(ExplainConfig::from(args)),
}
}
}
#[derive(clap::ValueEnum, Clone)]
#[allow(non_camel_case_types)]
enum ArgModelKind {
OpenAIGPT4_1,
OpenAIGPT4_1Mini,
OpenAIGPT4_1Nano,
OpenAIGPT4o,
OpenAIGPT4oMini,
OpenAIO3,
OpenAIO3Mini,
OpenAIO4Mini,
OpenAIO1,
OpenAIGPT4Turbo,
OpenAIGPT4,
AnthropicClaudeOpus46,
AnthropicClaudeSonnet46,
AnthropicClaudeHaiku45,
AnthropicClaudeOpus45,
AnthropicClaudeSonnet45,
AnthropicClaudeOpus4,
AnthropicClaudeSonnet4,
AnthropicClaudeOpus41,
}
impl From<ArgModelKind> for ModelKind {
fn from(value: ArgModelKind) -> Self {
match value {
ArgModelKind::OpenAIGPT4_1 => Self::OpenAIGPT(OpenAIGPTModel::GPT4_1),
ArgModelKind::OpenAIGPT4_1Mini => Self::OpenAIGPT(OpenAIGPTModel::GPT4_1Mini),
ArgModelKind::OpenAIGPT4_1Nano => Self::OpenAIGPT(OpenAIGPTModel::GPT4_1Nano),
ArgModelKind::OpenAIGPT4o => Self::OpenAIGPT(OpenAIGPTModel::GPT4o),
ArgModelKind::OpenAIGPT4oMini => Self::OpenAIGPT(OpenAIGPTModel::GPT4oMini),
ArgModelKind::OpenAIO3 => Self::OpenAIGPT(OpenAIGPTModel::O3),
ArgModelKind::OpenAIO3Mini => Self::OpenAIGPT(OpenAIGPTModel::O3Mini),
ArgModelKind::OpenAIO4Mini => Self::OpenAIGPT(OpenAIGPTModel::O4Mini),
ArgModelKind::OpenAIO1 => Self::OpenAIGPT(OpenAIGPTModel::O1),
ArgModelKind::OpenAIGPT4Turbo => Self::OpenAIGPT(OpenAIGPTModel::GPT4Turbo),
ArgModelKind::OpenAIGPT4 => Self::OpenAIGPT(OpenAIGPTModel::GPT4),
ArgModelKind::AnthropicClaudeOpus46 => Self::Anthropic(AnthropicModel::ClaudeOpus46),
ArgModelKind::AnthropicClaudeSonnet46 => Self::Anthropic(AnthropicModel::ClaudeSonnet46),
ArgModelKind::AnthropicClaudeHaiku45 => Self::Anthropic(AnthropicModel::ClaudeHaiku45),
ArgModelKind::AnthropicClaudeOpus45 => Self::Anthropic(AnthropicModel::ClaudeOpus45),
ArgModelKind::AnthropicClaudeSonnet45 => Self::Anthropic(AnthropicModel::ClaudeSonnet45),
ArgModelKind::AnthropicClaudeOpus4 => Self::Anthropic(AnthropicModel::ClaudeOpus4),
ArgModelKind::AnthropicClaudeSonnet4 => Self::Anthropic(AnthropicModel::ClaudeSonnet4),
ArgModelKind::AnthropicClaudeOpus41 => Self::Anthropic(AnthropicModel::ClaudeOpus41),
}
}
}
#[derive(clap::ValueEnum, Clone)]
enum Shell {
Bash,
Zsh,
Fish,
Nushell,
PowerShell,
}
impl From<AskArgs> for AskConfig {
fn from(value: AskArgs) -> Self {
let cwd = if value.cwd { Some(()) } else { None };
let model = value.model.into();
Self {
operating_system: value.operating_system,
shell: value.shell,
environment: value.environment,
programs: value.program,
cwd,
depth: value.depth,
model,
}
}
}
impl From<ExplainArgs> for ExplainConfig {
fn from(value: ExplainArgs) -> Self {
let cwd = if value.cwd { Some(()) } else { None };
let model = value.model.into();
Self {
operating_system: value.operating_system,
shell: value.shell,
environment: value.environment,
cwd,
depth: value.depth,
model,
}
}
}
#[allow(clippy::missing_errors_doc)]
pub async fn run() -> Result<(), Box<dyn std::error::Error>> {
let args = ShaiCLIArgs::parse();
match args {
ShaiCLIArgs::Ask(shai_args) => {
let mut ui = ShaiUI::new(ShaiArgs::Ask(shai_args))?;
ui.run().await?;
}
ShaiCLIArgs::Explain(shai_args) => {
let mut ui = ShaiUI::new(ShaiArgs::Explain(shai_args))?;
ui.run().await?;
}
ShaiCLIArgs::GenerateScript(integration_args) => match integration_args.shell {
Shell::Bash => println!("{}", include_str!("../scripts//bash_assistant.sh")),
Shell::Zsh => println!("{}", include_str!("../scripts/zsh_assistant.zsh")),
Shell::Fish => println!("{}", include_str!("../scripts/fish_assistant.fish")),
Shell::Nushell => println!("{}", include_str!("../scripts/nushell_assistant.nu")),
Shell::PowerShell => {
println!("{}", include_str!("../scripts/powershell_assistant.ps1"));
}
},
}
Ok(())
}
enum WriteBuffer {
Yes,
Raw,
No,
}
enum RequestState {
WaitRequest,
Streaming,
}
enum RequestExit {
Cancel,
Exit,
Finished,
}
#[derive(Copy, Clone, Default)]
enum ShaiRequestProgress {
#[default]
None,
S0,
S1,
S2,
S3,
}
#[derive(Clone, Copy)]
enum ShaiState {
Started,
Processing,
ExplanationGenerated,
CommandGenerated,
AuxExplanationGenerated,
}
#[derive(Clone, Copy)]
enum RequestType {
Normal,
Auxiliary,
}
impl ShaiRequestProgress {
const fn next_state(self) -> Self {
match self {
Self::None | Self::S3 => Self::S0,
Self::S0 => Self::S1,
Self::S1 => Self::S2,
Self::S2 => Self::S3,
}
}
}
impl Display for ShaiRequestProgress {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::None => write!(f, ""),
Self::S0 => write!(f, "-"),
Self::S1 => write!(f, "\\"),
Self::S2 => write!(f, "|"),
Self::S3 => write!(f, "/"),
}
}
}
fn create_explanation_paragraph<'t>(
text: String,
thinking: ShaiRequestProgress,
focus: bool,
) -> Paragraph<'t> {
let focus_indicator = if focus { "*" } else { "" };
let title = format!("Shai {thinking} {focus_indicator}");
Paragraph::new(text)
.block(Block::default().borders(Borders::ALL).title(title))
.alignment(Alignment::Left)
.wrap(Wrap { trim: true })
}
fn create_input_paragraph<'t>(text: String, title: String) -> Paragraph<'t> {
Paragraph::new(text)
.block(Block::default().borders(Borders::ALL).title(title))
.alignment(Alignment::Left)
}
fn create_controls_paragraph<'t>(state: ShaiState) -> Paragraph<'t> {
let text = match state {
ShaiState::Started=> "<C-c>: Exit | Enter: Send Prompt".to_string(),
ShaiState::Processing => "<C-c>: Exit | Esc: Cancel ".to_string(),
ShaiState::ExplanationGenerated => "<C-c>: Exit | Enter: Send Prompt | <C-u|d>: Scroll".to_string(),
ShaiState::CommandGenerated => "<C-c>: Exit | Enter: Send Prompt | <C-a>: Accept | <C-A>: Accept (raw) | <C-e>: Explain".to_string(),
ShaiState::AuxExplanationGenerated =>"<C-c>: Exit | Enter: Send Prompt | <C-a>: Accept | <C-A>: Accept (raw) | <C-e>: Explain | <Tab>: Toggle Focus | <C-u|d>: Scroll | <S-Up|Down>: Resize explanation".to_string(),
};
Paragraph::new(text)
.block(Block::default().borders(Borders::TOP))
.alignment(Alignment::Left)
.wrap(Wrap { trim: true })
}
#[derive(Default)]
struct Response {
text: String,
scroll: u16,
request_state: ShaiRequestProgress,
}
pub struct ShaiUI<'t> {
args: ShaiArgs,
term: Terminal<CrosstermBackend<StdoutLock<'t>>>,
layout: Layout,
input_text: String,
input: Input,
main_response: Response,
auxiliary_response: Response,
main_response_size: u16,
response_focus: Focus,
}
fn extract_code_blocks(text: &str) -> Vec<String> {
lazy_static! {
static ref RE: Regex = Regex::new(r"(?s)```(?:\w+)?\n(.*?)\n```")
.expect("The regex expression should be valid");
}
let mut code_blocks = Vec::new();
for capture in RE.captures_iter(text) {
if let Some(code_block) = capture.get(1) {
code_blocks.push(code_block.as_str().to_string());
}
}
code_blocks
}
enum Layout {
InputResponse,
InputResponseExplanation,
}
enum Focus {
MainResponse,
AuxiliaryResponse,
}
impl Layout {
fn create(&self, main_response_size: u16) -> ratatui::layout::Layout {
match self {
Self::InputResponse => ratatui::layout::Layout::default()
.direction(Direction::Vertical)
.constraints(
[
Constraint::Length(3),
Constraint::Min(main_response_size),
Constraint::Length(2),
]
.as_ref(),
),
Self::InputResponseExplanation => ratatui::layout::Layout::default()
.direction(Direction::Vertical)
.constraints([
Constraint::Length(3),
Constraint::Length(main_response_size),
Constraint::Min(3),
Constraint::Length(2),
]),
}
}
}
const fn div_ceil(lhs: usize, rhs: usize) -> usize {
let d = lhs / rhs;
let r = lhs % rhs;
if r > 0 && rhs > 0 {
d + 1
} else {
d
}
}
fn estimate_text_height(text: &str, width: usize) -> usize {
text.split('\n')
.fold(1, |acc, line| acc + div_ceil(line.len().max(1), width))
}
impl<'t> ShaiUI<'t> {
fn initialization(args: ShaiArgs) -> Result<Self, Box<dyn std::error::Error>> {
let mut stdout = io::stdout().lock();
crossterm::execute!(stdout, EnterAlternateScreen)?;
let backend = CrosstermBackend::new(stdout);
let term = Terminal::new(backend)?;
let cli_text = args
.edit_file()
.as_ref()
.and_then(|file| fs::read_to_string(file).ok())
.map(|bufstr| bufstr.trim().to_string())
.unwrap_or_default();
Ok(ShaiUI {
args,
term,
layout: Layout::InputResponse,
input_text: cli_text.clone(),
input: Input::default().with_value(cli_text),
main_response: Response::default(),
auxiliary_response: Response::default(),
main_response_size: 3,
response_focus: Focus::MainResponse,
})
}
fn new(args: ShaiArgs) -> Result<Self, Box<dyn std::error::Error>> {
enable_raw_mode().expect("Terminal needs to be set in raw mode for Shai UI to work");
match Self::initialization(args) {
Ok(shai) => Ok(shai),
Err(err) => {
disable_raw_mode()?;
Err(err)
}
}
}
async fn run(&mut self) -> Result<(), Box<dyn std::error::Error>> {
let write_mode = self.mainloop().await;
disable_raw_mode()?;
crossterm::execute!(self.term.backend_mut(), LeaveAlternateScreen,)?;
self.term.show_cursor()?;
let write_mode = write_mode?;
if let ShaiArgs::Ask(_) = self.args {
if let Some(file) = &self.args.edit_file() {
match write_mode {
WriteBuffer::Yes => {
let code_blocks = extract_code_blocks(&self.main_response.text);
if code_blocks.is_empty() {
fs::write(file, &self.main_response.text)?;
} else {
fs::write(file, code_blocks.join("\n"))?;
}
}
WriteBuffer::Raw => fs::write(file, &self.main_response.text)?,
WriteBuffer::No => (),
}
}
}
if self.args.write_stdout() {
let response = &self.main_response.text;
println!("{response}");
}
Ok(())
}
fn state(&self) -> ShaiState {
match (
self.main_response.request_state,
self.auxiliary_response.request_state,
) {
(ShaiRequestProgress::None, ShaiRequestProgress::None) => match self.args {
ShaiArgs::Ask(_) => {
if self.main_response.text.is_empty() {
ShaiState::Started
} else if self.auxiliary_response.text.is_empty() {
ShaiState::CommandGenerated
} else {
ShaiState::AuxExplanationGenerated
}
}
ShaiArgs::Explain(_) => {
if self.main_response.text.is_empty() {
ShaiState::Started
} else {
ShaiState::ExplanationGenerated
}
}
},
_ => ShaiState::Processing,
}
}
#[allow(clippy::too_many_lines)]
async fn mainloop(&mut self) -> Result<WriteBuffer, Box<dyn std::error::Error>> {
loop {
self.draw()?;
if let Event::Key(key) = crossterm::event::read()? {
match key {
KeyEvent {
code: KeyCode::Char('c'),
modifiers: KeyModifiers::CONTROL,
..
} => return Ok(WriteBuffer::No),
KeyEvent {
code: KeyCode::Char('r'),
modifiers: KeyModifiers::CONTROL,
..
} if matches!(
self.state(),
ShaiState::CommandGenerated | ShaiState::AuxExplanationGenerated
) =>
{
return Ok(WriteBuffer::Raw)
}
KeyEvent {
code: KeyCode::Char('a'),
modifiers: KeyModifiers::CONTROL,
..
} if matches!(
self.state(),
ShaiState::CommandGenerated | ShaiState::AuxExplanationGenerated
) =>
{
return Ok(WriteBuffer::Yes)
}
KeyEvent {
code: KeyCode::Enter,
..
} => {
if matches!(
self.send_request(RequestType::Normal).await?,
RequestExit::Exit
) {
return Ok(WriteBuffer::No);
}
}
KeyEvent {
code: KeyCode::Char('e'),
modifiers: KeyModifiers::CONTROL,
..
} if matches!(
self.state(),
ShaiState::CommandGenerated | ShaiState::AuxExplanationGenerated
) =>
{
self.layout = Layout::InputResponseExplanation;
self.response_focus = Focus::AuxiliaryResponse;
if matches!(
self.send_request(RequestType::Auxiliary).await?,
RequestExit::Exit
) {
return Ok(WriteBuffer::No);
}
}
KeyEvent {
code: dirchar @ KeyCode::Char('d' | 'u'),
modifiers: KeyModifiers::CONTROL,
..
} if matches!(
self.state(),
ShaiState::ExplanationGenerated | ShaiState::AuxExplanationGenerated
) =>
{
match self.response_focus {
Focus::MainResponse => {
if dirchar == KeyCode::Char('d') {
self.main_response.scroll =
self.main_response.scroll.saturating_add(1);
} else {
self.main_response.scroll =
self.main_response.scroll.saturating_sub(1);
}
}
Focus::AuxiliaryResponse => {
if dirchar == KeyCode::Char('d') {
self.auxiliary_response.scroll =
self.auxiliary_response.scroll.saturating_add(1);
} else {
self.auxiliary_response.scroll =
self.auxiliary_response.scroll.saturating_sub(1);
}
}
}
}
KeyEvent {
code: KeyCode::Char('d' | 'u'),
modifiers: KeyModifiers::CONTROL,
..
} => (),
KeyEvent {
code: dirchar @ (KeyCode::Up | KeyCode::Down),
modifiers: KeyModifiers::SHIFT,
..
} if matches!(self.state(), ShaiState::AuxExplanationGenerated) => {
if dirchar == KeyCode::Up {
self.main_response_size =
self.main_response_size.saturating_sub(1).max(3);
} else {
self.main_response_size += 1; }
}
KeyEvent {
code: KeyCode::Tab, ..
} if matches!(self.layout, Layout::InputResponseExplanation) => {
self.response_focus = match self.response_focus {
Focus::MainResponse => Focus::AuxiliaryResponse,
Focus::AuxiliaryResponse => Focus::MainResponse,
}
}
_ => {
self.input.handle_event(&Event::Key(key));
self.input_text = self.input.value().to_string();
}
}
}
}
}
fn draw(&mut self) -> Result<(), Box<dyn std::error::Error>> {
let state = self.state();
self.term.draw(|f| {
let layout = self.layout.create(self.main_response_size);
let chunks = layout.split(f.size());
let width = chunks[0].width.max(3) - 3; let scroll = self.input.visual_scroll(width as usize);
f.render_widget(
create_input_paragraph(self.input_text.clone(), Self::title(&self.args))
.scroll((0, u16::try_from(scroll).unwrap_or_default())),
chunks[0],
);
f.set_cursor(
chunks[0].x
+ u16::try_from(self.input.visual_cursor().max(scroll) - scroll)
.unwrap_or_default()
+ 1,
chunks[0].y + 1,
);
let main_scroll =
if matches!(self.main_response.request_state, ShaiRequestProgress::None) {
self.main_response.scroll * chunks[1].height / 2
} else {
let full_scroll = estimate_text_height(
&self.main_response.text,
(chunks[1].width.saturating_sub(2)).into(),
)
.saturating_sub(chunks[1].height.saturating_sub(2).into())
.try_into()
.unwrap_or_default();
self.main_response.scroll = full_scroll / chunks[1].height * 2;
full_scroll
};
f.render_widget(
create_explanation_paragraph(
self.main_response.text.clone(),
self.main_response.request_state,
matches!(self.response_focus, Focus::MainResponse),
)
.scroll((main_scroll, 0)),
chunks[1],
);
match &self.layout {
Layout::InputResponse => {
f.render_widget(create_controls_paragraph(state), chunks[2]);
}
Layout::InputResponseExplanation => {
let chunk = chunks[2];
let aux_scroll = if matches!(
self.auxiliary_response.request_state,
ShaiRequestProgress::None
) {
self.auxiliary_response.scroll * chunk.height / 2
} else {
let full_scroll = estimate_text_height(
&self.auxiliary_response.text,
(chunk.width.saturating_sub(2)).into(),
)
.saturating_sub(chunk.height.saturating_sub(2).into())
.try_into()
.unwrap_or_default();
self.auxiliary_response.scroll = full_scroll / chunk.height * 2;
full_scroll
};
f.render_widget(
create_explanation_paragraph(
self.auxiliary_response.text.clone(),
self.auxiliary_response.request_state,
matches!(self.response_focus, Focus::AuxiliaryResponse),
)
.scroll((aux_scroll, 0)),
chunk,
);
f.render_widget(create_controls_paragraph(state), chunks[3]);
}
}
})?;
Ok(())
}
fn update_request_state(&mut self, request_type: RequestType, finished: bool) {
if finished {
match request_type {
RequestType::Normal => {
self.main_response.request_state = ShaiRequestProgress::None;
}
RequestType::Auxiliary => {
self.auxiliary_response.request_state = ShaiRequestProgress::None;
}
}
} else {
match request_type {
RequestType::Normal => {
self.main_response.request_state =
self.main_response.request_state.next_state();
}
RequestType::Auxiliary => {
self.auxiliary_response.request_state =
self.auxiliary_response.request_state.next_state();
}
}
}
}
async fn send_request(
&mut self,
request_type: RequestType,
) -> Result<RequestExit, Box<dyn std::error::Error>> {
let config = ConfigKind::from(self.args.clone());
let model = config.model().clone();
let task = match config {
ConfigKind::Ask(_) => match request_type {
RequestType::Normal => Task::GenerateCommand,
RequestType::Auxiliary => Task::Explain,
},
ConfigKind::Explain(_) => Task::Explain,
};
let context = Context::from(config);
let user_prompt = match request_type {
RequestType::Normal => self.input.value().to_string(),
RequestType::Auxiliary => self.main_response.text.clone(),
};
let request_task = tokio::spawn(model_stream_request(
model.clone(),
user_prompt,
context.clone(),
task,
));
let mut reqstate = RequestState::WaitRequest;
let ret = loop {
self.draw()?;
match reqstate {
RequestState::WaitRequest => {
if crossterm::event::poll(Duration::from_millis(100))? {
if let Event::Key(key) = crossterm::event::read()? {
match key {
KeyEvent {
code: KeyCode::Esc, ..
} => break Ok(RequestExit::Cancel),
KeyEvent {
code: KeyCode::Char('c'),
modifiers: KeyModifiers::CONTROL,
..
} => break Ok(RequestExit::Exit),
_ => (),
}
}
}
if request_task.is_finished() {
reqstate = RequestState::Streaming;
self.clear_response(request_type);
}
}
RequestState::Streaming => {
break self
.stream_response(
request_task.await??,
request_type,
)
.await
}
}
self.update_request_state(request_type, false);
};
self.update_request_state(request_type, true);
ret
}
async fn stream_response(
&mut self,
mut response_stream: impl Stream<Item = Result<String, ModelError>> + Unpin,
request_type: RequestType,
) -> Result<RequestExit, Box<dyn std::error::Error>> {
while let Some(message) = response_stream.next().await {
self.append_message_response(&message?, request_type);
self.draw()?;
if crossterm::event::poll(Duration::from_millis(100))? {
if let Event::Key(key) = crossterm::event::read()? {
match key {
KeyEvent {
code: KeyCode::Char('c'),
modifiers: KeyModifiers::CONTROL,
..
} => return Ok(RequestExit::Exit),
KeyEvent {
code: KeyCode::Esc, ..
} => return Ok(RequestExit::Cancel),
_ => (),
}
}
}
self.update_request_state(request_type, false);
}
Ok(RequestExit::Finished)
}
fn clear_response(&mut self, request_type: RequestType) {
match request_type {
RequestType::Normal => {
self.layout = Layout::InputResponse;
self.response_focus = Focus::MainResponse;
self.main_response = Response::default();
self.auxiliary_response = Response::default();
}
RequestType::Auxiliary => {
self.auxiliary_response = Response::default();
}
}
}
fn append_message_response(&mut self, response: &str, request_type: RequestType) {
let old_text = match request_type {
RequestType::Normal => &self.main_response.text,
RequestType::Auxiliary => &self.auxiliary_response.text,
};
let new = format!("{old_text}{response}");
match request_type {
RequestType::Normal => self.main_response.text = new,
RequestType::Auxiliary => self.auxiliary_response.text = new,
}
}
fn title(args: &ShaiArgs) -> String {
match args {
ShaiArgs::Ask(_) => "What should shai's command do?",
ShaiArgs::Explain(_) => "What command should shai explain?",
}
.to_string()
}
}
#[cfg(test)]
mod tests {
use super::extract_code_blocks;
#[test]
fn code_blocks_regex() {
let code_rust = "fn main() {
println!(\"Hello, World!\");
}";
let code_no_tag = "
Hello my friend";
let code_python = "
print('Hello, World!')
";
let text = format!(
"
Some text before the code block
```rust
{code_rust}
```
```
{code_no_tag}
```
Some text after the code block
```python
{code_python}
```
"
);
let blocks = extract_code_blocks(&text);
assert_eq!(blocks.len(), 3);
assert_eq!(blocks[0], code_rust);
assert_eq!(blocks[1], code_no_tag);
assert_eq!(blocks[2], code_python);
}
}