ln-websocket-proxy 0.3.0

Websocket-based proxy for connecting to lightning nodes and mutiny wallets
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
use axum::{
    extract::{
        ws::{Message, WebSocket, WebSocketUpgrade},
        Path, Query, State, TypedHeader,
    },
    response::IntoResponse,
    routing::get,
    Router,
};
use bitcoin_hashes::hex::FromHex;
use bytes::Bytes;
use futures::executor::block_on;
use futures::lock::Mutex;
use ln_websocket_proxy::MutinyProxyCommand;
use serde::Deserialize;
use serde_with::{serde_as, NoneAsEmptyString};
use std::collections::HashMap;
use std::collections::HashSet;
use std::env;
use std::net::{SocketAddr, ToSocketAddrs};
use std::sync::Arc;
use tokio::sync::{broadcast, mpsc};
use tokio::{
    io::{AsyncReadExt, AsyncWriteExt},
    net::TcpStream,
};
use tower_http::trace::{DefaultMakeSpan, TraceLayer};

const PUBKEY_BYTES_LEN: usize = 33;

pub(crate) type WSMap =
    Arc<Mutex<HashMap<bytes::Bytes, (mpsc::Sender<MutinyWSCommand>, broadcast::Sender<bool>)>>>;

// TODO make all of these required
// can remove serde_with/serde_as afterwards
#[serde_as]
#[derive(Deserialize)]
struct MutinyConnectionParams {
    #[serde_as(as = "NoneAsEmptyString")]
    _message: Option<String>,
    #[serde_as(as = "NoneAsEmptyString")]
    _session_id: Option<String>,
    #[serde_as(as = "NoneAsEmptyString")]
    _signature: Option<String>,
}

#[tokio::main]
async fn main() {
    println!("Running ln-websocket-proxy");
    tracing_subscriber::fmt::init();

    let producer_map: WSMap = Arc::new(Mutex::new(HashMap::new()));

    let app = Router::new()
        .route("/v1/:ip/:port", get(ws_handler))
        .route("/v1/mutiny/:identifier", get(mutiny_ws_handler))
        .with_state(producer_map)
        .layer(
            TraceLayer::new_for_http()
                .make_span_with(DefaultMakeSpan::default().include_headers(true)),
        );

    let port = match env::var("LN_PROXY_PORT") {
        Ok(p) => p.parse().expect("port must be a u16 string"),
        Err(_) => 3001,
    };
    let addr = SocketAddr::from(([0, 0, 0, 0], port));
    tracing::info!("listening on {}", addr);
    axum::Server::bind(&addr)
        .serve(app.into_make_service())
        .await
        .unwrap();
    println!("Stopping websocket-tcp-proxy");
}

async fn ws_handler(
    Path((ip, port)): Path<(String, String)>,
    ws: WebSocketUpgrade,
    user_agent: Option<TypedHeader<headers::UserAgent>>,
) -> impl IntoResponse {
    tracing::info!("ip: {}, port: {}", ip, port);
    if let Some(TypedHeader(user_agent)) = user_agent {
        tracing::info!("`{}` connected", user_agent.as_str());
    }

    ws.protocols(["binary"])
        .on_upgrade(move |socket| handle_socket(socket, ip, port))
}

fn format_addr_from_url(ip: String, port: String) -> String {
    format!("{}:{}", ip.replace('_', "."), port)
}

// Big help from https://github.com/HsuJv/axum-websockify
async fn handle_socket(mut socket: WebSocket, host: String, port: String) {
    let addr_str = format_addr_from_url(host, port);
    let addrs = addr_str.to_socket_addrs();

    if addrs.is_err() || addrs.as_ref().unwrap().len() == 0 {
        tracing::error!("Could not resolve addr {addr_str}");
        let _ = socket
            .send(Message::Text(format!("Could not resolve addr {addr_str}")))
            .await;
        return;
    }

    let mut addrs = addrs.unwrap();

    let server_stream = addrs.find_map(|addr| {
        let connection = block_on(TcpStream::connect(&addr));
        if let Err(error) = &connection {
            tracing::error!("Could not connect to {addr}: {error}");
        };

        connection.ok()
    });

    if server_stream.is_none() {
        tracing::error!("Could not connect to: {addr_str}");
        let _ = socket
            .send(Message::Text(format!("Could not connect to: {addr_str}")))
            .await;
        return;
    }

    let mut server_stream = server_stream.unwrap();

    let addr = server_stream.peer_addr().unwrap();

    let mut buf = [0u8; 65536]; // the max lightning message size is 65536

    loop {
        tokio::select! {
            res  = socket.recv() => {
                if let Some(msg) = res {
                    if let Ok(Message::Binary(msg)) = msg {
                        tracing::debug!("Received {}, sending to {addr}", &msg.len());
                        let _ = server_stream.write_all(&msg).await;
                    }
                } else {
                    tracing::info!("Client close");
                    return;
                }
            },
            res  = server_stream.read(&mut buf) => {
                match res {
                    Ok(n) => {
                        tracing::debug!("Read {:?} from {addr}", n);
                        if 0 != n {
                            let _ = socket.send(Message::Binary(buf[..n].to_vec())).await;
                        } else {
                            return ;
                        }
                    },
                    Err(e) => {
                        tracing::info!("Server close with err {:?}", e);
                        return;
                    }
                }
            },
        }
    }
}

