ai_chain/output/
stream.rs1use crate::prompt::{ChatRole, Data};
2use crate::traits::ExecutorError;
3use futures::StreamExt;
4use std::fmt;
5use std::pin::Pin;
6use std::task::{Context, Poll};
7use tokio::sync::mpsc::{self, UnboundedReceiver};
8use tokio_stream::Stream;
9
10use crate::prompt::{ChatMessage, ChatMessageCollection};
11#[derive(Debug)]
12pub enum StreamSegment {
13 Role(ChatRole),
14 Content(String),
15 Err(ExecutorError),
16}
17
18impl fmt::Display for StreamSegment {
19 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
20 match self {
21 StreamSegment::Role(chat_role) => write!(f, "{}", chat_role),
22 StreamSegment::Content(content) => write!(f, "{}", content),
23 StreamSegment::Err(executor_error) => write!(f, "{}", executor_error),
24 }
25 }
26}
27
28pub struct OutputStream {
29 receiver: UnboundedReceiver<StreamSegment>,
30}
31
32impl OutputStream {
33 pub(super) fn new() -> (mpsc::UnboundedSender<StreamSegment>, Self) {
34 let (sender, receiver) = mpsc::unbounded_channel();
35 (sender, Self { receiver })
36 }
37
38 pub(super) fn from_stream<S>(stream: S) -> Self
39 where
40 S: Stream<Item = StreamSegment> + Send + 'static,
41 {
42 let (sender, receiver) = mpsc::unbounded_channel();
43 let sender_clone = sender;
44 let mut stream = Box::pin(stream);
45
46 tokio::spawn(async move {
47 while let Some(segment) = stream.next().await {
48 if sender_clone.send(segment).is_err() {
49 break;
50 }
51 }
52 });
53
54 Self { receiver }
55 }
56
57 pub(super) async fn into_data(self) -> Result<Data<String>, ExecutorError> {
58 let mut messages = ChatMessageCollection::new();
59 let mut current_role = None;
60 let mut current_body = Vec::new();
61
62 let mut stream = self.receiver;
63
64 while let Some(segment) = stream.recv().await {
65 match segment {
66 StreamSegment::Role(role) => {
67 if let Some(role) = current_role {
68 if !current_body.is_empty() {
69 let body = current_body.join("");
70 messages.add_message(ChatMessage::new(role, body));
71 current_body.clear();
72 }
73 }
74 current_role = Some(role);
75 }
76 StreamSegment::Content(text) => {
77 current_body.push(text);
78 }
79 StreamSegment::Err(err) => return Err(err),
80 }
81 }
82
83 let body = current_body.join("");
84 if let Some(role) = current_role {
86 if !current_body.is_empty() {
87 messages.add_message(ChatMessage::new(role, body));
88 }
89 Ok(messages.into())
90 } else {
91 Ok(Data::text(body))
92 }
93 }
94}
95
96impl Stream for OutputStream {
97 type Item = StreamSegment;
98
99 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
100 self.receiver.poll_recv(cx)
101 }
102}