1use super::message::{AssistantContent, DocumentMediaType};
67use crate::client::FinalCompletionResponse;
68#[allow(deprecated)]
69use crate::client::completion::CompletionModelHandle;
70use crate::message::ToolChoice;
71use crate::streaming::StreamingCompletionResponse;
72use crate::tool::server::ToolServerError;
73use crate::wasm_compat::{WasmBoxedFuture, WasmCompatSend, WasmCompatSync};
74use crate::{OneOrMany, http_client, streaming};
75use crate::{
76 json_utils,
77 message::{Message, UserContent},
78 tool::ToolSetError,
79};
80use serde::de::DeserializeOwned;
81use serde::{Deserialize, Serialize};
82use std::collections::HashMap;
83use std::ops::{Add, AddAssign};
84use std::sync::Arc;
85use thiserror::Error;
86
87#[derive(Debug, Error)]
89pub enum CompletionError {
90 #[error("HttpError: {0}")]
92 HttpError(#[from] http_client::Error),
93
94 #[error("JsonError: {0}")]
96 JsonError(#[from] serde_json::Error),
97
98 #[error("UrlError: {0}")]
100 UrlError(#[from] url::ParseError),
101
102 #[cfg(not(target_family = "wasm"))]
103 #[error("RequestError: {0}")]
105 RequestError(#[from] Box<dyn std::error::Error + Send + Sync + 'static>),
106
107 #[cfg(target_family = "wasm")]
108 #[error("RequestError: {0}")]
110 RequestError(#[from] Box<dyn std::error::Error + 'static>),
111
112 #[error("ResponseError: {0}")]
114 ResponseError(String),
115
116 #[error("ProviderError: {0}")]
118 ProviderError(String),
119}
120
121#[derive(Debug, Error)]
123pub enum PromptError {
124 #[error("CompletionError: {0}")]
126 CompletionError(#[from] CompletionError),
127
128 #[error("ToolCallError: {0}")]
130 ToolError(#[from] ToolSetError),
131
132 #[error("ToolServerError: {0}")]
134 ToolServerError(#[from] ToolServerError),
135
136 #[error("MaxTurnError: (reached max turn limit: {max_turns})")]
140 MaxTurnsError {
141 max_turns: usize,
142 chat_history: Box<Vec<Message>>,
143 prompt: Box<Message>,
144 },
145
146 #[error("PromptCancelled: {reason}")]
148 PromptCancelled {
149 chat_history: Box<Vec<Message>>,
150 reason: String,
151 },
152}
153
154impl PromptError {
155 pub(crate) fn prompt_cancelled(chat_history: Vec<Message>, reason: impl Into<String>) -> Self {
156 Self::PromptCancelled {
157 chat_history: Box::new(chat_history),
158 reason: reason.into(),
159 }
160 }
161}
162
163#[derive(Debug, Error)]
165pub enum StructuredOutputError {
166 #[error("PromptError: {0}")]
168 PromptError(#[from] PromptError),
169
170 #[error("DeserializationError: {0}")]
172 DeserializationError(#[from] serde_json::Error),
173
174 #[error("EmptyResponse: model returned no content")]
176 EmptyResponse,
177}
178
179#[derive(Clone, Debug, Deserialize, Serialize)]
180pub struct Document {
181 pub id: String,
182 pub text: String,
183 #[serde(flatten)]
184 pub additional_props: HashMap<String, String>,
185}
186
187impl std::fmt::Display for Document {
188 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
189 write!(
190 f,
191 concat!("<file id: {}>\n", "{}\n", "</file>\n"),
192 self.id,
193 if self.additional_props.is_empty() {
194 self.text.clone()
195 } else {
196 let mut sorted_props = self.additional_props.iter().collect::<Vec<_>>();
197 sorted_props.sort_by(|a, b| a.0.cmp(b.0));
198 let metadata = sorted_props
199 .iter()
200 .map(|(k, v)| format!("{k}: {v:?}"))
201 .collect::<Vec<_>>()
202 .join(" ");
203 format!("<metadata {} />\n{}", metadata, self.text)
204 }
205 )
206 }
207}
208
209#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
210pub struct ToolDefinition {
211 pub name: String,
212 pub description: String,
213 pub parameters: serde_json::Value,
214}
215
216#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
221pub struct ProviderToolDefinition {
222 #[serde(rename = "type")]
224 pub kind: String,
225 #[serde(flatten, default, skip_serializing_if = "serde_json::Map::is_empty")]
227 pub config: serde_json::Map<String, serde_json::Value>,
228}
229
230impl ProviderToolDefinition {
231 pub fn new(kind: impl Into<String>) -> Self {
233 Self {
234 kind: kind.into(),
235 config: serde_json::Map::new(),
236 }
237 }
238
239 pub fn with_config(mut self, key: impl Into<String>, value: serde_json::Value) -> Self {
241 self.config.insert(key.into(), value);
242 self
243 }
244}
245
246pub trait Prompt: WasmCompatSend + WasmCompatSync {
251 fn prompt(
260 &self,
261 prompt: impl Into<Message> + WasmCompatSend,
262 ) -> impl std::future::IntoFuture<Output = Result<String, PromptError>, IntoFuture: WasmCompatSend>;
263}
264
265pub trait Chat: WasmCompatSend + WasmCompatSync {
267 fn chat(
276 &self,
277 prompt: impl Into<Message> + WasmCompatSend,
278 chat_history: Vec<Message>,
279 ) -> impl std::future::IntoFuture<Output = Result<String, PromptError>, IntoFuture: WasmCompatSend>;
280}
281
282pub trait TypedPrompt: WasmCompatSend + WasmCompatSync {
306 type TypedRequest<'a, T>: std::future::IntoFuture<Output = Result<T, StructuredOutputError>>
308 where
309 Self: 'a,
310 T: schemars::JsonSchema + DeserializeOwned + WasmCompatSend + 'a;
311
312 fn prompt_typed<T>(
332 &self,
333 prompt: impl Into<Message> + WasmCompatSend,
334 ) -> Self::TypedRequest<'_, T>
335 where
336 T: schemars::JsonSchema + DeserializeOwned + WasmCompatSend;
337}
338
339pub trait Completion<M: CompletionModel> {
341 fn completion(
353 &self,
354 prompt: impl Into<Message> + WasmCompatSend,
355 chat_history: Vec<Message>,
356 ) -> impl std::future::Future<Output = Result<CompletionRequestBuilder<M>, CompletionError>>
357 + WasmCompatSend;
358}
359
360#[derive(Debug)]
363pub struct CompletionResponse<T> {
364 pub choice: OneOrMany<AssistantContent>,
367 pub usage: Usage,
369 pub raw_response: T,
371 pub message_id: Option<String>,
374}
375
376pub trait GetTokenUsage {
380 fn token_usage(&self) -> Option<crate::completion::Usage>;
381}
382
383impl GetTokenUsage for () {
384 fn token_usage(&self) -> Option<crate::completion::Usage> {
385 None
386 }
387}
388
389impl<T> GetTokenUsage for Option<T>
390where
391 T: GetTokenUsage,
392{
393 fn token_usage(&self) -> Option<crate::completion::Usage> {
394 if let Some(usage) = self {
395 usage.token_usage()
396 } else {
397 None
398 }
399 }
400}
401
402#[derive(Debug, PartialEq, Eq, Clone, Copy, Serialize, Deserialize)]
405pub struct Usage {
406 pub input_tokens: u64,
408 pub output_tokens: u64,
410 pub total_tokens: u64,
412 pub cached_input_tokens: u64,
414}
415
416impl Usage {
417 pub fn new() -> Self {
419 Self {
420 input_tokens: 0,
421 output_tokens: 0,
422 total_tokens: 0,
423 cached_input_tokens: 0,
424 }
425 }
426}
427
428impl Default for Usage {
429 fn default() -> Self {
430 Self::new()
431 }
432}
433
434impl Add for Usage {
435 type Output = Self;
436
437 fn add(self, other: Self) -> Self::Output {
438 Self {
439 input_tokens: self.input_tokens + other.input_tokens,
440 output_tokens: self.output_tokens + other.output_tokens,
441 total_tokens: self.total_tokens + other.total_tokens,
442 cached_input_tokens: self.cached_input_tokens + other.cached_input_tokens,
443 }
444 }
445}
446
447impl AddAssign for Usage {
448 fn add_assign(&mut self, other: Self) {
449 self.input_tokens += other.input_tokens;
450 self.output_tokens += other.output_tokens;
451 self.total_tokens += other.total_tokens;
452 self.cached_input_tokens += other.cached_input_tokens;
453 }
454}
455
456pub trait CompletionModel: Clone + WasmCompatSend + WasmCompatSync {
460 type Response: WasmCompatSend + WasmCompatSync + Serialize + DeserializeOwned;
462 type StreamingResponse: Clone
464 + Unpin
465 + WasmCompatSend
466 + WasmCompatSync
467 + Serialize
468 + DeserializeOwned
469 + GetTokenUsage;
470
471 type Client;
472
473 fn make(client: &Self::Client, model: impl Into<String>) -> Self;
474
475 fn completion(
477 &self,
478 request: CompletionRequest,
479 ) -> impl std::future::Future<
480 Output = Result<CompletionResponse<Self::Response>, CompletionError>,
481 > + WasmCompatSend;
482
483 fn stream(
484 &self,
485 request: CompletionRequest,
486 ) -> impl std::future::Future<
487 Output = Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError>,
488 > + WasmCompatSend;
489
490 fn completion_request(&self, prompt: impl Into<Message>) -> CompletionRequestBuilder<Self> {
492 CompletionRequestBuilder::new(self.clone(), prompt)
493 }
494}
495
496#[allow(deprecated)]
497#[deprecated(
498 since = "0.25.0",
499 note = "`DynClientBuilder` and related features have been deprecated and will be removed in a future release. In this case, use `CompletionModel` instead."
500)]
501pub trait CompletionModelDyn: WasmCompatSend + WasmCompatSync {
502 fn completion(
503 &self,
504 request: CompletionRequest,
505 ) -> WasmBoxedFuture<'_, Result<CompletionResponse<()>, CompletionError>>;
506
507 fn stream(
508 &self,
509 request: CompletionRequest,
510 ) -> WasmBoxedFuture<
511 '_,
512 Result<StreamingCompletionResponse<FinalCompletionResponse>, CompletionError>,
513 >;
514
515 fn completion_request(
516 &self,
517 prompt: Message,
518 ) -> CompletionRequestBuilder<CompletionModelHandle<'_>>;
519}
520
521#[allow(deprecated)]
522impl<T, R> CompletionModelDyn for T
523where
524 T: CompletionModel<StreamingResponse = R>,
525 R: Clone + Unpin + GetTokenUsage + 'static,
526{
527 fn completion(
528 &self,
529 request: CompletionRequest,
530 ) -> WasmBoxedFuture<'_, Result<CompletionResponse<()>, CompletionError>> {
531 Box::pin(async move {
532 self.completion(request)
533 .await
534 .map(|resp| CompletionResponse {
535 choice: resp.choice,
536 usage: resp.usage,
537 raw_response: (),
538 message_id: resp.message_id,
539 })
540 })
541 }
542
543 fn stream(
544 &self,
545 request: CompletionRequest,
546 ) -> WasmBoxedFuture<
547 '_,
548 Result<StreamingCompletionResponse<FinalCompletionResponse>, CompletionError>,
549 > {
550 Box::pin(async move {
551 let resp = self.stream(request).await?;
552 let inner = resp.inner;
553
554 let stream = streaming::StreamingResultDyn {
555 inner: Box::pin(inner),
556 };
557
558 Ok(StreamingCompletionResponse::stream(Box::pin(stream)))
559 })
560 }
561
562 fn completion_request(
564 &self,
565 prompt: Message,
566 ) -> CompletionRequestBuilder<CompletionModelHandle<'_>> {
567 CompletionRequestBuilder::new(CompletionModelHandle::new(Arc::new(self.clone())), prompt)
568 }
569}
570
571#[derive(Debug, Clone)]
573pub struct CompletionRequest {
574 pub model: Option<String>,
576 pub preamble: Option<String>,
581 pub chat_history: OneOrMany<Message>,
584 pub documents: Vec<Document>,
586 pub tools: Vec<ToolDefinition>,
588 pub temperature: Option<f64>,
590 pub max_tokens: Option<u64>,
592 pub tool_choice: Option<ToolChoice>,
594 pub additional_params: Option<serde_json::Value>,
596 pub output_schema: Option<schemars::Schema>,
599}
600
601impl CompletionRequest {
602 pub fn output_schema_name(&self) -> Option<String> {
605 self.output_schema.as_ref().map(|schema| {
606 schema
607 .as_object()
608 .and_then(|o| o.get("title"))
609 .and_then(|v| v.as_str())
610 .unwrap_or("response_schema")
611 .to_string()
612 })
613 }
614
615 pub fn normalized_documents(&self) -> Option<Message> {
619 if self.documents.is_empty() {
620 return None;
621 }
622
623 let messages = self
626 .documents
627 .iter()
628 .map(|doc| {
629 UserContent::document(
630 doc.to_string(),
631 Some(DocumentMediaType::TXT),
634 )
635 })
636 .collect::<Vec<_>>();
637
638 Some(Message::User {
639 content: OneOrMany::many(messages).expect("There will be atleast one document"),
640 })
641 }
642
643 pub fn with_provider_tool(mut self, tool: ProviderToolDefinition) -> Self {
645 self.additional_params =
646 merge_provider_tools_into_additional_params(self.additional_params, vec![tool]);
647 self
648 }
649
650 pub fn with_provider_tools(mut self, tools: Vec<ProviderToolDefinition>) -> Self {
652 self.additional_params =
653 merge_provider_tools_into_additional_params(self.additional_params, tools);
654 self
655 }
656}
657
658fn merge_provider_tools_into_additional_params(
659 additional_params: Option<serde_json::Value>,
660 provider_tools: Vec<ProviderToolDefinition>,
661) -> Option<serde_json::Value> {
662 if provider_tools.is_empty() {
663 return additional_params;
664 }
665
666 let mut provider_tools_json = provider_tools
667 .into_iter()
668 .map(|ProviderToolDefinition { kind, mut config }| {
669 config.insert("type".to_string(), serde_json::Value::String(kind));
671 serde_json::Value::Object(config)
672 })
673 .collect::<Vec<_>>();
674
675 let mut params_map = match additional_params {
676 Some(serde_json::Value::Object(map)) => map,
677 Some(serde_json::Value::Bool(stream)) => {
678 let mut map = serde_json::Map::new();
679 map.insert("stream".to_string(), serde_json::Value::Bool(stream));
680 map
681 }
682 _ => serde_json::Map::new(),
683 };
684
685 let mut merged_tools = match params_map.remove("tools") {
686 Some(serde_json::Value::Array(existing)) => existing,
687 _ => Vec::new(),
688 };
689 merged_tools.append(&mut provider_tools_json);
690 params_map.insert("tools".to_string(), serde_json::Value::Array(merged_tools));
691 Some(serde_json::Value::Object(params_map))
692}
693
694pub struct CompletionRequestBuilder<M: CompletionModel> {
739 model: M,
740 prompt: Message,
741 request_model: Option<String>,
742 preamble: Option<String>,
743 chat_history: Vec<Message>,
744 documents: Vec<Document>,
745 tools: Vec<ToolDefinition>,
746 provider_tools: Vec<ProviderToolDefinition>,
747 temperature: Option<f64>,
748 max_tokens: Option<u64>,
749 tool_choice: Option<ToolChoice>,
750 additional_params: Option<serde_json::Value>,
751 output_schema: Option<schemars::Schema>,
752}
753
754impl<M: CompletionModel> CompletionRequestBuilder<M> {
755 pub fn new(model: M, prompt: impl Into<Message>) -> Self {
756 Self {
757 model,
758 prompt: prompt.into(),
759 request_model: None,
760 preamble: None,
761 chat_history: Vec::new(),
762 documents: Vec::new(),
763 tools: Vec::new(),
764 provider_tools: Vec::new(),
765 temperature: None,
766 max_tokens: None,
767 tool_choice: None,
768 additional_params: None,
769 output_schema: None,
770 }
771 }
772
773 pub fn preamble(mut self, preamble: String) -> Self {
775 self.preamble = Some(preamble);
777 self
778 }
779
780 pub fn model(mut self, model: impl Into<String>) -> Self {
782 self.request_model = Some(model.into());
783 self
784 }
785
786 pub fn model_opt(mut self, model: Option<String>) -> Self {
788 self.request_model = model;
789 self
790 }
791
792 pub fn without_preamble(mut self) -> Self {
793 self.preamble = None;
794 self
795 }
796
797 pub fn message(mut self, message: Message) -> Self {
799 self.chat_history.push(message);
800 self
801 }
802
803 pub fn messages(self, messages: Vec<Message>) -> Self {
805 messages
806 .into_iter()
807 .fold(self, |builder, msg| builder.message(msg))
808 }
809
810 pub fn document(mut self, document: Document) -> Self {
812 self.documents.push(document);
813 self
814 }
815
816 pub fn documents(self, documents: Vec<Document>) -> Self {
818 documents
819 .into_iter()
820 .fold(self, |builder, doc| builder.document(doc))
821 }
822
823 pub fn tool(mut self, tool: ToolDefinition) -> Self {
825 self.tools.push(tool);
826 self
827 }
828
829 pub fn tools(self, tools: Vec<ToolDefinition>) -> Self {
831 tools
832 .into_iter()
833 .fold(self, |builder, tool| builder.tool(tool))
834 }
835
836 pub fn provider_tool(mut self, tool: ProviderToolDefinition) -> Self {
838 self.provider_tools.push(tool);
839 self
840 }
841
842 pub fn provider_tools(self, tools: Vec<ProviderToolDefinition>) -> Self {
844 tools
845 .into_iter()
846 .fold(self, |builder, tool| builder.provider_tool(tool))
847 }
848
849 pub fn additional_params(mut self, additional_params: serde_json::Value) -> Self {
855 match self.additional_params {
856 Some(params) => {
857 self.additional_params = Some(json_utils::merge(params, additional_params));
858 }
859 None => {
860 self.additional_params = Some(additional_params);
861 }
862 }
863 self
864 }
865
866 pub fn additional_params_opt(mut self, additional_params: Option<serde_json::Value>) -> Self {
872 self.additional_params = additional_params;
873 self
874 }
875
876 pub fn temperature(mut self, temperature: f64) -> Self {
878 self.temperature = Some(temperature);
879 self
880 }
881
882 pub fn temperature_opt(mut self, temperature: Option<f64>) -> Self {
884 self.temperature = temperature;
885 self
886 }
887
888 pub fn max_tokens(mut self, max_tokens: u64) -> Self {
891 self.max_tokens = Some(max_tokens);
892 self
893 }
894
895 pub fn max_tokens_opt(mut self, max_tokens: Option<u64>) -> Self {
898 self.max_tokens = max_tokens;
899 self
900 }
901
902 pub fn tool_choice(mut self, tool_choice: ToolChoice) -> Self {
904 self.tool_choice = Some(tool_choice);
905 self
906 }
907
908 pub fn output_schema(mut self, schema: schemars::Schema) -> Self {
915 self.output_schema = Some(schema);
916 self
917 }
918
919 pub fn output_schema_opt(mut self, schema: Option<schemars::Schema>) -> Self {
925 self.output_schema = schema;
926 self
927 }
928
929 pub fn build(self) -> CompletionRequest {
931 let mut chat_history = self.chat_history;
932 if let Some(preamble) = self.preamble {
933 chat_history.insert(0, Message::system(preamble));
934 }
935 let chat_history = OneOrMany::many([chat_history, vec![self.prompt]].concat())
936 .expect("There will always be atleast the prompt");
937 let additional_params = merge_provider_tools_into_additional_params(
938 self.additional_params,
939 self.provider_tools,
940 );
941
942 CompletionRequest {
943 model: self.request_model,
944 preamble: None,
945 chat_history,
946 documents: self.documents,
947 tools: self.tools,
948 temperature: self.temperature,
949 max_tokens: self.max_tokens,
950 tool_choice: self.tool_choice,
951 additional_params,
952 output_schema: self.output_schema,
953 }
954 }
955
956 pub async fn send(self) -> Result<CompletionResponse<M::Response>, CompletionError> {
958 let model = self.model.clone();
959 model.completion(self.build()).await
960 }
961
962 pub async fn stream<'a>(
964 self,
965 ) -> Result<StreamingCompletionResponse<M::StreamingResponse>, CompletionError>
966 where
967 <M as CompletionModel>::StreamingResponse: 'a,
968 Self: 'a,
969 {
970 let model = self.model.clone();
971 model.stream(self.build()).await
972 }
973}
974
975#[cfg(test)]
976mod tests {
977
978 use super::*;
979 use crate::streaming::StreamingCompletionResponse;
980 use serde::{Deserialize, Serialize};
981
982 #[derive(Clone)]
983 struct DummyModel;
984
985 #[derive(Clone, Debug, Serialize, Deserialize)]
986 struct DummyStreamingResponse;
987
988 impl GetTokenUsage for DummyStreamingResponse {
989 fn token_usage(&self) -> Option<Usage> {
990 None
991 }
992 }
993
994 impl CompletionModel for DummyModel {
995 type Response = serde_json::Value;
996 type StreamingResponse = DummyStreamingResponse;
997 type Client = ();
998
999 fn make(_client: &Self::Client, _model: impl Into<String>) -> Self {
1000 Self
1001 }
1002
1003 async fn completion(
1004 &self,
1005 _request: CompletionRequest,
1006 ) -> Result<CompletionResponse<Self::Response>, CompletionError> {
1007 Err(CompletionError::ProviderError(
1008 "dummy completion model".to_string(),
1009 ))
1010 }
1011
1012 async fn stream(
1013 &self,
1014 _request: CompletionRequest,
1015 ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
1016 Err(CompletionError::ProviderError(
1017 "dummy completion model".to_string(),
1018 ))
1019 }
1020 }
1021
1022 #[test]
1023 fn test_document_display_without_metadata() {
1024 let doc = Document {
1025 id: "123".to_string(),
1026 text: "This is a test document.".to_string(),
1027 additional_props: HashMap::new(),
1028 };
1029
1030 let expected = "<file id: 123>\nThis is a test document.\n</file>\n";
1031 assert_eq!(format!("{doc}"), expected);
1032 }
1033
1034 #[test]
1035 fn test_document_display_with_metadata() {
1036 let mut additional_props = HashMap::new();
1037 additional_props.insert("author".to_string(), "John Doe".to_string());
1038 additional_props.insert("length".to_string(), "42".to_string());
1039
1040 let doc = Document {
1041 id: "123".to_string(),
1042 text: "This is a test document.".to_string(),
1043 additional_props,
1044 };
1045
1046 let expected = concat!(
1047 "<file id: 123>\n",
1048 "<metadata author: \"John Doe\" length: \"42\" />\n",
1049 "This is a test document.\n",
1050 "</file>\n"
1051 );
1052 assert_eq!(format!("{doc}"), expected);
1053 }
1054
1055 #[test]
1056 fn test_normalize_documents_with_documents() {
1057 let doc1 = Document {
1058 id: "doc1".to_string(),
1059 text: "Document 1 text.".to_string(),
1060 additional_props: HashMap::new(),
1061 };
1062
1063 let doc2 = Document {
1064 id: "doc2".to_string(),
1065 text: "Document 2 text.".to_string(),
1066 additional_props: HashMap::new(),
1067 };
1068
1069 let request = CompletionRequest {
1070 model: None,
1071 preamble: None,
1072 chat_history: OneOrMany::one("What is the capital of France?".into()),
1073 documents: vec![doc1, doc2],
1074 tools: Vec::new(),
1075 temperature: None,
1076 max_tokens: None,
1077 tool_choice: None,
1078 additional_params: None,
1079 output_schema: None,
1080 };
1081
1082 let expected = Message::User {
1083 content: OneOrMany::many(vec![
1084 UserContent::document(
1085 "<file id: doc1>\nDocument 1 text.\n</file>\n".to_string(),
1086 Some(DocumentMediaType::TXT),
1087 ),
1088 UserContent::document(
1089 "<file id: doc2>\nDocument 2 text.\n</file>\n".to_string(),
1090 Some(DocumentMediaType::TXT),
1091 ),
1092 ])
1093 .expect("There will be at least one document"),
1094 };
1095
1096 assert_eq!(request.normalized_documents(), Some(expected));
1097 }
1098
1099 #[test]
1100 fn test_normalize_documents_without_documents() {
1101 let request = CompletionRequest {
1102 model: None,
1103 preamble: None,
1104 chat_history: OneOrMany::one("What is the capital of France?".into()),
1105 documents: Vec::new(),
1106 tools: Vec::new(),
1107 temperature: None,
1108 max_tokens: None,
1109 tool_choice: None,
1110 additional_params: None,
1111 output_schema: None,
1112 };
1113
1114 assert_eq!(request.normalized_documents(), None);
1115 }
1116
1117 #[test]
1118 fn preamble_builder_funnels_to_system_message() {
1119 let request = CompletionRequestBuilder::new(DummyModel, Message::user("Prompt"))
1120 .preamble("System prompt".to_string())
1121 .message(Message::user("History"))
1122 .build();
1123
1124 assert_eq!(request.preamble, None);
1125
1126 let history = request.chat_history.into_iter().collect::<Vec<_>>();
1127 assert_eq!(history.len(), 3);
1128 assert!(matches!(
1129 &history[0],
1130 Message::System { content } if content == "System prompt"
1131 ));
1132 assert!(matches!(&history[1], Message::User { .. }));
1133 assert!(matches!(&history[2], Message::User { .. }));
1134 }
1135
1136 #[test]
1137 fn without_preamble_removes_legacy_preamble_injection() {
1138 let request = CompletionRequestBuilder::new(DummyModel, Message::user("Prompt"))
1139 .preamble("System prompt".to_string())
1140 .without_preamble()
1141 .build();
1142
1143 assert_eq!(request.preamble, None);
1144 let history = request.chat_history.into_iter().collect::<Vec<_>>();
1145 assert_eq!(history.len(), 1);
1146 assert!(matches!(&history[0], Message::User { .. }));
1147 }
1148}