use futures::stream::{Stream, StreamExt};
use std::pin::Pin;
use std::task::{Context, Poll};
use turul_a2a_types::{Message, Task};
use crate::A2aClientError;
#[derive(Debug, Clone)]
pub struct SseEvent {
pub id: Option<String>,
pub data: serde_json::Value,
}
pub struct SseStream {
inner: Pin<Box<dyn Stream<Item = Result<SseEvent, A2aClientError>> + Send>>,
}
impl SseStream {
pub(crate) fn from_response(response: reqwest::Response) -> Self {
let byte_stream = response.bytes_stream();
let event_stream = futures::stream::unfold(
(byte_stream, String::new()),
|(mut stream, mut buffer)| async move {
loop {
if let Some(pos) = buffer.find("\n\n") {
let event_text = buffer[..pos].to_string();
buffer = buffer[pos + 2..].to_string();
if let Some(event) = parse_sse_event(&event_text) {
return Some((Ok(event), (stream, buffer)));
}
continue;
}
match stream.next().await {
Some(Ok(chunk)) => {
buffer.push_str(&String::from_utf8_lossy(&chunk));
}
Some(Err(e)) => {
return Some((Err(A2aClientError::Request(e)), (stream, buffer)));
}
None => {
let remaining = buffer.trim().to_string();
buffer.clear();
if !remaining.is_empty() {
if let Some(event) = parse_sse_event(&remaining) {
return Some((Ok(event), (stream, buffer)));
}
}
return None; }
}
}
},
);
Self {
inner: Box::pin(event_stream),
}
}
}
impl Stream for SseStream {
type Item = Result<SseEvent, A2aClientError>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
self.inner.as_mut().poll_next(cx)
}
}
fn parse_sse_event(text: &str) -> Option<SseEvent> {
let mut id = None;
let mut data = None;
for line in text.lines() {
let line = line.trim();
if line.starts_with(':') {
continue; }
if let Some(value) = line.strip_prefix("id:") {
id = Some(value.trim().to_string());
} else if let Some(value) = line.strip_prefix("data:") {
let value = value.trim();
if let Ok(json) = serde_json::from_str::<serde_json::Value>(value) {
data = Some(json);
}
}
}
data.map(|d| SseEvent { id, data: d })
}
#[derive(Debug, Clone)]
#[non_exhaustive]
pub enum StreamEvent {
Task(Task),
Message(Message),
StatusUpdate {
task_id: String,
context_id: String,
status: serde_json::Value,
},
ArtifactUpdate {
task_id: String,
context_id: String,
artifact: serde_json::Value,
append: bool,
last_chunk: bool,
},
}
#[derive(Debug, Clone)]
#[non_exhaustive]
pub struct TypedSseEvent {
pub id: Option<String>,
pub event: StreamEvent,
}
fn parse_stream_event(raw: &SseEvent) -> Result<TypedSseEvent, A2aClientError> {
let data = &raw.data;
let event = if let Some(task_json) = data.get("task") {
let proto: turul_a2a_proto::Task = serde_json::from_value(task_json.clone())
.map_err(|e| A2aClientError::Conversion(format!("Invalid Task: {e}")))?;
let task = Task::try_from(proto).map_err(|e| A2aClientError::Conversion(e.to_string()))?;
StreamEvent::Task(task)
} else if let Some(msg_json) = data.get("message") {
let proto: turul_a2a_proto::Message = serde_json::from_value(msg_json.clone())
.map_err(|e| A2aClientError::Conversion(format!("Invalid Message: {e}")))?;
let msg =
Message::try_from(proto).map_err(|e| A2aClientError::Conversion(e.to_string()))?;
StreamEvent::Message(msg)
} else if let Some(su) = data.get("statusUpdate") {
StreamEvent::StatusUpdate {
task_id: su
.get("taskId")
.and_then(|v| v.as_str())
.unwrap_or("")
.to_string(),
context_id: su
.get("contextId")
.and_then(|v| v.as_str())
.unwrap_or("")
.to_string(),
status: su.get("status").cloned().unwrap_or_default(),
}
} else if let Some(au) = data.get("artifactUpdate") {
StreamEvent::ArtifactUpdate {
task_id: au
.get("taskId")
.and_then(|v| v.as_str())
.unwrap_or("")
.to_string(),
context_id: au
.get("contextId")
.and_then(|v| v.as_str())
.unwrap_or("")
.to_string(),
artifact: au.get("artifact").cloned().unwrap_or_default(),
append: au.get("append").and_then(|v| v.as_bool()).unwrap_or(false),
last_chunk: au
.get("lastChunk")
.and_then(|v| v.as_bool())
.unwrap_or(false),
}
} else {
return Err(A2aClientError::Conversion(format!(
"Unknown stream event shape: {data}"
)));
};
Ok(TypedSseEvent {
id: raw.id.clone(),
event,
})
}
pub struct TypedSseStream {
inner: Pin<Box<dyn Stream<Item = Result<TypedSseEvent, A2aClientError>> + Send>>,
}
impl TypedSseStream {
pub(crate) fn from_raw(raw: SseStream) -> Self {
let typed = futures::stream::unfold(raw, |mut raw_stream| async move {
match raw_stream.next().await? {
Ok(raw_event) => {
let typed = parse_stream_event(&raw_event);
Some((typed, raw_stream))
}
Err(e) => Some((Err(e), raw_stream)),
}
});
Self {
inner: Box::pin(typed),
}
}
}
impl Stream for TypedSseStream {
type Item = Result<TypedSseEvent, A2aClientError>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
self.inner.as_mut().poll_next(cx)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn parse_status_update_event() {
let text = "id: task-1:1\ndata: {\"statusUpdate\":{\"taskId\":\"task-1\"}}";
let event = parse_sse_event(text).unwrap();
assert_eq!(event.id.as_deref(), Some("task-1:1"));
assert!(event.data.get("statusUpdate").is_some());
}
#[test]
fn parse_event_without_id() {
let text = "data: {\"task\":{\"id\":\"t-1\"}}";
let event = parse_sse_event(text).unwrap();
assert!(event.id.is_none());
assert!(event.data.get("task").is_some());
}
#[test]
fn parse_comment_only_returns_none() {
let text = ": keepalive";
assert!(parse_sse_event(text).is_none());
}
#[test]
fn parse_empty_returns_none() {
assert!(parse_sse_event("").is_none());
}
}