cipher-gate 0.3.0

Proxy RPC that routes signing requests to a browser wallet UI
use std::sync::Arc;

use actix_ws::{Message, Session};
use dashmap::DashMap;
use futures_util::StreamExt;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use tokio::sync::{Notify, RwLock, oneshot};

use crate::rpc::JsonRpcResponse;

/// Message sent from proxy to frontend requesting a signature.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SigningRequest {
    pub id: String,
    #[serde(rename = "type")]
    pub msg_type: String, // "signing_request"
    pub method: String,
    pub params: Value,
    #[serde(skip_serializing_if = "Option::is_none")]
    pub simulation: Option<crate::proxy::SimulationResult>,
    #[serde(skip_serializing_if = "Option::is_none")]
    pub decoded_calldata: Option<crate::decode::DecodedCalldata>,
}

#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SigningError {
    pub code: i64,
    pub message: String,
}

/// Envelope for parsing incoming WS messages from frontend.
#[derive(Debug, Clone, Deserialize)]
#[serde(tag = "type")]
pub enum FrontendMessage {
    #[serde(rename = "signing_response")]
    SigningResponse {
        id: String,
        #[serde(default)]
        result: Option<Value>,
        #[serde(default)]
        error: Option<SigningError>,
    },
    #[serde(rename = "account_connected")]
    AccountConnected { address: String },
    #[serde(rename = "account_disconnected")]
    AccountDisconnected,
}

/// Shared application state between HTTP handlers and WebSocket.
pub struct AppState {
    /// Pending signing requests waiting for frontend response.
    pub pending: DashMap<String, oneshot::Sender<JsonRpcResponse>>,
    /// Currently connected wallet address from frontend.
    pub connected_address: RwLock<Option<String>>,
    /// Sender half to push messages to the active WebSocket session.
    pub ws_sender: RwLock<Option<tokio::sync::mpsc::UnboundedSender<String>>>,
    /// Upstream RPC URL.
    pub rpc_url: String,
    /// Chain ID from upstream RPC (hex string, e.g. "0x1").
    pub chain_id: RwLock<Option<String>>,
    /// Shared HTTP client for upstream calls.
    pub http_client: reqwest::Client,
    /// Cache: 4-byte selector → function signature string.
    pub selector_cache: DashMap<[u8; 4], String>,
    /// Secret token the frontend must present on the WebSocket handshake.
    pub auth_token: String,
    /// Browser Origins permitted to open the signing WebSocket.
    pub allowed_origins: Vec<String>,
    /// Notified whenever a frontend WebSocket connects, so queued signing
    /// requests can retry sending instead of failing when no UI is open yet.
    pub frontend_connected: Notify,
}

impl AppState {
    pub fn new(rpc_url: String, auth_token: String, allowed_origins: Vec<String>) -> Self {
        Self {
            pending: DashMap::new(),
            connected_address: RwLock::new(None),
            ws_sender: RwLock::new(None),
            rpc_url,
            chain_id: RwLock::new(None),
            http_client: reqwest::Client::new(),
            selector_cache: DashMap::new(),
            auth_token,
            allowed_origins,
            frontend_connected: Notify::new(),
        }
    }

    /// Send a signing request to the frontend via WebSocket.
    /// Returns a oneshot receiver that will get the response.
    pub async fn send_signing_request(
        &self,
        request_id: String,
        method: String,
        params: Value,
        simulation: Option<crate::proxy::SimulationResult>,
        decoded_calldata: Option<crate::decode::DecodedCalldata>,
    ) -> Option<oneshot::Receiver<JsonRpcResponse>> {
        let ws_sender = self.ws_sender.read().await;
        let sender = ws_sender.as_ref()?;

        let msg = SigningRequest {
            id: request_id.clone(),
            msg_type: "signing_request".into(),
            method,
            params,
            simulation,
            decoded_calldata,
        };

        let json = serde_json::to_string(&msg).ok()?;
        let (tx, rx) = oneshot::channel();
        self.pending.insert(request_id, tx);

        if sender.send(json).is_err() {
            return None;
        }

        Some(rx)
    }
}

/// Handle the WebSocket connection lifecycle.
pub async fn handle_ws_connection(
    state: Arc<AppState>,
    mut session: Session,
    mut msg_stream: actix_ws::MessageStream,
) {
    let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel::<String>();

    // Store the sender so HTTP handlers can push messages
    {
        let mut ws = state.ws_sender.write().await;
        *ws = Some(tx);
    }

    crate::log::ws_connected();

    // Wake any signing requests that were queued waiting for a frontend.
    state.frontend_connected.notify_waiters();

    // Send chain config to frontend on connect
    {
        let chain_id = state.chain_id.read().await;
        if let Some(ref cid) = *chain_id {
            let msg = serde_json::json!({
                "type": "chain_config",
                "chainId": cid
            });
            let _ = session
                .text(bytestring::ByteString::from(msg.to_string()))
                .await;
        }
    }

    let state_clone = state.clone();
    let mut session_clone = session.clone();

    // Task: forward outbound messages from channel to WS
    let send_task = tokio::spawn(async move {
        while let Some(msg) = rx.recv().await {
            if session_clone
                .text(bytestring::ByteString::from(msg))
                .await
                .is_err()
            {
                break;
            }
        }
    });

    // Process incoming messages from frontend
    while let Some(Ok(msg)) = msg_stream.next().await {
        match msg {
            Message::Text(text) => {
                handle_frontend_message(&state_clone, text.to_string()).await;
            }
            Message::Close(reason) => {
                let _ = session.close(reason).await;
                break;
            }
            Message::Ping(bytes) => {
                let _ = session.pong(&bytes).await;
            }
            _ => {}
        }
    }

    // Cleanup
    send_task.abort();
    {
        let mut ws = state.ws_sender.write().await;
        *ws = None;
    }
    {
        let mut addr = state.connected_address.write().await;
        *addr = None;
    }

    crate::log::ws_disconnected();
}

async fn handle_frontend_message(state: &AppState, text: String) {
    let msg: FrontendMessage = match serde_json::from_str(&text) {
        Ok(m) => m,
        Err(e) => {
            eprintln!("  warn: invalid message from frontend: {e}");
            return;
        }
    };

    match msg {
        FrontendMessage::SigningResponse { id, result, error } => {
            if let Some((_, tx)) = state.pending.remove(&id) {
                let response = if let Some(err) = error {
                    JsonRpcResponse {
                        jsonrpc: "2.0".into(),
                        result: None,
                        error: Some(crate::rpc::JsonRpcError {
                            code: err.code,
                            message: err.message,
                            data: None,
                        }),
                        // We don't have the original JSON-RPC id here, the caller will set it
                        id: Value::Null,
                    }
                } else {
                    JsonRpcResponse {
                        jsonrpc: "2.0".into(),
                        result,
                        error: None,
                        id: Value::Null,
                    }
                };
                let _ = tx.send(response);
            } else {
                eprintln!("  warn: signing response for unknown request: {id}");
            }
        }
        FrontendMessage::AccountConnected { address } => {
            crate::log::wallet_connected(&address);
            let mut addr = state.connected_address.write().await;
            *addr = Some(address);
        }
        FrontendMessage::AccountDisconnected => {
            crate::log::wallet_disconnected();
            let mut addr = state.connected_address.write().await;
            *addr = None;
        }
    }
}