switchgear-service 0.1.0

Service layer and API implementations for Switchgear LNURL load balancer
Documentation
use crate::api::discovery::{DiscoveryBackend, DiscoveryBackendImplementation};
use crate::api::offer::Offer;
use crate::api::service::ServiceErrorSource;
use crate::components::pool::cln::grpc::client::DefaultClnGrpcClient;
use crate::components::pool::error::{LnPoolError, LnPoolErrorSourceKind};
use crate::components::pool::lnd::grpc::client::DefaultLndGrpcClient;
use crate::components::pool::{
    Bolt11InvoiceDescription, LnClientPool, LnMetrics, LnMetricsCache, LnRpcClient,
};
use async_trait::async_trait;
use std::collections::HashMap;
use std::fmt::Debug;
use std::sync::{Arc, Mutex};
use std::time::Duration;

type LnClientMap<K> =
    HashMap<K, Arc<Box<dyn LnRpcClient<Error = LnPoolError> + Send + Sync + 'static>>>;

#[derive(Clone)]
pub struct DefaultLnClientPool<K>
where
    K: Clone + std::hash::Hash + Eq,
{
    timeout: Duration,
    pool: Arc<Mutex<LnClientMap<K>>>,
    metrics_cache: Arc<Mutex<HashMap<K, LnMetrics>>>,
}

impl<K> DefaultLnClientPool<K>
where
    K: Clone + std::hash::Hash + Eq + Debug,
{
    pub fn new(timeout: Duration) -> DefaultLnClientPool<K> {
        Self {
            timeout,
            pool: Default::default(),
            metrics_cache: Default::default(),
        }
    }

    async fn get_client(
        &self,
        key: &K,
    ) -> Result<Arc<Box<dyn LnRpcClient<Error = LnPoolError> + Send + Sync + 'static>>, LnPoolError>
    {
        let pool = self.pool.lock().map_err(|e| {
            LnPoolError::new(
                LnPoolErrorSourceKind::Generic,
                ServiceErrorSource::Internal,
                e.to_string(),
            )
        })?;
        let client = pool.get(key).ok_or_else(|| {
            LnPoolError::from_invalid_configuration(
                format!("client for key: {key:?} not found in pool"),
                ServiceErrorSource::Internal,
                format!("fetching client from pool for key: {key:?}"),
            )
        })?;
        Ok(client.clone())
    }
}

#[async_trait]
impl<K> LnClientPool for DefaultLnClientPool<K>
where
    K: Clone + std::hash::Hash + Eq + Send + Sync + Debug + 'static,
{
    type Error = LnPoolError;
    type Key = K;

    async fn get_invoice(
        &self,
        offer: &Offer,
        key: &Self::Key,
        amount_msat: Option<u64>,
        expiry_secs: Option<u64>,
    ) -> Result<String, Self::Error> {
        let client = self.get_client(key).await?;

        let capabilities = client.get_features();

        let invoice_from_desc_hash =
            capabilities.map_or_else(|| false, |c| c.invoice_from_desc_hash);

        let description = if invoice_from_desc_hash {
            Bolt11InvoiceDescription::Hash(&offer.metadata_json_hash)
        } else {
            Bolt11InvoiceDescription::DirectIntoHash(offer.metadata_json_string.as_str())
        };

        Ok(client
            .get_invoice(amount_msat, description, expiry_secs)
            .await?)
    }

    async fn get_metrics(&self, key: &Self::Key) -> Result<LnMetrics, Self::Error> {
        let client = self.get_client(key).await?;

        let metrics = client.get_metrics().await?;

        let mut cache = self.metrics_cache.lock().map_err(|e| {
            LnPoolError::new(
                LnPoolErrorSourceKind::Generic,
                ServiceErrorSource::Internal,
                e.to_string(),
            )
        })?;

        cache.insert(key.clone(), metrics.clone());
        Ok(metrics)
    }

    fn connect(&self, key: Self::Key, backend: &DiscoveryBackend) -> Result<(), Self::Error> {
        let client: Box<dyn LnRpcClient<Error = LnPoolError> + std::marker::Send + Sync> =
            match &backend.backend.implementation {
                DiscoveryBackendImplementation::ClnGrpc(c) => {
                    Box::new(DefaultClnGrpcClient::create(self.timeout, c.clone())?)
                }
                DiscoveryBackendImplementation::LndGrpc(c) => {
                    Box::new(DefaultLndGrpcClient::create(self.timeout, c.clone())?)
                }
                DiscoveryBackendImplementation::RemoteHttp => {
                    return Err(LnPoolError::new(
                        LnPoolErrorSourceKind::Generic,
                        ServiceErrorSource::Internal,
                        "RemoteHttp backends not available",
                    ));
                }
            };

        let mut pool = self.pool.lock().map_err(|e| {
            LnPoolError::new(
                LnPoolErrorSourceKind::Generic,
                ServiceErrorSource::Internal,
                e.to_string(),
            )
        })?;
        pool.insert(key, Arc::new(client));

        Ok(())
    }
}

impl<K: Clone + std::hash::Hash + Eq> LnMetricsCache for DefaultLnClientPool<K> {
    type Key = K;
    fn get_cached_metrics(&self, key: &K) -> Option<LnMetrics> {
        match self.metrics_cache.lock() {
            Ok(cache) => cache.get(key).cloned(),
            Err(_) => None,
        }
    }
}