agent_chain_core/output_parsers/
transform.rs1use std::fmt::Debug;
8
9use async_trait::async_trait;
10use futures::StreamExt;
11use futures::stream::BoxStream;
12
13use crate::error::Result;
14use crate::messages::BaseMessage;
15use crate::outputs::{ChatGenerationChunk, Generation, GenerationChunk};
16use crate::runnables::RunnableConfig;
17
18use super::base::BaseOutputParser;
19
20#[async_trait]
25pub trait BaseTransformOutputParser: BaseOutputParser {
26 fn parse_generation(&self, generation: &Generation) -> Result<Self::Output> {
28 self.parse(&generation.text)
29 }
30
31 fn transform<'a>(
35 &'a self,
36 input: BoxStream<'a, StringOrMessage>,
37 ) -> BoxStream<'a, Result<Self::Output>>
38 where
39 Self::Output: 'a,
40 {
41 Box::pin(async_stream::stream! {
42 let mut stream = input;
43 while let Some(chunk) = stream.next().await {
44 let generation = match chunk {
45 StringOrMessage::Text(text) => Generation::new(text),
46 StringOrMessage::Message(msg) => Generation::new((*msg).content()),
47 };
48 yield self.parse_generation(&generation);
49 }
50 })
51 }
52
53 fn atransform<'a>(
55 &'a self,
56 input: BoxStream<'a, StringOrMessage>,
57 ) -> BoxStream<'a, Result<Self::Output>>
58 where
59 Self::Output: 'a,
60 {
61 self.transform(input)
62 }
63}
64
65#[async_trait]
71pub trait BaseCumulativeTransformOutputParser: BaseOutputParser {
72 fn diff_mode(&self) -> bool {
75 false
76 }
77
78 fn compute_diff(&self, _prev: Option<&Self::Output>, next: Self::Output) -> Self::Output {
83 next
84 }
85
86 fn transform<'a>(
91 &'a self,
92 input: BoxStream<'a, StringOrMessage>,
93 _config: Option<RunnableConfig>,
94 ) -> BoxStream<'a, Result<Self::Output>>
95 where
96 Self::Output: PartialEq + 'a,
97 {
98 let diff_mode = self.diff_mode();
99
100 Box::pin(async_stream::stream! {
101 let mut prev_parsed: Option<Self::Output> = None;
102 let mut acc_gen: Option<AccumulatedGeneration> = None;
103 let mut stream = input;
104
105 while let Some(chunk) = stream.next().await {
106 let chunk_gen = match chunk {
107 StringOrMessage::Text(text) => AccumulatedGeneration::Text(text),
108 StringOrMessage::Message(msg) => {
109 AccumulatedGeneration::Text((*msg).content().to_string())
110 }
111 };
112
113 acc_gen = Some(match acc_gen {
114 None => chunk_gen,
115 Some(acc) => acc.add(chunk_gen),
116 });
117
118 if let Some(ref acc) = acc_gen {
119 let generation = acc.to_generation();
120 if let Ok(parsed) = self.parse_result(&[generation], true) {
121 let should_yield = match &prev_parsed {
122 Some(prev) => parsed != *prev,
123 None => true,
124 };
125
126 if should_yield {
127 if diff_mode {
128 yield Ok(self.compute_diff(prev_parsed.as_ref(), parsed.clone()));
129 } else {
130 yield Ok(parsed.clone());
131 }
132 prev_parsed = Some(parsed);
133 }
134 }
135 }
136 }
137 })
138 }
139
140 fn atransform<'a>(
142 &'a self,
143 input: BoxStream<'a, StringOrMessage>,
144 config: Option<RunnableConfig>,
145 ) -> BoxStream<'a, Result<Self::Output>>
146 where
147 Self::Output: PartialEq + 'a,
148 {
149 self.transform(input, config)
150 }
151}
152
153#[derive(Debug, Clone)]
155pub enum StringOrMessage {
156 Text(String),
158 Message(Box<BaseMessage>),
160}
161
162impl From<String> for StringOrMessage {
163 fn from(text: String) -> Self {
164 StringOrMessage::Text(text)
165 }
166}
167
168impl From<&str> for StringOrMessage {
169 fn from(text: &str) -> Self {
170 StringOrMessage::Text(text.to_string())
171 }
172}
173
174impl From<BaseMessage> for StringOrMessage {
175 fn from(msg: BaseMessage) -> Self {
176 StringOrMessage::Message(Box::new(msg))
177 }
178}
179
180#[derive(Debug, Clone)]
182#[allow(dead_code)]
183enum AccumulatedGeneration {
184 Text(String),
186 GenerationChunk(GenerationChunk),
188 ChatGenerationChunk(Box<ChatGenerationChunk>),
190}
191
192impl AccumulatedGeneration {
193 fn add(self, other: AccumulatedGeneration) -> Self {
195 match (self, other) {
196 (AccumulatedGeneration::Text(mut left), AccumulatedGeneration::Text(right)) => {
197 left.push_str(&right);
198 AccumulatedGeneration::Text(left)
199 }
200 (
201 AccumulatedGeneration::GenerationChunk(left),
202 AccumulatedGeneration::GenerationChunk(right),
203 ) => AccumulatedGeneration::GenerationChunk(left + right),
204 (
205 AccumulatedGeneration::ChatGenerationChunk(left),
206 AccumulatedGeneration::ChatGenerationChunk(right),
207 ) => AccumulatedGeneration::ChatGenerationChunk(Box::new(*left + *right)),
208 (AccumulatedGeneration::Text(text), AccumulatedGeneration::GenerationChunk(chunk)) => {
209 let combined = GenerationChunk::new(text) + chunk;
210 AccumulatedGeneration::GenerationChunk(combined)
211 }
212 (AccumulatedGeneration::GenerationChunk(chunk), AccumulatedGeneration::Text(text)) => {
213 let combined = chunk + GenerationChunk::new(text);
214 AccumulatedGeneration::GenerationChunk(combined)
215 }
216 (left, right) => {
217 let left_gen = left.to_generation();
218 let right_gen = right.to_generation();
219 let combined_text = format!("{}{}", left_gen.text, right_gen.text);
220 AccumulatedGeneration::Text(combined_text)
221 }
222 }
223 }
224
225 fn to_generation(&self) -> Generation {
227 match self {
228 AccumulatedGeneration::Text(text) => Generation::new(text.clone()),
229 AccumulatedGeneration::GenerationChunk(chunk) => Generation::from(chunk.clone()),
230 AccumulatedGeneration::ChatGenerationChunk(chunk) => {
231 Generation::new(chunk.as_ref().text.clone())
232 }
233 }
234 }
235}
236
237#[cfg(test)]
238mod tests {
239 use super::*;
240
241 #[derive(Debug, Clone)]
242 struct TestTransformParser;
243
244 impl BaseOutputParser for TestTransformParser {
245 type Output = String;
246
247 fn parse(&self, text: &str) -> Result<String> {
248 Ok(text.to_uppercase())
249 }
250
251 fn parser_type(&self) -> &str {
252 "test_transform"
253 }
254 }
255
256 impl BaseTransformOutputParser for TestTransformParser {}
257
258 #[test]
259 fn test_transform_parser_parse() {
260 let parser = TestTransformParser;
261 let result = parser.parse("hello").unwrap();
262 assert_eq!(result, "HELLO");
263 }
264
265 #[test]
266 fn test_transform_parser_parse_generation() {
267 let parser = TestTransformParser;
268 let generation = Generation::new("world");
269 let result = parser.parse_generation(&generation).unwrap();
270 assert_eq!(result, "WORLD");
271 }
272
273 #[test]
274 fn test_string_or_message_from_string() {
275 let input: StringOrMessage = "test".into();
276 match input {
277 StringOrMessage::Text(t) => assert_eq!(t, "test"),
278 _ => panic!("Expected Text variant"),
279 }
280 }
281
282 #[test]
283 fn test_accumulated_generation_add_text() {
284 let left = AccumulatedGeneration::Text("Hello ".to_string());
285 let right = AccumulatedGeneration::Text("World".to_string());
286 let result = left.add(right);
287
288 if let AccumulatedGeneration::Text(text) = result {
289 assert_eq!(text, "Hello World");
290 } else {
291 panic!("Expected Text variant");
292 }
293 }
294
295 #[test]
296 fn test_accumulated_generation_to_generation() {
297 let acc = AccumulatedGeneration::Text("test".to_string());
298 let generation = acc.to_generation();
299 assert_eq!(generation.text, "test");
300 }
301}