1use super::message::{AssistantContent, DocumentMediaType};
38use crate::message::ToolChoice;
39use crate::streaming::StreamingCompletionResponse;
40use crate::tool::server::ToolServerError;
41use crate::wasm_compat::{WasmCompatSend, WasmCompatSync};
42use crate::{OneOrMany, http_client};
43use crate::{
44 json_utils,
45 message::{Message, UserContent},
46 tool::ToolSetError,
47};
48use serde::de::DeserializeOwned;
49use serde::{Deserialize, Serialize};
50use std::collections::HashMap;
51use std::ops::{Add, AddAssign};
52use thiserror::Error;
53
54#[derive(Debug, Error)]
56pub enum CompletionError {
57 #[error("HttpError: {0}")]
59 HttpError(#[from] http_client::Error),
60
61 #[error("JsonError: {0}")]
63 JsonError(#[from] serde_json::Error),
64
65 #[error("UrlError: {0}")]
67 UrlError(#[from] url::ParseError),
68
69 #[cfg(not(target_family = "wasm"))]
70 #[error("RequestError: {0}")]
72 RequestError(#[from] Box<dyn std::error::Error + Send + Sync + 'static>),
73
74 #[cfg(target_family = "wasm")]
75 #[error("RequestError: {0}")]
77 RequestError(#[from] Box<dyn std::error::Error + 'static>),
78
79 #[error("ResponseError: {0}")]
81 ResponseError(String),
82
83 #[error("ProviderError: {0}")]
85 ProviderError(String),
86}
87
88#[derive(Debug, Error)]
90pub enum PromptError {
91 #[error("CompletionError: {0}")]
93 CompletionError(#[from] CompletionError),
94
95 #[error("ToolCallError: {0}")]
97 ToolError(#[from] ToolSetError),
98
99 #[error("ToolServerError: {0}")]
101 ToolServerError(#[from] Box<ToolServerError>),
102
103 #[error("MaxTurnsError: reached max turns limit: {max_turns}")]
107 MaxTurnsError {
108 max_turns: usize,
109 chat_history: Box<Vec<Message>>,
110 prompt: Box<Message>,
111 },
112
113 #[error("PromptCancelled: {reason}")]
115 PromptCancelled {
116 chat_history: Vec<Message>,
117 reason: String,
118 },
119
120 #[error(
123 "UnknownToolCall: model attempted to call unknown or disallowed tool `{tool_name}`. Available tools: {available_tools:?}. Allowed tools for this turn: {allowed_tools:?}"
124 )]
125 UnknownToolCall {
126 tool_name: String,
127 available_tools: Vec<String>,
128 allowed_tools: Vec<String>,
129 chat_history: Box<Vec<Message>>,
130 },
131}
132
133impl From<crate::memory::MemoryError> for PromptError {
137 fn from(err: crate::memory::MemoryError) -> Self {
138 Self::CompletionError(CompletionError::RequestError(Box::new(err)))
139 }
140}
141
142impl PromptError {
143 pub(crate) fn prompt_cancelled(
144 chat_history: impl IntoIterator<Item = Message>,
145 reason: impl Into<String>,
146 ) -> Self {
147 Self::PromptCancelled {
148 chat_history: chat_history.into_iter().collect(),
149 reason: reason.into(),
150 }
151 }
152}
153
154#[derive(Debug, Error)]
156pub enum StructuredOutputError {
157 #[error("PromptError: {0}")]
159 PromptError(#[from] Box<PromptError>),
160
161 #[error("DeserializationError: {0}")]
163 DeserializationError(#[from] serde_json::Error),
164
165 #[error("EmptyResponse: model returned no content")]
167 EmptyResponse,
168}
169
170#[derive(Clone, Debug, Deserialize, Serialize)]
171pub struct Document {
172 pub id: String,
174 pub text: String,
176 #[serde(flatten)]
178 pub additional_props: HashMap<String, String>,
179}
180
181impl std::fmt::Display for Document {
182 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
183 write!(
184 f,
185 concat!("<file id: {}>\n", "{}\n", "</file>\n"),
186 self.id,
187 if self.additional_props.is_empty() {
188 self.text.clone()
189 } else {
190 let mut sorted_props = self.additional_props.iter().collect::<Vec<_>>();
191 sorted_props.sort_by(|a, b| a.0.cmp(b.0));
192 let metadata = sorted_props
193 .iter()
194 .map(|(k, v)| format!("{k}: {v:?}"))
195 .collect::<Vec<_>>()
196 .join(" ");
197 format!("<metadata {} />\n{}", metadata, self.text)
198 }
199 )
200 }
201}
202
203#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
204pub struct ToolDefinition {
205 pub name: String,
207 pub description: String,
209 pub parameters: serde_json::Value,
211}
212
213#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
218pub struct ProviderToolDefinition {
219 #[serde(rename = "type")]
221 pub kind: String,
222 #[serde(flatten, default, skip_serializing_if = "serde_json::Map::is_empty")]
224 pub config: serde_json::Map<String, serde_json::Value>,
225}
226
227impl ProviderToolDefinition {
228 pub fn new(kind: impl Into<String>) -> Self {
230 Self {
231 kind: kind.into(),
232 config: serde_json::Map::new(),
233 }
234 }
235
236 pub fn with_config(mut self, key: impl Into<String>, value: serde_json::Value) -> Self {
238 self.config.insert(key.into(), value);
239 self
240 }
241}
242
243pub trait Prompt: WasmCompatSend + WasmCompatSync {
248 fn prompt(
257 &self,
258 prompt: impl Into<Message> + WasmCompatSend,
259 ) -> impl std::future::IntoFuture<Output = Result<String, PromptError>, IntoFuture: WasmCompatSend>;
260}
261
262pub trait Chat: WasmCompatSend + WasmCompatSync {
264 fn chat(
278 &self,
279 prompt: impl Into<Message> + WasmCompatSend,
280 chat_history: &mut Vec<Message>,
281 ) -> impl std::future::Future<Output = Result<String, PromptError>> + WasmCompatSend;
282}
283
284pub trait TypedPrompt: WasmCompatSend + WasmCompatSync {
308 type TypedRequest<T>: std::future::IntoFuture<Output = Result<T, StructuredOutputError>>
310 where
311 T: schemars::JsonSchema + DeserializeOwned + WasmCompatSend + 'static;
312
313 fn prompt_typed<T>(&self, prompt: impl Into<Message> + WasmCompatSend) -> Self::TypedRequest<T>
333 where
334 T: schemars::JsonSchema + DeserializeOwned + WasmCompatSend;
335}
336
337pub trait Completion<M: CompletionModel> {
339 fn completion<I, T>(
351 &self,
352 prompt: impl Into<Message> + WasmCompatSend,
353 chat_history: I,
354 ) -> impl std::future::Future<Output = Result<CompletionRequestBuilder<M>, CompletionError>>
355 + WasmCompatSend
356 where
357 I: IntoIterator<Item = T> + WasmCompatSend,
358 T: Into<Message>;
359}
360
361#[derive(Debug)]
364pub struct CompletionResponse<T> {
365 pub choice: OneOrMany<AssistantContent>,
368 pub usage: Usage,
370 pub raw_response: T,
372 pub message_id: Option<String>,
375}
376
377pub trait GetTokenUsage {
381 fn token_usage(&self) -> crate::completion::Usage;
385}
386
387impl GetTokenUsage for () {
388 fn token_usage(&self) -> crate::completion::Usage {
389 crate::completion::Usage::new()
390 }
391}
392
393impl<T> GetTokenUsage for Option<T>
394where
395 T: GetTokenUsage,
396{
397 fn token_usage(&self) -> crate::completion::Usage {
398 if let Some(usage) = self {
399 usage.token_usage()
400 } else {
401 crate::completion::Usage::new()
402 }
403 }
404}
405
406#[derive(Debug, PartialEq, Eq, Clone, Copy, Serialize, Deserialize)]
409pub struct Usage {
410 pub input_tokens: u64,
412 pub output_tokens: u64,
414 pub total_tokens: u64,
416 pub cached_input_tokens: u64,
418 pub cache_creation_input_tokens: u64,
420 #[serde(default)]
422 pub tool_use_prompt_tokens: u64,
423 pub reasoning_tokens: u64,
426}
427
428impl Usage {
429 pub fn new() -> Self {
431 Self {
432 input_tokens: 0,
433 output_tokens: 0,
434 total_tokens: 0,
435 cached_input_tokens: 0,
436 cache_creation_input_tokens: 0,
437 tool_use_prompt_tokens: 0,
438 reasoning_tokens: 0,
439 }
440 }
441
442 pub fn has_values(&self) -> bool {
447 *self != Self::new()
448 }
449}
450
451impl Default for Usage {
452 fn default() -> Self {
453 Self::new()
454 }
455}
456
457impl Add for Usage {
458 type Output = Self;
459
460 fn add(self, other: Self) -> Self::Output {
461 Self {
462 input_tokens: self.input_tokens + other.input_tokens,
463 output_tokens: self.output_tokens + other.output_tokens,
464 total_tokens: self.total_tokens + other.total_tokens,
465 cached_input_tokens: self.cached_input_tokens + other.cached_input_tokens,
466 cache_creation_input_tokens: self.cache_creation_input_tokens
467 + other.cache_creation_input_tokens,
468 tool_use_prompt_tokens: self.tool_use_prompt_tokens + other.tool_use_prompt_tokens,
469 reasoning_tokens: self.reasoning_tokens + other.reasoning_tokens,
470 }
471 }
472}
473
474impl AddAssign for Usage {
475 fn add_assign(&mut self, other: Self) {
476 self.input_tokens += other.input_tokens;
477 self.output_tokens += other.output_tokens;
478 self.total_tokens += other.total_tokens;
479 self.cached_input_tokens += other.cached_input_tokens;
480 self.cache_creation_input_tokens += other.cache_creation_input_tokens;
481 self.tool_use_prompt_tokens += other.tool_use_prompt_tokens;
482 self.reasoning_tokens += other.reasoning_tokens;
483 }
484}
485
486pub trait CompletionModel: Clone + WasmCompatSend + WasmCompatSync {
490 type Response: WasmCompatSend + WasmCompatSync + Serialize + DeserializeOwned;
492 type StreamingResponse: Clone
494 + Unpin
495 + WasmCompatSend
496 + WasmCompatSync
497 + Serialize
498 + DeserializeOwned
499 + GetTokenUsage;
500
501 type Client;
503
504 fn make(client: &Self::Client, model: impl Into<String>) -> Self;
506
507 fn completion(
509 &self,
510 request: CompletionRequest,
511 ) -> impl std::future::Future<
512 Output = Result<CompletionResponse<Self::Response>, CompletionError>,
513 > + WasmCompatSend;
514
515 fn stream(
516 &self,
517 request: CompletionRequest,
518 ) -> impl std::future::Future<
519 Output = Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError>,
520 > + WasmCompatSend;
521
522 fn completion_request(&self, prompt: impl Into<Message>) -> CompletionRequestBuilder<Self> {
524 CompletionRequestBuilder::new(self.clone(), prompt)
525 }
526}
527
528#[derive(Debug, Clone, Serialize, Deserialize)]
530pub struct CompletionRequest {
531 pub model: Option<String>,
533 pub preamble: Option<String>,
538 pub chat_history: OneOrMany<Message>,
541 pub documents: Vec<Document>,
543 pub tools: Vec<ToolDefinition>,
545 pub temperature: Option<f64>,
547 pub max_tokens: Option<u64>,
549 pub tool_choice: Option<ToolChoice>,
551 pub additional_params: Option<serde_json::Value>,
553 pub output_schema: Option<schemars::Schema>,
556}
557
558impl CompletionRequest {
559 pub fn output_schema_name(&self) -> Option<String> {
562 self.output_schema.as_ref().map(|schema| {
563 schema
564 .as_object()
565 .and_then(|o| o.get("title"))
566 .and_then(|v| v.as_str())
567 .unwrap_or("response_schema")
568 .to_string()
569 })
570 }
571
572 pub fn normalized_documents(&self) -> Option<Message> {
576 Self::normalized_documents_from(&self.documents)
577 }
578
579 fn normalized_documents_from(documents: &[Document]) -> Option<Message> {
580 if documents.is_empty() {
581 return None;
582 }
583
584 let messages = documents
587 .iter()
588 .map(|doc| {
589 UserContent::document(
590 doc.to_string(),
591 Some(DocumentMediaType::TXT),
594 )
595 })
596 .collect::<Vec<_>>();
597
598 OneOrMany::from_iter_optional(messages).map(|content| Message::User { content })
599 }
600
601 pub(crate) fn chat_history_with_documents(&self) -> Vec<Message> {
602 let mut chat_history = self.chat_history.iter().cloned().collect::<Vec<_>>();
603 if let Some(documents) = self.normalized_documents() {
604 let insert_at = chat_history
605 .iter()
606 .position(|message| !matches!(message, Message::System { .. }))
607 .unwrap_or(chat_history.len());
608 chat_history.insert(insert_at, documents);
609 }
610 chat_history
611 }
612
613 pub fn with_provider_tool(mut self, tool: ProviderToolDefinition) -> Self {
615 self.additional_params =
616 merge_provider_tools_into_additional_params(self.additional_params, vec![tool]);
617 self
618 }
619
620 pub fn with_provider_tools(mut self, tools: Vec<ProviderToolDefinition>) -> Self {
622 self.additional_params =
623 merge_provider_tools_into_additional_params(self.additional_params, tools);
624 self
625 }
626}
627
628fn merge_provider_tools_into_additional_params(
629 additional_params: Option<serde_json::Value>,
630 provider_tools: Vec<ProviderToolDefinition>,
631) -> Option<serde_json::Value> {
632 if provider_tools.is_empty() {
633 return additional_params;
634 }
635
636 let mut provider_tools_json = provider_tools
637 .into_iter()
638 .map(|ProviderToolDefinition { kind, mut config }| {
639 config.insert("type".to_string(), serde_json::Value::String(kind));
641 serde_json::Value::Object(config)
642 })
643 .collect::<Vec<_>>();
644
645 let mut params_map = match additional_params {
646 Some(serde_json::Value::Object(map)) => map,
647 Some(serde_json::Value::Bool(stream)) => {
648 let mut map = serde_json::Map::new();
649 map.insert("stream".to_string(), serde_json::Value::Bool(stream));
650 map
651 }
652 _ => serde_json::Map::new(),
653 };
654
655 let mut merged_tools = match params_map.remove("tools") {
656 Some(serde_json::Value::Array(existing)) => existing,
657 _ => Vec::new(),
658 };
659 merged_tools.append(&mut provider_tools_json);
660 params_map.insert("tools".to_string(), serde_json::Value::Array(merged_tools));
661 Some(serde_json::Value::Object(params_map))
662}
663
664pub struct CompletionRequestBuilder<M: CompletionModel> {
714 model: M,
715 prompt: Message,
716 request_model: Option<String>,
717 preamble: Option<String>,
718 chat_history: Vec<Message>,
719 documents: Vec<Document>,
720 tools: Vec<ToolDefinition>,
721 provider_tools: Vec<ProviderToolDefinition>,
722 temperature: Option<f64>,
723 max_tokens: Option<u64>,
724 tool_choice: Option<ToolChoice>,
725 additional_params: Option<serde_json::Value>,
726 output_schema: Option<schemars::Schema>,
727}
728
729impl<M: CompletionModel> CompletionRequestBuilder<M> {
730 pub fn new(model: M, prompt: impl Into<Message>) -> Self {
731 Self {
732 model,
733 prompt: prompt.into(),
734 request_model: None,
735 preamble: None,
736 chat_history: Vec::new(),
737 documents: Vec::new(),
738 tools: Vec::new(),
739 provider_tools: Vec::new(),
740 temperature: None,
741 max_tokens: None,
742 tool_choice: None,
743 additional_params: None,
744 output_schema: None,
745 }
746 }
747
748 pub fn preamble(mut self, preamble: String) -> Self {
750 self.preamble = Some(preamble);
752 self
753 }
754
755 pub fn model(mut self, model: impl Into<String>) -> Self {
757 self.request_model = Some(model.into());
758 self
759 }
760
761 pub fn model_opt(mut self, model: Option<String>) -> Self {
763 self.request_model = model;
764 self
765 }
766
767 pub fn without_preamble(mut self) -> Self {
768 self.preamble = None;
769 self
770 }
771
772 pub fn message(mut self, message: Message) -> Self {
774 self.chat_history.push(message);
775
776 self
777 }
778
779 pub fn messages(mut self, messages: impl IntoIterator<Item = Message>) -> Self {
781 self.chat_history.extend(messages);
782
783 self
784 }
785
786 pub fn document(mut self, document: Document) -> Self {
788 self.documents.push(document);
789 self
790 }
791
792 pub fn documents(self, documents: impl IntoIterator<Item = Document>) -> Self {
794 documents
795 .into_iter()
796 .fold(self, |builder, doc| builder.document(doc))
797 }
798
799 pub fn tool(mut self, tool: ToolDefinition) -> Self {
801 self.tools.push(tool);
802 self
803 }
804
805 pub fn tools(self, tools: Vec<ToolDefinition>) -> Self {
807 tools
808 .into_iter()
809 .fold(self, |builder, tool| builder.tool(tool))
810 }
811
812 pub fn provider_tool(mut self, tool: ProviderToolDefinition) -> Self {
814 self.provider_tools.push(tool);
815 self
816 }
817
818 pub fn provider_tools(self, tools: Vec<ProviderToolDefinition>) -> Self {
820 tools
821 .into_iter()
822 .fold(self, |builder, tool| builder.provider_tool(tool))
823 }
824
825 pub fn additional_params(mut self, additional_params: serde_json::Value) -> Self {
831 match self.additional_params {
832 Some(params) => {
833 self.additional_params = Some(json_utils::merge(params, additional_params));
834 }
835 None => {
836 self.additional_params = Some(additional_params);
837 }
838 }
839 self
840 }
841
842 pub fn additional_params_opt(mut self, additional_params: Option<serde_json::Value>) -> Self {
848 self.additional_params = additional_params;
849 self
850 }
851
852 pub fn temperature(mut self, temperature: f64) -> Self {
854 self.temperature = Some(temperature);
855 self
856 }
857
858 pub fn temperature_opt(mut self, temperature: Option<f64>) -> Self {
860 self.temperature = temperature;
861 self
862 }
863
864 pub fn max_tokens(mut self, max_tokens: u64) -> Self {
867 self.max_tokens = Some(max_tokens);
868 self
869 }
870
871 pub fn max_tokens_opt(mut self, max_tokens: Option<u64>) -> Self {
874 self.max_tokens = max_tokens;
875 self
876 }
877
878 pub fn tool_choice(mut self, tool_choice: ToolChoice) -> Self {
880 self.tool_choice = Some(tool_choice);
881 self
882 }
883
884 pub fn output_schema(mut self, schema: schemars::Schema) -> Self {
891 self.output_schema = Some(schema);
892 self
893 }
894
895 pub fn output_schema_opt(mut self, schema: Option<schemars::Schema>) -> Self {
901 self.output_schema = schema;
902 self
903 }
904
905 pub fn build(self) -> CompletionRequest {
907 let mut chat_history = self.chat_history;
909 let prompt = self.prompt;
910 if let Some(preamble) = self.preamble {
911 chat_history.insert(0, Message::system(preamble));
912 }
913
914 chat_history.push(prompt.clone());
915
916 let chat_history =
917 OneOrMany::from_iter_optional(chat_history).unwrap_or_else(|| OneOrMany::one(prompt));
918 let additional_params = merge_provider_tools_into_additional_params(
919 self.additional_params,
920 self.provider_tools,
921 );
922
923 CompletionRequest {
924 model: self.request_model,
925 preamble: None,
926 chat_history,
927 documents: self.documents,
928 tools: self.tools,
929 temperature: self.temperature,
930 max_tokens: self.max_tokens,
931 tool_choice: self.tool_choice,
932 additional_params,
933 output_schema: self.output_schema,
934 }
935 }
936
937 pub async fn send(self) -> Result<CompletionResponse<M::Response>, CompletionError> {
939 let model = self.model.clone();
940 model.completion(self.build()).await
941 }
942
943 pub async fn stream<'a>(
945 self,
946 ) -> Result<StreamingCompletionResponse<M::StreamingResponse>, CompletionError>
947 where
948 <M as CompletionModel>::StreamingResponse: 'a,
949 Self: 'a,
950 {
951 let model = self.model.clone();
952 model.stream(self.build()).await
953 }
954}
955
956#[cfg(test)]
957mod tests {
958 #[test]
959 fn usage_has_values_reflects_the_zero_sentinel() {
960 use super::Usage;
961
962 assert!(!Usage::new().has_values());
963
964 let mut usage = Usage::new();
965 usage.reasoning_tokens = 1;
966 assert!(usage.has_values());
967 }
968
969 use super::*;
970 use crate::test_utils::MockCompletionModel;
971
972 fn test_document(id: &str, text: &str) -> Document {
973 Document {
974 id: id.to_string(),
975 text: text.to_string(),
976 additional_props: HashMap::new(),
977 }
978 }
979
980 fn is_document_message(message: &Message, expected_id: &str) -> bool {
981 let Message::User { content } = message else {
982 return false;
983 };
984
985 content.iter().any(|content| {
986 matches!(
987 content,
988 UserContent::Document(document)
989 if document.data.to_string().contains(&format!("<file id: {expected_id}>"))
990 )
991 })
992 }
993
994 #[test]
995 fn test_document_display_without_metadata() {
996 let doc = Document {
997 id: "123".to_string(),
998 text: "This is a test document.".to_string(),
999 additional_props: HashMap::new(),
1000 };
1001
1002 let expected = "<file id: 123>\nThis is a test document.\n</file>\n";
1003 assert_eq!(format!("{doc}"), expected);
1004 }
1005
1006 #[test]
1007 fn test_document_display_with_metadata() {
1008 let mut additional_props = HashMap::new();
1009 additional_props.insert("author".to_string(), "John Doe".to_string());
1010 additional_props.insert("length".to_string(), "42".to_string());
1011
1012 let doc = Document {
1013 id: "123".to_string(),
1014 text: "This is a test document.".to_string(),
1015 additional_props,
1016 };
1017
1018 let expected = concat!(
1019 "<file id: 123>\n",
1020 "<metadata author: \"John Doe\" length: \"42\" />\n",
1021 "This is a test document.\n",
1022 "</file>\n"
1023 );
1024 assert_eq!(format!("{doc}"), expected);
1025 }
1026
1027 #[test]
1028 fn test_normalize_documents_with_documents() {
1029 let doc1 = Document {
1030 id: "doc1".to_string(),
1031 text: "Document 1 text.".to_string(),
1032 additional_props: HashMap::new(),
1033 };
1034
1035 let doc2 = Document {
1036 id: "doc2".to_string(),
1037 text: "Document 2 text.".to_string(),
1038 additional_props: HashMap::new(),
1039 };
1040
1041 let request = CompletionRequest {
1042 model: None,
1043 preamble: None,
1044 chat_history: OneOrMany::one("What is the capital of France?".into()),
1045 documents: vec![doc1, doc2],
1046 tools: Vec::new(),
1047 temperature: None,
1048 max_tokens: None,
1049 tool_choice: None,
1050 additional_params: None,
1051 output_schema: None,
1052 };
1053
1054 let expected = Message::User {
1055 content: OneOrMany::many(vec![
1056 UserContent::document(
1057 "<file id: doc1>\nDocument 1 text.\n</file>\n".to_string(),
1058 Some(DocumentMediaType::TXT),
1059 ),
1060 UserContent::document(
1061 "<file id: doc2>\nDocument 2 text.\n</file>\n".to_string(),
1062 Some(DocumentMediaType::TXT),
1063 ),
1064 ])
1065 .expect("There will be at least one document"),
1066 };
1067
1068 assert_eq!(request.normalized_documents(), Some(expected));
1069 }
1070
1071 #[test]
1072 fn test_normalize_documents_without_documents() {
1073 let request = CompletionRequest {
1074 model: None,
1075 preamble: None,
1076 chat_history: OneOrMany::one("What is the capital of France?".into()),
1077 documents: Vec::new(),
1078 tools: Vec::new(),
1079 temperature: None,
1080 max_tokens: None,
1081 tool_choice: None,
1082 additional_params: None,
1083 output_schema: None,
1084 };
1085
1086 assert_eq!(request.normalized_documents(), None);
1087 }
1088
1089 #[test]
1090 fn preamble_builder_funnels_to_system_message() {
1091 let request =
1092 CompletionRequestBuilder::new(MockCompletionModel::default(), Message::user("Prompt"))
1093 .preamble("System prompt".to_string())
1094 .message(Message::user("History"))
1095 .build();
1096
1097 assert_eq!(request.preamble, None);
1098
1099 let history = request.chat_history.into_iter().collect::<Vec<_>>();
1100 assert_eq!(history.len(), 3);
1101 assert!(matches!(
1102 &history[0],
1103 Message::System { content } if content == "System prompt"
1104 ));
1105 assert!(matches!(&history[1], Message::User { .. }));
1106 assert!(matches!(&history[2], Message::User { .. }));
1107 }
1108
1109 #[test]
1110 fn without_preamble_removes_legacy_preamble_injection() {
1111 let request =
1112 CompletionRequestBuilder::new(MockCompletionModel::default(), Message::user("Prompt"))
1113 .preamble("System prompt".to_string())
1114 .without_preamble()
1115 .build();
1116
1117 assert_eq!(request.preamble, None);
1118 let history = request.chat_history.into_iter().collect::<Vec<_>>();
1119 assert_eq!(history.len(), 1);
1120 assert!(matches!(&history[0], Message::User { .. }));
1121 }
1122
1123 #[test]
1124 fn build_places_documents_after_preamble_system_message() {
1125 let request =
1126 CompletionRequestBuilder::new(MockCompletionModel::default(), Message::user("Prompt"))
1127 .preamble("System prompt".to_string())
1128 .document(test_document("doc1", "Document text."))
1129 .build();
1130
1131 assert_eq!(request.documents.len(), 1);
1132
1133 let history = request.chat_history_with_documents();
1134 let history = history.iter().collect::<Vec<_>>();
1135 assert_eq!(history.len(), 3);
1136 assert!(matches!(
1137 history[0],
1138 Message::System { content } if content == "System prompt"
1139 ));
1140 assert!(is_document_message(history[1], "doc1"));
1141 assert!(matches!(history[2], Message::User { .. }));
1142 }
1143
1144 #[test]
1145 fn build_places_documents_after_leading_system_messages_before_prior_history() {
1146 let request =
1147 CompletionRequestBuilder::new(MockCompletionModel::default(), Message::user("Prompt"))
1148 .message(Message::system("System one"))
1149 .message(Message::system("System two"))
1150 .message(Message::user("Earlier user turn"))
1151 .message(Message::assistant("Earlier assistant turn"))
1152 .document(test_document("doc1", "Document text."))
1153 .build();
1154
1155 let history = request.chat_history_with_documents();
1156 let history = history.iter().collect::<Vec<_>>();
1157 assert_eq!(history.len(), 6);
1158 assert!(matches!(
1159 history[0],
1160 Message::System { content } if content == "System one"
1161 ));
1162 assert!(matches!(
1163 history[1],
1164 Message::System { content } if content == "System two"
1165 ));
1166 assert!(is_document_message(history[2], "doc1"));
1167 assert!(matches!(history[3], Message::User { .. }));
1168 assert!(matches!(history[4], Message::Assistant { .. }));
1169 assert!(matches!(history[5], Message::User { .. }));
1170 }
1171
1172 #[test]
1173 fn build_without_documents_keeps_message_order_unchanged() {
1174 let request =
1175 CompletionRequestBuilder::new(MockCompletionModel::default(), Message::user("Prompt"))
1176 .message(Message::system("System prompt"))
1177 .message(Message::user("Earlier user turn"))
1178 .build();
1179
1180 let history = request.chat_history.iter().collect::<Vec<_>>();
1181 assert_eq!(history.len(), 3);
1182 assert!(matches!(
1183 history[0],
1184 Message::System { content } if content == "System prompt"
1185 ));
1186 assert!(matches!(history[1], Message::User { .. }));
1187 assert!(matches!(history[2], Message::User { .. }));
1188 }
1189
1190 #[test]
1191 fn chat_history_with_documents_places_documents_after_leading_system_messages() {
1192 let request = CompletionRequest {
1193 model: None,
1194 preamble: None,
1195 chat_history: OneOrMany::many(vec![
1196 Message::system("System prompt"),
1197 Message::assistant("Earlier assistant turn"),
1198 Message::user("Earlier user turn"),
1199 Message::user("Prompt"),
1200 ])
1201 .unwrap(),
1202 documents: vec![test_document("doc1", "Document text.")],
1203 tools: Vec::new(),
1204 temperature: None,
1205 max_tokens: None,
1206 tool_choice: None,
1207 additional_params: None,
1208 output_schema: None,
1209 };
1210
1211 assert_eq!(request.documents.len(), 1);
1212
1213 let history = request.chat_history_with_documents();
1214 let history = history.iter().collect::<Vec<_>>();
1215 assert_eq!(history.len(), 5);
1216 assert!(matches!(history[0], Message::System { .. }));
1217 assert!(is_document_message(history[1], "doc1"));
1218 assert!(matches!(history[2], Message::Assistant { .. }));
1219 assert!(matches!(history[3], Message::User { .. }));
1220 assert!(matches!(history[4], Message::User { .. }));
1221 }
1222
1223 #[test]
1224 fn chat_history_with_documents_places_documents_before_mid_conversation_system_messages() {
1225 let request = CompletionRequest {
1226 model: None,
1227 preamble: None,
1228 chat_history: OneOrMany::many(vec![
1229 Message::system("Leading system prompt"),
1230 Message::assistant("Earlier assistant turn"),
1231 Message::system("Mid-conversation instruction"),
1232 Message::user("Prompt"),
1233 ])
1234 .unwrap(),
1235 documents: vec![test_document("doc1", "Document text.")],
1236 tools: Vec::new(),
1237 temperature: None,
1238 max_tokens: None,
1239 tool_choice: None,
1240 additional_params: None,
1241 output_schema: None,
1242 };
1243
1244 let history = request.chat_history_with_documents();
1245 let history = history.iter().collect::<Vec<_>>();
1246 assert_eq!(history.len(), 5);
1247 assert!(matches!(
1248 history[0],
1249 Message::System { content } if content == "Leading system prompt"
1250 ));
1251 assert!(is_document_message(history[1], "doc1"));
1252 assert!(matches!(history[2], Message::Assistant { .. }));
1253 assert!(matches!(
1254 history[3],
1255 Message::System { content } if content == "Mid-conversation instruction"
1256 ));
1257 assert!(matches!(history[4], Message::User { .. }));
1258 }
1259
1260 #[test]
1261 fn chat_history_with_documents_does_not_duplicate_documents() {
1262 let request = CompletionRequest {
1263 model: None,
1264 preamble: None,
1265 chat_history: OneOrMany::many(vec![
1266 Message::system("System prompt"),
1267 Message::user("Earlier user turn"),
1268 Message::assistant("Earlier assistant turn"),
1269 Message::user("Prompt"),
1270 ])
1271 .unwrap(),
1272 documents: vec![test_document("doc1", "Document text.")],
1273 tools: Vec::new(),
1274 temperature: None,
1275 max_tokens: None,
1276 tool_choice: None,
1277 additional_params: None,
1278 output_schema: None,
1279 };
1280
1281 let history = request.chat_history_with_documents();
1282 let document_messages = history
1283 .iter()
1284 .filter(|message| is_document_message(message, "doc1"))
1285 .count();
1286 assert_eq!(document_messages, 1);
1287 }
1288}