pub mod events;
use std::collections::HashMap;
use std::sync::{Arc, Mutex};
use futures::Stream;
use pin_project::pin_project;
use tokio::sync::{broadcast, oneshot};
use tokio_stream::wrappers::BroadcastStream;
use crate::types::{
Message, MessageStreamEvent, ContentBlock, ContentBlockDelta,
AnthropicError, Result
};
use self::events::{EventHandler, EventType};
#[pin_project]
pub struct MessageStream {
current_message: Arc<Mutex<Option<Message>>>,
event_handlers: Arc<Mutex<HashMap<EventType, Vec<EventHandler>>>>,
event_sender: broadcast::Sender<MessageStreamEvent>,
#[pin]
event_stream: BroadcastStream<MessageStreamEvent>,
completion_sender: Option<oneshot::Sender<Result<Message>>>,
completion_receiver: oneshot::Receiver<Result<Message>>,
ended: Arc<Mutex<bool>>,
errored: Arc<Mutex<bool>>,
aborted: Arc<Mutex<bool>>,
response: Option<reqwest::Response>,
request_id: Option<String>,
}
impl MessageStream {
pub fn new(response: reqwest::Response, request_id: Option<String>) -> Self {
let (event_sender, event_receiver) = broadcast::channel(1000);
let (completion_sender, completion_receiver) = oneshot::channel();
Self {
current_message: Arc::new(Mutex::new(None)),
event_handlers: Arc::new(Mutex::new(HashMap::new())),
event_sender,
event_stream: BroadcastStream::new(event_receiver),
completion_sender: Some(completion_sender),
completion_receiver,
ended: Arc::new(Mutex::new(false)),
errored: Arc::new(Mutex::new(false)),
aborted: Arc::new(Mutex::new(false)),
response: Some(response),
request_id,
}
}
pub fn from_http_stream(mut http_stream: crate::http::streaming::HttpStreamClient) -> Result<Self> {
let (event_sender, event_receiver) = broadcast::channel(1000);
let (completion_sender, completion_receiver) = oneshot::channel();
let current_message = Arc::new(Mutex::new(None));
let ended = Arc::new(Mutex::new(false));
let errored = Arc::new(Mutex::new(false));
let request_id = http_stream.request_id().map(|s| s.to_string());
let current_message_clone = current_message.clone();
let ended_clone = ended.clone();
let errored_clone = errored.clone();
let event_sender_clone = event_sender.clone();
tokio::spawn(async move {
use futures::StreamExt;
let mut final_message: Option<crate::types::Message> = None;
while let Some(event_result) = http_stream.next().await {
match event_result {
Ok(event) => {
match &event {
crate::types::MessageStreamEvent::MessageStart { message } => {
*current_message_clone.lock().unwrap() = Some(message.clone());
final_message = Some(message.clone());
}
crate::types::MessageStreamEvent::ContentBlockStart { content_block, index } => {
if let Some(ref mut msg) = *current_message_clone.lock().unwrap() {
while msg.content.len() <= *index {
msg.content.push(crate::types::ContentBlock::Text { text: String::new() });
}
msg.content[*index] = content_block.clone();
}
if let Some(ref mut msg) = final_message.as_mut() {
while msg.content.len() <= *index {
msg.content.push(crate::types::ContentBlock::Text { text: String::new() });
}
msg.content[*index] = content_block.clone();
}
}
crate::types::MessageStreamEvent::ContentBlockDelta { delta, index } => {
if let Some(ref mut msg) = *current_message_clone.lock().unwrap() {
if let Some(content_block) = msg.content.get_mut(*index) {
if let (crate::types::ContentBlock::Text { text },
crate::types::ContentBlockDelta::TextDelta { text: delta_text }) =
(content_block, delta) {
text.push_str(delta_text);
}
}
}
if let Some(ref mut msg) = final_message.as_mut() {
if let Some(content_block) = msg.content.get_mut(*index) {
if let (crate::types::ContentBlock::Text { text },
crate::types::ContentBlockDelta::TextDelta { text: delta_text }) =
(content_block, delta) {
text.push_str(delta_text);
}
}
}
}
crate::types::MessageStreamEvent::MessageDelta { delta, usage } => {
if let Some(ref mut msg) = *current_message_clone.lock().unwrap() {
if let Some(stop_reason) = &delta.stop_reason {
msg.stop_reason = Some(stop_reason.clone());
}
if let Some(stop_sequence) = &delta.stop_sequence {
msg.stop_sequence = Some(stop_sequence.clone());
}
msg.usage.output_tokens = usage.output_tokens;
if let Some(input_tokens) = usage.input_tokens {
msg.usage.input_tokens = input_tokens;
}
if let Some(cache_creation) = usage.cache_creation_input_tokens {
msg.usage.cache_creation_input_tokens = Some(cache_creation);
}
if let Some(cache_read) = usage.cache_read_input_tokens {
msg.usage.cache_read_input_tokens = Some(cache_read);
}
}
if let Some(ref mut msg) = final_message.as_mut() {
if let Some(stop_reason) = &delta.stop_reason {
msg.stop_reason = Some(stop_reason.clone());
}
if let Some(stop_sequence) = &delta.stop_sequence {
msg.stop_sequence = Some(stop_sequence.clone());
}
msg.usage.output_tokens = usage.output_tokens;
if let Some(input_tokens) = usage.input_tokens {
msg.usage.input_tokens = input_tokens;
}
if let Some(cache_creation) = usage.cache_creation_input_tokens {
msg.usage.cache_creation_input_tokens = Some(cache_creation);
}
if let Some(cache_read) = usage.cache_read_input_tokens {
msg.usage.cache_read_input_tokens = Some(cache_read);
}
}
}
crate::types::MessageStreamEvent::MessageStop => {
*ended_clone.lock().unwrap() = true;
if let Some(message) = final_message.clone() {
let _ = completion_sender.send(Ok(message));
} else {
let _ = completion_sender.send(Err(crate::types::AnthropicError::StreamError(
"Stream ended without message".to_string()
)));
}
let _ = event_sender_clone.send(event);
break;
}
_ => {}
}
let _ = event_sender_clone.send(event);
}
Err(e) => {
*errored_clone.lock().unwrap() = true;
let _ = completion_sender.send(Err(e));
break;
}
}
}
});
Ok(Self {
current_message,
event_handlers: Arc::new(Mutex::new(HashMap::new())),
event_sender,
event_stream: BroadcastStream::new(event_receiver),
completion_sender: None, completion_receiver,
ended,
errored,
aborted: Arc::new(Mutex::new(false)),
response: None, request_id,
})
}
pub fn on_text<F>(self, callback: F) -> Self
where
F: Fn(&str, &str) + Send + Sync + 'static,
{
self.on(EventType::Text, EventHandler::Text(Box::new(callback)))
}
pub fn on_stream_event<F>(self, callback: F) -> Self
where
F: Fn(&MessageStreamEvent, &Message) + Send + Sync + 'static,
{
self.on(EventType::StreamEvent, EventHandler::StreamEvent(Box::new(callback)))
}
pub fn on_message<F>(self, callback: F) -> Self
where
F: Fn(&Message) + Send + Sync + 'static,
{
self.on(EventType::Message, EventHandler::Message(Box::new(callback)))
}
pub fn on_final_message<F>(self, callback: F) -> Self
where
F: Fn(&Message) + Send + Sync + 'static,
{
self.on(EventType::FinalMessage, EventHandler::FinalMessage(Box::new(callback)))
}
pub fn on_error<F>(self, callback: F) -> Self
where
F: Fn(&AnthropicError) + Send + Sync + 'static,
{
self.on(EventType::Error, EventHandler::Error(Box::new(callback)))
}
pub fn on_end<F>(self, callback: F) -> Self
where
F: Fn() + Send + Sync + 'static,
{
self.on(EventType::End, EventHandler::End(Box::new(callback)))
}
fn on(self, event_type: EventType, handler: EventHandler) -> Self {
{
let mut handlers = self.event_handlers.lock().unwrap();
handlers.entry(event_type).or_insert_with(Vec::new).push(handler);
}
self
}
pub async fn final_message(self) -> Result<Message> {
self.completion_receiver.await
.map_err(|_| AnthropicError::StreamError("Stream ended unexpectedly".to_string()))?
}
pub async fn done(self) -> Result<()> {
self.completion_receiver.await
.map_err(|_| AnthropicError::StreamError("Stream ended unexpectedly".to_string()))?
.map(|_| ())
}
pub fn current_message(&self) -> Option<Message> {
self.current_message.lock().unwrap().clone()
}
pub fn ended(&self) -> bool {
*self.ended.lock().unwrap()
}
pub fn errored(&self) -> bool {
*self.errored.lock().unwrap()
}
pub fn aborted(&self) -> bool {
*self.aborted.lock().unwrap()
}
pub fn response(&self) -> Option<&reqwest::Response> {
self.response.as_ref()
}
pub fn request_id(&self) -> Option<&str> {
self.request_id.as_deref()
}
pub fn abort(&self) {
*self.aborted.lock().unwrap() = true;
}
#[allow(dead_code)]
fn process_event(&self, event: MessageStreamEvent) -> Result<()> {
match &event {
MessageStreamEvent::MessageStart { message } => {
*self.current_message.lock().unwrap() = Some(message.clone());
}
MessageStreamEvent::ContentBlockStart { content_block, index } => {
if let Some(ref mut msg) = *self.current_message.lock().unwrap() {
while msg.content.len() <= *index {
msg.content.push(ContentBlock::Text { text: String::new() });
}
msg.content[*index] = content_block.clone();
}
}
MessageStreamEvent::ContentBlockDelta { delta, index } => {
if let Some(ref mut msg) = *self.current_message.lock().unwrap() {
if let Some(content_block) = msg.content.get_mut(*index) {
self.apply_delta(content_block, delta)?;
}
}
}
MessageStreamEvent::MessageDelta { delta, usage } => {
if let Some(ref mut msg) = *self.current_message.lock().unwrap() {
if let Some(stop_reason) = &delta.stop_reason {
msg.stop_reason = Some(stop_reason.clone());
}
if let Some(stop_sequence) = &delta.stop_sequence {
msg.stop_sequence = Some(stop_sequence.clone());
}
msg.usage.output_tokens = usage.output_tokens;
if let Some(input_tokens) = usage.input_tokens {
msg.usage.input_tokens = input_tokens;
}
if let Some(cache_creation) = usage.cache_creation_input_tokens {
msg.usage.cache_creation_input_tokens = Some(cache_creation);
}
if let Some(cache_read) = usage.cache_read_input_tokens {
msg.usage.cache_read_input_tokens = Some(cache_read);
}
}
}
MessageStreamEvent::MessageStop => {
*self.ended.lock().unwrap() = true;
}
_ => {}
}
self.dispatch_event(&event)?;
let _ = self.event_sender.send(event);
Ok(())
}
#[allow(dead_code)]
fn apply_delta(&self, content_block: &mut ContentBlock, delta: &ContentBlockDelta) -> Result<()> {
match (content_block, delta) {
(ContentBlock::Text { text }, ContentBlockDelta::TextDelta { text: delta_text }) => {
text.push_str(delta_text);
}
(ContentBlock::ToolUse { input, .. }, ContentBlockDelta::InputJsonDelta { partial_json }) => {
*input = serde_json::from_str(partial_json)
.unwrap_or_else(|_| serde_json::Value::String(partial_json.clone()));
}
_ => {
}
}
Ok(())
}
fn dispatch_event(&self, event: &MessageStreamEvent) -> Result<()> {
let handlers = self.event_handlers.lock().unwrap();
let current_message = self.current_message.lock().unwrap();
if let Some(stream_handlers) = handlers.get(&EventType::StreamEvent) {
for handler in stream_handlers {
if let EventHandler::StreamEvent(callback) = handler {
if let Some(ref msg) = *current_message {
callback(event, msg);
}
}
}
}
match event {
MessageStreamEvent::ContentBlockDelta { delta, .. } => {
if let ContentBlockDelta::TextDelta { text } = delta {
if let Some(text_handlers) = handlers.get(&EventType::Text) {
for handler in text_handlers {
if let EventHandler::Text(callback) = handler {
let snapshot = if let Some(ref msg) = *current_message {
self.get_accumulated_text(msg)
} else {
String::new()
};
callback(text, &snapshot);
}
}
}
}
}
MessageStreamEvent::MessageStop => {
if let Some(end_handlers) = handlers.get(&EventType::End) {
for handler in end_handlers {
if let EventHandler::End(callback) = handler {
callback();
}
}
}
if let Some(ref msg) = *current_message {
if let Some(final_handlers) = handlers.get(&EventType::FinalMessage) {
for handler in final_handlers {
if let EventHandler::FinalMessage(callback) = handler {
callback(msg);
}
}
}
}
}
_ => {}
}
Ok(())
}
fn get_accumulated_text(&self, message: &Message) -> String {
message.content
.iter()
.filter_map(|block| match block {
ContentBlock::Text { text } => Some(text.as_str()),
_ => None,
})
.collect::<Vec<_>>()
.join("")
}
}
impl Stream for MessageStream {
type Item = Result<MessageStreamEvent>;
fn poll_next(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Option<Self::Item>> {
use futures::Stream as FuturesStream;
let this = self.project();
match FuturesStream::poll_next(this.event_stream, cx) {
std::task::Poll::Ready(Some(Ok(event))) => {
std::task::Poll::Ready(Some(Ok(event)))
}
std::task::Poll::Ready(Some(Err(err))) => {
std::task::Poll::Ready(Some(Err(AnthropicError::StreamError(
format!("Stream error: {}", err)
))))
}
std::task::Poll::Ready(None) => {
std::task::Poll::Ready(None)
}
std::task::Poll::Pending => std::task::Poll::Pending,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::types::{Role, Usage};
async fn create_dummy_response() -> reqwest::Response {
let client = reqwest::Client::new();
client.get("https://httpbin.org/status/200")
.send()
.await
.expect("Failed to create test response")
}
#[tokio::test]
async fn test_message_stream_creation() {
let response = create_dummy_response().await;
let stream = MessageStream::new(response, Some("test-request-id".to_string()));
assert!(!stream.ended());
assert!(!stream.errored());
assert!(!stream.aborted());
assert_eq!(stream.request_id(), Some("test-request-id"));
}
#[tokio::test]
async fn test_event_processing() {
let response = create_dummy_response().await;
let stream = MessageStream::new(response, None);
let start_event = MessageStreamEvent::MessageStart {
message: Message {
id: "msg_test".to_string(),
type_: "message".to_string(),
role: Role::Assistant,
content: vec![],
model: "claude-3-5-sonnet-latest".to_string(),
stop_reason: None,
stop_sequence: None,
usage: Usage {
input_tokens: 10,
output_tokens: 0,
cache_creation_input_tokens: None,
cache_read_input_tokens: None,
server_tool_use: None,
service_tier: None,
},
request_id: None,
},
};
stream.process_event(start_event).unwrap();
let current = stream.current_message().unwrap();
assert_eq!(current.id, "msg_test");
assert_eq!(current.role, Role::Assistant);
}
#[test]
fn test_event_handlers() {
use std::sync::{Arc, Mutex};
use std::collections::HashMap;
let text_called = Arc::new(Mutex::new(false));
let text_called_clone = text_called.clone();
let _handler = EventHandler::Text(Box::new(move |_delta, _snapshot| {
*text_called_clone.lock().unwrap() = true;
}));
assert_eq!(EventType::Text, EventType::Text);
assert_ne!(EventType::Text, EventType::Error);
let mut map: HashMap<EventType, String> = HashMap::new();
map.insert(EventType::Text, "text_handler".to_string());
assert_eq!(map.get(&EventType::Text), Some(&"text_handler".to_string()));
}
}