cdk_payment_processor/proto/
server.rs

1use std::net::SocketAddr;
2use std::path::PathBuf;
3use std::pin::Pin;
4use std::str::FromStr;
5use std::sync::Arc;
6use std::time::Duration;
7
8use cdk_common::payment::MintPayment;
9use futures::{Stream, StreamExt};
10use serde_json::Value;
11use tokio::sync::{mpsc, Notify};
12use tokio::task::JoinHandle;
13use tokio::time::{sleep, Instant};
14use tokio_stream::wrappers::ReceiverStream;
15use tonic::transport::{Certificate, Identity, Server, ServerTlsConfig};
16use tonic::{async_trait, Request, Response, Status};
17use tracing::instrument;
18
19use super::cdk_payment_processor_server::{CdkPaymentProcessor, CdkPaymentProcessorServer};
20use crate::proto::*;
21
22type ResponseStream =
23    Pin<Box<dyn Stream<Item = Result<WaitIncomingPaymentResponse, Status>> + Send>>;
24
25/// Payment Processor
26#[derive(Clone)]
27pub struct PaymentProcessorServer {
28    inner: Arc<dyn MintPayment<Err = cdk_common::payment::Error> + Send + Sync>,
29    socket_addr: SocketAddr,
30    shutdown: Arc<Notify>,
31    handle: Option<Arc<JoinHandle<anyhow::Result<()>>>>,
32}
33
34impl PaymentProcessorServer {
35    pub fn new(
36        payment_processor: Arc<dyn MintPayment<Err = cdk_common::payment::Error> + Send + Sync>,
37        addr: &str,
38        port: u16,
39    ) -> anyhow::Result<Self> {
40        let socket_addr = SocketAddr::new(addr.parse()?, port);
41        Ok(Self {
42            inner: payment_processor,
43            socket_addr,
44            shutdown: Arc::new(Notify::new()),
45            handle: None,
46        })
47    }
48
49    /// Start fake wallet grpc server
50    pub async fn start(&mut self, tls_dir: Option<PathBuf>) -> anyhow::Result<()> {
51        tracing::info!("Starting RPC server {}", self.socket_addr);
52
53        let server = match tls_dir {
54            Some(tls_dir) => {
55                tracing::info!("TLS configuration found, starting secure server");
56
57                // Check for server.pem
58                let server_pem_path = tls_dir.join("server.pem");
59                if !server_pem_path.exists() {
60                    let err_msg = format!(
61                        "TLS certificate file not found: {}",
62                        server_pem_path.display()
63                    );
64                    tracing::error!("{}", err_msg);
65                    return Err(anyhow::anyhow!(err_msg));
66                }
67
68                // Check for server.key
69                let server_key_path = tls_dir.join("server.key");
70                if !server_key_path.exists() {
71                    let err_msg = format!("TLS key file not found: {}", server_key_path.display());
72                    tracing::error!("{}", err_msg);
73                    return Err(anyhow::anyhow!(err_msg));
74                }
75
76                // Check for ca.pem
77                let ca_pem_path = tls_dir.join("ca.pem");
78                if !ca_pem_path.exists() {
79                    let err_msg =
80                        format!("CA certificate file not found: {}", ca_pem_path.display());
81                    tracing::error!("{}", err_msg);
82                    return Err(anyhow::anyhow!(err_msg));
83                }
84
85                let cert = std::fs::read_to_string(&server_pem_path)?;
86                let key = std::fs::read_to_string(&server_key_path)?;
87                let client_ca_cert = std::fs::read_to_string(&ca_pem_path)?;
88
89                let client_ca_cert = Certificate::from_pem(client_ca_cert);
90                let server_identity = Identity::from_pem(cert, key);
91                let tls_config = ServerTlsConfig::new()
92                    .identity(server_identity)
93                    .client_ca_root(client_ca_cert);
94
95                Server::builder()
96                    .tls_config(tls_config)?
97                    .add_service(CdkPaymentProcessorServer::new(self.clone()))
98            }
99            None => {
100                tracing::warn!("No valid TLS configuration found, starting insecure server");
101                Server::builder().add_service(CdkPaymentProcessorServer::new(self.clone()))
102            }
103        };
104
105        let shutdown = self.shutdown.clone();
106        let addr = self.socket_addr;
107
108        self.handle = Some(Arc::new(tokio::spawn(async move {
109            let server = server.serve_with_shutdown(addr, async {
110                shutdown.notified().await;
111            });
112
113            server.await?;
114            Ok(())
115        })));
116
117        Ok(())
118    }
119
120    /// Stop fake wallet grpc server
121    pub async fn stop(&self) -> anyhow::Result<()> {
122        const SHUTDOWN_TIMEOUT: Duration = Duration::from_secs(5);
123
124        if let Some(handle) = &self.handle {
125            tracing::info!("Initiating server shutdown");
126            self.shutdown.notify_waiters();
127
128            let start = Instant::now();
129
130            while !handle.is_finished() {
131                if start.elapsed() >= SHUTDOWN_TIMEOUT {
132                    tracing::error!(
133                        "Server shutdown timed out after {} seconds, aborting handle",
134                        SHUTDOWN_TIMEOUT.as_secs()
135                    );
136                    handle.abort();
137                    break;
138                }
139                sleep(Duration::from_millis(100)).await;
140            }
141
142            if handle.is_finished() {
143                tracing::info!("Server shutdown completed successfully");
144            }
145        } else {
146            tracing::info!("No server handle found, nothing to stop");
147        }
148
149        Ok(())
150    }
151}
152
153impl Drop for PaymentProcessorServer {
154    fn drop(&mut self) {
155        tracing::debug!("Dropping payment process server");
156        self.shutdown.notify_one();
157    }
158}
159
160#[async_trait]
161impl CdkPaymentProcessor for PaymentProcessorServer {
162    async fn get_settings(
163        &self,
164        _request: Request<SettingsRequest>,
165    ) -> Result<Response<SettingsResponse>, Status> {
166        let settings: Value = self
167            .inner
168            .get_settings()
169            .await
170            .map_err(|_| Status::internal("Could not get settings"))?;
171
172        Ok(Response::new(SettingsResponse {
173            inner: settings.to_string(),
174        }))
175    }
176
177    async fn create_payment(
178        &self,
179        request: Request<CreatePaymentRequest>,
180    ) -> Result<Response<CreatePaymentResponse>, Status> {
181        let CreatePaymentRequest {
182            amount,
183            unit,
184            description,
185            unix_expiry,
186        } = request.into_inner();
187
188        let unit =
189            CurrencyUnit::from_str(&unit).map_err(|_| Status::invalid_argument("Invalid unit"))?;
190        let invoice_response = self
191            .inner
192            .create_incoming_payment_request(amount.into(), &unit, description, unix_expiry)
193            .await
194            .map_err(|_| Status::internal("Could not create invoice"))?;
195
196        Ok(Response::new(invoice_response.into()))
197    }
198
199    async fn get_payment_quote(
200        &self,
201        request: Request<PaymentQuoteRequest>,
202    ) -> Result<Response<PaymentQuoteResponse>, Status> {
203        let request = request.into_inner();
204
205        let options: Option<cdk_common::MeltOptions> =
206            request.options.as_ref().map(|options| (*options).into());
207
208        let payment_quote = self
209            .inner
210            .get_payment_quote(
211                &request.request,
212                &CurrencyUnit::from_str(&request.unit)
213                    .map_err(|_| Status::invalid_argument("Invalid currency unit"))?,
214                options,
215            )
216            .await
217            .map_err(|err| {
218                tracing::error!("Could not get bolt11 melt quote: {}", err);
219                Status::internal("Could not get melt quote")
220            })?;
221
222        Ok(Response::new(payment_quote.into()))
223    }
224
225    async fn make_payment(
226        &self,
227        request: Request<MakePaymentRequest>,
228    ) -> Result<Response<MakePaymentResponse>, Status> {
229        let request = request.into_inner();
230
231        let pay_invoice = self
232            .inner
233            .make_payment(
234                request
235                    .melt_quote
236                    .ok_or(Status::invalid_argument("Meltquote is required"))?
237                    .try_into()
238                    .map_err(|_err| Status::invalid_argument("Invalid melt quote"))?,
239                request.partial_amount.map(|a| a.into()),
240                request.max_fee_amount.map(|a| a.into()),
241            )
242            .await
243            .map_err(|err| {
244                tracing::error!("Could not make payment: {}", err);
245
246                match err {
247                    cdk_common::payment::Error::InvoiceAlreadyPaid => {
248                        Status::already_exists("Payment request already paid")
249                    }
250                    cdk_common::payment::Error::InvoicePaymentPending => {
251                        Status::already_exists("Payment request pending")
252                    }
253                    _ => Status::internal("Could not pay invoice"),
254                }
255            })?;
256
257        Ok(Response::new(pay_invoice.into()))
258    }
259
260    async fn check_incoming_payment(
261        &self,
262        request: Request<CheckIncomingPaymentRequest>,
263    ) -> Result<Response<CheckIncomingPaymentResponse>, Status> {
264        let request = request.into_inner();
265
266        let check_response = self
267            .inner
268            .check_incoming_payment_status(&request.request_lookup_id)
269            .await
270            .map_err(|_| Status::internal("Could not check incoming payment status"))?;
271
272        Ok(Response::new(CheckIncomingPaymentResponse {
273            status: QuoteState::from(check_response).into(),
274        }))
275    }
276
277    async fn check_outgoing_payment(
278        &self,
279        request: Request<CheckOutgoingPaymentRequest>,
280    ) -> Result<Response<MakePaymentResponse>, Status> {
281        let request = request.into_inner();
282
283        let check_response = self
284            .inner
285            .check_outgoing_payment(&request.request_lookup_id)
286            .await
287            .map_err(|_| Status::internal("Could not check incoming payment status"))?;
288
289        Ok(Response::new(check_response.into()))
290    }
291
292    type WaitIncomingPaymentStream = ResponseStream;
293
294    // Clippy thinks select is not stable but it compiles fine on MSRV (1.63.0)
295    #[allow(clippy::incompatible_msrv)]
296    #[instrument(skip_all)]
297    async fn wait_incoming_payment(
298        &self,
299        _request: Request<WaitIncomingPaymentRequest>,
300    ) -> Result<Response<Self::WaitIncomingPaymentStream>, Status> {
301        tracing::debug!("Server waiting for payment stream");
302        let (tx, rx) = mpsc::channel(128);
303
304        let shutdown_clone = self.shutdown.clone();
305        let ln = self.inner.clone();
306        tokio::spawn(async move {
307            loop {
308                tokio::select! {
309                _ = shutdown_clone.notified() => {
310                    tracing::info!("Shutdown signal received, stopping task for ");
311                    ln.cancel_wait_invoice();
312                    break;
313                }
314                result = ln.wait_any_incoming_payment() => {
315                    match result {
316                        Ok(mut stream) => {
317                            while let Some(request_lookup_id) = stream.next().await {
318                                                match tx.send(Result::<_, Status>::Ok(WaitIncomingPaymentResponse{lookup_id: request_lookup_id} )).await {
319                    Ok(_) => {
320                        // item (server response) was queued to be send to client
321                    }
322                    Err(item) => {
323                        tracing::error!("Error adding incoming payment to stream: {}", item);
324                        break;
325                    }
326                }
327                            }
328                        }
329                        Err(err) => {
330                            tracing::warn!("Could not get invoice stream for {}", err);
331
332                            tokio::time::sleep(std::time::Duration::from_secs(5)).await;
333                        }
334                    }
335                }
336                }
337            }
338        });
339
340        let output_stream = ReceiverStream::new(rx);
341        Ok(Response::new(
342            Box::pin(output_stream) as Self::WaitIncomingPaymentStream
343        ))
344    }
345}