use nodedb_cluster::routing::{RoutingTable, vshard_for_collection};
use nodedb_types::id::DatabaseId;
use crate::bridge::physical_plan::PhysicalPlan;
use super::route::{RouteDecision, TaskRoute};
use super::version_set::touched_collections;
pub fn route_plan(
plan: PhysicalPlan,
local_node_id: u64,
routing: Option<&RoutingTable>,
database_id: DatabaseId,
) -> Vec<TaskRoute> {
let Some(routing) = routing else {
let vshard_id = primary_vshard(&plan, database_id);
return vec![TaskRoute {
plan,
decision: RouteDecision::Local,
vshard_id,
}];
};
if plan.is_broadcast_scan() {
return route_broadcast(plan, local_node_id, routing);
}
let vshard_id = primary_vshard(&plan, database_id);
let decision = resolve_decision(vshard_id, local_node_id, Some(routing), None);
vec![TaskRoute {
plan,
decision,
vshard_id,
}]
}
pub fn resolve_decision(
vshard_id: u32,
local_node_id: u64,
routing: Option<&RoutingTable>,
live_leader_for_group: Option<&dyn Fn(u64) -> u64>,
) -> RouteDecision {
let Some(routing) = routing else {
return RouteDecision::Local;
};
let unknown = RouteDecision::LeaderUnknown {
vshard_id: vshard_id as u64,
};
if let Some(live) = live_leader_for_group
&& let Ok(group_id) = routing.group_for_vshard(vshard_id)
{
let live_leader = live(group_id);
if live_leader == local_node_id {
return RouteDecision::Local;
}
if live_leader != 0 {
return RouteDecision::Remote {
node_id: live_leader,
vshard_id: vshard_id as u64,
};
}
}
match routing.leader_for_vshard(vshard_id) {
Ok(0) => unknown,
Ok(leader) if leader == local_node_id => RouteDecision::Local,
Ok(leader) => RouteDecision::Remote {
node_id: leader,
vshard_id: vshard_id as u64,
},
Err(_) => unknown,
}
}
fn route_broadcast(
plan: PhysicalPlan,
local_node_id: u64,
routing: &RoutingTable,
) -> Vec<TaskRoute> {
use nodedb_cluster::routing::VSHARD_COUNT;
let mut routes = Vec::with_capacity(VSHARD_COUNT as usize);
for vshard_id in 0u32..VSHARD_COUNT {
let decision = resolve_decision(vshard_id, local_node_id, Some(routing), None);
routes.push(TaskRoute {
plan: plan.clone(),
decision,
vshard_id,
});
}
routes
}
fn primary_vshard(plan: &PhysicalPlan, database_id: DatabaseId) -> u32 {
touched_collections(plan)
.into_iter()
.next()
.map(|name| vshard_for_collection(database_id, &name))
.unwrap_or(0)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::bridge::physical_plan::{DocumentOp, KvOp, PhysicalPlan};
fn single_node_table() -> RoutingTable {
RoutingTable::uniform(1, &[1], 1)
}
fn two_node_table() -> RoutingTable {
RoutingTable::uniform(2, &[1, 2], 1)
}
#[test]
fn single_node_routes_locally() {
let table = single_node_table();
let plan = PhysicalPlan::Kv(KvOp::Get {
collection: "users".into(),
key: vec![],
rls_filters: vec![],
surrogate_ceiling: None,
});
let routes = route_plan(plan, 1, Some(&table), DatabaseId::DEFAULT);
assert_eq!(routes.len(), 1);
assert_eq!(routes[0].decision, RouteDecision::Local);
}
#[test]
fn no_routing_table_routes_locally() {
let plan = PhysicalPlan::Kv(KvOp::Put {
collection: "x".into(),
key: vec![],
value: vec![],
ttl_ms: 0,
surrogate: nodedb_types::Surrogate::ZERO,
});
let routes = route_plan(plan, 99, None, DatabaseId::DEFAULT);
assert_eq!(routes.len(), 1);
assert_eq!(routes[0].decision, RouteDecision::Local);
}
#[test]
fn remote_route_when_different_leader() {
let mut table = two_node_table();
let group = table.group_for_vshard(0).unwrap();
table.set_leader(group, 2);
let collection = find_collection_for_vshard(0);
let plan = PhysicalPlan::Kv(KvOp::Get {
collection,
key: vec![],
rls_filters: vec![],
surrogate_ceiling: None,
});
let routes = route_plan(plan, 1, Some(&table), DatabaseId::DEFAULT);
assert_eq!(routes.len(), 1);
match &routes[0].decision {
RouteDecision::Remote { node_id, .. } => assert_eq!(*node_id, 2),
other => panic!("expected Remote, got {other:?}"),
}
}
#[test]
fn broadcast_scan_produces_multiple_routes() {
let table = two_node_table();
let plan = PhysicalPlan::Document(DocumentOp::Scan {
collection: "events".into(),
limit: 100,
offset: 0,
sort_keys: vec![],
filters: vec![],
distinct: false,
projection: vec![],
computed_columns: vec![],
window_functions: vec![],
system_as_of_ms: None,
valid_at_ms: None,
prefilter: None,
});
let routes = route_plan(plan, 1, Some(&table), DatabaseId::DEFAULT);
assert_eq!(routes.len(), nodedb_cluster::routing::VSHARD_COUNT as usize);
}
fn find_collection_for_vshard(target: u32) -> String {
for i in 0u64.. {
let name = format!("col_{i}");
if vshard_for_collection(DatabaseId::DEFAULT, &name) == target {
return name;
}
}
unreachable!()
}
}