pub mod channel;
pub mod types;
pub mod utils;
pub mod verify;
mod extractors;
use std::{future::Future, pin::Pin};
use axum::{
body::Body,
extract::State,
http::{Request, StatusCode},
middleware::Next,
response::{IntoResponse, Response},
};
use tower::{Layer, Service};
use channel::ChannelState;
use types::PaymentChannelConfig;
use utils::{modify_headers_axum, parse_headers};
use verify::verify_and_update_channel;
use crate::error::AuthError;
#[derive(Clone)]
#[cfg(not(target_arch = "wasm32"))]
#[deprecated(
since = "0.6.0",
note = "Use middleware::PaymentsLayer (unified PipegateMiddlewareLayer alias)"
)]
pub struct PaymentChannelMiddlewareLayer {
pub state: ChannelState,
pub config: PaymentChannelConfig,
}
#[cfg(not(target_arch = "wasm32"))]
#[allow(deprecated)]
impl PaymentChannelMiddlewareLayer {
pub fn new(state: ChannelState, config: PaymentChannelConfig) -> Self {
Self { state, config }
}
}
#[cfg(not(target_arch = "wasm32"))]
#[allow(deprecated)]
impl<S> Layer<S> for PaymentChannelMiddlewareLayer {
type Service = PaymentChannelMiddleware<S>;
fn layer(&self, service: S) -> Self::Service {
PaymentChannelMiddleware {
state: self.state.clone(),
config: self.config.clone(),
inner: service,
}
}
}
#[derive(Clone)]
#[cfg(not(target_arch = "wasm32"))]
#[deprecated(
since = "0.6.0",
note = "Use middleware::Payments<S> (unified PipegateMiddleware alias)"
)]
pub struct PaymentChannelMiddleware<S> {
inner: S,
state: ChannelState,
config: PaymentChannelConfig,
}
#[cfg(not(target_arch = "wasm32"))]
#[allow(deprecated)]
impl<S> Service<Request<Body>> for PaymentChannelMiddleware<S>
where
S: Service<Request<Body>, Response = Response> + Clone + Send + 'static,
S::Future: Send + 'static,
{
type Response = S::Response;
type Error = S::Error;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
fn poll_ready(
&mut self,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, request: Request<Body>) -> Self::Future {
let state = self.state.clone();
let config = self.config.clone();
let payment_amount = self.config.clone().amount;
let mut inner = self.inner.clone();
Box::pin(async move {
println!("\n=== auth_middleware ===");
println!(" === new request ===");
let (parts, body) = request.into_parts();
let body_bytes = match axum::body::to_bytes(body, usize::MAX).await {
Ok(bytes) => bytes,
Err(_) => {
println!("Failed: Body decode");
return Ok(
AuthError::InvalidRequest("Failed to decode body".to_string())
.into_response(),
);
}
};
let signed_request =
match parse_headers(&parts.headers, body_bytes.to_vec(), payment_amount).await {
Ok(signed_request) => signed_request,
Err(e) => {
println!("Failed: Parse headers");
return Ok(e.into_response());
}
};
let body_bytes = signed_request.body_bytes.clone();
let (updated_channel, verify) =
match verify_and_update_channel(&state, &config, signed_request).await {
Ok((updated_channel, verify)) => (updated_channel, verify),
Err(e) => {
println!("Failed: Verify and update channel");
return Ok(e.into_response());
}
};
state
.channels
.write()
.await
.insert(updated_channel.channel_id, updated_channel.clone());
if verify {
println!("Verified");
println!("=== end middleware check ===");
let request = Request::from_parts(parts, Body::from(body_bytes));
let response = inner.call(request).await?;
let response = modify_headers_axum(response, &updated_channel);
Ok(response)
} else {
Ok(AuthError::InternalError.into_response())
}
})
}
}
#[derive(Clone)]
pub struct PaymentChannelFnMiddlewareState {
state: ChannelState,
config: PaymentChannelConfig,
}
pub async fn payment_channel_auth_fn_middleware(
State(state): State<PaymentChannelFnMiddlewareState>,
request: Request<Body>,
next: Next,
) -> Result<Response, StatusCode> {
println!("\n=== auth_middleware ===");
println!(" === new request ===");
let (parts, body) = request.into_parts();
let body_bytes = match axum::body::to_bytes(body, usize::MAX).await {
Ok(bytes) => bytes,
Err(_) => {
println!("Failed: Body decode");
return Ok(
AuthError::InvalidRequest("Failed to decode body".to_string()).into_response(),
);
}
};
let signed_request =
match parse_headers(&parts.headers, body_bytes.to_vec(), state.config.amount).await {
Ok(signed_request) => signed_request,
Err(e) => {
println!("Failed: Parse headers");
return Ok(e.into_response());
}
};
let body_bytes = signed_request.body_bytes.clone();
let (updated_channel, verify) =
match verify_and_update_channel(&state.state, &state.config, signed_request).await {
Ok((updated_channel, verify)) => (updated_channel, verify),
Err(e) => {
println!("Failed: Verify and update channel");
return Ok(e.into_response());
}
};
state
.state
.channels
.write()
.await
.insert(updated_channel.channel_id, updated_channel.clone());
let request = Request::from_parts(parts, Body::from(body_bytes));
if verify {
println!("Verified");
println!("=== end middleware check ===");
let response = next.run(request).await;
let response = modify_headers_axum(response, &updated_channel);
Ok(response)
} else {
Err(StatusCode::UNAUTHORIZED)
}
}