mod private {
pub trait Sealed {}
}
use std::borrow::Cow;
pub trait KnownEngineDefinition: private::Sealed {
const ID: &'static str;
const MAX_TOKENS: usize = 1024;
const AS_CUSTOM_ENGINE_DEFINITION: CustomEngineDefinition =
CustomEngineDefinition::r#static(Self::ID, Self::MAX_TOKENS);
}
pub struct GptJ6B {
_priv: (),
}
impl KnownEngineDefinition for GptJ6B {
const ID: &'static str = "gptj_6B";
const MAX_TOKENS: usize = 2048;
}
impl private::Sealed for GptJ6B {}
pub struct Boris6B {
_priv: (),
}
impl KnownEngineDefinition for Boris6B {
const ID: &'static str = "boris_6B";
}
impl private::Sealed for Boris6B {}
pub struct FairseqGpt13B {
_priv: (),
}
impl KnownEngineDefinition for FairseqGpt13B {
const ID: &'static str = "fairseq_gpt_13B";
}
impl private::Sealed for FairseqGpt13B {}
#[derive(Debug, Clone, Ord, PartialOrd, Eq, PartialEq, Hash)]
#[cfg_attr(
feature = "serde_derives",
derive(serde::Serialize, serde::Deserialize)
)]
pub struct CustomEngineDefinition {
pub id: Cow<'static, str>,
pub max_tokens: usize,
}
impl CustomEngineDefinition {
pub const fn r#static(id: &'static str, max_tokens: usize) -> Self {
Self {
id: Cow::Borrowed(id),
max_tokens,
}
}
pub const fn dynamic(id: String, max_tokens: usize) -> Self {
Self {
id: Cow::Owned(id),
max_tokens,
}
}
pub fn new(id: impl Into<Cow<'static, str>>, max_tokens: usize) -> Self {
Self {
id: id.into(),
max_tokens,
}
}
}
#[derive(Debug, Clone, Ord, PartialOrd, Eq, PartialEq, Hash)]
#[cfg_attr(
feature = "serde_derives",
derive(serde::Serialize, serde::Deserialize)
)]
pub enum EngineDefinition {
GptJ6B,
Boris6B,
FairseqGpt13B,
Custom(CustomEngineDefinition),
}
impl EngineDefinition {
pub const fn to_custom_engine_definition(&self) -> Cow<CustomEngineDefinition> {
match self {
Self::GptJ6B => Cow::Owned(GptJ6B::AS_CUSTOM_ENGINE_DEFINITION),
Self::Boris6B => Cow::Owned(Boris6B::AS_CUSTOM_ENGINE_DEFINITION),
Self::FairseqGpt13B => Cow::Owned(FairseqGpt13B::AS_CUSTOM_ENGINE_DEFINITION),
Self::Custom(custom_engine) => Cow::Borrowed(custom_engine),
}
}
pub fn id(&self) -> &str {
match self {
Self::GptJ6B => GptJ6B::ID,
Self::Boris6B => Boris6B::ID,
Self::FairseqGpt13B => FairseqGpt13B::ID,
Self::Custom(custom_engine) => &custom_engine.id,
}
}
pub fn max_tokens(&self) -> usize {
self.to_custom_engine_definition().max_tokens
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_custom_engine_definition_static() {
let _ = CustomEngineDefinition::r#static("static", 42);
}
#[test]
fn test_custom_engine_definition_dynamic() {
let _ = CustomEngineDefinition::dynamic("dynamic".into(), 42);
}
#[test]
fn test_custom_engine_definition_new() {
let _ = CustomEngineDefinition::new("new", 42);
let _ = CustomEngineDefinition::new(String::from("new"), 42);
}
#[test]
fn test_engine_definition_to_custom_engine_definition() {
assert_eq!(
EngineDefinition::GptJ6B.to_custom_engine_definition(),
Cow::Owned(GptJ6B::AS_CUSTOM_ENGINE_DEFINITION)
);
assert_eq!(
EngineDefinition::Boris6B.to_custom_engine_definition(),
Cow::Owned(Boris6B::AS_CUSTOM_ENGINE_DEFINITION)
);
assert_eq!(
EngineDefinition::FairseqGpt13B.to_custom_engine_definition(),
Cow::Owned(FairseqGpt13B::AS_CUSTOM_ENGINE_DEFINITION)
);
let custom_engine_definition = CustomEngineDefinition::new("custom", 42);
let custom_engine_definition_clone = custom_engine_definition.clone();
let cow_custom_engine_definition = Cow::Borrowed(&custom_engine_definition_clone);
assert_eq!(
EngineDefinition::Custom(custom_engine_definition).to_custom_engine_definition(),
cow_custom_engine_definition,
);
}
#[test]
fn test_engine_definition_id() {
assert_eq!(EngineDefinition::GptJ6B.id(), GptJ6B::ID);
assert_eq!(EngineDefinition::Boris6B.id(), Boris6B::ID);
assert_eq!(EngineDefinition::FairseqGpt13B.id(), FairseqGpt13B::ID);
assert_eq!(
EngineDefinition::Custom(CustomEngineDefinition::r#static("static", 42)).id(),
"static"
);
}
#[test]
fn test_engine_definition_max_tokens() {
assert_eq!(EngineDefinition::GptJ6B.max_tokens(), GptJ6B::MAX_TOKENS);
assert_eq!(EngineDefinition::Boris6B.max_tokens(), Boris6B::MAX_TOKENS);
assert_eq!(
EngineDefinition::FairseqGpt13B.max_tokens(),
FairseqGpt13B::MAX_TOKENS
);
assert_eq!(
EngineDefinition::Custom(CustomEngineDefinition::r#static("static", 42)).max_tokens(),
42
);
}
}