use futures::{SinkExt, StreamExt};
use std::pin::Pin;
use std::task::{Context, Poll};
use tokio::net::TcpStream;
use tokio_tungstenite::{
connect_async, tungstenite::protocol::Message as WsMessage, MaybeTlsStream,
WebSocketStream as TungsteniteStream,
};
use crate::error::{StreamError, StreamResult};
use crate::events::AgentStreamEvent;
use crate::partial_response::ResponseDelta;
#[derive(Debug, Clone)]
pub struct WebSocketConfig {
pub url: String,
pub headers: Vec<(String, String)>,
pub ping_interval: Option<u64>,
pub timeout: u64,
}
impl WebSocketConfig {
pub fn new(url: impl Into<String>) -> Self {
Self {
url: url.into(),
headers: Vec::new(),
ping_interval: Some(30),
timeout: 30,
}
}
pub fn with_header(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
self.headers.push((key.into(), value.into()));
self
}
pub fn with_auth(self, token: impl Into<String>) -> Self {
self.with_header("Authorization", format!("Bearer {}", token.into()))
}
pub fn with_ping_interval(mut self, seconds: u64) -> Self {
self.ping_interval = Some(seconds);
self
}
pub fn without_ping(mut self) -> Self {
self.ping_interval = None;
self
}
pub fn with_timeout(mut self, seconds: u64) -> Self {
self.timeout = seconds;
self
}
}
#[derive(Debug, Clone)]
pub enum WsStreamMessage {
Text(String),
Binary(Vec<u8>),
Ping,
Pong,
Close,
}
impl From<WsMessage> for WsStreamMessage {
fn from(msg: WsMessage) -> Self {
match msg {
WsMessage::Text(text) => WsStreamMessage::Text(text.to_string()),
WsMessage::Binary(data) => WsStreamMessage::Binary(data.to_vec()),
WsMessage::Ping(_) => WsStreamMessage::Ping,
WsMessage::Pong(_) => WsStreamMessage::Pong,
WsMessage::Close(_) => WsStreamMessage::Close,
WsMessage::Frame(_) => WsStreamMessage::Binary(vec![]),
}
}
}
pub struct WebSocketStream {
inner: TungsteniteStream<MaybeTlsStream<TcpStream>>,
config: WebSocketConfig,
}
impl WebSocketStream {
pub async fn connect(config: WebSocketConfig) -> StreamResult<Self> {
let (ws_stream, _) = connect_async(&config.url)
.await
.map_err(|e| StreamError::Connection(e.to_string()))?;
Ok(Self {
inner: ws_stream,
config,
})
}
pub async fn send_text(&mut self, text: impl Into<String>) -> StreamResult<()> {
self.inner
.send(WsMessage::Text(text.into()))
.await
.map_err(|e| StreamError::Send(e.to_string()))
}
pub async fn send_json<T: serde::Serialize>(&mut self, value: &T) -> StreamResult<()> {
let json =
serde_json::to_string(value).map_err(|e| StreamError::Serialization(e.to_string()))?;
self.send_text(json).await
}
pub async fn close(&mut self) -> StreamResult<()> {
self.inner
.close(None)
.await
.map_err(|e| StreamError::Connection(e.to_string()))
}
pub async fn next_message(&mut self) -> Option<StreamResult<WsStreamMessage>> {
match self.inner.next().await {
Some(Ok(msg)) => Some(Ok(msg.into())),
Some(Err(e)) => Some(Err(StreamError::Receive(e.to_string()))),
None => None,
}
}
pub async fn next_delta(&mut self) -> Option<StreamResult<ResponseDelta>> {
loop {
match self.next_message().await? {
Ok(WsStreamMessage::Text(text)) => {
match serde_json::from_str::<ResponseDelta>(&text) {
Ok(delta) => return Some(Ok(delta)),
Err(e) => {
tracing::warn!("Failed to parse WebSocket message as delta: {}", e);
continue;
}
}
}
Ok(WsStreamMessage::Close) => return None,
Ok(WsStreamMessage::Ping) | Ok(WsStreamMessage::Pong) => continue,
Ok(_) => continue,
Err(e) => return Some(Err(e)),
}
}
}
pub fn config(&self) -> &WebSocketConfig {
&self.config
}
}
impl futures::Stream for WebSocketStream {
type Item = StreamResult<WsStreamMessage>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
match Pin::new(&mut self.inner).poll_next(cx) {
Poll::Ready(Some(Ok(msg))) => Poll::Ready(Some(Ok(msg.into()))),
Poll::Ready(Some(Err(e))) => {
Poll::Ready(Some(Err(StreamError::Receive(e.to_string()))))
}
Poll::Ready(None) => Poll::Ready(None),
Poll::Pending => Poll::Pending,
}
}
}
pub struct WebSocketAgentStream {
ws: WebSocketStream,
run_id: String,
}
impl WebSocketAgentStream {
pub fn new(ws: WebSocketStream, run_id: impl Into<String>) -> Self {
Self {
ws,
run_id: run_id.into(),
}
}
pub async fn next_event(&mut self) -> Option<StreamResult<AgentStreamEvent>> {
loop {
match self.ws.next_delta().await? {
Ok(delta) => {
match delta {
ResponseDelta::Text { index, content } => {
return Some(Ok(AgentStreamEvent::TextDelta {
content,
part_index: index,
}));
}
ResponseDelta::ToolCall {
index,
name,
args,
id,
} => {
if let Some(name) = name {
return Some(Ok(AgentStreamEvent::ToolCallStart {
name,
tool_call_id: id,
index,
}));
} else if let Some(args) = args {
return Some(Ok(AgentStreamEvent::ToolCallDelta {
args_delta: args,
index,
}));
} else {
continue;
}
}
ResponseDelta::Thinking { index, content, .. } => {
return Some(Ok(AgentStreamEvent::ThinkingDelta { content, index }));
}
ResponseDelta::Finish { .. } => {
return Some(Ok(AgentStreamEvent::RunComplete {
run_id: self.run_id.clone(),
total_steps: 1,
}));
}
ResponseDelta::Usage { usage } => {
return Some(Ok(AgentStreamEvent::UsageUpdate { usage }));
}
}
}
Err(e) => return Some(Err(e)),
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_config() {
let config = WebSocketConfig::new("wss://example.com/stream")
.with_auth("token123")
.with_timeout(60)
.with_ping_interval(15);
assert_eq!(config.url, "wss://example.com/stream");
assert_eq!(config.timeout, 60);
assert_eq!(config.ping_interval, Some(15));
assert!(config.headers.iter().any(|(k, _)| k == "Authorization"));
}
#[test]
fn test_ws_message_conversion() {
let text_msg = WsMessage::Text("hello".to_string());
let converted: WsStreamMessage = text_msg.into();
assert!(matches!(converted, WsStreamMessage::Text(s) if s == "hello"));
let close_msg = WsMessage::Close(None);
let converted: WsStreamMessage = close_msg.into();
assert!(matches!(converted, WsStreamMessage::Close));
}
}