async fn mutiny_ws_handler(
    params: Option<Query<MutinyConnectionParams>>,
    Path(identifier): Path<String>,
    State(state): State<WSMap>,
    ws: WebSocketUpgrade,
    user_agent: Option<TypedHeader<headers::UserAgent>>,
) -> impl IntoResponse {
    tracing::info!("new mutiny websocket handler: {identifier}");
    if let Some(TypedHeader(user_agent)) = user_agent {
        tracing::info!("`{}` connected", user_agent.as_str());
    }

    ws.protocols(["binary"])
        .on_upgrade(move |socket| handle_mutiny_ws(socket, identifier, params, state))
}

#[derive(Debug)]
enum MutinyWSCommand {
    Send { id: Bytes, val: Bytes },
    Disconnect { id: Bytes },
}

/// handle_mutiny_ws will handle mutiny to mutiny (ws to ws) logic.
/// A node pubkey will have a connection URL like: /v1/mutiny/{identifier}
/// where identifier is either going to be arbitrary or based on their node
/// pubkey. Future iterations might want a single identifier for all their
/// nodes. This should be persistent enough to allow others to reconnect.
///
/// Owners:
/// Need to send a signed message with the private key in order to
/// verify that they are the owners. Afterwards, they will receive
/// all incoming messages. If owner is already registered, kill new one.
///
/// Sending:
/// Indicate which identifier you would like to message and the bytes to
/// send. This should be the first 33 bytes of the message. If the owner
/// is not connected or disconnections, the connection should be killed.
/// This proxy will replace the 33 bytes with the identifier of the sender.
///
/// Receiving:
/// You will receive a message with the first 33 bytes being the identifier
/// that has sent the message and the rest of the bytes being the message.
/// When replying to a received message, set the first 33 bytes to be the
/// destination that had sent to you. IE, keeping same first 33 bytes.
async fn handle_mutiny_ws(
    mut socket: WebSocket,
    identifier: String,
    _params: Option<Query<MutinyConnectionParams>>,
    state: WSMap,
) {
    // TODO do verification on the params and identifier
    // This is important so that only the node with the
    // private key can read and send messages through
    // this socket.
    #[allow(clippy::redundant_closure)]
    let owner_id_bytes = FromHex::from_hex(identifier.as_str())
        .map(|h: Vec<u8>| bytes::Bytes::from(h))
        .unwrap_or_default();
    if owner_id_bytes.is_empty() {
        tracing::error!("could not parse hex string identifier");
        return;
    }

    // Now create one consumer and a producer that other
    // mutiny websocket connections can reference to send
    // to later. The consumer here is to listen to events
    // that should be sent down the websocket that owns this.
    let (tx, mut rx) = mpsc::channel::<MutinyWSCommand>(32);

    // Create a broadcast channel that this websocket owner can post
    // to in order to indicate that the websocket owner went away and
    // that all previously connected peers need to force a disconnect.
    // The boolean is arbitrary, we just need to send something, consumers
    // should know who this is from and what it means.
    let (bc_tx, _bc_rx1) = broadcast::channel::<bool>(32);

    state
        .lock()
        .await
        .insert(owner_id_bytes.clone(), (tx.clone(), bc_tx.clone()));

    // keep track of the peers that this websocket owner is connected to
    let connected_peers = Arc::new(Mutex::new(HashSet::<bytes::Bytes>::new()));

    tracing::debug!("listening for {identifier} websocket or consumer channel",);
    loop {
        tokio::select! {
            // The websocket owner is sending a message to some peer
            // or got disconnected.
            res  = socket.recv() => {
                if let Some(msg) = res {
                    if let Ok(msg_wrapper) = msg {
                        match msg_wrapper {
                            Message::Text(msg) => {
                                let command: MutinyProxyCommand = match serde_json::from_str(&msg) {
                                    Ok(c) => c,
                                    Err(e) => {
                                        tracing::error!("couldn't parse text command from client, ignoring: {e}");
                                        continue;
                                    }
                                };
                                match command {
                                    MutinyProxyCommand::Disconnect { to, from: _from } => {
                                        // ignore the from and take it from our websocket owner
                                        // find out who we are supposed to send this to and get
                                        // producer
                                        let peer_id_bytes = bytes::Bytes::from(to);
                                        if let Some((peer_tx, _bc_tx)) = state.lock().await.get(&peer_id_bytes) {
                                            try_send_disconnect_ws_command(peer_tx.clone(), owner_id_bytes.clone()).await;
                                            connected_peers.lock().await.remove(&peer_id_bytes);
                                        } else {
                                            tracing::error!("peer tried to disconnect someone not connected to");
                                        }
                                    }
                                };
                            },
                            Message::Binary(msg) => {
                                // parse the first 33 bytes to find the ID to send to
                                if msg.len() < PUBKEY_BYTES_LEN {
                                    tracing::error!("msg not long enough to have pubkey (had {}), ignoring...", msg.len());
                                    continue
                                }
                                let (id_bytes, message_bytes) = msg.split_at(PUBKEY_BYTES_LEN);
                                let peer_id_bytes = bytes::Bytes::from(id_bytes.to_vec());
                                tracing::debug!("received a ws msg from {identifier} to send to {:?}", peer_id_bytes);

                                // find the producer and send down it
                                if let Some((peer_tx, bc_tx)) = state.lock().await.get(&peer_id_bytes) {
                                    match peer_tx.send(MutinyWSCommand::Send { id: owner_id_bytes.clone(), val: bytes::Bytes::from(message_bytes.to_vec()) }).await {
                                        Ok(_) => {
                                            // Keep track that this websocket owner is connected to this
                                            // peer. We will need to know when to send a disconnect cmd
                                            // message back to the websocket owner if this peer goes
                                            // offline.
                                            tracing::debug!("successfully sent msg to {:?}", peer_id_bytes);
                                            listen_for_disconnections(connected_peers.clone(), peer_id_bytes.clone(), bc_tx.subscribe(), tx.clone()).await;
                                        },
                                        Err(e) => {
                                            tracing::error!("could not send message to peer identity: {}", e);
                                            // return a close command, we are having a problem sending
                                            // to the other peer's consumer
                                            try_send_disconnect_ws_command(tx.clone(), peer_id_bytes).await;
                                        },
                                    }
                                } else {
                                    // if no producer, return a close command
                                    tracing::error!("no producer found, sending disconnect");
                                    try_send_disconnect_ws_command(tx.clone(), peer_id_bytes).await;
                                }
                            },
                            _ => {
                                // don't care about pings or others...
                            },
                        };
                    }
                } else {
                    // Websocket owner closed the connection, let's remove the
                    // producer from state. When others try to access producer
                    // again, they will not find it and need to close the conn.
                    //
                    // we should accelerate the disconnection instead of
                    // rely on the next message sent causing a disconnection.
                    try_broadcast_disconnect(bc_tx);
                    state.lock().await.remove(&owner_id_bytes);
                    tracing::info!("Websocket owner closed the connection");
                    return;
                }
            },
            // some peer is trying to send a message to the websocket owner
            // or a disconnection happened and the websocket owner needs to
            // disconnect from that peer.
            res  = rx.recv() => {
                match res {
                    Some(message) => {
                        match message {
                            MutinyWSCommand::Send{id, val} => {
                                tracing::debug!("received a channel msg from {:?} to send to {identifier}", id);
                                // put in first 33 bytes as from ID
                                let mut concat_bytes = id[..].to_vec();
                                let mut val_bytes = val[..].to_vec();
                                concat_bytes.append(&mut val_bytes);
                                match socket.send(Message::Binary(concat_bytes)).await {
                                    Ok(_) => {
                                        // Some other peer has successfully sent a message to this
                                        // websocket owner. We should find the broadcast channel
                                        // for that peer and let this websocket owner listen for
                                        // when it needs to disconnect.
                                        // TODO, but maybe it's not really needed because the
                                        // websocket owner SHOULD send a message back for us to
                                        // consider them connected, in which case the other flow
                                        // should add the listener.
                                        tracing::debug!("sent channel msg down socket from {:?} to to {identifier}", id);
                                    },
                                    Err(e) => {
                                        // if we can't send down websocket, kill the connection
                                        // send a disconnection to all peers connected to this peer
                                        tracing::error!("could not send message to ws owner: {}", e);
                                        try_broadcast_disconnect(bc_tx);
                                        state.lock().await.remove(&owner_id_bytes);
                                        return;
                                    },
                                }
                            }
                            MutinyWSCommand::Disconnect{id} => {
                                tracing::debug!("received a channel msg from {:?} to disconnect from {identifier}", id);
                                match socket.send(Message::Text(serde_json::to_string(&MutinyProxyCommand::Disconnect{to: owner_id_bytes.to_vec(), from: id.to_vec()}).unwrap())).await {
                                    Ok(_) => (),
                                    Err(e) => {
                                        // if we can't send down websocket, kill the connection
                                        // send a disconnection to all peers connected to this peer
                                        tracing::error!("could not send message to ws owner: {}", e);
                                        try_broadcast_disconnect(bc_tx);
                                        state.lock().await.remove(&owner_id_bytes);
                                        return;
                                    },
                                }
                            }
                        };
                    },
                    None => {
                        // send a disconnection to all peers
                        // that are connected to this peer
                        tracing::info!("channel closed");
                        try_broadcast_disconnect(bc_tx);
                        state.lock().await.remove(&owner_id_bytes);
                        return;
                    }
                }
            },
        }
    }
}

