agent_chain_core/output_parsers/
transform.rs

1//! Base classes for output parsers that can handle streaming input.
2//!
3//! This module contains `BaseTransformOutputParser` and
4//! `BaseCumulativeTransformOutputParser` which provide streaming support.
5//! Mirrors `langchain_core.output_parsers.transform`.
6
7use std::fmt::Debug;
8
9use async_trait::async_trait;
10use futures::StreamExt;
11use futures::stream::BoxStream;
12
13use crate::error::Result;
14use crate::messages::BaseMessage;
15use crate::outputs::{ChatGenerationChunk, Generation, GenerationChunk};
16use crate::runnables::RunnableConfig;
17
18use super::base::BaseOutputParser;
19
20/// Base trait for an output parser that can handle streaming input.
21///
22/// Transform output parsers can process input streams chunk by chunk,
23/// which is useful for streaming responses from LLMs.
24#[async_trait]
25pub trait BaseTransformOutputParser: BaseOutputParser {
26    /// Parse a generation into the output type.
27    fn parse_generation(&self, generation: &Generation) -> Result<Self::Output> {
28        self.parse(&generation.text)
29    }
30
31    /// Transform an input stream into an output stream.
32    ///
33    /// Default implementation yields a parsed result for each chunk.
34    fn transform<'a>(
35        &'a self,
36        input: BoxStream<'a, StringOrMessage>,
37    ) -> BoxStream<'a, Result<Self::Output>>
38    where
39        Self::Output: 'a,
40    {
41        Box::pin(async_stream::stream! {
42            let mut stream = input;
43            while let Some(chunk) = stream.next().await {
44                let generation = match chunk {
45                    StringOrMessage::Text(text) => Generation::new(text),
46                    StringOrMessage::Message(msg) => Generation::new((*msg).content()),
47                };
48                yield self.parse_generation(&generation);
49            }
50        })
51    }
52
53    /// Async transform an input stream into an output stream.
54    fn atransform<'a>(
55        &'a self,
56        input: BoxStream<'a, StringOrMessage>,
57    ) -> BoxStream<'a, Result<Self::Output>>
58    where
59        Self::Output: 'a,
60    {
61        self.transform(input)
62    }
63}
64
65/// Base trait for an output parser that accumulates chunks before parsing.
66///
67/// This is useful for parsers that need to see the complete output before
68/// parsing, but want to yield intermediate results during streaming.
69/// For example, a JSON parser might yield partial JSON objects as they're built up.
70#[async_trait]
71pub trait BaseCumulativeTransformOutputParser: BaseOutputParser {
72    /// Whether to yield diffs between the previous and current parsed output,
73    /// or just the current parsed output.
74    fn diff_mode(&self) -> bool {
75        false
76    }
77
78    /// Convert parsed outputs into a diff format.
79    ///
80    /// The semantics of this are up to the output parser.
81    /// Default implementation returns the next value unchanged.
82    fn compute_diff(&self, _prev: Option<&Self::Output>, next: Self::Output) -> Self::Output {
83        next
84    }
85
86    /// Transform an input stream into an output stream, accumulating chunks.
87    ///
88    /// This accumulates input chunks and parses the accumulated result,
89    /// yielding intermediate results as they change.
90    fn transform<'a>(
91        &'a self,
92        input: BoxStream<'a, StringOrMessage>,
93        _config: Option<RunnableConfig>,
94    ) -> BoxStream<'a, Result<Self::Output>>
95    where
96        Self::Output: PartialEq + 'a,
97    {
98        let diff_mode = self.diff_mode();
99
100        Box::pin(async_stream::stream! {
101            let mut prev_parsed: Option<Self::Output> = None;
102            let mut acc_gen: Option<AccumulatedGeneration> = None;
103            let mut stream = input;
104
105            while let Some(chunk) = stream.next().await {
106                let chunk_gen = match chunk {
107                    StringOrMessage::Text(text) => AccumulatedGeneration::Text(text),
108                    StringOrMessage::Message(msg) => {
109                        AccumulatedGeneration::Text((*msg).content().to_string())
110                    }
111                };
112
113                acc_gen = Some(match acc_gen {
114                    None => chunk_gen,
115                    Some(acc) => acc.add(chunk_gen),
116                });
117
118                if let Some(ref acc) = acc_gen {
119                    let generation = acc.to_generation();
120                    if let Ok(parsed) = self.parse_result(&[generation], true) {
121                        let should_yield = match &prev_parsed {
122                            Some(prev) => parsed != *prev,
123                            None => true,
124                        };
125
126                        if should_yield {
127                            if diff_mode {
128                                yield Ok(self.compute_diff(prev_parsed.as_ref(), parsed.clone()));
129                            } else {
130                                yield Ok(parsed.clone());
131                            }
132                            prev_parsed = Some(parsed);
133                        }
134                    }
135                }
136            }
137        })
138    }
139
140    /// Async transform an input stream into an output stream.
141    fn atransform<'a>(
142        &'a self,
143        input: BoxStream<'a, StringOrMessage>,
144        config: Option<RunnableConfig>,
145    ) -> BoxStream<'a, Result<Self::Output>>
146    where
147        Self::Output: PartialEq + 'a,
148    {
149        self.transform(input, config)
150    }
151}
152
153/// Input type that can be either a string or a message.
154#[derive(Debug, Clone)]
155pub enum StringOrMessage {
156    /// Raw text input.
157    Text(String),
158    /// Message input.
159    Message(Box<BaseMessage>),
160}
161
162impl From<String> for StringOrMessage {
163    fn from(text: String) -> Self {
164        StringOrMessage::Text(text)
165    }
166}
167
168impl From<&str> for StringOrMessage {
169    fn from(text: &str) -> Self {
170        StringOrMessage::Text(text.to_string())
171    }
172}
173
174impl From<BaseMessage> for StringOrMessage {
175    fn from(msg: BaseMessage) -> Self {
176        StringOrMessage::Message(Box::new(msg))
177    }
178}
179
180/// Accumulated generation state for streaming.
181#[derive(Debug, Clone)]
182#[allow(dead_code)]
183enum AccumulatedGeneration {
184    /// Accumulated text.
185    Text(String),
186    /// Accumulated generation chunk.
187    GenerationChunk(GenerationChunk),
188    /// Accumulated chat generation chunk.
189    ChatGenerationChunk(Box<ChatGenerationChunk>),
190}
191
192impl AccumulatedGeneration {
193    /// Add another chunk to this accumulation.
194    fn add(self, other: AccumulatedGeneration) -> Self {
195        match (self, other) {
196            (AccumulatedGeneration::Text(mut left), AccumulatedGeneration::Text(right)) => {
197                left.push_str(&right);
198                AccumulatedGeneration::Text(left)
199            }
200            (
201                AccumulatedGeneration::GenerationChunk(left),
202                AccumulatedGeneration::GenerationChunk(right),
203            ) => AccumulatedGeneration::GenerationChunk(left + right),
204            (
205                AccumulatedGeneration::ChatGenerationChunk(left),
206                AccumulatedGeneration::ChatGenerationChunk(right),
207            ) => AccumulatedGeneration::ChatGenerationChunk(Box::new(*left + *right)),
208            (AccumulatedGeneration::Text(text), AccumulatedGeneration::GenerationChunk(chunk)) => {
209                let combined = GenerationChunk::new(text) + chunk;
210                AccumulatedGeneration::GenerationChunk(combined)
211            }
212            (AccumulatedGeneration::GenerationChunk(chunk), AccumulatedGeneration::Text(text)) => {
213                let combined = chunk + GenerationChunk::new(text);
214                AccumulatedGeneration::GenerationChunk(combined)
215            }
216            (left, right) => {
217                let left_gen = left.to_generation();
218                let right_gen = right.to_generation();
219                let combined_text = format!("{}{}", left_gen.text, right_gen.text);
220                AccumulatedGeneration::Text(combined_text)
221            }
222        }
223    }
224
225    /// Convert to a Generation for parsing.
226    fn to_generation(&self) -> Generation {
227        match self {
228            AccumulatedGeneration::Text(text) => Generation::new(text.clone()),
229            AccumulatedGeneration::GenerationChunk(chunk) => Generation::from(chunk.clone()),
230            AccumulatedGeneration::ChatGenerationChunk(chunk) => {
231                Generation::new(chunk.as_ref().text.clone())
232            }
233        }
234    }
235}
236
237#[cfg(test)]
238mod tests {
239    use super::*;
240
241    #[derive(Debug, Clone)]
242    struct TestTransformParser;
243
244    impl BaseOutputParser for TestTransformParser {
245        type Output = String;
246
247        fn parse(&self, text: &str) -> Result<String> {
248            Ok(text.to_uppercase())
249        }
250
251        fn parser_type(&self) -> &str {
252            "test_transform"
253        }
254    }
255
256    impl BaseTransformOutputParser for TestTransformParser {}
257
258    #[test]
259    fn test_transform_parser_parse() {
260        let parser = TestTransformParser;
261        let result = parser.parse("hello").unwrap();
262        assert_eq!(result, "HELLO");
263    }
264
265    #[test]
266    fn test_transform_parser_parse_generation() {
267        let parser = TestTransformParser;
268        let generation = Generation::new("world");
269        let result = parser.parse_generation(&generation).unwrap();
270        assert_eq!(result, "WORLD");
271    }
272
273    #[test]
274    fn test_string_or_message_from_string() {
275        let input: StringOrMessage = "test".into();
276        match input {
277            StringOrMessage::Text(t) => assert_eq!(t, "test"),
278            _ => panic!("Expected Text variant"),
279        }
280    }
281
282    #[test]
283    fn test_accumulated_generation_add_text() {
284        let left = AccumulatedGeneration::Text("Hello ".to_string());
285        let right = AccumulatedGeneration::Text("World".to_string());
286        let result = left.add(right);
287
288        if let AccumulatedGeneration::Text(text) = result {
289            assert_eq!(text, "Hello World");
290        } else {
291            panic!("Expected Text variant");
292        }
293    }
294
295    #[test]
296    fn test_accumulated_generation_to_generation() {
297        let acc = AccumulatedGeneration::Text("test".to_string());
298        let generation = acc.to_generation();
299        assert_eq!(generation.text, "test");
300    }
301}