use std::sync::Arc;
use crate::models::{
effective_partition_key::EffectivePartitionKey, partition_key_range::PkRangesResponse,
ContainerReference, PartitionKey,
};
use super::{container_routing_map::ContainerRoutingMap, AsyncCache};
const MAX_FETCH_ITERATIONS: usize = 10;
#[derive(Debug)]
pub(crate) struct PkRangeFetchResult {
pub ranges: Vec<crate::models::partition_key_range::PartitionKeyRange>,
pub continuation: Option<String>,
pub not_modified: bool,
}
#[derive(Debug)]
pub(crate) struct PartitionKeyRangeCache {
cache: AsyncCache<ContainerReference, ContainerRoutingMap>,
}
impl PartitionKeyRangeCache {
pub fn new() -> Self {
Self {
cache: AsyncCache::new(),
}
}
pub async fn resolve_partition_key_range_id<F, Fut>(
&self,
container: &ContainerReference,
partition_key: &PartitionKey,
force_refresh: bool,
fetch_pk_ranges: F,
) -> Option<String>
where
F: Fn(ContainerReference, Option<String>) -> Fut,
Fut: std::future::Future<Output = Option<PkRangeFetchResult>>,
{
if partition_key.is_empty() {
return None;
}
let pk_def = container.partition_key_definition();
let kind = pk_def.kind();
let version = pk_def.version();
let epk = EffectivePartitionKey::compute(partition_key.values(), kind, version);
let routing_map = self
.try_lookup(container, force_refresh, fetch_pk_ranges)
.await?;
routing_map
.get_range_by_effective_partition_key(&epk)
.map(|r| r.id.clone())
}
pub async fn resolve_partition_key_range_ids<F, Fut>(
&self,
container: &ContainerReference,
partition_key: &PartitionKey,
force_refresh: bool,
fetch_pk_ranges: F,
) -> Option<Vec<String>>
where
F: Fn(ContainerReference, Option<String>) -> Fut,
Fut: std::future::Future<Output = Option<PkRangeFetchResult>>,
{
if partition_key.is_empty() {
return None;
}
let pk_def = container.partition_key_definition();
let epk_range =
EffectivePartitionKey::compute_range(partition_key.values(), pk_def).ok()?;
if epk_range.start == epk_range.end {
let routing_map = self
.try_lookup(container, force_refresh, fetch_pk_ranges)
.await?;
routing_map
.get_range_by_effective_partition_key(&epk_range.start)
.map(|r| vec![r.id.clone()])
} else {
self.resolve_overlapping_ranges(
container,
&epk_range.start..&epk_range.end,
force_refresh,
fetch_pk_ranges,
)
.await
.map(|ranges| ranges.into_iter().map(|r| r.id).collect())
}
}
pub async fn resolve_overlapping_ranges<F, Fut>(
&self,
container: &ContainerReference,
epk_range: std::ops::Range<&EffectivePartitionKey>,
force_refresh: bool,
fetch_pk_ranges: F,
) -> Option<Vec<crate::models::partition_key_range::PartitionKeyRange>>
where
F: Fn(ContainerReference, Option<String>) -> Fut,
Fut: std::future::Future<Output = Option<PkRangeFetchResult>>,
{
let routing_map = self
.try_lookup(container, force_refresh, fetch_pk_ranges)
.await?;
Some(
routing_map
.get_overlapping_ranges(epk_range)
.into_iter()
.cloned()
.collect(),
)
}
pub async fn resolve_partition_key_range_by_id<F, Fut>(
&self,
container: &ContainerReference,
partition_key_range_id: &str,
force_refresh: bool,
fetch_pk_ranges: F,
) -> Option<crate::models::partition_key_range::PartitionKeyRange>
where
F: Fn(ContainerReference, Option<String>) -> Fut,
Fut: std::future::Future<Output = Option<PkRangeFetchResult>>,
{
let routing_map = self
.try_lookup(container, force_refresh, fetch_pk_ranges)
.await?;
routing_map.range(partition_key_range_id).cloned()
}
pub(crate) async fn try_lookup<F, Fut>(
&self,
container: &ContainerReference,
force_refresh: bool,
fetch_pk_ranges: F,
) -> Option<Arc<ContainerRoutingMap>>
where
F: Fn(ContainerReference, Option<String>) -> Fut,
Fut: std::future::Future<Output = Option<PkRangeFetchResult>>,
{
let key = container.clone();
if force_refresh {
let previous = self.cache.get(&key).await;
let prev_continuation = previous
.as_ref()
.and_then(|m| m.change_feed_next_if_none_match.clone());
self.cache
.get_or_refresh_with(
key.clone(),
|existing| {
if existing.is_none() {
return true;
}
existing.map(|m| &m.change_feed_next_if_none_match)
== Some(&prev_continuation)
},
|| fetch_and_build_routing_map(key.clone(), previous, fetch_pk_ranges),
)
.await
} else {
Some(
self.cache
.get_or_insert_with(key.clone(), || {
fetch_and_build_routing_map(key.clone(), None, fetch_pk_ranges)
})
.await,
)
}
}
pub async fn invalidate(&self, container: &ContainerReference) {
self.cache.invalidate(container).await;
}
}
async fn fetch_and_build_routing_map<F, Fut>(
container: ContainerReference,
previous_routing_map: Option<Arc<ContainerRoutingMap>>,
fetch_pk_ranges: F,
) -> ContainerRoutingMap
where
F: Fn(ContainerReference, Option<String>) -> Fut,
Fut: std::future::Future<Output = Option<PkRangeFetchResult>>,
{
let mut all_ranges = Vec::new();
let mut continuation = previous_routing_map
.as_ref()
.and_then(|m| m.change_feed_next_if_none_match.clone());
let mut received_not_modified = false;
let mut iterations_completed = 0;
for iteration in 0..MAX_FETCH_ITERATIONS {
iterations_completed = iteration + 1;
tracing::trace!(
iteration,
has_continuation = continuation.is_some(),
"Fetching partition key ranges"
);
let result = match fetch_pk_ranges(container.clone(), continuation.clone()).await {
Some(r) => r,
None => {
tracing::warn!(
"Failed to fetch partition key ranges from service (iteration {})",
iteration
);
return ContainerRoutingMap::empty();
}
};
continuation = result.continuation;
if result.not_modified {
tracing::trace!(iteration, "Service returned 304 Not Modified");
received_not_modified = true;
break;
}
tracing::trace!(
iteration,
range_count = result.ranges.len(),
"Received partition key ranges"
);
all_ranges.extend(result.ranges);
}
tracing::debug!(
iterations = iterations_completed,
total_ranges = all_ranges.len(),
not_modified = received_not_modified,
"Partition key range fetch loop completed"
);
if !received_not_modified && !all_ranges.is_empty() {
tracing::warn!(
"Partition key range fetch loop reached MAX_FETCH_ITERATIONS ({}) without \
receiving Not Modified; routing map may be built from partial data",
MAX_FETCH_ITERATIONS
);
}
if let Some(prev) = previous_routing_map {
if all_ranges.is_empty() {
return (*prev).clone();
}
return match prev.try_combine(all_ranges, continuation) {
Ok(Some(map)) => map,
Ok(None) => {
tracing::warn!(
"Incremental routing map merge incomplete; falling back to previous map"
);
(*prev).clone()
}
Err(e) => {
tracing::warn!(
"Incremental routing map merge failed: {}; falling back to previous map",
e
);
(*prev).clone()
}
};
}
match ContainerRoutingMap::try_create(all_ranges, None, continuation) {
Ok(Some(map)) => map,
Ok(None) => {
tracing::warn!("Partition key range fetch returned empty set");
ContainerRoutingMap::empty()
}
Err(e) => {
tracing::warn!("Partition key ranges invalid: {}", e);
ContainerRoutingMap::empty()
}
}
}
pub(crate) fn parse_pk_ranges_response(
body: &[u8],
) -> Option<Vec<crate::models::partition_key_range::PartitionKeyRange>> {
let response: PkRangesResponse = serde_json::from_slice(body).ok()?;
Some(response.partition_key_ranges)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::models::partition_key_range::PartitionKeyRange as PkRange;
fn test_ranges() -> Vec<PkRange> {
vec![PkRange::new("0".into(), "", "FF")]
}
async fn test_fetch(
_container: ContainerReference,
continuation: Option<String>,
) -> Option<PkRangeFetchResult> {
if continuation.is_some() {
Some(PkRangeFetchResult {
ranges: vec![],
continuation,
not_modified: true,
})
} else {
Some(PkRangeFetchResult {
ranges: test_ranges(),
continuation: Some("test-etag".to_string()),
not_modified: false,
})
}
}
#[tokio::test]
async fn resolve_returns_range_id() {
let cache = PartitionKeyRangeCache::new();
let account = crate::models::AccountReference::with_master_key(
url::Url::parse("https://test.documents.azure.com:443/").unwrap(),
"key",
);
let container_props = crate::models::ContainerProperties {
id: "testcontainer".into(),
partition_key: serde_json::from_str(r#"{"paths":["/pk"],"version":2}"#).unwrap(),
system_properties: Default::default(),
};
let container = ContainerReference::new(
account,
"testdb",
"testdb_rid",
"testcontainer",
"testcontainer_rid",
&container_props,
);
let pk = PartitionKey::from("hello");
let range_id = cache
.resolve_partition_key_range_id(&container, &pk, false, test_fetch)
.await;
assert!(range_id.is_some());
assert_eq!(range_id.unwrap(), "0");
}
#[tokio::test]
async fn empty_pk_returns_none() {
let cache = PartitionKeyRangeCache::new();
let account = crate::models::AccountReference::with_master_key(
url::Url::parse("https://test.documents.azure.com:443/").unwrap(),
"key",
);
let container_props = crate::models::ContainerProperties {
id: "testcontainer".into(),
partition_key: serde_json::from_str(r#"{"paths":["/pk"],"version":2}"#).unwrap(),
system_properties: Default::default(),
};
let container = ContainerReference::new(
account,
"testdb",
"testdb_rid",
"testcontainer",
"testcontainer_rid",
&container_props,
);
let pk = PartitionKey::EMPTY;
let range_id = cache
.resolve_partition_key_range_id(&container, &pk, false, test_fetch)
.await;
assert!(range_id.is_none());
}
#[tokio::test]
async fn force_refresh_uses_incremental_merge() {
let cache = PartitionKeyRangeCache::new();
let account = crate::models::AccountReference::with_master_key(
url::Url::parse("https://test.documents.azure.com:443/").unwrap(),
"key",
);
let container_props = crate::models::ContainerProperties {
id: "testcontainer".into(),
partition_key: serde_json::from_str(r#"{"paths":["/pk"],"version":2}"#).unwrap(),
system_properties: Default::default(),
};
let container = ContainerReference::new(
account,
"testdb",
"testdb_rid",
"testcontainer",
"testcontainer_rid",
&container_props,
);
let pk = PartitionKey::from("hello");
let range_id = cache
.resolve_partition_key_range_id(&container, &pk, false, test_fetch)
.await;
assert_eq!(range_id.as_deref(), Some("0"));
let range_id = cache
.resolve_partition_key_range_id(&container, &pk, true, test_fetch)
.await;
assert_eq!(range_id.as_deref(), Some("0"));
}
#[test]
fn parse_pk_ranges_response_test() {
let body = br#"{
"PartitionKeyRanges": [
{"id": "0", "_rid": "rid0", "minInclusive": "", "maxExclusive": "FF"}
]
}"#;
let ranges = parse_pk_ranges_response(body).unwrap();
assert_eq!(ranges.len(), 1);
assert_eq!(ranges[0].id, "0");
}
fn make_container(pk_json: &str) -> ContainerReference {
let account = crate::models::AccountReference::with_master_key(
url::Url::parse("https://test.documents.azure.com:443/").unwrap(),
"key",
);
let container_props = crate::models::ContainerProperties {
id: "testcontainer".into(),
partition_key: serde_json::from_str(pk_json).unwrap(),
system_properties: Default::default(),
};
ContainerReference::new(
account,
"testdb",
"testdb_rid",
"testcontainer",
"testcontainer_rid",
&container_props,
)
}
async fn two_range_fetch(
_container: ContainerReference,
continuation: Option<String>,
) -> Option<PkRangeFetchResult> {
if continuation.is_some() {
Some(PkRangeFetchResult {
ranges: vec![],
continuation,
not_modified: true,
})
} else {
Some(PkRangeFetchResult {
ranges: vec![
PkRange::new("0".into(), "", "80"),
PkRange::new("1".into(), "80", "FF"),
],
continuation: Some("test-etag".to_string()),
not_modified: false,
})
}
}
#[tokio::test]
async fn resolve_ids_empty_pk_returns_none() {
let cache = PartitionKeyRangeCache::new();
let container = make_container(
r#"{"paths":["/tenantId","/userId","/sessionId"],"kind":"MultiHash","version":2}"#,
);
let result = cache
.resolve_partition_key_range_ids(&container, &PartitionKey::EMPTY, false, test_fetch)
.await;
assert!(result.is_none());
}
#[tokio::test]
async fn resolve_ids_full_multihash_returns_single_id() {
let cache = PartitionKeyRangeCache::new();
let container =
make_container(r#"{"paths":["/tenantId","/userId"],"kind":"MultiHash","version":2}"#);
let pk = PartitionKey::from(("tenant1", "user1"));
let result = cache
.resolve_partition_key_range_ids(&container, &pk, false, test_fetch)
.await;
assert!(result.is_some());
let ids = result.unwrap();
assert_eq!(ids.len(), 1);
assert_eq!(ids[0], "0"); }
#[tokio::test]
async fn resolve_ids_prefix_multihash_returns_multiple_ids() {
let cache = PartitionKeyRangeCache::new();
let container = make_container(
r#"{"paths":["/tenantId","/userId","/sessionId"],"kind":"MultiHash","version":2}"#,
);
let pk = PartitionKey::from("tenant1");
let result = cache
.resolve_partition_key_range_ids(&container, &pk, false, two_range_fetch)
.await;
assert!(result.is_some());
let ids = result.unwrap();
assert_eq!(ids, vec!["0".to_string()]);
}
#[tokio::test]
async fn resolve_ids_non_multihash_returns_single_id() {
let cache = PartitionKeyRangeCache::new();
let container = make_container(r#"{"paths":["/pk"],"version":2}"#);
let pk = PartitionKey::from("hello");
let result = cache
.resolve_partition_key_range_ids(&container, &pk, false, test_fetch)
.await;
assert!(result.is_some());
let ids = result.unwrap();
assert_eq!(ids, vec!["0".to_string()]);
}
#[tokio::test]
async fn resolve_ids_matches_single_resolve() {
let cache = PartitionKeyRangeCache::new();
let container =
make_container(r#"{"paths":["/tenantId","/userId"],"kind":"MultiHash","version":2}"#);
let pk = PartitionKey::from(("tenant1", "user1"));
let single = cache
.resolve_partition_key_range_id(&container, &pk, false, test_fetch)
.await;
let plural = cache
.resolve_partition_key_range_ids(&container, &pk, false, test_fetch)
.await;
assert_eq!(single.as_deref(), Some("0"));
assert_eq!(plural.as_deref(), Some(vec!["0".to_string()].as_slice()));
}
}