use crate::transport::TransportLayer;
use crate::{AhpError, AhpNotification, AhpRequest, AhpResponse, AuthConfig, AuthMethod, Result};
use async_trait::async_trait;
use futures_util::{SinkExt, StreamExt};
use std::collections::HashMap;
use std::sync::Arc;
use tokio::net::TcpStream;
use tokio::sync::Mutex;
use tokio_tungstenite::{connect_async, tungstenite::Message, MaybeTlsStream, WebSocketStream};
pub struct WebSocketTransport {
write: Arc<
Mutex<futures_util::stream::SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, Message>>,
>,
pending_requests: Arc<Mutex<HashMap<String, tokio::sync::oneshot::Sender<AhpResponse>>>>,
}
impl WebSocketTransport {
pub async fn connect(url: impl Into<String>, auth: Option<AuthConfig>) -> Result<Self> {
let mut url_string = url.into();
if let Some(auth_config) = auth {
match auth_config.method {
AuthMethod::ApiKey { key } => {
url_string = format!("{}?api_key={}", url_string, key);
}
AuthMethod::Bearer { token } => {
url_string = format!("{}?token={}", url_string, token);
}
_ => {}
}
}
let (ws_stream, _) = connect_async(&url_string)
.await
.map_err(|e| AhpError::Transport(format!("WebSocket connection failed: {}", e)))?;
let (write, read) = ws_stream.split();
let transport = Self {
write: Arc::new(Mutex::new(write)),
pending_requests: Arc::new(Mutex::new(HashMap::new())),
};
transport.start_reader(read);
Ok(transport)
}
fn start_reader(
&self,
mut read: futures_util::stream::SplitStream<WebSocketStream<MaybeTlsStream<TcpStream>>>,
) {
let pending = self.pending_requests.clone();
tokio::spawn(async move {
while let Some(msg) = read.next().await {
match msg {
Ok(Message::Text(text)) => {
if let Ok(response) = serde_json::from_str::<AhpResponse>(&text) {
let mut pending_guard = pending.lock().await;
if let Some(sender) = pending_guard.remove(&response.id) {
let _ = sender.send(response);
}
}
}
Ok(Message::Close(_)) => break,
Err(_) => break,
_ => {}
}
}
});
}
}
#[async_trait]
impl TransportLayer for WebSocketTransport {
async fn send_request(&self, request: AhpRequest) -> Result<AhpResponse> {
let (tx, rx) = tokio::sync::oneshot::channel();
{
let mut pending = self.pending_requests.lock().await;
pending.insert(request.id.clone(), tx);
}
let json = serde_json::to_string(&request)?;
let mut write = self.write.lock().await;
write
.send(Message::Text(json))
.await
.map_err(|e| AhpError::Transport(format!("Failed to send: {}", e)))?;
drop(write);
match tokio::time::timeout(std::time::Duration::from_millis(10_000), rx).await {
Ok(Ok(response)) => Ok(response),
Ok(Err(_)) => Err(AhpError::ConnectionClosed),
Err(_) => Err(AhpError::Timeout(10_000)),
}
}
async fn send_notification(&self, notification: AhpNotification) -> Result<()> {
let json = serde_json::to_string(¬ification)?;
let mut write = self.write.lock().await;
write
.send(Message::Text(json))
.await
.map_err(|e| AhpError::Transport(format!("Failed to send: {}", e)))?;
Ok(())
}
async fn close(&self) -> Result<()> {
let mut write = self.write.lock().await;
write
.send(Message::Close(None))
.await
.map_err(|e| AhpError::Transport(format!("Failed to close: {}", e)))?;
Ok(())
}
}
pub struct WebSocketServer {
server: Arc<crate::AhpServer>,
}
impl WebSocketServer {
pub fn new(server: Arc<crate::AhpServer>) -> Self {
Self { server }
}
#[cfg(feature = "websocket")]
pub async fn run(self, addr: impl Into<std::net::SocketAddr>) -> Result<()> {
use tokio::net::TcpListener;
use tokio_tungstenite::accept_async;
let listener = TcpListener::bind(addr.into())
.await
.map_err(|e| AhpError::Transport(format!("Failed to bind: {}", e)))?;
tracing::info!(
"WebSocket server listening on {}",
listener.local_addr().unwrap()
);
while let Ok((stream, addr)) = listener.accept().await {
let server = self.server.clone();
tokio::spawn(async move {
if let Err(e) = handle_connection(stream, server, addr).await {
tracing::error!("WebSocket connection error: {}", e);
}
});
}
Ok(())
}
}
#[cfg(feature = "websocket")]
async fn handle_connection(
stream: tokio::net::TcpStream,
server: Arc<crate::AhpServer>,
addr: std::net::SocketAddr,
) -> Result<()> {
use tokio_tungstenite::accept_async;
use tokio_tungstenite::tungstenite::Message as WsMessage;
let ws_stream = accept_async(stream)
.await
.map_err(|e| AhpError::Transport(format!("WebSocket handshake failed: {}", e)))?;
tracing::info!("WebSocket connection established: {}", addr);
let (mut write, mut read) = ws_stream.split();
while let Some(msg) = read.next().await {
match msg {
Ok(WsMessage::Text(text)) => {
if let Ok(request) = serde_json::from_str::<AhpRequest>(&text) {
let response = server.handle_request(request).await;
let json =
serde_json::to_string(&response).map_err(|e| AhpError::Serialization(e))?;
write
.send(WsMessage::Text(json))
.await
.map_err(|e| AhpError::Transport(format!("Failed to send: {}", e)))?;
} else if let Ok(notification) = serde_json::from_str::<AhpNotification>(&text) {
let _ = server.handle_notification(notification).await;
}
}
Ok(WsMessage::Close(_)) => break,
Ok(WsMessage::Ping(data)) => {
write
.send(WsMessage::Pong(data))
.await
.map_err(|e| AhpError::Transport(format!("Failed to send pong: {}", e)))?;
}
Err(e) => {
tracing::error!("WebSocket error: {}", e);
break;
}
_ => {}
}
}
tracing::info!("WebSocket connection closed: {}", addr);
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_websocket_url_with_api_key() {
let auth = Some(AuthConfig::api_key("test-key"));
let url = "ws://localhost:8080/ahp";
assert!(auth.is_some());
}
}