use std::collections::HashMap;
use std::collections::HashSet;
use std::error::Error;
use std::fmt::Debug;
use std::hash::Hash;
use std::mem;
use std::sync::Arc;
use std::sync::Mutex;
use crate::client::ConnectivityState;
use crate::client::load_balancing::ChannelController;
use crate::client::load_balancing::DynLbConfig;
use crate::client::load_balancing::DynLbPolicy;
use crate::client::load_balancing::DynLbPolicyBuilder;
use crate::client::load_balancing::LbPolicyOptions;
use crate::client::load_balancing::LbState;
use crate::client::load_balancing::Subchannel;
use crate::client::load_balancing::SubchannelState;
use crate::client::load_balancing::WorkScheduler;
use crate::client::load_balancing::subchannel::WeakSubchannel;
use crate::client::name_resolution::Address;
use crate::client::name_resolution::ResolverUpdate;
use crate::rt::GrpcRuntime;
#[derive(Debug)]
pub(crate) struct ChildManager<T: Debug> {
subchannel_to_child_idx: HashMap<WeakSubchannel, usize>,
children: Vec<Child<T>>,
pending_work: Arc<Mutex<HashSet<usize>>>,
runtime: GrpcRuntime,
updated: bool, work_scheduler: Arc<dyn WorkScheduler>,
}
#[non_exhaustive]
#[derive(Debug)]
pub(crate) struct Child<T> {
pub identifier: T,
pub builder: Arc<DynLbPolicyBuilder>,
pub state: LbState,
policy: Box<DynLbPolicy>,
work_scheduler: Arc<ChildWorkScheduler>,
}
pub(crate) struct ChildUpdate<'a, T> {
pub child_identifier: T,
pub child_policy_builder: Arc<DynLbPolicyBuilder>,
pub child_update: Option<(ResolverUpdate, Option<&'a DynLbConfig>)>,
}
impl<T> ChildManager<T>
where
T: Debug + PartialEq + Hash + Eq + Send + Sync + 'static,
{
pub fn new(runtime: GrpcRuntime, work_scheduler: Arc<dyn WorkScheduler>) -> Self {
Self {
subchannel_to_child_idx: Default::default(),
children: Default::default(),
pending_work: Default::default(),
runtime,
work_scheduler,
updated: false,
}
}
pub fn children(&self) -> impl Iterator<Item = &Child<T>> {
self.children.iter()
}
pub fn aggregate_states(&self) -> ConnectivityState {
let mut is_connecting = false;
let mut is_idle = false;
for child in &self.children {
match child.state.connectivity_state {
ConnectivityState::Ready => {
return ConnectivityState::Ready;
}
ConnectivityState::Connecting => {
is_connecting = true;
}
ConnectivityState::Idle => {
is_idle = true;
}
ConnectivityState::TransientFailure => {}
}
}
if is_connecting {
ConnectivityState::Connecting
} else if is_idle {
ConnectivityState::Idle
} else {
ConnectivityState::TransientFailure
}
}
fn resolve_child_controller(
&mut self,
channel_controller: WrappedController,
child_idx: usize,
) {
for csc in channel_controller.created_subchannels {
self.subchannel_to_child_idx
.insert((&csc).into(), child_idx);
}
if let Some(state) = channel_controller.picker_update {
self.children[child_idx].state = state;
self.updated = true;
};
}
pub fn child_updated(&mut self) -> bool {
mem::take(&mut self.updated)
}
pub fn retain_children(
&mut self,
ids_builders: impl IntoIterator<Item = (T, Arc<DynLbPolicyBuilder>)>,
) {
self.reset_children(ids_builders, true);
}
fn reset_children(
&mut self,
ids_builders: impl IntoIterator<Item = (T, Arc<DynLbPolicyBuilder>)>,
retain_only: bool,
) {
let mut pending_work = self.pending_work.lock().unwrap();
let old_pending_work = mem::take(&mut *pending_work);
let old_children = mem::take(&mut self.children);
let old_subchannel_child_map = mem::take(&mut self.subchannel_to_child_idx);
let mut old_child_subchannels: Vec<Vec<WeakSubchannel>> = Vec::new();
old_child_subchannels.resize_with(old_children.len(), Vec::new);
for (subchannel, old_idx) in old_subchannel_child_map {
old_child_subchannels[old_idx].push(subchannel);
}
let mut old_children: HashMap<(&'static str, T), _> = old_children
.into_iter()
.enumerate()
.map(|(old_idx, e)| {
(
(e.builder.name(), e.identifier),
Child {
identifier: old_idx,
policy: e.policy,
builder: e.builder,
state: e.state,
work_scheduler: e.work_scheduler,
},
)
})
.collect();
for (new_idx, (identifier, builder)) in ids_builders.into_iter().enumerate() {
let k = (builder.name(), identifier);
if let Some(old_child) = old_children.remove(&k) {
let old_idx = old_child.identifier;
for subchannel in mem::take(&mut old_child_subchannels[old_idx]) {
self.subchannel_to_child_idx.insert(subchannel, new_idx);
}
if old_pending_work.contains(&old_idx) {
pending_work.insert(new_idx);
}
*old_child.work_scheduler.idx.lock().unwrap() = Some(new_idx);
self.children.push(Child {
builder,
identifier: k.1,
state: old_child.state,
policy: old_child.policy,
work_scheduler: old_child.work_scheduler,
});
} else if !retain_only {
let work_scheduler = Arc::new(ChildWorkScheduler {
pending_work: self.pending_work.clone(),
idx: Mutex::new(Some(new_idx)),
work_scheduler: self.work_scheduler.clone(),
});
let policy = builder.build(LbPolicyOptions {
work_scheduler: work_scheduler.clone(),
runtime: self.runtime.clone(),
});
self.children.push(Child {
builder,
identifier: k.1,
state: LbState::initial(),
policy,
work_scheduler,
});
};
}
for (_, old_child) in old_children {
old_child.work_scheduler.invalidate();
}
}
pub fn update<'a>(
&mut self,
child_updates: impl IntoIterator<Item = ChildUpdate<'a, T>>,
channel_controller: &mut dyn ChannelController,
) -> Result<(), String> {
let mut errs = vec![];
let (ids_builders, updates): (Vec<_>, Vec<_>) = child_updates
.into_iter()
.map(|e| ((e.child_identifier, e.child_policy_builder), e.child_update))
.unzip();
self.reset_children(ids_builders, false);
let mut updates = updates.into_iter();
for child_idx in 0..self.children.len() {
let child = &mut self.children[child_idx];
let child_update = updates.next().unwrap();
let Some((resolver_update, config)) = child_update else {
continue;
};
let mut channel_controller = WrappedController::new(channel_controller);
if let Err(err) =
child
.policy
.resolver_update(resolver_update, config, &mut channel_controller)
{
errs.push(err);
}
self.resolve_child_controller(channel_controller, child_idx);
}
if errs.is_empty() {
Ok(())
} else {
let err = errs
.into_iter()
.map(|e| e.to_string())
.collect::<Vec<_>>()
.join("; ");
Err(err)
}
}
pub fn resolver_update(
&mut self,
resolver_update: ResolverUpdate,
config: Option<&DynLbConfig>,
channel_controller: &mut dyn ChannelController,
) -> Result<(), Box<dyn Error + Send + Sync>> {
let mut errs = Vec::with_capacity(self.children.len());
for child_idx in 0..self.children.len() {
let child = &mut self.children[child_idx];
let mut channel_controller = WrappedController::new(channel_controller);
if let Err(err) = child.policy.resolver_update(
resolver_update.clone(),
config,
&mut channel_controller,
) {
errs.push(err);
}
self.resolve_child_controller(channel_controller, child_idx);
}
if errs.is_empty() {
Ok(())
} else {
let err = errs
.into_iter()
.map(|e| e.to_string())
.collect::<Vec<_>>()
.join("; ");
Err(err.into())
}
}
pub fn subchannel_update(
&mut self,
subchannel: Arc<dyn Subchannel>,
state: &SubchannelState,
channel_controller: &mut dyn ChannelController,
) {
let child_idx = *self
.subchannel_to_child_idx
.get(&WeakSubchannel::new(&subchannel))
.unwrap();
let policy = &mut self.children[child_idx].policy;
let mut channel_controller = WrappedController::new(channel_controller);
policy.subchannel_update(subchannel, state, &mut channel_controller);
self.resolve_child_controller(channel_controller, child_idx);
}
pub fn work(&mut self, channel_controller: &mut dyn ChannelController) {
let child_idxes = mem::take(&mut *self.pending_work.lock().unwrap());
for child_idx in child_idxes {
let mut channel_controller = WrappedController::new(channel_controller);
self.children[child_idx]
.policy
.work(&mut channel_controller);
self.resolve_child_controller(channel_controller, child_idx);
}
}
pub fn exit_idle(&mut self, channel_controller: &mut dyn ChannelController) {
for child_idx in 0..self.children.len() {
let child = &mut self.children[child_idx];
let mut channel_controller = WrappedController::new(channel_controller);
child.policy.exit_idle(&mut channel_controller);
self.resolve_child_controller(channel_controller, child_idx);
}
}
}
struct WrappedController<'a> {
channel_controller: &'a mut dyn ChannelController,
created_subchannels: Vec<Arc<dyn Subchannel>>,
picker_update: Option<LbState>,
}
impl<'a> WrappedController<'a> {
fn new(channel_controller: &'a mut dyn ChannelController) -> Self {
Self {
channel_controller,
created_subchannels: vec![],
picker_update: None,
}
}
}
impl ChannelController for WrappedController<'_> {
fn new_subchannel(&mut self, address: &Address) -> (Arc<dyn Subchannel>, SubchannelState) {
let (subchannel, state) = self.channel_controller.new_subchannel(address);
self.created_subchannels.push(subchannel.clone());
(subchannel, state)
}
fn update_picker(&mut self, update: LbState) {
self.picker_update = Some(update);
}
fn request_resolution(&mut self) {
self.channel_controller.request_resolution();
}
}
#[derive(Debug)]
struct ChildWorkScheduler {
work_scheduler: Arc<dyn WorkScheduler>, pending_work: Arc<Mutex<HashSet<usize>>>, idx: Mutex<Option<usize>>, }
impl WorkScheduler for ChildWorkScheduler {
fn schedule_work(&self) {
let mut pending_work = self.pending_work.lock().unwrap();
if let Some(idx) = *self.idx.lock().unwrap() {
pending_work.insert(idx);
self.work_scheduler.schedule_work();
}
}
}
impl ChildWorkScheduler {
fn invalidate(&self) {
*self.idx.lock().unwrap() = None;
}
}
#[cfg(test)]
mod test {
use std::collections::HashMap;
use std::panic;
use std::sync::Arc;
use std::sync::Mutex;
use std::sync::mpsc;
use crate::client::ConnectivityState;
use crate::client::load_balancing::ChannelController;
use crate::client::load_balancing::DynLbConfig;
use crate::client::load_balancing::DynLbPolicyBuilder;
use crate::client::load_balancing::GLOBAL_LB_REGISTRY;
use crate::client::load_balancing::LbState;
use crate::client::load_balancing::QueuingPicker;
use crate::client::load_balancing::Subchannel;
use crate::client::load_balancing::SubchannelState;
use crate::client::load_balancing::child_manager::ChildManager;
use crate::client::load_balancing::child_manager::ChildUpdate;
use crate::client::load_balancing::test_utils::StubPolicyFuncs;
use crate::client::load_balancing::test_utils::TestChannelController;
use crate::client::load_balancing::test_utils::TestEvent;
use crate::client::load_balancing::test_utils::TestWorkScheduler;
use crate::client::load_balancing::test_utils::{self};
use crate::client::name_resolution::Address;
use crate::client::name_resolution::Endpoint;
use crate::client::name_resolution::ResolverUpdate;
use crate::rt::default_runtime;
fn setup(
funcs: StubPolicyFuncs,
test_name: &'static str,
) -> (
mpsc::Receiver<TestEvent>,
ChildManager<Endpoint>,
Box<dyn ChannelController>,
) {
test_utils::reg_stub_policy(test_name, funcs);
let (tx_events, rx_events) = mpsc::channel::<TestEvent>();
let tcc = Box::new(TestChannelController {
tx_events: tx_events.clone(),
});
let child_manager =
ChildManager::new(default_runtime(), Arc::new(TestWorkScheduler { tx_events }));
(rx_events, child_manager, tcc)
}
fn create_n_endpoints_with_k_addresses(n: usize, k: usize) -> Vec<Endpoint> {
let mut endpoints = Vec::with_capacity(n);
for i in 0..n {
let mut addresses: Vec<Address> = Vec::with_capacity(k);
for j in 0..k {
addresses.push(Address {
address: format!("{}.{}.{}.{}:{}", i + 1, i + 1, i + 1, i + 1, j).into(),
..Default::default()
});
}
endpoints.push(Endpoint {
addresses,
..Default::default()
})
}
endpoints
}
fn send_resolver_update_to_policy(
child_manager: &mut ChildManager<Endpoint>,
endpoints: Vec<Endpoint>,
builder: Arc<DynLbPolicyBuilder>,
tcc: &mut dyn ChannelController,
) -> Result<(), String> {
let updates = endpoints.iter().map(|e| ChildUpdate {
child_identifier: e.clone(),
child_policy_builder: builder.clone(),
child_update: Some((
ResolverUpdate {
attributes: crate::attributes::Attributes::default(),
endpoints: Ok(vec![e.clone()]),
service_config: Ok(None),
resolution_note: None,
},
None,
)),
});
child_manager.update(updates, tcc)
}
fn move_subchannel_to_state(
child_manager: &mut ChildManager<Endpoint>,
subchannel: Arc<dyn Subchannel>,
tcc: &mut dyn ChannelController,
state: &SubchannelState,
) {
child_manager.subchannel_update(subchannel, state, tcc);
}
fn verify_subchannel_creation_from_policy(
rx_events: &mut mpsc::Receiver<TestEvent>,
number_of_subchannels: usize,
) -> Vec<Arc<dyn Subchannel>> {
let mut subchannels = Vec::new();
for _ in 0..number_of_subchannels {
match rx_events.recv().unwrap() {
TestEvent::NewSubchannel(sc) => {
subchannels.push(sc);
}
other => panic!("unexpected event {:?}", other),
};
}
subchannels
}
fn create_verifying_funcs_for_aggregate_tests() -> StubPolicyFuncs {
StubPolicyFuncs {
resolver_update: Some(Arc::new(
move |data, update: ResolverUpdate, _, controller| {
assert_eq!(update.endpoints.iter().len(), 1);
let endpoint = update.endpoints.unwrap().pop().unwrap();
let subchannel = controller.new_subchannel(&endpoint.addresses[0]);
Ok(())
},
)),
subchannel_update: Some(Arc::new(
move |data, updated_subchannel, state, controller| {
controller.update_picker(LbState {
connectivity_state: state.connectivity_state,
picker: Arc::new(QueuingPicker {}),
});
},
)),
..Default::default()
}
}
#[test]
fn childmanager_aggregate_state_is_ready_if_any_child_is_ready() {
let test_name = "stub-childmanager_aggregate_state_is_ready_if_any_child_is_ready";
let (mut rx_events, mut child_manager, mut tcc) =
setup(create_verifying_funcs_for_aggregate_tests(), test_name);
let builder: Arc<DynLbPolicyBuilder> = GLOBAL_LB_REGISTRY.get_policy(test_name).unwrap();
let endpoints = create_n_endpoints_with_k_addresses(4, 1);
send_resolver_update_to_policy(
&mut child_manager,
endpoints.clone(),
builder,
tcc.as_mut(),
)
.unwrap();
let mut subchannels = vec![];
for endpoint in endpoints {
subchannels.push(
verify_subchannel_creation_from_policy(&mut rx_events, endpoint.addresses.len())
.remove(0),
);
}
let mut subchannels = subchannels.into_iter();
move_subchannel_to_state(
&mut child_manager,
subchannels.next().unwrap(),
tcc.as_mut(),
&SubchannelState::transient_failure("n/a"),
);
move_subchannel_to_state(
&mut child_manager,
subchannels.next().unwrap(),
tcc.as_mut(),
&SubchannelState::idle(),
);
move_subchannel_to_state(
&mut child_manager,
subchannels.next().unwrap(),
tcc.as_mut(),
&SubchannelState::connecting(),
);
move_subchannel_to_state(
&mut child_manager,
subchannels.next().unwrap(),
tcc.as_mut(),
&SubchannelState::ready(),
);
assert_eq!(child_manager.aggregate_states(), ConnectivityState::Ready);
}
#[test]
fn childmanager_aggregate_state_is_connecting_if_no_child_is_ready() {
let test_name = "stub-childmanager_aggregate_state_is_connecting_if_no_child_is_ready";
let (mut rx_events, mut child_manager, mut tcc) =
setup(create_verifying_funcs_for_aggregate_tests(), test_name);
let builder: Arc<DynLbPolicyBuilder> = GLOBAL_LB_REGISTRY.get_policy(test_name).unwrap();
let endpoints = create_n_endpoints_with_k_addresses(3, 1);
send_resolver_update_to_policy(
&mut child_manager,
endpoints.clone(),
builder,
tcc.as_mut(),
)
.unwrap();
let mut subchannels = vec![];
for endpoint in endpoints {
subchannels.push(
verify_subchannel_creation_from_policy(&mut rx_events, endpoint.addresses.len())
.remove(0),
);
}
let mut subchannels = subchannels.into_iter();
move_subchannel_to_state(
&mut child_manager,
subchannels.next().unwrap(),
tcc.as_mut(),
&SubchannelState::transient_failure("n/a"),
);
move_subchannel_to_state(
&mut child_manager,
subchannels.next().unwrap(),
tcc.as_mut(),
&SubchannelState::idle(),
);
move_subchannel_to_state(
&mut child_manager,
subchannels.next().unwrap(),
tcc.as_mut(),
&SubchannelState::connecting(),
);
assert_eq!(
child_manager.aggregate_states(),
ConnectivityState::Connecting
);
}
#[test]
fn childmanager_aggregate_state_is_idle_if_only_idle_and_failure() {
let test_name = "stub-childmanager_aggregate_state_is_idle_if_only_idle_and_failure";
let (mut rx_events, mut child_manager, mut tcc) =
setup(create_verifying_funcs_for_aggregate_tests(), test_name);
let builder: Arc<DynLbPolicyBuilder> = GLOBAL_LB_REGISTRY.get_policy(test_name).unwrap();
let endpoints = create_n_endpoints_with_k_addresses(2, 1);
send_resolver_update_to_policy(
&mut child_manager,
endpoints.clone(),
builder,
tcc.as_mut(),
)
.unwrap();
let mut subchannels = vec![];
for endpoint in endpoints {
subchannels.push(
verify_subchannel_creation_from_policy(&mut rx_events, endpoint.addresses.len())
.remove(0),
);
}
let mut subchannels = subchannels.into_iter();
move_subchannel_to_state(
&mut child_manager,
subchannels.next().unwrap(),
tcc.as_mut(),
&SubchannelState::transient_failure("n/a"),
);
move_subchannel_to_state(
&mut child_manager,
subchannels.next().unwrap(),
tcc.as_mut(),
&SubchannelState::idle(),
);
assert_eq!(child_manager.aggregate_states(), ConnectivityState::Idle);
}
#[test]
fn childmanager_aggregate_state_is_transient_failure_if_all_children_are() {
let test_name =
"stub-childmanager_aggregate_state_is_transient_failure_if_all_children_are";
let (mut rx_events, mut child_manager, mut tcc) =
setup(create_verifying_funcs_for_aggregate_tests(), test_name);
let builder: Arc<DynLbPolicyBuilder> = GLOBAL_LB_REGISTRY.get_policy(test_name).unwrap();
let endpoints = create_n_endpoints_with_k_addresses(2, 1);
send_resolver_update_to_policy(
&mut child_manager,
endpoints.clone(),
builder,
tcc.as_mut(),
)
.unwrap();
let mut subchannels = vec![];
for endpoint in endpoints {
subchannels.push(
verify_subchannel_creation_from_policy(&mut rx_events, endpoint.addresses.len())
.remove(0),
);
}
let mut subchannels = subchannels.into_iter();
move_subchannel_to_state(
&mut child_manager,
subchannels.next().unwrap(),
tcc.as_mut(),
&SubchannelState::transient_failure("n/a"),
);
move_subchannel_to_state(
&mut child_manager,
subchannels.next().unwrap(),
tcc.as_mut(),
&SubchannelState::transient_failure("n/a"),
);
assert_eq!(
child_manager.aggregate_states(),
ConnectivityState::TransientFailure
);
}
struct ScheduleWorkStubData {
requested_work: bool,
}
fn create_funcs_for_schedule_work_tests(name: &'static str) -> StubPolicyFuncs {
StubPolicyFuncs {
resolver_update: Some(Arc::new(move |data, _update, lbcfg, _controller| {
if data.test_data.is_none() {
data.test_data = Some(Box::new(ScheduleWorkStubData {
requested_work: false,
}));
}
let stubdata = data
.test_data
.as_mut()
.unwrap()
.downcast_mut::<ScheduleWorkStubData>()
.unwrap();
assert!(!stubdata.requested_work);
if lbcfg
.unwrap()
.downcast_ref::<Mutex<HashMap<&'static str, ()>>>()
.unwrap()
.lock()
.unwrap()
.contains_key(name)
{
stubdata.requested_work = true;
data.lb_policy_options.work_scheduler.schedule_work();
}
Ok(())
})),
work: Some(Arc::new(move |data, _controller| {
println!("work called for {name}");
let stubdata = data
.test_data
.as_mut()
.unwrap()
.downcast_mut::<ScheduleWorkStubData>()
.unwrap();
stubdata.requested_work = false;
})),
..Default::default()
}
}
#[test]
fn childmanager_schedule_work_works() {
let name1 = "childmanager_schedule_work_works-one";
let name2 = "childmanager_schedule_work_works-two";
test_utils::reg_stub_policy(name1, create_funcs_for_schedule_work_tests(name1));
test_utils::reg_stub_policy(name2, create_funcs_for_schedule_work_tests(name2));
let (tx_events, rx_events) = mpsc::channel::<TestEvent>();
let mut tcc = TestChannelController {
tx_events: tx_events.clone(),
};
let names = [name1, name2];
let mut child_manager =
ChildManager::new(default_runtime(), Arc::new(TestWorkScheduler { tx_events }));
let cfg = Arc::new(Mutex::new(HashMap::<&'static str, ()>::new())) as DynLbConfig;
let children = cfg
.downcast_ref::<Mutex<HashMap<&'static str, ()>>>()
.unwrap();
children.lock().unwrap().insert(name1, ());
let updates = names.iter().map(|name| {
let child_policy_builder: Arc<DynLbPolicyBuilder> =
GLOBAL_LB_REGISTRY.get_policy(name).unwrap();
ChildUpdate {
child_identifier: (),
child_policy_builder,
child_update: Some((ResolverUpdate::default(), Some(&cfg))),
}
});
child_manager.update(updates.clone(), &mut tcc).unwrap();
match rx_events.recv().unwrap() {
TestEvent::ScheduleWork => {}
other => panic!("unexpected event {:?}", other),
};
assert_eq!(child_manager.pending_work.lock().unwrap().len(), 1);
let idx = *child_manager
.pending_work
.lock()
.unwrap()
.iter()
.next()
.unwrap();
assert_eq!(child_manager.children[idx].builder.name(), name1);
child_manager.work(&mut tcc);
assert_eq!(child_manager.pending_work.lock().unwrap().len(), 0);
children.lock().unwrap().insert(name2, ());
child_manager.update(updates.clone(), &mut tcc).unwrap();
match rx_events.recv().unwrap() {
TestEvent::ScheduleWork => {}
other => panic!("unexpected event {:?}", other),
};
assert_eq!(child_manager.pending_work.lock().unwrap().len(), 2);
child_manager.work(&mut tcc);
assert_eq!(child_manager.pending_work.lock().unwrap().len(), 0);
child_manager.update(updates, &mut tcc).unwrap();
}
}