1use super::message::{AssistantContent, DocumentMediaType};
67use crate::client::FinalCompletionResponse;
68#[allow(deprecated)]
69use crate::client::completion::CompletionModelHandle;
70use crate::message::ToolChoice;
71use crate::streaming::StreamingCompletionResponse;
72use crate::tool::server::ToolServerError;
73use crate::wasm_compat::{WasmBoxedFuture, WasmCompatSend, WasmCompatSync};
74use crate::{OneOrMany, http_client, streaming};
75use crate::{
76 json_utils,
77 message::{Message, UserContent},
78 tool::ToolSetError,
79};
80use serde::de::DeserializeOwned;
81use serde::{Deserialize, Serialize};
82use std::collections::HashMap;
83use std::ops::{Add, AddAssign};
84use std::sync::Arc;
85use thiserror::Error;
86
87#[derive(Debug, Error)]
89pub enum CompletionError {
90 #[error("HttpError: {0}")]
92 HttpError(#[from] http_client::Error),
93
94 #[error("JsonError: {0}")]
96 JsonError(#[from] serde_json::Error),
97
98 #[error("UrlError: {0}")]
100 UrlError(#[from] url::ParseError),
101
102 #[cfg(not(target_family = "wasm"))]
103 #[error("RequestError: {0}")]
105 RequestError(#[from] Box<dyn std::error::Error + Send + Sync + 'static>),
106
107 #[cfg(target_family = "wasm")]
108 #[error("RequestError: {0}")]
110 RequestError(#[from] Box<dyn std::error::Error + 'static>),
111
112 #[error("ResponseError: {0}")]
114 ResponseError(String),
115
116 #[error("ProviderError: {0}")]
118 ProviderError(String),
119}
120
121#[derive(Debug, Error)]
123pub enum PromptError {
124 #[error("CompletionError: {0}")]
126 CompletionError(#[from] CompletionError),
127
128 #[error("ToolCallError: {0}")]
130 ToolError(#[from] ToolSetError),
131
132 #[error("ToolServerError: {0}")]
134 ToolServerError(#[from] ToolServerError),
135
136 #[error("MaxTurnError: (reached max turn limit: {max_turns})")]
140 MaxTurnsError {
141 max_turns: usize,
142 chat_history: Box<Vec<Message>>,
143 prompt: Box<Message>,
144 },
145
146 #[error("PromptCancelled: {reason}")]
148 PromptCancelled {
149 chat_history: Box<Vec<Message>>,
150 reason: String,
151 },
152}
153
154impl PromptError {
155 pub(crate) fn prompt_cancelled(chat_history: Vec<Message>, reason: impl Into<String>) -> Self {
156 Self::PromptCancelled {
157 chat_history: Box::new(chat_history),
158 reason: reason.into(),
159 }
160 }
161}
162
163#[derive(Clone, Debug, Deserialize, Serialize)]
164pub struct Document {
165 pub id: String,
166 pub text: String,
167 #[serde(flatten)]
168 pub additional_props: HashMap<String, String>,
169}
170
171impl std::fmt::Display for Document {
172 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
173 write!(
174 f,
175 concat!("<file id: {}>\n", "{}\n", "</file>\n"),
176 self.id,
177 if self.additional_props.is_empty() {
178 self.text.clone()
179 } else {
180 let mut sorted_props = self.additional_props.iter().collect::<Vec<_>>();
181 sorted_props.sort_by(|a, b| a.0.cmp(b.0));
182 let metadata = sorted_props
183 .iter()
184 .map(|(k, v)| format!("{k}: {v:?}"))
185 .collect::<Vec<_>>()
186 .join(" ");
187 format!("<metadata {} />\n{}", metadata, self.text)
188 }
189 )
190 }
191}
192
193#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
194pub struct ToolDefinition {
195 pub name: String,
196 pub description: String,
197 pub parameters: serde_json::Value,
198}
199
200pub trait Prompt: WasmCompatSend + WasmCompatSync {
205 fn prompt(
214 &self,
215 prompt: impl Into<Message> + WasmCompatSend,
216 ) -> impl std::future::IntoFuture<Output = Result<String, PromptError>, IntoFuture: WasmCompatSend>;
217}
218
219pub trait Chat: WasmCompatSend + WasmCompatSync {
221 fn chat(
230 &self,
231 prompt: impl Into<Message> + WasmCompatSend,
232 chat_history: Vec<Message>,
233 ) -> impl std::future::IntoFuture<Output = Result<String, PromptError>, IntoFuture: WasmCompatSend>;
234}
235
236pub trait Completion<M: CompletionModel> {
238 fn completion(
250 &self,
251 prompt: impl Into<Message> + WasmCompatSend,
252 chat_history: Vec<Message>,
253 ) -> impl std::future::Future<Output = Result<CompletionRequestBuilder<M>, CompletionError>>
254 + WasmCompatSend;
255}
256
257#[derive(Debug)]
260pub struct CompletionResponse<T> {
261 pub choice: OneOrMany<AssistantContent>,
264 pub usage: Usage,
266 pub raw_response: T,
268}
269
270pub trait GetTokenUsage {
274 fn token_usage(&self) -> Option<crate::completion::Usage>;
275}
276
277impl GetTokenUsage for () {
278 fn token_usage(&self) -> Option<crate::completion::Usage> {
279 None
280 }
281}
282
283impl<T> GetTokenUsage for Option<T>
284where
285 T: GetTokenUsage,
286{
287 fn token_usage(&self) -> Option<crate::completion::Usage> {
288 if let Some(usage) = self {
289 usage.token_usage()
290 } else {
291 None
292 }
293 }
294}
295
296#[derive(Debug, PartialEq, Eq, Clone, Copy, Serialize, Deserialize)]
299pub struct Usage {
300 pub input_tokens: u64,
302 pub output_tokens: u64,
304 pub total_tokens: u64,
306 pub cached_input_tokens: u64,
308}
309
310impl Usage {
311 pub fn new() -> Self {
313 Self {
314 input_tokens: 0,
315 output_tokens: 0,
316 total_tokens: 0,
317 cached_input_tokens: 0,
318 }
319 }
320}
321
322impl Default for Usage {
323 fn default() -> Self {
324 Self::new()
325 }
326}
327
328impl Add for Usage {
329 type Output = Self;
330
331 fn add(self, other: Self) -> Self::Output {
332 Self {
333 input_tokens: self.input_tokens + other.input_tokens,
334 output_tokens: self.output_tokens + other.output_tokens,
335 total_tokens: self.total_tokens + other.total_tokens,
336 cached_input_tokens: self.cached_input_tokens + other.cached_input_tokens,
337 }
338 }
339}
340
341impl AddAssign for Usage {
342 fn add_assign(&mut self, other: Self) {
343 self.input_tokens += other.input_tokens;
344 self.output_tokens += other.output_tokens;
345 self.total_tokens += other.total_tokens;
346 self.cached_input_tokens += other.cached_input_tokens;
347 }
348}
349
350pub trait CompletionModel: Clone + WasmCompatSend + WasmCompatSync {
354 type Response: WasmCompatSend + WasmCompatSync + Serialize + DeserializeOwned;
356 type StreamingResponse: Clone
358 + Unpin
359 + WasmCompatSend
360 + WasmCompatSync
361 + Serialize
362 + DeserializeOwned
363 + GetTokenUsage;
364
365 type Client;
366
367 fn make(client: &Self::Client, model: impl Into<String>) -> Self;
368
369 fn completion(
371 &self,
372 request: CompletionRequest,
373 ) -> impl std::future::Future<
374 Output = Result<CompletionResponse<Self::Response>, CompletionError>,
375 > + WasmCompatSend;
376
377 fn stream(
378 &self,
379 request: CompletionRequest,
380 ) -> impl std::future::Future<
381 Output = Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError>,
382 > + WasmCompatSend;
383
384 fn completion_request(&self, prompt: impl Into<Message>) -> CompletionRequestBuilder<Self> {
386 CompletionRequestBuilder::new(self.clone(), prompt)
387 }
388}
389
390#[allow(deprecated)]
391#[deprecated(
392 since = "0.25.0",
393 note = "`DynClientBuilder` and related features have been deprecated and will be removed in a future release. In this case, use `CompletionModel` instead."
394)]
395pub trait CompletionModelDyn: WasmCompatSend + WasmCompatSync {
396 fn completion(
397 &self,
398 request: CompletionRequest,
399 ) -> WasmBoxedFuture<'_, Result<CompletionResponse<()>, CompletionError>>;
400
401 fn stream(
402 &self,
403 request: CompletionRequest,
404 ) -> WasmBoxedFuture<
405 '_,
406 Result<StreamingCompletionResponse<FinalCompletionResponse>, CompletionError>,
407 >;
408
409 fn completion_request(
410 &self,
411 prompt: Message,
412 ) -> CompletionRequestBuilder<CompletionModelHandle<'_>>;
413}
414
415#[allow(deprecated)]
416impl<T, R> CompletionModelDyn for T
417where
418 T: CompletionModel<StreamingResponse = R>,
419 R: Clone + Unpin + GetTokenUsage + 'static,
420{
421 fn completion(
422 &self,
423 request: CompletionRequest,
424 ) -> WasmBoxedFuture<'_, Result<CompletionResponse<()>, CompletionError>> {
425 Box::pin(async move {
426 self.completion(request)
427 .await
428 .map(|resp| CompletionResponse {
429 choice: resp.choice,
430 usage: resp.usage,
431 raw_response: (),
432 })
433 })
434 }
435
436 fn stream(
437 &self,
438 request: CompletionRequest,
439 ) -> WasmBoxedFuture<
440 '_,
441 Result<StreamingCompletionResponse<FinalCompletionResponse>, CompletionError>,
442 > {
443 Box::pin(async move {
444 let resp = self.stream(request).await?;
445 let inner = resp.inner;
446
447 let stream = streaming::StreamingResultDyn {
448 inner: Box::pin(inner),
449 };
450
451 Ok(StreamingCompletionResponse::stream(Box::pin(stream)))
452 })
453 }
454
455 fn completion_request(
457 &self,
458 prompt: Message,
459 ) -> CompletionRequestBuilder<CompletionModelHandle<'_>> {
460 CompletionRequestBuilder::new(CompletionModelHandle::new(Arc::new(self.clone())), prompt)
461 }
462}
463
464#[derive(Debug, Clone)]
466pub struct CompletionRequest {
467 pub preamble: Option<String>,
469 pub chat_history: OneOrMany<Message>,
472 pub documents: Vec<Document>,
474 pub tools: Vec<ToolDefinition>,
476 pub temperature: Option<f64>,
478 pub max_tokens: Option<u64>,
480 pub tool_choice: Option<ToolChoice>,
482 pub additional_params: Option<serde_json::Value>,
484}
485
486impl CompletionRequest {
487 pub fn normalized_documents(&self) -> Option<Message> {
491 if self.documents.is_empty() {
492 return None;
493 }
494
495 let messages = self
498 .documents
499 .iter()
500 .map(|doc| {
501 UserContent::document(
502 doc.to_string(),
503 Some(DocumentMediaType::TXT),
506 )
507 })
508 .collect::<Vec<_>>();
509
510 Some(Message::User {
511 content: OneOrMany::many(messages).expect("There will be atleast one document"),
512 })
513 }
514}
515
516pub struct CompletionRequestBuilder<M: CompletionModel> {
561 model: M,
562 prompt: Message,
563 preamble: Option<String>,
564 chat_history: Vec<Message>,
565 documents: Vec<Document>,
566 tools: Vec<ToolDefinition>,
567 temperature: Option<f64>,
568 max_tokens: Option<u64>,
569 tool_choice: Option<ToolChoice>,
570 additional_params: Option<serde_json::Value>,
571}
572
573impl<M: CompletionModel> CompletionRequestBuilder<M> {
574 pub fn new(model: M, prompt: impl Into<Message>) -> Self {
575 Self {
576 model,
577 prompt: prompt.into(),
578 preamble: None,
579 chat_history: Vec::new(),
580 documents: Vec::new(),
581 tools: Vec::new(),
582 temperature: None,
583 max_tokens: None,
584 tool_choice: None,
585 additional_params: None,
586 }
587 }
588
589 pub fn preamble(mut self, preamble: String) -> Self {
591 self.preamble = Some(preamble);
592 self
593 }
594
595 pub fn without_preamble(mut self) -> Self {
596 self.preamble = None;
597 self
598 }
599
600 pub fn message(mut self, message: Message) -> Self {
602 self.chat_history.push(message);
603 self
604 }
605
606 pub fn messages(self, messages: Vec<Message>) -> Self {
608 messages
609 .into_iter()
610 .fold(self, |builder, msg| builder.message(msg))
611 }
612
613 pub fn document(mut self, document: Document) -> Self {
615 self.documents.push(document);
616 self
617 }
618
619 pub fn documents(self, documents: Vec<Document>) -> Self {
621 documents
622 .into_iter()
623 .fold(self, |builder, doc| builder.document(doc))
624 }
625
626 pub fn tool(mut self, tool: ToolDefinition) -> Self {
628 self.tools.push(tool);
629 self
630 }
631
632 pub fn tools(self, tools: Vec<ToolDefinition>) -> Self {
634 tools
635 .into_iter()
636 .fold(self, |builder, tool| builder.tool(tool))
637 }
638
639 pub fn additional_params(mut self, additional_params: serde_json::Value) -> Self {
645 match self.additional_params {
646 Some(params) => {
647 self.additional_params = Some(json_utils::merge(params, additional_params));
648 }
649 None => {
650 self.additional_params = Some(additional_params);
651 }
652 }
653 self
654 }
655
656 pub fn additional_params_opt(mut self, additional_params: Option<serde_json::Value>) -> Self {
662 self.additional_params = additional_params;
663 self
664 }
665
666 pub fn temperature(mut self, temperature: f64) -> Self {
668 self.temperature = Some(temperature);
669 self
670 }
671
672 pub fn temperature_opt(mut self, temperature: Option<f64>) -> Self {
674 self.temperature = temperature;
675 self
676 }
677
678 pub fn max_tokens(mut self, max_tokens: u64) -> Self {
681 self.max_tokens = Some(max_tokens);
682 self
683 }
684
685 pub fn max_tokens_opt(mut self, max_tokens: Option<u64>) -> Self {
688 self.max_tokens = max_tokens;
689 self
690 }
691
692 pub fn tool_choice(mut self, tool_choice: ToolChoice) -> Self {
694 self.tool_choice = Some(tool_choice);
695 self
696 }
697
698 pub fn build(self) -> CompletionRequest {
700 let chat_history = OneOrMany::many([self.chat_history, vec![self.prompt]].concat())
701 .expect("There will always be atleast the prompt");
702
703 CompletionRequest {
704 preamble: self.preamble,
705 chat_history,
706 documents: self.documents,
707 tools: self.tools,
708 temperature: self.temperature,
709 max_tokens: self.max_tokens,
710 tool_choice: self.tool_choice,
711 additional_params: self.additional_params,
712 }
713 }
714
715 pub async fn send(self) -> Result<CompletionResponse<M::Response>, CompletionError> {
717 let model = self.model.clone();
718 model.completion(self.build()).await
719 }
720
721 pub async fn stream<'a>(
723 self,
724 ) -> Result<StreamingCompletionResponse<M::StreamingResponse>, CompletionError>
725 where
726 <M as CompletionModel>::StreamingResponse: 'a,
727 Self: 'a,
728 {
729 let model = self.model.clone();
730 model.stream(self.build()).await
731 }
732}
733
734#[cfg(test)]
735mod tests {
736
737 use super::*;
738
739 #[test]
740 fn test_document_display_without_metadata() {
741 let doc = Document {
742 id: "123".to_string(),
743 text: "This is a test document.".to_string(),
744 additional_props: HashMap::new(),
745 };
746
747 let expected = "<file id: 123>\nThis is a test document.\n</file>\n";
748 assert_eq!(format!("{doc}"), expected);
749 }
750
751 #[test]
752 fn test_document_display_with_metadata() {
753 let mut additional_props = HashMap::new();
754 additional_props.insert("author".to_string(), "John Doe".to_string());
755 additional_props.insert("length".to_string(), "42".to_string());
756
757 let doc = Document {
758 id: "123".to_string(),
759 text: "This is a test document.".to_string(),
760 additional_props,
761 };
762
763 let expected = concat!(
764 "<file id: 123>\n",
765 "<metadata author: \"John Doe\" length: \"42\" />\n",
766 "This is a test document.\n",
767 "</file>\n"
768 );
769 assert_eq!(format!("{doc}"), expected);
770 }
771
772 #[test]
773 fn test_normalize_documents_with_documents() {
774 let doc1 = Document {
775 id: "doc1".to_string(),
776 text: "Document 1 text.".to_string(),
777 additional_props: HashMap::new(),
778 };
779
780 let doc2 = Document {
781 id: "doc2".to_string(),
782 text: "Document 2 text.".to_string(),
783 additional_props: HashMap::new(),
784 };
785
786 let request = CompletionRequest {
787 preamble: None,
788 chat_history: OneOrMany::one("What is the capital of France?".into()),
789 documents: vec![doc1, doc2],
790 tools: Vec::new(),
791 temperature: None,
792 max_tokens: None,
793 tool_choice: None,
794 additional_params: None,
795 };
796
797 let expected = Message::User {
798 content: OneOrMany::many(vec![
799 UserContent::document(
800 "<file id: doc1>\nDocument 1 text.\n</file>\n".to_string(),
801 Some(DocumentMediaType::TXT),
802 ),
803 UserContent::document(
804 "<file id: doc2>\nDocument 2 text.\n</file>\n".to_string(),
805 Some(DocumentMediaType::TXT),
806 ),
807 ])
808 .expect("There will be at least one document"),
809 };
810
811 assert_eq!(request.normalized_documents(), Some(expected));
812 }
813
814 #[test]
815 fn test_normalize_documents_without_documents() {
816 let request = CompletionRequest {
817 preamble: None,
818 chat_history: OneOrMany::one("What is the capital of France?".into()),
819 documents: Vec::new(),
820 tools: Vec::new(),
821 temperature: None,
822 max_tokens: None,
823 tool_choice: None,
824 additional_params: None,
825 };
826
827 assert_eq!(request.normalized_documents(), None);
828 }
829}