cdk_payment_processor/proto/
server.rs

1use std::net::SocketAddr;
2use std::path::PathBuf;
3use std::pin::Pin;
4use std::str::FromStr;
5use std::sync::Arc;
6use std::time::Duration;
7
8use cdk_common::payment::{IncomingPaymentOptions, MintPayment};
9use cdk_common::CurrencyUnit;
10use futures::{Stream, StreamExt};
11use lightning::offers::offer::Offer;
12use serde_json::Value;
13use tokio::sync::{mpsc, Notify};
14use tokio::task::JoinHandle;
15use tokio::time::{sleep, Instant};
16use tokio_stream::wrappers::ReceiverStream;
17use tonic::transport::{Certificate, Identity, Server, ServerTlsConfig};
18use tonic::{async_trait, Request, Response, Status};
19use tracing::instrument;
20
21use super::cdk_payment_processor_server::{CdkPaymentProcessor, CdkPaymentProcessorServer};
22use crate::error::Error;
23use crate::proto::*;
24
25type ResponseStream =
26    Pin<Box<dyn Stream<Item = Result<WaitIncomingPaymentResponse, Status>> + Send>>;
27
28/// Payment Processor
29#[derive(Clone)]
30pub struct PaymentProcessorServer {
31    inner: Arc<dyn MintPayment<Err = cdk_common::payment::Error> + Send + Sync>,
32    socket_addr: SocketAddr,
33    shutdown: Arc<Notify>,
34    handle: Option<Arc<JoinHandle<anyhow::Result<()>>>>,
35}
36
37impl PaymentProcessorServer {
38    /// Create new [`PaymentProcessorServer`]
39    pub fn new(
40        payment_processor: Arc<dyn MintPayment<Err = cdk_common::payment::Error> + Send + Sync>,
41        addr: &str,
42        port: u16,
43    ) -> anyhow::Result<Self> {
44        let socket_addr = SocketAddr::new(addr.parse()?, port);
45        Ok(Self {
46            inner: payment_processor,
47            socket_addr,
48            shutdown: Arc::new(Notify::new()),
49            handle: None,
50        })
51    }
52
53    /// Start fake wallet grpc server
54    pub async fn start(&mut self, tls_dir: Option<PathBuf>) -> anyhow::Result<()> {
55        tracing::info!("Starting RPC server {}", self.socket_addr);
56
57        let server = match tls_dir {
58            Some(tls_dir) => {
59                tracing::info!("TLS configuration found, starting secure server");
60
61                // Check for server.pem
62                let server_pem_path = tls_dir.join("server.pem");
63                if !server_pem_path.exists() {
64                    let err_msg = format!(
65                        "TLS certificate file not found: {}",
66                        server_pem_path.display()
67                    );
68                    tracing::error!("{}", err_msg);
69                    return Err(anyhow::anyhow!(err_msg));
70                }
71
72                // Check for server.key
73                let server_key_path = tls_dir.join("server.key");
74                if !server_key_path.exists() {
75                    let err_msg = format!("TLS key file not found: {}", server_key_path.display());
76                    tracing::error!("{}", err_msg);
77                    return Err(anyhow::anyhow!(err_msg));
78                }
79
80                // Check for ca.pem
81                let ca_pem_path = tls_dir.join("ca.pem");
82                if !ca_pem_path.exists() {
83                    let err_msg =
84                        format!("CA certificate file not found: {}", ca_pem_path.display());
85                    tracing::error!("{}", err_msg);
86                    return Err(anyhow::anyhow!(err_msg));
87                }
88
89                let cert = std::fs::read_to_string(&server_pem_path)?;
90                let key = std::fs::read_to_string(&server_key_path)?;
91                let client_ca_cert = std::fs::read_to_string(&ca_pem_path)?;
92
93                let client_ca_cert = Certificate::from_pem(client_ca_cert);
94                let server_identity = Identity::from_pem(cert, key);
95                let tls_config = ServerTlsConfig::new()
96                    .identity(server_identity)
97                    .client_ca_root(client_ca_cert);
98
99                Server::builder()
100                    .tls_config(tls_config)?
101                    .add_service(CdkPaymentProcessorServer::new(self.clone()))
102            }
103            None => {
104                tracing::warn!("No valid TLS configuration found, starting insecure server");
105                Server::builder().add_service(CdkPaymentProcessorServer::new(self.clone()))
106            }
107        };
108
109        let shutdown = self.shutdown.clone();
110        let addr = self.socket_addr;
111
112        self.handle = Some(Arc::new(tokio::spawn(async move {
113            let server = server.serve_with_shutdown(addr, async {
114                shutdown.notified().await;
115            });
116
117            server.await?;
118            Ok(())
119        })));
120
121        Ok(())
122    }
123
124    /// Stop fake wallet grpc server
125    pub async fn stop(&self) -> anyhow::Result<()> {
126        const SHUTDOWN_TIMEOUT: Duration = Duration::from_secs(5);
127
128        if let Some(handle) = &self.handle {
129            tracing::info!("Initiating server shutdown");
130            self.shutdown.notify_waiters();
131
132            let start = Instant::now();
133
134            while !handle.is_finished() {
135                if start.elapsed() >= SHUTDOWN_TIMEOUT {
136                    tracing::error!(
137                        "Server shutdown timed out after {} seconds, aborting handle",
138                        SHUTDOWN_TIMEOUT.as_secs()
139                    );
140                    handle.abort();
141                    break;
142                }
143                sleep(Duration::from_millis(100)).await;
144            }
145
146            if handle.is_finished() {
147                tracing::info!("Server shutdown completed successfully");
148            }
149        } else {
150            tracing::info!("No server handle found, nothing to stop");
151        }
152
153        Ok(())
154    }
155}
156
157impl Drop for PaymentProcessorServer {
158    fn drop(&mut self) {
159        tracing::debug!("Dropping payment process server");
160        self.shutdown.notify_one();
161    }
162}
163
164#[async_trait]
165impl CdkPaymentProcessor for PaymentProcessorServer {
166    async fn get_settings(
167        &self,
168        _request: Request<EmptyRequest>,
169    ) -> Result<Response<SettingsResponse>, Status> {
170        let settings: Value = self
171            .inner
172            .get_settings()
173            .await
174            .map_err(|_| Status::internal("Could not get settings"))?;
175
176        Ok(Response::new(SettingsResponse {
177            inner: settings.to_string(),
178        }))
179    }
180
181    async fn create_payment(
182        &self,
183        request: Request<CreatePaymentRequest>,
184    ) -> Result<Response<CreatePaymentResponse>, Status> {
185        let CreatePaymentRequest { unit, options } = request.into_inner();
186
187        let unit = CurrencyUnit::from_str(&unit)
188            .map_err(|_| Status::invalid_argument("Invalid currency unit"))?;
189
190        let options = options.ok_or_else(|| Status::invalid_argument("Missing payment options"))?;
191
192        let proto_options = match options
193            .options
194            .ok_or_else(|| Status::invalid_argument("Missing options"))?
195        {
196            incoming_payment_options::Options::Bolt11(opts) => {
197                IncomingPaymentOptions::Bolt11(cdk_common::payment::Bolt11IncomingPaymentOptions {
198                    description: opts.description,
199                    amount: opts.amount.into(),
200                    unix_expiry: opts.unix_expiry,
201                })
202            }
203            incoming_payment_options::Options::Bolt12(opts) => IncomingPaymentOptions::Bolt12(
204                Box::new(cdk_common::payment::Bolt12IncomingPaymentOptions {
205                    description: opts.description,
206                    amount: opts.amount.map(Into::into),
207                    unix_expiry: opts.unix_expiry,
208                }),
209            ),
210        };
211
212        let invoice_response = self
213            .inner
214            .create_incoming_payment_request(&unit, proto_options)
215            .await
216            .map_err(|_| Status::internal("Could not create invoice"))?;
217
218        Ok(Response::new(invoice_response.into()))
219    }
220
221    async fn get_payment_quote(
222        &self,
223        request: Request<PaymentQuoteRequest>,
224    ) -> Result<Response<PaymentQuoteResponse>, Status> {
225        let request = request.into_inner();
226
227        let unit = CurrencyUnit::from_str(&request.unit)
228            .map_err(|_| Status::invalid_argument("Invalid currency unit"))?;
229
230        let options = match request.request_type() {
231            OutgoingPaymentRequestType::Bolt11Invoice => {
232                let bolt11: cdk_common::Bolt11Invoice =
233                    request.request.parse().map_err(Error::Invoice)?;
234
235                cdk_common::payment::OutgoingPaymentOptions::Bolt11(Box::new(
236                    cdk_common::payment::Bolt11OutgoingPaymentOptions {
237                        bolt11,
238                        max_fee_amount: None,
239                        timeout_secs: None,
240                        melt_options: request.options.map(Into::into),
241                    },
242                ))
243            }
244            OutgoingPaymentRequestType::Bolt12Offer => {
245                // Parse offer to verify it's valid, but store as string
246                let _: Offer = request.request.parse().map_err(|_| Error::Bolt12Parse)?;
247
248                cdk_common::payment::OutgoingPaymentOptions::Bolt12(Box::new(
249                    cdk_common::payment::Bolt12OutgoingPaymentOptions {
250                        offer: Offer::from_str(&request.request).unwrap(),
251                        max_fee_amount: None,
252                        timeout_secs: None,
253                        melt_options: request.options.map(Into::into),
254                    },
255                ))
256            }
257        };
258
259        let payment_quote = self
260            .inner
261            .get_payment_quote(&unit, options)
262            .await
263            .map_err(|err| {
264                tracing::error!("Could not get payment quote: {}", err);
265                Status::internal("Could not get quote")
266            })?;
267
268        Ok(Response::new(payment_quote.into()))
269    }
270
271    async fn make_payment(
272        &self,
273        request: Request<MakePaymentRequest>,
274    ) -> Result<Response<MakePaymentResponse>, Status> {
275        let request = request.into_inner();
276
277        let options = request
278            .payment_options
279            .ok_or_else(|| Status::invalid_argument("Missing payment options"))?;
280
281        let (unit, payment_options) = match options
282            .options
283            .ok_or_else(|| Status::invalid_argument("Missing options"))?
284        {
285            outgoing_payment_variant::Options::Bolt11(opts) => {
286                let bolt11: cdk_common::Bolt11Invoice =
287                    opts.bolt11.parse().map_err(Error::Invoice)?;
288
289                let payment_options = cdk_common::payment::OutgoingPaymentOptions::Bolt11(
290                    Box::new(cdk_common::payment::Bolt11OutgoingPaymentOptions {
291                        bolt11,
292                        max_fee_amount: opts.max_fee_amount.map(Into::into),
293                        timeout_secs: opts.timeout_secs,
294                        melt_options: opts.melt_options.map(Into::into),
295                    }),
296                );
297
298                (CurrencyUnit::Msat, payment_options)
299            }
300            outgoing_payment_variant::Options::Bolt12(opts) => {
301                let offer = Offer::from_str(&opts.offer)
302                    .map_err(|_| Error::Bolt12Parse)
303                    .unwrap();
304
305                let payment_options = cdk_common::payment::OutgoingPaymentOptions::Bolt12(
306                    Box::new(cdk_common::payment::Bolt12OutgoingPaymentOptions {
307                        offer,
308                        max_fee_amount: opts.max_fee_amount.map(Into::into),
309                        timeout_secs: opts.timeout_secs,
310                        melt_options: opts.melt_options.map(Into::into),
311                    }),
312                );
313
314                (CurrencyUnit::Msat, payment_options)
315            }
316        };
317
318        let pay_response = self
319            .inner
320            .make_payment(&unit, payment_options)
321            .await
322            .map_err(|err| {
323                tracing::error!("Could not make payment: {}", err);
324
325                match err {
326                    cdk_common::payment::Error::InvoiceAlreadyPaid => {
327                        Status::already_exists("Payment request already paid")
328                    }
329                    cdk_common::payment::Error::InvoicePaymentPending => {
330                        Status::already_exists("Payment request pending")
331                    }
332                    _ => Status::internal("Could not pay invoice"),
333                }
334            })?;
335
336        Ok(Response::new(pay_response.into()))
337    }
338
339    async fn check_incoming_payment(
340        &self,
341        request: Request<CheckIncomingPaymentRequest>,
342    ) -> Result<Response<CheckIncomingPaymentResponse>, Status> {
343        let request = request.into_inner();
344
345        let payment_identifier = request
346            .request_identifier
347            .ok_or_else(|| Status::invalid_argument("Missing request identifier"))?
348            .try_into()
349            .map_err(|_| Status::invalid_argument("Invalid request identifier"))?;
350
351        let check_responses = self
352            .inner
353            .check_incoming_payment_status(&payment_identifier)
354            .await
355            .map_err(|_| Status::internal("Could not check incoming payment status"))?;
356
357        Ok(Response::new(CheckIncomingPaymentResponse {
358            payments: check_responses.into_iter().map(|r| r.into()).collect(),
359        }))
360    }
361
362    async fn check_outgoing_payment(
363        &self,
364        request: Request<CheckOutgoingPaymentRequest>,
365    ) -> Result<Response<MakePaymentResponse>, Status> {
366        let request = request.into_inner();
367
368        let payment_identifier = request
369            .request_identifier
370            .ok_or_else(|| Status::invalid_argument("Missing request identifier"))?
371            .try_into()
372            .map_err(|_| Status::invalid_argument("Invalid request identifier"))?;
373
374        let check_response = self
375            .inner
376            .check_outgoing_payment(&payment_identifier)
377            .await
378            .map_err(|_| Status::internal("Could not check outgoing payment status"))?;
379
380        Ok(Response::new(check_response.into()))
381    }
382
383    type WaitIncomingPaymentStream = ResponseStream;
384
385    #[allow(clippy::incompatible_msrv)]
386    #[instrument(skip_all)]
387    async fn wait_incoming_payment(
388        &self,
389        _request: Request<EmptyRequest>,
390    ) -> Result<Response<Self::WaitIncomingPaymentStream>, Status> {
391        tracing::debug!("Server waiting for payment stream");
392        let (tx, rx) = mpsc::channel(128);
393
394        let shutdown_clone = self.shutdown.clone();
395        let ln = self.inner.clone();
396        tokio::spawn(async move {
397            loop {
398                tokio::select! {
399                    _ = shutdown_clone.notified() => {
400                        tracing::info!("Shutdown signal received, stopping task");
401                        ln.cancel_wait_invoice();
402                        break;
403                    }
404                    result = ln.wait_payment_event() => {
405                        match result {
406                            Ok(mut stream) => {
407                                while let Some(event) = stream.next().await {
408                                    match event {
409                                        cdk_common::payment::Event::PaymentReceived(payment_response) => {
410                                            match tx.send(Result::<_, Status>::Ok(payment_response.into()))
411                                            .await
412                                            {
413                                                Ok(_) => {
414                                                    // Response was queued to be sent to client
415                                                }
416                                                Err(item) => {
417                                                    tracing::error!("Error adding incoming payment to stream: {}", item);
418                                                    break;
419                                                }
420                                            }
421                                        }
422                                    }
423                                }
424                            }
425                            Err(err) => {
426                                tracing::warn!("Could not get invoice stream: {}", err);
427                                tokio::time::sleep(std::time::Duration::from_secs(5)).await;
428                            }
429                        }
430                    }
431                }
432            }
433        });
434
435        let output_stream = ReceiverStream::new(rx);
436        Ok(Response::new(
437            Box::pin(output_stream) as Self::WaitIncomingPaymentStream
438        ))
439    }
440}