use std::borrow::Cow;
use azure_core::fmt::SafeDebug;
use serde::{Deserialize, Deserializer, Serialize, Serializer};
use crate::models::{IndexingPolicy, PartitionKeyDefinition, SystemProperties};
#[derive(Clone, Default, SafeDebug, PartialEq, Eq)]
#[safe(true)]
#[non_exhaustive]
pub enum TimeToLive {
#[default]
Forever,
NoDefault,
Seconds(u32),
}
impl TimeToLive {
pub fn is_forever(&self) -> bool {
matches!(self, TimeToLive::Forever)
}
}
impl From<u32> for TimeToLive {
fn from(n: u32) -> Self {
TimeToLive::Seconds(n)
}
}
impl Serialize for TimeToLive {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
match self {
TimeToLive::Forever => serializer.serialize_none(),
TimeToLive::NoDefault => serializer.serialize_i32(-1),
TimeToLive::Seconds(n) => serializer.serialize_u32(*n),
}
}
}
impl<'de> Deserialize<'de> for TimeToLive {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
match Option::<i32>::deserialize(deserializer)? {
None => Ok(TimeToLive::Forever),
Some(-1) => Ok(TimeToLive::NoDefault),
Some(n) if n > 0 => Ok(TimeToLive::Seconds(n as u32)),
Some(n) => Err(serde::de::Error::invalid_value(
serde::de::Unexpected::Signed(n as i64),
&"a nonzero positive integer or -1",
)),
}
}
}
#[derive(Clone, SafeDebug, Deserialize, Serialize, PartialEq, Eq)]
#[safe(true)]
#[serde(rename_all = "camelCase")]
#[non_exhaustive]
pub struct ContainerProperties {
pub id: Cow<'static, str>,
pub partition_key: PartitionKeyDefinition,
#[serde(skip_serializing_if = "Option::is_none")]
pub indexing_policy: Option<IndexingPolicy>,
#[serde(skip_serializing_if = "Option::is_none")]
pub unique_key_policy: Option<UniqueKeyPolicy>,
#[serde(skip_serializing_if = "Option::is_none")]
pub conflict_resolution_policy: Option<ConflictResolutionPolicy>,
#[serde(skip_serializing_if = "Option::is_none")]
pub vector_embedding_policy: Option<VectorEmbeddingPolicy>,
#[serde(default)]
#[serde(skip_serializing_if = "TimeToLive::is_forever")]
pub default_ttl: TimeToLive,
#[serde(default)]
#[serde(skip_serializing_if = "TimeToLive::is_forever")]
pub analytical_storage_ttl: TimeToLive,
#[serde(flatten)]
pub system_properties: SystemProperties,
}
impl ContainerProperties {
pub fn new(id: impl Into<Cow<'static, str>>, partition_key: PartitionKeyDefinition) -> Self {
Self {
id: id.into(),
partition_key,
indexing_policy: None,
unique_key_policy: None,
conflict_resolution_policy: None,
vector_embedding_policy: None,
default_ttl: TimeToLive::Forever,
analytical_storage_ttl: TimeToLive::Forever,
system_properties: SystemProperties::default(),
}
}
pub fn with_indexing_policy(mut self, indexing_policy: IndexingPolicy) -> Self {
self.indexing_policy = Some(indexing_policy);
self
}
pub fn with_unique_key_policy(mut self, unique_key_policy: UniqueKeyPolicy) -> Self {
self.unique_key_policy = Some(unique_key_policy);
self
}
pub fn with_conflict_resolution_policy(
mut self,
conflict_resolution_policy: ConflictResolutionPolicy,
) -> Self {
self.conflict_resolution_policy = Some(conflict_resolution_policy);
self
}
pub fn with_vector_embedding_policy(
mut self,
vector_embedding_policy: VectorEmbeddingPolicy,
) -> Self {
self.vector_embedding_policy = Some(vector_embedding_policy);
self
}
pub fn with_default_ttl(mut self, default_ttl: impl Into<TimeToLive>) -> Self {
self.default_ttl = default_ttl.into();
self
}
pub fn with_analytical_storage_ttl(
mut self,
analytical_storage_ttl: impl Into<TimeToLive>,
) -> Self {
self.analytical_storage_ttl = analytical_storage_ttl.into();
self
}
}
#[derive(Clone, SafeDebug, Deserialize, Serialize, PartialEq, Eq)]
#[safe(true)]
#[serde(rename_all = "camelCase")]
#[non_exhaustive]
pub struct VectorEmbeddingPolicy {
#[serde(rename = "vectorEmbeddings")]
pub embeddings: Vec<VectorEmbedding>,
}
#[derive(Clone, SafeDebug, Deserialize, Serialize, PartialEq, Eq)]
#[safe(true)]
#[serde(rename_all = "camelCase")]
#[non_exhaustive]
pub struct VectorEmbedding {
pub path: String,
pub data_type: VectorDataType,
pub dimensions: u32,
pub distance_function: VectorDistanceFunction,
}
#[derive(Clone, SafeDebug, Deserialize, Serialize, PartialEq, Eq)]
#[safe(true)]
#[serde(rename_all = "camelCase")]
pub enum VectorDataType {
Float16,
Float32,
Uint8,
Int8,
}
#[derive(Clone, SafeDebug, Deserialize, Serialize, PartialEq, Eq)]
#[safe(true)]
#[serde(rename_all = "camelCase")]
pub enum VectorDistanceFunction {
Euclidean,
Cosine,
#[serde(rename = "dotproduct")]
DotProduct,
}
#[derive(Clone, SafeDebug, Deserialize, Serialize, PartialEq, Eq)]
#[safe(true)]
#[serde(rename_all = "camelCase")]
#[non_exhaustive]
pub struct UniqueKeyPolicy {
pub unique_keys: Vec<UniqueKey>,
}
#[derive(Clone, SafeDebug, Deserialize, Serialize, PartialEq, Eq)]
#[safe(true)]
#[serde(rename_all = "camelCase")]
#[non_exhaustive]
pub struct UniqueKey {
pub paths: Vec<String>,
}
#[derive(Clone, SafeDebug, Deserialize, Serialize, PartialEq, Eq)]
#[safe(true)]
#[serde(rename_all = "camelCase")]
#[non_exhaustive]
pub struct ConflictResolutionPolicy {
pub mode: ConflictResolutionMode,
#[serde(rename = "conflictResolutionPath")]
pub resolution_path: String,
#[serde(rename = "conflictResolutionProcedure")]
pub resolution_procedure: String,
}
#[derive(Clone, SafeDebug, Deserialize, Serialize, PartialEq, Eq)]
#[safe(true)]
#[serde(rename_all = "PascalCase")]
pub enum ConflictResolutionMode {
LastWriterWins,
Custom,
}
#[cfg(test)]
mod tests {
use serde::{Deserialize, Serialize};
use super::TimeToLive;
use crate::models::ContainerProperties;
#[derive(Debug, Deserialize, Serialize)]
struct TtlHolder {
#[serde(default)]
#[serde(skip_serializing_if = "TimeToLive::is_forever")]
pub ttl: TimeToLive,
}
#[test]
fn serialize_ttl_seconds() {
let value = TtlHolder {
ttl: TimeToLive::Seconds(4200),
};
let json = serde_json::to_string(&value).unwrap();
assert_eq!(r#"{"ttl":4200}"#, json);
}
#[test]
fn serialize_ttl_forever() {
let value = TtlHolder {
ttl: TimeToLive::Forever,
};
let json = serde_json::to_string(&value).unwrap();
assert_eq!(r#"{}"#, json);
}
#[test]
fn serialize_ttl_no_default() {
let value = TtlHolder {
ttl: TimeToLive::NoDefault,
};
let json = serde_json::to_string(&value).unwrap();
assert_eq!(r#"{"ttl":-1}"#, json);
}
#[test]
fn deserialize_ttl_seconds() {
let value: TtlHolder = serde_json::from_str(r#"{"ttl":4200}"#).unwrap();
assert_eq!(TimeToLive::Seconds(4200), value.ttl);
}
#[test]
fn deserialize_ttl_missing() {
let value: TtlHolder = serde_json::from_str(r#"{}"#).unwrap();
assert_eq!(TimeToLive::Forever, value.ttl);
}
#[test]
fn deserialize_ttl_null() {
let value: TtlHolder = serde_json::from_str(r#"{"ttl":null}"#).unwrap();
assert_eq!(TimeToLive::Forever, value.ttl);
}
#[test]
fn deserialize_ttl_negative_one() {
let value: TtlHolder = serde_json::from_str(r#"{"ttl":-1}"#).unwrap();
assert_eq!(TimeToLive::NoDefault, value.ttl);
}
#[test]
fn deserialize_ttl_zero() {
let result = serde_json::from_str::<TtlHolder>(r#"{"ttl":0}"#);
assert!(result.is_err());
}
#[test]
fn deserialize_ttl_invalid_negative() {
let result = serde_json::from_str::<TtlHolder>(r#"{"ttl":-2}"#);
assert!(result.is_err());
}
#[test]
fn deserialize_ttl_overflow() {
let result = serde_json::from_str::<TtlHolder>(r#"{"ttl":2147483648}"#);
assert!(result.is_err());
}
#[test]
fn serialize_ttl_seconds_value() {
let json = serde_json::to_string(&TimeToLive::Seconds(86400)).unwrap();
assert_eq!("86400", json);
}
#[test]
fn serialize_ttl_no_default_value() {
let json = serde_json::to_string(&TimeToLive::NoDefault).unwrap();
assert_eq!("-1", json);
}
#[test]
fn serialize_ttl_forever_value() {
let json = serde_json::to_string(&TimeToLive::Forever).unwrap();
assert_eq!("null", json);
}
#[test]
fn deserialize_container_properties_with_ttl_negative_one() {
let json = r#"{
"id": "MyContainer",
"partitionKey": {"paths": ["/pk"], "kind": "Hash", "version": 2},
"defaultTtl": -1
}"#;
let props: ContainerProperties = serde_json::from_str(json).unwrap();
assert_eq!(TimeToLive::NoDefault, props.default_ttl);
assert_eq!(TimeToLive::Forever, props.analytical_storage_ttl);
}
#[test]
fn deserialize_container_properties_with_ttl_seconds() {
let json = r#"{
"id": "MyContainer",
"partitionKey": {"paths": ["/pk"], "kind": "Hash", "version": 2},
"defaultTtl": 3600,
"analyticalStorageTtl": -1
}"#;
let props: ContainerProperties = serde_json::from_str(json).unwrap();
assert_eq!(TimeToLive::Seconds(3600), props.default_ttl);
assert_eq!(TimeToLive::NoDefault, props.analytical_storage_ttl);
}
#[test]
pub fn container_properties_default_serialization() {
let properties = ContainerProperties::new("MyContainer", "/partitionKey".into());
let json = serde_json::to_string(&properties).unwrap();
assert_eq!(
"{\"id\":\"MyContainer\",\"partitionKey\":{\"paths\":[\"/partitionKey\"],\"kind\":\"Hash\",\"version\":2}}",
json
);
}
}