murgamu 0.7.3

Murgamü is an NestJS-inspired web framework for Rust
Documentation
use super::TimeoutConfig;
use crate::mur_http::request::MurRequestContext;
use crate::traits::{MurMiddleware, MurNext};
use crate::types::{MurFuture, MurHttpResponse};
use hyper::StatusCode;
use std::sync::Arc;
use std::time::Duration;

#[derive(Clone)]
pub struct MurTimeout {
	pub config: Arc<TimeoutConfig>,
}

impl MurTimeout {
	pub fn new(timeout: Duration) -> Self {
		Self {
			config: Arc::new(TimeoutConfig::new(timeout)),
		}
	}

	pub fn from_secs(secs: u64) -> Self {
		Self::new(Duration::from_secs(secs))
	}

	pub fn from_millis(millis: u64) -> Self {
		Self::new(Duration::from_millis(millis))
	}

	pub fn from_config(config: TimeoutConfig) -> Self {
		Self {
			config: Arc::new(config),
		}
	}

	pub fn with_status(mut self, status: StatusCode) -> Self {
		let mut config = (*self.config).clone();
		config.status_code = status;
		self.config = Arc::new(config);
		self
	}

	pub fn request_timeout(self) -> Self {
		self.with_status(StatusCode::REQUEST_TIMEOUT)
	}

	pub fn gateway_timeout(self) -> Self {
		self.with_status(StatusCode::GATEWAY_TIMEOUT)
	}

	pub fn service_unavailable(self) -> Self {
		self.with_status(StatusCode::SERVICE_UNAVAILABLE)
	}

	pub fn with_message(mut self, message: impl Into<String>) -> Self {
		let mut config = (*self.config).clone();
		config.message = Some(message.into());
		self.config = Arc::new(config);
		self
	}

	pub fn skip_paths(mut self, paths: Vec<impl Into<String>>) -> Self {
		let mut config = (*self.config).clone();
		config.skip_paths = paths.into_iter().map(|p| p.into()).collect();
		self.config = Arc::new(config);
		self
	}

	pub fn skip_path(mut self, path: impl Into<String>) -> Self {
		let mut config = (*self.config).clone();
		config.skip_paths.push(path.into());
		self.config = Arc::new(config);
		self
	}

	pub fn skip_prefixes(mut self, prefixes: Vec<impl Into<String>>) -> Self {
		let mut config = (*self.config).clone();
		config.skip_path_prefixes = prefixes.into_iter().map(|p| p.into()).collect();
		self.config = Arc::new(config);
		self
	}

	pub fn skip_prefix(mut self, prefix: impl Into<String>) -> Self {
		let mut config = (*self.config).clone();
		config.skip_path_prefixes.push(prefix.into());
		self.config = Arc::new(config);
		self
	}

	pub fn include_header(mut self) -> Self {
		let mut config = (*self.config).clone();
		config.include_timeout_header = true;
		self.config = Arc::new(config);
		self
	}

	pub fn log(mut self, enable: bool) -> Self {
		let mut config = (*self.config).clone();
		config.log_timeouts = enable;
		self.config = Arc::new(config);
		self
	}

	pub fn should_skip(&self, path: &str) -> bool {
		if self.config.skip_paths.iter().any(|p| p == path) {
			return true;
		}

		if self
			.config
			.skip_path_prefixes
			.iter()
			.any(|p| path.starts_with(p.as_str()))
		{
			return true;
		}

		false
	}

	pub fn build_timeout_response(&self) -> crate::types::MurRes {
		let message = self
			.config
			.message
			.clone()
			.unwrap_or_else(|| "Request timed out".to_string());

		let timeout_secs = self.config.timeout.as_secs_f64();

		MurHttpResponse::status(self.config.status_code).json(serde_json::json!({
			"error": "Timeout",
			"message": message,
			"timeout_seconds": timeout_secs,
			"status": self.config.status_code.as_u16()
		}))
	}
}

impl Default for MurTimeout {
	fn default() -> Self {
		Self::new(Duration::from_secs(30))
	}
}

impl std::fmt::Debug for MurTimeout {
	fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
		f.debug_struct("MurTimeout")
			.field("config", &self.config)
			.finish()
	}
}

impl MurMiddleware for MurTimeout {
	fn handle(&self, ctx: MurRequestContext, next: MurNext) -> MurFuture {
		let config = Arc::clone(&self.config);
		let timeout_middleware = self.clone();

		Box::pin(async move {
			let path = ctx.path().to_string();
			let method = ctx.method().to_string();

			if timeout_middleware.should_skip(&path) {
				return next.run(ctx).await;
			}

			let start = std::time::Instant::now();
			let result = tokio::time::timeout(config.timeout, next.run(ctx)).await;

			match result {
				Ok(response) => {
					if config.include_timeout_header {
						if let Ok(mut resp) = response {
							let elapsed = start.elapsed();
							if let Ok(value) = format!("{}ms", elapsed.as_millis())
								.parse::<hyper::header::HeaderValue>()
							{
								if let Ok(header_name) = hyper::header::HeaderName::from_bytes(
									config.timeout_header_name.as_bytes(),
								) {
									resp.headers_mut().insert(header_name, value);
								}
							}
							return Ok(resp);
						}
					}
					response
				}
				Err(_elapsed) => {
					if config.log_timeouts {
						eprintln!(
							"[TIMEOUT] {} {} exceeded {}ms timeout",
							method,
							path,
							config.timeout.as_millis()
						);
					}
					timeout_middleware.build_timeout_response()
				}
			}
		})
	}

	fn name(&self) -> &str {
		"MurTimeout"
	}
}