use base64::{Engine, engine::general_purpose::STANDARD as BASE64};
use serde::{Deserialize, Serialize};
use sha2::{Digest, Sha256};
use std::fmt;
use thiserror::Error;
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct ScimVersion {
opaque: String,
}
impl ScimVersion {
pub fn from_content(content: &[u8]) -> Self {
let mut hasher = Sha256::new();
hasher.update(content);
let hash = hasher.finalize();
let encoded = BASE64.encode(&hash[..8]); Self { opaque: encoded }
}
pub fn from_hash(hash_string: impl AsRef<str>) -> Self {
Self {
opaque: hash_string.as_ref().to_string(),
}
}
pub fn parse_http_header(etag_header: &str) -> Result<Self, VersionError> {
let trimmed = etag_header.trim();
let etag_value = if trimmed.starts_with("W/") {
&trimmed[2..]
} else {
trimmed
};
if etag_value.len() < 2 || !etag_value.starts_with('"') || !etag_value.ends_with('"') {
return Err(VersionError::InvalidEtagFormat(etag_header.to_string()));
}
let opaque = etag_value[1..etag_value.len() - 1].to_string();
if opaque.is_empty() {
return Err(VersionError::InvalidEtagFormat(etag_header.to_string()));
}
Ok(Self { opaque })
}
pub fn parse_raw(version_str: &str) -> Result<Self, VersionError> {
let trimmed = version_str.trim();
if trimmed.is_empty() {
return Err(VersionError::ParseError("Version string cannot be empty".to_string()));
}
Ok(Self {
opaque: trimmed.to_string()
})
}
pub fn to_http_header(&self) -> String {
format!("W/\"{}\"", self.opaque)
}
pub fn matches(&self, other: &ScimVersion) -> bool {
self.opaque == other.opaque
}
pub fn as_str(&self) -> &str {
&self.opaque
}
}
impl fmt::Display for ScimVersion {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.opaque)
}
}
#[derive(Debug, Clone, PartialEq)]
pub enum ConditionalResult<T> {
Success(T),
VersionMismatch(VersionConflict),
NotFound,
}
impl<T> ConditionalResult<T> {
pub fn is_success(&self) -> bool {
matches!(self, ConditionalResult::Success(_))
}
pub fn is_version_mismatch(&self) -> bool {
matches!(self, ConditionalResult::VersionMismatch(_))
}
pub fn is_not_found(&self) -> bool {
matches!(self, ConditionalResult::NotFound)
}
pub fn into_success(self) -> Option<T> {
match self {
ConditionalResult::Success(value) => Some(value),
_ => None,
}
}
pub fn into_version_conflict(self) -> Option<VersionConflict> {
match self {
ConditionalResult::VersionMismatch(conflict) => Some(conflict),
_ => None,
}
}
pub fn map<U, F>(self, f: F) -> ConditionalResult<U>
where
F: FnOnce(T) -> U,
{
match self {
ConditionalResult::Success(value) => ConditionalResult::Success(f(value)),
ConditionalResult::VersionMismatch(conflict) => {
ConditionalResult::VersionMismatch(conflict)
}
ConditionalResult::NotFound => ConditionalResult::NotFound,
}
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct VersionConflict {
pub expected: ScimVersion,
pub current: ScimVersion,
pub message: String,
}
impl VersionConflict {
pub fn new(expected: ScimVersion, current: ScimVersion, message: impl Into<String>) -> Self {
Self {
expected,
current,
message: message.into(),
}
}
pub fn standard_message(expected: ScimVersion, current: ScimVersion) -> Self {
Self::new(
expected,
current,
"Resource was modified by another client. Please refresh and try again.",
)
}
}
impl fmt::Display for VersionConflict {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"Version conflict: expected '{}', found '{}'. {}",
self.expected, self.current, self.message
)
}
}
impl std::error::Error for VersionConflict {}
#[derive(Debug, Error, Clone, PartialEq)]
pub enum VersionError {
#[error("Invalid ETag format: {0}")]
InvalidEtagFormat(String),
#[error("Failed to parse version: {0}")]
ParseError(String),
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_version_from_content() {
let content = br#"{"id":"123","userName":"john.doe"}"#;
let version = ScimVersion::from_content(content);
let version2 = ScimVersion::from_content(content);
assert_eq!(version, version2);
let different_content = br#"{"id":"123","userName":"jane.doe"}"#;
let different_version = ScimVersion::from_content(different_content);
assert_ne!(version, different_version);
}
#[test]
fn test_version_from_hash() {
let hash_string = "abc123def456";
let version = ScimVersion::from_hash(hash_string);
assert_eq!(version.as_str(), hash_string);
assert_eq!(version.to_http_header(), "W/\"abc123def456\"");
let version2 = ScimVersion::from_hash("different123");
assert_ne!(version, version2);
}
#[test]
fn test_version_parse_http_header() {
let version = ScimVersion::parse_http_header("\"abc123\"").unwrap();
assert_eq!(version.as_str(), "abc123");
let weak_version = ScimVersion::parse_http_header("W/\"abc123\"").unwrap();
assert_eq!(weak_version.as_str(), "abc123");
assert!(ScimVersion::parse_http_header("abc123").is_err());
assert!(ScimVersion::parse_http_header("\"\"").is_err());
assert!(ScimVersion::parse_http_header("").is_err());
}
#[test]
fn test_version_parse_raw() {
let version = ScimVersion::parse_raw("abc123def").unwrap();
assert_eq!(version.as_str(), "abc123def");
let trimmed_version = ScimVersion::parse_raw(" xyz789 ").unwrap();
assert_eq!(trimmed_version.as_str(), "xyz789");
assert!(ScimVersion::parse_raw("").is_err());
assert!(ScimVersion::parse_raw(" ").is_err());
let raw_version = ScimVersion::parse_raw("test123").unwrap();
let http_version = ScimVersion::parse_http_header("W/\"test123\"").unwrap();
assert!(raw_version.matches(&http_version));
}
#[test]
fn test_version_matches() {
let content = br#"{"id":"123","data":"test"}"#;
let v1 = ScimVersion::from_content(content);
let v2 = ScimVersion::from_content(content);
let v3 = ScimVersion::from_content(br#"{"id":"456","data":"test"}"#);
assert!(v1.matches(&v2));
assert!(!v1.matches(&v3));
}
#[test]
fn test_version_round_trip() {
let content = br#"{"id":"test","version":"round-trip"}"#;
let original = ScimVersion::from_content(content);
let etag = original.to_http_header();
let parsed = ScimVersion::parse_http_header(&etag).unwrap();
assert_eq!(original, parsed);
}
#[test]
fn test_conditional_result() {
let success: ConditionalResult<i32> = ConditionalResult::Success(42);
assert!(success.is_success());
assert_eq!(success.into_success(), Some(42));
let conflict = ConditionalResult::<i32>::VersionMismatch(VersionConflict::new(
ScimVersion::from_hash("version1"),
ScimVersion::from_hash("version2"),
"test conflict",
));
assert!(conflict.is_version_mismatch());
let not_found: ConditionalResult<i32> = ConditionalResult::NotFound;
assert!(not_found.is_not_found());
}
#[test]
fn test_conditional_result_map() {
let success: ConditionalResult<i32> = ConditionalResult::Success(42);
let mapped = success.map(|x| x.to_string());
assert_eq!(mapped.into_success(), Some("42".to_string()));
}
#[test]
fn test_version_conflict() {
let conflict = VersionConflict::standard_message(
ScimVersion::from_hash("version1"),
ScimVersion::from_hash("version2"),
);
assert_eq!(conflict.expected.as_str(), "version1");
assert_eq!(conflict.current.as_str(), "version2");
assert!(!conflict.message.is_empty());
}
#[test]
fn test_version_conflict_display() {
let conflict = VersionConflict::new(
ScimVersion::from_hash("old-hash"),
ScimVersion::from_hash("new-hash"),
"Custom message",
);
let display = format!("{}", conflict);
assert!(display.contains("old-hash"));
assert!(display.contains("new-hash"));
assert!(display.contains("Custom message"));
}
#[test]
fn test_version_serialization() {
let content = br#"{"test":"serialization"}"#;
let version = ScimVersion::from_content(content);
let json = serde_json::to_string(&version).unwrap();
let deserialized: ScimVersion = serde_json::from_str(&json).unwrap();
assert_eq!(version, deserialized);
}
#[test]
fn test_version_conflict_serialization() {
let conflict = VersionConflict::new(
ScimVersion::from_hash("hash-v1"),
ScimVersion::from_hash("hash-v2"),
"Serialization test conflict",
);
let json = serde_json::to_string(&conflict).unwrap();
let deserialized: VersionConflict = serde_json::from_str(&json).unwrap();
assert_eq!(conflict, deserialized);
}
}