1use super::message::{AssistantContent, DocumentMediaType};
67use crate::client::FinalCompletionResponse;
68#[allow(deprecated)]
69use crate::client::completion::CompletionModelHandle;
70use crate::message::ToolChoice;
71use crate::streaming::StreamingCompletionResponse;
72use crate::tool::server::ToolServerError;
73use crate::wasm_compat::{WasmBoxedFuture, WasmCompatSend, WasmCompatSync};
74use crate::{OneOrMany, http_client, streaming};
75use crate::{
76 json_utils,
77 message::{Message, UserContent},
78 tool::ToolSetError,
79};
80use serde::de::DeserializeOwned;
81use serde::{Deserialize, Serialize};
82use std::collections::HashMap;
83use std::ops::{Add, AddAssign};
84use std::sync::Arc;
85use thiserror::Error;
86
87#[derive(Debug, Error)]
89pub enum CompletionError {
90 #[error("HttpError: {0}")]
92 HttpError(#[from] http_client::Error),
93
94 #[error("JsonError: {0}")]
96 JsonError(#[from] serde_json::Error),
97
98 #[error("UrlError: {0}")]
100 UrlError(#[from] url::ParseError),
101
102 #[cfg(not(target_family = "wasm"))]
103 #[error("RequestError: {0}")]
105 RequestError(#[from] Box<dyn std::error::Error + Send + Sync + 'static>),
106
107 #[cfg(target_family = "wasm")]
108 #[error("RequestError: {0}")]
110 RequestError(#[from] Box<dyn std::error::Error + 'static>),
111
112 #[error("ResponseError: {0}")]
114 ResponseError(String),
115
116 #[error("ProviderError: {0}")]
118 ProviderError(String),
119}
120
121#[derive(Debug, Error)]
123pub enum PromptError {
124 #[error("CompletionError: {0}")]
126 CompletionError(#[from] CompletionError),
127
128 #[error("ToolCallError: {0}")]
130 ToolError(#[from] ToolSetError),
131
132 #[error("ToolServerError: {0}")]
134 ToolServerError(#[from] ToolServerError),
135
136 #[error("MaxDepthError: (reached limit: {max_depth})")]
140 MaxDepthError {
141 max_depth: usize,
142 chat_history: Box<Vec<Message>>,
143 prompt: Message,
144 },
145
146 #[error("PromptCancelled")]
148 PromptCancelled { chat_history: Box<Vec<Message>> },
149}
150
151impl PromptError {
152 pub(crate) fn prompt_cancelled(chat_history: Vec<Message>) -> Self {
153 Self::PromptCancelled {
154 chat_history: Box::new(chat_history),
155 }
156 }
157}
158
159#[derive(Clone, Debug, Deserialize, Serialize)]
160pub struct Document {
161 pub id: String,
162 pub text: String,
163 #[serde(flatten)]
164 pub additional_props: HashMap<String, String>,
165}
166
167impl std::fmt::Display for Document {
168 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
169 write!(
170 f,
171 concat!("<file id: {}>\n", "{}\n", "</file>\n"),
172 self.id,
173 if self.additional_props.is_empty() {
174 self.text.clone()
175 } else {
176 let mut sorted_props = self.additional_props.iter().collect::<Vec<_>>();
177 sorted_props.sort_by(|a, b| a.0.cmp(b.0));
178 let metadata = sorted_props
179 .iter()
180 .map(|(k, v)| format!("{k}: {v:?}"))
181 .collect::<Vec<_>>()
182 .join(" ");
183 format!("<metadata {} />\n{}", metadata, self.text)
184 }
185 )
186 }
187}
188
189#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
190pub struct ToolDefinition {
191 pub name: String,
192 pub description: String,
193 pub parameters: serde_json::Value,
194}
195
196pub trait Prompt: WasmCompatSend + WasmCompatSync {
201 fn prompt(
210 &self,
211 prompt: impl Into<Message> + WasmCompatSend,
212 ) -> impl std::future::IntoFuture<Output = Result<String, PromptError>, IntoFuture: WasmCompatSend>;
213}
214
215pub trait Chat: WasmCompatSend + WasmCompatSync {
217 fn chat(
226 &self,
227 prompt: impl Into<Message> + WasmCompatSend,
228 chat_history: Vec<Message>,
229 ) -> impl std::future::IntoFuture<Output = Result<String, PromptError>, IntoFuture: WasmCompatSend>;
230}
231
232pub trait Completion<M: CompletionModel> {
234 fn completion(
246 &self,
247 prompt: impl Into<Message> + WasmCompatSend,
248 chat_history: Vec<Message>,
249 ) -> impl std::future::Future<Output = Result<CompletionRequestBuilder<M>, CompletionError>>
250 + WasmCompatSend;
251}
252
253#[derive(Debug)]
256pub struct CompletionResponse<T> {
257 pub choice: OneOrMany<AssistantContent>,
260 pub usage: Usage,
262 pub raw_response: T,
264}
265
266pub trait GetTokenUsage {
270 fn token_usage(&self) -> Option<crate::completion::Usage>;
271}
272
273impl GetTokenUsage for () {
274 fn token_usage(&self) -> Option<crate::completion::Usage> {
275 None
276 }
277}
278
279impl<T> GetTokenUsage for Option<T>
280where
281 T: GetTokenUsage,
282{
283 fn token_usage(&self) -> Option<crate::completion::Usage> {
284 if let Some(usage) = self {
285 usage.token_usage()
286 } else {
287 None
288 }
289 }
290}
291
292#[derive(Debug, PartialEq, Eq, Clone, Copy, Serialize, Deserialize)]
295pub struct Usage {
296 pub input_tokens: u64,
298 pub output_tokens: u64,
300 pub total_tokens: u64,
302}
303
304impl Usage {
305 pub fn new() -> Self {
307 Self {
308 input_tokens: 0,
309 output_tokens: 0,
310 total_tokens: 0,
311 }
312 }
313}
314
315impl Default for Usage {
316 fn default() -> Self {
317 Self::new()
318 }
319}
320
321impl Add for Usage {
322 type Output = Self;
323
324 fn add(self, other: Self) -> Self::Output {
325 Self {
326 input_tokens: self.input_tokens + other.input_tokens,
327 output_tokens: self.output_tokens + other.output_tokens,
328 total_tokens: self.total_tokens + other.total_tokens,
329 }
330 }
331}
332
333impl AddAssign for Usage {
334 fn add_assign(&mut self, other: Self) {
335 self.input_tokens += other.input_tokens;
336 self.output_tokens += other.output_tokens;
337 self.total_tokens += other.total_tokens;
338 }
339}
340
341pub trait CompletionModel: Clone + WasmCompatSend + WasmCompatSync {
345 type Response: WasmCompatSend + WasmCompatSync + Serialize + DeserializeOwned;
347 type StreamingResponse: Clone
349 + Unpin
350 + WasmCompatSend
351 + WasmCompatSync
352 + Serialize
353 + DeserializeOwned
354 + GetTokenUsage;
355
356 type Client;
357
358 fn make(client: &Self::Client, model: impl Into<String>) -> Self;
359
360 fn completion(
362 &self,
363 request: CompletionRequest,
364 ) -> impl std::future::Future<
365 Output = Result<CompletionResponse<Self::Response>, CompletionError>,
366 > + WasmCompatSend;
367
368 fn stream(
369 &self,
370 request: CompletionRequest,
371 ) -> impl std::future::Future<
372 Output = Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError>,
373 > + WasmCompatSend;
374
375 fn completion_request(&self, prompt: impl Into<Message>) -> CompletionRequestBuilder<Self> {
377 CompletionRequestBuilder::new(self.clone(), prompt)
378 }
379}
380
381#[allow(deprecated)]
382#[deprecated(
383 since = "0.25.0",
384 note = "`DynClientBuilder` and related features have been deprecated and will be removed in a future release. In this case, use `CompletionModel` instead."
385)]
386pub trait CompletionModelDyn: WasmCompatSend + WasmCompatSync {
387 fn completion(
388 &self,
389 request: CompletionRequest,
390 ) -> WasmBoxedFuture<'_, Result<CompletionResponse<()>, CompletionError>>;
391
392 fn stream(
393 &self,
394 request: CompletionRequest,
395 ) -> WasmBoxedFuture<
396 '_,
397 Result<StreamingCompletionResponse<FinalCompletionResponse>, CompletionError>,
398 >;
399
400 fn completion_request(
401 &self,
402 prompt: Message,
403 ) -> CompletionRequestBuilder<CompletionModelHandle<'_>>;
404}
405
406#[allow(deprecated)]
407impl<T, R> CompletionModelDyn for T
408where
409 T: CompletionModel<StreamingResponse = R>,
410 R: Clone + Unpin + GetTokenUsage + 'static,
411{
412 fn completion(
413 &self,
414 request: CompletionRequest,
415 ) -> WasmBoxedFuture<'_, Result<CompletionResponse<()>, CompletionError>> {
416 Box::pin(async move {
417 self.completion(request)
418 .await
419 .map(|resp| CompletionResponse {
420 choice: resp.choice,
421 usage: resp.usage,
422 raw_response: (),
423 })
424 })
425 }
426
427 fn stream(
428 &self,
429 request: CompletionRequest,
430 ) -> WasmBoxedFuture<
431 '_,
432 Result<StreamingCompletionResponse<FinalCompletionResponse>, CompletionError>,
433 > {
434 Box::pin(async move {
435 let resp = self.stream(request).await?;
436 let inner = resp.inner;
437
438 let stream = streaming::StreamingResultDyn {
439 inner: Box::pin(inner),
440 };
441
442 Ok(StreamingCompletionResponse::stream(Box::pin(stream)))
443 })
444 }
445
446 fn completion_request(
448 &self,
449 prompt: Message,
450 ) -> CompletionRequestBuilder<CompletionModelHandle<'_>> {
451 CompletionRequestBuilder::new(CompletionModelHandle::new(Arc::new(self.clone())), prompt)
452 }
453}
454
455#[derive(Debug, Clone)]
457pub struct CompletionRequest {
458 pub preamble: Option<String>,
460 pub chat_history: OneOrMany<Message>,
463 pub documents: Vec<Document>,
465 pub tools: Vec<ToolDefinition>,
467 pub temperature: Option<f64>,
469 pub max_tokens: Option<u64>,
471 pub tool_choice: Option<ToolChoice>,
473 pub additional_params: Option<serde_json::Value>,
475}
476
477impl CompletionRequest {
478 pub fn normalized_documents(&self) -> Option<Message> {
482 if self.documents.is_empty() {
483 return None;
484 }
485
486 let messages = self
489 .documents
490 .iter()
491 .map(|doc| {
492 UserContent::document(
493 doc.to_string(),
494 Some(DocumentMediaType::TXT),
497 )
498 })
499 .collect::<Vec<_>>();
500
501 Some(Message::User {
502 content: OneOrMany::many(messages).expect("There will be atleast one document"),
503 })
504 }
505}
506
507pub struct CompletionRequestBuilder<M: CompletionModel> {
552 model: M,
553 prompt: Message,
554 preamble: Option<String>,
555 chat_history: Vec<Message>,
556 documents: Vec<Document>,
557 tools: Vec<ToolDefinition>,
558 temperature: Option<f64>,
559 max_tokens: Option<u64>,
560 tool_choice: Option<ToolChoice>,
561 additional_params: Option<serde_json::Value>,
562}
563
564impl<M: CompletionModel> CompletionRequestBuilder<M> {
565 pub fn new(model: M, prompt: impl Into<Message>) -> Self {
566 Self {
567 model,
568 prompt: prompt.into(),
569 preamble: None,
570 chat_history: Vec::new(),
571 documents: Vec::new(),
572 tools: Vec::new(),
573 temperature: None,
574 max_tokens: None,
575 tool_choice: None,
576 additional_params: None,
577 }
578 }
579
580 pub fn preamble(mut self, preamble: String) -> Self {
582 self.preamble = Some(preamble);
583 self
584 }
585
586 pub fn without_preamble(mut self) -> Self {
587 self.preamble = None;
588 self
589 }
590
591 pub fn message(mut self, message: Message) -> Self {
593 self.chat_history.push(message);
594 self
595 }
596
597 pub fn messages(self, messages: Vec<Message>) -> Self {
599 messages
600 .into_iter()
601 .fold(self, |builder, msg| builder.message(msg))
602 }
603
604 pub fn document(mut self, document: Document) -> Self {
606 self.documents.push(document);
607 self
608 }
609
610 pub fn documents(self, documents: Vec<Document>) -> Self {
612 documents
613 .into_iter()
614 .fold(self, |builder, doc| builder.document(doc))
615 }
616
617 pub fn tool(mut self, tool: ToolDefinition) -> Self {
619 self.tools.push(tool);
620 self
621 }
622
623 pub fn tools(self, tools: Vec<ToolDefinition>) -> Self {
625 tools
626 .into_iter()
627 .fold(self, |builder, tool| builder.tool(tool))
628 }
629
630 pub fn additional_params(mut self, additional_params: serde_json::Value) -> Self {
636 match self.additional_params {
637 Some(params) => {
638 self.additional_params = Some(json_utils::merge(params, additional_params));
639 }
640 None => {
641 self.additional_params = Some(additional_params);
642 }
643 }
644 self
645 }
646
647 pub fn additional_params_opt(mut self, additional_params: Option<serde_json::Value>) -> Self {
653 self.additional_params = additional_params;
654 self
655 }
656
657 pub fn temperature(mut self, temperature: f64) -> Self {
659 self.temperature = Some(temperature);
660 self
661 }
662
663 pub fn temperature_opt(mut self, temperature: Option<f64>) -> Self {
665 self.temperature = temperature;
666 self
667 }
668
669 pub fn max_tokens(mut self, max_tokens: u64) -> Self {
672 self.max_tokens = Some(max_tokens);
673 self
674 }
675
676 pub fn max_tokens_opt(mut self, max_tokens: Option<u64>) -> Self {
679 self.max_tokens = max_tokens;
680 self
681 }
682
683 pub fn tool_choice(mut self, tool_choice: ToolChoice) -> Self {
685 self.tool_choice = Some(tool_choice);
686 self
687 }
688
689 pub fn build(self) -> CompletionRequest {
691 let chat_history = OneOrMany::many([self.chat_history, vec![self.prompt]].concat())
692 .expect("There will always be atleast the prompt");
693
694 CompletionRequest {
695 preamble: self.preamble,
696 chat_history,
697 documents: self.documents,
698 tools: self.tools,
699 temperature: self.temperature,
700 max_tokens: self.max_tokens,
701 tool_choice: self.tool_choice,
702 additional_params: self.additional_params,
703 }
704 }
705
706 pub async fn send(self) -> Result<CompletionResponse<M::Response>, CompletionError> {
708 let model = self.model.clone();
709 model.completion(self.build()).await
710 }
711
712 pub async fn stream<'a>(
714 self,
715 ) -> Result<StreamingCompletionResponse<M::StreamingResponse>, CompletionError>
716 where
717 <M as CompletionModel>::StreamingResponse: 'a,
718 Self: 'a,
719 {
720 let model = self.model.clone();
721 model.stream(self.build()).await
722 }
723}
724
725#[cfg(test)]
726mod tests {
727
728 use super::*;
729
730 #[test]
731 fn test_document_display_without_metadata() {
732 let doc = Document {
733 id: "123".to_string(),
734 text: "This is a test document.".to_string(),
735 additional_props: HashMap::new(),
736 };
737
738 let expected = "<file id: 123>\nThis is a test document.\n</file>\n";
739 assert_eq!(format!("{doc}"), expected);
740 }
741
742 #[test]
743 fn test_document_display_with_metadata() {
744 let mut additional_props = HashMap::new();
745 additional_props.insert("author".to_string(), "John Doe".to_string());
746 additional_props.insert("length".to_string(), "42".to_string());
747
748 let doc = Document {
749 id: "123".to_string(),
750 text: "This is a test document.".to_string(),
751 additional_props,
752 };
753
754 let expected = concat!(
755 "<file id: 123>\n",
756 "<metadata author: \"John Doe\" length: \"42\" />\n",
757 "This is a test document.\n",
758 "</file>\n"
759 );
760 assert_eq!(format!("{doc}"), expected);
761 }
762
763 #[test]
764 fn test_normalize_documents_with_documents() {
765 let doc1 = Document {
766 id: "doc1".to_string(),
767 text: "Document 1 text.".to_string(),
768 additional_props: HashMap::new(),
769 };
770
771 let doc2 = Document {
772 id: "doc2".to_string(),
773 text: "Document 2 text.".to_string(),
774 additional_props: HashMap::new(),
775 };
776
777 let request = CompletionRequest {
778 preamble: None,
779 chat_history: OneOrMany::one("What is the capital of France?".into()),
780 documents: vec![doc1, doc2],
781 tools: Vec::new(),
782 temperature: None,
783 max_tokens: None,
784 tool_choice: None,
785 additional_params: None,
786 };
787
788 let expected = Message::User {
789 content: OneOrMany::many(vec![
790 UserContent::document(
791 "<file id: doc1>\nDocument 1 text.\n</file>\n".to_string(),
792 Some(DocumentMediaType::TXT),
793 ),
794 UserContent::document(
795 "<file id: doc2>\nDocument 2 text.\n</file>\n".to_string(),
796 Some(DocumentMediaType::TXT),
797 ),
798 ])
799 .expect("There will be at least one document"),
800 };
801
802 assert_eq!(request.normalized_documents(), Some(expected));
803 }
804
805 #[test]
806 fn test_normalize_documents_without_documents() {
807 let request = CompletionRequest {
808 preamble: None,
809 chat_history: OneOrMany::one("What is the capital of France?".into()),
810 documents: Vec::new(),
811 tools: Vec::new(),
812 temperature: None,
813 max_tokens: None,
814 tool_choice: None,
815 additional_params: None,
816 };
817
818 assert_eq!(request.normalized_documents(), None);
819 }
820}