rig/completion/
request.rs

1//! This module provides functionality for working with completion models.
2//! It provides traits, structs, and enums for generating completion requests,
3//! handling completion responses, and defining completion models.
4//!
5//! The main traits defined in this module are:
6//! - [Prompt]: Defines a high-level LLM one-shot prompt interface.
7//! - [Chat]: Defines a high-level LLM chat interface with chat history.
8//! - [Completion]: Defines a low-level LLM completion interface for generating completion requests.
9//! - [CompletionModel]: Defines a completion model that can be used to generate completion
10//!   responses from requests.
11//!
12//! The [Prompt] and [Chat] traits are high level traits that users are expected to use
13//! to interact with LLM models. Moreover, it is good practice to implement one of these
14//! traits for composite agents that use multiple LLM models to generate responses.
15//!
16//! The [Completion] trait defines a lower level interface that is useful when the user want
17//! to further customize the request before sending it to the completion model provider.
18//!
19//! The [CompletionModel] trait is meant to act as the interface between providers and
20//! the library. It defines the methods that need to be implemented by the user to define
21//! a custom base completion model (i.e.: a private or third party LLM provider).
22//!
23//! The module also provides various structs and enums for representing generic completion requests,
24//! responses, and errors.
25//!
26//! Example Usage:
27//! ```rust
28//! use rig::providers::openai::{Client, self};
29//! use rig::completion::*;
30//!
31//! // Initialize the OpenAI client and a completion model
32//! let openai = Client::new("your-openai-api-key");
33//!
34//! let gpt_4 = openai.completion_model(openai::GPT_4);
35//!
36//! // Create the completion request
37//! let request = gpt_4.completion_request("Who are you?")
38//!     .preamble("\
39//!         You are Marvin, an extremely smart but depressed robot who is \
40//!         nonetheless helpful towards humanity.\
41//!     ")
42//!     .temperature(0.5)
43//!     .build();
44//!
45//! // Send the completion request and get the completion response
46//! let response = gpt_4.completion(request)
47//!     .await
48//!     .expect("Failed to get completion response");
49//!
50//! // Handle the completion response
51//! match completion_response.choice {
52//!     ModelChoice::Message(message) => {
53//!         // Handle the completion response as a message
54//!         println!("Received message: {}", message);
55//!     }
56//!     ModelChoice::ToolCall(tool_name, tool_params) => {
57//!         // Handle the completion response as a tool call
58//!         println!("Received tool call: {} {:?}", tool_name, tool_params);
59//!     }
60//! }
61//! ```
62//!
63//! For more information on how to use the completion functionality, refer to the documentation of
64//! the individual traits, structs, and enums defined in this module.
65
66use super::message::{AssistantContent, DocumentMediaType};
67use crate::client::builder::FinalCompletionResponse;
68use crate::client::completion::CompletionModelHandle;
69use crate::message::ToolChoice;
70use crate::streaming::StreamingCompletionResponse;
71use crate::tool::server::ToolServerError;
72use crate::wasm_compat::{WasmBoxedFuture, WasmCompatSend, WasmCompatSync};
73use crate::{OneOrMany, http_client, streaming};
74use crate::{
75    json_utils,
76    message::{Message, UserContent},
77    tool::ToolSetError,
78};
79use serde::de::DeserializeOwned;
80use serde::{Deserialize, Serialize};
81use std::collections::HashMap;
82use std::ops::{Add, AddAssign};
83use std::sync::Arc;
84use thiserror::Error;
85
86// Errors
87#[derive(Debug, Error)]
88pub enum CompletionError {
89    /// Http error (e.g.: connection error, timeout, etc.)
90    #[error("HttpError: {0}")]
91    HttpError(#[from] http_client::Error),
92
93    /// Json error (e.g.: serialization, deserialization)
94    #[error("JsonError: {0}")]
95    JsonError(#[from] serde_json::Error),
96
97    /// Url error (e.g.: invalid URL)
98    #[error("UrlError: {0}")]
99    UrlError(#[from] url::ParseError),
100
101    #[cfg(not(target_family = "wasm"))]
102    /// Error building the completion request
103    #[error("RequestError: {0}")]
104    RequestError(#[from] Box<dyn std::error::Error + Send + Sync + 'static>),
105
106    #[cfg(target_family = "wasm")]
107    /// Error building the completion request
108    #[error("RequestError: {0}")]
109    RequestError(#[from] Box<dyn std::error::Error + 'static>),
110
111    /// Error parsing the completion response
112    #[error("ResponseError: {0}")]
113    ResponseError(String),
114
115    /// Error returned by the completion model provider
116    #[error("ProviderError: {0}")]
117    ProviderError(String),
118}
119
120/// Prompt errors
121#[derive(Debug, Error)]
122pub enum PromptError {
123    /// Something went wrong with the completion
124    #[error("CompletionError: {0}")]
125    CompletionError(#[from] CompletionError),
126
127    /// There was an error while using a tool
128    #[error("ToolCallError: {0}")]
129    ToolError(#[from] ToolSetError),
130
131    /// There was an issue while executing a tool on a tool server
132    #[error("ToolServerError: {0}")]
133    ToolServerError(#[from] ToolServerError),
134
135    /// The LLM tried to call too many tools during a multi-turn conversation.
136    /// To fix this, you may either need to lower the amount of tools your model has access to (and then create other agents to share the tool load)
137    /// or increase the amount of turns given in `.multi_turn()`.
138    #[error("MaxDepthError: (reached limit: {max_depth})")]
139    MaxDepthError {
140        max_depth: usize,
141        chat_history: Box<Vec<Message>>,
142        prompt: Message,
143    },
144
145    /// A prompting loop was cancelled.
146    #[error("PromptCancelled")]
147    PromptCancelled { chat_history: Box<Vec<Message>> },
148}
149
150impl PromptError {
151    pub(crate) fn prompt_cancelled(chat_history: Vec<Message>) -> Self {
152        Self::PromptCancelled {
153            chat_history: Box::new(chat_history),
154        }
155    }
156}
157
158#[derive(Clone, Debug, Deserialize, Serialize)]
159pub struct Document {
160    pub id: String,
161    pub text: String,
162    #[serde(flatten)]
163    pub additional_props: HashMap<String, String>,
164}
165
166impl std::fmt::Display for Document {
167    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
168        write!(
169            f,
170            concat!("<file id: {}>\n", "{}\n", "</file>\n"),
171            self.id,
172            if self.additional_props.is_empty() {
173                self.text.clone()
174            } else {
175                let mut sorted_props = self.additional_props.iter().collect::<Vec<_>>();
176                sorted_props.sort_by(|a, b| a.0.cmp(b.0));
177                let metadata = sorted_props
178                    .iter()
179                    .map(|(k, v)| format!("{k}: {v:?}"))
180                    .collect::<Vec<_>>()
181                    .join(" ");
182                format!("<metadata {} />\n{}", metadata, self.text)
183            }
184        )
185    }
186}
187
188#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
189pub struct ToolDefinition {
190    pub name: String,
191    pub description: String,
192    pub parameters: serde_json::Value,
193}
194
195// ================================================================
196// Implementations
197// ================================================================
198/// Trait defining a high-level LLM simple prompt interface (i.e.: prompt in, response out).
199pub trait Prompt: WasmCompatSend + WasmCompatSync {
200    /// Send a simple prompt to the underlying completion model.
201    ///
202    /// If the completion model's response is a message, then it is returned as a string.
203    ///
204    /// If the completion model's response is a tool call, then the tool is called and
205    /// the result is returned as a string.
206    ///
207    /// If the tool does not exist, or the tool call fails, then an error is returned.
208    fn prompt(
209        &self,
210        prompt: impl Into<Message> + WasmCompatSend,
211    ) -> impl std::future::IntoFuture<Output = Result<String, PromptError>, IntoFuture: WasmCompatSend>;
212}
213
214/// Trait defining a high-level LLM chat interface (i.e.: prompt and chat history in, response out).
215pub trait Chat: WasmCompatSend + WasmCompatSync {
216    /// Send a prompt with optional chat history to the underlying completion model.
217    ///
218    /// If the completion model's response is a message, then it is returned as a string.
219    ///
220    /// If the completion model's response is a tool call, then the tool is called and the result
221    /// is returned as a string.
222    ///
223    /// If the tool does not exist, or the tool call fails, then an error is returned.
224    fn chat(
225        &self,
226        prompt: impl Into<Message> + WasmCompatSend,
227        chat_history: Vec<Message>,
228    ) -> impl std::future::IntoFuture<Output = Result<String, PromptError>, IntoFuture: WasmCompatSend>;
229}
230
231/// Trait defining a low-level LLM completion interface
232pub trait Completion<M: CompletionModel> {
233    /// Generates a completion request builder for the given `prompt` and `chat_history`.
234    /// This function is meant to be called by the user to further customize the
235    /// request at prompt time before sending it.
236    ///
237    /// ❗IMPORTANT: The type that implements this trait might have already
238    /// populated fields in the builder (the exact fields depend on the type).
239    /// For fields that have already been set by the model, calling the corresponding
240    /// method on the builder will overwrite the value set by the model.
241    ///
242    /// For example, the request builder returned by [`Agent::completion`](crate::agent::Agent::completion) will already
243    /// contain the `preamble` provided when creating the agent.
244    fn completion(
245        &self,
246        prompt: impl Into<Message> + WasmCompatSend,
247        chat_history: Vec<Message>,
248    ) -> impl std::future::Future<Output = Result<CompletionRequestBuilder<M>, CompletionError>>
249    + WasmCompatSend;
250}
251
252/// General completion response struct that contains the high-level completion choice
253/// and the raw response. The completion choice contains one or more assistant content.
254#[derive(Debug)]
255pub struct CompletionResponse<T> {
256    /// The completion choice (represented by one or more assistant message content)
257    /// returned by the completion model provider
258    pub choice: OneOrMany<AssistantContent>,
259    /// Tokens used during prompting and responding
260    pub usage: Usage,
261    /// The raw response returned by the completion model provider
262    pub raw_response: T,
263}
264
265/// A trait for grabbing the token usage of a completion response.
266///
267/// Primarily designed for streamed completion responses in streamed multi-turn, as otherwise it would be impossible to do.
268pub trait GetTokenUsage {
269    fn token_usage(&self) -> Option<crate::completion::Usage>;
270}
271
272impl GetTokenUsage for () {
273    fn token_usage(&self) -> Option<crate::completion::Usage> {
274        None
275    }
276}
277
278impl<T> GetTokenUsage for Option<T>
279where
280    T: GetTokenUsage,
281{
282    fn token_usage(&self) -> Option<crate::completion::Usage> {
283        if let Some(usage) = self {
284            usage.token_usage()
285        } else {
286            None
287        }
288    }
289}
290
291/// Struct representing the token usage for a completion request.
292/// If tokens used are `0`, then the provider failed to supply token usage metrics.
293#[derive(Debug, PartialEq, Eq, Clone, Copy, Serialize, Deserialize)]
294pub struct Usage {
295    /// The number of input ("prompt") tokens used in a given request.
296    pub input_tokens: u64,
297    /// The number of output ("completion") tokens used in a given request.
298    pub output_tokens: u64,
299    /// We store this separately as some providers may only report one number
300    pub total_tokens: u64,
301}
302
303impl Usage {
304    /// Creates a new instance of `Usage`.
305    pub fn new() -> Self {
306        Self {
307            input_tokens: 0,
308            output_tokens: 0,
309            total_tokens: 0,
310        }
311    }
312}
313
314impl Default for Usage {
315    fn default() -> Self {
316        Self::new()
317    }
318}
319
320impl Add for Usage {
321    type Output = Self;
322
323    fn add(self, other: Self) -> Self::Output {
324        Self {
325            input_tokens: self.input_tokens + other.input_tokens,
326            output_tokens: self.output_tokens + other.output_tokens,
327            total_tokens: self.total_tokens + other.total_tokens,
328        }
329    }
330}
331
332impl AddAssign for Usage {
333    fn add_assign(&mut self, other: Self) {
334        self.input_tokens += other.input_tokens;
335        self.output_tokens += other.output_tokens;
336        self.total_tokens += other.total_tokens;
337    }
338}
339
340/// Trait defining a completion model that can be used to generate completion responses.
341/// This trait is meant to be implemented by the user to define a custom completion model,
342/// either from a third party provider (e.g.: OpenAI) or a local model.
343pub trait CompletionModel: Clone + WasmCompatSend + WasmCompatSync {
344    /// The raw response type returned by the underlying completion model.
345    type Response: WasmCompatSend + WasmCompatSync + Serialize + DeserializeOwned;
346    /// The raw response type returned by the underlying completion model when streaming.
347    type StreamingResponse: Clone
348        + Unpin
349        + WasmCompatSend
350        + WasmCompatSync
351        + Serialize
352        + DeserializeOwned
353        + GetTokenUsage;
354
355    /// Generates a completion response for the given completion request.
356    fn completion(
357        &self,
358        request: CompletionRequest,
359    ) -> impl std::future::Future<
360        Output = Result<CompletionResponse<Self::Response>, CompletionError>,
361    > + WasmCompatSend;
362
363    fn stream(
364        &self,
365        request: CompletionRequest,
366    ) -> impl std::future::Future<
367        Output = Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError>,
368    > + WasmCompatSend;
369
370    /// Generates a completion request builder for the given `prompt`.
371    fn completion_request(&self, prompt: impl Into<Message>) -> CompletionRequestBuilder<Self> {
372        CompletionRequestBuilder::new(self.clone(), prompt)
373    }
374}
375pub trait CompletionModelDyn: WasmCompatSend + WasmCompatSync {
376    fn completion(
377        &self,
378        request: CompletionRequest,
379    ) -> WasmBoxedFuture<'_, Result<CompletionResponse<()>, CompletionError>>;
380
381    fn stream(
382        &self,
383        request: CompletionRequest,
384    ) -> WasmBoxedFuture<
385        '_,
386        Result<StreamingCompletionResponse<FinalCompletionResponse>, CompletionError>,
387    >;
388
389    fn completion_request(
390        &self,
391        prompt: Message,
392    ) -> CompletionRequestBuilder<CompletionModelHandle<'_>>;
393}
394
395impl<T, R> CompletionModelDyn for T
396where
397    T: CompletionModel<StreamingResponse = R>,
398    R: Clone + Unpin + GetTokenUsage + 'static,
399{
400    fn completion(
401        &self,
402        request: CompletionRequest,
403    ) -> WasmBoxedFuture<'_, Result<CompletionResponse<()>, CompletionError>> {
404        Box::pin(async move {
405            self.completion(request)
406                .await
407                .map(|resp| CompletionResponse {
408                    choice: resp.choice,
409                    usage: resp.usage,
410                    raw_response: (),
411                })
412        })
413    }
414
415    fn stream(
416        &self,
417        request: CompletionRequest,
418    ) -> WasmBoxedFuture<
419        '_,
420        Result<StreamingCompletionResponse<FinalCompletionResponse>, CompletionError>,
421    > {
422        Box::pin(async move {
423            let resp = self.stream(request).await?;
424            let inner = resp.inner;
425
426            let stream = streaming::StreamingResultDyn {
427                inner: Box::pin(inner),
428            };
429
430            Ok(StreamingCompletionResponse::stream(Box::pin(stream)))
431        })
432    }
433
434    /// Generates a completion request builder for the given `prompt`.
435    fn completion_request(
436        &self,
437        prompt: Message,
438    ) -> CompletionRequestBuilder<CompletionModelHandle<'_>> {
439        CompletionRequestBuilder::new(
440            CompletionModelHandle {
441                inner: Arc::new(self.clone()),
442            },
443            prompt,
444        )
445    }
446}
447
448/// Struct representing a general completion request that can be sent to a completion model provider.
449#[derive(Debug, Clone)]
450pub struct CompletionRequest {
451    /// The preamble to be sent to the completion model provider
452    pub preamble: Option<String>,
453    /// The chat history to be sent to the completion model provider.
454    /// The very last message will always be the prompt (hence why there is *always* one)
455    pub chat_history: OneOrMany<Message>,
456    /// The documents to be sent to the completion model provider
457    pub documents: Vec<Document>,
458    /// The tools to be sent to the completion model provider
459    pub tools: Vec<ToolDefinition>,
460    /// The temperature to be sent to the completion model provider
461    pub temperature: Option<f64>,
462    /// The max tokens to be sent to the completion model provider
463    pub max_tokens: Option<u64>,
464    /// Whether tools are required to be used by the model provider or not before providing a response.
465    pub tool_choice: Option<ToolChoice>,
466    /// Additional provider-specific parameters to be sent to the completion model provider
467    pub additional_params: Option<serde_json::Value>,
468}
469
470impl CompletionRequest {
471    /// Returns documents normalized into a message (if any).
472    /// Most providers do not accept documents directly as input, so it needs to convert into a
473    ///  `Message` so that it can be incorporated into `chat_history` as a
474    pub fn normalized_documents(&self) -> Option<Message> {
475        if self.documents.is_empty() {
476            return None;
477        }
478
479        // Most providers will convert documents into a text unless it can handle document messages.
480        // We use `UserContent::document` for those who handle it directly!
481        let messages = self
482            .documents
483            .iter()
484            .map(|doc| {
485                UserContent::document(
486                    doc.to_string(),
487                    // In the future, we can customize `Document` to pass these extra types through.
488                    // Most providers ditch these but they might want to use them.
489                    Some(DocumentMediaType::TXT),
490                )
491            })
492            .collect::<Vec<_>>();
493
494        Some(Message::User {
495            content: OneOrMany::many(messages).expect("There will be atleast one document"),
496        })
497    }
498}
499
500/// Builder struct for constructing a completion request.
501///
502/// Example usage:
503/// ```rust
504/// use rig::{
505///     providers::openai::{Client, self},
506///     completion::CompletionRequestBuilder,
507/// };
508///
509/// let openai = Client::new("your-openai-api-key");
510/// let model = openai.completion_model(openai::GPT_4O).build();
511///
512/// // Create the completion request and execute it separately
513/// let request = CompletionRequestBuilder::new(model, "Who are you?".to_string())
514///     .preamble("You are Marvin from the Hitchhiker's Guide to the Galaxy.".to_string())
515///     .temperature(0.5)
516///     .build();
517///
518/// let response = model.completion(request)
519///     .await
520///     .expect("Failed to get completion response");
521/// ```
522///
523/// Alternatively, you can execute the completion request directly from the builder:
524/// ```rust
525/// use rig::{
526///     providers::openai::{Client, self},
527///     completion::CompletionRequestBuilder,
528/// };
529///
530/// let openai = Client::new("your-openai-api-key");
531/// let model = openai.completion_model(openai::GPT_4O).build();
532///
533/// // Create the completion request and execute it directly
534/// let response = CompletionRequestBuilder::new(model, "Who are you?".to_string())
535///     .preamble("You are Marvin from the Hitchhiker's Guide to the Galaxy.".to_string())
536///     .temperature(0.5)
537///     .send()
538///     .await
539///     .expect("Failed to get completion response");
540/// ```
541///
542/// Note: It is usually unnecessary to create a completion request builder directly.
543/// Instead, use the [CompletionModel::completion_request] method.
544pub struct CompletionRequestBuilder<M: CompletionModel> {
545    model: M,
546    prompt: Message,
547    preamble: Option<String>,
548    chat_history: Vec<Message>,
549    documents: Vec<Document>,
550    tools: Vec<ToolDefinition>,
551    temperature: Option<f64>,
552    max_tokens: Option<u64>,
553    tool_choice: Option<ToolChoice>,
554    additional_params: Option<serde_json::Value>,
555}
556
557impl<M: CompletionModel> CompletionRequestBuilder<M> {
558    pub fn new(model: M, prompt: impl Into<Message>) -> Self {
559        Self {
560            model,
561            prompt: prompt.into(),
562            preamble: None,
563            chat_history: Vec::new(),
564            documents: Vec::new(),
565            tools: Vec::new(),
566            temperature: None,
567            max_tokens: None,
568            tool_choice: None,
569            additional_params: None,
570        }
571    }
572
573    /// Sets the preamble for the completion request.
574    pub fn preamble(mut self, preamble: String) -> Self {
575        self.preamble = Some(preamble);
576        self
577    }
578
579    pub fn without_preamble(mut self) -> Self {
580        self.preamble = None;
581        self
582    }
583
584    /// Adds a message to the chat history for the completion request.
585    pub fn message(mut self, message: Message) -> Self {
586        self.chat_history.push(message);
587        self
588    }
589
590    /// Adds a list of messages to the chat history for the completion request.
591    pub fn messages(self, messages: Vec<Message>) -> Self {
592        messages
593            .into_iter()
594            .fold(self, |builder, msg| builder.message(msg))
595    }
596
597    /// Adds a document to the completion request.
598    pub fn document(mut self, document: Document) -> Self {
599        self.documents.push(document);
600        self
601    }
602
603    /// Adds a list of documents to the completion request.
604    pub fn documents(self, documents: Vec<Document>) -> Self {
605        documents
606            .into_iter()
607            .fold(self, |builder, doc| builder.document(doc))
608    }
609
610    /// Adds a tool to the completion request.
611    pub fn tool(mut self, tool: ToolDefinition) -> Self {
612        self.tools.push(tool);
613        self
614    }
615
616    /// Adds a list of tools to the completion request.
617    pub fn tools(self, tools: Vec<ToolDefinition>) -> Self {
618        tools
619            .into_iter()
620            .fold(self, |builder, tool| builder.tool(tool))
621    }
622
623    /// Adds additional parameters to the completion request.
624    /// This can be used to set additional provider-specific parameters. For example,
625    /// Cohere's completion models accept a `connectors` parameter that can be used to
626    /// specify the data connectors used by Cohere when executing the completion
627    /// (see `examples/cohere_connectors.rs`).
628    pub fn additional_params(mut self, additional_params: serde_json::Value) -> Self {
629        match self.additional_params {
630            Some(params) => {
631                self.additional_params = Some(json_utils::merge(params, additional_params));
632            }
633            None => {
634                self.additional_params = Some(additional_params);
635            }
636        }
637        self
638    }
639
640    /// Sets the additional parameters for the completion request.
641    /// This can be used to set additional provider-specific parameters. For example,
642    /// Cohere's completion models accept a `connectors` parameter that can be used to
643    /// specify the data connectors used by Cohere when executing the completion
644    /// (see `examples/cohere_connectors.rs`).
645    pub fn additional_params_opt(mut self, additional_params: Option<serde_json::Value>) -> Self {
646        self.additional_params = additional_params;
647        self
648    }
649
650    /// Sets the temperature for the completion request.
651    pub fn temperature(mut self, temperature: f64) -> Self {
652        self.temperature = Some(temperature);
653        self
654    }
655
656    /// Sets the temperature for the completion request.
657    pub fn temperature_opt(mut self, temperature: Option<f64>) -> Self {
658        self.temperature = temperature;
659        self
660    }
661
662    /// Sets the max tokens for the completion request.
663    /// Note: This is required if using Anthropic
664    pub fn max_tokens(mut self, max_tokens: u64) -> Self {
665        self.max_tokens = Some(max_tokens);
666        self
667    }
668
669    /// Sets the max tokens for the completion request.
670    /// Note: This is required if using Anthropic
671    pub fn max_tokens_opt(mut self, max_tokens: Option<u64>) -> Self {
672        self.max_tokens = max_tokens;
673        self
674    }
675
676    /// Sets the thing.
677    pub fn tool_choice(mut self, tool_choice: ToolChoice) -> Self {
678        self.tool_choice = Some(tool_choice);
679        self
680    }
681
682    /// Builds the completion request.
683    pub fn build(self) -> CompletionRequest {
684        let chat_history = OneOrMany::many([self.chat_history, vec![self.prompt]].concat())
685            .expect("There will always be atleast the prompt");
686
687        CompletionRequest {
688            preamble: self.preamble,
689            chat_history,
690            documents: self.documents,
691            tools: self.tools,
692            temperature: self.temperature,
693            max_tokens: self.max_tokens,
694            tool_choice: self.tool_choice,
695            additional_params: self.additional_params,
696        }
697    }
698
699    /// Sends the completion request to the completion model provider and returns the completion response.
700    pub async fn send(self) -> Result<CompletionResponse<M::Response>, CompletionError> {
701        let model = self.model.clone();
702        model.completion(self.build()).await
703    }
704
705    /// Stream the completion request
706    pub async fn stream<'a>(
707        self,
708    ) -> Result<StreamingCompletionResponse<M::StreamingResponse>, CompletionError>
709    where
710        <M as CompletionModel>::StreamingResponse: 'a,
711        Self: 'a,
712    {
713        let model = self.model.clone();
714        model.stream(self.build()).await
715    }
716}
717
718#[cfg(test)]
719mod tests {
720
721    use super::*;
722
723    #[test]
724    fn test_document_display_without_metadata() {
725        let doc = Document {
726            id: "123".to_string(),
727            text: "This is a test document.".to_string(),
728            additional_props: HashMap::new(),
729        };
730
731        let expected = "<file id: 123>\nThis is a test document.\n</file>\n";
732        assert_eq!(format!("{doc}"), expected);
733    }
734
735    #[test]
736    fn test_document_display_with_metadata() {
737        let mut additional_props = HashMap::new();
738        additional_props.insert("author".to_string(), "John Doe".to_string());
739        additional_props.insert("length".to_string(), "42".to_string());
740
741        let doc = Document {
742            id: "123".to_string(),
743            text: "This is a test document.".to_string(),
744            additional_props,
745        };
746
747        let expected = concat!(
748            "<file id: 123>\n",
749            "<metadata author: \"John Doe\" length: \"42\" />\n",
750            "This is a test document.\n",
751            "</file>\n"
752        );
753        assert_eq!(format!("{doc}"), expected);
754    }
755
756    #[test]
757    fn test_normalize_documents_with_documents() {
758        let doc1 = Document {
759            id: "doc1".to_string(),
760            text: "Document 1 text.".to_string(),
761            additional_props: HashMap::new(),
762        };
763
764        let doc2 = Document {
765            id: "doc2".to_string(),
766            text: "Document 2 text.".to_string(),
767            additional_props: HashMap::new(),
768        };
769
770        let request = CompletionRequest {
771            preamble: None,
772            chat_history: OneOrMany::one("What is the capital of France?".into()),
773            documents: vec![doc1, doc2],
774            tools: Vec::new(),
775            temperature: None,
776            max_tokens: None,
777            tool_choice: None,
778            additional_params: None,
779        };
780
781        let expected = Message::User {
782            content: OneOrMany::many(vec![
783                UserContent::document(
784                    "<file id: doc1>\nDocument 1 text.\n</file>\n".to_string(),
785                    Some(DocumentMediaType::TXT),
786                ),
787                UserContent::document(
788                    "<file id: doc2>\nDocument 2 text.\n</file>\n".to_string(),
789                    Some(DocumentMediaType::TXT),
790                ),
791            ])
792            .expect("There will be at least one document"),
793        };
794
795        assert_eq!(request.normalized_documents(), Some(expected));
796    }
797
798    #[test]
799    fn test_normalize_documents_without_documents() {
800        let request = CompletionRequest {
801            preamble: None,
802            chat_history: OneOrMany::one("What is the capital of France?".into()),
803            documents: Vec::new(),
804            tools: Vec::new(),
805            temperature: None,
806            max_tokens: None,
807            tool_choice: None,
808            additional_params: None,
809        };
810
811        assert_eq!(request.normalized_documents(), None);
812    }
813}