use crate::{
errors::{Result, SdkError},
transport::{InputMessage, Transport, TransportState},
types::{ControlRequest, ControlResponse, Message},
};
use async_trait::async_trait;
use futures::stream::Stream;
use serde_json::Value as JsonValue;
use std::pin::Pin;
use tokio::sync::{broadcast, mpsc, watch};
use tracing::{debug, error, info, warn};
#[derive(Debug, Clone)]
pub struct WebSocketConfig {
pub max_reconnect_attempts: u32,
pub base_reconnect_delay_ms: u64,
pub max_reconnect_delay_ms: u64,
pub ping_interval_secs: u64,
pub message_buffer_capacity: usize,
pub auth_token: Option<String>,
}
impl Default for WebSocketConfig {
fn default() -> Self {
Self {
max_reconnect_attempts: 3,
base_reconnect_delay_ms: 1000,
max_reconnect_delay_ms: 30000,
ping_interval_secs: 10,
message_buffer_capacity: 1000,
auth_token: None,
}
}
}
pub struct WebSocketTransport {
url: url::Url,
config: WebSocketConfig,
ws_tx: Option<mpsc::Sender<String>>,
message_broadcast_tx: Option<broadcast::Sender<Message>>,
control_rx: Option<mpsc::Receiver<ControlResponse>>,
sdk_control_rx: Option<mpsc::Receiver<JsonValue>>,
state: TransportState,
request_counter: u64,
shutdown_tx: Option<watch::Sender<bool>>,
}
impl std::fmt::Debug for WebSocketTransport {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("WebSocketTransport")
.field("url", &self.url)
.field("state", &self.state)
.field("request_counter", &self.request_counter)
.field("ws_tx", &self.ws_tx.is_some())
.finish()
}
}
impl WebSocketTransport {
pub fn new(url: &str, config: WebSocketConfig) -> Result<Self> {
let parsed_url = url::Url::parse(url).map_err(|e| {
SdkError::WebSocketError(format!("Invalid WebSocket URL '{url}': {e}"))
})?;
match parsed_url.scheme() {
"ws" | "wss" => {}
scheme => {
return Err(SdkError::WebSocketError(format!(
"Unsupported URL scheme '{scheme}', expected 'ws' or 'wss'"
)));
}
}
Ok(Self {
url: parsed_url,
config,
ws_tx: None,
message_broadcast_tx: None,
control_rx: None,
sdk_control_rx: None,
state: TransportState::Disconnected,
request_counter: 0,
shutdown_tx: None,
})
}
fn build_ws_request(&self) -> Result<http::Request<()>> {
let mut request = http::Request::builder()
.uri(self.url.as_str())
.header("Host", self.url.host_str().unwrap_or("localhost"))
.header("Connection", "Upgrade")
.header("Upgrade", "websocket")
.header("Sec-WebSocket-Version", "13")
.header(
"Sec-WebSocket-Key",
tokio_tungstenite::tungstenite::handshake::client::generate_key(),
);
if let Some(ref token) = self.config.auth_token {
request = request.header("Authorization", format!("Bearer {token}"));
}
request
.body(())
.map_err(|e| SdkError::WebSocketError(format!("Failed to build WS request: {e}")))
}
async fn establish_connection(&mut self) -> Result<()> {
use futures::StreamExt;
use tokio_tungstenite::tungstenite::Message as WsMessage;
self.state = TransportState::Connecting;
let request = self.build_ws_request()?;
let (ws_stream, _response) =
tokio_tungstenite::connect_async(request)
.await
.map_err(|e| {
SdkError::WebSocketError(format!("Failed to connect to {}: {e}", self.url))
})?;
info!("WebSocket connected to {}", self.url);
let (ws_sink, ws_stream) = ws_stream.split();
let (ws_tx, ws_rx) = mpsc::channel::<String>(256);
let (message_broadcast_tx, _) =
broadcast::channel::<Message>(self.config.message_buffer_capacity);
let (control_tx, control_rx) = mpsc::channel::<ControlResponse>(32);
let (sdk_control_tx, sdk_control_rx) = mpsc::channel::<JsonValue>(64);
let (shutdown_tx, shutdown_rx) = watch::channel(false);
let mut shutdown_rx_write = shutdown_rx.clone();
tokio::spawn(async move {
use futures::SinkExt;
let mut ws_sink = ws_sink;
let mut ws_rx = ws_rx;
loop {
tokio::select! {
msg = ws_rx.recv() => {
match msg {
Some(line) => {
if let Err(e) = ws_sink.send(WsMessage::Text(line.into())).await {
error!("WebSocket write error: {e}");
break;
}
}
None => {
debug!("Write channel closed, shutting down write task");
break;
}
}
}
_ = shutdown_rx_write.changed() => {
debug!("Shutdown signal received in write task");
let _ = ws_sink.send(WsMessage::Close(None)).await;
break;
}
}
}
debug!("WebSocket write task ended");
});
let message_broadcast_tx_clone = message_broadcast_tx.clone();
let control_tx_clone = control_tx;
let sdk_control_tx_clone = sdk_control_tx;
let mut shutdown_rx_read = shutdown_rx.clone();
tokio::spawn(async move {
let mut ws_stream = ws_stream;
loop {
tokio::select! {
msg = ws_stream.next() => {
match msg {
Some(Ok(WsMessage::Text(text))) => {
let text_str: &str = &text;
for line in text_str.split('\n') {
let line = line.trim();
if line.is_empty() {
continue;
}
match serde_json::from_str::<JsonValue>(line) {
Ok(json) => {
Self::route_incoming_message(
json,
&message_broadcast_tx_clone,
&control_tx_clone,
&sdk_control_tx_clone,
).await;
}
Err(e) => {
warn!("Failed to parse WebSocket JSON: {e} — line: {line}");
}
}
}
}
Some(Ok(WsMessage::Ping(data))) => {
debug!("Received WS ping, pong is auto-sent by tungstenite");
let _ = data; }
Some(Ok(WsMessage::Pong(_))) => {
debug!("Received WS pong");
}
Some(Ok(WsMessage::Close(frame))) => {
info!("WebSocket closed by server: {frame:?}");
break;
}
Some(Ok(_)) => {
}
Some(Err(e)) => {
error!("WebSocket read error: {e}");
break;
}
None => {
info!("WebSocket stream ended");
break;
}
}
}
_ = shutdown_rx_read.changed() => {
debug!("Shutdown signal received in read task");
break;
}
}
}
debug!("WebSocket read task ended");
});
let keepalive_tx = ws_tx.clone();
let ping_interval = self.config.ping_interval_secs;
let mut shutdown_rx_keepalive = shutdown_rx.clone();
tokio::spawn(async move {
let mut interval =
tokio::time::interval(tokio::time::Duration::from_secs(ping_interval));
interval.tick().await;
loop {
tokio::select! {
_ = interval.tick() => {
let keep_alive = serde_json::json!({"type": "keep_alive"}).to_string();
if keepalive_tx.send(keep_alive).await.is_err() {
debug!("Keepalive channel closed");
break;
}
}
_ = shutdown_rx_keepalive.changed() => {
debug!("Shutdown signal received in keepalive task");
break;
}
}
}
debug!("WebSocket keepalive task ended");
});
self.ws_tx = Some(ws_tx);
self.message_broadcast_tx = Some(message_broadcast_tx);
self.control_rx = Some(control_rx);
self.sdk_control_rx = Some(sdk_control_rx);
self.shutdown_tx = Some(shutdown_tx);
self.state = TransportState::Connected;
Ok(())
}
async fn route_incoming_message(
json: JsonValue,
message_broadcast_tx: &broadcast::Sender<Message>,
control_tx: &mpsc::Sender<ControlResponse>,
sdk_control_tx: &mpsc::Sender<JsonValue>,
) {
let msg_type = match json.get("type").and_then(|v| v.as_str()) {
Some(t) => t,
None => {
warn!("Received JSON without 'type' field: {json}");
return;
}
};
match msg_type {
"control_response" => {
debug!("Received control response: {json:?}");
let _ = sdk_control_tx.send(json.clone()).await;
if let Some(response_obj) = json.get("response") {
if let Some(request_id) = response_obj
.get("request_id")
.or_else(|| response_obj.get("requestId"))
.and_then(|v| v.as_str())
{
let subtype = response_obj.get("subtype").and_then(|v| v.as_str());
let success = subtype == Some("success");
let control_resp = ControlResponse::InterruptAck {
request_id: request_id.to_string(),
success,
};
let _ = control_tx.send(control_resp).await;
}
}
}
"control_request" => {
debug!("Received control request: {json:?}");
let _ = sdk_control_tx.send(json).await;
}
"sdk_control_request" => {
debug!("Received SDK control request (legacy): {json:?}");
let _ = sdk_control_tx.send(json).await;
}
"control" => {
if let Some(control) = json.get("control") {
debug!("Received control message: {control:?}");
let _ = sdk_control_tx.send(control.clone()).await;
}
}
"system" => {
if let Some(subtype) = json.get("subtype").and_then(|v| v.as_str()) {
if subtype.starts_with("sdk_control:") {
debug!("Received SDK control message: {subtype}");
let _ = sdk_control_tx.send(json.clone()).await;
}
}
match crate::message_parser::parse_message(json) {
Ok(Some(message)) => {
let _ = message_broadcast_tx.send(message);
}
Ok(None) => {}
Err(e) => {
warn!("Failed to parse system message: {e}");
}
}
}
"keep_alive" => {
debug!("Received keep_alive");
}
_ => {
match crate::message_parser::parse_message(json) {
Ok(Some(message)) => {
let _ = message_broadcast_tx.send(message);
}
Ok(None) => {
}
Err(e) => {
warn!("Failed to parse message: {e}");
}
}
}
}
}
}
#[async_trait]
impl Transport for WebSocketTransport {
fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
self
}
async fn connect(&mut self) -> Result<()> {
if self.state == TransportState::Connected {
return Ok(());
}
self.establish_connection().await?;
info!("WebSocket transport connected to {}", self.url);
Ok(())
}
async fn send_message(&mut self, message: InputMessage) -> Result<()> {
if self.state != TransportState::Connected {
return Err(SdkError::InvalidState {
message: "Not connected".into(),
});
}
let json = serde_json::to_string(&message)?;
debug!("Sending message via WebSocket: {json}");
if let Some(ref tx) = self.ws_tx {
tx.send(json)
.await
.map_err(|_| SdkError::WebSocketError("Write channel closed".into()))?;
Ok(())
} else {
Err(SdkError::InvalidState {
message: "WebSocket write channel not available".into(),
})
}
}
fn receive_messages(
&mut self,
) -> Pin<Box<dyn Stream<Item = Result<Message>> + Send + 'static>> {
use futures::StreamExt;
if let Some(ref tx) = self.message_broadcast_tx {
let rx = tx.subscribe();
Box::pin(
tokio_stream::wrappers::BroadcastStream::new(rx).filter_map(|result| async move {
match result {
Ok(msg) => Some(Ok(msg)),
Err(
tokio_stream::wrappers::errors::BroadcastStreamRecvError::Lagged(n),
) => {
warn!("WebSocket receiver lagged by {n} messages");
None
}
}
}),
)
} else {
Box::pin(futures::stream::empty())
}
}
async fn send_control_request(&mut self, request: ControlRequest) -> Result<()> {
if self.state != TransportState::Connected {
return Err(SdkError::InvalidState {
message: "Not connected".into(),
});
}
self.request_counter += 1;
let control_msg = match request {
ControlRequest::Interrupt { request_id } => {
serde_json::json!({
"type": "control_request",
"request": {
"type": "interrupt",
"request_id": request_id
}
})
}
};
let json = serde_json::to_string(&control_msg)?;
if let Some(ref tx) = self.ws_tx {
tx.send(json)
.await
.map_err(|_| SdkError::WebSocketError("Write channel closed".into()))?;
Ok(())
} else {
Err(SdkError::InvalidState {
message: "WebSocket write channel not available".into(),
})
}
}
async fn receive_control_response(&mut self) -> Result<Option<ControlResponse>> {
if let Some(ref mut rx) = self.control_rx {
Ok(rx.recv().await)
} else {
Ok(None)
}
}
async fn send_sdk_control_request(&mut self, request: JsonValue) -> Result<()> {
let json = serde_json::to_string(&request)?;
if let Some(ref tx) = self.ws_tx {
tx.send(json)
.await
.map_err(|_| SdkError::WebSocketError("Write channel closed".into()))?;
Ok(())
} else {
Err(SdkError::InvalidState {
message: "WebSocket write channel not available".into(),
})
}
}
async fn send_sdk_control_response(&mut self, response: JsonValue) -> Result<()> {
let control_response = serde_json::json!({
"type": "control_response",
"response": response
});
let json = serde_json::to_string(&control_response)?;
if let Some(ref tx) = self.ws_tx {
tx.send(json)
.await
.map_err(|_| SdkError::WebSocketError("Write channel closed".into()))?;
Ok(())
} else {
Err(SdkError::InvalidState {
message: "WebSocket write channel not available".into(),
})
}
}
fn take_sdk_control_receiver(&mut self) -> Option<mpsc::Receiver<JsonValue>> {
self.sdk_control_rx.take()
}
fn is_connected(&self) -> bool {
self.state == TransportState::Connected
}
async fn disconnect(&mut self) -> Result<()> {
if self.state != TransportState::Connected {
return Ok(());
}
self.state = TransportState::Disconnecting;
if let Some(ref tx) = self.shutdown_tx {
let _ = tx.send(true);
}
self.ws_tx.take();
self.shutdown_tx.take();
self.state = TransportState::Disconnected;
info!("WebSocket transport disconnected");
Ok(())
}
async fn end_input(&mut self) -> Result<()> {
self.ws_tx.take();
Ok(())
}
}
impl Drop for WebSocketTransport {
fn drop(&mut self) {
if let Some(ref tx) = self.shutdown_tx {
let _ = tx.send(true);
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_websocket_config_default() {
let config = WebSocketConfig::default();
assert_eq!(config.max_reconnect_attempts, 3);
assert_eq!(config.base_reconnect_delay_ms, 1000);
assert_eq!(config.max_reconnect_delay_ms, 30000);
assert_eq!(config.ping_interval_secs, 10);
assert_eq!(config.message_buffer_capacity, 1000);
assert!(config.auth_token.is_none());
}
#[test]
fn test_websocket_transport_new_valid_url() {
let transport =
WebSocketTransport::new("ws://localhost:8765", WebSocketConfig::default());
assert!(transport.is_ok());
let transport = transport.unwrap();
assert!(!transport.is_connected());
}
#[test]
fn test_websocket_transport_new_wss_url() {
let transport =
WebSocketTransport::new("wss://example.com/ws", WebSocketConfig::default());
assert!(transport.is_ok());
}
#[test]
fn test_websocket_transport_new_invalid_scheme() {
let transport =
WebSocketTransport::new("http://localhost:8765", WebSocketConfig::default());
assert!(transport.is_err());
let err = transport.unwrap_err().to_string();
assert!(err.contains("Unsupported URL scheme"));
}
#[test]
fn test_websocket_transport_new_invalid_url() {
let transport =
WebSocketTransport::new("not a url at all", WebSocketConfig::default());
assert!(transport.is_err());
}
#[tokio::test]
async fn test_websocket_transport_send_before_connect() {
let mut transport =
WebSocketTransport::new("ws://localhost:9999", WebSocketConfig::default()).unwrap();
let result = transport
.send_message(InputMessage::user("hello".into(), "".into()))
.await;
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("Not connected"));
}
#[tokio::test]
async fn test_websocket_transport_disconnect_when_not_connected() {
let mut transport =
WebSocketTransport::new("ws://localhost:9999", WebSocketConfig::default()).unwrap();
let result = transport.disconnect().await;
assert!(result.is_ok());
}
}