use axum::{
extract::{
ws::{Message, WebSocket, WebSocketUpgrade},
Path, Query, State, TypedHeader,
},
response::IntoResponse,
routing::get,
Router,
};
use bitcoin_hashes::hex::FromHex;
use bytes::Bytes;
use futures::executor::block_on;
use futures::lock::Mutex;
use ln_websocket_proxy::MutinyProxyCommand;
use serde::Deserialize;
use serde_with::{serde_as, NoneAsEmptyString};
use std::collections::HashMap;
use std::collections::HashSet;
use std::env;
use std::net::{SocketAddr, ToSocketAddrs};
use std::sync::Arc;
use tokio::sync::{broadcast, mpsc};
use tokio::{
io::{AsyncReadExt, AsyncWriteExt},
net::TcpStream,
};
use tower_http::trace::{DefaultMakeSpan, TraceLayer};
const PUBKEY_BYTES_LEN: usize = 33;
pub(crate) type WSMap =
Arc<Mutex<HashMap<bytes::Bytes, (mpsc::Sender<MutinyWSCommand>, broadcast::Sender<bool>)>>>;
#[serde_as]
#[derive(Deserialize)]
struct MutinyConnectionParams {
#[serde_as(as = "NoneAsEmptyString")]
_message: Option<String>,
#[serde_as(as = "NoneAsEmptyString")]
_session_id: Option<String>,
#[serde_as(as = "NoneAsEmptyString")]
_signature: Option<String>,
}
#[tokio::main]
async fn main() {
println!("Running ln-websocket-proxy");
tracing_subscriber::fmt::init();
let producer_map: WSMap = Arc::new(Mutex::new(HashMap::new()));
let app = Router::new()
.route("/v1/:ip/:port", get(ws_handler))
.route("/v1/mutiny/:identifier", get(mutiny_ws_handler))
.with_state(producer_map)
.layer(
TraceLayer::new_for_http()
.make_span_with(DefaultMakeSpan::default().include_headers(true)),
);
let port = match env::var("LN_PROXY_PORT") {
Ok(p) => p.parse().expect("port must be a u16 string"),
Err(_) => 3001,
};
let addr = SocketAddr::from(([0, 0, 0, 0], port));
tracing::info!("listening on {}", addr);
axum::Server::bind(&addr)
.serve(app.into_make_service())
.await
.unwrap();
println!("Stopping websocket-tcp-proxy");
}
async fn ws_handler(
Path((ip, port)): Path<(String, String)>,
ws: WebSocketUpgrade,
user_agent: Option<TypedHeader<headers::UserAgent>>,
) -> impl IntoResponse {
tracing::info!("ip: {}, port: {}", ip, port);
if let Some(TypedHeader(user_agent)) = user_agent {
tracing::info!("`{}` connected", user_agent.as_str());
}
ws.protocols(["binary"])
.on_upgrade(move |socket| handle_socket(socket, ip, port))
}
fn format_addr_from_url(ip: String, port: String) -> String {
format!("{}:{}", ip.replace('_', "."), port)
}
async fn handle_socket(mut socket: WebSocket, host: String, port: String) {
let addr_str = format_addr_from_url(host, port);
let addrs = addr_str.to_socket_addrs();
if addrs.is_err() || addrs.as_ref().unwrap().len() == 0 {
tracing::error!("Could not resolve addr {addr_str}");
let _ = socket
.send(Message::Text(format!("Could not resolve addr {addr_str}")))
.await;
return;
}
let mut addrs = addrs.unwrap();
let server_stream = addrs.find_map(|addr| {
let connection = block_on(TcpStream::connect(&addr));
if let Err(error) = &connection {
tracing::error!("Could not connect to {addr}: {error}");
};
connection.ok()
});
if server_stream.is_none() {
tracing::error!("Could not connect to: {addr_str}");
let _ = socket
.send(Message::Text(format!("Could not connect to: {addr_str}")))
.await;
return;
}
let mut server_stream = server_stream.unwrap();
let addr = server_stream.peer_addr().unwrap();
let mut buf = [0u8; 65536];
loop {
tokio::select! {
res = socket.recv() => {
if let Some(msg) = res {
if let Ok(Message::Binary(msg)) = msg {
tracing::debug!("Received {}, sending to {addr}", &msg.len());
let _ = server_stream.write_all(&msg).await;
}
} else {
tracing::info!("Client close");
return;
}
},
res = server_stream.read(&mut buf) => {
match res {
Ok(n) => {
tracing::debug!("Read {:?} from {addr}", n);
if 0 != n {
let _ = socket.send(Message::Binary(buf[..n].to_vec())).await;
} else {
return ;
}
},
Err(e) => {
tracing::info!("Server close with err {:?}", e);
return;
}
}
},
}
}
}
async fn mutiny_ws_handler(
params: Option<Query<MutinyConnectionParams>>,
Path(identifier): Path<String>,
State(state): State<WSMap>,
ws: WebSocketUpgrade,
user_agent: Option<TypedHeader<headers::UserAgent>>,
) -> impl IntoResponse {
tracing::info!("new mutiny websocket handler: {identifier}");
if let Some(TypedHeader(user_agent)) = user_agent {
tracing::info!("`{}` connected", user_agent.as_str());
}
ws.protocols(["binary"])
.on_upgrade(move |socket| handle_mutiny_ws(socket, identifier, params, state))
}
#[derive(Debug)]
enum MutinyWSCommand {
Send { id: Bytes, val: Bytes },
Disconnect { id: Bytes },
}
async fn handle_mutiny_ws(
mut socket: WebSocket,
identifier: String,
_params: Option<Query<MutinyConnectionParams>>,
state: WSMap,
) {
#[allow(clippy::redundant_closure)]
let owner_id_bytes = FromHex::from_hex(identifier.as_str())
.map(|h: Vec<u8>| bytes::Bytes::from(h))
.unwrap_or_default();
if owner_id_bytes.is_empty() {
tracing::error!("could not parse hex string identifier");
return;
}
let (tx, mut rx) = mpsc::channel::<MutinyWSCommand>(32);
let (bc_tx, _bc_rx1) = broadcast::channel::<bool>(32);
state
.lock()
.await
.insert(owner_id_bytes.clone(), (tx.clone(), bc_tx.clone()));
let connected_peers = Arc::new(Mutex::new(HashSet::<bytes::Bytes>::new()));
tracing::debug!("listening for {identifier} websocket or consumer channel",);
loop {
tokio::select! {
res = socket.recv() => {
if let Some(msg) = res {
if let Ok(msg_wrapper) = msg {
match msg_wrapper {
Message::Text(msg) => {
let command: MutinyProxyCommand = match serde_json::from_str(&msg) {
Ok(c) => c,
Err(e) => {
tracing::error!("couldn't parse text command from client, ignoring: {e}");
continue;
}
};
match command {
MutinyProxyCommand::Disconnect { to, from: _from } => {
let peer_id_bytes = bytes::Bytes::from(to);
if let Some((peer_tx, _bc_tx)) = state.lock().await.get(&peer_id_bytes) {
try_send_disconnect_ws_command(peer_tx.clone(), owner_id_bytes.clone()).await;
connected_peers.lock().await.remove(&peer_id_bytes);
} else {
tracing::error!("peer tried to disconnect someone not connected to");
}
}
};
},
Message::Binary(msg) => {
if msg.len() < PUBKEY_BYTES_LEN {
tracing::error!("msg not long enough to have pubkey (had {}), ignoring...", msg.len());
continue
}
let (id_bytes, message_bytes) = msg.split_at(PUBKEY_BYTES_LEN);
let peer_id_bytes = bytes::Bytes::from(id_bytes.to_vec());
tracing::debug!("received a ws msg from {identifier} to send to {:?}", peer_id_bytes);
if let Some((peer_tx, bc_tx)) = state.lock().await.get(&peer_id_bytes) {
match peer_tx.send(MutinyWSCommand::Send { id: owner_id_bytes.clone(), val: bytes::Bytes::from(message_bytes.to_vec()) }).await {
Ok(_) => {
tracing::debug!("successfully sent msg to {:?}", peer_id_bytes);
listen_for_disconnections(connected_peers.clone(), peer_id_bytes.clone(), bc_tx.subscribe(), tx.clone()).await;
},
Err(e) => {
tracing::error!("could not send message to peer identity: {}", e);
try_send_disconnect_ws_command(tx.clone(), peer_id_bytes).await;
},
}
} else {
tracing::error!("no producer found, sending disconnect");
try_send_disconnect_ws_command(tx.clone(), peer_id_bytes).await;
}
},
_ => {
},
};
}
} else {
try_broadcast_disconnect(bc_tx);
state.lock().await.remove(&owner_id_bytes);
tracing::info!("Websocket owner closed the connection");
return;
}
},
res = rx.recv() => {
match res {
Some(message) => {
match message {
MutinyWSCommand::Send{id, val} => {
tracing::debug!("received a channel msg from {:?} to send to {identifier}", id);
let mut concat_bytes = id[..].to_vec();
let mut val_bytes = val[..].to_vec();
concat_bytes.append(&mut val_bytes);
match socket.send(Message::Binary(concat_bytes)).await {
Ok(_) => {
tracing::debug!("sent channel msg down socket from {:?} to to {identifier}", id);
},
Err(e) => {
tracing::error!("could not send message to ws owner: {}", e);
try_broadcast_disconnect(bc_tx);
state.lock().await.remove(&owner_id_bytes);
return;
},
}
}
MutinyWSCommand::Disconnect{id} => {
tracing::debug!("received a channel msg from {:?} to disconnect from {identifier}", id);
match socket.send(Message::Text(serde_json::to_string(&MutinyProxyCommand::Disconnect{to: owner_id_bytes.to_vec(), from: id.to_vec()}).unwrap())).await {
Ok(_) => (),
Err(e) => {
tracing::error!("could not send message to ws owner: {}", e);
try_broadcast_disconnect(bc_tx);
state.lock().await.remove(&owner_id_bytes);
return;
},
}
}
};
},
None => {
tracing::info!("channel closed");
try_broadcast_disconnect(bc_tx);
state.lock().await.remove(&owner_id_bytes);
return;
}
}
},
}
}
}
async fn listen_for_disconnections(
connected_peers: Arc<Mutex<HashSet<bytes::Bytes>>>,
other_peer: bytes::Bytes,
mut rx: broadcast::Receiver<bool>,
tx: mpsc::Sender<MutinyWSCommand>,
) {
let mut locked_connected_peers = connected_peers.lock().await;
if locked_connected_peers.contains(&other_peer) {
return;
}
locked_connected_peers.insert(other_peer.clone());
let listening_connected_peers = connected_peers.clone();
tokio::spawn(async move {
match rx.recv().await {
Ok(_) => {
try_send_disconnect_ws_command(tx.clone(), other_peer.clone()).await;
}
Err(e) => {
tracing::error!(
"got an error listening for broadcast messages from {:?}: {}",
other_peer,
e
);
try_send_disconnect_ws_command(tx.clone(), other_peer.clone()).await;
}
};
listening_connected_peers.lock().await.remove(&other_peer);
});
}
fn try_broadcast_disconnect(bc_tx: broadcast::Sender<bool>) {
match bc_tx.send(true) {
Ok(_) => (),
Err(e) => {
tracing::error!(
"could not broadcast that we've disconnected websocket owner: {}",
e
);
}
};
}
async fn try_send_disconnect_ws_command(
tx: mpsc::Sender<MutinyWSCommand>,
other_peer: bytes::Bytes,
) {
match tx
.send(MutinyWSCommand::Disconnect { id: other_peer })
.await
{
Ok(_) => (),
Err(e) => {
tracing::error!("could not send disconnect msg to self: {}", e);
}
};
}
#[cfg(test)]
mod tests {
use crate::format_addr_from_url;
#[tokio::test]
async fn test_format_addr_from_url() {
assert_eq!(
"127.0.0.1:9000",
format_addr_from_url(String::from("127_0_0_1"), String::from("9000"))
)
}
}