pub mod openai;
pub mod vertex_ai;
use std::fmt::{Debug, Display};
use std::str::FromStr;
use std::sync::Arc;
use serde::{Deserialize, Serialize};
use crate::context::ChatEntry;
pub type ModelRef = Arc<Box<dyn Model>>;
#[derive(thiserror::Error, Debug)]
pub enum Error {
#[error("Model invocation failed")]
OpenAIError(#[from] openai::OpenAIError),
#[error("No response from the model")]
NoResponseFromModel,
#[error("Model not supported: {0}")]
ModelNotSupported(String),
#[error("Vertex AI error: {0}")]
VertexAIError(#[from] gcp_vertex_ai_generative_language::Error),
#[error("Filtered output")]
Filtered,
}
#[derive(Debug, Serialize, Deserialize, Clone, Default, PartialEq)]
#[serde(rename_all = "lowercase")]
pub enum Role {
System,
#[default]
User,
Assistant,
Function,
Tool,
}
impl Display for Role {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Role::System => write!(f, "system"),
Role::User => write!(f, "user"),
Role::Assistant => write!(f, "assistant"),
Role::Function => write!(f, "function"),
Role::Tool => write!(f, "tool"),
}
}
}
#[async_trait::async_trait]
pub trait ChatEntryTokenNumber {
async fn num_tokens(&self, input: ChatInput) -> usize;
async fn context_size(&self) -> usize;
}
#[derive(Debug, Clone)]
pub struct ChatInput {
pub(crate) context: Vec<ChatEntry>,
pub(crate) examples: Vec<(ChatEntry, ChatEntry)>,
pub(crate) chat: Vec<ChatEntry>,
}
#[async_trait::async_trait]
pub trait Model: ChatEntryTokenNumber + Send + Sync {
async fn query(
&self,
input: ChatInput,
max_tokens: Option<usize>,
) -> Result<ModelResponse, Error>;
}
#[derive(Clone)]
pub struct ModelResponse {
pub msg: String,
pub usage: Option<Usage>,
pub finish_reason: Option<String>,
}
impl Debug for ModelResponse {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
writeln!(f, "ModelResponse {{ ")?;
write!(f, "msg: \n{}, \n", &self.msg)?;
if let Some(usage) = &self.usage {
writeln!(f, "usage: {:#?}, ", usage)?;
}
if let Some(finish_reason) = &self.finish_reason {
writeln!(f, "finish_reason: {}, ", &finish_reason)?;
}
write!(f, "}}")
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Usage {
pub prompt_tokens: u32,
pub completion_tokens: u32,
pub total_tokens: u32,
}
#[derive(Clone, Serialize, Deserialize, Default)]
pub enum SupportedModel {
#[default]
GPT3_5Turbo,
GPT3_5Turbo0613,
GPT3_5Turbo16k,
Vicuna7B1_1,
Vicuna13B1_1,
ChatBison001,
}
impl Display for SupportedModel {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
SupportedModel::GPT3_5Turbo => write!(f, "gpt-3.5-turbo"),
SupportedModel::GPT3_5Turbo0613 => write!(f, "gpt-3.5-turbo-0613"),
SupportedModel::GPT3_5Turbo16k => write!(f, "gpt-3.5-turbo-16k"),
SupportedModel::Vicuna7B1_1 => write!(f, "vicuna-7b-1.1"),
SupportedModel::Vicuna13B1_1 => write!(f, "vicuna-13b-1.1"),
SupportedModel::ChatBison001 => write!(f, "chat-bison-001"),
}
}
}
impl Debug for SupportedModel {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
SupportedModel::GPT3_5Turbo => write!(f, "gpt-3.5-turbo"),
SupportedModel::GPT3_5Turbo0613 => write!(f, "gpt-3.5-turbo-0613"),
SupportedModel::GPT3_5Turbo16k => write!(f, "gpt-3.5-turbo-16k"),
SupportedModel::Vicuna7B1_1 => write!(f, "vicuna-7b-1.1"),
SupportedModel::Vicuna13B1_1 => write!(f, "vicuna-13b-1.1"),
SupportedModel::ChatBison001 => write!(f, "chat-bison-001"),
}
}
}
impl FromStr for SupportedModel {
type Err = Error;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"gpt-3.5-turbo" => Ok(Self::GPT3_5Turbo),
"gpt-3.5-turbo-0613" => Ok(Self::GPT3_5Turbo0613),
"gpt-3.5-turbo-16k" => Ok(Self::GPT3_5Turbo16k),
"vicuna-7b-1.1" => Ok(Self::Vicuna7B1_1),
"vicuna-13b-1.1" => Ok(Self::Vicuna13B1_1),
"chat-bison-001" => Ok(Self::ChatBison001),
_ => Err(Error::ModelNotSupported(s.to_string())),
}
}
}
#[cfg(feature = "clap")]
impl clap::ValueEnum for SupportedModel {
fn value_variants<'a>() -> &'a [Self] {
&[
SupportedModel::GPT3_5Turbo,
SupportedModel::GPT3_5Turbo0613,
SupportedModel::GPT3_5Turbo16k,
SupportedModel::Vicuna7B1_1,
SupportedModel::Vicuna13B1_1,
SupportedModel::ChatBison001,
]
}
fn to_possible_value(&self) -> Option<clap::builder::PossibleValue> {
match self {
SupportedModel::GPT3_5Turbo => Some(clap::builder::PossibleValue::new("gpt-3.5-turbo")),
SupportedModel::GPT3_5Turbo0613 => {
Some(clap::builder::PossibleValue::new("gpt-3.5-turbo-0613"))
}
SupportedModel::GPT3_5Turbo16k => {
Some(clap::builder::PossibleValue::new("gpt-3.5-turbo-16k"))
}
SupportedModel::Vicuna7B1_1 => Some(clap::builder::PossibleValue::new("vicuna-7b-1.1")),
SupportedModel::Vicuna13B1_1 => {
Some(clap::builder::PossibleValue::new("vicuna-13b-1.1"))
}
SupportedModel::ChatBison001 => {
Some(clap::builder::PossibleValue::new("chat-bison-001"))
}
}
}
}