use std::path::PathBuf;
use async_openai::types::CreateChatCompletionRequestArgs;
use clap::{Args, Parser, Subcommand, ValueEnum};
use strum_macros::{Display, EnumVariantNames, FromRepr};
#[derive(Debug, Parser)]
#[command(author, version, about, long_about = None)]
pub struct Cli {
#[command(subcommand)]
pub command: Option<Command>,
#[command(flatten)]
pub chat: ChatCommandArgs,
#[arg(long, short = 'v')]
pub verbose: bool,
}
#[derive(Debug, Subcommand)]
pub enum Command {
Chat(ChatCommandArgs),
Logs,
}
#[derive(Debug, Args)]
pub struct ChatCommandArgs {
#[arg(
long,
short = 'm',
default_value = "gpt-4-1106-preview",
hide_default_value = true,
)]
pub model: Option<Model>,
#[arg(long, short = 't')]
temperature: Option<f32>,
#[arg(long, short = 'p')]
top_p: Option<f32>,
#[arg(long, short)]
n: Option<u8>,
#[arg(long, short)]
stream: Option<bool>,
#[arg(long)]
stop: Option<Vec<String>>,
#[arg(long, short = 'c')]
max_tokens: Option<u16>,
#[arg(long)]
presence_penalty: Option<f32>,
#[arg(long)]
frequency_penalty: Option<f32>,
#[arg(long, short = 'u')]
user: Option<String>,
#[arg(long, short = 'i')]
input: Option<PathBuf>,
#[arg(long, short = 'o')]
output: Option<PathBuf>,
#[arg(long)]
pub system: Option<String>,
#[arg(last(true))]
pub message: Option<Vec<String>>,
}
#[derive(
Debug,
Display,
Default,
Copy,
Clone,
PartialEq,
Eq,
PartialOrd,
Ord,
ValueEnum,
EnumVariantNames,
FromRepr,
)]
#[value()]
pub enum Model {
#[strum(serialize = "gpt-3.5-turbo")]
#[value(name = "gpt-3.5-turbo", alias = "3.5")]
Gpt35Turbo,
#[strum(serialize = "gpt-3.5-turbo-0314")]
#[value(name = "gpt-3.5-turbo-0301")]
Gpt35Turbo0301,
#[strum(serialize = "gpt-3.5-turbo-0613")]
#[value(name = "gpt-3.5-turbo-0613")]
Gpt35Turbo0613,
#[strum(serialize = "gpt-3.5-turbo-1106")]
#[value(name = "gpt-3.5-turbo-1106")]
Gpt35Turbo1106,
#[strum(serialize = "gpt-3.5-16k")]
#[value(name = "gpt-3.5-turbo-16k", alias = "3.5-16k")]
Gpt35Turbo16k,
#[strum(serialize = "gpt-3.5-turbo-16k-0613")]
#[value(name = "gpt-3.5-turbo-16k-0613")]
Gpt35Turbo16k0613,
#[strum(serialize = "gpt-4")]
#[value(name = "gpt-4", alias = "4")]
Gpt4,
#[strum(serialize = "gpt-4-0314")]
#[value(name = "gpt-4-0314")]
Gpt40314,
#[strum(serialize = "gpt-4-0613")]
#[value(name = "gpt-4-0613")]
Gpt40613,
#[default]
#[strum(serialize = "gpt-4-1106-preview")]
#[value(name = "gpt-4-1106-preview")]
Gpt41106Preview,
#[strum(serialize = "gpt-4-32k")]
#[value(name = "gpt-4-32k", alias = "4-32k")]
Gpt432k,
#[strum(serialize = "gpt-4-32k-0314")]
#[value(name = "gpt-4-32k-0314")]
Gpt432k0314,
#[strum(serialize = "gpt-4-32k-0613")]
#[value(name = "gpt-4-32k-0613")]
Gpt432k0613,
}
#[allow(dead_code)]
impl From<&ChatCommandArgs> for CreateChatCompletionRequestArgs {
fn from(value: &ChatCommandArgs) -> Self {
let mut builder = CreateChatCompletionRequestArgs::default();
value.model.map(|m| builder.model(m.to_string()));
value.temperature.map(|t| builder.temperature(t));
value.top_p.map(|p| builder.top_p(p));
value.n.map(|n| builder.n(n));
value.stream.map(|s| builder.stream(s));
value.stop.as_ref().map(|s| builder.stop(s));
value.max_tokens.map(|m| builder.max_tokens(m));
value.presence_penalty.map(|p| builder.presence_penalty(p));
value
.frequency_penalty
.map(|p| builder.frequency_penalty(p));
value.user.as_ref().map(|u| builder.user(u));
builder
}
}