smartcat 0.3.0

Putting a brain behind `cat`. CLI interface to bring language models in the Unix ecosystem 🐈‍⬛
use log::debug;
use std::io::{self, Read, Result, Write};

use crate::config::{get_api_config, Message, Prompt, PLACEHOLDER_TOKEN};
use crate::request::{make_authenticated_request, OpenAiResponse};

// [tmp] mostly template to write tests
pub fn chunk_process_input<R: Read, W: Write>(
    input: &mut R,
    output: &mut W,
    prefix: &str,
    suffix: &str,
) -> Result<()> {
    let mut first_chunk = true;
    let mut buffer = [0; 1024];
    loop {
        match input.read(&mut buffer) {
            Ok(0) => break, // end of input
            Ok(n) => {
                if first_chunk {
                    output.write_all(prefix.as_bytes())?;
                    first_chunk = false;
                }
                output.write_all(&buffer[..n])?;
            }
            Err(e) => return Err(e),
        }
    }

    if !first_chunk {
        // we actually got some input
        output.write_all(suffix.as_bytes())?;
    }

    Ok(())
}

pub fn process_input_with_request<R: Read, W: Write>(
    mut prompt: Prompt,
    input: &mut R,
    input_string: Option<String>,
    output: &mut W,
    repeat_input: bool,
) -> Result<Prompt> {
    let mut input = match input_string {
        Some(input_string) => input_string,
        None => {
            let mut buffer = Vec::new();
            input.read_to_end(&mut buffer)?;

            String::from_utf8(buffer).unwrap()
        }
    };

    // nothing to do if no input
    if input.is_empty() {
        return Ok(prompt);
    }

    // insert the input in the messages with placeholders
    for message in prompt.messages.iter_mut() {
        message.content = message.content.replace(PLACEHOLDER_TOKEN, &input)
    }
    // fetch the api config tied to the prompt
    let api_config = get_api_config(&prompt.api.to_string());

    // make the request
    let response: OpenAiResponse = make_authenticated_request(api_config, &prompt)
        .map_err(|e| match e {
            ureq::Error::Status(status, response) => {
                let body = match response.into_string() {
                    Ok(body) => body,
                    Err(_) => "(non-UTF-8 response)".to_owned(),
                };
                io::Error::new(
                    io::ErrorKind::Other,
                    format!(
                        "API call failed with status code {} and body: {}",
                        status, body
                    ),
                )
            }
            ureq::Error::Transport(transport) => {
                io::Error::new(io::ErrorKind::Other, transport.to_string())
            }
        })?
        .into_json()?;

    let response_text = response.choices.first().unwrap().message.content.as_str();
    debug!("{}", &response_text);

    prompt.messages.push(Message::assistant(&response_text));

    if repeat_input {
        input.push('\n');
        output.write_all(input.as_bytes())?;
    }

    output.write_all(response_text.as_bytes())?;

    Ok(prompt)
}

#[cfg(test)]
mod tests {
    use super::*;
    use std::io::Cursor;

    macro_rules! test_process_input {
        ($test_name:ident, $prefix:expr, $suffix:expr, $input:expr) => {
            #[test]
            fn $test_name() {
                let input = $input.as_bytes();
                let mut output = std::io::Cursor::new(Vec::new());

                let result =
                    chunk_process_input(&mut Cursor::new(input), &mut output, $prefix, $suffix);
                assert!(result.is_ok());

                let expected_output = if !input.is_empty() {
                    format!("{}{}{}", $prefix, $input, $suffix)
                } else {
                    "".into()
                };

                let expected_output_as_bytes = expected_output.as_bytes();
                let output_data: Vec<u8> = output.into_inner();
                assert_eq!(
                    expected_output_as_bytes,
                    output_data,
                    "\nexpected: {}\nGot: {}",
                    String::from_utf8_lossy(expected_output_as_bytes),
                    &expected_output
                );
            }
        };
    }

    test_process_input!(
        test_with_prefix_and_suffix,
        "Prefix: ",
        " Suffix",
        "Input data"
    );
    test_process_input!(
        test_with_custom_prefix_suffix,
        "Start: ",
        " End",
        "Custom input"
    );
    test_process_input!(test_empty_input, "Pre: ", " Post", "");
}