use axum_core::response::{IntoResponseParts, ResponseParts};
use http::header::{HeaderValue, VARY};
use crate::{HxError, extractors, headers};
const HX_REQUEST: HeaderValue = HeaderValue::from_static(headers::HX_REQUEST_STR);
const HX_TARGET: HeaderValue = HeaderValue::from_static(headers::HX_TARGET_STR);
const HX_TRIGGER: HeaderValue = HeaderValue::from_static(headers::HX_TRIGGER_STR);
const HX_TRIGGER_NAME: HeaderValue = HeaderValue::from_static(headers::HX_TRIGGER_NAME_STR);
#[derive(Debug, Clone)]
pub struct VaryHxRequest;
impl IntoResponseParts for VaryHxRequest {
type Error = HxError;
fn into_response_parts(self, mut res: ResponseParts) -> Result<ResponseParts, Self::Error> {
res.headers_mut().try_append(VARY, HX_REQUEST)?;
Ok(res)
}
}
impl extractors::HxRequest {
pub fn vary_response() -> VaryHxRequest {
VaryHxRequest
}
}
#[derive(Debug, Clone)]
pub struct VaryHxTarget;
impl IntoResponseParts for VaryHxTarget {
type Error = HxError;
fn into_response_parts(self, mut res: ResponseParts) -> Result<ResponseParts, Self::Error> {
res.headers_mut().try_append(VARY, HX_TARGET)?;
Ok(res)
}
}
impl extractors::HxTarget {
pub fn vary_response() -> VaryHxTarget {
VaryHxTarget
}
}
#[derive(Debug, Clone)]
pub struct VaryHxTrigger;
impl IntoResponseParts for VaryHxTrigger {
type Error = HxError;
fn into_response_parts(self, mut res: ResponseParts) -> Result<ResponseParts, Self::Error> {
res.headers_mut().try_append(VARY, HX_TRIGGER)?;
Ok(res)
}
}
impl extractors::HxTrigger {
pub fn vary_response() -> VaryHxTrigger {
VaryHxTrigger
}
}
#[derive(Debug, Clone)]
pub struct VaryHxTriggerName;
impl IntoResponseParts for VaryHxTriggerName {
type Error = HxError;
fn into_response_parts(self, mut res: ResponseParts) -> Result<ResponseParts, Self::Error> {
res.headers_mut().try_append(VARY, HX_TRIGGER_NAME)?;
Ok(res)
}
}
impl extractors::HxTriggerName {
pub fn vary_response() -> VaryHxTriggerName {
VaryHxTriggerName
}
}
#[cfg(test)]
mod tests {
use std::collections::hash_set::HashSet;
use axum::{Router, routing::get};
use super::*;
#[tokio::test]
async fn multiple_headers() {
let app = Router::new().route("/", get(|| async { (VaryHxRequest, VaryHxTarget, "foo") }));
let server = axum_test::TestServer::new(app).unwrap();
let resp = server.get("/").await;
let values: HashSet<HeaderValue> = resp.iter_headers_by_name("vary").cloned().collect();
assert_eq!(values, HashSet::from([HX_REQUEST, HX_TARGET]));
}
}