agent_chain_core/outputs/
llm_result.rs

1//! LLMResult class.
2//!
3//! This module contains the `LLMResult` type which is a container
4//! for results of an LLM call.
5//! Mirrors `langchain_core.outputs.llm_result`.
6
7use serde::{Deserialize, Serialize};
8use serde_json::Value;
9use std::collections::HashMap;
10
11#[cfg(feature = "specta")]
12use specta::Type;
13
14use super::chat_generation::{ChatGeneration, ChatGenerationChunk};
15use super::generation::{Generation, GenerationChunk};
16use super::run_info::RunInfo;
17
18/// Enum representing different types of generations.
19///
20/// This allows LLMResult to hold different generation types.
21#[cfg_attr(feature = "specta", derive(Type))]
22#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
23#[serde(untagged)]
24pub enum GenerationType {
25    /// A standard text generation.
26    Generation(Generation),
27    /// A text generation chunk.
28    GenerationChunk(GenerationChunk),
29    /// A chat generation.
30    ChatGeneration(ChatGeneration),
31    /// A chat generation chunk.
32    ChatGenerationChunk(ChatGenerationChunk),
33}
34
35impl From<Generation> for GenerationType {
36    fn from(generation: Generation) -> Self {
37        GenerationType::Generation(generation)
38    }
39}
40
41impl From<GenerationChunk> for GenerationType {
42    fn from(generation: GenerationChunk) -> Self {
43        GenerationType::GenerationChunk(generation)
44    }
45}
46
47impl From<ChatGeneration> for GenerationType {
48    fn from(generation: ChatGeneration) -> Self {
49        GenerationType::ChatGeneration(generation)
50    }
51}
52
53impl From<ChatGenerationChunk> for GenerationType {
54    fn from(generation: ChatGenerationChunk) -> Self {
55        GenerationType::ChatGenerationChunk(generation)
56    }
57}
58
59/// A container for results of an LLM call.
60///
61/// Both chat models and LLMs generate an LLMResult object. This object contains the
62/// generated outputs and any additional information that the model provider wants to
63/// return.
64#[cfg_attr(feature = "specta", derive(Type))]
65#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
66pub struct LLMResult {
67    /// Generated outputs.
68    ///
69    /// The first dimension of the list represents completions for different input prompts.
70    ///
71    /// The second dimension of the list represents different candidate generations for a
72    /// given prompt.
73    ///
74    /// - When returned from **an LLM**, the type is `list[list[Generation]]`.
75    /// - When returned from a **chat model**, the type is `list[list[ChatGeneration]]`.
76    ///
77    /// ChatGeneration is a subclass of Generation that has a field for a structured chat
78    /// message.
79    pub generations: Vec<Vec<GenerationType>>,
80
81    /// For arbitrary LLM provider specific output.
82    ///
83    /// This dictionary is a free-form dictionary that can contain any information that the
84    /// provider wants to return. It is not standardized and is provider-specific.
85    ///
86    /// Users should generally avoid relying on this field and instead rely on accessing
87    /// relevant information from standardized fields present in AIMessage.
88    #[serde(skip_serializing_if = "Option::is_none")]
89    pub llm_output: Option<HashMap<String, Value>>,
90
91    /// List of metadata info for model call for each input.
92    ///
93    /// See `RunInfo` for details.
94    #[serde(skip_serializing_if = "Option::is_none")]
95    pub run: Option<Vec<RunInfo>>,
96
97    /// Type is used exclusively for serialization purposes.
98    #[serde(rename = "type", default = "default_llm_result_type")]
99    pub result_type: String,
100}
101
102fn default_llm_result_type() -> String {
103    "LLMResult".to_string()
104}
105
106impl LLMResult {
107    /// Create a new LLMResult with the given generations.
108    pub fn new(generations: Vec<Vec<GenerationType>>) -> Self {
109        Self {
110            generations,
111            llm_output: None,
112            run: None,
113            result_type: "LLMResult".to_string(),
114        }
115    }
116
117    /// Create a new LLMResult with generations and LLM output.
118    pub fn with_llm_output(
119        generations: Vec<Vec<GenerationType>>,
120        llm_output: HashMap<String, Value>,
121    ) -> Self {
122        Self {
123            generations,
124            llm_output: Some(llm_output),
125            run: None,
126            result_type: "LLMResult".to_string(),
127        }
128    }
129
130    /// Flatten generations into a single list.
131    ///
132    /// Unpack list\[list\[Generation\]\] -> list\[LLMResult\] where each returned LLMResult
133    /// contains only a single Generation. If token usage information is available,
134    /// it is kept only for the LLMResult corresponding to the top-choice
135    /// Generation, to avoid over-counting of token usage downstream.
136    ///
137    /// Returns a list of LLMResults where each returned LLMResult contains a single
138    /// Generation.
139    pub fn flatten(&self) -> Vec<LLMResult> {
140        let mut llm_results = Vec::new();
141
142        for (i, gen_list) in self.generations.iter().enumerate() {
143            // Avoid double counting tokens in OpenAICallback
144            if i == 0 {
145                llm_results.push(LLMResult {
146                    generations: vec![gen_list.clone()],
147                    llm_output: self.llm_output.clone(),
148                    run: None,
149                    result_type: "LLMResult".to_string(),
150                });
151            } else {
152                let llm_output = if let Some(ref output) = self.llm_output {
153                    let mut cloned = output.clone();
154                    cloned.insert("token_usage".to_string(), Value::Object(Default::default()));
155                    Some(cloned)
156                } else {
157                    None
158                };
159                llm_results.push(LLMResult {
160                    generations: vec![gen_list.clone()],
161                    llm_output,
162                    run: None,
163                    result_type: "LLMResult".to_string(),
164                });
165            }
166        }
167
168        llm_results
169    }
170}
171
172impl Default for LLMResult {
173    fn default() -> Self {
174        Self {
175            generations: Vec::new(),
176            llm_output: None,
177            run: None,
178            result_type: "LLMResult".to_string(),
179        }
180    }
181}
182
183#[cfg(test)]
184mod tests {
185    use super::*;
186    use crate::messages::AIMessage;
187    use serde_json::json;
188
189    #[test]
190    fn test_llm_result_new() {
191        let generation = Generation::new("Hello");
192        let result = LLMResult::new(vec![vec![generation.into()]]);
193        assert_eq!(result.generations.len(), 1);
194        assert_eq!(result.generations[0].len(), 1);
195        assert!(result.llm_output.is_none());
196    }
197
198    #[test]
199    fn test_llm_result_with_chat_generation() {
200        let msg = AIMessage::new("Hello");
201        let chat_gen = ChatGeneration::new(msg.into());
202        let result = LLMResult::new(vec![vec![chat_gen.into()]]);
203        assert_eq!(result.generations.len(), 1);
204    }
205
206    #[test]
207    fn test_llm_result_flatten() {
208        let generation1 = Generation::new("First");
209        let generation2 = Generation::new("Second");
210        let mut output = HashMap::new();
211        output.insert("token_usage".to_string(), json!({"total": 100}));
212        let result = LLMResult::with_llm_output(
213            vec![vec![generation1.into()], vec![generation2.into()]],
214            output,
215        );
216
217        let flattened = result.flatten();
218        assert_eq!(flattened.len(), 2);
219
220        // First result should have the original llm_output
221        assert!(flattened[0].llm_output.is_some());
222        let first_output = flattened[0].llm_output.as_ref().unwrap();
223        assert_eq!(
224            first_output.get("token_usage"),
225            Some(&json!({"total": 100}))
226        );
227
228        // Second result should have empty token_usage
229        assert!(flattened[1].llm_output.is_some());
230        let second_output = flattened[1].llm_output.as_ref().unwrap();
231        assert_eq!(second_output.get("token_usage"), Some(&json!({})));
232    }
233
234    #[test]
235    fn test_llm_result_equality() {
236        let generation1 = Generation::new("Hello");
237        let generation2 = Generation::new("Hello");
238        let result1 = LLMResult::new(vec![vec![generation1.into()]]);
239        let result2 = LLMResult::new(vec![vec![generation2.into()]]);
240        assert_eq!(result1, result2);
241    }
242
243    #[test]
244    fn test_llm_result_serialization() {
245        let generation = Generation::new("test");
246        let result = LLMResult::new(vec![vec![generation.into()]]);
247        let json = serde_json::to_string(&result).unwrap();
248        assert!(json.contains("\"type\":\"LLMResult\""));
249    }
250}