use std::collections::HashMap;
use std::collections::HashSet;
use std::hash::Hash;
use std::hash::Hasher;
use std::sync::Arc;
use std::sync::Mutex;
use crate::StatusCodeError;
use crate::StatusError;
use crate::client::load_balancing::ChannelController;
use crate::client::load_balancing::LbPolicy;
use crate::client::load_balancing::LbState;
use crate::client::load_balancing::PickResult;
use crate::client::load_balancing::Picker;
use crate::client::load_balancing::subchannel::ForwardingSubchannel;
use crate::client::load_balancing::subchannel::Subchannel;
use crate::client::load_balancing::subchannel::SubchannelState;
use crate::client::load_balancing::subchannel::WeakSubchannel;
use crate::client::name_resolution::Address;
use crate::client::name_resolution::ResolverUpdate;
use crate::core::RequestHeaders;
#[derive(Debug)]
pub(crate) struct SubchannelSharing<T> {
delegate: T,
inner: Arc<Mutex<Inner>>,
}
impl<T> SubchannelSharing<T> {
pub(crate) fn new(delegate: T) -> Self {
Self {
delegate,
inner: Arc::new(Mutex::new(Inner {
subchannels_by_address: HashMap::new(),
subchannels_int_to_ext: HashMap::new(),
})),
}
}
}
#[derive(Debug)]
struct Inner {
subchannels_by_address: HashMap<Address, Arc<dyn Subchannel>>,
subchannels_int_to_ext:
HashMap<Arc<dyn Subchannel>, (SubchannelState, HashSet<WeakSubchannel>)>,
}
impl<T: LbPolicy> LbPolicy for SubchannelSharing<T> {
type LbConfig = T::LbConfig;
fn resolver_update(
&mut self,
update: ResolverUpdate,
config: Option<&T::LbConfig>,
channel_controller: &mut dyn ChannelController,
) -> Result<(), String> {
let mut channel_controller = SharingChannelController {
balancer_inner: self.inner.clone(),
delegate: channel_controller,
};
self.delegate
.resolver_update(update, config, &mut channel_controller)
}
fn subchannel_update(
&mut self,
subchannel: Arc<dyn Subchannel>,
state: &SubchannelState,
channel_controller: &mut dyn ChannelController,
) {
let mut channel_controller = SharingChannelController {
balancer_inner: self.inner.clone(),
delegate: channel_controller,
};
let mut inner = self.inner.lock().unwrap();
let Some((old_state, subchannel_set)) = inner.subchannels_int_to_ext.get_mut(&subchannel)
else {
return;
};
*old_state = state.clone();
let ext_subchannels: Vec<_> = subchannel_set
.iter()
.filter_map(|weak| weak.upgrade())
.collect();
drop(inner);
for ext in ext_subchannels {
self.delegate
.subchannel_update(ext, state, &mut channel_controller)
}
}
fn work(&mut self, channel_controller: &mut dyn ChannelController) {
let mut channel_controller = SharingChannelController {
balancer_inner: self.inner.clone(),
delegate: channel_controller,
};
self.delegate.work(&mut channel_controller);
}
fn exit_idle(&mut self, channel_controller: &mut dyn ChannelController) {
let mut channel_controller = SharingChannelController {
balancer_inner: self.inner.clone(),
delegate: channel_controller,
};
self.delegate.exit_idle(&mut channel_controller);
}
}
#[derive(Debug)]
struct SharedSubchannel {
delegate: Arc<dyn Subchannel>,
balancer_inner: Arc<Mutex<Inner>>,
}
impl PartialEq for SharedSubchannel {
fn eq(&self, other: &Self) -> bool {
PartialEq::eq(&self.delegate, &other.delegate)
}
}
impl Eq for SharedSubchannel {}
impl Hash for SharedSubchannel {
fn hash<H: Hasher>(&self, state: &mut H) {
self.delegate.hash(state);
}
}
impl ForwardingSubchannel for SharedSubchannel {
fn delegate(&self) -> &Arc<dyn Subchannel> {
&self.delegate
}
}
impl Drop for SharedSubchannel {
fn drop(&mut self) {
let mut inner = self.balancer_inner.lock().unwrap();
let ext_subchannels = &mut inner
.subchannels_int_to_ext
.get_mut(&self.delegate)
.expect("should always find internal subchannel")
.1;
ext_subchannels.retain(|weak| weak.strong_count() != 0);
if ext_subchannels.is_empty() {
inner.subchannels_int_to_ext.remove(&self.delegate);
inner
.subchannels_by_address
.remove(&self.delegate.address());
}
}
}
struct SharingChannelController<'a> {
balancer_inner: Arc<Mutex<Inner>>,
delegate: &'a mut dyn ChannelController,
}
impl<'a> ChannelController for SharingChannelController<'a> {
fn new_subchannel(&mut self, address: &Address) -> (Arc<dyn Subchannel>, SubchannelState) {
let mut inner = self.balancer_inner.lock().unwrap();
let mut new_state = None;
let int_subchannel = inner
.subchannels_by_address
.entry(address.clone())
.or_insert_with(|| {
let (new_sc, state) = self.delegate.new_subchannel(address);
new_state = Some(state);
new_sc
})
.clone();
let ext_subchannel: Arc<dyn Subchannel> = Arc::new(SharedSubchannel {
delegate: int_subchannel.clone(),
balancer_inner: self.balancer_inner.clone(),
});
let entry = inner
.subchannels_int_to_ext
.entry(int_subchannel)
.or_insert_with(|| (new_state.unwrap(), HashSet::new()));
entry.1.insert((&ext_subchannel).into());
(ext_subchannel, entry.0.clone())
}
fn update_picker(&mut self, mut update: LbState) {
update.picker = UnwrapPicker::new_arc(update.picker);
self.delegate.update_picker(update);
}
fn request_resolution(&mut self) {
self.delegate.request_resolution();
}
}
#[derive(Debug)]
struct UnwrapPicker {
delegate: Arc<dyn Picker>,
}
impl UnwrapPicker {
fn new_arc(delegate: Arc<dyn Picker>) -> Arc<Self> {
Arc::new(Self { delegate })
}
}
impl Picker for UnwrapPicker {
fn pick(&self, request: &RequestHeaders) -> PickResult {
let result = self.delegate.pick(request);
match result {
PickResult::Pick(mut pick) => {
let Some(subchannel) = pick.subchannel.downcast_ref::<SharedSubchannel>() else {
return PickResult::Fail(StatusError::new(
StatusCodeError::Internal,
format!(
"received unexpected subchannel type: {:?}",
pick.subchannel.type_id()
),
));
};
pick.subchannel = subchannel.delegate.clone();
PickResult::Pick(pick)
}
_ => result,
}
}
}
#[cfg(test)]
mod tests {
use std::fmt::Debug;
use std::sync::Arc;
use std::sync::Mutex;
use std::sync::mpsc;
use super::*;
use crate::client::ConnectivityState;
use crate::client::load_balancing::LbPolicy;
use crate::client::load_balancing::LbPolicyOptions;
use crate::client::load_balancing::Pick;
use crate::client::load_balancing::PickResult;
use crate::client::load_balancing::Picker;
use crate::client::load_balancing::subchannel::SubchannelState;
use crate::client::load_balancing::test_utils::StubPolicy;
use crate::client::load_balancing::test_utils::StubPolicyFuncs;
use crate::client::load_balancing::test_utils::TestChannelController;
use crate::client::load_balancing::test_utils::TestEvent;
use crate::client::load_balancing::test_utils::TestWorkScheduler;
use crate::client::load_balancing::test_utils::new_request_headers;
use crate::client::name_resolution::Address;
use crate::client::name_resolution::ResolverUpdate;
use crate::metadata::MetadataMap;
use crate::rt::default_runtime;
fn test_lb_policy_options(tx_events: mpsc::Sender<TestEvent>) -> LbPolicyOptions {
LbPolicyOptions {
work_scheduler: Arc::new(TestWorkScheduler { tx_events }),
runtime: default_runtime(),
}
}
#[test]
fn test_single_subchannel() {
let (tx_events, rx_events) = mpsc::channel();
let mut cc = TestChannelController {
tx_events: tx_events.clone(),
};
let sc_out = Arc::new(Mutex::new(None));
let sc_out_clone = sc_out.clone();
let mock = StubPolicy::new(
StubPolicyFuncs {
work: Some(Arc::new(move |_data, cc| {
let addr = Address {
address: "127.0.0.1:80".to_string().into(),
..Default::default()
};
let sc = cc.new_subchannel(&addr).0;
*sc_out_clone.lock().unwrap() = Some(sc);
})),
..Default::default()
},
test_lb_policy_options(tx_events.clone()),
);
let mut sharing = SubchannelSharing::new(mock);
sharing.work(&mut cc);
let event = rx_events.recv().unwrap();
let TestEvent::NewSubchannel(internal_sc) = event else {
panic!("expected NewSubchannel")
};
let external_sc = sc_out.lock().unwrap().take().unwrap();
let shared = external_sc.downcast_ref::<SharedSubchannel>().unwrap();
assert!(Arc::ptr_eq(&shared.delegate, &internal_sc));
}
#[test]
fn test_multiple_subchannels_same_address() {
let (tx_events, rx_events) = mpsc::channel();
let mut cc = TestChannelController {
tx_events: tx_events.clone(),
};
let sc_out1 = Arc::new(Mutex::new(None));
let sc_out1_clone = sc_out1.clone();
let sc_out2 = Arc::new(Mutex::new(None));
let sc_out2_clone = sc_out2.clone();
let mock = StubPolicy::new(
StubPolicyFuncs {
work: Some(Arc::new(move |_data, cc| {
let addr = Address {
address: "127.0.0.1:80".to_string().into(),
..Default::default()
};
*sc_out1_clone.lock().unwrap() = Some(cc.new_subchannel(&addr).0);
*sc_out2_clone.lock().unwrap() = Some(cc.new_subchannel(&addr).0);
})),
..Default::default()
},
test_lb_policy_options(tx_events.clone()),
);
let mut sharing = SubchannelSharing::new(mock);
sharing.work(&mut cc);
let event = rx_events.recv().unwrap();
let TestEvent::NewSubchannel(internal_sc) = event else {
panic!("expected NewSubchannel")
};
assert!(rx_events.try_recv().is_err());
let external_sc1 = sc_out1.lock().unwrap().take().unwrap();
let external_sc2 = sc_out2.lock().unwrap().take().unwrap();
let shared1 = external_sc1.downcast_ref::<SharedSubchannel>().unwrap();
let shared2 = external_sc2.downcast_ref::<SharedSubchannel>().unwrap();
assert!(Arc::ptr_eq(&shared1.delegate, &internal_sc));
assert!(Arc::ptr_eq(&shared2.delegate, &internal_sc));
assert!(!Arc::ptr_eq(&external_sc1, &external_sc2));
}
#[test]
fn test_multiple_subchannels_different_addresses() {
let (tx_events, rx_events) = mpsc::channel();
let mut cc = TestChannelController {
tx_events: tx_events.clone(),
};
let sc_out1 = Arc::new(Mutex::new(None));
let sc_out1_clone = sc_out1.clone();
let sc_out2 = Arc::new(Mutex::new(None));
let sc_out2_clone = sc_out2.clone();
let mock = StubPolicy::new(
StubPolicyFuncs {
work: Some(Arc::new(move |_data, cc| {
let addr1 = Address {
address: "127.0.0.1:80".to_string().into(),
..Default::default()
};
let addr2 = Address {
address: "127.0.0.2:80".to_string().into(),
..Default::default()
};
*sc_out1_clone.lock().unwrap() = Some(cc.new_subchannel(&addr1).0);
*sc_out2_clone.lock().unwrap() = Some(cc.new_subchannel(&addr2).0);
})),
..Default::default()
},
test_lb_policy_options(tx_events.clone()),
);
let mut sharing = SubchannelSharing::new(mock);
sharing.work(&mut cc);
let event1 = rx_events.recv().unwrap();
let event2 = rx_events.recv().unwrap();
assert!(matches!(event1, TestEvent::NewSubchannel(_)));
assert!(matches!(event2, TestEvent::NewSubchannel(_)));
assert!(rx_events.try_recv().is_err());
let external_sc1 = sc_out1.lock().unwrap().take().unwrap();
let external_sc2 = sc_out2.lock().unwrap().take().unwrap();
let shared1 = external_sc1.downcast_ref::<SharedSubchannel>().unwrap();
let shared2 = external_sc2.downcast_ref::<SharedSubchannel>().unwrap();
assert!(!Arc::ptr_eq(&shared1.delegate, &shared2.delegate));
}
fn test_subchannel_cleanup_on_drop() {
let (tx_events, rx_events) = mpsc::channel();
let mut cc = TestChannelController {
tx_events: tx_events.clone(),
};
let update_calls = Arc::new(Mutex::new(0));
let update_calls_clone = update_calls.clone();
let sc_out1 = Arc::new(Mutex::new(None));
let sc_out1_clone = sc_out1.clone();
let sc_out2 = Arc::new(Mutex::new(None));
let sc_out2_clone = sc_out2.clone();
let sc_out3 = Arc::new(Mutex::new(None));
let sc_out3_clone = sc_out3.clone();
let work_calls = Arc::new(Mutex::new(0));
let work_calls_clone = work_calls.clone();
let mock = StubPolicy::new(
StubPolicyFuncs {
work: Some(Arc::new(move |_data, cc| {
let addr = Address {
address: "127.0.0.1:80".to_string().into(),
..Default::default()
};
let mut num_calls = work_calls_clone.lock().unwrap();
if *num_calls == 0 {
*sc_out1_clone.lock().unwrap() = Some(cc.new_subchannel(&addr).0);
*sc_out2_clone.lock().unwrap() = Some(cc.new_subchannel(&addr).0);
} else if *num_calls == 1 {
*sc_out3_clone.lock().unwrap() = Some(cc.new_subchannel(&addr).0);
}
*num_calls += 1;
})),
subchannel_update: Some(Arc::new(move |_data, _sc, _state, _cc| {
*update_calls_clone.lock().unwrap() += 1;
})),
..Default::default()
},
test_lb_policy_options(tx_events.clone()),
);
let mut sharing = SubchannelSharing::new(mock);
sharing.work(&mut cc);
let _ = rx_events.recv().unwrap();
let external_sc1 = sc_out1.lock().unwrap().take().unwrap();
let external_sc2 = sc_out2.lock().unwrap().take().unwrap();
let internal_sc = external_sc1
.downcast_ref::<SharedSubchannel>()
.unwrap()
.delegate
.clone();
let state = SubchannelState::idle();
sharing.subchannel_update(internal_sc.clone(), &state, &mut cc);
assert_eq!(*update_calls.lock().unwrap(), 2);
drop(external_sc1);
*update_calls.lock().unwrap() = 0;
sharing.subchannel_update(internal_sc.clone(), &state, &mut cc);
assert_eq!(*update_calls.lock().unwrap(), 1);
assert_eq!(Arc::strong_count(&internal_sc), 4);
drop(external_sc2);
assert_eq!(Arc::strong_count(&internal_sc), 1);
*update_calls.lock().unwrap() = 0;
sharing.subchannel_update(internal_sc.clone(), &state, &mut cc);
assert_eq!(*update_calls.lock().unwrap(), 0);
sharing.work(&mut cc);
let event = rx_events.recv().unwrap();
assert!(matches!(event, TestEvent::NewSubchannel(_)));
let external_sc3 = sc_out3.lock().unwrap().take().unwrap();
let shared_sc3 = external_sc3.downcast_ref::<SharedSubchannel>().unwrap();
assert!(!Arc::ptr_eq(&shared_sc3.delegate, &internal_sc));
}
#[test]
fn test_subchannel_update_broadcasts() {
let (tx_events, rx_events) = mpsc::channel();
let mut cc = TestChannelController {
tx_events: tx_events.clone(),
};
let update_calls = Arc::new(Mutex::new(0));
let update_calls_clone = update_calls.clone();
let sc_out1 = Arc::new(Mutex::new(None));
let sc_out1_clone = sc_out1.clone();
let sc_out2 = Arc::new(Mutex::new(None));
let sc_out2_clone = sc_out2.clone();
let mock = StubPolicy::new(
StubPolicyFuncs {
work: Some(Arc::new(move |_data, cc| {
let addr = Address {
address: "127.0.0.1:80".to_string().into(),
..Default::default()
};
*sc_out1_clone.lock().unwrap() = Some(cc.new_subchannel(&addr).0);
*sc_out2_clone.lock().unwrap() = Some(cc.new_subchannel(&addr).0);
})),
subchannel_update: Some(Arc::new(move |_data, _sc, _state, _cc| {
*update_calls_clone.lock().unwrap() += 1;
})),
..Default::default()
},
test_lb_policy_options(tx_events.clone()),
);
let mut sharing = SubchannelSharing::new(mock);
sharing.work(&mut cc);
let _ = rx_events.recv().unwrap();
let external_sc1 = sc_out1.lock().unwrap().take().unwrap();
let external_sc2 = sc_out2.lock().unwrap().take().unwrap();
let internal_sc = external_sc1
.downcast_ref::<SharedSubchannel>()
.unwrap()
.delegate
.clone();
let state = SubchannelState::idle();
sharing.subchannel_update(internal_sc.clone(), &state, &mut cc);
assert_eq!(*update_calls.lock().unwrap(), 2);
drop(external_sc1);
sharing.subchannel_update(internal_sc, &state, &mut cc);
assert_eq!(*update_calls.lock().unwrap(), 3);
}
#[test]
fn test_picker_unwraps_shared_subchannel() {
let (tx_events, rx_events) = mpsc::channel();
let mut cc = TestChannelController {
tx_events: tx_events.clone(),
};
let sc_out = Arc::new(Mutex::new(None));
let sc_out_clone = sc_out.clone();
let mock = StubPolicy::new(
StubPolicyFuncs {
work: Some(Arc::new(move |_data, cc| {
let addr = Address {
address: "127.0.0.1:80".to_string().into(),
..Default::default()
};
let sc = cc.new_subchannel(&addr).0;
*sc_out_clone.lock().unwrap() = Some(sc.clone());
#[derive(Debug)]
struct MockPicker {
sc: Arc<dyn Subchannel>,
}
impl Picker for MockPicker {
fn pick(&self, _req: &RequestHeaders) -> PickResult {
PickResult::Pick(Pick {
subchannel: self.sc.clone(),
metadata: MetadataMap::new(),
on_complete: None,
})
}
}
cc.update_picker(LbState {
connectivity_state: ConnectivityState::Ready,
picker: Arc::new(MockPicker { sc }),
});
})),
..Default::default()
},
test_lb_policy_options(tx_events.clone()),
);
let mut sharing = SubchannelSharing::new(mock);
sharing.work(&mut cc);
let _ = rx_events.recv().unwrap();
let event = rx_events.recv().unwrap();
let TestEvent::UpdatePicker(state) = event else {
panic!("expected UpdatePicker")
};
let req = new_request_headers();
let result = state.picker.pick(&req);
let PickResult::Pick(pick) = result else {
panic!("expected Pick")
};
let external_sc = sc_out.lock().unwrap().take().unwrap();
let shared = external_sc.downcast_ref::<SharedSubchannel>().unwrap();
assert!(Arc::ptr_eq(&pick.subchannel, &shared.delegate));
}
#[test]
fn test_delegates_other_methods() {
let (tx_events, rx_events) = mpsc::channel();
let mut cc = TestChannelController {
tx_events: tx_events.clone(),
};
let called = Arc::new(Mutex::new(vec![]));
let mock = StubPolicy::new(
StubPolicyFuncs {
resolver_update: Some(Arc::new({
let called_clone = called.clone();
move |_data, _update, _config, _cc| {
called_clone.lock().unwrap().push("resolver_update");
Ok(())
}
})),
work: Some(Arc::new({
let called_clone = called.clone();
move |_data, cc| {
called_clone.lock().unwrap().push("work");
cc.request_resolution();
}
})),
exit_idle: Some(Arc::new({
let called_clone = called.clone();
move |_data, _cc| called_clone.lock().unwrap().push("exit_idle")
})),
..Default::default()
},
test_lb_policy_options(tx_events.clone()),
);
let mut sharing = SubchannelSharing::new(mock);
let update = ResolverUpdate::default();
sharing.resolver_update(update, None, &mut cc).unwrap();
sharing.work(&mut cc);
sharing.exit_idle(&mut cc);
assert_eq!(
*called.lock().unwrap(),
vec!["resolver_update", "work", "exit_idle"]
);
let event = rx_events.recv().unwrap();
assert!(matches!(event, TestEvent::RequestResolution));
}
#[test]
fn test_subchannel_update_deadlock() {
let (tx_events, rx_events) = mpsc::channel();
let mut cc = TestChannelController {
tx_events: tx_events.clone(),
};
let sc_out1 = Arc::new(Mutex::new(None));
let sc_out1_clone = sc_out1.clone();
let mock = StubPolicy::new(
StubPolicyFuncs {
work: Some(Arc::new(move |_data, cc| {
let addr = Address {
address: "127.0.0.1:80".to_string().into(),
..Default::default()
};
*sc_out1_clone.lock().unwrap() = Some(cc.new_subchannel(&addr).0);
})),
subchannel_update: Some(Arc::new(move |_data, _sc, _state, cc| {
let addr = Address {
address: "127.0.0.2:80".to_string().into(),
..Default::default()
};
cc.new_subchannel(&addr);
})),
..Default::default()
},
test_lb_policy_options(tx_events.clone()),
);
let mut sharing = SubchannelSharing::new(mock);
sharing.work(&mut cc);
let event = rx_events.recv().unwrap();
let TestEvent::NewSubchannel(_int_sc) = event else {
panic!("expected NewSubchannel")
};
let external_sc = sc_out1.lock().unwrap().take().unwrap();
let internal_sc = external_sc
.downcast_ref::<SharedSubchannel>()
.unwrap()
.delegate
.clone();
let state = SubchannelState::idle();
sharing.subchannel_update(internal_sc.clone(), &state, &mut cc);
}
#[test]
fn test_new_subchannel_state() {
let (tx_events, rx_events) = mpsc::channel();
let mut cc = TestChannelController {
tx_events: tx_events.clone(),
};
let (tx_work, rx_work) =
mpsc::channel::<Box<dyn FnOnce(&mut dyn ChannelController) + Send>>();
let rx_work = Mutex::new(rx_work);
let mock = StubPolicy::new(
StubPolicyFuncs {
work: Some(Arc::new(move |_data, cc| {
(rx_work.lock().unwrap().recv().unwrap())(cc);
})),
..Default::default()
},
test_lb_policy_options(tx_events.clone()),
);
let mut sharing = SubchannelSharing::new(mock);
let addr = Address {
address: "127.0.0.2:80".to_string().into(),
..Default::default()
};
let sc1 = Arc::new(Mutex::new(None));
let sc1_clone = sc1.clone();
let addr_clone = addr.clone();
tx_work
.send(Box::new(move |cc| {
let (sc, state) = cc.new_subchannel(&addr_clone);
assert_eq!(state.connectivity_state, ConnectivityState::Idle);
*sc1_clone.lock().unwrap() = Some(sc);
}))
.unwrap();
sharing.work(&mut cc);
let event = rx_events.recv().unwrap();
let TestEvent::NewSubchannel(int_sc) = event else {
panic!("expected NewSubchannel")
};
sharing.subchannel_update(int_sc.clone(), &SubchannelState::connecting(), &mut cc);
let addr_clone = addr.clone();
tx_work
.send(Box::new(move |cc| {
let (sc, state) = cc.new_subchannel(&addr_clone);
assert_eq!(state.connectivity_state, ConnectivityState::Connecting);
}))
.unwrap();
sharing.work(&mut cc);
sharing.subchannel_update(int_sc.clone(), &SubchannelState::ready(), &mut cc);
let addr_clone = addr.clone();
tx_work
.send(Box::new(move |cc| {
let (sc, state) = cc.new_subchannel(&addr_clone);
assert_eq!(state.connectivity_state, ConnectivityState::Ready);
}))
.unwrap();
sharing.work(&mut cc);
}
}