Skip to main content

ai_chain/output/
stream.rs

1use crate::prompt::{ChatRole, Data};
2use crate::traits::ExecutorError;
3use futures::StreamExt;
4use std::fmt;
5use std::pin::Pin;
6use std::task::{Context, Poll};
7use tokio::sync::mpsc::{self, UnboundedReceiver};
8use tokio_stream::Stream;
9
10use crate::prompt::{ChatMessage, ChatMessageCollection};
11#[derive(Debug)]
12pub enum StreamSegment {
13    Role(ChatRole),
14    Content(String),
15    Err(ExecutorError),
16}
17
18impl fmt::Display for StreamSegment {
19    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
20        match self {
21            StreamSegment::Role(chat_role) => write!(f, "{}", chat_role),
22            StreamSegment::Content(content) => write!(f, "{}", content),
23            StreamSegment::Err(executor_error) => write!(f, "{}", executor_error),
24        }
25    }
26}
27
28pub struct OutputStream {
29    receiver: UnboundedReceiver<StreamSegment>,
30}
31
32impl OutputStream {
33    pub(super) fn new() -> (mpsc::UnboundedSender<StreamSegment>, Self) {
34        let (sender, receiver) = mpsc::unbounded_channel();
35        (sender, Self { receiver })
36    }
37
38    pub(super) fn from_stream<S>(stream: S) -> Self
39    where
40        S: Stream<Item = StreamSegment> + Send + 'static,
41    {
42        let (sender, receiver) = mpsc::unbounded_channel();
43        let sender_clone = sender;
44        let mut stream = Box::pin(stream);
45
46        tokio::spawn(async move {
47            while let Some(segment) = stream.next().await {
48                if sender_clone.send(segment).is_err() {
49                    break;
50                }
51            }
52        });
53
54        Self { receiver }
55    }
56
57    pub(super) async fn into_data(self) -> Result<Data<String>, ExecutorError> {
58        let mut messages = ChatMessageCollection::new();
59        let mut current_role = None;
60        let mut current_body = Vec::new();
61
62        let mut stream = self.receiver;
63
64        while let Some(segment) = stream.recv().await {
65            match segment {
66                StreamSegment::Role(role) => {
67                    if let Some(role) = current_role {
68                        if !current_body.is_empty() {
69                            let body = current_body.join("");
70                            messages.add_message(ChatMessage::new(role, body));
71                            current_body.clear();
72                        }
73                    }
74                    current_role = Some(role);
75                }
76                StreamSegment::Content(text) => {
77                    current_body.push(text);
78                }
79                StreamSegment::Err(err) => return Err(err),
80            }
81        }
82
83        let body = current_body.join("");
84        // Handle any remaining message
85        if let Some(role) = current_role {
86            if !current_body.is_empty() {
87                messages.add_message(ChatMessage::new(role, body));
88            }
89            Ok(messages.into())
90        } else {
91            Ok(Data::text(body))
92        }
93    }
94}
95
96impl Stream for OutputStream {
97    type Item = StreamSegment;
98
99    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
100        self.receiver.poll_recv(cx)
101    }
102}