#[cfg(feature = "schemars")]
use std::borrow::Cow;
use std::ops::Deref;
use http::StatusCode;
use rustc_hash::FxHashSet;
use serde::{Deserialize, Deserializer, Serialize, Serializer};
use url::Url;
use crate::{IndexCapabilities, IndexUrl};
#[derive(Debug, Clone, Default, Eq, PartialEq)]
pub enum IndexStatusCodeStrategy {
#[default]
Default,
IgnoreErrorCodes {
status_codes: FxHashSet<StatusCode>,
},
}
impl IndexStatusCodeStrategy {
pub fn from_index_url(url: &Url) -> Self {
if url
.host_str()
.is_some_and(|host| host.ends_with("pytorch.org"))
{
Self::IgnoreErrorCodes {
status_codes: FxHashSet::from_iter([StatusCode::FORBIDDEN]),
}
} else {
Self::Default
}
}
pub fn from_ignored_error_codes(status_codes: &[SerializableStatusCode]) -> Self {
Self::IgnoreErrorCodes {
status_codes: status_codes
.iter()
.map(SerializableStatusCode::deref)
.copied()
.collect::<FxHashSet<_>>(),
}
}
pub fn ignore_authentication_error_codes() -> Self {
Self::IgnoreErrorCodes {
status_codes: FxHashSet::from_iter([
StatusCode::UNAUTHORIZED,
StatusCode::FORBIDDEN,
StatusCode::NETWORK_AUTHENTICATION_REQUIRED,
StatusCode::PROXY_AUTHENTICATION_REQUIRED,
]),
}
}
pub fn handle_status_code(
&self,
status_code: StatusCode,
index_url: &IndexUrl,
capabilities: &IndexCapabilities,
) -> IndexStatusCodeDecision {
match self {
Self::Default => match status_code {
StatusCode::NOT_FOUND => IndexStatusCodeDecision::Ignore,
StatusCode::UNAUTHORIZED => {
capabilities.set_unauthorized(index_url.clone());
IndexStatusCodeDecision::Fail(status_code)
}
StatusCode::FORBIDDEN => {
capabilities.set_forbidden(index_url.clone());
IndexStatusCodeDecision::Fail(status_code)
}
_ => IndexStatusCodeDecision::Fail(status_code),
},
Self::IgnoreErrorCodes { status_codes } => {
if status_codes.contains(&status_code) {
IndexStatusCodeDecision::Ignore
} else {
Self::Default.handle_status_code(status_code, index_url, capabilities)
}
}
}
}
}
#[derive(Debug, Clone, Copy, Hash, Eq, PartialEq)]
pub enum IndexStatusCodeDecision {
Ignore,
Fail(StatusCode),
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct SerializableStatusCode(StatusCode);
impl Deref for SerializableStatusCode {
type Target = StatusCode;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl Serialize for SerializableStatusCode {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
serializer.serialize_u16(self.0.as_u16())
}
}
impl<'de> Deserialize<'de> for SerializableStatusCode {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
let code = u16::deserialize(deserializer)?;
StatusCode::from_u16(code)
.map(SerializableStatusCode)
.map_err(|_| {
serde::de::Error::custom(format!("{code} is not a valid HTTP status code"))
})
}
}
#[cfg(feature = "schemars")]
impl schemars::JsonSchema for SerializableStatusCode {
fn schema_name() -> Cow<'static, str> {
Cow::Borrowed("StatusCode")
}
fn json_schema(_generator: &mut schemars::generate::SchemaGenerator) -> schemars::Schema {
schemars::json_schema!({
"type": "number",
"minimum": 100,
"maximum": 599,
"description": "HTTP status code (100-599)"
})
}
}
#[cfg(test)]
mod tests {
use std::str::FromStr;
use url::Url;
use super::*;
#[test]
fn test_strategy_normal_registry() {
let url = Url::from_str("https://internal-registry.com/simple").unwrap();
assert_eq!(
IndexStatusCodeStrategy::from_index_url(&url),
IndexStatusCodeStrategy::Default
);
}
#[test]
fn test_strategy_pytorch_registry() {
let status_codes = std::iter::once(StatusCode::FORBIDDEN).collect::<FxHashSet<_>>();
let url = Url::from_str("https://download.pytorch.org/whl/cu118").unwrap();
assert_eq!(
IndexStatusCodeStrategy::from_index_url(&url),
IndexStatusCodeStrategy::IgnoreErrorCodes { status_codes }
);
}
#[test]
fn test_strategy_custom_error_codes() {
let status_codes = FxHashSet::from_iter([StatusCode::UNAUTHORIZED, StatusCode::FORBIDDEN]);
let serializable_status_codes = status_codes
.iter()
.map(|code| SerializableStatusCode(*code))
.collect::<Vec<_>>();
assert_eq!(
IndexStatusCodeStrategy::from_ignored_error_codes(&serializable_status_codes),
IndexStatusCodeStrategy::IgnoreErrorCodes { status_codes }
);
}
#[test]
fn test_decision_default_400() {
let strategy = IndexStatusCodeStrategy::Default;
let status_code = StatusCode::BAD_REQUEST;
let index_url = IndexUrl::parse("https://internal-registry.com/simple", None).unwrap();
let capabilities = IndexCapabilities::default();
let decision = strategy.handle_status_code(status_code, &index_url, &capabilities);
assert_eq!(
decision,
IndexStatusCodeDecision::Fail(StatusCode::BAD_REQUEST)
);
}
#[test]
fn test_decision_default_401() {
let strategy = IndexStatusCodeStrategy::Default;
let status_code = StatusCode::UNAUTHORIZED;
let index_url = IndexUrl::parse("https://internal-registry.com/simple", None).unwrap();
let capabilities = IndexCapabilities::default();
let decision = strategy.handle_status_code(status_code, &index_url, &capabilities);
assert_eq!(
decision,
IndexStatusCodeDecision::Fail(StatusCode::UNAUTHORIZED)
);
assert!(capabilities.unauthorized(&index_url));
assert!(!capabilities.forbidden(&index_url));
}
#[test]
fn test_decision_default_403() {
let strategy = IndexStatusCodeStrategy::Default;
let status_code = StatusCode::FORBIDDEN;
let index_url = IndexUrl::parse("https://internal-registry.com/simple", None).unwrap();
let capabilities = IndexCapabilities::default();
let decision = strategy.handle_status_code(status_code, &index_url, &capabilities);
assert_eq!(
decision,
IndexStatusCodeDecision::Fail(StatusCode::FORBIDDEN)
);
assert!(capabilities.forbidden(&index_url));
assert!(!capabilities.unauthorized(&index_url));
}
#[test]
fn test_decision_default_404() {
let strategy = IndexStatusCodeStrategy::Default;
let status_code = StatusCode::NOT_FOUND;
let index_url = IndexUrl::parse("https://internal-registry.com/simple", None).unwrap();
let capabilities = IndexCapabilities::default();
let decision = strategy.handle_status_code(status_code, &index_url, &capabilities);
assert_eq!(decision, IndexStatusCodeDecision::Ignore);
assert!(!capabilities.forbidden(&index_url));
assert!(!capabilities.unauthorized(&index_url));
}
#[test]
fn test_decision_pytorch() {
let index_url = IndexUrl::parse("https://download.pytorch.org/whl/cu118", None).unwrap();
let strategy = IndexStatusCodeStrategy::from_index_url(&index_url);
let capabilities = IndexCapabilities::default();
let status_code = StatusCode::FORBIDDEN;
let decision = strategy.handle_status_code(status_code, &index_url, &capabilities);
assert_eq!(decision, IndexStatusCodeDecision::Ignore);
let status_code = StatusCode::UNAUTHORIZED;
let decision = strategy.handle_status_code(status_code, &index_url, &capabilities);
assert_eq!(
decision,
IndexStatusCodeDecision::Fail(StatusCode::UNAUTHORIZED)
);
}
#[test]
fn test_decision_multiple_ignored_status_codes() {
let status_codes = vec![
StatusCode::UNAUTHORIZED,
StatusCode::BAD_GATEWAY,
StatusCode::SERVICE_UNAVAILABLE,
];
let strategy = IndexStatusCodeStrategy::IgnoreErrorCodes {
status_codes: status_codes.iter().copied().collect::<FxHashSet<_>>(),
};
let index_url = IndexUrl::parse("https://internal-registry.com/simple", None).unwrap();
let capabilities = IndexCapabilities::default();
for status_code in status_codes {
let decision = strategy.handle_status_code(status_code, &index_url, &capabilities);
assert_eq!(decision, IndexStatusCodeDecision::Ignore);
}
let other_status_code = StatusCode::FORBIDDEN;
let decision = strategy.handle_status_code(other_status_code, &index_url, &capabilities);
assert_eq!(
decision,
IndexStatusCodeDecision::Fail(StatusCode::FORBIDDEN)
);
}
}