bright_ln_client/
websocket.rs

1use std::sync::Arc;
2
3use futures_util::{
4    SinkExt, StreamExt,
5    stream::{SplitSink, SplitStream},
6};
7use httparse::Header;
8use serde::{Serialize, de::DeserializeOwned};
9use tokio::{net::TcpStream, sync::RwLock};
10use tokio_tungstenite::{MaybeTlsStream, WebSocketStream, tungstenite::Message};
11
12use bright_ln_models::{LndError, LndResponse};
13
14type LndWebsocketWriterHalf =
15    Option<SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, Message>>;
16type LndWebsocketReaderHalf = Option<SplitStream<WebSocketStream<MaybeTlsStream<TcpStream>>>>;
17
18#[derive(Debug, thiserror::Error)]
19pub enum LndWebsocketError {
20    #[error("NoWriter")]
21    NoWriter,
22    #[error("UnparseableMessage")]
23    UnparseableMessage,
24    #[error("WebSocketError: {0}")]
25    Tungstenite(Box<tokio_tungstenite::tungstenite::Error>),
26    #[error("TlsError: {0}")]
27    NativeTls(#[from] native_tls::Error),
28}
29
30type Result<T> = std::result::Result<T, LndWebsocketError>;
31
32#[derive(Clone, Default, Debug)]
33pub struct LndWebsocketWriter(Arc<RwLock<LndWebsocketWriterHalf>>);
34impl LndWebsocketWriter {
35    pub fn new(writer: LndWebsocketWriterHalf) -> Self {
36        Self(Arc::new(RwLock::new(writer)))
37    }
38    pub async fn send<S>(&self, message: S) -> Result<()>
39    where
40        S: Send + Sync + 'static,
41        String: TryFrom<S, Error = serde_json::Error>,
42    {
43        let message_string =
44            String::try_from(message).map_err(|_e| LndWebsocketError::UnparseableMessage)?;
45        let message = Message::Text(message_string.into());
46        let mut writer = self.0.write().await;
47        let writer = writer.as_mut().ok_or(LndWebsocketError::NoWriter)?;
48        writer
49            .send(message)
50            .await
51            .map_err(|e| LndWebsocketError::Tungstenite(Box::new(e)))
52    }
53}
54#[derive(Clone, Default, Debug)]
55pub struct LndWebsocketReader(Arc<RwLock<LndWebsocketReaderHalf>>);
56impl LndWebsocketReader {
57    #[must_use]
58    pub fn new(reader: LndWebsocketReaderHalf) -> Self {
59        Self(Arc::new(RwLock::new(reader)))
60    }
61    pub async fn read<R>(&self) -> Option<LndWebsocketMessage<R>>
62    where
63        R: std::str::FromStr + Send + Sync + 'static + Serialize + DeserializeOwned + Clone,
64    {
65        let value = self.0.write().await.as_mut()?.next().await?;
66        match value {
67            Ok(message) => match message {
68                Message::Text(text) => match text.to_string().parse::<LndResponse<R>>() {
69                    Ok(response) => Some(LndWebsocketMessage::Response(response.inner())),
70                    Err(_e) => {
71                        let lnd_error = text.to_string().parse::<LndError>().ok()?;
72                        Some(LndWebsocketMessage::Error(lnd_error))
73                    }
74                },
75                Message::Ping(_) => Some(LndWebsocketMessage::Ping),
76                _ => None,
77            },
78            Err(e) => Some(LndWebsocketMessage::Error(
79                e.to_string().parse::<LndError>().unwrap_or_default(),
80            )),
81        }
82    }
83}
84
85#[derive(Debug)]
86pub enum LndWebsocketMessage<R> {
87    Response(R),
88    Error(LndError),
89    Ping,
90}
91
92#[derive(Debug, Default)]
93pub struct LndWebsocket {
94    pub receiver: LndWebsocketReader,
95    pub sender: LndWebsocketWriter,
96}
97
98impl LndWebsocket {
99    pub async fn connect(&self, url: String, macaroon: String, request: String) -> Result<Self> {
100        let random_key = b"dGhlIHNhbXBsZSBub25jZQ2342qdfsdgfsdfg";
101        let mut headers = [
102            Header {
103                name: "Grpc-Metadata-macaroon",
104                value: macaroon.as_bytes(),
105            },
106            Header {
107                name: "Sec-WebSocket-Key",
108                value: random_key,
109            },
110            Header {
111                name: "Host",
112                value: url.as_bytes(),
113            },
114            Header {
115                name: "Connection",
116                value: b"Upgrade",
117            },
118            Header {
119                name: "Upgrade",
120                value: b"websocket",
121            },
122            httparse::Header {
123                name: "Sec-WebSocket-Version",
124                value: b"13",
125            },
126        ];
127        let mut req = httparse::Request::new(&mut headers);
128        req.method = Some("GET");
129        req.path = Some(&request);
130        req.version = Some(1);
131
132        // Prepare the websocket connection with SSL
133        let danger_conf = Some(
134            tokio_tungstenite::tungstenite::protocol::WebSocketConfig::default()
135                .accept_unmasked_frames(true),
136        );
137
138        let tls = native_tls::TlsConnector::builder()
139            .danger_accept_invalid_certs(true)
140            .build()?;
141        let (ws, _response) = tokio_tungstenite::connect_async_tls_with_config(
142            req,
143            danger_conf,
144            false,
145            Some(tokio_tungstenite::Connector::NativeTls(tls)),
146        )
147        .await
148        .map_err(|e| LndWebsocketError::Tungstenite(Box::new(e)))?;
149        let (websocket_sender, websocket_reader) = ws.split();
150        let sender = LndWebsocketWriter::new(Some(websocket_sender));
151        let receiver = LndWebsocketReader::new(Some(websocket_reader));
152        Ok(Self { receiver, sender })
153    }
154}
155#[cfg(test)]
156mod test {
157
158    use super::{LndWebsocketMessage, Result};
159    use bright_ln_models::{LndHodlInvoiceState, LndInvoiceRequestBody};
160    use std::io::Read;
161    use tracing_test::traced_test;
162
163    #[tokio::test]
164    #[traced_test]
165    async fn check_invoice_paid() -> Result<()> {
166        let url = "lnd.illuminodes.com";
167        let client = crate::LndRestClient::new(url, "./admin.macaroon").expect("No client");
168        let invoice = client
169            .get_invoice(LndInvoiceRequestBody {
170                value: 1000.to_string(),
171                memo: Some("Hello".to_string()),
172            })
173            .await
174            .expect("No invoice");
175        tracing::info!("Invoice: {:#?}", invoice);
176        let query = format!(
177            "wss://{}/v2/invoices/subscribe/{}",
178            url,
179            invoice.r_hash_url_safe().expect("No hash")
180        );
181        let mut macaroon = vec![];
182        let mut file = std::fs::File::open("../admin.macaroon").expect("No file");
183        file.read_to_end(&mut macaroon).expect("No macaroon");
184        let lnd_ws = super::LndWebsocket::default()
185            .connect(
186                url.to_string(),
187                macaroon.iter().fold(String::new(), |mut acc, x| {
188                    acc.push_str(&format!("{x:02x}"));
189                    acc
190                }),
191                query,
192            )
193            .await?;
194        loop {
195            match lnd_ws.receiver.read::<LndHodlInvoiceState>().await {
196                Some(LndWebsocketMessage::Response(state)) => {
197                    tracing::info!("State: {:#?}", state);
198                    break;
199                }
200                Some(LndWebsocketMessage::Error(e)) => {
201                    tracing::error!("Error: {:#?}", e);
202                    panic!("Error: {:#?}", e);
203                }
204                Some(LndWebsocketMessage::Ping) => {
205                    tracing::info!("Ping");
206                }
207                None => {
208                    tracing::info!("None");
209                    panic!("None");
210                }
211            }
212        }
213        Ok(())
214    }
215}