agent_chain_core/outputs/
chat_generation.rs

1//! Chat generation output classes.
2//!
3//! This module contains the `ChatGeneration` and `ChatGenerationChunk` types
4//! which represent chat message generation outputs from chat models.
5//! Mirrors `langchain_core.outputs.chat_generation`.
6
7use serde::{Deserialize, Serialize};
8use serde_json::Value;
9use std::collections::HashMap;
10use std::ops::Add;
11
12#[cfg(feature = "specta")]
13use specta::Type;
14
15use crate::messages::BaseMessage;
16use crate::utils::merge::merge_dicts;
17
18/// A single chat generation output.
19///
20/// A subclass of `Generation` that represents the response from a chat model
21/// that generates chat messages.
22///
23/// The `message` attribute is a structured representation of the chat message.
24/// Most of the time, the message will be of type `AIMessage`.
25///
26/// Users working with chat models will usually access information via either
27/// `AIMessage` (returned from runnable interfaces) or `LLMResult` (available
28/// via callbacks).
29#[cfg_attr(feature = "specta", derive(Type))]
30#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
31pub struct ChatGeneration {
32    /// The text contents of the output message.
33    ///
34    /// **Warning:** This field is automatically set from the message content
35    /// and should not be set directly!
36    #[serde(default)]
37    pub text: String,
38
39    /// The message output by the chat model.
40    pub message: BaseMessage,
41
42    /// Raw response from the provider.
43    ///
44    /// May include things like the reason for finishing or token log probabilities.
45    #[serde(skip_serializing_if = "Option::is_none")]
46    pub generation_info: Option<HashMap<String, Value>>,
47
48    /// Type is used exclusively for serialization purposes.
49    #[serde(rename = "type", default = "default_chat_generation_type")]
50    pub generation_type: String,
51}
52
53fn default_chat_generation_type() -> String {
54    "ChatGeneration".to_string()
55}
56
57impl ChatGeneration {
58    /// Create a new ChatGeneration from a message.
59    ///
60    /// The text field is automatically set from the message content.
61    pub fn new(message: BaseMessage) -> Self {
62        let text = extract_text_from_message(&message);
63        Self {
64            text,
65            message,
66            generation_info: None,
67            generation_type: "ChatGeneration".to_string(),
68        }
69    }
70
71    /// Create a new ChatGeneration with generation info.
72    pub fn with_info(message: BaseMessage, generation_info: HashMap<String, Value>) -> Self {
73        let text = extract_text_from_message(&message);
74        Self {
75            text,
76            message,
77            generation_info: Some(generation_info),
78            generation_type: "ChatGeneration".to_string(),
79        }
80    }
81}
82
83/// Extract text from a message.
84///
85/// This corresponds to the `set_text` model validator in Python which
86/// extracts the text content from the message.
87fn extract_text_from_message(message: &BaseMessage) -> String {
88    message.content().to_string()
89}
90
91/// `ChatGeneration` chunk.
92///
93/// `ChatGeneration` chunks can be concatenated with other `ChatGeneration` chunks.
94#[cfg_attr(feature = "specta", derive(Type))]
95#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
96pub struct ChatGenerationChunk {
97    /// The text contents of the output message.
98    #[serde(default)]
99    pub text: String,
100
101    /// The message chunk output by the chat model.
102    pub message: BaseMessage,
103
104    /// Raw response from the provider.
105    #[serde(skip_serializing_if = "Option::is_none")]
106    pub generation_info: Option<HashMap<String, Value>>,
107
108    /// Type is used exclusively for serialization purposes.
109    #[serde(rename = "type", default = "default_chat_generation_chunk_type")]
110    pub generation_type: String,
111}
112
113fn default_chat_generation_chunk_type() -> String {
114    "ChatGenerationChunk".to_string()
115}
116
117impl ChatGenerationChunk {
118    /// Create a new ChatGenerationChunk from a message.
119    pub fn new(message: BaseMessage) -> Self {
120        let text = extract_text_from_message(&message);
121        Self {
122            text,
123            message,
124            generation_info: None,
125            generation_type: "ChatGenerationChunk".to_string(),
126        }
127    }
128
129    /// Create a new ChatGenerationChunk with generation info.
130    pub fn with_info(message: BaseMessage, generation_info: HashMap<String, Value>) -> Self {
131        let text = extract_text_from_message(&message);
132        Self {
133            text,
134            message,
135            generation_info: Some(generation_info),
136            generation_type: "ChatGenerationChunk".to_string(),
137        }
138    }
139}
140
141impl Add for ChatGenerationChunk {
142    type Output = ChatGenerationChunk;
143
144    /// Concatenate two `ChatGenerationChunk`s.
145    ///
146    /// Returns a new `ChatGenerationChunk` concatenated from self and other.
147    fn add(self, other: ChatGenerationChunk) -> Self::Output {
148        let generation_info = merge_generation_info(self.generation_info, other.generation_info);
149
150        // For message merging, we concatenate the text content
151        // In a more complete implementation, this would use proper message chunk merging
152        let merged_text = self.text + &other.text;
153
154        // Create a new AI message with the merged content
155        let merged_message = crate::messages::AIMessage::new(&merged_text);
156
157        ChatGenerationChunk {
158            text: merged_text,
159            message: merged_message.into(),
160            generation_info,
161            generation_type: "ChatGenerationChunk".to_string(),
162        }
163    }
164}
165
166/// Merge generation info from two chunks.
167fn merge_generation_info(
168    left: Option<HashMap<String, Value>>,
169    right: Option<HashMap<String, Value>>,
170) -> Option<HashMap<String, Value>> {
171    match (left, right) {
172        (Some(left_map), Some(right_map)) => {
173            let left_value =
174                serde_json::to_value(&left_map).unwrap_or(Value::Object(Default::default()));
175            let right_value =
176                serde_json::to_value(&right_map).unwrap_or(Value::Object(Default::default()));
177            match merge_dicts(left_value, vec![right_value]) {
178                Ok(Value::Object(map)) => {
179                    let result: HashMap<String, Value> = map.into_iter().collect();
180                    if result.is_empty() {
181                        None
182                    } else {
183                        Some(result)
184                    }
185                }
186                _ => None,
187            }
188        }
189        (Some(info), None) | (None, Some(info)) => Some(info),
190        (None, None) => None,
191    }
192}
193
194impl From<ChatGeneration> for ChatGenerationChunk {
195    fn from(chat_gen: ChatGeneration) -> Self {
196        ChatGenerationChunk {
197            text: chat_gen.text,
198            message: chat_gen.message,
199            generation_info: chat_gen.generation_info,
200            generation_type: "ChatGenerationChunk".to_string(),
201        }
202    }
203}
204
205impl From<ChatGenerationChunk> for ChatGeneration {
206    fn from(chunk: ChatGenerationChunk) -> Self {
207        ChatGeneration {
208            text: chunk.text,
209            message: chunk.message,
210            generation_info: chunk.generation_info,
211            generation_type: "ChatGeneration".to_string(),
212        }
213    }
214}
215
216/// Merge a list of `ChatGenerationChunk`s into a single `ChatGenerationChunk`.
217///
218/// Returns `None` if the input list is empty.
219pub fn merge_chat_generation_chunks(
220    chunks: Vec<ChatGenerationChunk>,
221) -> Option<ChatGenerationChunk> {
222    if chunks.is_empty() {
223        return None;
224    }
225
226    if chunks.len() == 1 {
227        return chunks.into_iter().next();
228    }
229
230    let mut iter = chunks.into_iter();
231    let first = iter.next()?;
232    Some(iter.fold(first, |acc, chunk| acc + chunk))
233}
234
235#[cfg(test)]
236mod tests {
237    use super::*;
238    use crate::messages::AIMessage;
239    use serde_json::json;
240
241    #[test]
242    fn test_chat_generation_new() {
243        let msg = AIMessage::new("Hello, world!");
244        let chat_gen = ChatGeneration::new(msg.into());
245        assert_eq!(chat_gen.text, "Hello, world!");
246        assert!(chat_gen.generation_info.is_none());
247        assert_eq!(chat_gen.generation_type, "ChatGeneration");
248    }
249
250    #[test]
251    fn test_chat_generation_with_info() {
252        let msg = AIMessage::new("Hello");
253        let mut info = HashMap::new();
254        info.insert("finish_reason".to_string(), json!("stop"));
255        let chat_gen = ChatGeneration::with_info(msg.into(), info.clone());
256        assert_eq!(chat_gen.text, "Hello");
257        assert_eq!(chat_gen.generation_info, Some(info));
258    }
259
260    #[test]
261    fn test_chat_generation_chunk_add() {
262        let msg1 = AIMessage::new("Hello, ");
263        let msg2 = AIMessage::new("world!");
264        let chunk1 = ChatGenerationChunk::new(msg1.into());
265        let chunk2 = ChatGenerationChunk::new(msg2.into());
266        let result = chunk1 + chunk2;
267        assert_eq!(result.text, "Hello, world!");
268    }
269
270    #[test]
271    fn test_merge_chat_generation_chunks_empty() {
272        let result = merge_chat_generation_chunks(vec![]);
273        assert!(result.is_none());
274    }
275
276    #[test]
277    fn test_merge_chat_generation_chunks_single() {
278        let msg = AIMessage::new("Hello");
279        let chunk = ChatGenerationChunk::new(msg.into());
280        let result = merge_chat_generation_chunks(vec![chunk.clone()]);
281        assert!(result.is_some());
282        assert_eq!(result.unwrap().text, "Hello");
283    }
284
285    #[test]
286    fn test_merge_chat_generation_chunks_multiple() {
287        let msg1 = AIMessage::new("Hello, ");
288        let msg2 = AIMessage::new("world");
289        let msg3 = AIMessage::new("!");
290        let chunks = vec![
291            ChatGenerationChunk::new(msg1.into()),
292            ChatGenerationChunk::new(msg2.into()),
293            ChatGenerationChunk::new(msg3.into()),
294        ];
295        let result = merge_chat_generation_chunks(chunks);
296        assert!(result.is_some());
297        assert_eq!(result.unwrap().text, "Hello, world!");
298    }
299}