agent_chain_core/outputs/
llm_result.rs1use serde::{Deserialize, Serialize};
8use serde_json::Value;
9use std::collections::HashMap;
10
11#[cfg(feature = "specta")]
12use specta::Type;
13
14use super::chat_generation::{ChatGeneration, ChatGenerationChunk};
15use super::generation::{Generation, GenerationChunk};
16use super::run_info::RunInfo;
17
18#[cfg_attr(feature = "specta", derive(Type))]
22#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
23#[serde(untagged)]
24pub enum GenerationType {
25 Generation(Generation),
27 GenerationChunk(GenerationChunk),
29 ChatGeneration(ChatGeneration),
31 ChatGenerationChunk(ChatGenerationChunk),
33}
34
35impl From<Generation> for GenerationType {
36 fn from(generation: Generation) -> Self {
37 GenerationType::Generation(generation)
38 }
39}
40
41impl From<GenerationChunk> for GenerationType {
42 fn from(generation: GenerationChunk) -> Self {
43 GenerationType::GenerationChunk(generation)
44 }
45}
46
47impl From<ChatGeneration> for GenerationType {
48 fn from(generation: ChatGeneration) -> Self {
49 GenerationType::ChatGeneration(generation)
50 }
51}
52
53impl From<ChatGenerationChunk> for GenerationType {
54 fn from(generation: ChatGenerationChunk) -> Self {
55 GenerationType::ChatGenerationChunk(generation)
56 }
57}
58
59#[cfg_attr(feature = "specta", derive(Type))]
65#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
66pub struct LLMResult {
67 pub generations: Vec<Vec<GenerationType>>,
80
81 #[serde(skip_serializing_if = "Option::is_none")]
89 pub llm_output: Option<HashMap<String, Value>>,
90
91 #[serde(skip_serializing_if = "Option::is_none")]
95 pub run: Option<Vec<RunInfo>>,
96
97 #[serde(rename = "type", default = "default_llm_result_type")]
99 pub result_type: String,
100}
101
102fn default_llm_result_type() -> String {
103 "LLMResult".to_string()
104}
105
106impl LLMResult {
107 pub fn new(generations: Vec<Vec<GenerationType>>) -> Self {
109 Self {
110 generations,
111 llm_output: None,
112 run: None,
113 result_type: "LLMResult".to_string(),
114 }
115 }
116
117 pub fn with_llm_output(
119 generations: Vec<Vec<GenerationType>>,
120 llm_output: HashMap<String, Value>,
121 ) -> Self {
122 Self {
123 generations,
124 llm_output: Some(llm_output),
125 run: None,
126 result_type: "LLMResult".to_string(),
127 }
128 }
129
130 pub fn flatten(&self) -> Vec<LLMResult> {
140 let mut llm_results = Vec::new();
141
142 for (i, gen_list) in self.generations.iter().enumerate() {
143 if i == 0 {
145 llm_results.push(LLMResult {
146 generations: vec![gen_list.clone()],
147 llm_output: self.llm_output.clone(),
148 run: None,
149 result_type: "LLMResult".to_string(),
150 });
151 } else {
152 let llm_output = if let Some(ref output) = self.llm_output {
153 let mut cloned = output.clone();
154 cloned.insert("token_usage".to_string(), Value::Object(Default::default()));
155 Some(cloned)
156 } else {
157 None
158 };
159 llm_results.push(LLMResult {
160 generations: vec![gen_list.clone()],
161 llm_output,
162 run: None,
163 result_type: "LLMResult".to_string(),
164 });
165 }
166 }
167
168 llm_results
169 }
170}
171
172impl Default for LLMResult {
173 fn default() -> Self {
174 Self {
175 generations: Vec::new(),
176 llm_output: None,
177 run: None,
178 result_type: "LLMResult".to_string(),
179 }
180 }
181}
182
183#[cfg(test)]
184mod tests {
185 use super::*;
186 use crate::messages::AIMessage;
187 use serde_json::json;
188
189 #[test]
190 fn test_llm_result_new() {
191 let generation = Generation::new("Hello");
192 let result = LLMResult::new(vec![vec![generation.into()]]);
193 assert_eq!(result.generations.len(), 1);
194 assert_eq!(result.generations[0].len(), 1);
195 assert!(result.llm_output.is_none());
196 }
197
198 #[test]
199 fn test_llm_result_with_chat_generation() {
200 let msg = AIMessage::new("Hello");
201 let chat_gen = ChatGeneration::new(msg.into());
202 let result = LLMResult::new(vec![vec![chat_gen.into()]]);
203 assert_eq!(result.generations.len(), 1);
204 }
205
206 #[test]
207 fn test_llm_result_flatten() {
208 let generation1 = Generation::new("First");
209 let generation2 = Generation::new("Second");
210 let mut output = HashMap::new();
211 output.insert("token_usage".to_string(), json!({"total": 100}));
212 let result = LLMResult::with_llm_output(
213 vec![vec![generation1.into()], vec![generation2.into()]],
214 output,
215 );
216
217 let flattened = result.flatten();
218 assert_eq!(flattened.len(), 2);
219
220 assert!(flattened[0].llm_output.is_some());
222 let first_output = flattened[0].llm_output.as_ref().unwrap();
223 assert_eq!(
224 first_output.get("token_usage"),
225 Some(&json!({"total": 100}))
226 );
227
228 assert!(flattened[1].llm_output.is_some());
230 let second_output = flattened[1].llm_output.as_ref().unwrap();
231 assert_eq!(second_output.get("token_usage"), Some(&json!({})));
232 }
233
234 #[test]
235 fn test_llm_result_equality() {
236 let generation1 = Generation::new("Hello");
237 let generation2 = Generation::new("Hello");
238 let result1 = LLMResult::new(vec![vec![generation1.into()]]);
239 let result2 = LLMResult::new(vec![vec![generation2.into()]]);
240 assert_eq!(result1, result2);
241 }
242
243 #[test]
244 fn test_llm_result_serialization() {
245 let generation = Generation::new("test");
246 let result = LLMResult::new(vec![vec![generation.into()]]);
247 let json = serde_json::to_string(&result).unwrap();
248 assert!(json.contains("\"type\":\"LLMResult\""));
249 }
250}