use std::{
cell::RefCell,
rc::Rc,
task::{Context, Poll},
};
use actix_service::Service as ActixService;
use actix_web::{
body::BoxBody,
dev::{ServiceRequest, ServiceResponse},
Error,
};
use http_body::Body as HttpBody;
use tower_service::Service as TowerService;
use crate::compat::tower::{
body::ActixResponseBody,
future::{ActixServiceWrapperFuture, TowerMiddlewareFuture},
request::{http_to_service_request, service_request_to_http},
response::{http_to_service_response, service_response_to_http},
};
use crate::internal::common::{BoxError, StringError, TowerError};
pub struct ActixServiceWrapper<S> {
pub(crate) service: Rc<S>,
pub(crate) max_body_bytes: usize,
}
impl<S> ActixServiceWrapper<S> {
pub fn new(service: S, max_body_bytes: usize) -> Self {
Self {
service: Rc::new(service),
max_body_bytes,
}
}
}
impl<S> Clone for ActixServiceWrapper<S> {
fn clone(&self) -> Self {
Self {
service: self.service.clone(),
max_body_bytes: self.max_body_bytes,
}
}
}
#[derive(Debug)]
pub struct ThreadSafeActixError {
pub status: actix_web::http::StatusCode,
pub message: String,
}
impl std::fmt::Display for ThreadSafeActixError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.message)
}
}
impl std::error::Error for ThreadSafeActixError {}
impl<S> TowerService<http::Request<crate::compat::tower::body::ActixRequestBody>>
for ActixServiceWrapper<S>
where
S: ActixService<ServiceRequest, Response = ServiceResponse, Error = Error> + 'static,
S::Future: 'static,
{
type Response = http::Response<ActixResponseBody<BoxBody>>;
type Error = BoxError;
type Future = ActixServiceWrapperFuture;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.service.poll_ready(cx).map_err(|e| {
let status = e.as_response_error().status_code();
let message = e.to_string();
Box::new(ThreadSafeActixError { status, message }) as BoxError
})
}
fn call(
&mut self,
req: http::Request<crate::compat::tower::body::ActixRequestBody>,
) -> Self::Future {
let service = self.service.clone();
let max_body_bytes = self.max_body_bytes;
Box::pin(async move {
let service_request =
http_to_service_request(req, max_body_bytes)
.await
.map_err(|e| {
let status = e.as_response_error().status_code();
let message = e.to_string();
Box::new(ThreadSafeActixError { status, message }) as BoxError
})?;
let service_response = service.call(service_request).await.map_err(|e| {
let status = e.as_response_error().status_code();
let message = e.to_string();
Box::new(ThreadSafeActixError { status, message }) as BoxError
})?;
Ok(service_response_to_http(service_response))
})
}
}
pub struct TowerMiddlewareService<TS> {
pub(crate) tower_service: Rc<RefCell<TS>>,
pub(crate) max_body_bytes: usize,
}
impl<TS> TowerMiddlewareService<TS> {
pub fn new(tower_service: TS, max_body_bytes: usize) -> Self {
Self {
tower_service: Rc::new(RefCell::new(tower_service)),
max_body_bytes,
}
}
}
impl<TS> Clone for TowerMiddlewareService<TS> {
fn clone(&self) -> Self {
Self {
tower_service: self.tower_service.clone(),
max_body_bytes: self.max_body_bytes,
}
}
}
impl<TS, B, E> ActixService<ServiceRequest> for TowerMiddlewareService<TS>
where
TS: TowerService<
http::Request<crate::compat::tower::body::ActixRequestBody>,
Response = http::Response<B>,
Error = E,
> + 'static,
TS::Future: 'static,
B: HttpBody<Data = actix_web::web::Bytes> + 'static,
B::Error: std::fmt::Display + 'static,
E: Into<crate::internal::common::BoxError> + 'static,
{
type Response = ServiceResponse;
type Error = Error;
type Future = TowerMiddlewareFuture;
fn poll_ready(&self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.tower_service.borrow_mut().poll_ready(cx).map_err(|e| {
let boxed: BoxError = e.into();
match boxed.downcast::<ThreadSafeActixError>() {
Ok(wrapped) => {
actix_web::error::InternalError::new(wrapped.message, wrapped.status).into()
}
Err(boxed) => Error::from(TowerError(boxed)),
}
})
}
fn call(&self, req: ServiceRequest) -> Self::Future {
let tower_service = self.tower_service.clone();
let http_request = service_request_to_http(req);
let req_id = http_request
.extensions()
.get::<crate::compat::tower::request::RequestRegistryGuard>()
.expect("RequestRegistryGuard missing from newly created request")
.req_id;
Box::pin(async move {
let call_fut = tower_service.borrow_mut().call(http_request);
let mut http_response = call_fut.await.map_err(|e| {
let boxed: BoxError = e.into();
match boxed.downcast::<ThreadSafeActixError>() {
Ok(wrapped) => actix_web::error::InternalError::new(wrapped.message, wrapped.status).into(),
Err(boxed) => Error::from(TowerError(boxed)),
}
})?;
let guard = http_response
.extensions_mut()
.remove::<crate::compat::tower::request::ResponseRegistryGuard>();
if let Some(g) = guard {
std::mem::forget(g);
}
let actix_req = crate::compat::tower::request::RESPONSE_REGISTRY
.with(|registry| registry.borrow_mut().remove(&req_id))
.or_else(|| {
crate::compat::tower::request::REQUEST_REGISTRY
.with(|registry| registry.borrow_mut().remove(&req_id))
})
.ok_or_else(|| {
Error::from(TowerError(Box::new(StringError(
"actix_tower: HttpRequest not found in registries. \
This is an internal bug; please file an issue."
.to_owned(),
)) as BoxError))
})?;
Ok(http_to_service_response(http_response, actix_req))
})
}
}