async fn listen_for_disconnections(
    connected_peers: Arc<Mutex<HashSet<bytes::Bytes>>>,
    other_peer: bytes::Bytes,
    mut rx: broadcast::Receiver<bool>,
    tx: mpsc::Sender<MutinyWSCommand>,
) {
    let mut locked_connected_peers = connected_peers.lock().await;
    if locked_connected_peers.contains(&other_peer) {
        return;
    }
    locked_connected_peers.insert(other_peer.clone());
    let listening_connected_peers = connected_peers.clone();
    tokio::spawn(async move {
        match rx.recv().await {
            Ok(_) => {
                // we should send a disconnection message from
                // the other peer to the websocket owner
                // we'll use the websocket command flow since that'll
                // handle the flow just fine
                try_send_disconnect_ws_command(tx.clone(), other_peer.clone()).await;
            }
            Err(e) => {
                // we got an error? well disconnect anyways I guess, but log it!
                tracing::error!(
                    "got an error listening for broadcast messages from {:?}: {}",
                    other_peer,
                    e
                );
                try_send_disconnect_ws_command(tx.clone(), other_peer.clone()).await;
            }
        };
        // should only take one message to know to disconnect
        // so we should remove the peer from owner's connected list
        // this is needed so we can listen again!
        listening_connected_peers.lock().await.remove(&other_peer);
    });
}

fn try_broadcast_disconnect(bc_tx: broadcast::Sender<bool>) {
    match bc_tx.send(true) {
        Ok(_) => (),
        Err(e) => {
            // our best effort was made to inform others that we've
            // disconnected this peer. Log it and move on.
            // We really shouldn't see this happen, would indicate a problem
            // handling channels that we should fix.
            tracing::error!(
                "could not broadcast that we've disconnected websocket owner: {}",
                e
            );
        }
    };
}

async fn try_send_disconnect_ws_command(
    tx: mpsc::Sender<MutinyWSCommand>,
    other_peer: bytes::Bytes,
) {
    match tx
        .send(MutinyWSCommand::Disconnect { id: other_peer })
        .await
    {
        Ok(_) => (),
        Err(e) => {
            tracing::error!("could not send disconnect msg to self: {}", e);
        }
    };
}

#[cfg(test)]
mod tests {
    use crate::format_addr_from_url;

    #[tokio::test]
    async fn test_format_addr_from_url() {
        assert_eq!(
            "127.0.0.1:9000",
            format_addr_from_url(String::from("127_0_0_1"), String::from("9000"))
        )
    }
}