#![allow(dead_code)]
use crate::models::FiniteF64;
use azure_core::http::headers::{AsHeaders, HeaderName, HeaderValue};
use std::{borrow::Cow, hash::Hash};
pub(crate) const PARTITION_KEY: HeaderName =
HeaderName::from_static("x-ms-documentdb-partitionkey");
pub(crate) const QUERY_ENABLE_CROSS_PARTITION: HeaderName =
HeaderName::from_static("x-ms-documentdb-query-enablecrosspartition");
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
#[non_exhaustive]
pub struct PartitionKeyValue(InnerPartitionKeyValue);
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
enum InnerPartitionKeyValue {
Null,
String(Cow<'static, str>),
Number(FiniteF64),
Bool(bool),
Undefined,
Infinity,
}
const MAX_STRING_BYTES_TO_APPEND: usize = 100;
mod component {
pub const UNDEFINED: u8 = 0x00;
pub const NULL: u8 = 0x01;
pub const BOOL_FALSE: u8 = 0x02;
pub const BOOL_TRUE: u8 = 0x03;
pub const NUMBER: u8 = 0x05;
pub const STRING: u8 = 0x08;
pub const INFINITY: u8 = 0xFF;
}
impl InnerPartitionKeyValue {
fn write_for_hashing_core(&self, string_suffix: u8, writer: &mut Vec<u8>, truncate: bool) {
match self {
InnerPartitionKeyValue::Bool(true) => writer.push(component::BOOL_TRUE),
InnerPartitionKeyValue::Bool(false) => writer.push(component::BOOL_FALSE),
InnerPartitionKeyValue::Null => writer.push(component::NULL),
InnerPartitionKeyValue::Number(n) => {
writer.push(component::NUMBER);
writer.extend_from_slice(&n.value().to_le_bytes());
}
InnerPartitionKeyValue::String(s) => {
writer.push(component::STRING);
let bytes = s.as_bytes();
if truncate && bytes.len() > MAX_STRING_BYTES_TO_APPEND {
writer.extend_from_slice(&bytes[..MAX_STRING_BYTES_TO_APPEND]);
} else {
writer.extend_from_slice(bytes);
}
writer.push(string_suffix);
}
InnerPartitionKeyValue::Infinity => writer.push(component::INFINITY),
InnerPartitionKeyValue::Undefined => writer.push(component::UNDEFINED),
}
}
fn write_for_binary_encoding_v1(&self, writer: &mut Vec<u8>) {
match self {
InnerPartitionKeyValue::Bool(true) => writer.push(component::BOOL_TRUE),
InnerPartitionKeyValue::Bool(false) => writer.push(component::BOOL_FALSE),
InnerPartitionKeyValue::Infinity => writer.push(component::INFINITY),
InnerPartitionKeyValue::Number(n) => {
write_number_v1_binary(n.value(), writer);
}
InnerPartitionKeyValue::String(s) => {
writer.push(component::STRING);
let utf8 = s.as_bytes();
let short = utf8.len() <= MAX_STRING_BYTES_TO_APPEND;
let write_len = if short {
utf8.len()
} else {
std::cmp::min(utf8.len(), MAX_STRING_BYTES_TO_APPEND + 1)
};
for item in utf8.iter().take(write_len) {
writer.push(item.wrapping_add(1));
}
if short {
writer.push(0x00);
}
}
InnerPartitionKeyValue::Null => writer.push(component::NULL),
InnerPartitionKeyValue::Undefined => writer.push(component::UNDEFINED),
}
}
}
pub(crate) fn encode_double_as_uint64(value: f64) -> u64 {
let value_in_uint64 = u64::from_le_bytes(value.to_le_bytes());
let mask: u64 = 0x8000_0000_0000_0000;
if value_in_uint64 < mask {
value_in_uint64 ^ mask
} else {
(!value_in_uint64).wrapping_add(1)
}
}
pub(crate) fn write_number_v1_binary(value: f64, writer: &mut Vec<u8>) {
writer.push(component::NUMBER);
let mut payload = encode_double_as_uint64(value);
writer.push((payload >> 56) as u8);
payload <<= 8;
let mut first = true;
let mut byte_to_write: u8 = 0;
while payload != 0 {
if !first {
writer.push(byte_to_write);
} else {
first = false;
}
byte_to_write = ((payload >> 56) as u8) | 0x01;
payload <<= 7;
}
writer.push(byte_to_write & 0xFE);
}
impl From<InnerPartitionKeyValue> for PartitionKeyValue {
fn from(value: InnerPartitionKeyValue) -> Self {
PartitionKeyValue(value)
}
}
impl PartitionKeyValue {
pub(crate) fn write_for_hashing_v2(&self, writer: &mut Vec<u8>) {
self.0.write_for_hashing_core(0xFFu8, writer, false)
}
pub(crate) fn write_for_hashing_v1(&self, writer: &mut Vec<u8>) {
self.0.write_for_hashing_core(0x00u8, writer, true)
}
pub(crate) fn write_for_binary_encoding_v1(&self, writer: &mut Vec<u8>) {
self.0.write_for_binary_encoding_v1(writer)
}
pub(crate) fn is_infinity(&self) -> bool {
matches!(self.0, InnerPartitionKeyValue::Infinity)
}
pub(crate) fn truncated_for_v1_encoding(&self) -> PartitionKeyValue {
match &self.0 {
InnerPartitionKeyValue::String(s) if s.len() > MAX_STRING_BYTES_TO_APPEND => {
InnerPartitionKeyValue::String(Cow::Owned(
s[..MAX_STRING_BYTES_TO_APPEND].to_string(),
))
.into()
}
_ => self.clone(),
}
}
#[cfg(test)]
pub(crate) fn infinity() -> Self {
InnerPartitionKeyValue::Infinity.into()
}
pub fn undefined() -> Self {
InnerPartitionKeyValue::Undefined.into()
}
}
impl From<&'static str> for PartitionKeyValue {
fn from(value: &'static str) -> Self {
InnerPartitionKeyValue::String(Cow::Borrowed(value)).into()
}
}
impl From<String> for PartitionKeyValue {
fn from(value: String) -> Self {
InnerPartitionKeyValue::String(Cow::Owned(value)).into()
}
}
impl From<&String> for PartitionKeyValue {
fn from(value: &String) -> Self {
InnerPartitionKeyValue::String(Cow::Owned(value.clone())).into()
}
}
impl From<Cow<'static, str>> for PartitionKeyValue {
fn from(value: Cow<'static, str>) -> Self {
InnerPartitionKeyValue::String(value).into()
}
}
macro_rules! impl_from_number {
($source_type:ty) => {
impl From<$source_type> for PartitionKeyValue {
fn from(value: $source_type) -> Self {
InnerPartitionKeyValue::Number(FiniteF64::new_strict(value as f64)).into()
}
}
};
}
impl_from_number!(i8);
impl_from_number!(i16);
impl_from_number!(i32);
impl_from_number!(i64);
impl_from_number!(isize);
impl_from_number!(u8);
impl_from_number!(u16);
impl_from_number!(u32);
impl_from_number!(u64);
impl_from_number!(usize);
impl_from_number!(f32);
impl_from_number!(f64);
impl From<bool> for PartitionKeyValue {
fn from(value: bool) -> Self {
InnerPartitionKeyValue::Bool(value).into()
}
}
impl<T: Into<PartitionKeyValue>> From<Option<T>> for PartitionKeyValue {
fn from(value: Option<T>) -> Self {
match value {
Some(v) => v.into(),
None => InnerPartitionKeyValue::Null.into(),
}
}
}
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
#[non_exhaustive]
pub struct PartitionKey(Vec<PartitionKeyValue>);
impl Default for PartitionKey {
fn default() -> Self {
Self::EMPTY
}
}
impl PartitionKey {
pub const EMPTY: PartitionKey = PartitionKey(Vec::new());
pub(crate) fn new(value: impl Into<PartitionKeyValue>) -> Self {
Self(vec![value.into()])
}
pub fn is_empty(&self) -> bool {
self.0.is_empty()
}
pub fn len(&self) -> usize {
self.0.len()
}
pub fn values(&self) -> &[PartitionKeyValue] {
&self.0
}
}
impl AsHeaders for PartitionKey {
type Error = azure_core::Error;
type Iter = std::iter::Once<(HeaderName, HeaderValue)>;
fn as_headers(&self) -> Result<Self::Iter, Self::Error> {
if self.0.is_empty() {
return Ok(std::iter::once((
QUERY_ENABLE_CROSS_PARTITION,
HeaderValue::from_static("True"),
)));
}
let mut json = String::new();
let mut utf_buf = [0; 2]; json.push('[');
for key in &self.0 {
match &key.0 {
InnerPartitionKeyValue::Null => json.push_str("null"),
InnerPartitionKeyValue::String(ref string_key) => {
json.push('"');
for char in string_key.chars() {
match char {
'\x08' => json.push_str(r#"\b"#),
'\x0c' => json.push_str(r#"\f"#),
'\n' => json.push_str(r#"\n"#),
'\r' => json.push_str(r#"\r"#),
'\t' => json.push_str(r#"\t"#),
'"' => json.push_str(r#"\""#),
'\\' => json.push_str(r#"\\"#),
c if c.is_ascii() && !c.is_control() => json.push(c),
c if c.is_ascii() => {
json.push_str(&format!("\\u{:04x}", c as u32));
}
c => {
let encoded = c.encode_utf16(&mut utf_buf);
for code_unit in encoded {
json.push_str(&format!(r#"\u{:04x}"#, code_unit));
}
}
}
}
json.push('"');
}
InnerPartitionKeyValue::Number(num) => {
let val = num.value();
if val.fract() == 0.0 && val.abs() < (i64::MAX as f64) {
json.push_str(&format!("{}", val as i64));
} else {
json.push_str(&format!("{}", val));
}
}
InnerPartitionKeyValue::Bool(b) => {
json.push_str(if *b { "true" } else { "false" });
}
InnerPartitionKeyValue::Infinity => {
return Err(azure_core::Error::new(
azure_core::error::ErrorKind::Other,
"Infinity is not a valid partition key value for serialization",
));
}
InnerPartitionKeyValue::Undefined => {
json.push_str("{}");
}
}
json.push(',');
}
json.pop();
json.push(']');
Ok(std::iter::once((
PARTITION_KEY,
HeaderValue::from_cow(json),
)))
}
}
impl<T: Into<PartitionKeyValue>> From<T> for PartitionKey {
fn from(value: T) -> Self {
Self::new(value)
}
}
impl From<()> for PartitionKey {
fn from(_: ()) -> Self {
PartitionKey::EMPTY
}
}
impl From<Vec<PartitionKeyValue>> for PartitionKey {
fn from(values: Vec<PartitionKeyValue>) -> Self {
assert!(
values.len() <= 3,
"Partition keys can have at most 3 levels, got {}",
values.len()
);
PartitionKey(values)
}
}
impl<T1, T2> From<(T1, T2)> for PartitionKey
where
T1: Into<PartitionKeyValue>,
T2: Into<PartitionKeyValue>,
{
fn from((v1, v2): (T1, T2)) -> Self {
Self(vec![v1.into(), v2.into()])
}
}
impl<T1, T2, T3> From<(T1, T2, T3)> for PartitionKey
where
T1: Into<PartitionKeyValue>,
T2: Into<PartitionKeyValue>,
T3: Into<PartitionKeyValue>,
{
fn from((v1, v2, v3): (T1, T2, T3)) -> Self {
Self(vec![v1.into(), v2.into(), v3.into()])
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn single_partition_key() {
let pk = PartitionKey::from("test");
assert_eq!(pk.len(), 1);
assert!(!pk.is_empty());
}
#[test]
fn numeric_partition_key() {
let pk1 = PartitionKey::from(42);
let pk2 = PartitionKey::from(42i64);
let pk3 = PartitionKey::from(1.5f64);
assert_eq!(pk1.len(), 1);
assert_eq!(pk2.len(), 1);
assert_eq!(pk3.len(), 1);
}
#[test]
fn hierarchical_partition_key() {
let pk = PartitionKey::from(("tenant", "user", 42));
assert_eq!(pk.len(), 3);
}
#[test]
fn empty_partition_key() {
let pk = PartitionKey::EMPTY;
assert!(pk.is_empty());
assert_eq!(pk.len(), 0);
}
#[test]
fn default_is_empty() {
let pk = PartitionKey::default();
assert_eq!(pk, PartitionKey::EMPTY);
}
#[test]
fn unit_converts_to_empty() {
let pk = PartitionKey::from(());
assert_eq!(pk, PartitionKey::EMPTY);
assert!(pk.is_empty());
assert_eq!(pk.len(), 0);
}
#[test]
fn null_partition_key_value() {
let pk = PartitionKey::from(None::<String>);
assert_eq!(pk.len(), 1);
}
#[test]
#[should_panic(expected = "at most 3 levels")]
fn too_many_levels() {
let values = vec![
PartitionKeyValue::from("a"),
PartitionKeyValue::from("b"),
PartitionKeyValue::from("c"),
PartitionKeyValue::from("d"),
];
let _pk = PartitionKey::from(values);
}
}