switchgear-service 0.1.0

Service layer and API implementations for Switchgear LNURL load balancer
Documentation
use crate::api::service::ServiceErrorSource;
use crate::components::pool::cln::grpc::config::{
    ClnGrpcClientAuth, ClnGrpcClientAuthPath, ClnGrpcDiscoveryBackendImplementation,
};
use crate::components::pool::error::{LnPoolError, LnPoolErrorSourceKind};
use crate::components::pool::{Bolt11InvoiceDescription, LnFeatures, LnMetrics, LnRpcClient};
use async_trait::async_trait;
use cln_grpc::pb;
use cln_grpc::pb::node_client::NodeClient;
use secp256k1::hashes::hex::DisplayHex;
use sha2::Digest;
use std::fs;
use std::sync::Arc;
use std::time::{Duration, SystemTime, UNIX_EPOCH};
use tokio::sync::Mutex;
use tokio::time::timeout;
use tonic::transport::{Certificate, Channel, ClientTlsConfig, Endpoint, Identity};
pub use tonic_0_11_0 as tonic;
use url::Url;

type ClientCredentials = (Vec<u8>, Vec<u8>, Vec<u8>);

pub struct DefaultClnGrpcClient {
    timeout: Duration,
    config: ClnGrpcDiscoveryBackendImplementation,
    features: Option<LnFeatures>,
    inner: Arc<Mutex<Option<Arc<InnerClnGrpcClient>>>>,
}

#[allow(clippy::result_large_err)]
impl DefaultClnGrpcClient {
    pub fn create(
        timeout: Duration,
        config: ClnGrpcDiscoveryBackendImplementation,
    ) -> Result<Self, LnPoolError> {
        Ok(Self {
            timeout,
            config,
            features: Some(LnFeatures {
                invoice_from_desc_hash: false,
            }),
            inner: Arc::new(Default::default()),
        })
    }

    async fn inner_connect(&self) -> Result<Arc<InnerClnGrpcClient>, LnPoolError> {
        let mut inner = self.inner.lock().await;
        match inner.as_ref() {
            None => {
                let inner_connect = Arc::new(
                    InnerClnGrpcClient::connect(
                        self.timeout,
                        self.config.clone(),
                        self.config.url.clone(),
                    )
                    .await?,
                );
                *inner = Some(inner_connect.clone());
                Ok(inner_connect)
            }
            Some(inner) => Ok(inner.clone()),
        }
    }

    async fn inner_disconnect(&self) {
        let mut inner = self.inner.lock().await;
        *inner = None;
    }
}

#[async_trait]
impl LnRpcClient for DefaultClnGrpcClient {
    type Error = LnPoolError;

    async fn get_invoice<'a>(
        &self,
        amount_msat: Option<u64>,
        description: Bolt11InvoiceDescription<'a>,
        expiry_secs: Option<u64>,
    ) -> Result<String, Self::Error> {
        let inner = self.inner_connect().await?;

        let r = timeout(
            self.timeout,
            inner.get_invoice(amount_msat, description, expiry_secs),
        )
        .await;

        let r = match r {
            Ok(r) => r,
            Err(_) => Err(LnPoolError::from_timeout_error(
                ServiceErrorSource::Upstream,
                format!(
                    "CLN get invoice from {}, requesting invoice",
                    self.config.url
                ),
            )),
        };

        if let Err(e) = &r {
            match e.source() {
                LnPoolErrorSourceKind::ClnTonicError(_) | LnPoolErrorSourceKind::Timeout => {
                    self.inner_disconnect().await;
                }
                _ => {}
            }
        }
        r
    }

    async fn get_metrics(&self) -> Result<LnMetrics, Self::Error> {
        let inner = self.inner_connect().await?;

        let r = timeout(self.timeout, inner.get_metrics()).await;

        let r = match r {
            Ok(r) => r,
            Err(_) => {
                return Err(LnPoolError::from_timeout_error(
                    ServiceErrorSource::Upstream,
                    format!(
                        "CLN get metrics for {}, requesting channels",
                        self.config.url
                    ),
                ));
            }
        };

        if let Err(e) = &r {
            match e.source() {
                LnPoolErrorSourceKind::ClnTonicError(_) | LnPoolErrorSourceKind::Timeout => {
                    self.inner_disconnect().await;
                }
                _ => {}
            }
        }
        r
    }

    fn get_features(&self) -> Option<&LnFeatures> {
        self.features.as_ref()
    }
}

struct InnerClnGrpcClient {
    client: NodeClient<Channel>,
    config: ClnGrpcDiscoveryBackendImplementation,
}

impl InnerClnGrpcClient {
    async fn connect(
        timeout: Duration,
        config: ClnGrpcDiscoveryBackendImplementation,
        url: Url,
    ) -> Result<Self, LnPoolError> {
        let ClnGrpcClientAuth::Path(auth) = config.auth.clone();

        let (ca_cert_data, client_cert_data, client_key_data) =
            Self::load_client_credentials(&auth)?;

        let endpoint = Channel::from_shared(url.to_string()).map_err(|e| {
            LnPoolError::from_cln_invalid_endpoint_uri(
                e,
                ServiceErrorSource::Internal,
                format!("CLN connecting to endpoint address {url}"),
            )
        })?;

        let channel = Self::connect_with_tls(
            timeout,
            &url,
            endpoint,
            &ca_cert_data,
            &client_cert_data,
            &client_key_data,
            config.domain.as_deref(),
        )
        .await?;

        let client = NodeClient::new(channel);
        Ok(Self { client, config })
    }

