use osproxy_core::{ClusterId, Epoch, IndexName, PartitionId, Target};
use osproxy_spi::{
BodyDoc, BodyTransform, InjectedField, InjectedValue, MigrationPhase, Placement, RequestCtx,
RouteDecision, RoutingSpi, SpiError, TenancySpi,
};
use serde_json::Value;
#[derive(Clone, PartialEq, Eq, Debug)]
pub struct Resolved {
pub partition: PartitionId,
pub decision: RouteDecision,
pub migration: MigrationPhase,
}
#[derive(Debug)]
pub struct TenancyRouter<T> {
spi: T,
}
impl<T: TenancySpi> TenancyRouter<T> {
#[must_use]
pub fn new(spi: T) -> Self {
Self { spi }
}
#[must_use]
pub fn spi(&self) -> &T {
&self.spi
}
pub async fn admit_write(&self, partition: &PartitionId, epoch: Epoch) -> bool {
self.spi.admit_write(partition, epoch).await
}
pub async fn resolve(&self, ctx: &RequestCtx<'_>) -> Result<Resolved, SpiError> {
if !ctx.endpoint().is_tenancy_aware() {
return Err(SpiError::UnsupportedEndpoint {
endpoint: ctx.endpoint(),
});
}
let partition = self.resolve_partition(ctx, BodyDoc::new(ctx.body()))?;
self.resolve_placement(ctx, partition, ctx.logical_index())
.await
}
pub fn resolve_partition(
&self,
ctx: &RequestCtx<'_>,
body: BodyDoc<'_>,
) -> Result<PartitionId, SpiError> {
self.spi.resolve_partition(ctx, body)
}
pub async fn resolve_placement(
&self,
ctx: &RequestCtx<'_>,
partition: PartitionId,
logical_index: &str,
) -> Result<Resolved, SpiError> {
let at = self.spi.placement_for(&partition).await?;
let target = target_for(&at.placement, logical_index).with_endpoint(at.endpoint.clone());
let body_transform = self.build_transform(&at.placement, &partition, ctx)?;
let decision = RouteDecision {
target,
upstream_protocol: ctx.protocol(),
header_ops: Vec::new(),
body_transform,
epoch: at.epoch,
};
Ok(Resolved {
partition,
decision,
migration: at.phase,
})
}
fn build_transform(
&self,
placement: &Placement,
partition: &PartitionId,
ctx: &RequestCtx<'_>,
) -> Result<BodyTransform, SpiError> {
let inject = match placement {
Placement::SharedIndex { inject, .. } => resolve_inject(inject, partition, ctx)?,
Placement::DedicatedCluster { .. } | Placement::DedicatedIndex { .. } => Vec::new(),
};
let id_rule = self.spi.doc_id_rule();
if let Placement::SharedIndex { .. } = placement {
let partition_scoped = id_rule
.as_ref()
.is_some_and(|rule| rule.template.references_partition());
if !partition_scoped {
return Err(SpiError::IdRuleMissingPartition);
}
}
Ok(match (inject.is_empty(), id_rule) {
(true, None) => BodyTransform::None,
(false, None) => BodyTransform::Inject(inject),
(true, Some(id)) => BodyTransform::ConstructId(id),
(false, Some(id)) => BodyTransform::Both { inject, id },
})
}
}
impl<T: TenancySpi> RoutingSpi for TenancyRouter<T> {
async fn route(&self, ctx: &RequestCtx<'_>) -> Result<RouteDecision, SpiError> {
Ok(self.resolve(ctx).await?.decision)
}
}
#[allow(
async_fn_in_trait,
reason = "consumed through generics in the engine, where Send is verified at \
the spawn site, mirroring TenancySpi/RoutingSpi (docs/02 §2)"
)]
pub trait Router: Send + Sync + 'static {
async fn resolve(&self, ctx: &RequestCtx<'_>) -> Result<Resolved, SpiError>;
fn resolve_partition(
&self,
ctx: &RequestCtx<'_>,
body: BodyDoc<'_>,
) -> Result<PartitionId, SpiError>;
async fn resolve_placement(
&self,
ctx: &RequestCtx<'_>,
partition: PartitionId,
logical_index: &str,
) -> Result<Resolved, SpiError>;
async fn admit_write(&self, partition: &PartitionId, epoch: Epoch) -> bool;
fn cluster_endpoint(&self, _cluster: &ClusterId) -> Option<String> {
None
}
}
impl<T: TenancySpi> Router for TenancyRouter<T> {
async fn resolve(&self, ctx: &RequestCtx<'_>) -> Result<Resolved, SpiError> {
TenancyRouter::resolve(self, ctx).await
}
fn resolve_partition(
&self,
ctx: &RequestCtx<'_>,
body: BodyDoc<'_>,
) -> Result<PartitionId, SpiError> {
TenancyRouter::resolve_partition(self, ctx, body)
}
async fn resolve_placement(
&self,
ctx: &RequestCtx<'_>,
partition: PartitionId,
logical_index: &str,
) -> Result<Resolved, SpiError> {
TenancyRouter::resolve_placement(self, ctx, partition, logical_index).await
}
async fn admit_write(&self, partition: &PartitionId, epoch: Epoch) -> bool {
TenancyRouter::admit_write(self, partition, epoch).await
}
fn cluster_endpoint(&self, cluster: &ClusterId) -> Option<String> {
self.spi.cluster_endpoint(cluster)
}
}
fn target_for(placement: &Placement, logical_index: &str) -> Target {
match placement {
Placement::DedicatedCluster { cluster } => {
Target::new(cluster.clone(), IndexName::from(logical_index))
}
Placement::DedicatedIndex { cluster, index }
| Placement::SharedIndex { cluster, index, .. } => {
Target::new(cluster.clone(), index.clone())
}
}
}
fn resolve_inject(
fields: &[InjectedField],
_partition: &PartitionId,
ctx: &RequestCtx<'_>,
) -> Result<Vec<InjectedField>, SpiError> {
fields
.iter()
.map(|field| {
let value = match &field.value {
InjectedValue::PartitionId => return Ok(field.clone()),
InjectedValue::Constant(constant) => constant.clone(),
InjectedValue::FromPrincipal(attr) => ctx
.principal()
.attr(attr)
.map(|v| Value::String(v.to_owned()))
.ok_or_else(|| SpiError::PrincipalAttrMissing { attr: attr.clone() })?,
InjectedValue::FromHeader(name) => ctx
.headers()
.get(name)
.map(|v| Value::String(v.to_owned()))
.ok_or_else(|| SpiError::HeaderMissing {
header: name.clone(),
})?,
};
Ok(InjectedField::new(
field.name.clone(),
InjectedValue::Constant(value),
))
})
.collect()
}
#[cfg(test)]
#[path = "router_tests.rs"]
mod tests;