Skip to main content

cdk_payment_processor/proto/
server.rs

1use std::net::SocketAddr;
2use std::path::PathBuf;
3use std::pin::Pin;
4use std::str::FromStr;
5use std::sync::Arc;
6use std::time::Duration;
7
8use cdk_common::grpc::create_version_check_interceptor;
9use cdk_common::payment::{IncomingPaymentOptions, MintPayment};
10use cdk_common::{CurrencyUnit, QuoteId};
11use futures::{Stream, StreamExt};
12use lightning::offers::offer::Offer;
13use tokio::sync::{mpsc, Notify};
14use tokio::task::JoinHandle;
15use tokio::time::{sleep, Instant};
16use tokio_stream::wrappers::ReceiverStream;
17use tonic::transport::{Certificate, Identity, Server, ServerTlsConfig};
18use tonic::{async_trait, Request, Response, Status};
19use tracing::instrument;
20
21use super::cdk_payment_processor_server::{CdkPaymentProcessor, CdkPaymentProcessorServer};
22use crate::error::Error;
23use crate::proto::{TryFromProtoAmount, *};
24
25type ResponseStream = Pin<Box<dyn Stream<Item = Result<PaymentEventResponse, Status>> + Send>>;
26
27/// Payment Processor
28#[derive(Clone)]
29pub struct PaymentProcessorServer {
30    inner: Arc<dyn MintPayment<Err = cdk_common::payment::Error> + Send + Sync>,
31    socket_addr: SocketAddr,
32    shutdown: Arc<Notify>,
33    handle: Option<Arc<JoinHandle<anyhow::Result<()>>>>,
34}
35
36impl std::fmt::Debug for PaymentProcessorServer {
37    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
38        f.debug_struct("PaymentProcessorServer")
39            .field("socket_addr", &self.socket_addr)
40            .finish_non_exhaustive()
41    }
42}
43
44impl PaymentProcessorServer {
45    /// Create new [`PaymentProcessorServer`]
46    pub fn new(
47        payment_processor: Arc<dyn MintPayment<Err = cdk_common::payment::Error> + Send + Sync>,
48        addr: &str,
49        port: u16,
50    ) -> anyhow::Result<Self> {
51        let socket_addr = SocketAddr::new(addr.parse()?, port);
52        Ok(Self {
53            inner: payment_processor,
54            socket_addr,
55            shutdown: Arc::new(Notify::new()),
56            handle: None,
57        })
58    }
59
60    /// Start fake wallet grpc server
61    pub async fn start(&mut self, tls_dir: Option<PathBuf>) -> anyhow::Result<()> {
62        tracing::info!("Starting RPC server {}", self.socket_addr);
63
64        let server = match tls_dir {
65            Some(tls_dir) => {
66                tracing::info!("TLS configuration found, starting secure server");
67
68                // Check for server.pem
69                let server_pem_path = tls_dir.join("server.pem");
70                if !server_pem_path.exists() {
71                    let err_msg = format!(
72                        "TLS certificate file not found: {}",
73                        server_pem_path.display()
74                    );
75                    tracing::error!("{}", err_msg);
76                    return Err(anyhow::anyhow!(err_msg));
77                }
78
79                // Check for server.key
80                let server_key_path = tls_dir.join("server.key");
81                if !server_key_path.exists() {
82                    let err_msg = format!("TLS key file not found: {}", server_key_path.display());
83                    tracing::error!("{}", err_msg);
84                    return Err(anyhow::anyhow!(err_msg));
85                }
86
87                // Check for ca.pem
88                let ca_pem_path = tls_dir.join("ca.pem");
89                if !ca_pem_path.exists() {
90                    let err_msg =
91                        format!("CA certificate file not found: {}", ca_pem_path.display());
92                    tracing::error!("{}", err_msg);
93                    return Err(anyhow::anyhow!(err_msg));
94                }
95
96                let cert = std::fs::read_to_string(&server_pem_path)?;
97                let key = std::fs::read_to_string(&server_key_path)?;
98                let client_ca_cert = std::fs::read_to_string(&ca_pem_path)?;
99
100                let client_ca_cert = Certificate::from_pem(client_ca_cert);
101                let server_identity = Identity::from_pem(cert, key);
102                let tls_config = ServerTlsConfig::new()
103                    .identity(server_identity)
104                    .client_ca_root(client_ca_cert);
105
106                Server::builder().tls_config(tls_config)?.add_service(
107                    CdkPaymentProcessorServer::with_interceptor(
108                        self.clone(),
109                        create_version_check_interceptor(
110                            cdk_common::grpc::VERSION_HEADER,
111                            cdk_common::PAYMENT_PROCESSOR_PROTOCOL_VERSION,
112                        ),
113                    ),
114                )
115            }
116            None => {
117                tracing::warn!("No valid TLS configuration found, starting insecure server");
118                Server::builder().add_service(CdkPaymentProcessorServer::with_interceptor(
119                    self.clone(),
120                    create_version_check_interceptor(
121                        cdk_common::grpc::VERSION_HEADER,
122                        cdk_common::PAYMENT_PROCESSOR_PROTOCOL_VERSION,
123                    ),
124                ))
125            }
126        };
127
128        let shutdown = self.shutdown.clone();
129        let addr = self.socket_addr;
130
131        self.handle = Some(Arc::new(tokio::spawn(async move {
132            let server = server.serve_with_shutdown(addr, async {
133                shutdown.notified().await;
134            });
135
136            server.await?;
137            Ok(())
138        })));
139
140        Ok(())
141    }
142
143    /// Stop fake wallet grpc server
144    pub async fn stop(&self) -> anyhow::Result<()> {
145        const SHUTDOWN_TIMEOUT: Duration = Duration::from_secs(5);
146
147        if let Some(handle) = &self.handle {
148            tracing::info!("Initiating server shutdown");
149            self.shutdown.notify_waiters();
150
151            let start = Instant::now();
152
153            while !handle.is_finished() {
154                if start.elapsed() >= SHUTDOWN_TIMEOUT {
155                    tracing::error!(
156                        "Server shutdown timed out after {} seconds, aborting handle",
157                        SHUTDOWN_TIMEOUT.as_secs()
158                    );
159                    handle.abort();
160                    break;
161                }
162                sleep(Duration::from_millis(100)).await;
163            }
164
165            if handle.is_finished() {
166                tracing::info!("Server shutdown completed successfully");
167            }
168        } else {
169            tracing::info!("No server handle found, nothing to stop");
170        }
171
172        Ok(())
173    }
174}
175
176impl Drop for PaymentProcessorServer {
177    fn drop(&mut self) {
178        tracing::debug!("Dropping payment process server");
179        self.shutdown.notify_one();
180    }
181}
182
183#[async_trait]
184impl CdkPaymentProcessor for PaymentProcessorServer {
185    async fn get_settings(
186        &self,
187        _request: Request<EmptyRequest>,
188    ) -> Result<Response<SettingsResponse>, Status> {
189        let settings = self
190            .inner
191            .get_settings()
192            .await
193            .map_err(|_| Status::internal("Could not get settings"))?;
194
195        Ok(Response::new(SettingsResponse {
196            unit: settings.unit,
197            bolt11: settings.bolt11.map(|b| super::Bolt11Settings {
198                mpp: b.mpp,
199                amountless: b.amountless,
200                invoice_description: b.invoice_description,
201            }),
202            bolt12: settings.bolt12.map(|b| super::Bolt12Settings {
203                amountless: b.amountless,
204            }),
205            onchain: settings.onchain.map(|o| super::OnchainSettings {
206                confirmations: o.confirmations,
207                min_receive_amount_sat: o.min_receive_amount_sat,
208                min_send_amount_sat: o.min_send_amount_sat,
209            }),
210            custom: settings.custom,
211        }))
212    }
213
214    async fn create_payment(
215        &self,
216        request: Request<CreatePaymentRequest>,
217    ) -> Result<Response<CreatePaymentResponse>, Status> {
218        let CreatePaymentRequest { options, .. } = request.into_inner();
219
220        let options = options.ok_or_else(|| Status::invalid_argument("Missing payment options"))?;
221
222        let proto_options = match options
223            .options
224            .ok_or_else(|| Status::invalid_argument("Missing options"))?
225        {
226            incoming_payment_options::Options::Custom(opts) => {
227                let amount = opts
228                    .amount
229                    .ok_or_else(|| Status::invalid_argument("Missing amount"))?
230                    .try_into()
231                    .map_err(|_| Status::invalid_argument("Invalid amount"))?;
232                IncomingPaymentOptions::Custom(Box::new(
233                    cdk_common::payment::CustomIncomingPaymentOptions {
234                        method: "".to_string(),
235                        description: opts.description,
236                        amount,
237                        unix_expiry: opts.unix_expiry,
238                        extra_json: opts.extra_json,
239                    },
240                ))
241            }
242            incoming_payment_options::Options::Bolt11(opts) => {
243                let amount = opts
244                    .amount
245                    .ok_or_else(|| Status::invalid_argument("Missing amount"))?
246                    .try_into()
247                    .map_err(|_| Status::invalid_argument("Invalid amount"))?;
248                IncomingPaymentOptions::Bolt11(cdk_common::payment::Bolt11IncomingPaymentOptions {
249                    description: opts.description,
250                    amount,
251                    unix_expiry: opts.unix_expiry,
252                })
253            }
254            incoming_payment_options::Options::Bolt12(opts) => {
255                let amount: Option<cdk_common::Amount<CurrencyUnit>> = match opts.amount {
256                    Some(a) => Some(
257                        a.try_into()
258                            .map_err(|_| Status::invalid_argument("Invalid amount"))?,
259                    ),
260                    None => None,
261                };
262                IncomingPaymentOptions::Bolt12(Box::new(
263                    cdk_common::payment::Bolt12IncomingPaymentOptions {
264                        description: opts.description,
265                        amount,
266                        unix_expiry: opts.unix_expiry,
267                    },
268                ))
269            }
270            incoming_payment_options::Options::Onchain(opts) => IncomingPaymentOptions::Onchain(
271                cdk_common::payment::OnchainIncomingPaymentOptions {
272                    quote_id: opts.quote_id.parse().map_err(|_| {
273                        Status::invalid_argument("Invalid quote_id in Onchain options")
274                    })?,
275                },
276            ),
277        };
278
279        let invoice_response = self
280            .inner
281            .create_incoming_payment_request(proto_options)
282            .await
283            .map_err(|_| Status::internal("Could not create invoice"))?;
284
285        Ok(Response::new(invoice_response.into()))
286    }
287
288    async fn get_payment_quote(
289        &self,
290        request: Request<PaymentQuoteRequest>,
291    ) -> Result<Response<PaymentQuoteResponse>, Status> {
292        let request = request.into_inner();
293
294        let unit = CurrencyUnit::from_str(&request.unit)
295            .map_err(|_| Status::invalid_argument("Invalid currency unit"))?;
296
297        let quote_id = parse_quote_id(&request.quote_id)?;
298
299        let options = match request.request_type() {
300            OutgoingPaymentRequestType::Bolt11Invoice => {
301                let bolt11: cdk_common::Bolt11Invoice =
302                    request.request.parse().map_err(Error::Invoice)?;
303
304                cdk_common::payment::OutgoingPaymentOptions::Bolt11(Box::new(
305                    cdk_common::payment::Bolt11OutgoingPaymentOptions {
306                        bolt11,
307                        max_fee_amount: None,
308                        timeout_secs: None,
309                        melt_options: request.options.map(Into::into),
310                        quote_id,
311                    },
312                ))
313            }
314            OutgoingPaymentRequestType::Bolt12Offer => {
315                // Parse offer to verify it's valid, but store as string
316                let _: Offer = request.request.parse().map_err(|_| Error::Bolt12Parse)?;
317
318                cdk_common::payment::OutgoingPaymentOptions::Bolt12(Box::new(
319                    cdk_common::payment::Bolt12OutgoingPaymentOptions {
320                        offer: Offer::from_str(&request.request)
321                            .expect("Already validated offer above"),
322                        max_fee_amount: None,
323                        timeout_secs: None,
324                        melt_options: request.options.map(Into::into),
325                        quote_id,
326                    },
327                ))
328            }
329            OutgoingPaymentRequestType::Custom => {
330                // Custom payment method - pass request as-is with no validation
331                cdk_common::payment::OutgoingPaymentOptions::Custom(Box::new(
332                    cdk_common::payment::CustomOutgoingPaymentOptions {
333                        method: String::new(), // Will be set from variant
334                        request: request.request.clone(),
335                        max_fee_amount: None,
336                        timeout_secs: None,
337                        melt_options: request.options.map(Into::into),
338                        extra_json: request.extra_json.clone(),
339                        quote_id,
340                    },
341                ))
342            }
343            OutgoingPaymentRequestType::Onchain => {
344                let opts = request.onchain_options.ok_or_else(|| {
345                    Status::invalid_argument("Missing onchain_options for onchain quote")
346                })?;
347                let amount = opts
348                    .amount
349                    .ok_or_else(|| Status::invalid_argument("Missing amount in onchain quote"))?
350                    .try_into()
351                    .map_err(|_| Status::invalid_argument("Invalid amount"))?;
352                let max_fee_amount = opts
353                    .max_fee_amount
354                    .try_from_proto()
355                    .map_err(|_| Status::invalid_argument("Invalid max_fee_amount"))?;
356                let onchain_quote_id = parse_quote_id(&opts.quote_id)?;
357                if onchain_quote_id != quote_id {
358                    return Err(Status::invalid_argument(
359                        "quote_id does not match onchain_options quote_id",
360                    ));
361                }
362
363                cdk_common::payment::OutgoingPaymentOptions::Onchain(Box::new(
364                    cdk_common::payment::OnchainOutgoingPaymentOptions {
365                        address: opts.address,
366                        amount,
367                        max_fee_amount,
368                        quote_id,
369                        fee_index: opts.fee_index,
370                        metadata: opts.metadata,
371                    },
372                ))
373            }
374            OutgoingPaymentRequestType::Unspecified => {
375                return Err(Status::invalid_argument("Unspecified payment request type"));
376            }
377        };
378
379        let payment_quote = self
380            .inner
381            .get_payment_quote(&unit, options)
382            .await
383            .map_err(|err| {
384                tracing::error!("Could not get payment quote: {}", err);
385                Status::internal("Could not get quote")
386            })?;
387
388        Ok(Response::new(payment_quote.into()))
389    }
390
391    async fn make_payment(
392        &self,
393        request: Request<MakePaymentRequest>,
394    ) -> Result<Response<MakePaymentResponse>, Status> {
395        let request = request.into_inner();
396
397        let unit = CurrencyUnit::from_str(&request.unit)
398            .map_err(|_| Status::invalid_argument("Invalid currency unit"))?;
399
400        let options = request
401            .payment_options
402            .ok_or_else(|| Status::invalid_argument("Missing payment options"))?;
403
404        let payment_options = match options
405            .options
406            .ok_or_else(|| Status::invalid_argument("Missing options"))?
407        {
408            outgoing_payment_variant::Options::Bolt11(opts) => {
409                let bolt11: cdk_common::Bolt11Invoice =
410                    opts.bolt11.parse().map_err(Error::Invoice)?;
411
412                let max_fee_amount = opts
413                    .max_fee_amount
414                    .try_from_proto()
415                    .map_err(|_| Status::invalid_argument("Invalid max_fee_amount"))?;
416                let quote_id = parse_quote_id(&opts.quote_id)?;
417
418                cdk_common::payment::OutgoingPaymentOptions::Bolt11(Box::new(
419                    cdk_common::payment::Bolt11OutgoingPaymentOptions {
420                        bolt11,
421                        max_fee_amount,
422                        timeout_secs: opts.timeout_secs,
423                        melt_options: opts.melt_options.map(Into::into),
424                        quote_id,
425                    },
426                ))
427            }
428            outgoing_payment_variant::Options::Bolt12(opts) => {
429                let offer = Offer::from_str(&opts.offer).map_err(|_| Error::Bolt12Parse)?;
430
431                let max_fee_amount = opts
432                    .max_fee_amount
433                    .try_from_proto()
434                    .map_err(|_| Status::invalid_argument("Invalid max_fee_amount"))?;
435                let quote_id = parse_quote_id(&opts.quote_id)?;
436
437                cdk_common::payment::OutgoingPaymentOptions::Bolt12(Box::new(
438                    cdk_common::payment::Bolt12OutgoingPaymentOptions {
439                        offer,
440                        max_fee_amount,
441                        timeout_secs: opts.timeout_secs,
442                        melt_options: opts.melt_options.map(Into::into),
443                        quote_id,
444                    },
445                ))
446            }
447            outgoing_payment_variant::Options::Custom(opts) => {
448                let max_fee_amount = opts
449                    .max_fee_amount
450                    .try_from_proto()
451                    .map_err(|_| Status::invalid_argument("Invalid max_fee_amount"))?;
452                let quote_id = parse_quote_id(&opts.quote_id)?;
453
454                cdk_common::payment::OutgoingPaymentOptions::Custom(Box::new(
455                    cdk_common::payment::CustomOutgoingPaymentOptions {
456                        method: String::new(), // Method will be determined from context
457                        request: opts.offer,   // Reusing offer field for custom request string
458                        max_fee_amount,
459                        timeout_secs: opts.timeout_secs,
460                        melt_options: opts.melt_options.map(Into::into),
461                        extra_json: opts.extra_json,
462                        quote_id,
463                    },
464                ))
465            }
466            outgoing_payment_variant::Options::Onchain(opts) => {
467                let amount = opts
468                    .amount
469                    .ok_or_else(|| Status::invalid_argument("Missing amount"))?
470                    .try_into()
471                    .map_err(|_| Status::invalid_argument("Invalid amount"))?;
472
473                let max_fee_amount = opts
474                    .max_fee_amount
475                    .try_from_proto()
476                    .map_err(|_| Status::invalid_argument("Invalid max_fee_amount"))?;
477
478                cdk_common::payment::OutgoingPaymentOptions::Onchain(Box::new(
479                    cdk_common::payment::OnchainOutgoingPaymentOptions {
480                        address: opts.address,
481                        amount,
482                        max_fee_amount,
483                        quote_id: opts.quote_id.parse().map_err(|_| {
484                            Status::invalid_argument("Invalid quote_id in Onchain options")
485                        })?,
486                        fee_index: opts.fee_index,
487                        metadata: opts.metadata,
488                    },
489                ))
490            }
491        };
492
493        let pay_response = self
494            .inner
495            .make_payment(&unit, payment_options)
496            .await
497            .map_err(|err| {
498                tracing::error!("Could not make payment: {}", err);
499
500                match err {
501                    cdk_common::payment::Error::InvoiceAlreadyPaid => {
502                        Status::already_exists("Payment request already paid")
503                    }
504                    cdk_common::payment::Error::InvoicePaymentPending => {
505                        Status::already_exists("Payment request pending")
506                    }
507                    _ => Status::internal("Could not pay invoice"),
508                }
509            })?;
510
511        Ok(Response::new(pay_response.into()))
512    }
513
514    async fn check_incoming_payment(
515        &self,
516        request: Request<CheckIncomingPaymentRequest>,
517    ) -> Result<Response<CheckIncomingPaymentResponse>, Status> {
518        let request = request.into_inner();
519
520        let payment_identifier = request
521            .request_identifier
522            .ok_or_else(|| Status::invalid_argument("Missing request identifier"))?
523            .try_into()
524            .map_err(|_| Status::invalid_argument("Invalid request identifier"))?;
525
526        let check_responses = self
527            .inner
528            .check_incoming_payment_status(&payment_identifier)
529            .await
530            .map_err(|_| Status::internal("Could not check incoming payment status"))?;
531
532        Ok(Response::new(CheckIncomingPaymentResponse {
533            payments: check_responses.into_iter().map(|r| r.into()).collect(),
534        }))
535    }
536
537    async fn check_outgoing_payment(
538        &self,
539        request: Request<CheckOutgoingPaymentRequest>,
540    ) -> Result<Response<MakePaymentResponse>, Status> {
541        let request = request.into_inner();
542
543        let payment_identifier = request
544            .request_identifier
545            .ok_or_else(|| Status::invalid_argument("Missing request identifier"))?
546            .try_into()
547            .map_err(|_| Status::invalid_argument("Invalid request identifier"))?;
548
549        let check_response = self
550            .inner
551            .check_outgoing_payment(&payment_identifier)
552            .await
553            .map_err(|_| Status::internal("Could not check outgoing payment status"))?;
554
555        Ok(Response::new(check_response.into()))
556    }
557
558    type WaitPaymentEventStream = ResponseStream;
559
560    #[allow(clippy::incompatible_msrv)]
561    #[instrument(skip_all)]
562    async fn wait_payment_event(
563        &self,
564        _request: Request<EmptyRequest>,
565    ) -> Result<Response<Self::WaitPaymentEventStream>, Status> {
566        tracing::debug!("Server waiting for payment stream");
567        let (tx, rx) = mpsc::channel(128);
568
569        let shutdown_clone = self.shutdown.clone();
570        let ln = self.inner.clone();
571        tokio::spawn(async move {
572            loop {
573                tokio::select! {
574                    _ = shutdown_clone.notified() => {
575                        tracing::info!("Shutdown signal received, stopping task");
576                        ln.cancel_payment_event_stream();
577                        break;
578                    }
579                    result = ln.wait_payment_event() => {
580                        match result {
581                            Ok(mut stream) => {
582                                while let Some(event) = stream.next().await {
583                                    match tx.send(Result::<_, Status>::Ok(event.into())).await {
584                                        Ok(_) => {
585                                            // Response was queued to be sent to client
586                                        }
587                                        Err(item) => {
588                                            tracing::error!("Error adding payment event to stream: {}", item);
589                                            break;
590                                        }
591                                    }
592                                }
593                            }
594                            Err(err) => {
595                                tracing::warn!("Could not get invoice stream: {}", err);
596                                tokio::time::sleep(std::time::Duration::from_secs(5)).await;
597                            }
598                        }
599                    }
600                }
601            }
602        });
603
604        let output_stream = ReceiverStream::new(rx);
605        Ok(Response::new(
606            Box::pin(output_stream) as Self::WaitPaymentEventStream
607        ))
608    }
609}
610
611fn parse_quote_id(s: &str) -> Result<QuoteId, Status> {
612    s.parse()
613        .map_err(|err| Status::invalid_argument(format!("Invalid quote_id: {err}")))
614}