use std::sync::Arc;
use recoverable::RecoveryKind;
use templated_uri::{BaseUri, Uri};
use super::RouterContext;
use crate::error_labels::{LABEL_URI_CONFLICT, LABEL_URI_MISSING};
use crate::{HttpError, HttpRequest};
#[derive(Debug, Default, Clone, Copy, PartialEq, Eq)]
pub enum BaseUriConflict {
#[default]
UseOriginal,
UseRouted,
Fail,
}
#[derive(Debug, Clone, Default)]
pub struct Router {
resolver: Arc<Resolver>,
conflict_policy: BaseUriConflict,
}
#[derive(Clone, Debug)]
pub struct RequestUris {
original: Uri,
routed: Option<Uri>,
}
impl RequestUris {
#[must_use]
pub fn new(original: Uri) -> Self {
Self { original, routed: None }
}
#[must_use]
pub fn original(&self) -> &Uri {
&self.original
}
#[must_use]
pub fn routed(&self) -> Option<&Uri> {
self.routed.as_ref()
}
pub fn set_routed(&mut self, routed: Uri) {
self.routed = Some(routed);
}
}
impl Router {
#[must_use]
pub fn fixed(base_uri: BaseUri) -> Self {
Self {
resolver: Arc::new(Resolver::Fixed(base_uri)),
conflict_policy: BaseUriConflict::default(),
}
}
#[must_use]
pub fn fallback(primary: BaseUri, fallback: BaseUri) -> Self {
Self::custom(
move |context| Some(if use_fallback(context) { fallback.clone() } else { primary.clone() }),
true,
)
.conflict_policy(BaseUriConflict::UseRouted)
}
#[must_use]
pub fn custom<F>(resolver: F, has_alternatives: bool) -> Self
where
F: Fn(&RouterContext) -> Option<BaseUri> + Send + Sync + 'static,
{
Self {
resolver: Arc::new(Resolver::Custom {
resolver: Arc::new(resolver),
has_alternatives,
}),
conflict_policy: BaseUriConflict::default(),
}
}
#[must_use]
pub fn conflict_policy(mut self, policy: BaseUriConflict) -> Self {
self.conflict_policy = policy;
self
}
#[must_use]
pub fn has_alternatives(&self) -> bool {
match self.resolver.as_ref() {
Resolver::Empty | Resolver::Fixed(_) => false,
Resolver::Custom { has_alternatives, .. } => *has_alternatives,
}
}
#[expect(
clippy::needless_pass_by_value,
reason = "while not consuming the context, we might do it at some point"
)]
pub fn resolve_uri(&self, context: RouterContext, uri: Uri) -> Result<Uri, HttpError> {
let (original, path) = uri.into_parts();
let routed = self.resolve_base_uri(&context);
let Some(routed) = routed else {
let Some(original) = original else {
return Err(HttpError::validation_with_label(
"the target URI has no base URI and the routing did not produce one; \
provide a base URI on the target or configure the router to resolve one",
LABEL_URI_MISSING,
));
};
return Ok(Uri::from_parts(Some(original), path));
};
let Some(original) = original else {
return Ok(Uri::from_parts(routed, path));
};
let chosen = match self.conflict_policy {
BaseUriConflict::UseOriginal => original,
BaseUriConflict::UseRouted => routed,
BaseUriConflict::Fail => {
return Err(HttpError::validation_with_label(
"target URI already has a base URI; routing produced a conflicting base URI",
LABEL_URI_CONFLICT,
));
}
};
Ok(Uri::from_parts(chosen, path))
}
pub fn resolve_request_uri(&self, context: RouterContext, request: &mut HttpRequest) -> Result<(), HttpError> {
let original: Uri = match request.extensions().get::<RequestUris>() {
Some(uris) => uris.original().clone(),
None => request.uri().clone().try_into()?,
};
let resolved = self.resolve_uri(context.with_request(request), original.clone())?;
let http_uri = resolved.clone().try_into()?;
*request.uri_mut() = http_uri;
request
.extensions_mut()
.get_or_insert_with(|| RequestUris::new(original))
.set_routed(resolved);
Ok(())
}
fn resolve_base_uri(&self, context: &RouterContext) -> Option<BaseUri> {
match self.resolver.as_ref() {
Resolver::Empty => None,
Resolver::Fixed(base_uri) => Some(base_uri.clone()),
Resolver::Custom { resolver, .. } => resolver(context),
}
}
}
type RouterFn = dyn Fn(&RouterContext) -> Option<BaseUri> + Send + Sync + 'static;
#[derive(Default)]
enum Resolver {
#[default]
Empty,
Fixed(BaseUri),
Custom {
resolver: Arc<RouterFn>,
has_alternatives: bool,
},
}
impl std::fmt::Debug for Resolver {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Empty => f.write_str("Empty"),
Self::Fixed(base_uri) => f.debug_tuple("Fixed").field(base_uri).finish(),
Self::Custom { has_alternatives, .. } => f
.debug_struct("Custom")
.field("has_alternatives", has_alternatives)
.finish_non_exhaustive(),
}
}
}
fn use_fallback(context: &RouterContext) -> bool {
if context.attempt() == 0 {
return false;
}
if context.is_last_attempt() {
return true;
}
context
.previous_recovery()
.is_some_and(|info| info.kind() == RecoveryKind::Unavailable)
}
#[cfg(test)]
#[cfg_attr(coverage_nightly, coverage(off))]
mod tests {
use ohno::Labeled;
use super::*;
fn target_with_base() -> Uri {
"https://existing.example.com/items".parse().unwrap()
}
fn target_without_base() -> Uri {
"/v1/items".parse().unwrap()
}
#[test]
fn default_passes_target_through_when_target_has_base() {
let router = Router::default();
let with_base = router.resolve_uri(RouterContext::new(), target_with_base()).unwrap();
assert_eq!(with_base.to_string().declassify_into(), "https://existing.example.com/items");
}
#[test]
fn default_errors_when_target_has_no_base() {
let router = Router::default();
let err = router.resolve_uri(RouterContext::new(), target_without_base()).unwrap_err();
assert_eq!(err.label(), "uri_missing");
}
#[test]
fn fixed_attaches_when_target_has_none() {
let router = Router::fixed(BaseUri::from_static("https://api.example.com"));
let resolved = router.resolve_uri(RouterContext::new(), target_without_base()).unwrap();
assert_eq!(resolved.to_string().declassify_into(), "https://api.example.com/v1/items");
}
#[test]
fn custom_resolver_returning_none_passes_through() {
let router = Router::custom(|_| None, false);
let resolved = router.resolve_uri(RouterContext::new(), target_with_base()).unwrap();
assert_eq!(resolved.to_string().declassify_into(), "https://existing.example.com/items");
}
#[test]
fn custom_resolver_returning_some_is_used() {
let router = Router::custom(|_| Some(BaseUri::from_static("https://api.example.com")), false);
let resolved = router.resolve_uri(RouterContext::new(), target_without_base()).unwrap();
assert_eq!(resolved.to_string().declassify_into(), "https://api.example.com/v1/items");
}
#[test]
fn keep_existing_is_default_on_conflict() {
let router = Router::fixed(BaseUri::from_static("https://api.example.com"));
let resolved = router.resolve_uri(RouterContext::new(), target_with_base()).unwrap();
assert_eq!(resolved.to_string().declassify_into(), "https://existing.example.com/items");
}
#[test]
fn use_routed_replaces_original_base_uri() {
let router = Router::fixed(BaseUri::from_static("https://api.example.com")).conflict_policy(BaseUriConflict::UseRouted);
let resolved = router.resolve_uri(RouterContext::new(), target_with_base()).unwrap();
assert_eq!(resolved.to_string().declassify_into(), "https://api.example.com/items");
}
#[test]
fn fail_returns_error_on_conflict() {
let router = Router::fixed(BaseUri::from_static("https://api.example.com")).conflict_policy(BaseUriConflict::Fail);
let err = router.resolve_uri(RouterContext::new(), target_with_base()).unwrap_err();
assert_eq!(err.label(), "uri_conflict");
}
#[test]
fn missing_base_uri_errors_regardless_of_policy() {
for policy in [BaseUriConflict::UseOriginal, BaseUriConflict::UseRouted, BaseUriConflict::Fail] {
let router = Router::default().conflict_policy(policy);
let err = router.resolve_uri(RouterContext::new(), Uri::default()).unwrap_err();
assert_eq!(err.label(), "uri_missing", "empty Uri with policy {policy:?}");
let err = router.resolve_uri(RouterContext::new(), target_without_base()).unwrap_err();
assert_eq!(err.label(), "uri_missing", "path-only Uri with policy {policy:?}");
}
}
#[test]
fn fail_does_not_trigger_without_conflict() {
let router = Router::fixed(BaseUri::from_static("https://api.example.com")).conflict_policy(BaseUriConflict::Fail);
let resolved = router.resolve_uri(RouterContext::new(), target_without_base()).unwrap();
assert_eq!(resolved.to_string().declassify_into(), "https://api.example.com/v1/items");
}
#[test]
fn fallback_uses_primary_without_previous_recovery() {
let router = Router::fallback(
BaseUri::from_static("https://primary.example.com"),
BaseUri::from_static("https://fallback.example.com"),
);
let resolved = router.resolve_uri(RouterContext::new(), target_without_base()).unwrap();
assert_eq!(resolved.to_string().declassify_into(), "https://primary.example.com/v1/items");
}
#[test]
fn fallback_uses_primary_when_previous_recovery_is_not_unavailable() {
let router = Router::fallback(
BaseUri::from_static("https://primary.example.com"),
BaseUri::from_static("https://fallback.example.com"),
);
let ctx = RouterContext::new().with_previous_recovery(recoverable::RecoveryInfo::retry());
let resolved = router.resolve_uri(ctx, target_without_base()).unwrap();
assert_eq!(resolved.to_string().declassify_into(), "https://primary.example.com/v1/items");
}
#[test]
fn fallback_uses_fallback_when_previous_recovery_is_unavailable() {
let router = Router::fallback(
BaseUri::from_static("https://primary.example.com"),
BaseUri::from_static("https://fallback.example.com"),
);
let ctx = RouterContext::new()
.with_previous_recovery(recoverable::RecoveryInfo::unavailable())
.with_attempt(1, false);
let resolved = router.resolve_uri(ctx, target_without_base()).unwrap();
assert_eq!(resolved.to_string().declassify_into(), "https://fallback.example.com/v1/items");
}
#[test]
fn fallback_uses_fallback_on_last_attempt_after_first() {
let router = Router::fallback(
BaseUri::from_static("https://primary.example.com"),
BaseUri::from_static("https://fallback.example.com"),
);
let ctx = RouterContext::new().with_attempt(2, true);
let resolved = router.resolve_uri(ctx, target_without_base()).unwrap();
assert_eq!(resolved.to_string().declassify_into(), "https://fallback.example.com/v1/items");
}
#[test]
fn fallback_uses_primary_on_first_attempt_even_when_last() {
let router = Router::fallback(
BaseUri::from_static("https://primary.example.com"),
BaseUri::from_static("https://fallback.example.com"),
);
let ctx = RouterContext::new().with_attempt(0, true);
let resolved = router.resolve_uri(ctx, target_without_base()).unwrap();
assert_eq!(resolved.to_string().declassify_into(), "https://primary.example.com/v1/items");
}
#[test]
fn fallback_switches_endpoint_across_in_place_resolve_request_uri_calls() {
let router = Router::fallback(
BaseUri::from_static("https://primary.example.com"),
BaseUri::from_static("https://fallback.example.com"),
);
let mut request = crate::HttpRequestBuilder::new_fake().get("/v1/items").build().unwrap();
router.resolve_request_uri(RouterContext::new(), &mut request).unwrap();
assert_eq!(request.uri().to_string(), "https://primary.example.com/v1/items");
let ctx = RouterContext::new()
.with_previous_recovery(recoverable::RecoveryInfo::unavailable())
.with_attempt(1, false);
router.resolve_request_uri(ctx, &mut request).unwrap();
assert_eq!(request.uri().to_string(), "https://fallback.example.com/v1/items");
}
#[test]
fn fallback_does_not_duplicate_base_path_across_in_place_resolve_request_uri_calls() {
let router = Router::fallback(
BaseUri::from_static("https://primary.example.com/api/v1/"),
BaseUri::from_static("https://fallback.example.com/api/"),
);
let mut request = crate::HttpRequestBuilder::new_fake().get("/items").build().unwrap();
router.resolve_request_uri(RouterContext::new(), &mut request).unwrap();
assert_eq!(request.uri().to_string(), "https://primary.example.com/api/v1/items");
let ctx = RouterContext::new()
.with_previous_recovery(recoverable::RecoveryInfo::unavailable())
.with_attempt(1, false);
router.resolve_request_uri(ctx, &mut request).unwrap();
assert_eq!(request.uri().to_string(), "https://fallback.example.com/api/items");
let ctx = RouterContext::new().with_attempt(2, false);
router.resolve_request_uri(ctx, &mut request).unwrap();
assert_eq!(request.uri().to_string(), "https://primary.example.com/api/v1/items");
}
#[test]
fn fallback_uses_primary_on_non_last_attempt() {
let router = Router::fallback(
BaseUri::from_static("https://primary.example.com"),
BaseUri::from_static("https://fallback.example.com"),
);
let ctx = RouterContext::new().with_attempt(1, false);
let resolved = router.resolve_uri(ctx, target_without_base()).unwrap();
assert_eq!(resolved.to_string().declassify_into(), "https://primary.example.com/v1/items");
}
#[test]
fn assert_router_size() {
static_assertions::assert_eq_size!(Router, [u8; 16]);
}
#[test]
fn default_has_no_alternatives() {
assert!(!Router::default().has_alternatives());
}
#[test]
fn fixed_has_no_alternatives() {
let router = Router::fixed(BaseUri::from_static("https://api.example.com"));
assert!(!router.has_alternatives());
}
#[test]
fn fallback_has_alternatives() {
let router = Router::fallback(
BaseUri::from_static("https://primary.example.com"),
BaseUri::from_static("https://fallback.example.com"),
);
assert!(router.has_alternatives());
}
#[test]
fn custom_has_alternatives() {
let router = Router::custom(|_| None, true);
assert!(router.has_alternatives());
}
#[test]
fn custom_without_alternatives_reports_false() {
let router = Router::custom(|_| None, false);
assert!(!router.has_alternatives());
}
#[test]
fn resolve_request_uri_falls_back_to_request_uri_without_request_uris_extension() {
let router = Router::fixed(BaseUri::from_static("https://api.example.com"));
let body = crate::HttpBodyBuilder::new_fake().empty();
let mut request = http::Request::new(body);
*request.uri_mut() = http::Uri::from_static("/v1/items");
assert!(
request.extensions().get::<RequestUris>().is_none(),
"precondition: no RequestUris extension"
);
router.resolve_request_uri(RouterContext::new(), &mut request).unwrap();
assert_eq!(request.uri().to_string(), "https://api.example.com/v1/items");
let uris = request
.extensions()
.get::<RequestUris>()
.expect("resolve_request_uri must attach a RequestUris extension for hand-built requests");
assert_eq!(uris.original().to_string().declassify_ref(), "/v1/items");
assert_eq!(
uris.routed().expect("routed must be populated").to_string().declassify_ref(),
"https://api.example.com/v1/items"
);
}
#[test]
fn resolve_request_uri_attaches_base_uri() {
let router = Router::fixed(BaseUri::from_static("https://api.example.com"));
let mut request = crate::HttpRequestBuilder::new_fake().get("/v1/items").build().unwrap();
router.resolve_request_uri(RouterContext::new(), &mut request).unwrap();
assert_eq!(request.uri().to_string(), "https://api.example.com/v1/items");
}
#[test]
fn resolve_request_uri_attaches_resolved_uri_extension() {
let router = Router::fixed(BaseUri::from_static("https://api.example.com"));
let mut request = crate::HttpRequestBuilder::new_fake().get("/v1/items").build().unwrap();
router.resolve_request_uri(RouterContext::new(), &mut request).unwrap();
let uris = request
.extensions()
.get::<RequestUris>()
.expect("resolve_request_uri must keep the RequestUris extension");
assert_eq!(uris.original().to_string().declassify_ref(), "/v1/items");
assert_eq!(
uris.routed().expect("routed must be populated").to_string().declassify_ref(),
"https://api.example.com/v1/items"
);
}
#[test]
fn resolve_request_uri_preserves_original_across_repeated_calls() {
let cell = std::sync::Arc::new(std::sync::Mutex::new(0_usize));
let cell_clone = cell;
let router = Router::custom(
move |_| {
let mut count = cell_clone.lock().unwrap();
let base = if *count == 0 {
BaseUri::from_static("https://first.example.com")
} else {
BaseUri::from_static("https://second.example.com")
};
*count += 1;
Some(base)
},
true,
);
let mut request = crate::HttpRequestBuilder::new_fake().get("/v1/items").build().unwrap();
router.resolve_request_uri(RouterContext::new(), &mut request).unwrap();
assert_eq!(request.uri().to_string(), "https://first.example.com/v1/items");
router.resolve_request_uri(RouterContext::new(), &mut request).unwrap();
assert_eq!(request.uri().to_string(), "https://second.example.com/v1/items");
let uris = request.extensions().get::<RequestUris>().unwrap();
assert_eq!(uris.original().to_string().declassify_ref(), "/v1/items");
}
#[test]
fn resolve_request_uri_keeps_existing_base_uri_by_default() {
let router = Router::fixed(BaseUri::from_static("https://api.example.com"));
let mut request = crate::HttpRequestBuilder::new_fake()
.get("https://existing.example.com/items")
.build()
.unwrap();
router.resolve_request_uri(RouterContext::new(), &mut request).unwrap();
assert_eq!(request.uri().to_string(), "https://existing.example.com/items");
}
#[test]
fn resolve_request_uri_returns_error_on_conflict_when_policy_is_fail() {
let router = Router::fixed(BaseUri::from_static("https://api.example.com")).conflict_policy(BaseUriConflict::Fail);
let mut request = crate::HttpRequestBuilder::new_fake()
.get("https://existing.example.com/items")
.build()
.unwrap();
let err = router.resolve_request_uri(RouterContext::new(), &mut request).unwrap_err();
assert_eq!(err.label(), "uri_conflict");
}
#[test]
fn resolve_request_uri_preserves_original_uri_on_failure() {
let router = Router::fixed(BaseUri::from_static("https://api.example.com")).conflict_policy(BaseUriConflict::Fail);
let mut request = crate::HttpRequestBuilder::new_fake()
.get("https://existing.example.com/items")
.build()
.unwrap();
let original_uri = request.uri().clone();
let _ = router.resolve_request_uri(RouterContext::new(), &mut request).unwrap_err();
assert_eq!(request.uri(), &original_uri);
}
#[test]
fn resolver_debug_format() {
assert_eq!(format!("{:?}", Resolver::Empty), "Empty");
let fixed = Resolver::Fixed(BaseUri::from_static("https://api.example.com"));
assert!(format!("{fixed:?}").starts_with("Fixed("));
let custom = Resolver::Custom {
resolver: Arc::new(|_| None),
has_alternatives: true,
};
let custom_debug = format!("{custom:?}");
assert!(custom_debug.starts_with("Custom"));
assert!(custom_debug.contains("has_alternatives: true"));
}
}