turul-a2a-client 0.1.4

A2A Protocol v1.0 client library
Documentation
//! SSE (Server-Sent Events) stream parser for A2A streaming responses.
//!
//! Parses raw SSE text from `reqwest::Response::bytes_stream()` into
//! typed `SseEvent` values with `id` and `data` fields.

use futures::stream::{Stream, StreamExt};
use std::pin::Pin;
use std::task::{Context, Poll};

use turul_a2a_types::{Message, Task};

use crate::A2aClientError;

/// A parsed SSE event from the server.
#[derive(Debug, Clone)]
pub struct SseEvent {
    /// SSE event ID (from `id:` line). Format: `{task_id}:{sequence}`.
    pub id: Option<String>,
    /// Parsed JSON data (from `data:` line).
    pub data: serde_json::Value,
}

/// Async stream of SSE events from a server response.
///
/// Wraps a `reqwest::Response` byte stream and parses SSE format line-by-line.
pub struct SseStream {
    inner: Pin<Box<dyn Stream<Item = Result<SseEvent, A2aClientError>> + Send>>,
}

impl SseStream {
    /// Create an SSE stream from a reqwest response.
    ///
    /// The response must have `Content-Type: text/event-stream`.
    pub(crate) fn from_response(response: reqwest::Response) -> Self {
        let byte_stream = response.bytes_stream();

        let event_stream = futures::stream::unfold(
            (byte_stream, String::new()),
            |(mut stream, mut buffer)| async move {
                loop {
                    // Check if we have a complete event in the buffer
                    if let Some(pos) = buffer.find("\n\n") {
                        let event_text = buffer[..pos].to_string();
                        buffer = buffer[pos + 2..].to_string();

                        if let Some(event) = parse_sse_event(&event_text) {
                            return Some((Ok(event), (stream, buffer)));
                        }
                        // Empty event (keepalive comment) — continue
                        continue;
                    }

                    // Need more data from the stream
                    match stream.next().await {
                        Some(Ok(chunk)) => {
                            buffer.push_str(&String::from_utf8_lossy(&chunk));
                        }
                        Some(Err(e)) => {
                            return Some((Err(A2aClientError::Request(e)), (stream, buffer)));
                        }
                        None => {
                            // Stream ended — check if there's a final event in the buffer
                            let remaining = buffer.trim().to_string();
                            buffer.clear();
                            if !remaining.is_empty() {
                                if let Some(event) = parse_sse_event(&remaining) {
                                    return Some((Ok(event), (stream, buffer)));
                                }
                            }
                            return None; // Stream complete
                        }
                    }
                }
            },
        );

        Self {
            inner: Box::pin(event_stream),
        }
    }
}

impl Stream for SseStream {
    type Item = Result<SseEvent, A2aClientError>;

    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
        self.inner.as_mut().poll_next(cx)
    }
}

/// Parse a single SSE event block into an `SseEvent`.
///
/// SSE format:
/// ```text
/// id: task-123:1
/// data: {"statusUpdate":{"taskId":"task-123",...}}
/// ```
fn parse_sse_event(text: &str) -> Option<SseEvent> {
    let mut id = None;
    let mut data = None;

    for line in text.lines() {
        let line = line.trim();
        if line.starts_with(':') {
            continue; // SSE comment (keepalive)
        }
        if let Some(value) = line.strip_prefix("id:") {
            id = Some(value.trim().to_string());
        } else if let Some(value) = line.strip_prefix("data:") {
            let value = value.trim();
            if let Ok(json) = serde_json::from_str::<serde_json::Value>(value) {
                data = Some(json);
            }
        }
    }

    data.map(|d| SseEvent { id, data: d })
}

// =========================================================
// Typed streaming events
// =========================================================

/// A typed streaming event from the server.
///
/// Replaces raw `serde_json::Value` navigation with concrete variants
/// matching the proto `StreamResponse` oneof.
#[derive(Debug, Clone)]
#[non_exhaustive]
pub enum StreamEvent {
    /// Full task snapshot (first event in subscribe streams).
    Task(Task),
    /// Direct message from the agent.
    Message(Message),
    /// Task status changed.
    StatusUpdate {
        task_id: String,
        context_id: String,
        status: serde_json::Value,
    },
    /// Artifact produced or updated.
    ArtifactUpdate {
        task_id: String,
        context_id: String,
        artifact: serde_json::Value,
        append: bool,
        last_chunk: bool,
    },
}

