1use super::message::{AssistantContent, DocumentMediaType};
67use crate::client::FinalCompletionResponse;
68#[allow(deprecated)]
69use crate::client::completion::CompletionModelHandle;
70use crate::message::ToolChoice;
71use crate::streaming::StreamingCompletionResponse;
72use crate::tool::server::ToolServerError;
73use crate::wasm_compat::{WasmBoxedFuture, WasmCompatSend, WasmCompatSync};
74use crate::{OneOrMany, http_client, streaming};
75use crate::{
76 json_utils,
77 message::{Message, UserContent},
78 tool::ToolSetError,
79};
80use serde::de::DeserializeOwned;
81use serde::{Deserialize, Serialize};
82use std::collections::HashMap;
83use std::ops::{Add, AddAssign};
84use std::sync::Arc;
85use thiserror::Error;
86
87#[derive(Debug, Error)]
89pub enum CompletionError {
90 #[error("HttpError: {0}")]
92 HttpError(#[from] http_client::Error),
93
94 #[error("JsonError: {0}")]
96 JsonError(#[from] serde_json::Error),
97
98 #[error("UrlError: {0}")]
100 UrlError(#[from] url::ParseError),
101
102 #[cfg(not(target_family = "wasm"))]
103 #[error("RequestError: {0}")]
105 RequestError(#[from] Box<dyn std::error::Error + Send + Sync + 'static>),
106
107 #[cfg(target_family = "wasm")]
108 #[error("RequestError: {0}")]
110 RequestError(#[from] Box<dyn std::error::Error + 'static>),
111
112 #[error("ResponseError: {0}")]
114 ResponseError(String),
115
116 #[error("ProviderError: {0}")]
118 ProviderError(String),
119}
120
121#[derive(Debug, Error)]
123pub enum PromptError {
124 #[error("CompletionError: {0}")]
126 CompletionError(#[from] CompletionError),
127
128 #[error("ToolCallError: {0}")]
130 ToolError(#[from] ToolSetError),
131
132 #[error("ToolServerError: {0}")]
134 ToolServerError(#[from] ToolServerError),
135
136 #[error("MaxDepthError: (reached limit: {max_depth})")]
140 MaxDepthError {
141 max_depth: usize,
142 chat_history: Box<Vec<Message>>,
143 prompt: Box<Message>,
144 },
145
146 #[error("PromptCancelled: {reason}")]
148 PromptCancelled {
149 chat_history: Box<Vec<Message>>,
150 reason: String,
151 },
152}
153
154impl PromptError {
155 pub(crate) fn prompt_cancelled(chat_history: Vec<Message>, reason: &str) -> Self {
156 Self::PromptCancelled {
157 chat_history: Box::new(chat_history),
158 reason: reason.to_string(),
159 }
160 }
161}
162
163#[derive(Clone, Debug, Deserialize, Serialize)]
164pub struct Document {
165 pub id: String,
166 pub text: String,
167 #[serde(flatten)]
168 pub additional_props: HashMap<String, String>,
169}
170
171impl std::fmt::Display for Document {
172 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
173 write!(
174 f,
175 concat!("<file id: {}>\n", "{}\n", "</file>\n"),
176 self.id,
177 if self.additional_props.is_empty() {
178 self.text.clone()
179 } else {
180 let mut sorted_props = self.additional_props.iter().collect::<Vec<_>>();
181 sorted_props.sort_by(|a, b| a.0.cmp(b.0));
182 let metadata = sorted_props
183 .iter()
184 .map(|(k, v)| format!("{k}: {v:?}"))
185 .collect::<Vec<_>>()
186 .join(" ");
187 format!("<metadata {} />\n{}", metadata, self.text)
188 }
189 )
190 }
191}
192
193#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
194pub struct ToolDefinition {
195 pub name: String,
196 pub description: String,
197 pub parameters: serde_json::Value,
198}
199
200pub trait Prompt: WasmCompatSend + WasmCompatSync {
205 fn prompt(
214 &self,
215 prompt: impl Into<Message> + WasmCompatSend,
216 ) -> impl std::future::IntoFuture<Output = Result<String, PromptError>, IntoFuture: WasmCompatSend>;
217}
218
219pub trait Chat: WasmCompatSend + WasmCompatSync {
221 fn chat(
230 &self,
231 prompt: impl Into<Message> + WasmCompatSend,
232 chat_history: Vec<Message>,
233 ) -> impl std::future::IntoFuture<Output = Result<String, PromptError>, IntoFuture: WasmCompatSend>;
234}
235
236pub trait Completion<M: CompletionModel> {
238 fn completion(
250 &self,
251 prompt: impl Into<Message> + WasmCompatSend,
252 chat_history: Vec<Message>,
253 ) -> impl std::future::Future<Output = Result<CompletionRequestBuilder<M>, CompletionError>>
254 + WasmCompatSend;
255}
256
257#[derive(Debug)]
260pub struct CompletionResponse<T> {
261 pub choice: OneOrMany<AssistantContent>,
264 pub usage: Usage,
266 pub raw_response: T,
268}
269
270pub trait GetTokenUsage {
274 fn token_usage(&self) -> Option<crate::completion::Usage>;
275}
276
277impl GetTokenUsage for () {
278 fn token_usage(&self) -> Option<crate::completion::Usage> {
279 None
280 }
281}
282
283impl<T> GetTokenUsage for Option<T>
284where
285 T: GetTokenUsage,
286{
287 fn token_usage(&self) -> Option<crate::completion::Usage> {
288 if let Some(usage) = self {
289 usage.token_usage()
290 } else {
291 None
292 }
293 }
294}
295
296#[derive(Debug, PartialEq, Eq, Clone, Copy, Serialize, Deserialize)]
299pub struct Usage {
300 pub input_tokens: u64,
302 pub output_tokens: u64,
304 pub total_tokens: u64,
306}
307
308impl Usage {
309 pub fn new() -> Self {
311 Self {
312 input_tokens: 0,
313 output_tokens: 0,
314 total_tokens: 0,
315 }
316 }
317}
318
319impl Default for Usage {
320 fn default() -> Self {
321 Self::new()
322 }
323}
324
325impl Add for Usage {
326 type Output = Self;
327
328 fn add(self, other: Self) -> Self::Output {
329 Self {
330 input_tokens: self.input_tokens + other.input_tokens,
331 output_tokens: self.output_tokens + other.output_tokens,
332 total_tokens: self.total_tokens + other.total_tokens,
333 }
334 }
335}
336
337impl AddAssign for Usage {
338 fn add_assign(&mut self, other: Self) {
339 self.input_tokens += other.input_tokens;
340 self.output_tokens += other.output_tokens;
341 self.total_tokens += other.total_tokens;
342 }
343}
344
345pub trait CompletionModel: Clone + WasmCompatSend + WasmCompatSync {
349 type Response: WasmCompatSend + WasmCompatSync + Serialize + DeserializeOwned;
351 type StreamingResponse: Clone
353 + Unpin
354 + WasmCompatSend
355 + WasmCompatSync
356 + Serialize
357 + DeserializeOwned
358 + GetTokenUsage;
359
360 type Client;
361
362 fn make(client: &Self::Client, model: impl Into<String>) -> Self;
363
364 fn completion(
366 &self,
367 request: CompletionRequest,
368 ) -> impl std::future::Future<
369 Output = Result<CompletionResponse<Self::Response>, CompletionError>,
370 > + WasmCompatSend;
371
372 fn stream(
373 &self,
374 request: CompletionRequest,
375 ) -> impl std::future::Future<
376 Output = Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError>,
377 > + WasmCompatSend;
378
379 fn completion_request(&self, prompt: impl Into<Message>) -> CompletionRequestBuilder<Self> {
381 CompletionRequestBuilder::new(self.clone(), prompt)
382 }
383}
384
385#[allow(deprecated)]
386#[deprecated(
387 since = "0.25.0",
388 note = "`DynClientBuilder` and related features have been deprecated and will be removed in a future release. In this case, use `CompletionModel` instead."
389)]
390pub trait CompletionModelDyn: WasmCompatSend + WasmCompatSync {
391 fn completion(
392 &self,
393 request: CompletionRequest,
394 ) -> WasmBoxedFuture<'_, Result<CompletionResponse<()>, CompletionError>>;
395
396 fn stream(
397 &self,
398 request: CompletionRequest,
399 ) -> WasmBoxedFuture<
400 '_,
401 Result<StreamingCompletionResponse<FinalCompletionResponse>, CompletionError>,
402 >;
403
404 fn completion_request(
405 &self,
406 prompt: Message,
407 ) -> CompletionRequestBuilder<CompletionModelHandle<'_>>;
408}
409
410#[allow(deprecated)]
411impl<T, R> CompletionModelDyn for T
412where
413 T: CompletionModel<StreamingResponse = R>,
414 R: Clone + Unpin + GetTokenUsage + 'static,
415{
416 fn completion(
417 &self,
418 request: CompletionRequest,
419 ) -> WasmBoxedFuture<'_, Result<CompletionResponse<()>, CompletionError>> {
420 Box::pin(async move {
421 self.completion(request)
422 .await
423 .map(|resp| CompletionResponse {
424 choice: resp.choice,
425 usage: resp.usage,
426 raw_response: (),
427 })
428 })
429 }
430
431 fn stream(
432 &self,
433 request: CompletionRequest,
434 ) -> WasmBoxedFuture<
435 '_,
436 Result<StreamingCompletionResponse<FinalCompletionResponse>, CompletionError>,
437 > {
438 Box::pin(async move {
439 let resp = self.stream(request).await?;
440 let inner = resp.inner;
441
442 let stream = streaming::StreamingResultDyn {
443 inner: Box::pin(inner),
444 };
445
446 Ok(StreamingCompletionResponse::stream(Box::pin(stream)))
447 })
448 }
449
450 fn completion_request(
452 &self,
453 prompt: Message,
454 ) -> CompletionRequestBuilder<CompletionModelHandle<'_>> {
455 CompletionRequestBuilder::new(CompletionModelHandle::new(Arc::new(self.clone())), prompt)
456 }
457}
458
459#[derive(Debug, Clone)]
461pub struct CompletionRequest {
462 pub preamble: Option<String>,
464 pub chat_history: OneOrMany<Message>,
467 pub documents: Vec<Document>,
469 pub tools: Vec<ToolDefinition>,
471 pub temperature: Option<f64>,
473 pub max_tokens: Option<u64>,
475 pub tool_choice: Option<ToolChoice>,
477 pub additional_params: Option<serde_json::Value>,
479}
480
481impl CompletionRequest {
482 pub fn normalized_documents(&self) -> Option<Message> {
486 if self.documents.is_empty() {
487 return None;
488 }
489
490 let messages = self
493 .documents
494 .iter()
495 .map(|doc| {
496 UserContent::document(
497 doc.to_string(),
498 Some(DocumentMediaType::TXT),
501 )
502 })
503 .collect::<Vec<_>>();
504
505 Some(Message::User {
506 content: OneOrMany::many(messages).expect("There will be atleast one document"),
507 })
508 }
509}
510
511pub struct CompletionRequestBuilder<M: CompletionModel> {
556 model: M,
557 prompt: Message,
558 preamble: Option<String>,
559 chat_history: Vec<Message>,
560 documents: Vec<Document>,
561 tools: Vec<ToolDefinition>,
562 temperature: Option<f64>,
563 max_tokens: Option<u64>,
564 tool_choice: Option<ToolChoice>,
565 additional_params: Option<serde_json::Value>,
566}
567
568impl<M: CompletionModel> CompletionRequestBuilder<M> {
569 pub fn new(model: M, prompt: impl Into<Message>) -> Self {
570 Self {
571 model,
572 prompt: prompt.into(),
573 preamble: None,
574 chat_history: Vec::new(),
575 documents: Vec::new(),
576 tools: Vec::new(),
577 temperature: None,
578 max_tokens: None,
579 tool_choice: None,
580 additional_params: None,
581 }
582 }
583
584 pub fn preamble(mut self, preamble: String) -> Self {
586 self.preamble = Some(preamble);
587 self
588 }
589
590 pub fn without_preamble(mut self) -> Self {
591 self.preamble = None;
592 self
593 }
594
595 pub fn message(mut self, message: Message) -> Self {
597 self.chat_history.push(message);
598 self
599 }
600
601 pub fn messages(self, messages: Vec<Message>) -> Self {
603 messages
604 .into_iter()
605 .fold(self, |builder, msg| builder.message(msg))
606 }
607
608 pub fn document(mut self, document: Document) -> Self {
610 self.documents.push(document);
611 self
612 }
613
614 pub fn documents(self, documents: Vec<Document>) -> Self {
616 documents
617 .into_iter()
618 .fold(self, |builder, doc| builder.document(doc))
619 }
620
621 pub fn tool(mut self, tool: ToolDefinition) -> Self {
623 self.tools.push(tool);
624 self
625 }
626
627 pub fn tools(self, tools: Vec<ToolDefinition>) -> Self {
629 tools
630 .into_iter()
631 .fold(self, |builder, tool| builder.tool(tool))
632 }
633
634 pub fn additional_params(mut self, additional_params: serde_json::Value) -> Self {
640 match self.additional_params {
641 Some(params) => {
642 self.additional_params = Some(json_utils::merge(params, additional_params));
643 }
644 None => {
645 self.additional_params = Some(additional_params);
646 }
647 }
648 self
649 }
650
651 pub fn additional_params_opt(mut self, additional_params: Option<serde_json::Value>) -> Self {
657 self.additional_params = additional_params;
658 self
659 }
660
661 pub fn temperature(mut self, temperature: f64) -> Self {
663 self.temperature = Some(temperature);
664 self
665 }
666
667 pub fn temperature_opt(mut self, temperature: Option<f64>) -> Self {
669 self.temperature = temperature;
670 self
671 }
672
673 pub fn max_tokens(mut self, max_tokens: u64) -> Self {
676 self.max_tokens = Some(max_tokens);
677 self
678 }
679
680 pub fn max_tokens_opt(mut self, max_tokens: Option<u64>) -> Self {
683 self.max_tokens = max_tokens;
684 self
685 }
686
687 pub fn tool_choice(mut self, tool_choice: ToolChoice) -> Self {
689 self.tool_choice = Some(tool_choice);
690 self
691 }
692
693 pub fn build(self) -> CompletionRequest {
695 let chat_history = OneOrMany::many([self.chat_history, vec![self.prompt]].concat())
696 .expect("There will always be atleast the prompt");
697
698 CompletionRequest {
699 preamble: self.preamble,
700 chat_history,
701 documents: self.documents,
702 tools: self.tools,
703 temperature: self.temperature,
704 max_tokens: self.max_tokens,
705 tool_choice: self.tool_choice,
706 additional_params: self.additional_params,
707 }
708 }
709
710 pub async fn send(self) -> Result<CompletionResponse<M::Response>, CompletionError> {
712 let model = self.model.clone();
713 model.completion(self.build()).await
714 }
715
716 pub async fn stream<'a>(
718 self,
719 ) -> Result<StreamingCompletionResponse<M::StreamingResponse>, CompletionError>
720 where
721 <M as CompletionModel>::StreamingResponse: 'a,
722 Self: 'a,
723 {
724 let model = self.model.clone();
725 model.stream(self.build()).await
726 }
727}
728
729#[cfg(test)]
730mod tests {
731
732 use super::*;
733
734 #[test]
735 fn test_document_display_without_metadata() {
736 let doc = Document {
737 id: "123".to_string(),
738 text: "This is a test document.".to_string(),
739 additional_props: HashMap::new(),
740 };
741
742 let expected = "<file id: 123>\nThis is a test document.\n</file>\n";
743 assert_eq!(format!("{doc}"), expected);
744 }
745
746 #[test]
747 fn test_document_display_with_metadata() {
748 let mut additional_props = HashMap::new();
749 additional_props.insert("author".to_string(), "John Doe".to_string());
750 additional_props.insert("length".to_string(), "42".to_string());
751
752 let doc = Document {
753 id: "123".to_string(),
754 text: "This is a test document.".to_string(),
755 additional_props,
756 };
757
758 let expected = concat!(
759 "<file id: 123>\n",
760 "<metadata author: \"John Doe\" length: \"42\" />\n",
761 "This is a test document.\n",
762 "</file>\n"
763 );
764 assert_eq!(format!("{doc}"), expected);
765 }
766
767 #[test]
768 fn test_normalize_documents_with_documents() {
769 let doc1 = Document {
770 id: "doc1".to_string(),
771 text: "Document 1 text.".to_string(),
772 additional_props: HashMap::new(),
773 };
774
775 let doc2 = Document {
776 id: "doc2".to_string(),
777 text: "Document 2 text.".to_string(),
778 additional_props: HashMap::new(),
779 };
780
781 let request = CompletionRequest {
782 preamble: None,
783 chat_history: OneOrMany::one("What is the capital of France?".into()),
784 documents: vec![doc1, doc2],
785 tools: Vec::new(),
786 temperature: None,
787 max_tokens: None,
788 tool_choice: None,
789 additional_params: None,
790 };
791
792 let expected = Message::User {
793 content: OneOrMany::many(vec![
794 UserContent::document(
795 "<file id: doc1>\nDocument 1 text.\n</file>\n".to_string(),
796 Some(DocumentMediaType::TXT),
797 ),
798 UserContent::document(
799 "<file id: doc2>\nDocument 2 text.\n</file>\n".to_string(),
800 Some(DocumentMediaType::TXT),
801 ),
802 ])
803 .expect("There will be at least one document"),
804 };
805
806 assert_eq!(request.normalized_documents(), Some(expected));
807 }
808
809 #[test]
810 fn test_normalize_documents_without_documents() {
811 let request = CompletionRequest {
812 preamble: None,
813 chat_history: OneOrMany::one("What is the capital of France?".into()),
814 documents: Vec::new(),
815 tools: Vec::new(),
816 temperature: None,
817 max_tokens: None,
818 tool_choice: None,
819 additional_params: None,
820 };
821
822 assert_eq!(request.normalized_documents(), None);
823 }
824}