Skip to main content

cdk_payment_processor/proto/
server.rs

1use std::net::SocketAddr;
2use std::path::PathBuf;
3use std::pin::Pin;
4use std::str::FromStr;
5use std::sync::Arc;
6use std::time::Duration;
7
8use cdk_common::grpc::create_version_check_interceptor;
9use cdk_common::payment::{IncomingPaymentOptions, MintPayment};
10use cdk_common::CurrencyUnit;
11use futures::{Stream, StreamExt};
12use lightning::offers::offer::Offer;
13use tokio::sync::{mpsc, Notify};
14use tokio::task::JoinHandle;
15use tokio::time::{sleep, Instant};
16use tokio_stream::wrappers::ReceiverStream;
17use tonic::transport::{Certificate, Identity, Server, ServerTlsConfig};
18use tonic::{async_trait, Request, Response, Status};
19use tracing::instrument;
20
21use super::cdk_payment_processor_server::{CdkPaymentProcessor, CdkPaymentProcessorServer};
22use crate::error::Error;
23use crate::proto::{TryFromProtoAmount, *};
24
25type ResponseStream =
26    Pin<Box<dyn Stream<Item = Result<WaitIncomingPaymentResponse, Status>> + Send>>;
27
28/// Payment Processor
29#[derive(Clone)]
30pub struct PaymentProcessorServer {
31    inner: Arc<dyn MintPayment<Err = cdk_common::payment::Error> + Send + Sync>,
32    socket_addr: SocketAddr,
33    shutdown: Arc<Notify>,
34    handle: Option<Arc<JoinHandle<anyhow::Result<()>>>>,
35}
36
37impl std::fmt::Debug for PaymentProcessorServer {
38    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
39        f.debug_struct("PaymentProcessorServer")
40            .field("socket_addr", &self.socket_addr)
41            .finish_non_exhaustive()
42    }
43}
44
45impl PaymentProcessorServer {
46    /// Create new [`PaymentProcessorServer`]
47    pub fn new(
48        payment_processor: Arc<dyn MintPayment<Err = cdk_common::payment::Error> + Send + Sync>,
49        addr: &str,
50        port: u16,
51    ) -> anyhow::Result<Self> {
52        let socket_addr = SocketAddr::new(addr.parse()?, port);
53        Ok(Self {
54            inner: payment_processor,
55            socket_addr,
56            shutdown: Arc::new(Notify::new()),
57            handle: None,
58        })
59    }
60
61    /// Start fake wallet grpc server
62    pub async fn start(&mut self, tls_dir: Option<PathBuf>) -> anyhow::Result<()> {
63        tracing::info!("Starting RPC server {}", self.socket_addr);
64
65        let server = match tls_dir {
66            Some(tls_dir) => {
67                tracing::info!("TLS configuration found, starting secure server");
68
69                // Check for server.pem
70                let server_pem_path = tls_dir.join("server.pem");
71                if !server_pem_path.exists() {
72                    let err_msg = format!(
73                        "TLS certificate file not found: {}",
74                        server_pem_path.display()
75                    );
76                    tracing::error!("{}", err_msg);
77                    return Err(anyhow::anyhow!(err_msg));
78                }
79
80                // Check for server.key
81                let server_key_path = tls_dir.join("server.key");
82                if !server_key_path.exists() {
83                    let err_msg = format!("TLS key file not found: {}", server_key_path.display());
84                    tracing::error!("{}", err_msg);
85                    return Err(anyhow::anyhow!(err_msg));
86                }
87
88                // Check for ca.pem
89                let ca_pem_path = tls_dir.join("ca.pem");
90                if !ca_pem_path.exists() {
91                    let err_msg =
92                        format!("CA certificate file not found: {}", ca_pem_path.display());
93                    tracing::error!("{}", err_msg);
94                    return Err(anyhow::anyhow!(err_msg));
95                }
96
97                let cert = std::fs::read_to_string(&server_pem_path)?;
98                let key = std::fs::read_to_string(&server_key_path)?;
99                let client_ca_cert = std::fs::read_to_string(&ca_pem_path)?;
100
101                let client_ca_cert = Certificate::from_pem(client_ca_cert);
102                let server_identity = Identity::from_pem(cert, key);
103                let tls_config = ServerTlsConfig::new()
104                    .identity(server_identity)
105                    .client_ca_root(client_ca_cert);
106
107                Server::builder().tls_config(tls_config)?.add_service(
108                    CdkPaymentProcessorServer::with_interceptor(
109                        self.clone(),
110                        create_version_check_interceptor(
111                            cdk_common::grpc::VERSION_HEADER,
112                            cdk_common::PAYMENT_PROCESSOR_PROTOCOL_VERSION,
113                        ),
114                    ),
115                )
116            }
117            None => {
118                tracing::warn!("No valid TLS configuration found, starting insecure server");
119                Server::builder().add_service(CdkPaymentProcessorServer::with_interceptor(
120                    self.clone(),
121                    create_version_check_interceptor(
122                        cdk_common::grpc::VERSION_HEADER,
123                        cdk_common::PAYMENT_PROCESSOR_PROTOCOL_VERSION,
124                    ),
125                ))
126            }
127        };
128
129        let shutdown = self.shutdown.clone();
130        let addr = self.socket_addr;
131
132        self.handle = Some(Arc::new(tokio::spawn(async move {
133            let server = server.serve_with_shutdown(addr, async {
134                shutdown.notified().await;
135            });
136
137            server.await?;
138            Ok(())
139        })));
140
141        Ok(())
142    }
143
144    /// Stop fake wallet grpc server
145    pub async fn stop(&self) -> anyhow::Result<()> {
146        const SHUTDOWN_TIMEOUT: Duration = Duration::from_secs(5);
147
148        if let Some(handle) = &self.handle {
149            tracing::info!("Initiating server shutdown");
150            self.shutdown.notify_waiters();
151
152            let start = Instant::now();
153
154            while !handle.is_finished() {
155                if start.elapsed() >= SHUTDOWN_TIMEOUT {
156                    tracing::error!(
157                        "Server shutdown timed out after {} seconds, aborting handle",
158                        SHUTDOWN_TIMEOUT.as_secs()
159                    );
160                    handle.abort();
161                    break;
162                }
163                sleep(Duration::from_millis(100)).await;
164            }
165
166            if handle.is_finished() {
167                tracing::info!("Server shutdown completed successfully");
168            }
169        } else {
170            tracing::info!("No server handle found, nothing to stop");
171        }
172
173        Ok(())
174    }
175}
176
177impl Drop for PaymentProcessorServer {
178    fn drop(&mut self) {
179        tracing::debug!("Dropping payment process server");
180        self.shutdown.notify_one();
181    }
182}
183
184#[async_trait]
185impl CdkPaymentProcessor for PaymentProcessorServer {
186    async fn get_settings(
187        &self,
188        _request: Request<EmptyRequest>,
189    ) -> Result<Response<SettingsResponse>, Status> {
190        let settings = self
191            .inner
192            .get_settings()
193            .await
194            .map_err(|_| Status::internal("Could not get settings"))?;
195
196        Ok(Response::new(SettingsResponse {
197            unit: settings.unit,
198            bolt11: settings.bolt11.map(|b| super::Bolt11Settings {
199                mpp: b.mpp,
200                amountless: b.amountless,
201                invoice_description: b.invoice_description,
202            }),
203            bolt12: settings.bolt12.map(|b| super::Bolt12Settings {
204                amountless: b.amountless,
205            }),
206            custom: settings.custom,
207        }))
208    }
209
210    async fn create_payment(
211        &self,
212        request: Request<CreatePaymentRequest>,
213    ) -> Result<Response<CreatePaymentResponse>, Status> {
214        let CreatePaymentRequest { options, .. } = request.into_inner();
215
216        let options = options.ok_or_else(|| Status::invalid_argument("Missing payment options"))?;
217
218        let proto_options = match options
219            .options
220            .ok_or_else(|| Status::invalid_argument("Missing options"))?
221        {
222            incoming_payment_options::Options::Custom(opts) => {
223                let amount = opts
224                    .amount
225                    .ok_or_else(|| Status::invalid_argument("Missing amount"))?
226                    .try_into()
227                    .map_err(|_| Status::invalid_argument("Invalid amount"))?;
228                IncomingPaymentOptions::Custom(Box::new(
229                    cdk_common::payment::CustomIncomingPaymentOptions {
230                        method: "".to_string(),
231                        description: opts.description,
232                        amount,
233                        unix_expiry: opts.unix_expiry,
234                        extra_json: opts.extra_json,
235                    },
236                ))
237            }
238            incoming_payment_options::Options::Bolt11(opts) => {
239                let amount = opts
240                    .amount
241                    .ok_or_else(|| Status::invalid_argument("Missing amount"))?
242                    .try_into()
243                    .map_err(|_| Status::invalid_argument("Invalid amount"))?;
244                IncomingPaymentOptions::Bolt11(cdk_common::payment::Bolt11IncomingPaymentOptions {
245                    description: opts.description,
246                    amount,
247                    unix_expiry: opts.unix_expiry,
248                })
249            }
250            incoming_payment_options::Options::Bolt12(opts) => {
251                let amount: Option<cdk_common::Amount<CurrencyUnit>> = match opts.amount {
252                    Some(a) => Some(
253                        a.try_into()
254                            .map_err(|_| Status::invalid_argument("Invalid amount"))?,
255                    ),
256                    None => None,
257                };
258                IncomingPaymentOptions::Bolt12(Box::new(
259                    cdk_common::payment::Bolt12IncomingPaymentOptions {
260                        description: opts.description,
261                        amount,
262                        unix_expiry: opts.unix_expiry,
263                    },
264                ))
265            }
266        };
267
268        let invoice_response = self
269            .inner
270            .create_incoming_payment_request(proto_options)
271            .await
272            .map_err(|_| Status::internal("Could not create invoice"))?;
273
274        Ok(Response::new(invoice_response.into()))
275    }
276
277    async fn get_payment_quote(
278        &self,
279        request: Request<PaymentQuoteRequest>,
280    ) -> Result<Response<PaymentQuoteResponse>, Status> {
281        let request = request.into_inner();
282
283        let unit = CurrencyUnit::from_str(&request.unit)
284            .map_err(|_| Status::invalid_argument("Invalid currency unit"))?;
285
286        let options = match request.request_type() {
287            OutgoingPaymentRequestType::Bolt11Invoice => {
288                let bolt11: cdk_common::Bolt11Invoice =
289                    request.request.parse().map_err(Error::Invoice)?;
290
291                cdk_common::payment::OutgoingPaymentOptions::Bolt11(Box::new(
292                    cdk_common::payment::Bolt11OutgoingPaymentOptions {
293                        bolt11,
294                        max_fee_amount: None,
295                        timeout_secs: None,
296                        melt_options: request.options.map(Into::into),
297                    },
298                ))
299            }
300            OutgoingPaymentRequestType::Bolt12Offer => {
301                // Parse offer to verify it's valid, but store as string
302                let _: Offer = request.request.parse().map_err(|_| Error::Bolt12Parse)?;
303
304                cdk_common::payment::OutgoingPaymentOptions::Bolt12(Box::new(
305                    cdk_common::payment::Bolt12OutgoingPaymentOptions {
306                        offer: Offer::from_str(&request.request)
307                            .expect("Already validated offer above"),
308                        max_fee_amount: None,
309                        timeout_secs: None,
310                        melt_options: request.options.map(Into::into),
311                    },
312                ))
313            }
314            OutgoingPaymentRequestType::Custom => {
315                // Custom payment method - pass request as-is with no validation
316                cdk_common::payment::OutgoingPaymentOptions::Custom(Box::new(
317                    cdk_common::payment::CustomOutgoingPaymentOptions {
318                        method: String::new(), // Will be set from variant
319                        request: request.request.clone(),
320                        max_fee_amount: None,
321                        timeout_secs: None,
322                        melt_options: request.options.map(Into::into),
323                        extra_json: request.extra_json.clone(),
324                    },
325                ))
326            }
327            OutgoingPaymentRequestType::Unspecified => {
328                return Err(Status::invalid_argument("Unspecified payment request type"));
329            }
330        };
331
332        let payment_quote = self
333            .inner
334            .get_payment_quote(&unit, options)
335            .await
336            .map_err(|err| {
337                tracing::error!("Could not get payment quote: {}", err);
338                Status::internal("Could not get quote")
339            })?;
340
341        Ok(Response::new(payment_quote.into()))
342    }
343
344    async fn make_payment(
345        &self,
346        request: Request<MakePaymentRequest>,
347    ) -> Result<Response<MakePaymentResponse>, Status> {
348        let request = request.into_inner();
349
350        let unit = CurrencyUnit::from_str(&request.unit)
351            .map_err(|_| Status::invalid_argument("Invalid currency unit"))?;
352
353        let options = request
354            .payment_options
355            .ok_or_else(|| Status::invalid_argument("Missing payment options"))?;
356
357        let payment_options = match options
358            .options
359            .ok_or_else(|| Status::invalid_argument("Missing options"))?
360        {
361            outgoing_payment_variant::Options::Bolt11(opts) => {
362                let bolt11: cdk_common::Bolt11Invoice =
363                    opts.bolt11.parse().map_err(Error::Invoice)?;
364
365                let max_fee_amount = opts
366                    .max_fee_amount
367                    .try_from_proto()
368                    .map_err(|_| Status::invalid_argument("Invalid max_fee_amount"))?;
369
370                cdk_common::payment::OutgoingPaymentOptions::Bolt11(Box::new(
371                    cdk_common::payment::Bolt11OutgoingPaymentOptions {
372                        bolt11,
373                        max_fee_amount,
374                        timeout_secs: opts.timeout_secs,
375                        melt_options: opts.melt_options.map(Into::into),
376                    },
377                ))
378            }
379            outgoing_payment_variant::Options::Bolt12(opts) => {
380                let offer = Offer::from_str(&opts.offer).map_err(|_| Error::Bolt12Parse)?;
381
382                let max_fee_amount = opts
383                    .max_fee_amount
384                    .try_from_proto()
385                    .map_err(|_| Status::invalid_argument("Invalid max_fee_amount"))?;
386
387                cdk_common::payment::OutgoingPaymentOptions::Bolt12(Box::new(
388                    cdk_common::payment::Bolt12OutgoingPaymentOptions {
389                        offer,
390                        max_fee_amount,
391                        timeout_secs: opts.timeout_secs,
392                        melt_options: opts.melt_options.map(Into::into),
393                    },
394                ))
395            }
396            outgoing_payment_variant::Options::Custom(opts) => {
397                let max_fee_amount = opts
398                    .max_fee_amount
399                    .try_from_proto()
400                    .map_err(|_| Status::invalid_argument("Invalid max_fee_amount"))?;
401
402                cdk_common::payment::OutgoingPaymentOptions::Custom(Box::new(
403                    cdk_common::payment::CustomOutgoingPaymentOptions {
404                        method: String::new(), // Method will be determined from context
405                        request: opts.offer,   // Reusing offer field for custom request string
406                        max_fee_amount,
407                        timeout_secs: opts.timeout_secs,
408                        melt_options: opts.melt_options.map(Into::into),
409                        extra_json: opts.extra_json,
410                    },
411                ))
412            }
413        };
414
415        let pay_response = self
416            .inner
417            .make_payment(&unit, payment_options)
418            .await
419            .map_err(|err| {
420                tracing::error!("Could not make payment: {}", err);
421
422                match err {
423                    cdk_common::payment::Error::InvoiceAlreadyPaid => {
424                        Status::already_exists("Payment request already paid")
425                    }
426                    cdk_common::payment::Error::InvoicePaymentPending => {
427                        Status::already_exists("Payment request pending")
428                    }
429                    _ => Status::internal("Could not pay invoice"),
430                }
431            })?;
432
433        Ok(Response::new(pay_response.into()))
434    }
435
436    async fn check_incoming_payment(
437        &self,
438        request: Request<CheckIncomingPaymentRequest>,
439    ) -> Result<Response<CheckIncomingPaymentResponse>, Status> {
440        let request = request.into_inner();
441
442        let payment_identifier = request
443            .request_identifier
444            .ok_or_else(|| Status::invalid_argument("Missing request identifier"))?
445            .try_into()
446            .map_err(|_| Status::invalid_argument("Invalid request identifier"))?;
447
448        let check_responses = self
449            .inner
450            .check_incoming_payment_status(&payment_identifier)
451            .await
452            .map_err(|_| Status::internal("Could not check incoming payment status"))?;
453
454        Ok(Response::new(CheckIncomingPaymentResponse {
455            payments: check_responses.into_iter().map(|r| r.into()).collect(),
456        }))
457    }
458
459    async fn check_outgoing_payment(
460        &self,
461        request: Request<CheckOutgoingPaymentRequest>,
462    ) -> Result<Response<MakePaymentResponse>, Status> {
463        let request = request.into_inner();
464
465        let payment_identifier = request
466            .request_identifier
467            .ok_or_else(|| Status::invalid_argument("Missing request identifier"))?
468            .try_into()
469            .map_err(|_| Status::invalid_argument("Invalid request identifier"))?;
470
471        let check_response = self
472            .inner
473            .check_outgoing_payment(&payment_identifier)
474            .await
475            .map_err(|_| Status::internal("Could not check outgoing payment status"))?;
476
477        Ok(Response::new(check_response.into()))
478    }
479
480    type WaitIncomingPaymentStream = ResponseStream;
481
482    #[allow(clippy::incompatible_msrv)]
483    #[instrument(skip_all)]
484    async fn wait_incoming_payment(
485        &self,
486        _request: Request<EmptyRequest>,
487    ) -> Result<Response<Self::WaitIncomingPaymentStream>, Status> {
488        tracing::debug!("Server waiting for payment stream");
489        let (tx, rx) = mpsc::channel(128);
490
491        let shutdown_clone = self.shutdown.clone();
492        let ln = self.inner.clone();
493        tokio::spawn(async move {
494            loop {
495                tokio::select! {
496                    _ = shutdown_clone.notified() => {
497                        tracing::info!("Shutdown signal received, stopping task");
498                        ln.cancel_wait_invoice();
499                        break;
500                    }
501                    result = ln.wait_payment_event() => {
502                        match result {
503                            Ok(mut stream) => {
504                                while let Some(event) = stream.next().await {
505                                    match event {
506                                        cdk_common::payment::Event::PaymentReceived(payment_response) => {
507                                            match tx.send(Result::<_, Status>::Ok(payment_response.into()))
508                                            .await
509                                            {
510                                                Ok(_) => {
511                                                    // Response was queued to be sent to client
512                                                }
513                                                Err(item) => {
514                                                    tracing::error!("Error adding incoming payment to stream: {}", item);
515                                                    break;
516                                                }
517                                            }
518                                        }
519                                    }
520                                }
521                            }
522                            Err(err) => {
523                                tracing::warn!("Could not get invoice stream: {}", err);
524                                tokio::time::sleep(std::time::Duration::from_secs(5)).await;
525                            }
526                        }
527                    }
528                }
529            }
530        });
531
532        let output_stream = ReceiverStream::new(rx);
533        Ok(Response::new(
534            Box::pin(output_stream) as Self::WaitIncomingPaymentStream
535        ))
536    }
537}