/// A typed SSE event with parsed event data.
#[derive(Debug, Clone)]
#[non_exhaustive]
pub struct TypedSseEvent {
    /// SSE event ID. Format: `{task_id}:{sequence}`.
    pub id: Option<String>,
    /// Parsed and typed event data.
    pub event: StreamEvent,
}

/// Parse a raw `SseEvent` into a typed `TypedSseEvent`.
///
/// Inspects the JSON keys to determine the variant:
/// `task` → Task, `message` → Message, `statusUpdate` → StatusUpdate, `artifactUpdate` → ArtifactUpdate
fn parse_stream_event(raw: &SseEvent) -> Result<TypedSseEvent, A2aClientError> {
    let data = &raw.data;

    let event = if let Some(task_json) = data.get("task") {
        let proto: turul_a2a_proto::Task = serde_json::from_value(task_json.clone())
            .map_err(|e| A2aClientError::Conversion(format!("Invalid Task: {e}")))?;
        let task = Task::try_from(proto).map_err(|e| A2aClientError::Conversion(e.to_string()))?;
        StreamEvent::Task(task)
    } else if let Some(msg_json) = data.get("message") {
        let proto: turul_a2a_proto::Message = serde_json::from_value(msg_json.clone())
            .map_err(|e| A2aClientError::Conversion(format!("Invalid Message: {e}")))?;
        let msg =
            Message::try_from(proto).map_err(|e| A2aClientError::Conversion(e.to_string()))?;
        StreamEvent::Message(msg)
    } else if let Some(su) = data.get("statusUpdate") {
        StreamEvent::StatusUpdate {
            task_id: su
                .get("taskId")
                .and_then(|v| v.as_str())
                .unwrap_or("")
                .to_string(),
            context_id: su
                .get("contextId")
                .and_then(|v| v.as_str())
                .unwrap_or("")
                .to_string(),
            status: su.get("status").cloned().unwrap_or_default(),
        }
    } else if let Some(au) = data.get("artifactUpdate") {
        StreamEvent::ArtifactUpdate {
            task_id: au
                .get("taskId")
                .and_then(|v| v.as_str())
                .unwrap_or("")
                .to_string(),
            context_id: au
                .get("contextId")
                .and_then(|v| v.as_str())
                .unwrap_or("")
                .to_string(),
            artifact: au.get("artifact").cloned().unwrap_or_default(),
            append: au.get("append").and_then(|v| v.as_bool()).unwrap_or(false),
            last_chunk: au
                .get("lastChunk")
                .and_then(|v| v.as_bool())
                .unwrap_or(false),
        }
    } else {
        return Err(A2aClientError::Conversion(format!(
            "Unknown stream event shape: {data}"
        )));
    };

    Ok(TypedSseEvent {
        id: raw.id.clone(),
        event,
    })
}

/// Async stream of typed SSE events.
///
/// Wraps `SseStream` and parses raw JSON events into typed `StreamEvent` variants.
pub struct TypedSseStream {
    inner: Pin<Box<dyn Stream<Item = Result<TypedSseEvent, A2aClientError>> + Send>>,
}

impl TypedSseStream {
    pub(crate) fn from_raw(raw: SseStream) -> Self {
        let typed = futures::stream::unfold(raw, |mut raw_stream| async move {
            match raw_stream.next().await? {
                Ok(raw_event) => {
                    let typed = parse_stream_event(&raw_event);
                    Some((typed, raw_stream))
                }
                Err(e) => Some((Err(e), raw_stream)),
            }
        });
        Self {
            inner: Box::pin(typed),
        }
    }
}

impl Stream for TypedSseStream {
    type Item = Result<TypedSseEvent, A2aClientError>;

    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
        self.inner.as_mut().poll_next(cx)
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn parse_status_update_event() {
        let text = "id: task-1:1\ndata: {\"statusUpdate\":{\"taskId\":\"task-1\"}}";
        let event = parse_sse_event(text).unwrap();
        assert_eq!(event.id.as_deref(), Some("task-1:1"));
        assert!(event.data.get("statusUpdate").is_some());
    }

    #[test]
    fn parse_event_without_id() {
        let text = "data: {\"task\":{\"id\":\"t-1\"}}";
        let event = parse_sse_event(text).unwrap();
        assert!(event.id.is_none());
        assert!(event.data.get("task").is_some());
    }

    #[test]
    fn parse_comment_only_returns_none() {
        let text = ": keepalive";
        assert!(parse_sse_event(text).is_none());
    }

    #[test]
    fn parse_empty_returns_none() {
        assert!(parse_sse_event("").is_none());
    }
}