use super::types::*;
use anyhow::{Context, Result};
use futures_util::{SinkExt, StreamExt};
use parking_lot::RwLock;
use serde_json::{Value, json};
use std::collections::HashMap;
use std::sync::Arc;
use tokio::net::TcpStream;
use tokio::sync::{Mutex, oneshot};
use tokio_tungstenite::{MaybeTlsStream, WebSocketStream, connect_async, tungstenite::Message};
use tracing::{debug, error, warn};
use uuid::Uuid;
type WsStream = WebSocketStream<MaybeTlsStream<TcpStream>>;
pub struct WebSocketRpcClient {
url: String,
ws: Arc<Mutex<Option<WsStream>>>,
pending: Arc<Mutex<HashMap<String, oneshot::Sender<Value>>>>,
running: Arc<Mutex<bool>>,
pub output_sender: Arc<RwLock<Option<flume::Sender<ScriptOutput>>>>,
}
impl WebSocketRpcClient {
pub fn new(url: String) -> Self {
Self {
url,
ws: Arc::new(Mutex::new(None)),
pending: Arc::new(Mutex::new(HashMap::new())),
running: Arc::new(Mutex::new(false)),
output_sender: Arc::new(RwLock::new(None)),
}
}
pub fn set_output_channel(&self, sender: flume::Sender<ScriptOutput>) {
*self.output_sender.write() = Some(sender);
}
pub async fn connect(&self) -> Result<()> {
debug!("Connecting to WebSocket RPC server: {}", self.url);
println!("DEBUG: connect() called for URL: {}", self.url);
if self.is_connected().await {
println!("DEBUG: Already connected, skipping reconnection");
return Ok(());
}
let (ws_stream, _) = connect_async(&self.url)
.await
.context("Failed to connect to WebSocket server")?;
*self.ws.lock().await = Some(ws_stream);
*self.running.lock().await = true;
self.start_handler().await;
debug!("Connected to WebSocket RPC server");
println!("DEBUG: Successfully connected and handler started");
Ok(())
}
pub async fn disconnect(&self) -> Result<()> {
*self.running.lock().await = false;
if let Some(mut ws) = self.ws.lock().await.take() {
ws.close(None).await?;
}
let mut pending = self.pending.lock().await;
for (_, tx) in pending.drain() {
let _ = tx.send(Value::Null);
}
Ok(())
}
pub async fn call(&self, method: &str, params: Value) -> Result<Value> {
println!("DEBUG: call() invoked with method: {}", method);
let id = Uuid::new_v4().to_string();
let request = RpcRequest::new(id.clone(), method.to_string(), params);
let message = serde_json::to_string(&request)?;
println!("DEBUG: Sending RPC request: {}", message);
let (tx, rx) = oneshot::channel();
self.pending.lock().await.insert(id.clone(), tx);
let mut ws_guard = self.ws.lock().await;
if let Some(ws) = ws_guard.as_mut() {
println!("DEBUG: Sending message via WebSocket");
ws.send(Message::text(message)).await?;
println!("DEBUG: Message sent successfully");
} else {
println!("DEBUG: WebSocket not connected!");
return Err(anyhow::anyhow!("Not connected to WebSocket server"));
}
drop(ws_guard);
println!("DEBUG: Waiting for response with timeout...");
match tokio::time::timeout(std::time::Duration::from_secs(30), rx).await {
Ok(Ok(response)) => {
println!("DEBUG: Got response!");
Ok(response)
}
Ok(Err(_)) => {
println!("DEBUG: RPC call cancelled");
Err(anyhow::anyhow!("RPC call cancelled"))
}
Err(_) => {
println!("DEBUG: RPC call timed out after 30 seconds");
self.pending.lock().await.remove(&id);
Err(anyhow::anyhow!("RPC call timed out"))
}
}
}
async fn start_handler(&self) {
let ws = self.ws.clone();
let pending = self.pending.clone();
let running = self.running.clone();
let output_sender = self.output_sender.clone();
tokio::spawn(async move {
while *running.lock().await {
let msg = {
let mut ws_guard = ws.lock().await;
if let Some(ws_stream) = ws_guard.as_mut() {
match tokio::time::timeout(
std::time::Duration::from_millis(100),
ws_stream.next(),
)
.await
{
Ok(Some(msg)) => Some(msg),
Ok(None) => None,
Err(_) => None, }
} else {
None
}
};
match msg {
Some(Ok(Message::Text(text))) => {
let pending_clone = pending.clone();
let output_sender_clone = output_sender.clone();
WebSocketRpcClient::handle_message(
text.to_string(),
&pending_clone,
output_sender_clone,
)
.await;
}
Some(Ok(Message::Close(_))) => {
warn!("WebSocket connection closed");
*running.lock().await = false;
break;
}
Some(Err(e)) => {
error!("WebSocket error: {}", e);
*running.lock().await = false;
break;
}
None => {
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
}
_ => {}
}
}
});
}
async fn handle_message(
text: String,
pending: &Arc<Mutex<HashMap<String, oneshot::Sender<Value>>>>,
output_sender: Arc<RwLock<Option<flume::Sender<ScriptOutput>>>>,
) {
match serde_json::from_str::<WebSocketMessage>(&text) {
Ok(WebSocketMessage::Response(response)) => {
if let Some(tx) = pending.lock().await.remove(&response.id) {
if let Some(result) = response.result {
let _ = tx.send(result);
} else if let Some(error) = response.error {
let _ = tx.send(json!({
"error": {
"code": error.code,
"message": error.message,
"data": error.data
}
}));
} else {
let _ = tx.send(Value::Null);
}
}
}
Ok(WebSocketMessage::Notification(notification)) => {
if notification.method == "output" || notification.method == "script_output" {
if let Some(sender) = &*output_sender.read() {
let output_data = if notification.method == "script_output" {
notification.params
} else {
notification.params
};
if let Ok(output) = serde_json::from_value::<ScriptOutput>(output_data) {
debug!(
"Received async output from script: {} on port {}",
output.actor_id, output.port
);
let _ = sender.send(output);
}
}
} else if notification.method == "log" {
if let Some(log_msg) = notification.params.as_str() {
debug!("Script log: {}", log_msg);
}
} else if notification.method == "state_update" {
debug!("Script state update: {:?}", notification.params);
}
}
Err(e) => {
warn!("Failed to parse WebSocket message: {}", e);
}
}
}
pub async fn is_connected(&self) -> bool {
println!("DEBUG: is_connected - acquiring running lock");
let running = *self.running.lock().await;
println!("DEBUG: is_connected - running = {}", running);
println!("DEBUG: is_connected - acquiring ws lock");
let has_ws = self.ws.lock().await.is_some();
println!("DEBUG: is_connected - has_ws = {}", has_ws);
running && has_ws
}
pub async fn ensure_connected(&self) -> Result<()> {
println!("DEBUG: ensure_connected - checking connection");
let connected = self.is_connected().await;
println!("DEBUG: is_connected = {}", connected);
if !connected {
println!("DEBUG: Not connected, attempting to connect...");
self.connect().await?;
println!("DEBUG: Connected successfully");
}
Ok(())
}
}
impl Drop for WebSocketRpcClient {
fn drop(&mut self) {
let ws = self.ws.clone();
let running = self.running.clone();
tokio::spawn(async move {
*running.lock().await = false;
if let Some(mut ws_stream) = ws.lock().await.take() {
let _ = ws_stream.close(None).await;
}
});
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_rpc_request_serialization() {
let request = RpcRequest::new(
"test-id".to_string(),
"process".to_string(),
json!({"foo": "bar"}),
);
let serialized = serde_json::to_string(&request).unwrap();
assert!(serialized.contains("\"jsonrpc\":\"2.0\""));
assert!(serialized.contains("\"id\":\"test-id\""));
assert!(serialized.contains("\"method\":\"process\""));
}
#[tokio::test]
async fn test_rpc_response_deserialization() {
let json = r#"{
"jsonrpc": "2.0",
"id": "test-id",
"result": {"status": "ok"}
}"#;
let response: RpcResponse = serde_json::from_str(json).unwrap();
assert_eq!(response.id, "test-id");
assert!(response.result.is_some());
assert!(response.error.is_none());
}
}