use base64::{Engine, engine::general_purpose::STANDARD as BASE64};
use serde::{Deserialize, Deserializer, Serialize, Serializer};
use sha2::{Digest, Sha256};
use std::{fmt, marker::PhantomData, str::FromStr};
use thiserror::Error;
#[derive(Debug, Clone, Copy)]
pub struct Http;
#[derive(Debug, Clone, Copy)]
pub struct Raw;
#[derive(Debug, Clone, Eq, Hash)]
pub struct ScimVersion<Format> {
opaque: String,
#[allow(dead_code)]
_format: PhantomData<Format>,
}
pub type HttpVersion = ScimVersion<Http>;
pub type RawVersion = ScimVersion<Raw>;
impl<Format> ScimVersion<Format> {
pub fn from_content(content: &[u8]) -> RawVersion {
let mut hasher = Sha256::new();
hasher.update(content);
let hash = hasher.finalize();
let encoded = BASE64.encode(&hash[..8]);
ScimVersion {
opaque: encoded,
_format: PhantomData,
}
}
pub fn from_hash(hash_string: impl AsRef<str>) -> RawVersion {
ScimVersion {
opaque: hash_string.as_ref().to_string(),
_format: PhantomData,
}
}
pub fn as_str(&self) -> &str {
&self.opaque
}
}
impl fmt::Display for ScimVersion<Raw> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.opaque)
}
}
impl fmt::Display for ScimVersion<Http> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "W/\"{}\"", self.opaque)
}
}
impl FromStr for ScimVersion<Raw> {
type Err = VersionError;
fn from_str(version_str: &str) -> Result<Self, Self::Err> {
let trimmed = version_str.trim();
if trimmed.is_empty() {
return Err(VersionError::ParseError(
"Version string cannot be empty".to_string(),
));
}
Ok(ScimVersion {
opaque: trimmed.to_string(),
_format: PhantomData,
})
}
}
impl FromStr for ScimVersion<Http> {
type Err = VersionError;
fn from_str(etag_header: &str) -> Result<Self, Self::Err> {
let trimmed = etag_header.trim();
let etag_value = if trimmed.starts_with("W/") {
&trimmed[2..]
} else {
trimmed
};
if etag_value.len() < 2 || !etag_value.starts_with('"') || !etag_value.ends_with('"') {
return Err(VersionError::InvalidEtagFormat(etag_header.to_string()));
}
let opaque = etag_value[1..etag_value.len() - 1].to_string();
if opaque.is_empty() {
return Err(VersionError::InvalidEtagFormat(etag_header.to_string()));
}
Ok(ScimVersion {
opaque,
_format: PhantomData,
})
}
}
impl From<ScimVersion<Raw>> for ScimVersion<Http> {
fn from(raw: ScimVersion<Raw>) -> Self {
ScimVersion {
opaque: raw.opaque,
_format: PhantomData,
}
}
}
impl From<ScimVersion<Http>> for ScimVersion<Raw> {
fn from(http: ScimVersion<Http>) -> Self {
ScimVersion {
opaque: http.opaque,
_format: PhantomData,
}
}
}
impl<F1, F2> PartialEq<ScimVersion<F2>> for ScimVersion<F1> {
fn eq(&self, other: &ScimVersion<F2>) -> bool {
self.opaque == other.opaque
}
}
impl<Format> Serialize for ScimVersion<Format> {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
self.opaque.serialize(serializer)
}
}
impl<'de, Format> Deserialize<'de> for ScimVersion<Format> {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
let opaque = String::deserialize(deserializer)?;
Ok(ScimVersion {
opaque,
_format: PhantomData,
})
}
}
#[derive(Debug, Clone, PartialEq)]
pub enum ConditionalResult<T> {
Success(T),
VersionMismatch(VersionConflict),
NotFound,
}
impl<T> ConditionalResult<T> {
pub fn is_success(&self) -> bool {
matches!(self, ConditionalResult::Success(_))
}
pub fn is_version_mismatch(&self) -> bool {
matches!(self, ConditionalResult::VersionMismatch(_))
}
pub fn is_not_found(&self) -> bool {
matches!(self, ConditionalResult::NotFound)
}
pub fn into_success(self) -> Option<T> {
match self {
ConditionalResult::Success(value) => Some(value),
_ => None,
}
}
pub fn into_version_conflict(self) -> Option<VersionConflict> {
match self {
ConditionalResult::VersionMismatch(conflict) => Some(conflict),
_ => None,
}
}
pub fn map<U, F>(self, f: F) -> ConditionalResult<U>
where
F: FnOnce(T) -> U,
{
match self {
ConditionalResult::Success(value) => ConditionalResult::Success(f(value)),
ConditionalResult::VersionMismatch(conflict) => {
ConditionalResult::VersionMismatch(conflict)
}
ConditionalResult::NotFound => ConditionalResult::NotFound,
}
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct VersionConflict {
pub expected: RawVersion,
pub current: RawVersion,
pub message: String,
}
impl VersionConflict {
pub fn new<E, C>(expected: E, current: C, message: impl Into<String>) -> Self
where
E: Into<RawVersion>,
C: Into<RawVersion>,
{
Self {
expected: expected.into(),
current: current.into(),
message: message.into(),
}
}
pub fn standard_message<E, C>(expected: E, current: C) -> Self
where
E: Into<RawVersion>,
C: Into<RawVersion>,
{
Self::new(
expected,
current,
"Resource was modified by another client. Please refresh and try again.",
)
}
}
impl fmt::Display for VersionConflict {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"Version conflict: expected '{}', found '{}'. {}",
self.expected, self.current, self.message
)
}
}
impl std::error::Error for VersionConflict {}
#[derive(Debug, Error, Clone, PartialEq)]
pub enum VersionError {
#[error("Invalid ETag format: {0}")]
InvalidEtagFormat(String),
#[error("Failed to parse version: {0}")]
ParseError(String),
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_version_from_content() {
let content1 = b"test content";
let content2 = b"test content";
let content3 = b"different content";
let version1 = RawVersion::from_content(content1);
let version2 = RawVersion::from_content(content2);
let version3 = RawVersion::from_content(content3);
assert_eq!(version1, version2);
assert_ne!(version1, version3);
}
#[test]
fn test_version_from_hash() {
let version1 = RawVersion::from_hash("abc123def");
let version2 = RawVersion::from_hash("abc123def");
let version3 = RawVersion::from_hash("xyz789");
assert_eq!(version1, version2);
assert_ne!(version1, version3);
assert_eq!(version1.as_str(), "abc123def");
}
#[test]
fn test_http_version_parse() {
let version1: HttpVersion = "W/\"abc123\"".parse().unwrap();
assert_eq!(version1.as_str(), "abc123");
let version2: HttpVersion = "\"xyz789\"".parse().unwrap();
assert_eq!(version2.as_str(), "xyz789");
assert!("invalid".parse::<HttpVersion>().is_err());
assert!("\"\"".parse::<HttpVersion>().is_err());
assert!("W/invalid".parse::<HttpVersion>().is_err());
}
#[test]
fn test_raw_version_parse() {
let version: RawVersion = "abc123def".parse().unwrap();
assert_eq!(version.as_str(), "abc123def");
assert!("".parse::<RawVersion>().is_err());
assert!(" ".parse::<RawVersion>().is_err());
}
#[test]
fn test_format_display() {
let raw_version = RawVersion::from_hash("abc123");
let http_version = HttpVersion::from(raw_version.clone());
assert_eq!(raw_version.to_string(), "abc123");
assert_eq!(http_version.to_string(), "W/\"abc123\"");
assert_eq!(raw_version, http_version);
}
#[test]
fn test_conditional_result() {
let success: ConditionalResult<i32> = ConditionalResult::Success(42);
let not_found: ConditionalResult<i32> = ConditionalResult::NotFound;
let conflict: ConditionalResult<i32> =
ConditionalResult::VersionMismatch(VersionConflict::new(
RawVersion::from_hash("1"),
RawVersion::from_hash("2"),
"test conflict",
));
assert!(success.is_success());
assert!(!success.is_version_mismatch());
assert!(!success.is_not_found());
assert!(!not_found.is_success());
assert!(!not_found.is_version_mismatch());
assert!(not_found.is_not_found());
assert!(!conflict.is_success());
assert!(conflict.is_version_mismatch());
assert!(!conflict.is_not_found());
}
#[test]
fn test_conditional_result_map() {
let success: ConditionalResult<i32> = ConditionalResult::Success(21);
let doubled = success.map(|x| x * 2);
assert_eq!(doubled.into_success(), Some(42));
}
#[test]
fn test_version_conflict() {
let expected = RawVersion::from_hash("1");
let current = RawVersion::from_hash("2");
let conflict = VersionConflict::new(expected.clone(), current.clone(), "test message");
assert_eq!(conflict.expected, expected);
assert_eq!(conflict.current, current);
assert_eq!(conflict.message, "test message");
}
#[test]
fn test_version_conflict_display() {
let conflict = VersionConflict::standard_message(
RawVersion::from_hash("old"),
RawVersion::from_hash("new"),
);
let display_str = format!("{}", conflict);
assert!(display_str.contains("expected 'old'"));
assert!(display_str.contains("found 'new'"));
assert!(display_str.contains("Resource was modified"));
}
#[test]
fn test_version_serialization() {
let version = RawVersion::from_hash("test123");
let json = serde_json::to_string(&version).unwrap();
assert_eq!(json, "\"test123\"");
let deserialized: RawVersion = serde_json::from_str(&json).unwrap();
assert_eq!(version, deserialized);
}
#[test]
fn test_version_conflict_serialization() {
let conflict = VersionConflict::new(
RawVersion::from_hash("1"),
RawVersion::from_hash("2"),
"test",
);
let json = serde_json::to_string(&conflict).unwrap();
let deserialized: VersionConflict = serde_json::from_str(&json).unwrap();
assert_eq!(conflict, deserialized);
}
}