bright_lightning/lnd/
websocket.rs

1use std::sync::Arc;
2
3use futures_util::{
4    stream::{SplitSink, SplitStream},
5    SinkExt, StreamExt,
6};
7use httparse::Header;
8use serde::{de::DeserializeOwned, Serialize};
9use tokio::{net::TcpStream, sync::RwLock};
10use tokio_tungstenite::{tungstenite::Message, MaybeTlsStream, WebSocketStream};
11
12use crate::{LndError, LndResponse};
13
14type LndWebsocketWriterHalf =
15    Option<SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, Message>>;
16type LndWebsocketReaderHalf = Option<SplitStream<WebSocketStream<MaybeTlsStream<TcpStream>>>>;
17
18#[derive(Clone, Default, Debug)]
19pub struct LndWebsocketWriter(Arc<RwLock<LndWebsocketWriterHalf>>);
20impl LndWebsocketWriter {
21    pub fn new(writer: LndWebsocketWriterHalf) -> Self {
22        Self(Arc::new(RwLock::new(writer)))
23    }
24    pub async fn send<S>(&self, message: S) -> anyhow::Result<()>
25    where
26        S: TryInto<String> + Send + Sync + 'static,
27        <S as TryInto<std::string::String>>::Error:
28            std::marker::Send + std::fmt::Debug + std::marker::Sync,
29    {
30        let message_string = message
31            .try_into()
32            .map_err(|_e| anyhow::anyhow!("Could not parse"))?;
33        let message = Message::Text(message_string.into());
34        let mut writer = self.0.write().await;
35        if let Some(writer) = writer.as_mut() {
36            Ok(writer.send(message).await?)
37        } else {
38            Err(anyhow::anyhow!("No writer"))
39        }
40    }
41}
42#[derive(Clone, Default, Debug)]
43pub struct LndWebsocketReader(Arc<RwLock<LndWebsocketReaderHalf>>);
44impl LndWebsocketReader {
45    #[must_use]
46    pub fn new(reader: LndWebsocketReaderHalf) -> Self {
47        Self(Arc::new(RwLock::new(reader)))
48    }
49    pub async fn read<R>(&self) -> Option<LndWebsocketMessage<R>>
50    where
51        R: TryFrom<String>
52            + std::fmt::Display
53            + Send
54            + Sync
55            + 'static
56            + Serialize
57            + DeserializeOwned
58            + Clone,
59        <R as TryFrom<std::string::String>>::Error: std::marker::Send + std::fmt::Debug,
60    {
61        let value = self.0.write().await.as_mut()?.next().await?;
62        match value {
63            Ok(message) => match message {
64                Message::Text(text) => match LndResponse::<R>::try_from(&text.to_string()) {
65                    Ok(response) => Some(LndWebsocketMessage::Response(response.inner())),
66                    Err(_e) => {
67                        let lnd_error = LndError::try_from(text.to_string()).ok()?;
68                        Some(LndWebsocketMessage::Error(lnd_error))
69                    }
70                },
71                Message::Ping(_) => Some(LndWebsocketMessage::Ping),
72                _ => None,
73            },
74            Err(e) => Some(LndWebsocketMessage::Error(
75                LndError::try_from(e.to_string()).unwrap(),
76            )),
77        }
78    }
79}
80
81#[derive(Debug)]
82pub enum LndWebsocketMessage<R> {
83    Response(R),
84    Error(LndError),
85    Ping,
86}
87
88#[derive(Debug, Default)]
89pub struct LndWebsocket {
90    pub receiver: LndWebsocketReader,
91    pub sender: LndWebsocketWriter,
92}
93
94impl LndWebsocket {
95    pub async fn connect(
96        &self,
97        url: String,
98        macaroon: String,
99        request: String,
100    ) -> anyhow::Result<Self> {
101        let random_key = b"dGhlIHNhbXBsZSBub25jZQ2342qdfsdgfsdfg";
102        let mut headers = [
103            Header {
104                name: "Grpc-Metadata-macaroon",
105                value: macaroon.as_bytes(),
106            },
107            Header {
108                name: "Sec-WebSocket-Key",
109                value: random_key,
110            },
111            Header {
112                name: "Host",
113                value: url.as_bytes(),
114            },
115            Header {
116                name: "Connection",
117                value: b"Upgrade",
118            },
119            Header {
120                name: "Upgrade",
121                value: b"websocket",
122            },
123            httparse::Header {
124                name: "Sec-WebSocket-Version",
125                value: b"13",
126            },
127        ];
128        let mut req = httparse::Request::new(&mut headers);
129        req.method = Some("GET");
130        req.path = Some(&request);
131        req.version = Some(1);
132
133        // Prepare the websocket connection with SSL
134        let danger_conf = Some(
135            tokio_tungstenite::tungstenite::protocol::WebSocketConfig::default()
136                .accept_unmasked_frames(true),
137        );
138
139        let tls = native_tls::TlsConnector::builder()
140            .danger_accept_invalid_certs(true)
141            .build()?;
142        let (ws, _response) = tokio_tungstenite::connect_async_tls_with_config(
143            req,
144            danger_conf,
145            false,
146            Some(tokio_tungstenite::Connector::NativeTls(tls)),
147        )
148        .await?;
149        let (websocket_sender, websocket_reader) = ws.split();
150        let sender = LndWebsocketWriter::new(Some(websocket_sender));
151        let receiver = LndWebsocketReader::new(Some(websocket_reader));
152        Ok(Self { receiver, sender })
153    }
154}
155#[cfg(test)]
156mod test {
157
158    use super::LndWebsocketMessage;
159    use crate::LndHodlInvoiceState;
160    use std::io::Read;
161    use tracing_test::traced_test;
162
163    #[tokio::test]
164    #[traced_test]
165    async fn check_invoice_paid() -> Result<(), anyhow::Error> {
166        let url = "lnd.illuminodes.com";
167        let client = crate::lnd::rest_client::LndRestClient::new(url, "./admin.macaroon")?;
168        let invoice = client
169            .get_invoice(crate::LndInvoiceRequestBody {
170                value: 1000.to_string(),
171                memo: Some("Hello".to_string()),
172            })
173            .await?;
174        tracing::info!("Invoice: {}", invoice);
175        let query = format!(
176            "wss://{}/v2/invoices/subscribe/{}",
177            url,
178            invoice.r_hash_url_safe()
179        );
180        let mut macaroon = vec![];
181        let mut file = std::fs::File::open("./admin.macaroon")?;
182        file.read_to_end(&mut macaroon)?;
183        let lnd_ws = super::LndWebsocket::default()
184            .connect(
185                url.to_string(),
186                macaroon.iter().fold(String::new(), |mut acc, x| {
187                    acc.push_str(&format!("{x:02x}"));
188                    acc
189                }),
190                query,
191            )
192            .await?;
193        loop {
194            match lnd_ws.receiver.read::<LndHodlInvoiceState>().await {
195                Some(LndWebsocketMessage::Response(state)) => {
196                    tracing::info!("State: {}", state);
197                    break;
198                }
199                Some(LndWebsocketMessage::Error(e)) => {
200                    tracing::error!("Error: {}", e);
201                    assert!(false);
202                }
203                Some(LndWebsocketMessage::Ping) => {
204                    tracing::info!("Ping");
205                }
206                None => {
207                    tracing::info!("None");
208                    assert!(false);
209                }
210            }
211        }
212        Ok(())
213    }
214}