llm_chain/prompt/
model.rs1use serde::{Deserialize, Serialize};
2
3use std::fmt;
4
5#[derive(Debug, Clone, Serialize, Deserialize)]
6pub enum Data<T> {
8 Chat(ChatMessageCollection<T>),
10 Text(T),
12}
13
14impl<T> Data<T> {
15 pub fn text(text: T) -> Self {
16 Self::Text(text)
17 }
18
19 pub fn map<U, F: Fn(&T) -> U>(&self, f: F) -> Data<U> {
29 match self {
30 Self::Chat(chat) => Data::Chat(chat.map(|msg| msg.map(|body| f(body)))),
31 Self::Text(text) => Data::Text(f(text)),
32 }
33 }
34
35 pub fn try_map<U, E, F: Fn(&T) -> Result<U, E>>(&self, f: F) -> Result<Data<U>, E> {
46 match self {
47 Self::Chat(chat) => {
48 let result = chat.try_map(|msg| f(msg))?;
49 Ok(Data::Chat(result))
50 }
51 Self::Text(text) => Ok(Data::Text(f(text)?)),
52 }
53 }
54
55 pub fn extract_last_body(&self) -> Option<&T> {
57 match self {
58 Self::Chat(c) => c.extract_last_body(),
59 Self::Text(t) => Some(t),
60 }
61 }
62}
63
64impl<T: fmt::Display> fmt::Display for Data<T> {
65 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
66 match self {
67 Self::Chat(chat) => write!(f, "{}", chat),
68 Self::Text(text) => write!(f, "{}", text),
69 }
70 }
71}
72
73impl Data<String> {
74 pub fn to_chat(&self) -> ChatMessageCollection<String> {
75 match self {
76 Self::Chat(chat) => chat.clone(),
77 Self::Text(text) => {
78 let mut chat = ChatMessageCollection::new();
79 chat.add_message(ChatMessage::new(ChatRole::User, text.clone()));
80 chat
81 }
82 }
83 }
84 pub fn to_text(&self) -> String {
85 match self {
86 Self::Text(text) => text.clone(),
87 Self::Chat(chat) => chat.to_string(),
88 }
89 }
90
91 pub fn combine(&self, other: &Self) -> Self {
99 match (self, other) {
100 (Self::Chat(chat1), Self::Chat(chat2)) => {
101 let mut chat = chat1.clone();
102 chat.append(chat2.clone());
103 Self::Chat(chat)
104 }
105 (Self::Chat(chat), Self::Text(text)) => {
106 let mut chat = chat.clone();
107 chat.add_message(ChatMessage::new(ChatRole::User, text.clone()));
108 Self::Chat(chat)
109 }
110 (Self::Text(text), Self::Chat(chat)) => {
111 let mut chat = chat.clone();
112 chat.add_message(ChatMessage::new(ChatRole::User, text.clone()));
113 Self::Chat(chat)
114 }
115 (Self::Text(text1), Self::Text(text2)) => {
116 let combined_text = format!("{}\n\n{}", text1, text2);
117 Self::Text(combined_text)
118 }
119 }
120 }
121}
122
123impl<T> From<T> for Data<T> {
124 fn from(text: T) -> Self {
125 Self::Text(text)
126 }
127}
128
129impl<T> From<ChatMessageCollection<T>> for Data<T> {
130 fn from(chat: ChatMessageCollection<T>) -> Self {
131 Self::Chat(chat)
132 }
133}
134
135impl<T> From<ChatMessage<T>> for Data<T> {
136 fn from(chat: ChatMessage<T>) -> Self {
137 Self::Chat(ChatMessageCollection::for_vector(vec![chat]))
138 }
139}
140
141use crate::frame::FormatAndExecuteError;
142use crate::output::Output;
143use crate::prompt::{StringTemplate, StringTemplateError};
144use crate::step::Step;
145use crate::traits::Executor;
146use crate::Parameters;
147
148use super::chat::ChatMessageCollection;
149use super::{ChatMessage, ChatRole};
150
151impl Data<StringTemplate> {
152 pub async fn run<E: Executor>(
161 &self,
162 parameters: &Parameters,
163 executor: &E,
164 ) -> Result<Output, FormatAndExecuteError> {
165 Step::for_prompt_template(self.clone())
166 .run(parameters, executor)
167 .await
168 }
169
170 pub fn format(&self, parameters: &Parameters) -> Result<Data<String>, StringTemplateError> {
171 self.try_map(|x| x.format(parameters))
172 }
173}