llm_chain/prompt/
model.rs

1use serde::{Deserialize, Serialize};
2
3use std::fmt;
4
5#[derive(Debug, Clone, Serialize, Deserialize)]
6/// An enum representing either a collection of chat messages or a single text.
7pub enum Data<T> {
8    /// A collection of chat messages.
9    Chat(ChatMessageCollection<T>),
10    /// A text prompt.
11    Text(T),
12}
13
14impl<T> Data<T> {
15    pub fn text(text: T) -> Self {
16        Self::Text(text)
17    }
18
19    /// Maps the body of the chat messages or the text in the `Data` enum using the provided function.
20    ///
21    /// # Arguments
22    ///
23    /// * `f` - A function that takes a reference to the body of a chat message or the text and returns a value of type `U`.
24    ///
25    /// # Returns
26    ///
27    /// A new `Data<U>` with the body of the chat messages or the text mapped by the provided function.
28    pub fn map<U, F: Fn(&T) -> U>(&self, f: F) -> Data<U> {
29        match self {
30            Self::Chat(chat) => Data::Chat(chat.map(|msg| msg.map(|body| f(body)))),
31            Self::Text(text) => Data::Text(f(text)),
32        }
33    }
34
35    /// Maps the body of the chat messages or the text in the `Data` enum using the provided function that might fail.
36    ///
37    /// # Arguments
38    ///
39    /// * `f` - A function that takes a reference to the body of a chat message or the text and returns a `Result<U, E>` value.
40    ///
41    /// # Returns
42    ///
43    /// A `Result<Data<U>, E>` with the body of the chat messages or the text mapped by the provided function.
44    /// If the provided function returns an error, the error will be propagated in the result.
45    pub fn try_map<U, E, F: Fn(&T) -> Result<U, E>>(&self, f: F) -> Result<Data<U>, E> {
46        match self {
47            Self::Chat(chat) => {
48                let result = chat.try_map(|msg| f(msg))?;
49                Ok(Data::Chat(result))
50            }
51            Self::Text(text) => Ok(Data::Text(f(text)?)),
52        }
53    }
54
55    /// Extracts the body of the last message in the Data, or simply returns the Text if it is a text prompt
56    pub fn extract_last_body(&self) -> Option<&T> {
57        match self {
58            Self::Chat(c) => c.extract_last_body(),
59            Self::Text(t) => Some(t),
60        }
61    }
62}
63
64impl<T: fmt::Display> fmt::Display for Data<T> {
65    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
66        match self {
67            Self::Chat(chat) => write!(f, "{}", chat),
68            Self::Text(text) => write!(f, "{}", text),
69        }
70    }
71}
72
73impl Data<String> {
74    pub fn to_chat(&self) -> ChatMessageCollection<String> {
75        match self {
76            Self::Chat(chat) => chat.clone(),
77            Self::Text(text) => {
78                let mut chat = ChatMessageCollection::new();
79                chat.add_message(ChatMessage::new(ChatRole::User, text.clone()));
80                chat
81            }
82        }
83    }
84    pub fn to_text(&self) -> String {
85        match self {
86            Self::Text(text) => text.clone(),
87            Self::Chat(chat) => chat.to_string(),
88        }
89    }
90
91    /// Combines two `Data` values into one.
92    ///
93    /// If both values are `Chat`, the two chat collections will be combined.
94    /// If one value is `Chat` and the other is `Text`, the text will be added as a message to the chat collection.
95    ///
96    /// # Arguments
97    /// - `other` - The other `Data` value to combine with.
98    pub fn combine(&self, other: &Self) -> Self {
99        match (self, other) {
100            (Self::Chat(chat1), Self::Chat(chat2)) => {
101                let mut chat = chat1.clone();
102                chat.append(chat2.clone());
103                Self::Chat(chat)
104            }
105            (Self::Chat(chat), Self::Text(text)) => {
106                let mut chat = chat.clone();
107                chat.add_message(ChatMessage::new(ChatRole::User, text.clone()));
108                Self::Chat(chat)
109            }
110            (Self::Text(text), Self::Chat(chat)) => {
111                let mut chat = chat.clone();
112                chat.add_message(ChatMessage::new(ChatRole::User, text.clone()));
113                Self::Chat(chat)
114            }
115            (Self::Text(text1), Self::Text(text2)) => {
116                let combined_text = format!("{}\n\n{}", text1, text2);
117                Self::Text(combined_text)
118            }
119        }
120    }
121}
122
123impl<T> From<T> for Data<T> {
124    fn from(text: T) -> Self {
125        Self::Text(text)
126    }
127}
128
129impl<T> From<ChatMessageCollection<T>> for Data<T> {
130    fn from(chat: ChatMessageCollection<T>) -> Self {
131        Self::Chat(chat)
132    }
133}
134
135impl<T> From<ChatMessage<T>> for Data<T> {
136    fn from(chat: ChatMessage<T>) -> Self {
137        Self::Chat(ChatMessageCollection::for_vector(vec![chat]))
138    }
139}
140
141use crate::frame::FormatAndExecuteError;
142use crate::output::Output;
143use crate::prompt::{StringTemplate, StringTemplateError};
144use crate::step::Step;
145use crate::traits::Executor;
146use crate::Parameters;
147
148use super::chat::ChatMessageCollection;
149use super::{ChatMessage, ChatRole};
150
151impl Data<StringTemplate> {
152    /// Helper function to run a prompt template.
153    ///
154    /// # Arguments
155    /// parameters: &Parameters - The parameters to use for the prompt template.
156    /// executor: &E - The executor to use for the prompt template.
157    ///
158    /// # Returns
159    /// The output of applying the prompt template to the model.
160    pub async fn run<E: Executor>(
161        &self,
162        parameters: &Parameters,
163        executor: &E,
164    ) -> Result<Output, FormatAndExecuteError> {
165        Step::for_prompt_template(self.clone())
166            .run(parameters, executor)
167            .await
168    }
169
170    pub fn format(&self, parameters: &Parameters) -> Result<Data<String>, StringTemplateError> {
171        self.try_map(|x| x.format(parameters))
172    }
173}