agentik_sdk/streaming/mod.rs
1//! Streaming support for real-time message generation.
2//!
3//! This module provides the `MessageStream` struct which handles Server-Sent Events (SSE)
4//! from the Anthropic API, accumulates messages from incremental updates, and provides
5//! an event-driven API for processing streaming responses.
6
7pub mod events;
8
9use std::collections::HashMap;
10use std::sync::{Arc, Mutex};
11use futures::Stream;
12use pin_project::pin_project;
13use tokio::sync::{broadcast, oneshot};
14use tokio_stream::wrappers::BroadcastStream;
15
16use crate::types::{
17 Message, MessageStreamEvent, ContentBlock, ContentBlockDelta,
18 AnthropicError, Result
19};
20
21use self::events::{EventHandler, EventType};
22
23/// A streaming response from the Anthropic API.
24///
25/// `MessageStream` provides an event-driven interface for processing streaming responses
26/// from Claude. It accumulates message content from incremental updates and provides
27/// both callback-based and async iteration APIs.
28///
29/// # Examples
30///
31/// ## Callback-based processing:
32/// ```ignore
33/// # use agentik_sdk::{Anthropic, MessageCreateBuilder};
34/// # async fn example() -> agentik_sdk::Result<()> {
35/// let client = Anthropic::new("your-api-key")?;
36/// let stream = client.messages().create_stream(
37/// MessageCreateBuilder::new("claude-3-5-sonnet-latest", 1024)
38/// .user("Write a story about AI")
39/// .stream(true)
40/// .build()
41/// ).await?;
42///
43/// let final_message = stream
44/// .on_text(|delta, _snapshot| {
45/// print!("{}", delta);
46/// })
47/// .on_error(|error| {
48/// eprintln!("Stream error: {}", error);
49/// })
50/// .final_message().await?;
51/// # Ok(())
52/// # }
53/// ```
54///
55/// ## Async iteration:
56/// ```ignore
57/// # use agentik_sdk::{Anthropic, MessageCreateBuilder, MessageStreamEvent};
58/// # use futures::StreamExt;
59/// # async fn example() -> agentik_sdk::Result<()> {
60/// let client = Anthropic::new("your-api-key")?;
61/// let mut stream = client.messages().create_stream(
62/// MessageCreateBuilder::new("claude-3-5-sonnet-latest", 1024)
63/// .user("Tell me a joke")
64/// .stream(true)
65/// .build()
66/// ).await?;
67///
68/// while let Some(event) = stream.next().await {
69/// match event? {
70/// MessageStreamEvent::ContentBlockDelta { delta, .. } => {
71/// // Process incremental content
72/// }
73/// MessageStreamEvent::MessageStop => break,
74/// _ => {}
75/// }
76/// }
77/// # Ok(())
78/// # }
79/// ```
80#[pin_project]
81pub struct MessageStream {
82 /// Current accumulated message snapshot
83 current_message: Arc<Mutex<Option<Message>>>,
84
85 /// Event handlers for different event types
86 event_handlers: Arc<Mutex<HashMap<EventType, Vec<EventHandler>>>>,
87
88 /// Broadcast channel for distributing events to handlers
89 event_sender: broadcast::Sender<MessageStreamEvent>,
90
91 /// Stream for events from the underlying HTTP stream
92 #[pin]
93 event_stream: BroadcastStream<MessageStreamEvent>,
94
95 /// Channel for signaling when the stream ends
96 completion_sender: Option<oneshot::Sender<Result<Message>>>,
97 completion_receiver: oneshot::Receiver<Result<Message>>,
98
99 /// Whether the stream has ended
100 ended: Arc<Mutex<bool>>,
101
102 /// Whether an error occurred
103 errored: Arc<Mutex<bool>>,
104
105 /// Whether the stream was aborted by the user
106 aborted: Arc<Mutex<bool>>,
107
108 /// Response metadata
109 response: Option<reqwest::Response>,
110 request_id: Option<String>,
111}
112
113impl MessageStream {
114 /// Create a new MessageStream from an HTTP response.
115 ///
116 /// This is typically called internally by the SDK when creating streaming requests.
117 pub fn new(response: reqwest::Response, request_id: Option<String>) -> Self {
118 let (event_sender, event_receiver) = broadcast::channel(1000);
119 let (completion_sender, completion_receiver) = oneshot::channel();
120
121 Self {
122 current_message: Arc::new(Mutex::new(None)),
123 event_handlers: Arc::new(Mutex::new(HashMap::new())),
124 event_sender,
125 event_stream: BroadcastStream::new(event_receiver),
126 completion_sender: Some(completion_sender),
127 completion_receiver,
128 ended: Arc::new(Mutex::new(false)),
129 errored: Arc::new(Mutex::new(false)),
130 aborted: Arc::new(Mutex::new(false)),
131 response: Some(response),
132 request_id,
133 }
134 }
135
136 /// Create a new MessageStream from an HttpStreamClient.
137 ///
138 /// This connects a real HTTP stream to the MessageStream, providing
139 /// proper streaming functionality for real-time response processing.
140 pub fn from_http_stream(mut http_stream: crate::http::streaming::HttpStreamClient) -> Result<Self> {
141 let (event_sender, event_receiver) = broadcast::channel(1000);
142 let (completion_sender, completion_receiver) = oneshot::channel();
143
144 let current_message = Arc::new(Mutex::new(None));
145 let ended = Arc::new(Mutex::new(false));
146 let errored = Arc::new(Mutex::new(false));
147 let request_id = http_stream.request_id().map(|s| s.to_string());
148
149 // Clone references for the background task
150 let current_message_clone = current_message.clone();
151 let ended_clone = ended.clone();
152 let errored_clone = errored.clone();
153 let event_sender_clone = event_sender.clone();
154
155 // Spawn task to process HTTP stream events
156 tokio::spawn(async move {
157 use futures::StreamExt;
158 let mut final_message: Option<crate::types::Message> = None;
159
160 while let Some(event_result) = http_stream.next().await {
161 match event_result {
162 Ok(event) => {
163 // Update current message state
164 match &event {
165 crate::types::MessageStreamEvent::MessageStart { message } => {
166 *current_message_clone.lock().unwrap() = Some(message.clone());
167 final_message = Some(message.clone());
168 }
169 crate::types::MessageStreamEvent::ContentBlockStart { content_block, index } => {
170 if let Some(ref mut msg) = *current_message_clone.lock().unwrap() {
171 while msg.content.len() <= *index {
172 msg.content.push(crate::types::ContentBlock::Text { text: String::new() });
173 }
174 msg.content[*index] = content_block.clone();
175 }
176 if let Some(ref mut msg) = final_message.as_mut() {
177 while msg.content.len() <= *index {
178 msg.content.push(crate::types::ContentBlock::Text { text: String::new() });
179 }
180 msg.content[*index] = content_block.clone();
181 }
182 }
183 crate::types::MessageStreamEvent::ContentBlockDelta { delta, index } => {
184 if let Some(ref mut msg) = *current_message_clone.lock().unwrap() {
185 if let Some(content_block) = msg.content.get_mut(*index) {
186 if let (crate::types::ContentBlock::Text { text },
187 crate::types::ContentBlockDelta::TextDelta { text: delta_text }) =
188 (content_block, delta) {
189 text.push_str(delta_text);
190 }
191 }
192 }
193 if let Some(ref mut msg) = final_message.as_mut() {
194 if let Some(content_block) = msg.content.get_mut(*index) {
195 if let (crate::types::ContentBlock::Text { text },
196 crate::types::ContentBlockDelta::TextDelta { text: delta_text }) =
197 (content_block, delta) {
198 text.push_str(delta_text);
199 }
200 }
201 }
202 }
203 crate::types::MessageStreamEvent::MessageDelta { delta, usage } => {
204 if let Some(ref mut msg) = *current_message_clone.lock().unwrap() {
205 if let Some(stop_reason) = &delta.stop_reason {
206 msg.stop_reason = Some(stop_reason.clone());
207 }
208 if let Some(stop_sequence) = &delta.stop_sequence {
209 msg.stop_sequence = Some(stop_sequence.clone());
210 }
211 msg.usage.output_tokens = usage.output_tokens;
212 if let Some(input_tokens) = usage.input_tokens {
213 msg.usage.input_tokens = input_tokens;
214 }
215 if let Some(cache_creation) = usage.cache_creation_input_tokens {
216 msg.usage.cache_creation_input_tokens = Some(cache_creation);
217 }
218 if let Some(cache_read) = usage.cache_read_input_tokens {
219 msg.usage.cache_read_input_tokens = Some(cache_read);
220 }
221 }
222 if let Some(ref mut msg) = final_message.as_mut() {
223 if let Some(stop_reason) = &delta.stop_reason {
224 msg.stop_reason = Some(stop_reason.clone());
225 }
226 if let Some(stop_sequence) = &delta.stop_sequence {
227 msg.stop_sequence = Some(stop_sequence.clone());
228 }
229 msg.usage.output_tokens = usage.output_tokens;
230 if let Some(input_tokens) = usage.input_tokens {
231 msg.usage.input_tokens = input_tokens;
232 }
233 if let Some(cache_creation) = usage.cache_creation_input_tokens {
234 msg.usage.cache_creation_input_tokens = Some(cache_creation);
235 }
236 if let Some(cache_read) = usage.cache_read_input_tokens {
237 msg.usage.cache_read_input_tokens = Some(cache_read);
238 }
239 }
240 }
241 crate::types::MessageStreamEvent::MessageStop => {
242 *ended_clone.lock().unwrap() = true;
243 // Send the final message
244 if let Some(message) = final_message.clone() {
245 let _ = completion_sender.send(Ok(message));
246 } else {
247 let _ = completion_sender.send(Err(crate::types::AnthropicError::StreamError(
248 "Stream ended without message".to_string()
249 )));
250 }
251 // Send final event and break
252 let _ = event_sender_clone.send(event);
253 break;
254 }
255 _ => {}
256 }
257
258 // Send event to broadcast channel for callbacks
259 let _ = event_sender_clone.send(event);
260 }
261 Err(e) => {
262 *errored_clone.lock().unwrap() = true;
263 let _ = completion_sender.send(Err(e));
264 break;
265 }
266 }
267 }
268 });
269
270 Ok(Self {
271 current_message,
272 event_handlers: Arc::new(Mutex::new(HashMap::new())),
273 event_sender,
274 event_stream: BroadcastStream::new(event_receiver),
275 completion_sender: None, // Already consumed by the task
276 completion_receiver,
277 ended,
278 errored,
279 aborted: Arc::new(Mutex::new(false)),
280 response: None, // No response needed for HTTP stream
281 request_id,
282 })
283 }
284
285 /// Register a callback for text delta events.
286 ///
287 /// The callback receives two parameters:
288 /// - `delta`: The new text being appended
289 /// - `snapshot`: The current accumulated text
290 ///
291 /// # Examples
292 /// ```rust,no_run
293 /// # use agentik_sdk::MessageStream;
294 /// # async fn example(stream: MessageStream) {
295 /// stream.on_text(|delta, snapshot| {
296 /// print!("{}", delta);
297 /// println!("Total so far: {}", snapshot);
298 /// });
299 /// # }
300 /// ```
301 pub fn on_text<F>(self, callback: F) -> Self
302 where
303 F: Fn(&str, &str) + Send + Sync + 'static,
304 {
305 self.on(EventType::Text, EventHandler::Text(Box::new(callback)))
306 }
307
308 /// Register a callback for stream events.
309 ///
310 /// This provides access to all raw stream events and the current message snapshot.
311 ///
312 /// # Examples
313 /// ```rust,no_run
314 /// # use agentik_sdk::{MessageStream, MessageStreamEvent, Message};
315 /// # async fn example(stream: MessageStream) {
316 /// stream.on_stream_event(|event, snapshot| {
317 /// match event {
318 /// MessageStreamEvent::ContentBlockStart { .. } => {
319 /// println!("New content block started");
320 /// }
321 /// _ => {}
322 /// }
323 /// });
324 /// # }
325 /// ```
326 pub fn on_stream_event<F>(self, callback: F) -> Self
327 where
328 F: Fn(&MessageStreamEvent, &Message) + Send + Sync + 'static,
329 {
330 self.on(EventType::StreamEvent, EventHandler::StreamEvent(Box::new(callback)))
331 }
332
333 /// Register a callback for when a complete message is received.
334 ///
335 /// # Examples
336 /// ```rust,no_run
337 /// # use agentik_sdk::{MessageStream, Message};
338 /// # async fn example(stream: MessageStream) {
339 /// stream.on_message(|message| {
340 /// println!("Received message: {:?}", message);
341 /// });
342 /// # }
343 /// ```
344 pub fn on_message<F>(self, callback: F) -> Self
345 where
346 F: Fn(&Message) + Send + Sync + 'static,
347 {
348 self.on(EventType::Message, EventHandler::Message(Box::new(callback)))
349 }
350
351 /// Register a callback for when the final message is complete.
352 ///
353 /// # Examples
354 /// ```rust,no_run
355 /// # use agentik_sdk::{MessageStream, Message};
356 /// # async fn example(stream: MessageStream) {
357 /// stream.on_final_message(|message| {
358 /// println!("Final message: {:?}", message);
359 /// });
360 /// # }
361 /// ```
362 pub fn on_final_message<F>(self, callback: F) -> Self
363 where
364 F: Fn(&Message) + Send + Sync + 'static,
365 {
366 self.on(EventType::FinalMessage, EventHandler::FinalMessage(Box::new(callback)))
367 }
368
369 /// Register a callback for errors.
370 ///
371 /// # Examples
372 /// ```rust,no_run
373 /// # use agentik_sdk::{MessageStream, AnthropicError};
374 /// # async fn example(stream: MessageStream) {
375 /// stream.on_error(|error| {
376 /// eprintln!("Stream error: {}", error);
377 /// });
378 /// # }
379 /// ```
380 pub fn on_error<F>(self, callback: F) -> Self
381 where
382 F: Fn(&AnthropicError) + Send + Sync + 'static,
383 {
384 self.on(EventType::Error, EventHandler::Error(Box::new(callback)))
385 }
386
387 /// Register a callback for when the stream ends.
388 ///
389 /// # Examples
390 /// ```rust,no_run
391 /// # use agentik_sdk::MessageStream;
392 /// # async fn example(stream: MessageStream) {
393 /// stream.on_end(|| {
394 /// println!("Stream ended");
395 /// });
396 /// # }
397 /// ```
398 pub fn on_end<F>(self, callback: F) -> Self
399 where
400 F: Fn() + Send + Sync + 'static,
401 {
402 self.on(EventType::End, EventHandler::End(Box::new(callback)))
403 }
404
405 /// Generic method to register event handlers.
406 fn on(self, event_type: EventType, handler: EventHandler) -> Self {
407 {
408 let mut handlers = self.event_handlers.lock().unwrap();
409 handlers.entry(event_type).or_insert_with(Vec::new).push(handler);
410 }
411 self
412 }
413
414 /// Wait for the stream to complete and return the final message.
415 ///
416 /// This method will block until the stream ends and return the accumulated message.
417 ///
418 /// # Examples
419 /// ```rust,no_run
420 /// # use agentik_sdk::MessageStream;
421 /// # async fn example(stream: MessageStream) -> agentik_sdk::Result<()> {
422 /// let final_message = stream.final_message().await?;
423 /// println!("Claude said: {:?}", final_message.content);
424 /// # Ok(())
425 /// # }
426 /// ```
427 pub async fn final_message(self) -> Result<Message> {
428 self.completion_receiver.await
429 .map_err(|_| AnthropicError::StreamError("Stream ended unexpectedly".to_string()))?
430 }
431
432 /// Wait for the stream to complete without returning the message.
433 ///
434 /// This is useful when you're processing events with callbacks and just need
435 /// to wait for completion.
436 ///
437 /// # Examples
438 /// ```rust,no_run
439 /// # use agentik_sdk::MessageStream;
440 /// # async fn example(stream: MessageStream) -> agentik_sdk::Result<()> {
441 /// stream.on_text(|delta, _| print!("{}", delta))
442 /// .done().await?;
443 /// println!("\nStream completed!");
444 /// # Ok(())
445 /// # }
446 /// ```
447 pub async fn done(self) -> Result<()> {
448 self.completion_receiver.await
449 .map_err(|_| AnthropicError::StreamError("Stream ended unexpectedly".to_string()))?
450 .map(|_| ())
451 }
452
453 /// Get the current accumulated message snapshot.
454 ///
455 /// Returns `None` if the stream hasn't started or no message has been received yet.
456 pub fn current_message(&self) -> Option<Message> {
457 self.current_message.lock().unwrap().clone()
458 }
459
460 /// Check if the stream has ended.
461 pub fn ended(&self) -> bool {
462 *self.ended.lock().unwrap()
463 }
464
465 /// Check if an error occurred.
466 pub fn errored(&self) -> bool {
467 *self.errored.lock().unwrap()
468 }
469
470 /// Check if the stream was aborted.
471 pub fn aborted(&self) -> bool {
472 *self.aborted.lock().unwrap()
473 }
474
475 /// Get the response metadata.
476 pub fn response(&self) -> Option<&reqwest::Response> {
477 self.response.as_ref()
478 }
479
480 /// Get the request ID.
481 pub fn request_id(&self) -> Option<&str> {
482 self.request_id.as_deref()
483 }
484
485 /// Abort the stream.
486 ///
487 /// This will cancel the underlying HTTP request and mark the stream as aborted.
488 pub fn abort(&self) {
489 *self.aborted.lock().unwrap() = true;
490 // In a real implementation, this would cancel the HTTP request
491 }
492
493 /// Process a stream event and update the internal state.
494 ///
495 /// This method accumulates message content from incremental updates and
496 /// dispatches events to registered handlers.
497 #[allow(dead_code)]
498 fn process_event(&self, event: MessageStreamEvent) -> Result<()> {
499 // Update current message state based on the event
500 match &event {
501 MessageStreamEvent::MessageStart { message } => {
502 *self.current_message.lock().unwrap() = Some(message.clone());
503 }
504 MessageStreamEvent::ContentBlockStart { content_block, index } => {
505 if let Some(ref mut msg) = *self.current_message.lock().unwrap() {
506 // Ensure the content array is large enough
507 while msg.content.len() <= *index {
508 msg.content.push(ContentBlock::Text { text: String::new() });
509 }
510 msg.content[*index] = content_block.clone();
511 }
512 }
513 MessageStreamEvent::ContentBlockDelta { delta, index } => {
514 if let Some(ref mut msg) = *self.current_message.lock().unwrap() {
515 if let Some(content_block) = msg.content.get_mut(*index) {
516 self.apply_delta(content_block, delta)?;
517 }
518 }
519 }
520 MessageStreamEvent::MessageDelta { delta, usage } => {
521 if let Some(ref mut msg) = *self.current_message.lock().unwrap() {
522 if let Some(stop_reason) = &delta.stop_reason {
523 msg.stop_reason = Some(stop_reason.clone());
524 }
525 if let Some(stop_sequence) = &delta.stop_sequence {
526 msg.stop_sequence = Some(stop_sequence.clone());
527 }
528 msg.usage.output_tokens = usage.output_tokens;
529 if let Some(input_tokens) = usage.input_tokens {
530 msg.usage.input_tokens = input_tokens;
531 }
532 if let Some(cache_creation) = usage.cache_creation_input_tokens {
533 msg.usage.cache_creation_input_tokens = Some(cache_creation);
534 }
535 if let Some(cache_read) = usage.cache_read_input_tokens {
536 msg.usage.cache_read_input_tokens = Some(cache_read);
537 }
538 }
539 }
540 MessageStreamEvent::MessageStop => {
541 *self.ended.lock().unwrap() = true;
542 }
543 _ => {}
544 }
545
546 // Dispatch event to handlers
547 self.dispatch_event(&event)?;
548
549 // Send event to broadcast channel for async iteration
550 let _ = self.event_sender.send(event);
551
552 Ok(())
553 }
554
555 /// Apply a content block delta to update the content.
556 #[allow(dead_code)]
557 fn apply_delta(&self, content_block: &mut ContentBlock, delta: &ContentBlockDelta) -> Result<()> {
558 match (content_block, delta) {
559 (ContentBlock::Text { text }, ContentBlockDelta::TextDelta { text: delta_text }) => {
560 text.push_str(delta_text);
561 }
562 (ContentBlock::ToolUse { input, .. }, ContentBlockDelta::InputJsonDelta { partial_json }) => {
563 // In a real implementation, we'd parse the partial JSON
564 // For now, we'll just store it as-is
565 *input = serde_json::from_str(partial_json)
566 .unwrap_or_else(|_| serde_json::Value::String(partial_json.clone()));
567 }
568 _ => {
569 // Other delta types would be handled here
570 }
571 }
572 Ok(())
573 }
574
575 /// Dispatch an event to all registered handlers.
576 fn dispatch_event(&self, event: &MessageStreamEvent) -> Result<()> {
577 let handlers = self.event_handlers.lock().unwrap();
578 let current_message = self.current_message.lock().unwrap();
579
580 // Dispatch to stream event handlers
581 if let Some(stream_handlers) = handlers.get(&EventType::StreamEvent) {
582 for handler in stream_handlers {
583 if let EventHandler::StreamEvent(callback) = handler {
584 if let Some(ref msg) = *current_message {
585 callback(event, msg);
586 }
587 }
588 }
589 }
590
591 // Dispatch specific event types
592 match event {
593 MessageStreamEvent::ContentBlockDelta { delta, .. } => {
594 if let ContentBlockDelta::TextDelta { text } = delta {
595 if let Some(text_handlers) = handlers.get(&EventType::Text) {
596 for handler in text_handlers {
597 if let EventHandler::Text(callback) = handler {
598 // Get current accumulated text for snapshot
599 let snapshot = if let Some(ref msg) = *current_message {
600 self.get_accumulated_text(msg)
601 } else {
602 String::new()
603 };
604 callback(text, &snapshot);
605 }
606 }
607 }
608 }
609 }
610 MessageStreamEvent::MessageStop => {
611 if let Some(end_handlers) = handlers.get(&EventType::End) {
612 for handler in end_handlers {
613 if let EventHandler::End(callback) = handler {
614 callback();
615 }
616 }
617 }
618
619 // Send final message
620 if let Some(ref msg) = *current_message {
621 if let Some(final_handlers) = handlers.get(&EventType::FinalMessage) {
622 for handler in final_handlers {
623 if let EventHandler::FinalMessage(callback) = handler {
624 callback(msg);
625 }
626 }
627 }
628 }
629 }
630 _ => {}
631 }
632
633 Ok(())
634 }
635
636 /// Get the accumulated text from all text content blocks.
637 fn get_accumulated_text(&self, message: &Message) -> String {
638 message.content
639 .iter()
640 .filter_map(|block| match block {
641 ContentBlock::Text { text } => Some(text.as_str()),
642 _ => None,
643 })
644 .collect::<Vec<_>>()
645 .join("")
646 }
647}
648
649impl Stream for MessageStream {
650 type Item = Result<MessageStreamEvent>;
651
652 fn poll_next(
653 self: std::pin::Pin<&mut Self>,
654 cx: &mut std::task::Context<'_>,
655 ) -> std::task::Poll<Option<Self::Item>> {
656 use futures::Stream as FuturesStream;
657
658 let this = self.project();
659
660 match FuturesStream::poll_next(this.event_stream, cx) {
661 std::task::Poll::Ready(Some(Ok(event))) => {
662 std::task::Poll::Ready(Some(Ok(event)))
663 }
664 std::task::Poll::Ready(Some(Err(err))) => {
665 // Handle any broadcast stream errors
666 std::task::Poll::Ready(Some(Err(AnthropicError::StreamError(
667 format!("Stream error: {}", err)
668 ))))
669 }
670 std::task::Poll::Ready(None) => {
671 std::task::Poll::Ready(None)
672 }
673 std::task::Poll::Pending => std::task::Poll::Pending,
674 }
675 }
676}
677
678#[cfg(test)]
679mod tests {
680 use super::*;
681 use crate::types::{Role, Usage};
682
683 // For testing, we'll use a simple helper to create a dummy response
684 async fn create_dummy_response() -> reqwest::Response {
685 // Create a simple HTTP client and make a basic request for testing
686 let client = reqwest::Client::new();
687 // Use httpbin.org which provides testing endpoints
688 client.get("https://httpbin.org/status/200")
689 .send()
690 .await
691 .expect("Failed to create test response")
692 }
693
694 #[tokio::test]
695 async fn test_message_stream_creation() {
696 let response = create_dummy_response().await;
697 let stream = MessageStream::new(response, Some("test-request-id".to_string()));
698
699 assert!(!stream.ended());
700 assert!(!stream.errored());
701 assert!(!stream.aborted());
702 assert_eq!(stream.request_id(), Some("test-request-id"));
703 }
704
705 #[tokio::test]
706 async fn test_event_processing() {
707 let response = create_dummy_response().await;
708 let stream = MessageStream::new(response, None);
709
710 // Test message start event
711 let start_event = MessageStreamEvent::MessageStart {
712 message: Message {
713 id: "msg_test".to_string(),
714 type_: "message".to_string(),
715 role: Role::Assistant,
716 content: vec![],
717 model: "claude-3-5-sonnet-latest".to_string(),
718 stop_reason: None,
719 stop_sequence: None,
720 usage: Usage {
721 input_tokens: 10,
722 output_tokens: 0,
723 cache_creation_input_tokens: None,
724 cache_read_input_tokens: None,
725 server_tool_use: None,
726 service_tier: None,
727 },
728 request_id: None,
729 },
730 };
731
732 stream.process_event(start_event).unwrap();
733
734 let current = stream.current_message().unwrap();
735 assert_eq!(current.id, "msg_test");
736 assert_eq!(current.role, Role::Assistant);
737 }
738
739 #[test]
740 fn test_event_handlers() {
741 use std::sync::{Arc, Mutex};
742 use std::collections::HashMap;
743
744 // Test creating event handlers directly
745 let text_called = Arc::new(Mutex::new(false));
746 let text_called_clone = text_called.clone();
747
748 let _handler = EventHandler::Text(Box::new(move |_delta, _snapshot| {
749 *text_called_clone.lock().unwrap() = true;
750 }));
751
752 // Test event type equality
753 assert_eq!(EventType::Text, EventType::Text);
754 assert_ne!(EventType::Text, EventType::Error);
755
756 // Test using event types as hash keys
757 let mut map: HashMap<EventType, String> = HashMap::new();
758 map.insert(EventType::Text, "text_handler".to_string());
759 assert_eq!(map.get(&EventType::Text), Some(&"text_handler".to_string()));
760 }
761}