1use super::message::{AssistantContent, DocumentMediaType};
67use crate::client::FinalCompletionResponse;
68#[allow(deprecated)]
69use crate::client::completion::CompletionModelHandle;
70use crate::message::ToolChoice;
71use crate::streaming::StreamingCompletionResponse;
72use crate::tool::server::ToolServerError;
73use crate::wasm_compat::{WasmBoxedFuture, WasmCompatSend, WasmCompatSync};
74use crate::{OneOrMany, http_client, streaming};
75use crate::{
76 json_utils,
77 message::{Message, UserContent},
78 tool::ToolSetError,
79};
80use serde::de::DeserializeOwned;
81use serde::{Deserialize, Serialize};
82use std::collections::HashMap;
83use std::ops::{Add, AddAssign};
84use std::sync::Arc;
85use thiserror::Error;
86
87#[derive(Debug, Error)]
89pub enum CompletionError {
90 #[error("HttpError: {0}")]
92 HttpError(#[from] http_client::Error),
93
94 #[error("JsonError: {0}")]
96 JsonError(#[from] serde_json::Error),
97
98 #[error("UrlError: {0}")]
100 UrlError(#[from] url::ParseError),
101
102 #[cfg(not(target_family = "wasm"))]
103 #[error("RequestError: {0}")]
105 RequestError(#[from] Box<dyn std::error::Error + Send + Sync + 'static>),
106
107 #[cfg(target_family = "wasm")]
108 #[error("RequestError: {0}")]
110 RequestError(#[from] Box<dyn std::error::Error + 'static>),
111
112 #[error("ResponseError: {0}")]
114 ResponseError(String),
115
116 #[error("ProviderError: {0}")]
118 ProviderError(String),
119}
120
121#[derive(Debug, Error)]
123pub enum PromptError {
124 #[error("CompletionError: {0}")]
126 CompletionError(#[from] CompletionError),
127
128 #[error("ToolCallError: {0}")]
130 ToolError(#[from] ToolSetError),
131
132 #[error("ToolServerError: {0}")]
134 ToolServerError(#[from] Box<ToolServerError>),
135
136 #[error("MaxTurnError: (reached max turn limit: {max_turns})")]
140 MaxTurnsError {
141 max_turns: usize,
142 chat_history: Box<Vec<Message>>,
143 prompt: Box<Message>,
144 },
145
146 #[error("PromptCancelled: {reason}")]
148 PromptCancelled {
149 chat_history: Vec<Message>,
150 reason: String,
151 },
152}
153
154impl PromptError {
155 pub(crate) fn prompt_cancelled(
156 chat_history: impl IntoIterator<Item = Message>,
157 reason: impl Into<String>,
158 ) -> Self {
159 Self::PromptCancelled {
160 chat_history: chat_history.into_iter().collect(),
161 reason: reason.into(),
162 }
163 }
164}
165
166#[derive(Debug, Error)]
168pub enum StructuredOutputError {
169 #[error("PromptError: {0}")]
171 PromptError(#[from] Box<PromptError>),
172
173 #[error("DeserializationError: {0}")]
175 DeserializationError(#[from] serde_json::Error),
176
177 #[error("EmptyResponse: model returned no content")]
179 EmptyResponse,
180}
181
182#[derive(Clone, Debug, Deserialize, Serialize)]
183pub struct Document {
184 pub id: String,
185 pub text: String,
186 #[serde(flatten)]
187 pub additional_props: HashMap<String, String>,
188}
189
190impl std::fmt::Display for Document {
191 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
192 write!(
193 f,
194 concat!("<file id: {}>\n", "{}\n", "</file>\n"),
195 self.id,
196 if self.additional_props.is_empty() {
197 self.text.clone()
198 } else {
199 let mut sorted_props = self.additional_props.iter().collect::<Vec<_>>();
200 sorted_props.sort_by(|a, b| a.0.cmp(b.0));
201 let metadata = sorted_props
202 .iter()
203 .map(|(k, v)| format!("{k}: {v:?}"))
204 .collect::<Vec<_>>()
205 .join(" ");
206 format!("<metadata {} />\n{}", metadata, self.text)
207 }
208 )
209 }
210}
211
212#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
213pub struct ToolDefinition {
214 pub name: String,
215 pub description: String,
216 pub parameters: serde_json::Value,
217}
218
219#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
224pub struct ProviderToolDefinition {
225 #[serde(rename = "type")]
227 pub kind: String,
228 #[serde(flatten, default, skip_serializing_if = "serde_json::Map::is_empty")]
230 pub config: serde_json::Map<String, serde_json::Value>,
231}
232
233impl ProviderToolDefinition {
234 pub fn new(kind: impl Into<String>) -> Self {
236 Self {
237 kind: kind.into(),
238 config: serde_json::Map::new(),
239 }
240 }
241
242 pub fn with_config(mut self, key: impl Into<String>, value: serde_json::Value) -> Self {
244 self.config.insert(key.into(), value);
245 self
246 }
247}
248
249pub trait Prompt: WasmCompatSend + WasmCompatSync {
254 fn prompt(
263 &self,
264 prompt: impl Into<Message> + WasmCompatSend,
265 ) -> impl std::future::IntoFuture<Output = Result<String, PromptError>, IntoFuture: WasmCompatSend>;
266}
267
268pub trait Chat: WasmCompatSend + WasmCompatSync {
270 fn chat<I, T>(
279 &self,
280 prompt: impl Into<Message> + WasmCompatSend,
281 chat_history: I,
282 ) -> impl std::future::Future<Output = Result<String, PromptError>> + WasmCompatSend
283 where
284 I: IntoIterator<Item = T> + WasmCompatSend,
285 T: Into<Message>;
286}
287
288pub trait TypedPrompt: WasmCompatSend + WasmCompatSync {
312 type TypedRequest<T>: std::future::IntoFuture<Output = Result<T, StructuredOutputError>>
314 where
315 T: schemars::JsonSchema + DeserializeOwned + WasmCompatSend + 'static;
316
317 fn prompt_typed<T>(&self, prompt: impl Into<Message> + WasmCompatSend) -> Self::TypedRequest<T>
337 where
338 T: schemars::JsonSchema + DeserializeOwned + WasmCompatSend;
339}
340
341pub trait Completion<M: CompletionModel> {
343 fn completion<I, T>(
355 &self,
356 prompt: impl Into<Message> + WasmCompatSend,
357 chat_history: I,
358 ) -> impl std::future::Future<Output = Result<CompletionRequestBuilder<M>, CompletionError>>
359 + WasmCompatSend
360 where
361 I: IntoIterator<Item = T> + WasmCompatSend,
362 T: Into<Message>;
363}
364
365#[derive(Debug)]
368pub struct CompletionResponse<T> {
369 pub choice: OneOrMany<AssistantContent>,
372 pub usage: Usage,
374 pub raw_response: T,
376 pub message_id: Option<String>,
379}
380
381pub trait GetTokenUsage {
385 fn token_usage(&self) -> Option<crate::completion::Usage>;
386}
387
388impl GetTokenUsage for () {
389 fn token_usage(&self) -> Option<crate::completion::Usage> {
390 None
391 }
392}
393
394impl<T> GetTokenUsage for Option<T>
395where
396 T: GetTokenUsage,
397{
398 fn token_usage(&self) -> Option<crate::completion::Usage> {
399 if let Some(usage) = self {
400 usage.token_usage()
401 } else {
402 None
403 }
404 }
405}
406
407#[derive(Debug, PartialEq, Eq, Clone, Copy, Serialize, Deserialize)]
410pub struct Usage {
411 pub input_tokens: u64,
413 pub output_tokens: u64,
415 pub total_tokens: u64,
417 pub cached_input_tokens: u64,
419 pub cache_creation_input_tokens: u64,
421}
422
423impl Usage {
424 pub fn new() -> Self {
426 Self {
427 input_tokens: 0,
428 output_tokens: 0,
429 total_tokens: 0,
430 cached_input_tokens: 0,
431 cache_creation_input_tokens: 0,
432 }
433 }
434}
435
436impl Default for Usage {
437 fn default() -> Self {
438 Self::new()
439 }
440}
441
442impl Add for Usage {
443 type Output = Self;
444
445 fn add(self, other: Self) -> Self::Output {
446 Self {
447 input_tokens: self.input_tokens + other.input_tokens,
448 output_tokens: self.output_tokens + other.output_tokens,
449 total_tokens: self.total_tokens + other.total_tokens,
450 cached_input_tokens: self.cached_input_tokens + other.cached_input_tokens,
451 cache_creation_input_tokens: self.cache_creation_input_tokens
452 + other.cache_creation_input_tokens,
453 }
454 }
455}
456
457impl AddAssign for Usage {
458 fn add_assign(&mut self, other: Self) {
459 self.input_tokens += other.input_tokens;
460 self.output_tokens += other.output_tokens;
461 self.total_tokens += other.total_tokens;
462 self.cached_input_tokens += other.cached_input_tokens;
463 self.cache_creation_input_tokens += other.cache_creation_input_tokens;
464 }
465}
466
467pub trait CompletionModel: Clone + WasmCompatSend + WasmCompatSync {
471 type Response: WasmCompatSend + WasmCompatSync + Serialize + DeserializeOwned;
473 type StreamingResponse: Clone
475 + Unpin
476 + WasmCompatSend
477 + WasmCompatSync
478 + Serialize
479 + DeserializeOwned
480 + GetTokenUsage;
481
482 type Client;
483
484 fn make(client: &Self::Client, model: impl Into<String>) -> Self;
485
486 fn completion(
488 &self,
489 request: CompletionRequest,
490 ) -> impl std::future::Future<
491 Output = Result<CompletionResponse<Self::Response>, CompletionError>,
492 > + WasmCompatSend;
493
494 fn stream(
495 &self,
496 request: CompletionRequest,
497 ) -> impl std::future::Future<
498 Output = Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError>,
499 > + WasmCompatSend;
500
501 fn completion_request(&self, prompt: impl Into<Message>) -> CompletionRequestBuilder<Self> {
503 CompletionRequestBuilder::new(self.clone(), prompt)
504 }
505}
506
507#[allow(deprecated)]
508#[deprecated(
509 since = "0.25.0",
510 note = "`DynClientBuilder` and related features have been deprecated and will be removed in a future release. In this case, use `CompletionModel` instead."
511)]
512pub trait CompletionModelDyn: WasmCompatSend + WasmCompatSync {
513 fn completion(
514 &self,
515 request: CompletionRequest,
516 ) -> WasmBoxedFuture<'_, Result<CompletionResponse<()>, CompletionError>>;
517
518 fn stream(
519 &self,
520 request: CompletionRequest,
521 ) -> WasmBoxedFuture<
522 '_,
523 Result<StreamingCompletionResponse<FinalCompletionResponse>, CompletionError>,
524 >;
525
526 fn completion_request(
527 &self,
528 prompt: Message,
529 ) -> CompletionRequestBuilder<CompletionModelHandle<'_>>;
530}
531
532#[allow(deprecated)]
533impl<T, R> CompletionModelDyn for T
534where
535 T: CompletionModel<StreamingResponse = R>,
536 R: Clone + Unpin + GetTokenUsage + 'static,
537{
538 fn completion(
539 &self,
540 request: CompletionRequest,
541 ) -> WasmBoxedFuture<'_, Result<CompletionResponse<()>, CompletionError>> {
542 Box::pin(async move {
543 self.completion(request)
544 .await
545 .map(|resp| CompletionResponse {
546 choice: resp.choice,
547 usage: resp.usage,
548 raw_response: (),
549 message_id: resp.message_id,
550 })
551 })
552 }
553
554 fn stream(
555 &self,
556 request: CompletionRequest,
557 ) -> WasmBoxedFuture<
558 '_,
559 Result<StreamingCompletionResponse<FinalCompletionResponse>, CompletionError>,
560 > {
561 Box::pin(async move {
562 let resp = self.stream(request).await?;
563 let inner = resp.inner;
564
565 let stream = streaming::StreamingResultDyn {
566 inner: Box::pin(inner),
567 };
568
569 Ok(StreamingCompletionResponse::stream(Box::pin(stream)))
570 })
571 }
572
573 fn completion_request(
575 &self,
576 prompt: Message,
577 ) -> CompletionRequestBuilder<CompletionModelHandle<'_>> {
578 CompletionRequestBuilder::new(CompletionModelHandle::new(Arc::new(self.clone())), prompt)
579 }
580}
581
582#[derive(Debug, Clone)]
584pub struct CompletionRequest {
585 pub model: Option<String>,
587 pub preamble: Option<String>,
592 pub chat_history: OneOrMany<Message>,
595 pub documents: Vec<Document>,
597 pub tools: Vec<ToolDefinition>,
599 pub temperature: Option<f64>,
601 pub max_tokens: Option<u64>,
603 pub tool_choice: Option<ToolChoice>,
605 pub additional_params: Option<serde_json::Value>,
607 pub output_schema: Option<schemars::Schema>,
610}
611
612impl CompletionRequest {
613 pub fn output_schema_name(&self) -> Option<String> {
616 self.output_schema.as_ref().map(|schema| {
617 schema
618 .as_object()
619 .and_then(|o| o.get("title"))
620 .and_then(|v| v.as_str())
621 .unwrap_or("response_schema")
622 .to_string()
623 })
624 }
625
626 pub fn normalized_documents(&self) -> Option<Message> {
630 if self.documents.is_empty() {
631 return None;
632 }
633
634 let messages = self
637 .documents
638 .iter()
639 .map(|doc| {
640 UserContent::document(
641 doc.to_string(),
642 Some(DocumentMediaType::TXT),
645 )
646 })
647 .collect::<Vec<_>>();
648
649 Some(Message::User {
650 content: OneOrMany::many(messages).expect("There will be atleast one document"),
651 })
652 }
653
654 pub fn with_provider_tool(mut self, tool: ProviderToolDefinition) -> Self {
656 self.additional_params =
657 merge_provider_tools_into_additional_params(self.additional_params, vec![tool]);
658 self
659 }
660
661 pub fn with_provider_tools(mut self, tools: Vec<ProviderToolDefinition>) -> Self {
663 self.additional_params =
664 merge_provider_tools_into_additional_params(self.additional_params, tools);
665 self
666 }
667}
668
669fn merge_provider_tools_into_additional_params(
670 additional_params: Option<serde_json::Value>,
671 provider_tools: Vec<ProviderToolDefinition>,
672) -> Option<serde_json::Value> {
673 if provider_tools.is_empty() {
674 return additional_params;
675 }
676
677 let mut provider_tools_json = provider_tools
678 .into_iter()
679 .map(|ProviderToolDefinition { kind, mut config }| {
680 config.insert("type".to_string(), serde_json::Value::String(kind));
682 serde_json::Value::Object(config)
683 })
684 .collect::<Vec<_>>();
685
686 let mut params_map = match additional_params {
687 Some(serde_json::Value::Object(map)) => map,
688 Some(serde_json::Value::Bool(stream)) => {
689 let mut map = serde_json::Map::new();
690 map.insert("stream".to_string(), serde_json::Value::Bool(stream));
691 map
692 }
693 _ => serde_json::Map::new(),
694 };
695
696 let mut merged_tools = match params_map.remove("tools") {
697 Some(serde_json::Value::Array(existing)) => existing,
698 _ => Vec::new(),
699 };
700 merged_tools.append(&mut provider_tools_json);
701 params_map.insert("tools".to_string(), serde_json::Value::Array(merged_tools));
702 Some(serde_json::Value::Object(params_map))
703}
704
705pub struct CompletionRequestBuilder<M: CompletionModel> {
750 model: M,
751 prompt: Message,
752 request_model: Option<String>,
753 preamble: Option<String>,
754 chat_history: Vec<Message>,
755 documents: Vec<Document>,
756 tools: Vec<ToolDefinition>,
757 provider_tools: Vec<ProviderToolDefinition>,
758 temperature: Option<f64>,
759 max_tokens: Option<u64>,
760 tool_choice: Option<ToolChoice>,
761 additional_params: Option<serde_json::Value>,
762 output_schema: Option<schemars::Schema>,
763}
764
765impl<M: CompletionModel> CompletionRequestBuilder<M> {
766 pub fn new(model: M, prompt: impl Into<Message>) -> Self {
767 Self {
768 model,
769 prompt: prompt.into(),
770 request_model: None,
771 preamble: None,
772 chat_history: Vec::new(),
773 documents: Vec::new(),
774 tools: Vec::new(),
775 provider_tools: Vec::new(),
776 temperature: None,
777 max_tokens: None,
778 tool_choice: None,
779 additional_params: None,
780 output_schema: None,
781 }
782 }
783
784 pub fn preamble(mut self, preamble: String) -> Self {
786 self.preamble = Some(preamble);
788 self
789 }
790
791 pub fn model(mut self, model: impl Into<String>) -> Self {
793 self.request_model = Some(model.into());
794 self
795 }
796
797 pub fn model_opt(mut self, model: Option<String>) -> Self {
799 self.request_model = model;
800 self
801 }
802
803 pub fn without_preamble(mut self) -> Self {
804 self.preamble = None;
805 self
806 }
807
808 pub fn message(mut self, message: Message) -> Self {
810 self.chat_history.push(message);
811
812 self
813 }
814
815 pub fn messages(mut self, messages: impl IntoIterator<Item = Message>) -> Self {
817 self.chat_history.extend(messages);
818
819 self
820 }
821
822 pub fn document(mut self, document: Document) -> Self {
824 self.documents.push(document);
825 self
826 }
827
828 pub fn documents(self, documents: impl IntoIterator<Item = Document>) -> Self {
830 documents
831 .into_iter()
832 .fold(self, |builder, doc| builder.document(doc))
833 }
834
835 pub fn tool(mut self, tool: ToolDefinition) -> Self {
837 self.tools.push(tool);
838 self
839 }
840
841 pub fn tools(self, tools: Vec<ToolDefinition>) -> Self {
843 tools
844 .into_iter()
845 .fold(self, |builder, tool| builder.tool(tool))
846 }
847
848 pub fn provider_tool(mut self, tool: ProviderToolDefinition) -> Self {
850 self.provider_tools.push(tool);
851 self
852 }
853
854 pub fn provider_tools(self, tools: Vec<ProviderToolDefinition>) -> Self {
856 tools
857 .into_iter()
858 .fold(self, |builder, tool| builder.provider_tool(tool))
859 }
860
861 pub fn additional_params(mut self, additional_params: serde_json::Value) -> Self {
867 match self.additional_params {
868 Some(params) => {
869 self.additional_params = Some(json_utils::merge(params, additional_params));
870 }
871 None => {
872 self.additional_params = Some(additional_params);
873 }
874 }
875 self
876 }
877
878 pub fn additional_params_opt(mut self, additional_params: Option<serde_json::Value>) -> Self {
884 self.additional_params = additional_params;
885 self
886 }
887
888 pub fn temperature(mut self, temperature: f64) -> Self {
890 self.temperature = Some(temperature);
891 self
892 }
893
894 pub fn temperature_opt(mut self, temperature: Option<f64>) -> Self {
896 self.temperature = temperature;
897 self
898 }
899
900 pub fn max_tokens(mut self, max_tokens: u64) -> Self {
903 self.max_tokens = Some(max_tokens);
904 self
905 }
906
907 pub fn max_tokens_opt(mut self, max_tokens: Option<u64>) -> Self {
910 self.max_tokens = max_tokens;
911 self
912 }
913
914 pub fn tool_choice(mut self, tool_choice: ToolChoice) -> Self {
916 self.tool_choice = Some(tool_choice);
917 self
918 }
919
920 pub fn output_schema(mut self, schema: schemars::Schema) -> Self {
927 self.output_schema = Some(schema);
928 self
929 }
930
931 pub fn output_schema_opt(mut self, schema: Option<schemars::Schema>) -> Self {
937 self.output_schema = schema;
938 self
939 }
940
941 pub fn build(self) -> CompletionRequest {
943 let mut chat_history = self.chat_history;
945 if let Some(preamble) = self.preamble {
946 chat_history.insert(0, Message::system(preamble));
947 }
948 chat_history.push(self.prompt);
949
950 let chat_history =
951 OneOrMany::many(chat_history).expect("There will always be at least the prompt");
952 let additional_params = merge_provider_tools_into_additional_params(
953 self.additional_params,
954 self.provider_tools,
955 );
956
957 CompletionRequest {
958 model: self.request_model,
959 preamble: None,
960 chat_history,
961 documents: self.documents,
962 tools: self.tools,
963 temperature: self.temperature,
964 max_tokens: self.max_tokens,
965 tool_choice: self.tool_choice,
966 additional_params,
967 output_schema: self.output_schema,
968 }
969 }
970
971 pub async fn send(self) -> Result<CompletionResponse<M::Response>, CompletionError> {
973 let model = self.model.clone();
974 model.completion(self.build()).await
975 }
976
977 pub async fn stream<'a>(
979 self,
980 ) -> Result<StreamingCompletionResponse<M::StreamingResponse>, CompletionError>
981 where
982 <M as CompletionModel>::StreamingResponse: 'a,
983 Self: 'a,
984 {
985 let model = self.model.clone();
986 model.stream(self.build()).await
987 }
988}
989
990#[cfg(test)]
991mod tests {
992
993 use super::*;
994 use crate::streaming::StreamingCompletionResponse;
995 use serde::{Deserialize, Serialize};
996
997 #[derive(Clone)]
998 struct DummyModel;
999
1000 #[derive(Clone, Debug, Serialize, Deserialize)]
1001 struct DummyStreamingResponse;
1002
1003 impl GetTokenUsage for DummyStreamingResponse {
1004 fn token_usage(&self) -> Option<Usage> {
1005 None
1006 }
1007 }
1008
1009 impl CompletionModel for DummyModel {
1010 type Response = serde_json::Value;
1011 type StreamingResponse = DummyStreamingResponse;
1012 type Client = ();
1013
1014 fn make(_client: &Self::Client, _model: impl Into<String>) -> Self {
1015 Self
1016 }
1017
1018 async fn completion(
1019 &self,
1020 _request: CompletionRequest,
1021 ) -> Result<CompletionResponse<Self::Response>, CompletionError> {
1022 Err(CompletionError::ProviderError(
1023 "dummy completion model".to_string(),
1024 ))
1025 }
1026
1027 async fn stream(
1028 &self,
1029 _request: CompletionRequest,
1030 ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
1031 Err(CompletionError::ProviderError(
1032 "dummy completion model".to_string(),
1033 ))
1034 }
1035 }
1036
1037 #[test]
1038 fn test_document_display_without_metadata() {
1039 let doc = Document {
1040 id: "123".to_string(),
1041 text: "This is a test document.".to_string(),
1042 additional_props: HashMap::new(),
1043 };
1044
1045 let expected = "<file id: 123>\nThis is a test document.\n</file>\n";
1046 assert_eq!(format!("{doc}"), expected);
1047 }
1048
1049 #[test]
1050 fn test_document_display_with_metadata() {
1051 let mut additional_props = HashMap::new();
1052 additional_props.insert("author".to_string(), "John Doe".to_string());
1053 additional_props.insert("length".to_string(), "42".to_string());
1054
1055 let doc = Document {
1056 id: "123".to_string(),
1057 text: "This is a test document.".to_string(),
1058 additional_props,
1059 };
1060
1061 let expected = concat!(
1062 "<file id: 123>\n",
1063 "<metadata author: \"John Doe\" length: \"42\" />\n",
1064 "This is a test document.\n",
1065 "</file>\n"
1066 );
1067 assert_eq!(format!("{doc}"), expected);
1068 }
1069
1070 #[test]
1071 fn test_normalize_documents_with_documents() {
1072 let doc1 = Document {
1073 id: "doc1".to_string(),
1074 text: "Document 1 text.".to_string(),
1075 additional_props: HashMap::new(),
1076 };
1077
1078 let doc2 = Document {
1079 id: "doc2".to_string(),
1080 text: "Document 2 text.".to_string(),
1081 additional_props: HashMap::new(),
1082 };
1083
1084 let request = CompletionRequest {
1085 model: None,
1086 preamble: None,
1087 chat_history: OneOrMany::one("What is the capital of France?".into()),
1088 documents: vec![doc1, doc2],
1089 tools: Vec::new(),
1090 temperature: None,
1091 max_tokens: None,
1092 tool_choice: None,
1093 additional_params: None,
1094 output_schema: None,
1095 };
1096
1097 let expected = Message::User {
1098 content: OneOrMany::many(vec![
1099 UserContent::document(
1100 "<file id: doc1>\nDocument 1 text.\n</file>\n".to_string(),
1101 Some(DocumentMediaType::TXT),
1102 ),
1103 UserContent::document(
1104 "<file id: doc2>\nDocument 2 text.\n</file>\n".to_string(),
1105 Some(DocumentMediaType::TXT),
1106 ),
1107 ])
1108 .expect("There will be at least one document"),
1109 };
1110
1111 assert_eq!(request.normalized_documents(), Some(expected));
1112 }
1113
1114 #[test]
1115 fn test_normalize_documents_without_documents() {
1116 let request = CompletionRequest {
1117 model: None,
1118 preamble: None,
1119 chat_history: OneOrMany::one("What is the capital of France?".into()),
1120 documents: Vec::new(),
1121 tools: Vec::new(),
1122 temperature: None,
1123 max_tokens: None,
1124 tool_choice: None,
1125 additional_params: None,
1126 output_schema: None,
1127 };
1128
1129 assert_eq!(request.normalized_documents(), None);
1130 }
1131
1132 #[test]
1133 fn preamble_builder_funnels_to_system_message() {
1134 let request = CompletionRequestBuilder::new(DummyModel, Message::user("Prompt"))
1135 .preamble("System prompt".to_string())
1136 .message(Message::user("History"))
1137 .build();
1138
1139 assert_eq!(request.preamble, None);
1140
1141 let history = request.chat_history.into_iter().collect::<Vec<_>>();
1142 assert_eq!(history.len(), 3);
1143 assert!(matches!(
1144 &history[0],
1145 Message::System { content } if content == "System prompt"
1146 ));
1147 assert!(matches!(&history[1], Message::User { .. }));
1148 assert!(matches!(&history[2], Message::User { .. }));
1149 }
1150
1151 #[test]
1152 fn without_preamble_removes_legacy_preamble_injection() {
1153 let request = CompletionRequestBuilder::new(DummyModel, Message::user("Prompt"))
1154 .preamble("System prompt".to_string())
1155 .without_preamble()
1156 .build();
1157
1158 assert_eq!(request.preamble, None);
1159 let history = request.chat_history.into_iter().collect::<Vec<_>>();
1160 assert_eq!(history.len(), 1);
1161 assert!(matches!(&history[0], Message::User { .. }));
1162 }
1163}