Skip to main content

alchemy_llm/types/
event_stream.rs

1use std::pin::Pin;
2use std::task::{Context, Poll};
3
4use futures::channel::mpsc;
5use futures::Stream;
6use tokio::sync::oneshot;
7
8use crate::error::{Error, Result};
9
10use super::event::AssistantMessageEvent;
11use super::message::AssistantMessage;
12
13/// A stream of assistant message events.
14///
15/// This wraps an async channel and provides:
16/// - Async iteration over events via `Stream` trait
17/// - A `result()` method to await the final `AssistantMessage`
18///
19/// The stream is created by provider implementations and events are pushed
20/// via the sender handle returned from `AssistantMessageEventStream::new()`.
21pub struct AssistantMessageEventStream {
22    receiver: mpsc::UnboundedReceiver<AssistantMessageEvent>,
23    result_receiver: Option<oneshot::Receiver<AssistantMessage>>,
24}
25
26/// Handle for pushing events into an `AssistantMessageEventStream`.
27pub struct EventStreamSender {
28    sender: mpsc::UnboundedSender<AssistantMessageEvent>,
29    result_sender: Option<oneshot::Sender<AssistantMessage>>,
30}
31
32impl AssistantMessageEventStream {
33    /// Create a new event stream and sender pair.
34    ///
35    /// The sender is used by provider implementations to push events.
36    /// The stream is returned to the caller for iteration.
37    pub fn new() -> (Self, EventStreamSender) {
38        let (tx, rx) = mpsc::unbounded();
39        let (result_tx, result_rx) = oneshot::channel();
40
41        let stream = Self {
42            receiver: rx,
43            result_receiver: Some(result_rx),
44        };
45
46        let sender = EventStreamSender {
47            sender: tx,
48            result_sender: Some(result_tx),
49        };
50
51        (stream, sender)
52    }
53
54    /// Await the final result of the stream.
55    ///
56    /// This consumes the result receiver, so it can only be called once.
57    /// Returns the final `AssistantMessage` when the stream completes.
58    pub async fn result(mut self) -> Result<AssistantMessage> {
59        let receiver = self
60            .result_receiver
61            .take()
62            .ok_or_else(|| Error::InvalidResponse("result() already called".to_string()))?;
63
64        receiver
65            .await
66            .map_err(|_| Error::InvalidResponse("Stream ended without result".to_string()))
67    }
68}
69
70impl Default for AssistantMessageEventStream {
71    fn default() -> Self {
72        let (stream, _sender) = Self::new();
73        stream
74    }
75}
76
77impl Stream for AssistantMessageEventStream {
78    type Item = AssistantMessageEvent;
79
80    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
81        Pin::new(&mut self.receiver).poll_next(cx)
82    }
83}
84
85impl EventStreamSender {
86    /// Push an event to the stream.
87    ///
88    /// If the event is a terminal event (Done or Error), this also
89    /// resolves the result future with the final message.
90    pub fn push(&mut self, event: AssistantMessageEvent) {
91        // Check if this is a terminal event
92        let is_terminal = matches!(
93            &event,
94            AssistantMessageEvent::Done { .. } | AssistantMessageEvent::Error { .. }
95        );
96
97        // Extract the final message if terminal
98        if is_terminal {
99            let message = match &event {
100                AssistantMessageEvent::Done { message, .. } => message.clone(),
101                AssistantMessageEvent::Error { error, .. } => error.clone(),
102                _ => unreachable!(),
103            };
104
105            // Send the result (ignore error if receiver dropped)
106            if let Some(sender) = self.result_sender.take() {
107                let _ = sender.send(message);
108            }
109        }
110
111        // Send the event (ignore error if receiver dropped)
112        let _ = self.sender.unbounded_send(event);
113    }
114
115    /// End the stream without sending a terminal event.
116    ///
117    /// This closes the channel. Any pending `result()` calls will fail.
118    pub fn end(self) {
119        // Dropping self closes the channels
120    }
121}
122
123#[cfg(test)]
124mod tests {
125    use super::*;
126    use crate::types::{Api, Provider, StopReason, StopReasonSuccess, Usage};
127    use futures::StreamExt;
128
129    fn make_test_message() -> AssistantMessage {
130        AssistantMessage {
131            content: vec![],
132            api: Api::OpenAICompletions,
133            provider: Provider::Known(crate::types::KnownProvider::OpenAI),
134            model: "gpt-4".to_string(),
135            usage: Usage::default(),
136            stop_reason: StopReason::Stop,
137            error_message: None,
138            timestamp: 0,
139        }
140    }
141
142    #[tokio::test]
143    async fn test_stream_events() {
144        let (mut stream, mut sender) = AssistantMessageEventStream::new();
145
146        let msg = make_test_message();
147
148        // Push events
149        sender.push(AssistantMessageEvent::Start {
150            partial: msg.clone(),
151        });
152        sender.push(AssistantMessageEvent::TextDelta {
153            content_index: 0,
154            delta: "Hello".to_string(),
155            partial: msg.clone(),
156        });
157        sender.push(AssistantMessageEvent::Done {
158            reason: StopReasonSuccess::Stop,
159            message: msg.clone(),
160        });
161
162        // Collect events
163        let events: Vec<_> = stream.by_ref().take(3).collect().await;
164        assert_eq!(events.len(), 3);
165        assert!(matches!(events[0], AssistantMessageEvent::Start { .. }));
166        assert!(matches!(events[1], AssistantMessageEvent::TextDelta { .. }));
167        assert!(matches!(events[2], AssistantMessageEvent::Done { .. }));
168    }
169
170    #[tokio::test]
171    async fn test_result() {
172        let (stream, mut sender) = AssistantMessageEventStream::new();
173
174        let msg = make_test_message();
175
176        sender.push(AssistantMessageEvent::Done {
177            reason: StopReasonSuccess::Stop,
178            message: msg.clone(),
179        });
180
181        let result = stream.result().await.expect("stream result should succeed");
182        assert_eq!(result.model, "gpt-4");
183    }
184}