fgpt 0.1.1

A free reverse proxy and cli tool for OpenAI GPT-3.5-turbo.
use crate::fgpt::{CompletionRequest, Message};
use atty::Stream;
use futures::StreamExt;
use rustyline::error::ReadlineError;
use rustyline::DefaultEditor;
use std::io::{Read, Write};
use tokio::select;

pub async fn run_repl(state: crate::StateRef) -> Result<(), crate::fgpt::Error> {
    let help_text = r#"Type `/help` for more information.
Type `/exit` to exit the program.
Type `/reset` to reset the conversation.
    "#;

    println!("Ctrl-C to cancel, Ctrl-D to exit. \\ for a new line. ✨");

    let mut rl = DefaultEditor::new()?;
    let mut prompt_text = ">> ".to_string();
    let mut question = String::new();

    let mut last_message_id = Some(uuid::Uuid::new_v4().to_string());
    let mut conversation_id: Option<String> = None;

    loop {
        let readline = rl.readline(&prompt_text);
        match readline {
            Ok(line) => {
                let line = line.trim();
                match line {
                    "/exit" => break,
                    "/help" => {
                        println!("{}", help_text);
                        continue;
                    }
                    "/reset" => {
                        conversation_id = None;
                        last_message_id = Some(uuid::Uuid::new_v4().to_string());
                        println!("Conversation reset. ✨");
                        continue;
                    }
                    _ => {}
                }
                if line.ends_with("\\") {
                    prompt_text = ".. ".to_string();
                    question.push_str(&line[..line.len() - 1]);
                    question.push('\n');
                    continue;
                } else {
                    prompt_text = ">> ".to_string();
                    question.push_str(line);
                }
                rl.add_history_entry(&question).ok();
                question = String::new();

                let mut messages = vec![];
                messages.push(Message {
                    role: "user".to_string(),
                    content: line.to_string(),
                    content_type: "text".to_string(),
                });

                select! {
                    r = execute_plain(
                        state.clone(),
                        messages,
                        conversation_id.clone(),
                        last_message_id.clone(),
                    ) => {
                        let r = r?;
                        conversation_id = Some(r.conversation_id);
                        last_message_id = Some(r.last_message_id);
                    }
                    _ = tokio::signal::ctrl_c() => {
                        break;
                    }
                }
            }
            Err(ReadlineError::Interrupted) => {
                break;
            }
            Err(ReadlineError::Eof) => {
                break;
            }
            Err(err) => {
                println!("Error: {:?}", err);
                break;
            }
        }
    }
    Ok(())
}

pub async fn run(state: crate::StateRef) -> Result<(), crate::fgpt::Error> {
    if state.repl || (state.qusetion.is_none() && state.input_file.is_none()) {
        return run_repl(state).await;
    }

    let mut messages = vec![];
    if state.code {
        messages.push(Message {
            role: "system".to_string(),
            content: include_str!("./role.code.prompt").to_string(),
            content_type: "text".to_string(),
        });
    }

    match state.qusetion {
        Some(ref q) => {
            messages.push(Message {
                role: "user".to_string(),
                content: q.clone(),
                content_type: "text".to_string(),
            });
        }
        None => {}
    }

    match state.input_file {
        Some(ref file) => {
            let content = std::fs::read_to_string(file)?;
            messages.push(Message {
                role: "user".to_string(),
                content,
                content_type: "text".to_string(),
            });
        }
        None => {}
    }

    if !atty::is(Stream::Stdin) {
        // it may be a pipe or a file
        let mut content = String::new();
        std::io::stdin().read_to_string(&mut content)?;
        messages.push(Message {
            role: "user".to_string(),
            content,
            content_type: "text".to_string(),
        });
    }

    messages.iter().for_each(|m| log::debug!("{:?}", m));

    let tokenizer = gpt_tokenizer::Default::new();
    let prompt_tokens: usize = messages
        .iter()
        .map(|message| tokenizer.encode(&message.content).len())
        .sum();

    let start_at = std::time::Instant::now();
    let result = execute_plain(
        state.clone(),
        messages,
        None,
        Some(uuid::Uuid::new_v4().to_string()),
    )
    .await?;

    if state.dump_stats {
        let elapsed = start_at.elapsed().as_secs_f64();
        let completion_tokens = tokenizer.encode(&result.textbuf).len();
        let total_tokens = completion_tokens + prompt_tokens;
        let throughput = completion_tokens as f64 / elapsed as f64;

        println!(
            "Total tokens: \x1b[32m{}\x1b[0m, completion tokens: \x1b[32m{}\x1b[0m, prompt tokens: \x1b[32m{}\x1b[0m, elapsed: \x1b[33m{:.1}\x1b[0m secs, throughput: \x1b[33m{:.2}\x1b[0m tps",
            total_tokens,
            completion_tokens,
            prompt_tokens,
            elapsed,
            throughput
        );
    }
    Ok(())
}

struct CompletionResult {
    textbuf: String,
    conversation_id: String,
    last_message_id: String,
}

async fn execute_plain(
    state: crate::StateRef,
    messages: Vec<Message>,
    conversion_id: Option<String>,
    parent_message_id: Option<String>,
) -> Result<CompletionResult, crate::fgpt::Error> {
    let req = CompletionRequest::new(state.clone(), messages, conversion_id, parent_message_id);
    let mut stream = req.stream(state.clone()).await?;

    let mut textbuf = String::new();
    let mut conversation_id = String::new();
    let mut last_message_id = String::new();

    while let Some(message) = stream.next().await {
        match message {
            Ok(crate::fgpt::CompletionEvent::Data(message)) => {
                if message.message.author.role != "assistant" {
                    continue;
                }

                let text = message.message.content.parts.join("\n");
                if textbuf.len() > text.len() {
                    continue;
                }
                let delta_chars = &text[textbuf.len()..];
                textbuf = text.clone();
                print!("{}", delta_chars);
                let _ = std::io::stdout().flush();
                conversation_id = message.conversation_id.clone();
                last_message_id = message.message.id.clone();
            }
            Ok(crate::fgpt::CompletionEvent::Done) => {
                println!();
                break;
            }
            Ok(_) => {}
            Err(e) => {
                log::error!("{:?}", e);
                break;
            }
        }
    }
    Ok(CompletionResult {
        textbuf,
        conversation_id,
        last_message_id,
    })
}