use std::collections::HashMap;
use serde::{Deserialize, Serialize};
use thiserror::Error;
use crate::streaming::{StreamingCompletionModel, StreamingResult};
use crate::OneOrMany;
use crate::{
json_utils,
message::{Message, UserContent},
tool::ToolSetError,
};
use super::message::{AssistantContent, ContentFormat, DocumentMediaType};
#[derive(Debug, Error)]
pub enum CompletionError {
#[error("HttpError: {0}")]
HttpError(#[from] reqwest::Error),
#[error("JsonError: {0}")]
JsonError(#[from] serde_json::Error),
#[error("RequestError: {0}")]
RequestError(#[from] Box<dyn std::error::Error + Send + Sync + 'static>),
#[error("ResponseError: {0}")]
ResponseError(String),
#[error("ProviderError: {0}")]
ProviderError(String),
}
#[derive(Debug, Error)]
pub enum PromptError {
#[error("CompletionError: {0}")]
CompletionError(#[from] CompletionError),
#[error("ToolCallError: {0}")]
ToolError(#[from] ToolSetError),
#[error("MaxDepthError: (reached limit: {max_depth})")]
MaxDepthError {
max_depth: usize,
chat_history: Vec<Message>,
prompt: Message,
},
}
#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct Document {
pub id: String,
pub text: String,
#[serde(flatten)]
pub additional_props: HashMap<String, String>,
}
impl std::fmt::Display for Document {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
concat!("<file id: {}>\n", "{}\n", "</file>\n"),
self.id,
if self.additional_props.is_empty() {
self.text.clone()
} else {
let mut sorted_props = self.additional_props.iter().collect::<Vec<_>>();
sorted_props.sort_by(|a, b| a.0.cmp(b.0));
let metadata = sorted_props
.iter()
.map(|(k, v)| format!("{}: {:?}", k, v))
.collect::<Vec<_>>()
.join(" ");
format!("<metadata {} />\n{}", metadata, self.text)
}
)
}
}
#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct ToolDefinition {
pub name: String,
pub description: String,
pub parameters: serde_json::Value,
}
pub trait Prompt: Send + Sync {
fn prompt(
&self,
prompt: impl Into<Message> + Send,
) -> impl std::future::IntoFuture<Output = Result<String, PromptError>, IntoFuture: Send>;
}
pub trait Chat: Send + Sync {
fn chat(
&self,
prompt: impl Into<Message> + Send,
chat_history: Vec<Message>,
) -> impl std::future::IntoFuture<Output = Result<String, PromptError>, IntoFuture: Send>;
}
pub trait Completion<M: CompletionModel> {
fn completion(
&self,
prompt: impl Into<Message> + Send,
chat_history: Vec<Message>,
) -> impl std::future::Future<Output = Result<CompletionRequestBuilder<M>, CompletionError>> + Send;
}
#[derive(Debug)]
pub struct CompletionResponse<T> {
pub choice: OneOrMany<AssistantContent>,
pub raw_response: T,
}
pub trait CompletionModel: Clone + Send + Sync {
type Response: Send + Sync;
fn completion(
&self,
request: CompletionRequest,
) -> impl std::future::Future<Output = Result<CompletionResponse<Self::Response>, CompletionError>>
+ Send;
fn completion_request(&self, prompt: impl Into<Message>) -> CompletionRequestBuilder<Self> {
CompletionRequestBuilder::new(self.clone(), prompt)
}
}
pub struct CompletionRequest {
pub preamble: Option<String>,
pub chat_history: OneOrMany<Message>,
pub documents: Vec<Document>,
pub tools: Vec<ToolDefinition>,
pub temperature: Option<f64>,
pub max_tokens: Option<u64>,
pub additional_params: Option<serde_json::Value>,
}
impl CompletionRequest {
pub fn normalized_documents(&self) -> Option<Message> {
if self.documents.is_empty() {
return None;
}
let messages = self
.documents
.iter()
.map(|doc| {
UserContent::document(
doc.to_string(),
Some(ContentFormat::String),
Some(DocumentMediaType::TXT),
)
})
.collect::<Vec<_>>();
Some(Message::User {
content: OneOrMany::many(messages).expect("There will be atleast one document"),
})
}
}
pub struct CompletionRequestBuilder<M: CompletionModel> {
model: M,
prompt: Message,
preamble: Option<String>,
chat_history: Vec<Message>,
documents: Vec<Document>,
tools: Vec<ToolDefinition>,
temperature: Option<f64>,
max_tokens: Option<u64>,
additional_params: Option<serde_json::Value>,
}
impl<M: CompletionModel> CompletionRequestBuilder<M> {
pub fn new(model: M, prompt: impl Into<Message>) -> Self {
Self {
model,
prompt: prompt.into(),
preamble: None,
chat_history: Vec::new(),
documents: Vec::new(),
tools: Vec::new(),
temperature: None,
max_tokens: None,
additional_params: None,
}
}
pub fn preamble(mut self, preamble: String) -> Self {
self.preamble = Some(preamble);
self
}
pub fn message(mut self, message: Message) -> Self {
self.chat_history.push(message);
self
}
pub fn messages(self, messages: Vec<Message>) -> Self {
messages
.into_iter()
.fold(self, |builder, msg| builder.message(msg))
}
pub fn document(mut self, document: Document) -> Self {
self.documents.push(document);
self
}
pub fn documents(self, documents: Vec<Document>) -> Self {
documents
.into_iter()
.fold(self, |builder, doc| builder.document(doc))
}
pub fn tool(mut self, tool: ToolDefinition) -> Self {
self.tools.push(tool);
self
}
pub fn tools(self, tools: Vec<ToolDefinition>) -> Self {
tools
.into_iter()
.fold(self, |builder, tool| builder.tool(tool))
}
pub fn additional_params(mut self, additional_params: serde_json::Value) -> Self {
match self.additional_params {
Some(params) => {
self.additional_params = Some(json_utils::merge(params, additional_params));
}
None => {
self.additional_params = Some(additional_params);
}
}
self
}
pub fn additional_params_opt(mut self, additional_params: Option<serde_json::Value>) -> Self {
self.additional_params = additional_params;
self
}
pub fn temperature(mut self, temperature: f64) -> Self {
self.temperature = Some(temperature);
self
}
pub fn temperature_opt(mut self, temperature: Option<f64>) -> Self {
self.temperature = temperature;
self
}
pub fn max_tokens(mut self, max_tokens: u64) -> Self {
self.max_tokens = Some(max_tokens);
self
}
pub fn max_tokens_opt(mut self, max_tokens: Option<u64>) -> Self {
self.max_tokens = max_tokens;
self
}
pub fn build(self) -> CompletionRequest {
let chat_history = OneOrMany::many([self.chat_history, vec![self.prompt]].concat())
.expect("There will always be atleast the prompt");
CompletionRequest {
preamble: self.preamble,
chat_history,
documents: self.documents,
tools: self.tools,
temperature: self.temperature,
max_tokens: self.max_tokens,
additional_params: self.additional_params,
}
}
pub async fn send(self) -> Result<CompletionResponse<M::Response>, CompletionError> {
let model = self.model.clone();
model.completion(self.build()).await
}
}
impl<M: StreamingCompletionModel> CompletionRequestBuilder<M> {
pub async fn stream(self) -> Result<StreamingResult, CompletionError> {
let model = self.model.clone();
model.stream(self.build()).await
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_document_display_without_metadata() {
let doc = Document {
id: "123".to_string(),
text: "This is a test document.".to_string(),
additional_props: HashMap::new(),
};
let expected = "<file id: 123>\nThis is a test document.\n</file>\n";
assert_eq!(format!("{}", doc), expected);
}
#[test]
fn test_document_display_with_metadata() {
let mut additional_props = HashMap::new();
additional_props.insert("author".to_string(), "John Doe".to_string());
additional_props.insert("length".to_string(), "42".to_string());
let doc = Document {
id: "123".to_string(),
text: "This is a test document.".to_string(),
additional_props,
};
let expected = concat!(
"<file id: 123>\n",
"<metadata author: \"John Doe\" length: \"42\" />\n",
"This is a test document.\n",
"</file>\n"
);
assert_eq!(format!("{}", doc), expected);
}
#[test]
fn test_normalize_documents_with_documents() {
let doc1 = Document {
id: "doc1".to_string(),
text: "Document 1 text.".to_string(),
additional_props: HashMap::new(),
};
let doc2 = Document {
id: "doc2".to_string(),
text: "Document 2 text.".to_string(),
additional_props: HashMap::new(),
};
let request = CompletionRequest {
preamble: None,
chat_history: OneOrMany::one("What is the capital of France?".into()),
documents: vec![doc1, doc2],
tools: Vec::new(),
temperature: None,
max_tokens: None,
additional_params: None,
};
let expected = Message::User {
content: OneOrMany::many(vec![
UserContent::document(
"<file id: doc1>\nDocument 1 text.\n</file>\n".to_string(),
Some(ContentFormat::String),
Some(DocumentMediaType::TXT),
),
UserContent::document(
"<file id: doc2>\nDocument 2 text.\n</file>\n".to_string(),
Some(ContentFormat::String),
Some(DocumentMediaType::TXT),
),
])
.expect("There will be at least one document"),
};
assert_eq!(request.normalized_documents(), Some(expected));
}
#[test]
fn test_normalize_documents_without_documents() {
let request = CompletionRequest {
preamble: None,
chat_history: OneOrMany::one("What is the capital of France?".into()),
documents: Vec::new(),
tools: Vec::new(),
temperature: None,
max_tokens: None,
additional_params: None,
};
assert_eq!(request.normalized_documents(), None);
}
}