use descartes_components::{ClientEvent, Server, ServerEvent};
use descartes_core::dists::{ConstantServiceTime, ExponentialDistribution, ServiceTimeDistribution};
use descartes_core::{defer_wake, in_scheduler_context, SchedulerHandle as CoreSchedulerHandle};
use descartes_core::{
Component, Key, RequestAttempt, RequestAttemptId, RequestId, Response, Scheduler, SimTime,
Simulation,
};
use http::Request;
use pin_project::pin_project;
use std::collections::HashMap;
use std::future::Future;
use std::pin::Pin;
use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
use std::sync::{Arc, Mutex};
use std::task::{Context, Poll, Waker};
use std::time::Duration;
use tokio::sync::oneshot;
use tower::Service;
use tower_layer::Layer;
use super::{response_to_http, serialize_http_request, ServiceError, SimBody};
pub struct ResponseRouter {
pending_responses: Arc<Mutex<HashMap<RequestAttemptId, oneshot::Sender<Response>>>>,
current_load: Arc<AtomicUsize>,
wakers: Arc<Mutex<Vec<Waker>>>,
}
impl ResponseRouter {
fn new(
pending_responses: Arc<Mutex<HashMap<RequestAttemptId, oneshot::Sender<Response>>>>,
current_load: Arc<AtomicUsize>,
wakers: Arc<Mutex<Vec<Waker>>>,
) -> Self {
Self {
pending_responses,
current_load,
wakers,
}
}
}
impl Component for ResponseRouter {
type Event = ClientEvent;
fn process_event(
&mut self,
_self_id: Key<Self::Event>,
event: &Self::Event,
_scheduler: &mut Scheduler,
) {
match event {
ClientEvent::ResponseReceived { response } => {
self.current_load.fetch_sub(1, Ordering::Relaxed);
if let Some(tx) = self
.pending_responses
.lock()
.unwrap()
.remove(&response.attempt_id)
{
let _ = tx.send(response.clone());
}
let wakers: Vec<Waker> = {
let mut wakers = self.wakers.lock().unwrap();
std::mem::take(&mut *wakers)
};
for waker in wakers {
waker.wake();
}
}
ClientEvent::SendRequest => {
}
ClientEvent::RequestTimeout { .. } => {
}
ClientEvent::RetryRequest { .. } => {
}
}
}
}
#[derive(Clone)]
pub struct TowerSchedulerHandle {
scheduler: CoreSchedulerHandle,
server_key: Key<ServerEvent>,
router_key: Key<ClientEvent>,
pending_responses: Arc<Mutex<HashMap<RequestAttemptId, oneshot::Sender<Response>>>>,
next_request_id: Arc<AtomicU64>,
next_attempt_id: Arc<AtomicU64>,
current_load: Arc<AtomicUsize>,
pub capacity: usize,
wakers: Arc<Mutex<Vec<Waker>>>,
}
impl TowerSchedulerHandle {
pub fn new(
scheduler: CoreSchedulerHandle,
server_key: Key<ServerEvent>,
router_key: Key<ClientEvent>,
pending_responses: Arc<Mutex<HashMap<RequestAttemptId, oneshot::Sender<Response>>>>,
current_load: Arc<AtomicUsize>,
wakers: Arc<Mutex<Vec<Waker>>>,
capacity: usize,
) -> Self {
Self {
scheduler,
server_key,
router_key,
pending_responses,
next_request_id: Arc::new(AtomicU64::new(1)),
next_attempt_id: Arc::new(AtomicU64::new(1)),
current_load,
capacity,
wakers,
}
}
pub fn has_capacity(&self) -> bool {
self.current_load.load(Ordering::Relaxed) < self.capacity
}
pub fn schedule_request(
&self,
attempt: RequestAttempt,
response_tx: oneshot::Sender<Response>,
) -> Result<(), ServiceError> {
self.current_load.fetch_add(1, Ordering::Relaxed);
{
let mut pending = self.pending_responses.lock().unwrap();
pending.insert(attempt.id, response_tx);
}
if in_scheduler_context() {
defer_wake(
self.server_key,
ServerEvent::ProcessRequest {
attempt,
client_id: self.router_key,
},
);
} else {
self.scheduler.schedule(
SimTime::zero(),
self.server_key,
ServerEvent::ProcessRequest {
attempt,
client_id: self.router_key,
},
);
}
Ok(())
}
pub fn register_waker(&self, waker: Waker) {
let mut wakers = self.wakers.lock().unwrap();
wakers.push(waker);
}
pub fn create_request_attempt(&self, req: &Request<SimBody>) -> RequestAttempt {
if let Some(retry_meta) = crate::tower::retry::metadata::get_retry_metadata(req) {
let attempt_id = RequestAttemptId(self.next_attempt_id.fetch_add(1, Ordering::Relaxed));
let payload = serialize_http_request(req);
RequestAttempt::new(
attempt_id,
retry_meta.original_request_id,
retry_meta.attempt_number,
SimTime::from_duration(Duration::ZERO), payload,
)
} else {
let request_id = RequestId(self.next_request_id.fetch_add(1, Ordering::Relaxed));
let attempt_id = RequestAttemptId(self.next_attempt_id.fetch_add(1, Ordering::Relaxed));
let payload = serialize_http_request(req);
RequestAttempt::new(
attempt_id,
request_id,
1, SimTime::from_duration(Duration::ZERO), payload,
)
}
}
}
pub struct DesServiceBuilder<L> {
server_name: String,
thread_capacity: usize,
service_time_distribution: Option<Box<dyn ServiceTimeDistribution>>,
layer: L,
}
impl Default for DesServiceBuilder<tower_layer::Identity> {
fn default() -> Self {
Self::new("default-server".to_string())
}
}
impl DesServiceBuilder<tower_layer::Identity> {
pub const fn new(server_name: String) -> Self {
Self {
server_name,
thread_capacity: 10,
service_time_distribution: None,
layer: tower_layer::Identity::new(),
}
}
}
impl<L> DesServiceBuilder<L> {
pub fn thread_capacity(mut self, capacity: usize) -> Self {
self.thread_capacity = capacity;
self
}
pub fn service_time_distribution<D: ServiceTimeDistribution + 'static>(
mut self,
distribution: D,
) -> Self {
self.service_time_distribution = Some(Box::new(distribution));
self
}
pub fn service_time(self, duration: Duration) -> Self {
self.service_time_distribution(ConstantServiceTime::new(duration))
}
pub fn exponential_service_time(self, mean_service_time: Duration) -> Self {
let rate = 1.0 / mean_service_time.as_secs_f64();
self.service_time_distribution(ExponentialDistribution::new(rate))
}
pub fn layer<T>(self, layer: T) -> DesServiceBuilder<tower_layer::Stack<T, L>> {
DesServiceBuilder {
server_name: self.server_name,
thread_capacity: self.thread_capacity,
service_time_distribution: self.service_time_distribution,
layer: tower_layer::Stack::new(layer, self.layer),
}
}
pub fn option_layer<T>(
self,
layer: Option<T>,
) -> DesServiceBuilder<tower_layer::Stack<tower::util::Either<T, tower_layer::Identity>, L>>
{
match layer {
Some(layer) => self.layer(tower::util::Either::Left(layer)),
None => self.layer(tower::util::Either::Right(tower_layer::Identity::new())),
}
}
pub fn layer_fn<F>(
self,
f: F,
) -> DesServiceBuilder<tower_layer::Stack<tower_layer::LayerFn<F>, L>> {
self.layer(tower_layer::layer_fn(f))
}
pub fn buffer<Request>(
self,
bound: usize,
) -> DesServiceBuilder<tower_layer::Stack<tower::buffer::BufferLayer<Request>, L>> {
self.layer(tower::buffer::BufferLayer::new(bound))
}
pub fn concurrency_limit(
self,
max: usize,
) -> DesServiceBuilder<tower_layer::Stack<super::limit::concurrency::DesConcurrencyLimitLayer, L>>
{
self.layer(super::limit::concurrency::DesConcurrencyLimitLayer::new(
max,
))
}
pub fn load_shed(
self,
) -> DesServiceBuilder<tower_layer::Stack<tower::load_shed::LoadShedLayer, L>> {
self.layer(tower::load_shed::LoadShedLayer::new())
}
pub fn rate_limit(
self,
num: u64,
per: std::time::Duration,
) -> DesServiceBuilder<tower_layer::Stack<super::limit::rate::DesRateLimitLayer, L>> {
let rate = num as f64 / per.as_secs_f64();
self.layer(super::limit::rate::DesRateLimitLayer::new(
rate,
num as usize,
))
}
pub fn retry<P>(
self,
policy: P,
) -> DesServiceBuilder<tower_layer::Stack<super::retry::DesRetryLayer<P>, L>>
where
P: Clone,
{
self.layer(super::retry::DesRetryLayer::new(policy))
}
pub fn timeout(
self,
timeout: std::time::Duration,
) -> DesServiceBuilder<tower_layer::Stack<super::timeout::DesTimeoutLayer, L>> {
self.layer(super::timeout::DesTimeoutLayer::new(timeout))
}
pub fn filter<P>(
self,
predicate: P,
) -> DesServiceBuilder<tower_layer::Stack<tower::filter::FilterLayer<P>, L>> {
self.layer(tower::filter::FilterLayer::new(predicate))
}
pub fn filter_async<P>(
self,
predicate: P,
) -> DesServiceBuilder<tower_layer::Stack<tower::filter::AsyncFilterLayer<P>, L>> {
self.layer(tower::filter::AsyncFilterLayer::new(predicate))
}
pub fn map_request<F, R1, R2>(
self,
f: F,
) -> DesServiceBuilder<tower_layer::Stack<tower::util::MapRequestLayer<F>, L>>
where
F: FnMut(R1) -> R2 + Clone,
{
self.layer(tower::util::MapRequestLayer::new(f))
}
pub fn map_response<F>(
self,
f: F,
) -> DesServiceBuilder<tower_layer::Stack<tower::util::MapResponseLayer<F>, L>> {
self.layer(tower::util::MapResponseLayer::new(f))
}
pub fn map_err<F>(
self,
f: F,
) -> DesServiceBuilder<tower_layer::Stack<tower::util::MapErrLayer<F>, L>> {
self.layer(tower::util::MapErrLayer::new(f))
}
pub fn map_future<F>(
self,
f: F,
) -> DesServiceBuilder<tower_layer::Stack<tower::util::MapFutureLayer<F>, L>> {
self.layer(tower::util::MapFutureLayer::new(f))
}
pub fn then<F>(
self,
f: F,
) -> DesServiceBuilder<tower_layer::Stack<tower::util::ThenLayer<F>, L>> {
self.layer(tower::util::ThenLayer::new(f))
}
pub fn and_then<F>(
self,
f: F,
) -> DesServiceBuilder<tower_layer::Stack<tower::util::AndThenLayer<F>, L>> {
self.layer(tower::util::AndThenLayer::new(f))
}
pub fn map_result<F>(
self,
f: F,
) -> DesServiceBuilder<tower_layer::Stack<tower::util::MapResultLayer<F>, L>> {
self.layer(tower::util::MapResultLayer::new(f))
}
pub fn circuit_breaker(
self,
failure_threshold: usize,
recovery_timeout: std::time::Duration,
) -> DesServiceBuilder<tower_layer::Stack<super::circuit_breaker::DesCircuitBreakerLayer, L>>
{
self.layer(super::circuit_breaker::DesCircuitBreakerLayer::new(
failure_threshold,
recovery_timeout,
))
}
pub fn global_concurrency_limit(
self,
state: std::sync::Arc<super::limit::global_concurrency::GlobalConcurrencyLimitState>,
) -> DesServiceBuilder<
tower_layer::Stack<super::limit::global_concurrency::DesGlobalConcurrencyLimitLayer, L>,
> {
self.layer(super::limit::global_concurrency::DesGlobalConcurrencyLimitLayer::new(state))
}
pub fn hedge(
self,
delay: std::time::Duration,
) -> DesServiceBuilder<tower_layer::Stack<super::hedge::DesHedgeLayer, L>> {
self.layer(super::hedge::DesHedgeLayer::new(delay, 2))
}
pub fn into_inner(self) -> L {
self.layer
}
pub fn service<S>(&self, service: S) -> L::Service
where
L: tower_layer::Layer<S>,
{
self.layer.layer(service)
}
pub fn service_fn<F>(self, f: F) -> L::Service
where
L: tower_layer::Layer<tower::util::ServiceFn<F>>,
{
self.service(tower::util::service_fn(f))
}
#[inline]
pub fn check_clone(self) -> Self
where
Self: Clone,
{
self
}
#[inline]
pub fn check_service_clone<S>(self) -> Self
where
L: tower_layer::Layer<S>,
L::Service: Clone,
{
self
}
#[inline]
pub fn check_service<S, T, U, E>(self) -> Self
where
L: tower_layer::Layer<S>,
L::Service: tower::Service<T, Response = U, Error = E>,
{
self
}
pub fn build(self, simulation: &mut Simulation) -> Result<L::Service, ServiceError>
where
L: tower_layer::Layer<DesService>,
{
let service_time_distribution = self
.service_time_distribution
.unwrap_or_else(|| Box::new(ConstantServiceTime::new(Duration::from_millis(100))));
let pending_responses = Arc::new(Mutex::new(HashMap::new()));
let current_load = Arc::new(AtomicUsize::new(0));
let wakers = Arc::new(Mutex::new(Vec::new()));
let router = ResponseRouter::new(
pending_responses.clone(),
current_load.clone(),
wakers.clone(),
);
let router_key = simulation.add_component(router);
let server = Server::new(
self.server_name.clone(),
self.thread_capacity,
service_time_distribution,
);
let server_key = simulation.add_component(server);
let scheduler = simulation.scheduler_handle();
let handle = TowerSchedulerHandle::new(
scheduler,
server_key,
router_key,
pending_responses,
current_load,
wakers,
self.thread_capacity,
);
let base_service = DesService::new(handle);
Ok(self.layer.layer(base_service))
}
}
impl<L: std::fmt::Debug> std::fmt::Debug for DesServiceBuilder<L> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("DesServiceBuilder")
.field("server_name", &self.server_name)
.field("thread_capacity", &self.thread_capacity)
.field(
"has_service_time_distribution",
&self.service_time_distribution.is_some(),
)
.field("layer", &self.layer)
.finish()
}
}
impl<S, L> Layer<S> for DesServiceBuilder<L>
where
L: Layer<S>,
{
type Service = L::Service;
fn layer(&self, inner: S) -> Self::Service {
self.layer.layer(inner)
}
}
#[derive(Clone)]
pub struct DesService {
scheduler_handle: TowerSchedulerHandle,
}
impl DesService {
pub fn new(scheduler_handle: TowerSchedulerHandle) -> Self {
Self { scheduler_handle }
}
}
#[pin_project]
pub struct DesServiceFuture {
#[pin]
receiver: oneshot::Receiver<Response>,
}
impl Future for DesServiceFuture {
type Output = Result<http::Response<SimBody>, ServiceError>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.project();
match this.receiver.poll(cx) {
Poll::Ready(Ok(response)) => {
let http_response = response_to_http(response)?;
Poll::Ready(Ok(http_response))
}
Poll::Ready(Err(_)) => Poll::Ready(Err(ServiceError::Cancelled)),
Poll::Pending => Poll::Pending,
}
}
}
impl Service<Request<SimBody>> for DesService {
type Response = http::Response<SimBody>;
type Error = ServiceError;
type Future = DesServiceFuture;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
if self.scheduler_handle.has_capacity() {
Poll::Ready(Ok(()))
} else {
self.scheduler_handle.register_waker(cx.waker().clone());
Poll::Pending
}
}
fn call(&mut self, req: Request<SimBody>) -> Self::Future {
let attempt = http_to_request_attempt(&self.scheduler_handle, req);
let (tx, rx) = oneshot::channel();
if self.scheduler_handle.schedule_request(attempt, tx).is_err() {
let (error_tx, error_rx) = oneshot::channel();
let _ = error_tx.send(Response::error(
RequestAttemptId(0),
RequestId(0),
SimTime::from_duration(Duration::ZERO),
500,
"Failed to schedule request".to_string(),
));
return DesServiceFuture { receiver: error_rx };
}
DesServiceFuture { receiver: rx }
}
}
fn http_to_request_attempt(handle: &TowerSchedulerHandle, req: Request<SimBody>) -> RequestAttempt {
handle.create_request_attempt(&req)
}