vorma 0.86.0-pre.1

Vorma framework.
Documentation
use std::collections::BTreeMap;
use std::future::Future;
use std::marker::PhantomData;
use std::pin::Pin;
use std::sync::{Arc, Mutex};

use http::StatusCode;
use vorma_matcher::Params;
use vorma_tasks::{CancelToken, Error as TaskError, ExecCtx, Result as TaskResult};

use crate::response::{Proxy, merge_proxy_responses};

use super::context::{None, RequestCtx};
use super::error::Error;
use super::request::RawRequest;
use super::task::{proxy_for_task_output, run_with_exec_cancellation};

type MiddlewareFuture<'a, E> = Pin<Box<dyn Future<Output = TaskResult<(), E>> + Send + 'a>>;

type MiddlewareMarker<S, E, O> = fn() -> (S, E, O);

trait ErasedMiddleware<S, E>: Send + Sync {
	fn run<'a>(&'a self, ctx: RequestCtx<S, E, None>) -> MiddlewareFuture<'a, E>;
}

struct FnMiddleware<S, E, F, O>
where
	S: Send + Sync + 'static,
	E: Send + Sync + 'static,
{
	handler: F,
	_marker: PhantomData<MiddlewareMarker<S, E, O>>,
}

impl<S, E, F, Fut, O> ErasedMiddleware<S, E> for FnMiddleware<S, E, F, O>
where
	S: Send + Sync + 'static,
	E: Send + Sync + 'static,
	F: Fn(RequestCtx<S, E, None>) -> Fut + Send + Sync + 'static,
	Fut: Future<Output = TaskResult<O, E>> + Send + 'static,
	O: Send + Sync + 'static,
{
	fn run<'a>(&'a self, ctx: RequestCtx<S, E, None>) -> MiddlewareFuture<'a, E> {
		Box::pin(async move { (self.handler)(ctx).await.map(|_| ()) })
	}
}

pub(crate) struct Middleware<S, E> {
	mw: Arc<dyn ErasedMiddleware<S, E>>,
}

impl<S, E> Clone for Middleware<S, E> {
	fn clone(&self) -> Self {
		Self {
			mw: self.mw.clone(),
		}
	}
}

impl<S, E> Middleware<S, E>
where
	S: Send + Sync + 'static,
	E: Send + Sync + 'static,
{
	pub(crate) fn new<F, Fut, O>(handler: F) -> Self
	where
		F: Fn(RequestCtx<S, E, None>) -> Fut + Send + Sync + 'static,
		Fut: Future<Output = TaskResult<O, E>> + Send + 'static,
		O: Send + Sync + 'static,
	{
		Self {
			mw: Arc::new(FnMiddleware::<S, E, F, O> {
				handler,
				_marker: PhantomData,
			}),
		}
	}
}

pub(in crate::mux) struct MiddlewareInvocation<S, E> {
	entry: Middleware<S, E>,
	matched_pattern: String,
}

impl<S, E> MiddlewareInvocation<S, E> {
	pub(in crate::mux) fn new(entry: &Middleware<S, E>, matched_pattern: &str) -> Self {
		Self {
			entry: entry.clone(),
			matched_pattern: matched_pattern.to_owned(),
		}
	}
}

struct MiddlewareOutput<E> {
	index: usize,
	proxy: Proxy,
	error: Option<TaskError<E>>,
}

impl<E> Clone for MiddlewareOutput<E>
where
	TaskError<E>: Clone,
{
	fn clone(&self) -> Self {
		Self {
			index: self.index,
			proxy: self.proxy.clone(),
			error: self.error.clone(),
		}
	}
}

