switchgear-service 0.1.13

Service layer and API implementations for Switchgear LNURL load balancer
Documentation
use crate::axum::extract::uuid::UuidParam;
use crate::lnurl::pay::error::LnUrlPayServiceError;
use axum::{
    extract::{FromRequestParts, Request},
    response::{IntoResponse, Response},
};
use std::collections::HashSet;
use std::convert::Infallible;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use tower::{Layer, Service};

#[derive(Clone)]
pub struct PartitionsLayer {
    partitions: Arc<HashSet<String>>,
}

impl PartitionsLayer {
    pub fn new(partitions: Arc<HashSet<String>>) -> Self {
        Self { partitions }
    }
}

impl<S> Layer<S> for PartitionsLayer {
    type Service = PartitionsService<S>;

    fn layer(&self, inner: S) -> Self::Service {
        PartitionsService {
            inner,
            partitions: self.partitions.clone(),
        }
    }
}

#[derive(Clone)]
pub struct PartitionsService<S> {
    inner: S,
    partitions: Arc<HashSet<String>>,
}

impl<S> Service<Request> for PartitionsService<S>
where
    S: Service<Request, Response = Response> + Clone + Send + 'static,
    S::Future: Send + 'static,
    S::Error: Into<Infallible>,
{
    type Response = S::Response;
    type Error = S::Error;
    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;

    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
        self.inner.poll_ready(cx)
    }

    fn call(&mut self, req: Request) -> Self::Future {
        let not_ready_inner = self.inner.clone();
        let mut inner = std::mem::replace(&mut self.inner, not_ready_inner);

        let partitions = self.partitions.clone();

        Box::pin(async move {
            let (mut parts, body) = req.into_parts();

            let uuid_param = match UuidParam::from_request_parts(&mut parts, &()).await {
                Ok(param) => param,
                Err(status) => {
                    return Ok(status.into_response());
                }
            };

            if partitions.contains(&uuid_param.partition) {
                let req = Request::from_parts(parts, body);
                inner.call(req).await
            } else {
                let error_response =
                    LnUrlPayServiceError::not_found(format!("offer not found: {}", &uuid_param.id));
                Ok(error_response.into_response())
            }
        })
    }
}