chainseeker_server/
web_socket_relay.rs

1use std::sync::Arc;
2use futures_util::{StreamExt, SinkExt};
3use tokio::sync::RwLock;
4use tokio::net::TcpListener;
5use tokio_tungstenite::tungstenite::Message;
6
7#[derive(Debug, Clone)]
8pub struct WebSocketRelay {
9    zmq_endpoint: String,
10    ws_endpoint: String,
11    stop: Arc<RwLock<bool>>,
12    ready: Arc<RwLock<bool>>,
13}
14
15impl WebSocketRelay {
16    pub fn new(zmq_endpoint: &str, ws_endpoint: &str) -> Self {
17        Self {
18            zmq_endpoint: zmq_endpoint.to_string(),
19            ws_endpoint: ws_endpoint.to_string(),
20            stop: Arc::new(RwLock::new(false)),
21            ready: Arc::new(RwLock::new(false)),
22        }
23    }
24    pub async fn run(&self) {
25        let stop = self.stop.clone();
26        tokio::spawn(async move {
27            tokio::signal::ctrl_c().await.expect("Failed to install CTRL+C signal handler.");
28            *stop.write().await = true;
29        });
30        //println!("WebSocketRelay: waiting for a ZeroMQ message...");
31        // Create a WebSocket server.
32        let ws_endpoint = self.ws_endpoint.clone();
33        let (tx, rx) = tokio::sync::watch::channel("".to_string());
34        let ready = self.ready.clone();
35        tokio::spawn(async move {
36            let listener = TcpListener::bind(&ws_endpoint).await.unwrap();
37            println!("WebSocketRelay: listening on {}", ws_endpoint);
38            *ready.write().await = true;
39            loop {
40                if let Ok((stream, _)) = listener.accept().await {
41                    let mut rx = rx.clone();
42                    tokio::spawn(async move {
43                        let addr = stream.peer_addr().unwrap();
44                        let ws_stream = tokio_tungstenite::accept_async(stream).await;
45                        if ws_stream.is_err() {
46                            // Invalid request from client.
47                            return;
48                        }
49                        println!("WebSocketRelay: new connection from {}.", addr);
50                        let (mut write, _read) = ws_stream.unwrap().split();
51                        while rx.changed().await.is_ok() {
52                            let message = (*rx.borrow()).to_string();
53                            match write.send(Message::Text(message)).await {
54                                Ok(_) => {},
55                                // Connection lost.
56                                Err(_) => break,
57                            }
58                        }
59                    });
60                }
61            }
62        });
63        // Connect to ZMQ.
64        let zmq_ctx = zmq::Context::new();
65        let socket = zmq_ctx.socket(zmq::SocketType::SUB).expect("Failed to open a ZeroMQ socket.");
66        socket.connect(&self.zmq_endpoint).expect("Failed to connect to a ZeroMQ endpoint.");
67        socket.set_subscribe(b"hashblock").expect("Failed to subscribe to a ZeroMQ topic.");
68        socket.set_subscribe(b"hashtx").expect("Failed to subscribe to a ZeroMQ topic.");
69        loop {
70            if *self.stop.read().await {
71                break;
72            }
73            let multipart = socket.recv_multipart(zmq::DONTWAIT);
74            match multipart {
75                Ok(multipart) => {
76                    assert_eq!(multipart.len(), 3);
77                    let topic = std::str::from_utf8(&multipart[0]).expect("Failed to decode ZeroMQ topic.").to_string();
78                    let hash = &multipart[1];
79                    //println!("WebSocketRelay: {} {} {}", topic, hex::encode(hash), hex::encode(&multipart[2]));
80                    let json = serde_json::to_string(&vec![topic, hex::encode(hash)]).unwrap();
81                    tx.send(json).unwrap();
82                },
83                Err(_) => {
84                    //println!("WebSockerRelay: failed to receive a message from ZeroMq.");
85                    tokio::time::sleep(std::time::Duration::from_millis(100)).await;
86                    continue;
87                },
88            }
89        }
90        println!("WebSocketRelay stopped.");
91    }
92    pub async fn ready(&self) -> bool {
93        *self.ready.read().await
94    }
95    pub async fn wait_for_ready(&self) {
96        while !self.ready().await {
97            tokio::time::sleep(std::time::Duration::from_millis(1)).await;
98        }
99    }
100    pub async fn stop(&self) {
101        *self.stop.write().await = true;
102    }
103}
104
105#[cfg(test)]
106mod tests {
107    use super::*;
108    #[tokio::test(flavor = "multi_thread")]
109    async fn web_socket_relay() {
110        //const ZMQ_ENDPOINT: &str = "inproc://web-socket-relay-zmq";
111        const ZMQ_PORT: u16 = 5555;
112        const WS_PORT: u16 = 6666;
113        const BLOCK_HASH: &str = "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef";
114        const TXID: &str = "fedcba9876543210fedcba9876543210fedcba9876543210fedcba9876543210";
115        // Create ZeroMQ server.
116        let zmq_ctx = zmq::Context::new();
117        let socket = zmq_ctx.socket(zmq::SocketType::PUB).unwrap();
118        //socket.bind(ZMQ_ENDPOINT).unwrap();
119        socket.bind("tcp://lo:5555").unwrap();
120        println!("ZeroMQ server created.");
121        // Run relay.
122        let relay = WebSocketRelay::new(&format!("tcp://localhost:{}", ZMQ_PORT), &format!("localhost:{}", WS_PORT));
123        let handle = {
124            let relay = relay.clone();
125            tokio::spawn(async move {
126                relay.run().await;
127            })
128        };
129        // Wait before WebSocketRelay is ready.
130        relay.wait_for_ready().await;
131        // Create WebSocket client.
132        println!("Creating WebSocket client...");
133        let (ws_stream, _) = tokio_tungstenite::connect_async(&format!("ws://localhost:{}", WS_PORT)).await.unwrap();
134        let (_write, mut read) = ws_stream.split();
135        // Send "hashblock" message.
136        println!("Sending \"hashblock\"...");
137        socket.send_multipart(vec![
138            "hashblock".to_string().into_bytes(),
139            hex::decode(BLOCK_HASH).unwrap(),
140            0u32.to_le_bytes().to_vec(),
141        ], zmq::DONTWAIT).unwrap();
142        println!("Reading a message from WebSocket...");
143        let msg = read.next().await.unwrap().unwrap().into_data();
144        assert_eq!(String::from_utf8(msg).unwrap(), format!("[\"hashblock\",\"{}\"]", BLOCK_HASH));
145        // Send "hashtx" message.
146        println!("Sending \"hashtx\"...");
147        socket.send_multipart(vec![
148            "hashtx".to_string().into_bytes(),
149            hex::decode(TXID).unwrap(),
150            1u32.to_le_bytes().to_vec(),
151        ], zmq::DONTWAIT).unwrap();
152        println!("Reading a message from WebSocket...");
153        let msg = read.next().await.unwrap().unwrap().into_data();
154        assert_eq!(String::from_utf8(msg).unwrap(), format!("[\"hashtx\",\"{}\"]", TXID));
155        relay.stop().await;
156        handle.await.unwrap();
157    }
158}