use std::collections::HashMap;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
use futures_util::{SinkExt, StreamExt};
use serde_json::{json, Value};
use tokio::net::TcpStream;
use tokio::sync::{oneshot, Mutex};
use tokio_tungstenite::tungstenite::Message;
use tokio_tungstenite::{MaybeTlsStream, WebSocketStream};
use super::encoding::{from_hex, to_hex};
use crate::errors::HopError;
type WsStream = WebSocketStream<MaybeTlsStream<TcpStream>>;
#[derive(Debug)]
pub struct PoolStatus {
pub entry_count: u64,
pub total_bytes: u64,
pub max_bytes: u64,
}
pub struct WsJsonRpcTransport {
next_id: AtomicU64,
pending: Arc<Mutex<HashMap<u64, oneshot::Sender<Result<Value, HopError>>>>>,
write: Arc<Mutex<Option<futures_util::stream::SplitSink<WsStream, Message>>>>,
_reader_handle: tokio::task::JoinHandle<()>,
}
impl WsJsonRpcTransport {
pub async fn connect(endpoint: &str) -> Result<Self, HopError> {
let (ws_stream, _) = tokio_tungstenite::connect_async(endpoint)
.await
.map_err(|e| HopError::Network(format!("WebSocket connect failed: {e}")))?;
let (write, read) = ws_stream.split();
let pending: Arc<Mutex<HashMap<u64, oneshot::Sender<Result<Value, HopError>>>>> =
Arc::new(Mutex::new(HashMap::new()));
let pending_clone = Arc::clone(&pending);
let reader_handle = tokio::spawn(async move {
let mut read = read;
while let Some(msg) = read.next().await {
let msg = match msg {
Ok(Message::Text(t)) => t,
Ok(Message::Close(_)) => {
let mut map = pending_clone.lock().await;
for (_, tx) in map.drain() {
let _ = tx.send(Err(HopError::Network("WebSocket closed".into())));
}
break;
}
Err(e) => {
let mut map = pending_clone.lock().await;
for (_, tx) in map.drain() {
let _ = tx.send(Err(HopError::Network(format!(
"WebSocket error: {e}"
))));
}
break;
}
_ => continue,
};
let response: Value = match serde_json::from_str(&msg) {
Ok(v) => v,
Err(_) => continue,
};
let id = match response.get("id").and_then(|v| v.as_u64()) {
Some(id) => id,
None => continue,
};
let mut map = pending_clone.lock().await;
if let Some(tx) = map.remove(&id) {
if let Some(err) = response.get("error") {
let code = err.get("code").and_then(|v| v.as_i64()).unwrap_or(0) as i32;
let message = err
.get("message")
.and_then(|v| v.as_str())
.unwrap_or("Unknown error")
.to_string();
let _ = tx.send(Err(map_rpc_error(code, message)));
} else if let Some(result) = response.get("result") {
let _ = tx.send(Ok(result.clone()));
}
}
}
});
Ok(Self {
next_id: AtomicU64::new(1),
pending,
write: Arc::new(Mutex::new(Some(write))),
_reader_handle: reader_handle,
})
}
async fn call(&self, method: &str, params: Value) -> Result<Value, HopError> {
let id = self.next_id.fetch_add(1, Ordering::Relaxed);
let request = json!({
"jsonrpc": "2.0",
"id": id,
"method": method,
"params": params,
});
let (tx, rx) = oneshot::channel();
{
let mut map = self.pending.lock().await;
map.insert(id, tx);
}
{
let mut guard = self.write.lock().await;
let writer = guard
.as_mut()
.ok_or_else(|| HopError::Network("Transport destroyed".into()))?;
writer
.send(Message::Text(request.to_string()))
.await
.map_err(|e| HopError::Network(format!("Send failed: {e}")))?;
}
rx.await
.map_err(|_| HopError::Network("Channel closed".into()))?
}
pub async fn submit(
&self,
data: &[u8],
recipients: &[Vec<u8>],
proof: &[u8],
) -> Result<PoolStatus, HopError> {
let recipients_hex: Vec<String> = recipients.iter().map(|r| to_hex(r)).collect();
let result = self
.call(
"hop_submit",
json!([to_hex(data), recipients_hex, to_hex(proof)]),
)
.await?;
let pool = result
.get("poolStatus")
.ok_or_else(|| HopError::Network("Missing poolStatus in response".into()))?;
Ok(PoolStatus {
entry_count: pool
.get("entryCount")
.and_then(|v| v.as_u64())
.unwrap_or(0),
total_bytes: pool
.get("totalBytes")
.and_then(|v| v.as_u64())
.unwrap_or(0),
max_bytes: pool
.get("maxBytes")
.and_then(|v| v.as_u64())
.unwrap_or(0),
})
}
pub async fn claim(&self, hash: &[u8], signature: &[u8]) -> Result<Vec<u8>, HopError> {
let result = self
.call("hop_claim", json!([to_hex(hash), to_hex(signature)]))
.await?;
let hex_str = result
.as_str()
.ok_or_else(|| HopError::Network("Expected hex string in claim response".into()))?;
from_hex(hex_str).map_err(|e| HopError::Network(format!("Invalid hex in response: {e}")))
}
pub async fn pool_status(&self) -> Result<PoolStatus, HopError> {
let result = self.call("hop_poolStatus", json!([])).await?;
Ok(PoolStatus {
entry_count: result
.get("entryCount")
.and_then(|v| v.as_u64())
.unwrap_or(0),
total_bytes: result
.get("totalBytes")
.and_then(|v| v.as_u64())
.unwrap_or(0),
max_bytes: result
.get("maxBytes")
.and_then(|v| v.as_u64())
.unwrap_or(0),
})
}
pub async fn close(&self) {
let mut guard = self.write.lock().await;
if let Some(mut writer) = guard.take() {
let _ = writer.close().await;
}
}
}
fn map_rpc_error(code: i32, message: String) -> HopError {
match code {
1001 => HopError::DataTooLarge(message),
1002 => HopError::PoolFull(message),
1003 => HopError::NotFound(message),
1004 => HopError::InvalidTicket(message),
1005 => HopError::QuotaExceeded(message),
_ => HopError::Network(format!("RPC error {code}: {message}")),
}
}