Skip to main content

cdk_payment_processor/proto/
server.rs

1use std::net::SocketAddr;
2use std::path::PathBuf;
3use std::pin::Pin;
4use std::str::FromStr;
5use std::sync::Arc;
6use std::time::Duration;
7
8use cdk_common::grpc::create_version_check_interceptor;
9use cdk_common::payment::{IncomingPaymentOptions, MintPayment};
10use cdk_common::CurrencyUnit;
11use futures::{Stream, StreamExt};
12use lightning::offers::offer::Offer;
13use tokio::sync::{mpsc, Notify};
14use tokio::task::JoinHandle;
15use tokio::time::{sleep, Instant};
16use tokio_stream::wrappers::ReceiverStream;
17use tonic::transport::{Certificate, Identity, Server, ServerTlsConfig};
18use tonic::{async_trait, Request, Response, Status};
19use tracing::instrument;
20
21use super::cdk_payment_processor_server::{CdkPaymentProcessor, CdkPaymentProcessorServer};
22use crate::error::Error;
23use crate::proto::*;
24
25type ResponseStream =
26    Pin<Box<dyn Stream<Item = Result<WaitIncomingPaymentResponse, Status>> + Send>>;
27
28/// Payment Processor
29#[derive(Clone)]
30pub struct PaymentProcessorServer {
31    inner: Arc<dyn MintPayment<Err = cdk_common::payment::Error> + Send + Sync>,
32    socket_addr: SocketAddr,
33    shutdown: Arc<Notify>,
34    handle: Option<Arc<JoinHandle<anyhow::Result<()>>>>,
35}
36
37impl std::fmt::Debug for PaymentProcessorServer {
38    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
39        f.debug_struct("PaymentProcessorServer")
40            .field("socket_addr", &self.socket_addr)
41            .finish_non_exhaustive()
42    }
43}
44
45impl PaymentProcessorServer {
46    /// Create new [`PaymentProcessorServer`]
47    pub fn new(
48        payment_processor: Arc<dyn MintPayment<Err = cdk_common::payment::Error> + Send + Sync>,
49        addr: &str,
50        port: u16,
51    ) -> anyhow::Result<Self> {
52        let socket_addr = SocketAddr::new(addr.parse()?, port);
53        Ok(Self {
54            inner: payment_processor,
55            socket_addr,
56            shutdown: Arc::new(Notify::new()),
57            handle: None,
58        })
59    }
60
61    /// Start fake wallet grpc server
62    pub async fn start(&mut self, tls_dir: Option<PathBuf>) -> anyhow::Result<()> {
63        tracing::info!("Starting RPC server {}", self.socket_addr);
64
65        let server = match tls_dir {
66            Some(tls_dir) => {
67                tracing::info!("TLS configuration found, starting secure server");
68
69                // Check for server.pem
70                let server_pem_path = tls_dir.join("server.pem");
71                if !server_pem_path.exists() {
72                    let err_msg = format!(
73                        "TLS certificate file not found: {}",
74                        server_pem_path.display()
75                    );
76                    tracing::error!("{}", err_msg);
77                    return Err(anyhow::anyhow!(err_msg));
78                }
79
80                // Check for server.key
81                let server_key_path = tls_dir.join("server.key");
82                if !server_key_path.exists() {
83                    let err_msg = format!("TLS key file not found: {}", server_key_path.display());
84                    tracing::error!("{}", err_msg);
85                    return Err(anyhow::anyhow!(err_msg));
86                }
87
88                // Check for ca.pem
89                let ca_pem_path = tls_dir.join("ca.pem");
90                if !ca_pem_path.exists() {
91                    let err_msg =
92                        format!("CA certificate file not found: {}", ca_pem_path.display());
93                    tracing::error!("{}", err_msg);
94                    return Err(anyhow::anyhow!(err_msg));
95                }
96
97                let cert = std::fs::read_to_string(&server_pem_path)?;
98                let key = std::fs::read_to_string(&server_key_path)?;
99                let client_ca_cert = std::fs::read_to_string(&ca_pem_path)?;
100
101                let client_ca_cert = Certificate::from_pem(client_ca_cert);
102                let server_identity = Identity::from_pem(cert, key);
103                let tls_config = ServerTlsConfig::new()
104                    .identity(server_identity)
105                    .client_ca_root(client_ca_cert);
106
107                Server::builder().tls_config(tls_config)?.add_service(
108                    CdkPaymentProcessorServer::with_interceptor(
109                        self.clone(),
110                        create_version_check_interceptor(
111                            cdk_common::PAYMENT_PROCESSOR_PROTOCOL_VERSION,
112                        ),
113                    ),
114                )
115            }
116            None => {
117                tracing::warn!("No valid TLS configuration found, starting insecure server");
118                Server::builder().add_service(CdkPaymentProcessorServer::with_interceptor(
119                    self.clone(),
120                    create_version_check_interceptor(
121                        cdk_common::PAYMENT_PROCESSOR_PROTOCOL_VERSION,
122                    ),
123                ))
124            }
125        };
126
127        let shutdown = self.shutdown.clone();
128        let addr = self.socket_addr;
129
130        self.handle = Some(Arc::new(tokio::spawn(async move {
131            let server = server.serve_with_shutdown(addr, async {
132                shutdown.notified().await;
133            });
134
135            server.await?;
136            Ok(())
137        })));
138
139        Ok(())
140    }
141
142    /// Stop fake wallet grpc server
143    pub async fn stop(&self) -> anyhow::Result<()> {
144        const SHUTDOWN_TIMEOUT: Duration = Duration::from_secs(5);
145
146        if let Some(handle) = &self.handle {
147            tracing::info!("Initiating server shutdown");
148            self.shutdown.notify_waiters();
149
150            let start = Instant::now();
151
152            while !handle.is_finished() {
153                if start.elapsed() >= SHUTDOWN_TIMEOUT {
154                    tracing::error!(
155                        "Server shutdown timed out after {} seconds, aborting handle",
156                        SHUTDOWN_TIMEOUT.as_secs()
157                    );
158                    handle.abort();
159                    break;
160                }
161                sleep(Duration::from_millis(100)).await;
162            }
163
164            if handle.is_finished() {
165                tracing::info!("Server shutdown completed successfully");
166            }
167        } else {
168            tracing::info!("No server handle found, nothing to stop");
169        }
170
171        Ok(())
172    }
173}
174
175impl Drop for PaymentProcessorServer {
176    fn drop(&mut self) {
177        tracing::debug!("Dropping payment process server");
178        self.shutdown.notify_one();
179    }
180}
181
182#[async_trait]
183impl CdkPaymentProcessor for PaymentProcessorServer {
184    async fn get_settings(
185        &self,
186        _request: Request<EmptyRequest>,
187    ) -> Result<Response<SettingsResponse>, Status> {
188        let settings = self
189            .inner
190            .get_settings()
191            .await
192            .map_err(|_| Status::internal("Could not get settings"))?;
193
194        Ok(Response::new(SettingsResponse {
195            unit: settings.unit,
196            bolt11: settings.bolt11.map(|b| super::Bolt11Settings {
197                mpp: b.mpp,
198                amountless: b.amountless,
199                invoice_description: b.invoice_description,
200            }),
201            bolt12: settings.bolt12.map(|b| super::Bolt12Settings {
202                amountless: b.amountless,
203            }),
204            custom: settings.custom,
205        }))
206    }
207
208    async fn create_payment(
209        &self,
210        request: Request<CreatePaymentRequest>,
211    ) -> Result<Response<CreatePaymentResponse>, Status> {
212        let CreatePaymentRequest { unit, options } = request.into_inner();
213
214        let unit = CurrencyUnit::from_str(&unit)
215            .map_err(|_| Status::invalid_argument("Invalid currency unit"))?;
216
217        let options = options.ok_or_else(|| Status::invalid_argument("Missing payment options"))?;
218
219        let proto_options = match options
220            .options
221            .ok_or_else(|| Status::invalid_argument("Missing options"))?
222        {
223            incoming_payment_options::Options::Custom(opts) => IncomingPaymentOptions::Custom(
224                Box::new(cdk_common::payment::CustomIncomingPaymentOptions {
225                    method: "".to_string(),
226                    description: opts.description,
227                    amount: opts.amount.unwrap_or(0).into(),
228                    unix_expiry: opts.unix_expiry,
229                    extra_json: opts.extra_json,
230                }),
231            ),
232            incoming_payment_options::Options::Bolt11(opts) => {
233                IncomingPaymentOptions::Bolt11(cdk_common::payment::Bolt11IncomingPaymentOptions {
234                    description: opts.description,
235                    amount: opts.amount.into(),
236                    unix_expiry: opts.unix_expiry,
237                })
238            }
239            incoming_payment_options::Options::Bolt12(opts) => IncomingPaymentOptions::Bolt12(
240                Box::new(cdk_common::payment::Bolt12IncomingPaymentOptions {
241                    description: opts.description,
242                    amount: opts.amount.map(Into::into),
243                    unix_expiry: opts.unix_expiry,
244                }),
245            ),
246        };
247
248        let invoice_response = self
249            .inner
250            .create_incoming_payment_request(&unit, proto_options)
251            .await
252            .map_err(|_| Status::internal("Could not create invoice"))?;
253
254        Ok(Response::new(invoice_response.into()))
255    }
256
257    async fn get_payment_quote(
258        &self,
259        request: Request<PaymentQuoteRequest>,
260    ) -> Result<Response<PaymentQuoteResponse>, Status> {
261        let request = request.into_inner();
262
263        let unit = CurrencyUnit::from_str(&request.unit)
264            .map_err(|_| Status::invalid_argument("Invalid currency unit"))?;
265
266        let options = match request.request_type() {
267            OutgoingPaymentRequestType::Bolt11Invoice => {
268                let bolt11: cdk_common::Bolt11Invoice =
269                    request.request.parse().map_err(Error::Invoice)?;
270
271                cdk_common::payment::OutgoingPaymentOptions::Bolt11(Box::new(
272                    cdk_common::payment::Bolt11OutgoingPaymentOptions {
273                        bolt11,
274                        max_fee_amount: None,
275                        timeout_secs: None,
276                        melt_options: request.options.map(Into::into),
277                    },
278                ))
279            }
280            OutgoingPaymentRequestType::Bolt12Offer => {
281                // Parse offer to verify it's valid, but store as string
282                let _: Offer = request.request.parse().map_err(|_| Error::Bolt12Parse)?;
283
284                cdk_common::payment::OutgoingPaymentOptions::Bolt12(Box::new(
285                    cdk_common::payment::Bolt12OutgoingPaymentOptions {
286                        offer: Offer::from_str(&request.request)
287                            .expect("Already validated offer above"),
288                        max_fee_amount: None,
289                        timeout_secs: None,
290                        melt_options: request.options.map(Into::into),
291                    },
292                ))
293            }
294            OutgoingPaymentRequestType::Custom => {
295                // Custom payment method - pass request as-is with no validation
296                cdk_common::payment::OutgoingPaymentOptions::Custom(Box::new(
297                    cdk_common::payment::CustomOutgoingPaymentOptions {
298                        method: String::new(), // Will be set from variant
299                        request: request.request.clone(),
300                        max_fee_amount: None,
301                        timeout_secs: None,
302                        melt_options: request.options.map(Into::into),
303                        extra_json: request.extra_json.clone(),
304                    },
305                ))
306            }
307            OutgoingPaymentRequestType::Unspecified => {
308                return Err(Status::invalid_argument("Unspecified payment request type"));
309            }
310        };
311
312        let payment_quote = self
313            .inner
314            .get_payment_quote(&unit, options)
315            .await
316            .map_err(|err| {
317                tracing::error!("Could not get payment quote: {}", err);
318                Status::internal("Could not get quote")
319            })?;
320
321        Ok(Response::new(payment_quote.into()))
322    }
323
324    async fn make_payment(
325        &self,
326        request: Request<MakePaymentRequest>,
327    ) -> Result<Response<MakePaymentResponse>, Status> {
328        let request = request.into_inner();
329
330        let options = request
331            .payment_options
332            .ok_or_else(|| Status::invalid_argument("Missing payment options"))?;
333
334        let (unit, payment_options) = match options
335            .options
336            .ok_or_else(|| Status::invalid_argument("Missing options"))?
337        {
338            outgoing_payment_variant::Options::Bolt11(opts) => {
339                let bolt11: cdk_common::Bolt11Invoice =
340                    opts.bolt11.parse().map_err(Error::Invoice)?;
341
342                let payment_options = cdk_common::payment::OutgoingPaymentOptions::Bolt11(
343                    Box::new(cdk_common::payment::Bolt11OutgoingPaymentOptions {
344                        bolt11,
345                        max_fee_amount: opts.max_fee_amount.map(Into::into),
346                        timeout_secs: opts.timeout_secs,
347                        melt_options: opts.melt_options.map(Into::into),
348                    }),
349                );
350
351                (CurrencyUnit::Msat, payment_options)
352            }
353            outgoing_payment_variant::Options::Bolt12(opts) => {
354                let offer = Offer::from_str(&opts.offer).map_err(|_| Error::Bolt12Parse)?;
355
356                let payment_options = cdk_common::payment::OutgoingPaymentOptions::Bolt12(
357                    Box::new(cdk_common::payment::Bolt12OutgoingPaymentOptions {
358                        offer,
359                        max_fee_amount: opts.max_fee_amount.map(Into::into),
360                        timeout_secs: opts.timeout_secs,
361                        melt_options: opts.melt_options.map(Into::into),
362                    }),
363                );
364
365                (CurrencyUnit::Msat, payment_options)
366            }
367            outgoing_payment_variant::Options::Custom(opts) => {
368                let payment_options = cdk_common::payment::OutgoingPaymentOptions::Custom(
369                    Box::new(cdk_common::payment::CustomOutgoingPaymentOptions {
370                        method: String::new(), // Method will be determined from context
371                        request: opts.offer,   // Reusing offer field for custom request string
372                        max_fee_amount: opts.max_fee_amount.map(Into::into),
373                        timeout_secs: opts.timeout_secs,
374                        melt_options: opts.melt_options.map(Into::into),
375                        extra_json: opts.extra_json,
376                    }),
377                );
378
379                (CurrencyUnit::Msat, payment_options)
380            }
381        };
382
383        let pay_response = self
384            .inner
385            .make_payment(&unit, payment_options)
386            .await
387            .map_err(|err| {
388                tracing::error!("Could not make payment: {}", err);
389
390                match err {
391                    cdk_common::payment::Error::InvoiceAlreadyPaid => {
392                        Status::already_exists("Payment request already paid")
393                    }
394                    cdk_common::payment::Error::InvoicePaymentPending => {
395                        Status::already_exists("Payment request pending")
396                    }
397                    _ => Status::internal("Could not pay invoice"),
398                }
399            })?;
400
401        Ok(Response::new(pay_response.into()))
402    }
403
404    async fn check_incoming_payment(
405        &self,
406        request: Request<CheckIncomingPaymentRequest>,
407    ) -> Result<Response<CheckIncomingPaymentResponse>, Status> {
408        let request = request.into_inner();
409
410        let payment_identifier = request
411            .request_identifier
412            .ok_or_else(|| Status::invalid_argument("Missing request identifier"))?
413            .try_into()
414            .map_err(|_| Status::invalid_argument("Invalid request identifier"))?;
415
416        let check_responses = self
417            .inner
418            .check_incoming_payment_status(&payment_identifier)
419            .await
420            .map_err(|_| Status::internal("Could not check incoming payment status"))?;
421
422        Ok(Response::new(CheckIncomingPaymentResponse {
423            payments: check_responses.into_iter().map(|r| r.into()).collect(),
424        }))
425    }
426
427    async fn check_outgoing_payment(
428        &self,
429        request: Request<CheckOutgoingPaymentRequest>,
430    ) -> Result<Response<MakePaymentResponse>, Status> {
431        let request = request.into_inner();
432
433        let payment_identifier = request
434            .request_identifier
435            .ok_or_else(|| Status::invalid_argument("Missing request identifier"))?
436            .try_into()
437            .map_err(|_| Status::invalid_argument("Invalid request identifier"))?;
438
439        let check_response = self
440            .inner
441            .check_outgoing_payment(&payment_identifier)
442            .await
443            .map_err(|_| Status::internal("Could not check outgoing payment status"))?;
444
445        Ok(Response::new(check_response.into()))
446    }
447
448    type WaitIncomingPaymentStream = ResponseStream;
449
450    #[allow(clippy::incompatible_msrv)]
451    #[instrument(skip_all)]
452    async fn wait_incoming_payment(
453        &self,
454        _request: Request<EmptyRequest>,
455    ) -> Result<Response<Self::WaitIncomingPaymentStream>, Status> {
456        tracing::debug!("Server waiting for payment stream");
457        let (tx, rx) = mpsc::channel(128);
458
459        let shutdown_clone = self.shutdown.clone();
460        let ln = self.inner.clone();
461        tokio::spawn(async move {
462            loop {
463                tokio::select! {
464                    _ = shutdown_clone.notified() => {
465                        tracing::info!("Shutdown signal received, stopping task");
466                        ln.cancel_wait_invoice();
467                        break;
468                    }
469                    result = ln.wait_payment_event() => {
470                        match result {
471                            Ok(mut stream) => {
472                                while let Some(event) = stream.next().await {
473                                    match event {
474                                        cdk_common::payment::Event::PaymentReceived(payment_response) => {
475                                            match tx.send(Result::<_, Status>::Ok(payment_response.into()))
476                                            .await
477                                            {
478                                                Ok(_) => {
479                                                    // Response was queued to be sent to client
480                                                }
481                                                Err(item) => {
482                                                    tracing::error!("Error adding incoming payment to stream: {}", item);
483                                                    break;
484                                                }
485                                            }
486                                        }
487                                    }
488                                }
489                            }
490                            Err(err) => {
491                                tracing::warn!("Could not get invoice stream: {}", err);
492                                tokio::time::sleep(std::time::Duration::from_secs(5)).await;
493                            }
494                        }
495                    }
496                }
497            }
498        });
499
500        let output_stream = ReceiverStream::new(rx);
501        Ok(Response::new(
502            Box::pin(output_stream) as Self::WaitIncomingPaymentStream
503        ))
504    }
505}