use super::message::{AssistantContent, DocumentMediaType};
use crate::message::ToolChoice;
use crate::streaming::StreamingCompletionResponse;
use crate::tool::server::ToolServerError;
use crate::wasm_compat::{WasmCompatSend, WasmCompatSync};
use crate::{OneOrMany, http_client};
use crate::{
json_utils,
message::{Message, UserContent},
tool::ToolSetError,
};
use serde::de::DeserializeOwned;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::ops::{Add, AddAssign};
use thiserror::Error;
#[derive(Debug, Error)]
pub enum CompletionError {
#[error("HttpError: {0}")]
HttpError(#[from] http_client::Error),
#[error("JsonError: {0}")]
JsonError(#[from] serde_json::Error),
#[error("UrlError: {0}")]
UrlError(#[from] url::ParseError),
#[cfg(not(target_family = "wasm"))]
#[error("RequestError: {0}")]
RequestError(#[from] Box<dyn std::error::Error + Send + Sync + 'static>),
#[cfg(target_family = "wasm")]
#[error("RequestError: {0}")]
RequestError(#[from] Box<dyn std::error::Error + 'static>),
#[error("ResponseError: {0}")]
ResponseError(String),
#[error("ProviderError: {0}")]
ProviderError(String),
}
#[derive(Debug, Error)]
pub enum PromptError {
#[error("CompletionError: {0}")]
CompletionError(#[from] CompletionError),
#[error("ToolCallError: {0}")]
ToolError(#[from] ToolSetError),
#[error("ToolServerError: {0}")]
ToolServerError(#[from] Box<ToolServerError>),
#[error("MaxTurnError: (reached max turn limit: {max_turns})")]
MaxTurnsError {
max_turns: usize,
chat_history: Box<Vec<Message>>,
prompt: Box<Message>,
},
#[error("PromptCancelled: {reason}")]
PromptCancelled {
chat_history: Vec<Message>,
reason: String,
},
}
impl PromptError {
pub(crate) fn prompt_cancelled(
chat_history: impl IntoIterator<Item = Message>,
reason: impl Into<String>,
) -> Self {
Self::PromptCancelled {
chat_history: chat_history.into_iter().collect(),
reason: reason.into(),
}
}
}
#[derive(Debug, Error)]
pub enum StructuredOutputError {
#[error("PromptError: {0}")]
PromptError(#[from] Box<PromptError>),
#[error("DeserializationError: {0}")]
DeserializationError(#[from] serde_json::Error),
#[error("EmptyResponse: model returned no content")]
EmptyResponse,
}
#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct Document {
pub id: String,
pub text: String,
#[serde(flatten)]
pub additional_props: HashMap<String, String>,
}
impl std::fmt::Display for Document {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
concat!("<file id: {}>\n", "{}\n", "</file>\n"),
self.id,
if self.additional_props.is_empty() {
self.text.clone()
} else {
let mut sorted_props = self.additional_props.iter().collect::<Vec<_>>();
sorted_props.sort_by(|a, b| a.0.cmp(b.0));
let metadata = sorted_props
.iter()
.map(|(k, v)| format!("{k}: {v:?}"))
.collect::<Vec<_>>()
.join(" ");
format!("<metadata {} />\n{}", metadata, self.text)
}
)
}
}
#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
pub struct ToolDefinition {
pub name: String,
pub description: String,
pub parameters: serde_json::Value,
}
#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
pub struct ProviderToolDefinition {
#[serde(rename = "type")]
pub kind: String,
#[serde(flatten, default, skip_serializing_if = "serde_json::Map::is_empty")]
pub config: serde_json::Map<String, serde_json::Value>,
}
impl ProviderToolDefinition {
pub fn new(kind: impl Into<String>) -> Self {
Self {
kind: kind.into(),
config: serde_json::Map::new(),
}
}
pub fn with_config(mut self, key: impl Into<String>, value: serde_json::Value) -> Self {
self.config.insert(key.into(), value);
self
}
}
pub trait Prompt: WasmCompatSend + WasmCompatSync {
fn prompt(
&self,
prompt: impl Into<Message> + WasmCompatSend,
) -> impl std::future::IntoFuture<Output = Result<String, PromptError>, IntoFuture: WasmCompatSend>;
}
pub trait Chat: WasmCompatSend + WasmCompatSync {
fn chat<I, T>(
&self,
prompt: impl Into<Message> + WasmCompatSend,
chat_history: I,
) -> impl std::future::Future<Output = Result<String, PromptError>> + WasmCompatSend
where
I: IntoIterator<Item = T> + WasmCompatSend,
T: Into<Message>;
}
pub trait TypedPrompt: WasmCompatSend + WasmCompatSync {
type TypedRequest<T>: std::future::IntoFuture<Output = Result<T, StructuredOutputError>>
where
T: schemars::JsonSchema + DeserializeOwned + WasmCompatSend + 'static;
fn prompt_typed<T>(&self, prompt: impl Into<Message> + WasmCompatSend) -> Self::TypedRequest<T>
where
T: schemars::JsonSchema + DeserializeOwned + WasmCompatSend;
}
pub trait Completion<M: CompletionModel> {
fn completion<I, T>(
&self,
prompt: impl Into<Message> + WasmCompatSend,
chat_history: I,
) -> impl std::future::Future<Output = Result<CompletionRequestBuilder<M>, CompletionError>>
+ WasmCompatSend
where
I: IntoIterator<Item = T> + WasmCompatSend,
T: Into<Message>;
}
#[derive(Debug)]
pub struct CompletionResponse<T> {
pub choice: OneOrMany<AssistantContent>,
pub usage: Usage,
pub raw_response: T,
pub message_id: Option<String>,
}
pub trait GetTokenUsage {
fn token_usage(&self) -> Option<crate::completion::Usage>;
}
impl GetTokenUsage for () {
fn token_usage(&self) -> Option<crate::completion::Usage> {
None
}
}
impl<T> GetTokenUsage for Option<T>
where
T: GetTokenUsage,
{
fn token_usage(&self) -> Option<crate::completion::Usage> {
if let Some(usage) = self {
usage.token_usage()
} else {
None
}
}
}
#[derive(Debug, PartialEq, Eq, Clone, Copy, Serialize, Deserialize)]
pub struct Usage {
pub input_tokens: u64,
pub output_tokens: u64,
pub total_tokens: u64,
pub cached_input_tokens: u64,
pub cache_creation_input_tokens: u64,
}
impl Usage {
pub fn new() -> Self {
Self {
input_tokens: 0,
output_tokens: 0,
total_tokens: 0,
cached_input_tokens: 0,
cache_creation_input_tokens: 0,
}
}
}
impl Default for Usage {
fn default() -> Self {
Self::new()
}
}
impl Add for Usage {
type Output = Self;
fn add(self, other: Self) -> Self::Output {
Self {
input_tokens: self.input_tokens + other.input_tokens,
output_tokens: self.output_tokens + other.output_tokens,
total_tokens: self.total_tokens + other.total_tokens,
cached_input_tokens: self.cached_input_tokens + other.cached_input_tokens,
cache_creation_input_tokens: self.cache_creation_input_tokens
+ other.cache_creation_input_tokens,
}
}
}
impl AddAssign for Usage {
fn add_assign(&mut self, other: Self) {
self.input_tokens += other.input_tokens;
self.output_tokens += other.output_tokens;
self.total_tokens += other.total_tokens;
self.cached_input_tokens += other.cached_input_tokens;
self.cache_creation_input_tokens += other.cache_creation_input_tokens;
}
}
pub trait CompletionModel: Clone + WasmCompatSend + WasmCompatSync {
type Response: WasmCompatSend + WasmCompatSync + Serialize + DeserializeOwned;
type StreamingResponse: Clone
+ Unpin
+ WasmCompatSend
+ WasmCompatSync
+ Serialize
+ DeserializeOwned
+ GetTokenUsage;
type Client;
fn make(client: &Self::Client, model: impl Into<String>) -> Self;
fn completion(
&self,
request: CompletionRequest,
) -> impl std::future::Future<
Output = Result<CompletionResponse<Self::Response>, CompletionError>,
> + WasmCompatSend;
fn stream(
&self,
request: CompletionRequest,
) -> impl std::future::Future<
Output = Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError>,
> + WasmCompatSend;
fn completion_request(&self, prompt: impl Into<Message>) -> CompletionRequestBuilder<Self> {
CompletionRequestBuilder::new(self.clone(), prompt)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CompletionRequest {
pub model: Option<String>,
pub preamble: Option<String>,
pub chat_history: OneOrMany<Message>,
pub documents: Vec<Document>,
pub tools: Vec<ToolDefinition>,
pub temperature: Option<f64>,
pub max_tokens: Option<u64>,
pub tool_choice: Option<ToolChoice>,
pub additional_params: Option<serde_json::Value>,
pub output_schema: Option<schemars::Schema>,
}
impl CompletionRequest {
pub fn output_schema_name(&self) -> Option<String> {
self.output_schema.as_ref().map(|schema| {
schema
.as_object()
.and_then(|o| o.get("title"))
.and_then(|v| v.as_str())
.unwrap_or("response_schema")
.to_string()
})
}
pub fn normalized_documents(&self) -> Option<Message> {
if self.documents.is_empty() {
return None;
}
let messages = self
.documents
.iter()
.map(|doc| {
UserContent::document(
doc.to_string(),
Some(DocumentMediaType::TXT),
)
})
.collect::<Vec<_>>();
OneOrMany::from_iter_optional(messages).map(|content| Message::User { content })
}
pub fn with_provider_tool(mut self, tool: ProviderToolDefinition) -> Self {
self.additional_params =
merge_provider_tools_into_additional_params(self.additional_params, vec![tool]);
self
}
pub fn with_provider_tools(mut self, tools: Vec<ProviderToolDefinition>) -> Self {
self.additional_params =
merge_provider_tools_into_additional_params(self.additional_params, tools);
self
}
}
fn merge_provider_tools_into_additional_params(
additional_params: Option<serde_json::Value>,
provider_tools: Vec<ProviderToolDefinition>,
) -> Option<serde_json::Value> {
if provider_tools.is_empty() {
return additional_params;
}
let mut provider_tools_json = provider_tools
.into_iter()
.map(|ProviderToolDefinition { kind, mut config }| {
config.insert("type".to_string(), serde_json::Value::String(kind));
serde_json::Value::Object(config)
})
.collect::<Vec<_>>();
let mut params_map = match additional_params {
Some(serde_json::Value::Object(map)) => map,
Some(serde_json::Value::Bool(stream)) => {
let mut map = serde_json::Map::new();
map.insert("stream".to_string(), serde_json::Value::Bool(stream));
map
}
_ => serde_json::Map::new(),
};
let mut merged_tools = match params_map.remove("tools") {
Some(serde_json::Value::Array(existing)) => existing,
_ => Vec::new(),
};
merged_tools.append(&mut provider_tools_json);
params_map.insert("tools".to_string(), serde_json::Value::Array(merged_tools));
Some(serde_json::Value::Object(params_map))
}
pub struct CompletionRequestBuilder<M: CompletionModel> {
model: M,
prompt: Message,
request_model: Option<String>,
preamble: Option<String>,
chat_history: Vec<Message>,
documents: Vec<Document>,
tools: Vec<ToolDefinition>,
provider_tools: Vec<ProviderToolDefinition>,
temperature: Option<f64>,
max_tokens: Option<u64>,
tool_choice: Option<ToolChoice>,
additional_params: Option<serde_json::Value>,
output_schema: Option<schemars::Schema>,
}
impl<M: CompletionModel> CompletionRequestBuilder<M> {
pub fn new(model: M, prompt: impl Into<Message>) -> Self {
Self {
model,
prompt: prompt.into(),
request_model: None,
preamble: None,
chat_history: Vec::new(),
documents: Vec::new(),
tools: Vec::new(),
provider_tools: Vec::new(),
temperature: None,
max_tokens: None,
tool_choice: None,
additional_params: None,
output_schema: None,
}
}
pub fn preamble(mut self, preamble: String) -> Self {
self.preamble = Some(preamble);
self
}
pub fn model(mut self, model: impl Into<String>) -> Self {
self.request_model = Some(model.into());
self
}
pub fn model_opt(mut self, model: Option<String>) -> Self {
self.request_model = model;
self
}
pub fn without_preamble(mut self) -> Self {
self.preamble = None;
self
}
pub fn message(mut self, message: Message) -> Self {
self.chat_history.push(message);
self
}
pub fn messages(mut self, messages: impl IntoIterator<Item = Message>) -> Self {
self.chat_history.extend(messages);
self
}
pub fn document(mut self, document: Document) -> Self {
self.documents.push(document);
self
}
pub fn documents(self, documents: impl IntoIterator<Item = Document>) -> Self {
documents
.into_iter()
.fold(self, |builder, doc| builder.document(doc))
}
pub fn tool(mut self, tool: ToolDefinition) -> Self {
self.tools.push(tool);
self
}
pub fn tools(self, tools: Vec<ToolDefinition>) -> Self {
tools
.into_iter()
.fold(self, |builder, tool| builder.tool(tool))
}
pub fn provider_tool(mut self, tool: ProviderToolDefinition) -> Self {
self.provider_tools.push(tool);
self
}
pub fn provider_tools(self, tools: Vec<ProviderToolDefinition>) -> Self {
tools
.into_iter()
.fold(self, |builder, tool| builder.provider_tool(tool))
}
pub fn additional_params(mut self, additional_params: serde_json::Value) -> Self {
match self.additional_params {
Some(params) => {
self.additional_params = Some(json_utils::merge(params, additional_params));
}
None => {
self.additional_params = Some(additional_params);
}
}
self
}
pub fn additional_params_opt(mut self, additional_params: Option<serde_json::Value>) -> Self {
self.additional_params = additional_params;
self
}
pub fn temperature(mut self, temperature: f64) -> Self {
self.temperature = Some(temperature);
self
}
pub fn temperature_opt(mut self, temperature: Option<f64>) -> Self {
self.temperature = temperature;
self
}
pub fn max_tokens(mut self, max_tokens: u64) -> Self {
self.max_tokens = Some(max_tokens);
self
}
pub fn max_tokens_opt(mut self, max_tokens: Option<u64>) -> Self {
self.max_tokens = max_tokens;
self
}
pub fn tool_choice(mut self, tool_choice: ToolChoice) -> Self {
self.tool_choice = Some(tool_choice);
self
}
pub fn output_schema(mut self, schema: schemars::Schema) -> Self {
self.output_schema = Some(schema);
self
}
pub fn output_schema_opt(mut self, schema: Option<schemars::Schema>) -> Self {
self.output_schema = schema;
self
}
pub fn build(self) -> CompletionRequest {
let mut chat_history = self.chat_history;
let prompt = self.prompt;
if let Some(preamble) = self.preamble {
chat_history.insert(0, Message::system(preamble));
}
chat_history.push(prompt.clone());
let chat_history =
OneOrMany::from_iter_optional(chat_history).unwrap_or_else(|| OneOrMany::one(prompt));
let additional_params = merge_provider_tools_into_additional_params(
self.additional_params,
self.provider_tools,
);
CompletionRequest {
model: self.request_model,
preamble: None,
chat_history,
documents: self.documents,
tools: self.tools,
temperature: self.temperature,
max_tokens: self.max_tokens,
tool_choice: self.tool_choice,
additional_params,
output_schema: self.output_schema,
}
}
pub async fn send(self) -> Result<CompletionResponse<M::Response>, CompletionError> {
let model = self.model.clone();
model.completion(self.build()).await
}
pub async fn stream<'a>(
self,
) -> Result<StreamingCompletionResponse<M::StreamingResponse>, CompletionError>
where
<M as CompletionModel>::StreamingResponse: 'a,
Self: 'a,
{
let model = self.model.clone();
model.stream(self.build()).await
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::streaming::StreamingCompletionResponse;
use serde::{Deserialize, Serialize};
#[derive(Clone)]
struct DummyModel;
#[derive(Clone, Debug, Serialize, Deserialize)]
struct DummyStreamingResponse;
impl GetTokenUsage for DummyStreamingResponse {
fn token_usage(&self) -> Option<Usage> {
None
}
}
impl CompletionModel for DummyModel {
type Response = serde_json::Value;
type StreamingResponse = DummyStreamingResponse;
type Client = ();
fn make(_client: &Self::Client, _model: impl Into<String>) -> Self {
Self
}
async fn completion(
&self,
_request: CompletionRequest,
) -> Result<CompletionResponse<Self::Response>, CompletionError> {
Err(CompletionError::ProviderError(
"dummy completion model".to_string(),
))
}
async fn stream(
&self,
_request: CompletionRequest,
) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
Err(CompletionError::ProviderError(
"dummy completion model".to_string(),
))
}
}
#[test]
fn test_document_display_without_metadata() {
let doc = Document {
id: "123".to_string(),
text: "This is a test document.".to_string(),
additional_props: HashMap::new(),
};
let expected = "<file id: 123>\nThis is a test document.\n</file>\n";
assert_eq!(format!("{doc}"), expected);
}
#[test]
fn test_document_display_with_metadata() {
let mut additional_props = HashMap::new();
additional_props.insert("author".to_string(), "John Doe".to_string());
additional_props.insert("length".to_string(), "42".to_string());
let doc = Document {
id: "123".to_string(),
text: "This is a test document.".to_string(),
additional_props,
};
let expected = concat!(
"<file id: 123>\n",
"<metadata author: \"John Doe\" length: \"42\" />\n",
"This is a test document.\n",
"</file>\n"
);
assert_eq!(format!("{doc}"), expected);
}
#[test]
fn test_normalize_documents_with_documents() {
let doc1 = Document {
id: "doc1".to_string(),
text: "Document 1 text.".to_string(),
additional_props: HashMap::new(),
};
let doc2 = Document {
id: "doc2".to_string(),
text: "Document 2 text.".to_string(),
additional_props: HashMap::new(),
};
let request = CompletionRequest {
model: None,
preamble: None,
chat_history: OneOrMany::one("What is the capital of France?".into()),
documents: vec![doc1, doc2],
tools: Vec::new(),
temperature: None,
max_tokens: None,
tool_choice: None,
additional_params: None,
output_schema: None,
};
let expected = Message::User {
content: OneOrMany::many(vec![
UserContent::document(
"<file id: doc1>\nDocument 1 text.\n</file>\n".to_string(),
Some(DocumentMediaType::TXT),
),
UserContent::document(
"<file id: doc2>\nDocument 2 text.\n</file>\n".to_string(),
Some(DocumentMediaType::TXT),
),
])
.expect("There will be at least one document"),
};
assert_eq!(request.normalized_documents(), Some(expected));
}
#[test]
fn test_normalize_documents_without_documents() {
let request = CompletionRequest {
model: None,
preamble: None,
chat_history: OneOrMany::one("What is the capital of France?".into()),
documents: Vec::new(),
tools: Vec::new(),
temperature: None,
max_tokens: None,
tool_choice: None,
additional_params: None,
output_schema: None,
};
assert_eq!(request.normalized_documents(), None);
}
#[test]
fn preamble_builder_funnels_to_system_message() {
let request = CompletionRequestBuilder::new(DummyModel, Message::user("Prompt"))
.preamble("System prompt".to_string())
.message(Message::user("History"))
.build();
assert_eq!(request.preamble, None);
let history = request.chat_history.into_iter().collect::<Vec<_>>();
assert_eq!(history.len(), 3);
assert!(matches!(
&history[0],
Message::System { content } if content == "System prompt"
));
assert!(matches!(&history[1], Message::User { .. }));
assert!(matches!(&history[2], Message::User { .. }));
}
#[test]
fn without_preamble_removes_legacy_preamble_injection() {
let request = CompletionRequestBuilder::new(DummyModel, Message::user("Prompt"))
.preamble("System prompt".to_string())
.without_preamble()
.build();
assert_eq!(request.preamble, None);
let history = request.chat_history.into_iter().collect::<Vec<_>>();
assert_eq!(history.len(), 1);
assert!(matches!(&history[0], Message::User { .. }));
}
}