use std::collections::HashMap;
#[cfg(feature = "ws")]
use std::sync::Arc;
#[cfg(feature = "ws")]
use tokio::sync::Mutex;
#[derive(Debug, Clone, Default)]
pub struct Context {
metadata: HashMap<String, String>,
user_id: Option<String>,
request_id: Option<String>,
}
impl Context {
pub fn new() -> Self {
Self::default()
}
pub fn with_metadata(metadata: HashMap<String, String>) -> Self {
Self {
metadata,
..Default::default()
}
}
pub fn get(&self, key: &str) -> Option<&str> {
self.metadata.get(key).map(|s| s.as_str())
}
pub fn set(&mut self, key: impl Into<String>, value: impl Into<String>) {
self.metadata.insert(key.into(), value.into());
}
pub fn user_id(&self) -> Option<&str> {
self.user_id.as_deref()
}
pub fn set_user_id(&mut self, user_id: impl Into<String>) {
self.user_id = Some(user_id.into());
}
pub fn request_id(&self) -> Option<&str> {
self.request_id.as_deref()
}
pub fn set_request_id(&mut self, request_id: impl Into<String>) {
self.request_id = Some(request_id.into());
}
pub fn metadata(&self) -> &HashMap<String, String> {
&self.metadata
}
pub fn header(&self, name: &str) -> Option<&str> {
let name_lower = name.to_lowercase();
self.metadata
.iter()
.find(|(k, _)| k.to_lowercase() == name_lower)
.map(|(_, v)| v.as_str())
}
pub fn authorization(&self) -> Option<&str> {
self.header("authorization")
}
pub fn content_type(&self) -> Option<&str> {
self.header("content-type")
}
pub fn env(&self, name: &str) -> Option<&str> {
self.get(&format!("env:{name}"))
}
}
#[cfg(feature = "ws")]
#[derive(Clone)]
pub struct WsSender {
sender: Arc<
Mutex<futures::stream::SplitSink<axum::extract::ws::WebSocket, axum::extract::ws::Message>>,
>,
}
#[cfg(feature = "ws")]
impl WsSender {
#[doc(hidden)]
pub fn new(
sender: futures::stream::SplitSink<
axum::extract::ws::WebSocket,
axum::extract::ws::Message,
>,
) -> Self {
Self {
sender: Arc::new(Mutex::new(sender)),
}
}
pub async fn send(&self, text: impl Into<String>) -> Result<(), String> {
use futures::sink::SinkExt;
let mut guard = self.sender.lock().await;
guard
.send(axum::extract::ws::Message::Text(text.into().into()))
.await
.map_err(|e| format!("Failed to send WebSocket message: {}", e))
}
pub async fn send_json<T: serde::Serialize>(&self, value: &T) -> Result<(), String> {
let json =
serde_json::to_string(value).map_err(|e| format!("Failed to serialize JSON: {}", e))?;
self.send(json).await
}
pub async fn close(&self) -> Result<(), String> {
use futures::sink::SinkExt;
let mut guard = self.sender.lock().await;
guard
.send(axum::extract::ws::Message::Close(None))
.await
.map_err(|e| format!("Failed to close WebSocket: {}", e))
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct Path<T>(pub T);
impl<T> Path<T> {
pub fn into_inner(self) -> T {
self.0
}
}
impl<T> std::ops::Deref for Path<T> {
type Target = T;
fn deref(&self) -> &Self::Target {
&self.0
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct Query<T>(pub T);
impl<T> Query<T> {
pub fn into_inner(self) -> T {
self.0
}
}
impl<T> std::ops::Deref for Query<T> {
type Target = T;
fn deref(&self) -> &Self::Target {
&self.0
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct Json<T>(pub T);
impl<T> Json<T> {
pub fn into_inner(self) -> T {
self.0
}
}
impl<T> std::ops::Deref for Json<T> {
type Target = T;
fn deref(&self) -> &Self::Target {
&self.0
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_context_metadata() {
let mut ctx = Context::new();
ctx.set("Content-Type", "application/json");
ctx.set("X-Request-Id", "abc123");
assert_eq!(ctx.get("Content-Type"), Some("application/json"));
assert_eq!(ctx.header("content-type"), Some("application/json"));
}
#[test]
fn test_context_user() {
let mut ctx = Context::new();
assert!(ctx.user_id().is_none());
ctx.set_user_id("user_123");
assert_eq!(ctx.user_id(), Some("user_123"));
}
}