use crate::error::ProtoError;
use serde::{Deserialize, Serialize};
pub mod defaults {
pub const TEMPERATURE: f64 = 1.0;
pub const TOP_P: f64 = 0.95;
pub const TOP_K: u32 = 64;
pub const MAX_TOKENS: u32 = 1000;
pub const STREAM: bool = true;
}
pub const VALID_IMAGE_TOKEN_BUDGETS: [u32; 5] = [70, 140, 280, 560, 1120];
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum Role {
System,
User,
Assistant,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct Message {
pub role: Role,
pub content: String,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct ImageTokenBudget(u32);
impl ImageTokenBudget {
pub fn new(value: u32) -> Option<Self> {
if VALID_IMAGE_TOKEN_BUDGETS.contains(&value) {
Some(Self(value))
} else {
None
}
}
pub fn get(self) -> u32 {
self.0
}
}
#[derive(Debug, Clone, PartialEq, Default, Serialize, Deserialize)]
pub struct Request {
#[serde(default, skip_serializing_if = "String::is_empty")]
pub id: String,
pub messages: Vec<Message>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub temperature: Option<f64>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub top_p: Option<f64>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub top_k: Option<u32>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub max_tokens: Option<u32>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub stream: Option<bool>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub image_token_budget: Option<u32>,
#[serde(default, skip_serializing_if = "String::is_empty")]
pub grammar: String,
}
#[derive(Debug, Clone, PartialEq)]
pub struct Resolved {
pub id: String,
pub messages: Vec<Message>,
pub temperature: f64,
pub top_p: f64,
pub top_k: u32,
pub max_tokens: u32,
pub stream: bool,
pub image_token_budget: Option<ImageTokenBudget>,
pub grammar: String,
}
impl Request {
pub fn resolve(self) -> Result<Resolved, ProtoError> {
if self.messages.is_empty() {
return Err(ProtoError::InvalidRequest(
"messages must not be empty".into(),
));
}
let image_token_budget = match self.image_token_budget {
Some(v) => Some(ImageTokenBudget::new(v).ok_or_else(|| {
ProtoError::InvalidRequest(format!(
"image_token_budget {v} not in {VALID_IMAGE_TOKEN_BUDGETS:?}"
))
})?),
None => None,
};
Ok(Resolved {
id: self.id,
messages: self.messages,
temperature: self.temperature.unwrap_or(defaults::TEMPERATURE),
top_p: self.top_p.unwrap_or(defaults::TOP_P),
top_k: self.top_k.unwrap_or(defaults::TOP_K),
max_tokens: self.max_tokens.unwrap_or(defaults::MAX_TOKENS),
stream: self.stream.unwrap_or(defaults::STREAM),
image_token_budget,
grammar: self.grammar,
})
}
}