use std::sync::{Arc, RwLock};
use std::task::{Context, Poll};
use axum::extract::MatchedPath;
use bytes::BytesMut;
use chrono::Utc;
use futures::{future::BoxFuture, stream::StreamExt};
use http::header::CONTENT_LENGTH;
use axum::{body::Body, http::Request, response::Response};
use tower::{Layer, Service};
use crate::controller::Controller;
use crate::generic_http::{BodyCapture, GenericRequest};
use crate::transport::Transport;
use crate::{path_hint, GenericSpeakeasySdk};
#[derive(Clone)]
pub struct SpeakeasySdk<T>
where
T: Transport + Send + Clone + 'static,
{
sdk: GenericSpeakeasySdk<T>,
}
impl<T> SpeakeasySdk<T>
where
T: Transport + Send + Clone + 'static,
{
pub(crate) fn new(sdk: GenericSpeakeasySdk<T>) -> Self {
Self { sdk }
}
}
impl<S, T: Transport> Layer<S> for SpeakeasySdk<T>
where
T: Transport + Send + Clone + 'static,
{
type Service = SpeakeasySdkMiddleware<S, T>;
fn layer(&self, inner: S) -> Self::Service {
SpeakeasySdkMiddleware {
sdk: self.sdk.clone(),
inner,
}
}
}
#[derive(Clone)]
pub struct SpeakeasySdkMiddleware<S, T> {
inner: S,
sdk: GenericSpeakeasySdk<T>,
}
impl<S, T> Service<Request<Body>> for SpeakeasySdkMiddleware<S, T>
where
S: Service<Request<Body>, Response = Response> + Send + Clone + 'static,
S::Future: Send + 'static,
T: Transport + Send + Sync + Clone + 'static,
{
type Response = S::Response;
type Error = S::Error;
type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, mut request: Request<Body>) -> Self::Future {
let start_time = Utc::now();
let mut svc = self.inner.clone();
let mut controller = Controller::new(&self.sdk);
Box::pin(async move {
let mut max_reached = false;
let mut captured_body = BytesMut::new();
let mut body = BodyCapture::Empty;
let headers = request.headers();
let path_hint = request
.extensions()
.get::<MatchedPath>()
.map(|path_hint| path_hint::normalize(path_hint.as_str()));
let content_length = headers
.get(CONTENT_LENGTH)
.and_then(|value| value.to_str().unwrap().parse::<usize>().ok())
.unwrap_or_default();
if content_length <= controller.max_capture_size {
if content_length > 0 {
captured_body.reserve(content_length);
}
let payload_stream = request.body_mut();
let (mut payload_sender, payload) = Body::channel();
while let Some(chunk) = payload_stream.next().await {
captured_body.extend_from_slice(&chunk.unwrap());
if captured_body.len() >= controller.max_capture_size {
max_reached = true;
break;
}
}
payload_sender
.send_data(captured_body.clone().freeze())
.await
.unwrap();
if max_reached {
while let Some(chunk) = payload_stream.next().await {
payload_sender.send_data(chunk.unwrap()).await.unwrap();
}
body = BodyCapture::Dropped;
} else if !captured_body.is_empty() {
body = BodyCapture::Captured(captured_body.into_iter().collect());
}
let request_body = request.body_mut();
*request_body = payload;
} else {
body = BodyCapture::Dropped;
}
let generic_request = GenericRequest::new(&request, start_time, path_hint, body);
controller.set_request(generic_request);
let controller_in_arc = Arc::new(RwLock::new(controller));
request.extensions_mut().insert(controller_in_arc.clone());
let mut response = svc.call(request).await?;
response.extensions_mut().insert(controller_in_arc);
Ok(response)
})
}
}