    #[allow(clippy::result_large_err)]
    fn load_client_credentials(
        auth: &ClnGrpcClientAuthPath,
    ) -> Result<ClientCredentials, LnPoolError> {
        let ca_cert_path = &auth.ca_cert_path;
        let client_cert_path = &auth.client_cert_path;
        let client_key_path = &auth.client_key_path;

        let ca_cert = fs::read(ca_cert_path).map_err(|e| {
            LnPoolError::from_invalid_credentials(
                e.to_string(),
                ServiceErrorSource::Internal,
                format!(
                    "loading CLN credentials {auth:?} and reading CA certificate from path {}",
                    ca_cert_path.to_string_lossy()
                ),
            )
        })?;

        let client_cert = fs::read(client_cert_path).map_err(|e| {
            LnPoolError::from_invalid_credentials(
                e.to_string(),
                ServiceErrorSource::Internal,
                format!(
                    "loading CLN credentials {auth:?} and reading client certificate from path {}",
                    ca_cert_path.to_string_lossy(),
                ),
            )
        })?;

        let client_key = fs::read(client_key_path).map_err(|e| {
            LnPoolError::from_invalid_credentials(
                e.to_string(),
                ServiceErrorSource::Internal,
                format!(
                    "loading CLN credentials {auth:?} and reading client key from path {}",
                    client_key_path.to_string_lossy()
                ),
            )
        })?;

        Ok((ca_cert, client_cert, client_key))
    }

    async fn connect_with_tls(
        timeout: Duration,
        url: &Url,
        endpoint: Endpoint,
        ca_cert: &[u8],
        client_cert: &[u8],
        client_key: &[u8],
        domain: Option<&str>,
    ) -> Result<Channel, LnPoolError> {
        let ca_cert = Certificate::from_pem(ca_cert);
        let identity = Identity::from_pem(client_cert, client_key);

        let mut tls_config = ClientTlsConfig::new()
            .ca_certificate(ca_cert)
            .identity(identity);

        if let Some(domain) = domain {
            tls_config = tls_config.domain_name(domain);
        }

        let endpoint = endpoint.tls_config(tls_config).map_err(|e| {
            LnPoolError::from_cln_transport_error(
                e,
                ServiceErrorSource::Internal,
                format!("loading CLN TLS configuration into client for {url}"),
            )
        })?;
        endpoint
            .connect_timeout(timeout)
            .connect()
            .await
            .map_err(|e| {
                LnPoolError::from_cln_transport_error(
                    e,
                    ServiceErrorSource::Upstream,
                    format!("connecting CLN client to {url}"),
                )
            })
    }

    async fn get_invoice<'a>(
        &self,
        amount_msat: Option<u64>,
        description: Bolt11InvoiceDescription<'a>,
        expiry_secs: Option<u64>,
    ) -> Result<String, LnPoolError> {
        let (description_str, deschashonly, label) = match description {
            Bolt11InvoiceDescription::Direct(d) => {
                // Use the direct description as the label
                (d.to_string(), Some(false), d.to_string())
            }
            Bolt11InvoiceDescription::DirectIntoHash(d) => {
                let hash = sha2::Sha256::digest(d.as_bytes()).to_vec();
                (d.to_string(), Some(true), hash.to_lower_hex_string())
            }
            Bolt11InvoiceDescription::Hash(_) => {
                return Err(LnPoolError::from_invalid_configuration(
                    "hash descriptions unsupported".to_string(),
                    ServiceErrorSource::Internal,
                    format!(
                        "CLN get invoice from {}, parsing invoice description",
                        self.config.url
                    ),
                ))
            }
        };

        let now = SystemTime::now().duration_since(UNIX_EPOCH).map_err(|e| {
            LnPoolError::from_invalid_configuration(
                e.to_string(),
                ServiceErrorSource::Internal,
                format!(
                    "CLN get invoice from {}, getting current time for label",
                    self.config.url
                ),
            )
        })?;
        let label = format!("{label}:{}", now.as_nanos());

        let request = pb::InvoiceRequest {
            amount_msat: match amount_msat {
                Some(msat) => Some(pb::AmountOrAny {
                    value: Some(pb::amount_or_any::Value::Amount(pb::Amount { msat })),
                }),
                None => Some(pb::AmountOrAny {
                    value: Some(pb::amount_or_any::Value::Any(true)),
                }),
            },
            description: description_str,
            label,
            deschashonly,
            expiry: expiry_secs,
            ..Default::default()
        };

        Ok(self
            .client
            .clone()
            .invoice(request)
            .await
            .map_err(|e| {
                LnPoolError::from_cln_tonic_error(
                    e,
                    format!(
                        "CLN get invoice from {}, requesting invoice",
                        self.config.url
                    ),
                )
            })?
            .into_inner()
            .bolt11)
    }

    async fn get_metrics(&self) -> Result<LnMetrics, LnPoolError> {
        // Get channel information
        let channels_request = pb::ListpeerchannelsRequest { id: None };
        let channels_response = self
            .client
            .clone()
            .list_peer_channels(channels_request)
            .await
            .map_err(|e| {
                LnPoolError::from_cln_tonic_error_with_esource(
                    e,
                    ServiceErrorSource::Upstream,
                    format!(
                        "CLN get metrics for {}, requesting channels",
                        self.config.url
                    ),
                )
            })?
            .into_inner();

        let mut node_effective_inbound_msat = 0u64;

        for channel in &channels_response.channels {
            if channel.state == pb::ChannelState::ChanneldNormal as i32 {
                let receivable_msat = channel
                    .receivable_msat
                    .as_ref()
                    .map(|a| a.msat)
                    .unwrap_or(0);
                node_effective_inbound_msat += receivable_msat;
            }
        }

        Ok(LnMetrics {
            healthy: true,
            node_effective_inbound_msat,
        })
    }
}