cdk_payment_processor/proto/
server.rs

1use std::net::SocketAddr;
2use std::path::PathBuf;
3use std::pin::Pin;
4use std::str::FromStr;
5use std::sync::Arc;
6use std::time::Duration;
7
8use cdk_common::payment::MintPayment;
9use futures::{Stream, StreamExt};
10use serde_json::Value;
11use tokio::sync::{mpsc, Notify};
12use tokio::task::JoinHandle;
13use tokio::time::{sleep, Instant};
14use tokio_stream::wrappers::ReceiverStream;
15use tonic::transport::{Certificate, Identity, Server, ServerTlsConfig};
16use tonic::{async_trait, Request, Response, Status};
17use tracing::instrument;
18
19use super::cdk_payment_processor_server::{CdkPaymentProcessor, CdkPaymentProcessorServer};
20use crate::proto::*;
21
22type ResponseStream =
23    Pin<Box<dyn Stream<Item = Result<WaitIncomingPaymentResponse, Status>> + Send>>;
24
25/// Payment Processor
26#[derive(Clone)]
27pub struct PaymentProcessorServer {
28    inner: Arc<dyn MintPayment<Err = cdk_common::payment::Error> + Send + Sync>,
29    socket_addr: SocketAddr,
30    shutdown: Arc<Notify>,
31    handle: Option<Arc<JoinHandle<anyhow::Result<()>>>>,
32}
33
34impl PaymentProcessorServer {
35    /// Create new [`PaymentProcessorServer`]
36    pub fn new(
37        payment_processor: Arc<dyn MintPayment<Err = cdk_common::payment::Error> + Send + Sync>,
38        addr: &str,
39        port: u16,
40    ) -> anyhow::Result<Self> {
41        let socket_addr = SocketAddr::new(addr.parse()?, port);
42        Ok(Self {
43            inner: payment_processor,
44            socket_addr,
45            shutdown: Arc::new(Notify::new()),
46            handle: None,
47        })
48    }
49
50    /// Start fake wallet grpc server
51    pub async fn start(&mut self, tls_dir: Option<PathBuf>) -> anyhow::Result<()> {
52        tracing::info!("Starting RPC server {}", self.socket_addr);
53
54        let server = match tls_dir {
55            Some(tls_dir) => {
56                tracing::info!("TLS configuration found, starting secure server");
57
58                // Check for server.pem
59                let server_pem_path = tls_dir.join("server.pem");
60                if !server_pem_path.exists() {
61                    let err_msg = format!(
62                        "TLS certificate file not found: {}",
63                        server_pem_path.display()
64                    );
65                    tracing::error!("{}", err_msg);
66                    return Err(anyhow::anyhow!(err_msg));
67                }
68
69                // Check for server.key
70                let server_key_path = tls_dir.join("server.key");
71                if !server_key_path.exists() {
72                    let err_msg = format!("TLS key file not found: {}", server_key_path.display());
73                    tracing::error!("{}", err_msg);
74                    return Err(anyhow::anyhow!(err_msg));
75                }
76
77                // Check for ca.pem
78                let ca_pem_path = tls_dir.join("ca.pem");
79                if !ca_pem_path.exists() {
80                    let err_msg =
81                        format!("CA certificate file not found: {}", ca_pem_path.display());
82                    tracing::error!("{}", err_msg);
83                    return Err(anyhow::anyhow!(err_msg));
84                }
85
86                let cert = std::fs::read_to_string(&server_pem_path)?;
87                let key = std::fs::read_to_string(&server_key_path)?;
88                let client_ca_cert = std::fs::read_to_string(&ca_pem_path)?;
89
90                let client_ca_cert = Certificate::from_pem(client_ca_cert);
91                let server_identity = Identity::from_pem(cert, key);
92                let tls_config = ServerTlsConfig::new()
93                    .identity(server_identity)
94                    .client_ca_root(client_ca_cert);
95
96                Server::builder()
97                    .tls_config(tls_config)?
98                    .add_service(CdkPaymentProcessorServer::new(self.clone()))
99            }
100            None => {
101                tracing::warn!("No valid TLS configuration found, starting insecure server");
102                Server::builder().add_service(CdkPaymentProcessorServer::new(self.clone()))
103            }
104        };
105
106        let shutdown = self.shutdown.clone();
107        let addr = self.socket_addr;
108
109        self.handle = Some(Arc::new(tokio::spawn(async move {
110            let server = server.serve_with_shutdown(addr, async {
111                shutdown.notified().await;
112            });
113
114            server.await?;
115            Ok(())
116        })));
117
118        Ok(())
119    }
120
121    /// Stop fake wallet grpc server
122    pub async fn stop(&self) -> anyhow::Result<()> {
123        const SHUTDOWN_TIMEOUT: Duration = Duration::from_secs(5);
124
125        if let Some(handle) = &self.handle {
126            tracing::info!("Initiating server shutdown");
127            self.shutdown.notify_waiters();
128
129            let start = Instant::now();
130
131            while !handle.is_finished() {
132                if start.elapsed() >= SHUTDOWN_TIMEOUT {
133                    tracing::error!(
134                        "Server shutdown timed out after {} seconds, aborting handle",
135                        SHUTDOWN_TIMEOUT.as_secs()
136                    );
137                    handle.abort();
138                    break;
139                }
140                sleep(Duration::from_millis(100)).await;
141            }
142
143            if handle.is_finished() {
144                tracing::info!("Server shutdown completed successfully");
145            }
146        } else {
147            tracing::info!("No server handle found, nothing to stop");
148        }
149
150        Ok(())
151    }
152}
153
154impl Drop for PaymentProcessorServer {
155    fn drop(&mut self) {
156        tracing::debug!("Dropping payment process server");
157        self.shutdown.notify_one();
158    }
159}
160
161#[async_trait]
162impl CdkPaymentProcessor for PaymentProcessorServer {
163    async fn get_settings(
164        &self,
165        _request: Request<SettingsRequest>,
166    ) -> Result<Response<SettingsResponse>, Status> {
167        let settings: Value = self
168            .inner
169            .get_settings()
170            .await
171            .map_err(|_| Status::internal("Could not get settings"))?;
172
173        Ok(Response::new(SettingsResponse {
174            inner: settings.to_string(),
175        }))
176    }
177
178    async fn create_payment(
179        &self,
180        request: Request<CreatePaymentRequest>,
181    ) -> Result<Response<CreatePaymentResponse>, Status> {
182        let CreatePaymentRequest {
183            amount,
184            unit,
185            description,
186            unix_expiry,
187        } = request.into_inner();
188
189        let unit =
190            CurrencyUnit::from_str(&unit).map_err(|_| Status::invalid_argument("Invalid unit"))?;
191        let invoice_response = self
192            .inner
193            .create_incoming_payment_request(amount.into(), &unit, description, unix_expiry)
194            .await
195            .map_err(|_| Status::internal("Could not create invoice"))?;
196
197        Ok(Response::new(invoice_response.into()))
198    }
199
200    async fn get_payment_quote(
201        &self,
202        request: Request<PaymentQuoteRequest>,
203    ) -> Result<Response<PaymentQuoteResponse>, Status> {
204        let request = request.into_inner();
205
206        let options: Option<cdk_common::MeltOptions> =
207            request.options.as_ref().map(|options| (*options).into());
208
209        let payment_quote = self
210            .inner
211            .get_payment_quote(
212                &request.request,
213                &CurrencyUnit::from_str(&request.unit)
214                    .map_err(|_| Status::invalid_argument("Invalid currency unit"))?,
215                options,
216            )
217            .await
218            .map_err(|err| {
219                tracing::error!("Could not get bolt11 melt quote: {}", err);
220                Status::internal("Could not get melt quote")
221            })?;
222
223        Ok(Response::new(payment_quote.into()))
224    }
225
226    async fn make_payment(
227        &self,
228        request: Request<MakePaymentRequest>,
229    ) -> Result<Response<MakePaymentResponse>, Status> {
230        let request = request.into_inner();
231
232        let pay_invoice = self
233            .inner
234            .make_payment(
235                request
236                    .melt_quote
237                    .ok_or(Status::invalid_argument("Meltquote is required"))?
238                    .try_into()
239                    .map_err(|_err| Status::invalid_argument("Invalid melt quote"))?,
240                request.partial_amount.map(|a| a.into()),
241                request.max_fee_amount.map(|a| a.into()),
242            )
243            .await
244            .map_err(|err| {
245                tracing::error!("Could not make payment: {}", err);
246
247                match err {
248                    cdk_common::payment::Error::InvoiceAlreadyPaid => {
249                        Status::already_exists("Payment request already paid")
250                    }
251                    cdk_common::payment::Error::InvoicePaymentPending => {
252                        Status::already_exists("Payment request pending")
253                    }
254                    _ => Status::internal("Could not pay invoice"),
255                }
256            })?;
257
258        Ok(Response::new(pay_invoice.into()))
259    }
260
261    async fn check_incoming_payment(
262        &self,
263        request: Request<CheckIncomingPaymentRequest>,
264    ) -> Result<Response<CheckIncomingPaymentResponse>, Status> {
265        let request = request.into_inner();
266
267        let check_response = self
268            .inner
269            .check_incoming_payment_status(&request.request_lookup_id)
270            .await
271            .map_err(|_| Status::internal("Could not check incoming payment status"))?;
272
273        Ok(Response::new(CheckIncomingPaymentResponse {
274            status: QuoteState::from(check_response).into(),
275        }))
276    }
277
278    async fn check_outgoing_payment(
279        &self,
280        request: Request<CheckOutgoingPaymentRequest>,
281    ) -> Result<Response<MakePaymentResponse>, Status> {
282        let request = request.into_inner();
283
284        let check_response = self
285            .inner
286            .check_outgoing_payment(&request.request_lookup_id)
287            .await
288            .map_err(|_| Status::internal("Could not check incoming payment status"))?;
289
290        Ok(Response::new(check_response.into()))
291    }
292
293    type WaitIncomingPaymentStream = ResponseStream;
294
295    // Clippy thinks select is not stable but it compiles fine on MSRV (1.63.0)
296    #[allow(clippy::incompatible_msrv)]
297    #[instrument(skip_all)]
298    async fn wait_incoming_payment(
299        &self,
300        _request: Request<WaitIncomingPaymentRequest>,
301    ) -> Result<Response<Self::WaitIncomingPaymentStream>, Status> {
302        tracing::debug!("Server waiting for payment stream");
303        let (tx, rx) = mpsc::channel(128);
304
305        let shutdown_clone = self.shutdown.clone();
306        let ln = self.inner.clone();
307        tokio::spawn(async move {
308            loop {
309                tokio::select! {
310                _ = shutdown_clone.notified() => {
311                    tracing::info!("Shutdown signal received, stopping task for ");
312                    ln.cancel_wait_invoice();
313                    break;
314                }
315                result = ln.wait_any_incoming_payment() => {
316                    match result {
317                        Ok(mut stream) => {
318                            while let Some(request_lookup_id) = stream.next().await {
319                                                match tx.send(Result::<_, Status>::Ok(WaitIncomingPaymentResponse{lookup_id: request_lookup_id} )).await {
320                    Ok(_) => {
321                        // item (server response) was queued to be send to client
322                    }
323                    Err(item) => {
324                        tracing::error!("Error adding incoming payment to stream: {}", item);
325                        break;
326                    }
327                }
328                            }
329                        }
330                        Err(err) => {
331                            tracing::warn!("Could not get invoice stream for {}", err);
332
333                            tokio::time::sleep(std::time::Duration::from_secs(5)).await;
334                        }
335                    }
336                }
337                }
338            }
339        });
340
341        let output_stream = ReceiverStream::new(rx);
342        Ok(Response::new(
343            Box::pin(output_stream) as Self::WaitIncomingPaymentStream
344        ))
345    }
346}