pub(in crate::mux) async fn run_middleware_entries<S, E>(
	request: &RawRequest,
	state: Arc<S>,
	exec_ctx: ExecCtx<E>,
	public_filemap: Arc<BTreeMap<String, String>>,
	params: Params,
	splat_values: Vec<String>,
	middleware_entries: Vec<MiddlewareInvocation<S, E>>,
) -> Result<Proxy, Error>
where
	S: Send + Sync + 'static,
	E: Send + Sync + 'static,
{
	if middleware_entries.is_empty() {
		return Ok(Proxy::new());
	}
	let middleware_exec_ctxs = (0..middleware_entries.len())
		.map(|_| exec_ctx.child())
		.collect::<Vec<_>>();
	let cancel_middleware = Arc::new(
		middleware_exec_ctxs
			.iter()
			.map(|exec_ctx| exec_ctx.cancel_token().clone())
			.collect::<Vec<_>>(),
	);
	let mut handles = Vec::with_capacity(middleware_exec_ctxs.len());

	for (index, (invocation, exec_ctx)) in middleware_entries
		.into_iter()
		.zip(middleware_exec_ctxs)
		.enumerate()
	{
		let proxy = Arc::new(Mutex::new(Proxy::new()));
		let ctx = RequestCtx {
			matched_pattern: invocation.matched_pattern,
			params: params.clone(),
			splat_values: splat_values.clone(),
			state: state.clone(),
			exec_ctx: exec_ctx.clone(),
			public_filemap: public_filemap.clone(),
			response_proxy: proxy.clone(),
			request: request.clone(),
			input: None,
		};
		let middleware = invocation.entry.mw.clone();
		let cancel_middleware = cancel_middleware.clone();
		handles.push((
			index,
			tokio::spawn(async move {
				let ctx = ctx.clone_for_task(exec_ctx.clone(), None);
				run_with_exec_cancellation(&exec_ctx, async move {
					let output = middleware.run(ctx).await;
					let proxy = proxy.lock().expect("response proxy lock poisoned").clone();
					let should_cancel_later = output.is_err() || proxy.is_terminal_response();
					if should_cancel_later {
						cancel_later(&cancel_middleware, index);
					}
					match output {
						Ok(()) => MiddlewareOutput {
							index,
							proxy,
							error: Option::None,
						},
						Err(error) => MiddlewareOutput {
							index,
							proxy: proxy_for_task_output(true, proxy),
							error: Some(error),
						},
					}
				})
				.await
			}),
		));
	}

	let mut outputs = Vec::with_capacity(handles.len());
	let mut errors = Vec::new();
	for (index, handle) in handles {
		let output = handle
			.await
			.map_err(|error| Error::TaskJoin(error.to_string()))?;
		match output {
			Ok(output) => {
				if let Some(error) = output.error.clone() {
					cancel_later(&cancel_middleware, output.index);
					errors.push((output.index, error));
				}
				outputs.push(output);
			}
			Err(error) => {
				cancel_later(&cancel_middleware, index);
				errors.push((index, error));
			}
		}
	}

	outputs.sort_by_key(|output| output.index);
	if let Some(first_terminal_index) = first_terminal_middleware_index(&outputs, &errors) {
		let eligible_proxies = outputs
			.iter()
			.filter(|output| output.index <= first_terminal_index)
			.map(|output| output.proxy.clone())
			.collect::<Vec<_>>();
		let mut merged_terminal_proxy = merge_owned_proxy_responses(eligible_proxies);
		if merged_terminal_proxy.is_terminal_response() {
			return Ok(merged_terminal_proxy);
		}
		set_internal_server_error(&mut merged_terminal_proxy);
		return Ok(merged_terminal_proxy);
	}
	Ok(merge_owned_proxy_responses(
		outputs.into_iter().map(|output| output.proxy).collect(),
	))
}

fn first_terminal_middleware_index<E>(
	outputs: &[MiddlewareOutput<E>],
	errors: &[(usize, TaskError<E>)],
) -> Option<usize> {
	let first_proxy_index = outputs
		.iter()
		.find(|output| output.proxy.is_terminal_response())
		.map(|output| output.index);
	let first_error_index = errors
		.iter()
		.filter(|(_, error)| !error.is_cancelled())
		.map(|(index, _)| *index)
		.min();

	match (first_proxy_index, first_error_index) {
		(Some(proxy_index), Some(error_index)) => Some(proxy_index.min(error_index)),
		(Some(proxy_index), Option::None) => Some(proxy_index),
		(Option::None, Some(error_index)) => Some(error_index),
		(Option::None, Option::None) => errors.iter().map(|(index, _)| *index).min(),
	}
}

fn cancel_later(tokens: &[CancelToken], own_index: usize) {
	for (index, token) in tokens.iter().enumerate() {
		if index > own_index {
			token.cancel();
		}
	}
}

fn set_internal_server_error(proxy: &mut Proxy) {
	proxy.set_status(
		StatusCode::INTERNAL_SERVER_ERROR,
		Some("Internal Server Error".to_owned()),
	);
}

pub(in crate::mux) fn merge_owned_proxy_responses(proxies: Vec<Proxy>) -> Proxy {
	let refs = proxies.iter().map(Some).collect::<Vec<_>>();
	merge_proxy_responses(&refs)
}