alchemy_llm/types/
event_stream.rs1use std::pin::Pin;
2use std::task::{Context, Poll};
3
4use futures::channel::mpsc;
5use futures::Stream;
6use tokio::sync::oneshot;
7
8use crate::error::{Error, Result};
9
10use super::event::AssistantMessageEvent;
11use super::message::AssistantMessage;
12
13pub struct AssistantMessageEventStream {
22 receiver: mpsc::UnboundedReceiver<AssistantMessageEvent>,
23 result_receiver: Option<oneshot::Receiver<AssistantMessage>>,
24}
25
26pub struct EventStreamSender {
28 sender: mpsc::UnboundedSender<AssistantMessageEvent>,
29 result_sender: Option<oneshot::Sender<AssistantMessage>>,
30}
31
32impl AssistantMessageEventStream {
33 pub fn new() -> (Self, EventStreamSender) {
38 let (tx, rx) = mpsc::unbounded();
39 let (result_tx, result_rx) = oneshot::channel();
40
41 let stream = Self {
42 receiver: rx,
43 result_receiver: Some(result_rx),
44 };
45
46 let sender = EventStreamSender {
47 sender: tx,
48 result_sender: Some(result_tx),
49 };
50
51 (stream, sender)
52 }
53
54 pub async fn result(mut self) -> Result<AssistantMessage> {
59 let receiver = self
60 .result_receiver
61 .take()
62 .ok_or_else(|| Error::InvalidResponse("result() already called".to_string()))?;
63
64 receiver
65 .await
66 .map_err(|_| Error::InvalidResponse("Stream ended without result".to_string()))
67 }
68}
69
70impl Default for AssistantMessageEventStream {
71 fn default() -> Self {
72 let (stream, _sender) = Self::new();
73 stream
74 }
75}
76
77impl Stream for AssistantMessageEventStream {
78 type Item = AssistantMessageEvent;
79
80 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
81 Pin::new(&mut self.receiver).poll_next(cx)
82 }
83}
84
85impl EventStreamSender {
86 pub fn push(&mut self, event: AssistantMessageEvent) {
91 let is_terminal = matches!(
93 &event,
94 AssistantMessageEvent::Done { .. } | AssistantMessageEvent::Error { .. }
95 );
96
97 if is_terminal {
99 let message = match &event {
100 AssistantMessageEvent::Done { message, .. } => message.clone(),
101 AssistantMessageEvent::Error { error, .. } => error.clone(),
102 _ => unreachable!(),
103 };
104
105 if let Some(sender) = self.result_sender.take() {
107 let _ = sender.send(message);
108 }
109 }
110
111 let _ = self.sender.unbounded_send(event);
113 }
114
115 pub fn end(self) {
119 }
121}
122
123#[cfg(test)]
124mod tests {
125 use super::*;
126 use crate::types::{Api, Provider, StopReason, StopReasonSuccess, Usage};
127 use futures::StreamExt;
128
129 fn make_test_message() -> AssistantMessage {
130 AssistantMessage {
131 content: vec![],
132 api: Api::OpenAICompletions,
133 provider: Provider::Known(crate::types::KnownProvider::OpenAI),
134 model: "gpt-4".to_string(),
135 usage: Usage::default(),
136 stop_reason: StopReason::Stop,
137 error_message: None,
138 timestamp: 0,
139 }
140 }
141
142 #[tokio::test]
143 async fn test_stream_events() {
144 let (mut stream, mut sender) = AssistantMessageEventStream::new();
145
146 let msg = make_test_message();
147
148 sender.push(AssistantMessageEvent::Start {
150 partial: msg.clone(),
151 });
152 sender.push(AssistantMessageEvent::TextDelta {
153 content_index: 0,
154 delta: "Hello".to_string(),
155 partial: msg.clone(),
156 });
157 sender.push(AssistantMessageEvent::Done {
158 reason: StopReasonSuccess::Stop,
159 message: msg.clone(),
160 });
161
162 let events: Vec<_> = stream.by_ref().take(3).collect().await;
164 assert_eq!(events.len(), 3);
165 assert!(matches!(events[0], AssistantMessageEvent::Start { .. }));
166 assert!(matches!(events[1], AssistantMessageEvent::TextDelta { .. }));
167 assert!(matches!(events[2], AssistantMessageEvent::Done { .. }));
168 }
169
170 #[tokio::test]
171 async fn test_result() {
172 let (stream, mut sender) = AssistantMessageEventStream::new();
173
174 let msg = make_test_message();
175
176 sender.push(AssistantMessageEvent::Done {
177 reason: StopReasonSuccess::Stop,
178 message: msg.clone(),
179 });
180
181 let result = stream.result().await.expect("stream result should succeed");
182 assert_eq!(result.model, "gpt-4");
183 }
184}