agent_chain_core/outputs/
chat_generation.rs1use serde::{Deserialize, Serialize};
8use serde_json::Value;
9use std::collections::HashMap;
10use std::ops::Add;
11
12#[cfg(feature = "specta")]
13use specta::Type;
14
15use crate::messages::BaseMessage;
16use crate::utils::merge::merge_dicts;
17
18#[cfg_attr(feature = "specta", derive(Type))]
30#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
31pub struct ChatGeneration {
32 #[serde(default)]
37 pub text: String,
38
39 pub message: BaseMessage,
41
42 #[serde(skip_serializing_if = "Option::is_none")]
46 pub generation_info: Option<HashMap<String, Value>>,
47
48 #[serde(rename = "type", default = "default_chat_generation_type")]
50 pub generation_type: String,
51}
52
53fn default_chat_generation_type() -> String {
54 "ChatGeneration".to_string()
55}
56
57impl ChatGeneration {
58 pub fn new(message: BaseMessage) -> Self {
62 let text = extract_text_from_message(&message);
63 Self {
64 text,
65 message,
66 generation_info: None,
67 generation_type: "ChatGeneration".to_string(),
68 }
69 }
70
71 pub fn with_info(message: BaseMessage, generation_info: HashMap<String, Value>) -> Self {
73 let text = extract_text_from_message(&message);
74 Self {
75 text,
76 message,
77 generation_info: Some(generation_info),
78 generation_type: "ChatGeneration".to_string(),
79 }
80 }
81}
82
83fn extract_text_from_message(message: &BaseMessage) -> String {
88 message.content().to_string()
89}
90
91#[cfg_attr(feature = "specta", derive(Type))]
95#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
96pub struct ChatGenerationChunk {
97 #[serde(default)]
99 pub text: String,
100
101 pub message: BaseMessage,
103
104 #[serde(skip_serializing_if = "Option::is_none")]
106 pub generation_info: Option<HashMap<String, Value>>,
107
108 #[serde(rename = "type", default = "default_chat_generation_chunk_type")]
110 pub generation_type: String,
111}
112
113fn default_chat_generation_chunk_type() -> String {
114 "ChatGenerationChunk".to_string()
115}
116
117impl ChatGenerationChunk {
118 pub fn new(message: BaseMessage) -> Self {
120 let text = extract_text_from_message(&message);
121 Self {
122 text,
123 message,
124 generation_info: None,
125 generation_type: "ChatGenerationChunk".to_string(),
126 }
127 }
128
129 pub fn with_info(message: BaseMessage, generation_info: HashMap<String, Value>) -> Self {
131 let text = extract_text_from_message(&message);
132 Self {
133 text,
134 message,
135 generation_info: Some(generation_info),
136 generation_type: "ChatGenerationChunk".to_string(),
137 }
138 }
139}
140
141impl Add for ChatGenerationChunk {
142 type Output = ChatGenerationChunk;
143
144 fn add(self, other: ChatGenerationChunk) -> Self::Output {
148 let generation_info = merge_generation_info(self.generation_info, other.generation_info);
149
150 let merged_text = self.text + &other.text;
153
154 let merged_message = crate::messages::AIMessage::new(&merged_text);
156
157 ChatGenerationChunk {
158 text: merged_text,
159 message: merged_message.into(),
160 generation_info,
161 generation_type: "ChatGenerationChunk".to_string(),
162 }
163 }
164}
165
166fn merge_generation_info(
168 left: Option<HashMap<String, Value>>,
169 right: Option<HashMap<String, Value>>,
170) -> Option<HashMap<String, Value>> {
171 match (left, right) {
172 (Some(left_map), Some(right_map)) => {
173 let left_value =
174 serde_json::to_value(&left_map).unwrap_or(Value::Object(Default::default()));
175 let right_value =
176 serde_json::to_value(&right_map).unwrap_or(Value::Object(Default::default()));
177 match merge_dicts(left_value, vec![right_value]) {
178 Ok(Value::Object(map)) => {
179 let result: HashMap<String, Value> = map.into_iter().collect();
180 if result.is_empty() {
181 None
182 } else {
183 Some(result)
184 }
185 }
186 _ => None,
187 }
188 }
189 (Some(info), None) | (None, Some(info)) => Some(info),
190 (None, None) => None,
191 }
192}
193
194impl From<ChatGeneration> for ChatGenerationChunk {
195 fn from(chat_gen: ChatGeneration) -> Self {
196 ChatGenerationChunk {
197 text: chat_gen.text,
198 message: chat_gen.message,
199 generation_info: chat_gen.generation_info,
200 generation_type: "ChatGenerationChunk".to_string(),
201 }
202 }
203}
204
205impl From<ChatGenerationChunk> for ChatGeneration {
206 fn from(chunk: ChatGenerationChunk) -> Self {
207 ChatGeneration {
208 text: chunk.text,
209 message: chunk.message,
210 generation_info: chunk.generation_info,
211 generation_type: "ChatGeneration".to_string(),
212 }
213 }
214}
215
216pub fn merge_chat_generation_chunks(
220 chunks: Vec<ChatGenerationChunk>,
221) -> Option<ChatGenerationChunk> {
222 if chunks.is_empty() {
223 return None;
224 }
225
226 if chunks.len() == 1 {
227 return chunks.into_iter().next();
228 }
229
230 let mut iter = chunks.into_iter();
231 let first = iter.next()?;
232 Some(iter.fold(first, |acc, chunk| acc + chunk))
233}
234
235#[cfg(test)]
236mod tests {
237 use super::*;
238 use crate::messages::AIMessage;
239 use serde_json::json;
240
241 #[test]
242 fn test_chat_generation_new() {
243 let msg = AIMessage::new("Hello, world!");
244 let chat_gen = ChatGeneration::new(msg.into());
245 assert_eq!(chat_gen.text, "Hello, world!");
246 assert!(chat_gen.generation_info.is_none());
247 assert_eq!(chat_gen.generation_type, "ChatGeneration");
248 }
249
250 #[test]
251 fn test_chat_generation_with_info() {
252 let msg = AIMessage::new("Hello");
253 let mut info = HashMap::new();
254 info.insert("finish_reason".to_string(), json!("stop"));
255 let chat_gen = ChatGeneration::with_info(msg.into(), info.clone());
256 assert_eq!(chat_gen.text, "Hello");
257 assert_eq!(chat_gen.generation_info, Some(info));
258 }
259
260 #[test]
261 fn test_chat_generation_chunk_add() {
262 let msg1 = AIMessage::new("Hello, ");
263 let msg2 = AIMessage::new("world!");
264 let chunk1 = ChatGenerationChunk::new(msg1.into());
265 let chunk2 = ChatGenerationChunk::new(msg2.into());
266 let result = chunk1 + chunk2;
267 assert_eq!(result.text, "Hello, world!");
268 }
269
270 #[test]
271 fn test_merge_chat_generation_chunks_empty() {
272 let result = merge_chat_generation_chunks(vec![]);
273 assert!(result.is_none());
274 }
275
276 #[test]
277 fn test_merge_chat_generation_chunks_single() {
278 let msg = AIMessage::new("Hello");
279 let chunk = ChatGenerationChunk::new(msg.into());
280 let result = merge_chat_generation_chunks(vec![chunk.clone()]);
281 assert!(result.is_some());
282 assert_eq!(result.unwrap().text, "Hello");
283 }
284
285 #[test]
286 fn test_merge_chat_generation_chunks_multiple() {
287 let msg1 = AIMessage::new("Hello, ");
288 let msg2 = AIMessage::new("world");
289 let msg3 = AIMessage::new("!");
290 let chunks = vec![
291 ChatGenerationChunk::new(msg1.into()),
292 ChatGenerationChunk::new(msg2.into()),
293 ChatGenerationChunk::new(msg3.into()),
294 ];
295 let result = merge_chat_generation_chunks(chunks);
296 assert!(result.is_some());
297 assert_eq!(result.unwrap().text, "Hello, world!");
298 }
299}