use std::borrow::Cow;
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
#[non_exhaustive]
pub(crate) struct ResourceName(Cow<'static, str>);
impl ResourceName {
pub fn new(name: impl Into<Cow<'static, str>>) -> Self {
Self(name.into())
}
pub fn as_str(&self) -> &str {
&self.0
}
}
impl From<&'static str> for ResourceName {
fn from(s: &'static str) -> Self {
Self::new(s)
}
}
impl From<String> for ResourceName {
fn from(s: String) -> Self {
Self::new(s)
}
}
impl AsRef<str> for ResourceName {
fn as_ref(&self) -> &str {
self.as_str()
}
}
impl std::fmt::Display for ResourceName {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.0)
}
}
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
#[non_exhaustive]
pub(crate) struct ResourceId(Cow<'static, str>);
impl ResourceId {
pub fn new(rid: impl Into<Cow<'static, str>>) -> Self {
Self(rid.into())
}
pub fn as_str(&self) -> &str {
&self.0
}
}
impl From<&'static str> for ResourceId {
fn from(s: &'static str) -> Self {
Self::new(s)
}
}
impl From<String> for ResourceId {
fn from(s: String) -> Self {
Self::new(s)
}
}
impl AsRef<str> for ResourceId {
fn as_ref(&self) -> &str {
self.as_str()
}
}
impl std::borrow::Borrow<str> for ResourceId {
fn borrow(&self) -> &str {
self.as_str()
}
}
impl std::fmt::Display for ResourceId {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.0)
}
}
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
pub(crate) enum ResourceIdentifier {
ByName(ResourceName),
ByRid(ResourceId),
}
impl ResourceIdentifier {
pub(crate) fn by_name(name: impl Into<ResourceName>) -> Self {
Self::ByName(name.into())
}
pub(crate) fn by_rid(rid: impl Into<ResourceId>) -> Self {
Self::ByRid(rid.into())
}
pub(crate) fn name(&self) -> Option<&str> {
match self {
Self::ByName(name) => Some(name.as_str()),
Self::ByRid(_) => None,
}
}
pub(crate) fn rid(&self) -> Option<&str> {
match self {
Self::ByName(_) => None,
Self::ByRid(rid) => Some(rid.as_str()),
}
}
pub(crate) fn is_by_name(&self) -> bool {
matches!(self, Self::ByName(_))
}
pub(crate) fn is_by_rid(&self) -> bool {
matches!(self, Self::ByRid(_))
}
}
#[cfg(test)]
mod tests {
use super::*;
use base64::{engine::general_purpose::STANDARD, Engine as _};
#[derive(Clone, Debug, PartialEq, Eq)]
struct ParsedResourceId {
database_rid: Option<ResourceId>,
container_rid: Option<ResourceId>,
document_rid: Option<ResourceId>,
}
impl ParsedResourceId {
fn empty() -> Self {
Self {
database_rid: None,
container_rid: None,
document_rid: None,
}
}
fn database(database_rid: ResourceId) -> Self {
Self {
database_rid: Some(database_rid),
container_rid: None,
document_rid: None,
}
}
fn container(database_rid: ResourceId, container_rid: ResourceId) -> Self {
Self {
database_rid: Some(database_rid),
container_rid: Some(container_rid),
document_rid: None,
}
}
fn document(
database_rid: ResourceId,
container_rid: ResourceId,
document_rid: ResourceId,
) -> Self {
Self {
database_rid: Some(database_rid),
container_rid: Some(container_rid),
document_rid: Some(document_rid),
}
}
fn database_rid(&self) -> Option<&ResourceId> {
self.database_rid.as_ref()
}
fn container_rid(&self) -> Option<&ResourceId> {
self.container_rid.as_ref()
}
fn document_rid(&self) -> Option<&ResourceId> {
self.document_rid.as_ref()
}
}
#[derive(Clone, Debug, PartialEq, Eq)]
enum RidParseError {
Empty,
InvalidLength,
InvalidBase64,
}
impl std::fmt::Display for RidParseError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Empty => write!(f, "RID string is empty"),
Self::InvalidLength => write!(f, "RID has invalid byte length"),
Self::InvalidBase64 => write!(f, "RID contains invalid Base64"),
}
}
}
impl std::error::Error for RidParseError {}
fn decode_rid(rid: &str) -> Result<Vec<u8>, RidParseError> {
if rid.is_empty() {
return Err(RidParseError::Empty);
}
if !rid.len().is_multiple_of(4) {
return Err(RidParseError::InvalidLength);
}
let b64 = rid.replace('-', "/");
STANDARD
.decode(&b64)
.map_err(|_| RidParseError::InvalidBase64)
}
fn encode_rid(bytes: &[u8]) -> String {
STANDARD.encode(bytes).replace('/', "-")
}
fn extract_database_rid_from_container_rid(
container_rid: &str,
) -> Result<ResourceId, RidParseError> {
let bytes = decode_rid(container_rid)?;
if bytes.len() < 8 || bytes.len() % 4 != 0 {
return Err(RidParseError::InvalidLength);
}
let db_bytes = &bytes[0..4];
Ok(ResourceId::new(encode_rid(db_bytes)))
}
fn extract_container_rid_from_document_rid(
document_rid: &str,
) -> Result<ResourceId, RidParseError> {
let bytes = decode_rid(document_rid)?;
if bytes.len() < 16 || bytes.len() % 4 != 0 {
return Err(RidParseError::InvalidLength);
}
let container_bytes = &bytes[0..8];
Ok(ResourceId::new(encode_rid(container_bytes)))
}
fn parse_rid(rid: &str) -> Result<ParsedResourceId, RidParseError> {
let bytes = decode_rid(rid)?;
let len = bytes.len();
if len == 3 {
return Ok(ParsedResourceId::empty());
}
if len % 4 != 0 {
return Err(RidParseError::InvalidLength);
}
let mut parsed = ParsedResourceId::empty();
if len >= 4 {
let db_rid = encode_rid(&bytes[0..4]);
parsed.database_rid = Some(ResourceId::new(db_rid));
}
if len >= 8 {
let container_rid = encode_rid(&bytes[0..8]);
parsed.container_rid = Some(ResourceId::new(container_rid));
}
if len >= 16 {
let document_rid = encode_rid(&bytes[0..16]);
parsed.document_rid = Some(ResourceId::new(document_rid));
}
Ok(parsed)
}
#[test]
fn resource_name_from_str() {
let name = ResourceName::from("mydb");
assert_eq!(name.as_str(), "mydb");
}
#[test]
fn resource_name_from_string() {
let name = ResourceName::from(String::from("mydb"));
assert_eq!(name.as_str(), "mydb");
}
#[test]
fn resource_rid_from_str() {
let rid = ResourceId::from("abc123");
assert_eq!(rid.as_str(), "abc123");
}
#[test]
fn database_id_by_name() {
let id = ResourceIdentifier::ByName(ResourceName::from("testdb"));
assert_eq!(id.name(), Some("testdb"));
assert_eq!(id.rid(), None);
}
#[test]
fn database_id_by_rid() {
let id = ResourceIdentifier::ByRid(ResourceId::from("abc123"));
assert_eq!(id.name(), None);
assert_eq!(id.rid(), Some("abc123"));
}
#[test]
fn parsed_resource_id_database() {
let parsed = ParsedResourceId::database(ResourceId::from("db123"));
assert_eq!(parsed.database_rid().map(|r| r.as_str()), Some("db123"));
assert!(parsed.container_rid().is_none());
assert!(parsed.document_rid().is_none());
}
#[test]
fn parsed_resource_id_container() {
let parsed =
ParsedResourceId::container(ResourceId::from("db123"), ResourceId::from("coll456"));
assert_eq!(parsed.database_rid().map(|r| r.as_str()), Some("db123"));
assert_eq!(parsed.container_rid().map(|r| r.as_str()), Some("coll456"));
assert!(parsed.document_rid().is_none());
}
#[test]
fn parsed_resource_id_document() {
let parsed = ParsedResourceId::document(
ResourceId::from("db123"),
ResourceId::from("coll456"),
ResourceId::from("doc789"),
);
assert_eq!(parsed.database_rid().map(|r| r.as_str()), Some("db123"));
assert_eq!(parsed.container_rid().map(|r| r.as_str()), Some("coll456"));
assert_eq!(parsed.document_rid().map(|r| r.as_str()), Some("doc789"));
}
#[test]
fn decode_and_encode_rid_roundtrip() {
let db_bytes: [u8; 4] = [0x01, 0x02, 0x03, 0x04];
let encoded = encode_rid(&db_bytes);
let decoded = decode_rid(&encoded).unwrap();
assert_eq!(decoded, db_bytes);
}
#[test]
fn decode_rid_replaces_dash_with_slash() {
let bytes: [u8; 4] = [0xFF, 0xFF, 0xFF, 0xFF];
let b64 = STANDARD.encode(bytes);
let cosmos_rid = b64.replace('/', "-");
let decoded = decode_rid(&cosmos_rid).unwrap();
assert_eq!(decoded, bytes);
}
#[test]
fn encode_rid_replaces_slash_with_dash() {
let bytes: [u8; 4] = [0xFF, 0xFF, 0xFF, 0xFF];
let encoded = encode_rid(&bytes);
assert!(!encoded.contains('/'), "encoded RID should not contain '/'");
}
#[test]
fn decode_rid_empty_returns_error() {
assert_eq!(decode_rid(""), Err(RidParseError::Empty));
}
#[test]
fn decode_rid_invalid_length_returns_error() {
assert_eq!(decode_rid("abc"), Err(RidParseError::InvalidLength));
}
#[test]
fn extract_database_rid_from_container_rid_valid() {
let mut container_bytes = [0u8; 8];
container_bytes[0..4].copy_from_slice(&[0x0A, 0x0B, 0x0C, 0x0D]); container_bytes[4..8].copy_from_slice(&[0x80, 0x01, 0x02, 0x03]); let container_rid = encode_rid(&container_bytes);
let db_rid = extract_database_rid_from_container_rid(&container_rid).unwrap();
let expected_db_rid = encode_rid(&[0x0A, 0x0B, 0x0C, 0x0D]);
assert_eq!(db_rid.as_str(), expected_db_rid);
}
#[test]
fn extract_database_rid_from_short_rid_returns_error() {
let db_bytes: [u8; 4] = [0x01, 0x02, 0x03, 0x04];
let db_rid = encode_rid(&db_bytes);
assert_eq!(
extract_database_rid_from_container_rid(&db_rid),
Err(RidParseError::InvalidLength)
);
}
#[test]
fn extract_container_rid_from_document_rid_valid() {
let mut doc_bytes = [0u8; 16];
doc_bytes[0..4].copy_from_slice(&[0x0A, 0x0B, 0x0C, 0x0D]); doc_bytes[4..8].copy_from_slice(&[0x80, 0x01, 0x02, 0x03]); doc_bytes[8..16].copy_from_slice(&[0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x00]);
let doc_rid = encode_rid(&doc_bytes);
let container_rid = extract_container_rid_from_document_rid(&doc_rid).unwrap();
let expected = encode_rid(&doc_bytes[0..8]);
assert_eq!(container_rid.as_str(), expected);
}
#[test]
fn parse_rid_database() {
let db_bytes: [u8; 4] = [0x01, 0x02, 0x03, 0x04];
let rid_str = encode_rid(&db_bytes);
let parsed = parse_rid(&rid_str).unwrap();
assert!(parsed.database_rid().is_some());
assert!(parsed.container_rid().is_none());
assert!(parsed.document_rid().is_none());
}
#[test]
fn parse_rid_container() {
let mut bytes = [0u8; 8];
bytes[0..4].copy_from_slice(&[0x0A, 0x0B, 0x0C, 0x0D]);
bytes[4..8].copy_from_slice(&[0x80, 0x01, 0x02, 0x03]);
let rid_str = encode_rid(&bytes);
let parsed = parse_rid(&rid_str).unwrap();
assert!(parsed.database_rid().is_some());
assert!(parsed.container_rid().is_some());
assert!(parsed.document_rid().is_none());
let container_rid = parsed.container_rid().unwrap().as_str();
let db_rid = extract_database_rid_from_container_rid(container_rid).unwrap();
assert_eq!(db_rid.as_str(), parsed.database_rid().unwrap().as_str());
}
#[test]
fn parse_rid_document() {
let mut bytes = [0u8; 16];
bytes[0..4].copy_from_slice(&[0x0A, 0x0B, 0x0C, 0x0D]);
bytes[4..8].copy_from_slice(&[0x80, 0x01, 0x02, 0x03]);
bytes[8..16].copy_from_slice(&[0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x00]);
let rid_str = encode_rid(&bytes);
let parsed = parse_rid(&rid_str).unwrap();
assert!(parsed.database_rid().is_some());
assert!(parsed.container_rid().is_some());
assert!(parsed.document_rid().is_some());
}
}