use crate::message::AssistantContent;
use crate::message::Message;
use crate::one_or_many::OneOrMany;
use crate::types::ToolDefinition;
#[derive(Debug, thiserror::Error)]
pub enum CompletionError {
#[error("HttpError: {0}")]
HttpError(String),
#[error("JsonError: {0}")]
JsonError(#[from] serde_json::Error),
#[error("RequestError: {0}")]
RequestError(String),
#[error("ResponseError: {0}")]
ResponseError(String),
#[error("ProviderError: {0}")]
ProviderError(String),
}
pub trait CompletionModel {
fn completion_request(&self, prompt: Message) -> CompletionRequestBuilder;
fn completion(&self, request: CompletionRequest)
-> Result<CompletionResponse, CompletionError>;
}
pub struct CompletionRequest {
pub prompt: Message,
pub preamble: Option<String>,
pub chat_history: Vec<Message>,
pub tools: Vec<ToolDefinition>,
pub temperature: Option<f64>,
pub max_tokens: Option<u64>,
pub additional_params: Option<serde_json::Value>,
}
pub struct CompletionResponse {
pub choice: OneOrMany<AssistantContent>,
pub usage: Option<Usage>,
}
#[derive(Debug, Clone, Default, serde::Serialize, serde::Deserialize)]
pub struct Usage {
pub prompt_tokens: u64,
pub completion_tokens: u64,
pub total_tokens: u64,
}
pub struct CompletionRequestBuilder {
prompt: Message,
preamble: Option<String>,
chat_history: Vec<Message>,
tools: Vec<ToolDefinition>,
temperature: Option<f64>,
max_tokens: Option<u64>,
additional_params: Option<serde_json::Value>,
}
impl CompletionRequestBuilder {
pub fn new(prompt: impl Into<Message>) -> Self {
Self {
prompt: prompt.into(),
preamble: None,
chat_history: Vec::new(),
tools: Vec::new(),
temperature: None,
max_tokens: None,
additional_params: None,
}
}
pub fn preamble(mut self, preamble: impl Into<String>) -> Self {
self.preamble = Some(preamble.into());
self
}
pub fn temperature(mut self, temp: f64) -> Self {
self.temperature = Some(temp);
self
}
pub fn max_tokens(mut self, tokens: u64) -> Self {
self.max_tokens = Some(tokens);
self
}
pub fn tool(mut self, tool: ToolDefinition) -> Self {
self.tools.push(tool);
self
}
pub fn message(mut self, msg: Message) -> Self {
self.chat_history.push(msg);
self
}
pub fn additional_params(mut self, params: serde_json::Value) -> Self {
self.additional_params = Some(match self.additional_params {
Some(existing) => merge_json(existing, params),
None => params,
});
self
}
pub fn build(self) -> CompletionRequest {
CompletionRequest {
prompt: self.prompt,
preamble: self.preamble,
chat_history: self.chat_history,
tools: self.tools,
temperature: self.temperature,
max_tokens: self.max_tokens,
additional_params: self.additional_params,
}
}
}
fn merge_json(a: serde_json::Value, b: serde_json::Value) -> serde_json::Value {
match (a, b) {
(serde_json::Value::Object(mut a_map), serde_json::Value::Object(b_map)) => {
for (k, v) in b_map {
a_map.insert(k, v);
}
serde_json::Value::Object(a_map)
}
(_, b) => b,
}
}