use super::message::{AssistantContent, DocumentMediaType};
use crate::client::FinalCompletionResponse;
#[allow(deprecated)]
use crate::client::completion::CompletionModelHandle;
use crate::message::ToolChoice;
use crate::streaming::StreamingCompletionResponse;
use crate::tool::server::ToolServerError;
use crate::wasm_compat::{WasmBoxedFuture, WasmCompatSend, WasmCompatSync};
use crate::{OneOrMany, http_client, streaming};
use crate::{
json_utils,
message::{Message, UserContent},
tool::ToolSetError,
};
use serde::de::DeserializeOwned;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::ops::{Add, AddAssign};
use std::sync::Arc;
use thiserror::Error;
#[derive(Debug, Error)]
pub enum CompletionError {
#[error("HttpError: {0}")]
HttpError(#[from] http_client::Error),
#[error("JsonError: {0}")]
JsonError(#[from] serde_json::Error),
#[error("UrlError: {0}")]
UrlError(#[from] url::ParseError),
#[cfg(not(target_family = "wasm"))]
#[error("RequestError: {0}")]
RequestError(#[from] Box<dyn std::error::Error + Send + Sync + 'static>),
#[cfg(target_family = "wasm")]
#[error("RequestError: {0}")]
RequestError(#[from] Box<dyn std::error::Error + 'static>),
#[error("ResponseError: {0}")]
ResponseError(String),
#[error("ProviderError: {0}")]
ProviderError(String),
}
#[derive(Debug, Error)]
pub enum PromptError {
#[error("CompletionError: {0}")]
CompletionError(#[from] CompletionError),
#[error("ToolCallError: {0}")]
ToolError(#[from] ToolSetError),
#[error("ToolServerError: {0}")]
ToolServerError(#[from] ToolServerError),
#[error("MaxDepthError: (reached limit: {max_depth})")]
MaxDepthError {
max_depth: usize,
chat_history: Box<Vec<Message>>,
prompt: Box<Message>,
},
#[error("PromptCancelled")]
PromptCancelled { chat_history: Box<Vec<Message>> },
}
impl PromptError {
pub(crate) fn prompt_cancelled(chat_history: Vec<Message>) -> Self {
Self::PromptCancelled {
chat_history: Box::new(chat_history),
}
}
}
#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct Document {
pub id: String,
pub text: String,
#[serde(flatten)]
pub additional_props: HashMap<String, String>,
}
impl std::fmt::Display for Document {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
concat!("<file id: {}>\n", "{}\n", "</file>\n"),
self.id,
if self.additional_props.is_empty() {
self.text.clone()
} else {
let mut sorted_props = self.additional_props.iter().collect::<Vec<_>>();
sorted_props.sort_by(|a, b| a.0.cmp(b.0));
let metadata = sorted_props
.iter()
.map(|(k, v)| format!("{k}: {v:?}"))
.collect::<Vec<_>>()
.join(" ");
format!("<metadata {} />\n{}", metadata, self.text)
}
)
}
}
#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
pub struct ToolDefinition {
pub name: String,
pub description: String,
pub parameters: serde_json::Value,
}
pub trait Prompt: WasmCompatSend + WasmCompatSync {
fn prompt(
&self,
prompt: impl Into<Message> + WasmCompatSend,
) -> impl std::future::IntoFuture<Output = Result<String, PromptError>, IntoFuture: WasmCompatSend>;
}
pub trait Chat: WasmCompatSend + WasmCompatSync {
fn chat(
&self,
prompt: impl Into<Message> + WasmCompatSend,
chat_history: Vec<Message>,
) -> impl std::future::IntoFuture<Output = Result<String, PromptError>, IntoFuture: WasmCompatSend>;
}
pub trait Completion<M: CompletionModel> {
fn completion(
&self,
prompt: impl Into<Message> + WasmCompatSend,
chat_history: Vec<Message>,
) -> impl std::future::Future<Output = Result<CompletionRequestBuilder<M>, CompletionError>>
+ WasmCompatSend;
}
#[derive(Debug)]
pub struct CompletionResponse<T> {
pub choice: OneOrMany<AssistantContent>,
pub usage: Usage,
pub raw_response: T,
}
pub trait GetTokenUsage {
fn token_usage(&self) -> Option<crate::completion::Usage>;
}
impl GetTokenUsage for () {
fn token_usage(&self) -> Option<crate::completion::Usage> {
None
}
}
impl<T> GetTokenUsage for Option<T>
where
T: GetTokenUsage,
{
fn token_usage(&self) -> Option<crate::completion::Usage> {
if let Some(usage) = self {
usage.token_usage()
} else {
None
}
}
}
#[derive(Debug, PartialEq, Eq, Clone, Copy, Serialize, Deserialize)]
pub struct Usage {
pub input_tokens: u64,
pub output_tokens: u64,
pub total_tokens: u64,
}
impl Usage {
pub fn new() -> Self {
Self {
input_tokens: 0,
output_tokens: 0,
total_tokens: 0,
}
}
}
impl Default for Usage {
fn default() -> Self {
Self::new()
}
}
impl Add for Usage {
type Output = Self;
fn add(self, other: Self) -> Self::Output {
Self {
input_tokens: self.input_tokens + other.input_tokens,
output_tokens: self.output_tokens + other.output_tokens,
total_tokens: self.total_tokens + other.total_tokens,
}
}
}
impl AddAssign for Usage {
fn add_assign(&mut self, other: Self) {
self.input_tokens += other.input_tokens;
self.output_tokens += other.output_tokens;
self.total_tokens += other.total_tokens;
}
}
pub trait CompletionModel: Clone + WasmCompatSend + WasmCompatSync {
type Response: WasmCompatSend + WasmCompatSync + Serialize + DeserializeOwned;
type StreamingResponse: Clone
+ Unpin
+ WasmCompatSend
+ WasmCompatSync
+ Serialize
+ DeserializeOwned
+ GetTokenUsage;
type Client;
fn make(client: &Self::Client, model: impl Into<String>) -> Self;
fn completion(
&self,
request: CompletionRequest,
) -> impl std::future::Future<
Output = Result<CompletionResponse<Self::Response>, CompletionError>,
> + WasmCompatSend;
fn stream(
&self,
request: CompletionRequest,
) -> impl std::future::Future<
Output = Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError>,
> + WasmCompatSend;
fn completion_request(&self, prompt: impl Into<Message>) -> CompletionRequestBuilder<Self> {
CompletionRequestBuilder::new(self.clone(), prompt)
}
}
#[allow(deprecated)]
#[deprecated(
since = "0.25.0",
note = "`DynClientBuilder` and related features have been deprecated and will be removed in a future release. In this case, use `CompletionModel` instead."
)]
pub trait CompletionModelDyn: WasmCompatSend + WasmCompatSync {
fn completion(
&self,
request: CompletionRequest,
) -> WasmBoxedFuture<'_, Result<CompletionResponse<()>, CompletionError>>;
fn stream(
&self,
request: CompletionRequest,
) -> WasmBoxedFuture<
'_,
Result<StreamingCompletionResponse<FinalCompletionResponse>, CompletionError>,
>;
fn completion_request(
&self,
prompt: Message,
) -> CompletionRequestBuilder<CompletionModelHandle<'_>>;
}
#[allow(deprecated)]
impl<T, R> CompletionModelDyn for T
where
T: CompletionModel<StreamingResponse = R>,
R: Clone + Unpin + GetTokenUsage + 'static,
{
fn completion(
&self,
request: CompletionRequest,
) -> WasmBoxedFuture<'_, Result<CompletionResponse<()>, CompletionError>> {
Box::pin(async move {
self.completion(request)
.await
.map(|resp| CompletionResponse {
choice: resp.choice,
usage: resp.usage,
raw_response: (),
})
})
}
fn stream(
&self,
request: CompletionRequest,
) -> WasmBoxedFuture<
'_,
Result<StreamingCompletionResponse<FinalCompletionResponse>, CompletionError>,
> {
Box::pin(async move {
let resp = self.stream(request).await?;
let inner = resp.inner;
let stream = streaming::StreamingResultDyn {
inner: Box::pin(inner),
};
Ok(StreamingCompletionResponse::stream(Box::pin(stream)))
})
}
fn completion_request(
&self,
prompt: Message,
) -> CompletionRequestBuilder<CompletionModelHandle<'_>> {
CompletionRequestBuilder::new(CompletionModelHandle::new(Arc::new(self.clone())), prompt)
}
}
#[derive(Debug, Clone)]
pub struct CompletionRequest {
pub preamble: Option<String>,
pub chat_history: OneOrMany<Message>,
pub documents: Vec<Document>,
pub tools: Vec<ToolDefinition>,
pub temperature: Option<f64>,
pub max_tokens: Option<u64>,
pub tool_choice: Option<ToolChoice>,
pub additional_params: Option<serde_json::Value>,
}
impl CompletionRequest {
pub fn normalized_documents(&self) -> Option<Message> {
if self.documents.is_empty() {
return None;
}
let messages = self
.documents
.iter()
.map(|doc| {
UserContent::document(
doc.to_string(),
Some(DocumentMediaType::TXT),
)
})
.collect::<Vec<_>>();
Some(Message::User {
content: OneOrMany::many(messages).expect("There will be atleast one document"),
})
}
}
pub struct CompletionRequestBuilder<M: CompletionModel> {
model: M,
prompt: Message,
preamble: Option<String>,
chat_history: Vec<Message>,
documents: Vec<Document>,
tools: Vec<ToolDefinition>,
temperature: Option<f64>,
max_tokens: Option<u64>,
tool_choice: Option<ToolChoice>,
additional_params: Option<serde_json::Value>,
}
impl<M: CompletionModel> CompletionRequestBuilder<M> {
pub fn new(model: M, prompt: impl Into<Message>) -> Self {
Self {
model,
prompt: prompt.into(),
preamble: None,
chat_history: Vec::new(),
documents: Vec::new(),
tools: Vec::new(),
temperature: None,
max_tokens: None,
tool_choice: None,
additional_params: None,
}
}
pub fn preamble(mut self, preamble: String) -> Self {
self.preamble = Some(preamble);
self
}
pub fn without_preamble(mut self) -> Self {
self.preamble = None;
self
}
pub fn message(mut self, message: Message) -> Self {
self.chat_history.push(message);
self
}
pub fn messages(self, messages: Vec<Message>) -> Self {
messages
.into_iter()
.fold(self, |builder, msg| builder.message(msg))
}
pub fn document(mut self, document: Document) -> Self {
self.documents.push(document);
self
}
pub fn documents(self, documents: Vec<Document>) -> Self {
documents
.into_iter()
.fold(self, |builder, doc| builder.document(doc))
}
pub fn tool(mut self, tool: ToolDefinition) -> Self {
self.tools.push(tool);
self
}
pub fn tools(self, tools: Vec<ToolDefinition>) -> Self {
tools
.into_iter()
.fold(self, |builder, tool| builder.tool(tool))
}
pub fn additional_params(mut self, additional_params: serde_json::Value) -> Self {
match self.additional_params {
Some(params) => {
self.additional_params = Some(json_utils::merge(params, additional_params));
}
None => {
self.additional_params = Some(additional_params);
}
}
self
}
pub fn additional_params_opt(mut self, additional_params: Option<serde_json::Value>) -> Self {
self.additional_params = additional_params;
self
}
pub fn temperature(mut self, temperature: f64) -> Self {
self.temperature = Some(temperature);
self
}
pub fn temperature_opt(mut self, temperature: Option<f64>) -> Self {
self.temperature = temperature;
self
}
pub fn max_tokens(mut self, max_tokens: u64) -> Self {
self.max_tokens = Some(max_tokens);
self
}
pub fn max_tokens_opt(mut self, max_tokens: Option<u64>) -> Self {
self.max_tokens = max_tokens;
self
}
pub fn tool_choice(mut self, tool_choice: ToolChoice) -> Self {
self.tool_choice = Some(tool_choice);
self
}
pub fn build(self) -> CompletionRequest {
let chat_history = OneOrMany::many([self.chat_history, vec![self.prompt]].concat())
.expect("There will always be atleast the prompt");
CompletionRequest {
preamble: self.preamble,
chat_history,
documents: self.documents,
tools: self.tools,
temperature: self.temperature,
max_tokens: self.max_tokens,
tool_choice: self.tool_choice,
additional_params: self.additional_params,
}
}
pub async fn send(self) -> Result<CompletionResponse<M::Response>, CompletionError> {
let model = self.model.clone();
model.completion(self.build()).await
}
pub async fn stream<'a>(
self,
) -> Result<StreamingCompletionResponse<M::StreamingResponse>, CompletionError>
where
<M as CompletionModel>::StreamingResponse: 'a,
Self: 'a,
{
let model = self.model.clone();
model.stream(self.build()).await
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_document_display_without_metadata() {
let doc = Document {
id: "123".to_string(),
text: "This is a test document.".to_string(),
additional_props: HashMap::new(),
};
let expected = "<file id: 123>\nThis is a test document.\n</file>\n";
assert_eq!(format!("{doc}"), expected);
}
#[test]
fn test_document_display_with_metadata() {
let mut additional_props = HashMap::new();
additional_props.insert("author".to_string(), "John Doe".to_string());
additional_props.insert("length".to_string(), "42".to_string());
let doc = Document {
id: "123".to_string(),
text: "This is a test document.".to_string(),
additional_props,
};
let expected = concat!(
"<file id: 123>\n",
"<metadata author: \"John Doe\" length: \"42\" />\n",
"This is a test document.\n",
"</file>\n"
);
assert_eq!(format!("{doc}"), expected);
}
#[test]
fn test_normalize_documents_with_documents() {
let doc1 = Document {
id: "doc1".to_string(),
text: "Document 1 text.".to_string(),
additional_props: HashMap::new(),
};
let doc2 = Document {
id: "doc2".to_string(),
text: "Document 2 text.".to_string(),
additional_props: HashMap::new(),
};
let request = CompletionRequest {
preamble: None,
chat_history: OneOrMany::one("What is the capital of France?".into()),
documents: vec![doc1, doc2],
tools: Vec::new(),
temperature: None,
max_tokens: None,
tool_choice: None,
additional_params: None,
};
let expected = Message::User {
content: OneOrMany::many(vec![
UserContent::document(
"<file id: doc1>\nDocument 1 text.\n</file>\n".to_string(),
Some(DocumentMediaType::TXT),
),
UserContent::document(
"<file id: doc2>\nDocument 2 text.\n</file>\n".to_string(),
Some(DocumentMediaType::TXT),
),
])
.expect("There will be at least one document"),
};
assert_eq!(request.normalized_documents(), Some(expected));
}
#[test]
fn test_normalize_documents_without_documents() {
let request = CompletionRequest {
preamble: None,
chat_history: OneOrMany::one("What is the capital of France?".into()),
documents: Vec::new(),
tools: Vec::new(),
temperature: None,
max_tokens: None,
tool_choice: None,
additional_params: None,
};
assert_eq!(request.normalized_documents(), None);
}
}