use std::sync::Arc;
use http::HeaderMap;
use parlov_core::ResponseClass;
use crate::chain::{ChainRegistry, Producer, ProducerOutput};
pub mod case_variation;
pub mod double_slash;
pub mod percent_encoding;
pub mod post_to_303;
pub mod protocol_upgrade;
pub mod put_to_303;
pub mod slash_append;
pub mod slash_strip;
pub use case_variation::RdCaseVariation;
pub use double_slash::RdDoubleSlash;
pub use percent_encoding::RdPercentEncoding;
pub use post_to_303::RdPostTo303;
pub use protocol_upgrade::RdProtocolUpgrade;
pub use put_to_303::RdPutTo303;
pub use slash_append::RdSlashAppend;
pub use slash_strip::RdSlashStrip;
pub(super) struct RdLocationProducer;
impl Producer for RdLocationProducer {
fn admits(&self, class: ResponseClass) -> bool {
matches!(class, ResponseClass::Redirect)
}
fn extract(&self, _class: ResponseClass, headers: &HeaderMap) -> Option<ProducerOutput> {
let loc = headers.get(http::header::LOCATION)?.to_str().ok()?;
Some(ProducerOutput::Location(loc.to_owned()))
}
}
#[must_use]
pub fn rd_chain_registry() -> ChainRegistry {
let mut reg = ChainRegistry::new();
let producer: Arc<dyn Producer> = Arc::new(RdLocationProducer);
reg.register(
Arc::clone(&producer),
Arc::new(slash_append::RdSlashAppendLocationConsumer),
);
reg.register(
Arc::clone(&producer),
Arc::new(slash_strip::RdSlashStripLocationConsumer),
);
reg.register(
Arc::clone(&producer),
Arc::new(case_variation::RdCaseVariationLocationConsumer),
);
reg.register(
Arc::clone(&producer),
Arc::new(double_slash::RdDoubleSlashLocationConsumer),
);
reg.register(
Arc::clone(&producer),
Arc::new(percent_encoding::RdPercentEncodingLocationConsumer),
);
reg.register(
Arc::clone(&producer),
Arc::new(protocol_upgrade::RdProtocolUpgradeLocationConsumer),
);
reg.register(
Arc::clone(&producer),
Arc::new(post_to_303::RdPostTo303LocationConsumer),
);
reg.register(
Arc::clone(&producer),
Arc::new(put_to_303::RdPutTo303LocationConsumer),
);
reg
}
#[cfg(test)]
mod tests {
use super::*;
use http::{HeaderMap, HeaderValue};
#[test]
fn rd_location_producer_admits_redirect_only() {
let p = RdLocationProducer;
assert!(p.admits(ResponseClass::Redirect));
assert!(!p.admits(ResponseClass::Success));
assert!(!p.admits(ResponseClass::Other));
assert!(!p.admits(ResponseClass::StructuredError));
assert!(!p.admits(ResponseClass::PartialContent));
}
#[test]
fn rd_location_producer_extracts_location_header() {
let p = RdLocationProducer;
let mut headers = HeaderMap::new();
headers.insert(
http::header::LOCATION,
HeaderValue::from_static("https://example.com/users/123/"),
);
let out = p.extract(ResponseClass::Redirect, &headers);
assert!(out.is_some());
assert!(matches!(
out.unwrap(),
ProducerOutput::Location(ref s) if s == "https://example.com/users/123/"
));
}
#[test]
fn rd_location_producer_returns_none_when_no_location() {
let p = RdLocationProducer;
let headers = HeaderMap::new();
assert!(p.extract(ResponseClass::Redirect, &headers).is_none());
}
#[test]
fn rd_chain_registry_has_eight_edges() {
assert_eq!(rd_chain_registry().len(), 8);
}
}