use std::any::Any;
use std::fmt::Debug;
use std::hash::Hash;
use std::sync::Arc;
use serde::Deserialize;
use serde::Serialize;
use tokio::sync::Notify;
use crate::client::load_balancing::ChannelController;
use crate::client::load_balancing::DynLbConfig;
use crate::client::load_balancing::DynLbPolicy;
use crate::client::load_balancing::LbPolicy;
use crate::client::load_balancing::LbPolicyBuilder;
use crate::client::load_balancing::LbPolicyOptions;
use crate::client::load_balancing::LbState;
use crate::client::load_balancing::ParsedJsonLbConfig;
use crate::client::load_balancing::Subchannel;
use crate::client::load_balancing::SubchannelState;
use crate::client::load_balancing::WorkScheduler;
use crate::client::load_balancing::subchannel::ForwardingSubchannel;
use crate::client::name_resolution::Address;
use crate::client::name_resolution::ResolverUpdate;
use crate::core::RequestHeaders;
pub(crate) fn new_request_headers() -> RequestHeaders {
RequestHeaders::default()
}
pub(crate) struct TestSubchannel {
address: Address,
tx_connect: std::sync::mpsc::Sender<TestEvent>,
}
impl TestSubchannel {
pub fn new(address: Address, tx_connect: std::sync::mpsc::Sender<TestEvent>) -> Self {
Self {
address,
tx_connect,
}
}
}
impl ForwardingSubchannel for TestSubchannel {
fn delegate(&self) -> &Arc<dyn Subchannel> {
panic!("unsupported operation on a test subchannel");
}
fn address(&self) -> Address {
self.address.clone()
}
fn connect(&self) {
println!("connect called for subchannel {}", self.address);
self.tx_connect
.send(TestEvent::Connect(self.address.clone()))
.unwrap();
}
}
impl Hash for TestSubchannel {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
self.address.hash(state);
}
}
impl PartialEq for TestSubchannel {
fn eq(&self, other: &Self) -> bool {
std::ptr::eq(self, other)
}
}
impl Eq for TestSubchannel {}
pub(crate) enum TestEvent {
NewSubchannel(Arc<dyn Subchannel>),
UpdatePicker(LbState),
RequestResolution,
Connect(Address),
ScheduleWork,
}
impl Debug for TestEvent {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::NewSubchannel(sc) => write!(f, "NewSubchannel({})", sc.address()),
Self::UpdatePicker(state) => write!(f, "UpdatePicker({})", state.connectivity_state),
Self::RequestResolution => write!(f, "RequestResolution"),
Self::Connect(addr) => write!(f, "Connect({:?})", addr.address),
Self::ScheduleWork => write!(f, "ScheduleWork"),
}
}
}
pub(crate) struct TestChannelController {
pub(crate) tx_events: std::sync::mpsc::Sender<TestEvent>,
}
impl ChannelController for TestChannelController {
fn new_subchannel(&mut self, address: &Address) -> (Arc<dyn Subchannel>, SubchannelState) {
println!("new_subchannel called for address {}", address);
let notify = Arc::new(Notify::new());
let subchannel: Arc<dyn Subchannel> =
Arc::new(TestSubchannel::new(address.clone(), self.tx_events.clone()));
self.tx_events
.send(TestEvent::NewSubchannel(subchannel.clone()))
.unwrap();
(subchannel, SubchannelState::idle())
}
fn update_picker(&mut self, update: LbState) {
println!("picker_update called with {}", update.connectivity_state);
self.tx_events
.send(TestEvent::UpdatePicker(update))
.unwrap();
}
fn request_resolution(&mut self) {
self.tx_events.send(TestEvent::RequestResolution).unwrap();
}
}
#[derive(Debug)]
pub(crate) struct TestWorkScheduler {
pub(crate) tx_events: std::sync::mpsc::Sender<TestEvent>,
}
impl WorkScheduler for TestWorkScheduler {
fn schedule_work(&self) {
self.tx_events.send(TestEvent::ScheduleWork).unwrap();
}
}
type ResolverUpdateFn = Arc<
dyn Fn(
&mut StubPolicyData,
ResolverUpdate,
Option<&DynLbConfig>,
&mut dyn ChannelController,
) -> Result<(), String>
+ Send
+ Sync,
>;
type SubchannelUpdateFn = Arc<
dyn Fn(&mut StubPolicyData, Arc<dyn Subchannel>, &SubchannelState, &mut dyn ChannelController)
+ Send
+ Sync,
>;
type ExitIdleFn = Arc<dyn Fn(&mut StubPolicyData, &mut dyn ChannelController) + Send + Sync>;
type WorkFn = Arc<dyn Fn(&mut StubPolicyData, &mut dyn ChannelController) + Send + Sync>;
#[derive(Clone, Default)]
pub(crate) struct StubPolicyFuncs {
pub resolver_update: Option<ResolverUpdateFn>,
pub subchannel_update: Option<SubchannelUpdateFn>,
pub exit_idle: Option<ExitIdleFn>,
pub work: Option<WorkFn>,
}
impl Debug for StubPolicyFuncs {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "stub funcs")
}
}
#[derive(Debug)]
pub(crate) struct StubPolicyData {
pub lb_policy_options: LbPolicyOptions,
pub test_data: Option<Box<dyn Any + Send + Sync>>,
}
impl StubPolicyData {
pub fn new(lb_policy_options: LbPolicyOptions) -> Self {
Self {
test_data: None,
lb_policy_options,
}
}
}
#[derive(Debug)]
pub(crate) struct StubPolicy {
funcs: StubPolicyFuncs,
data: StubPolicyData,
}
impl LbPolicy for StubPolicy {
type LbConfig = DynLbConfig;
fn resolver_update(
&mut self,
update: ResolverUpdate,
config: Option<&DynLbConfig>,
channel_controller: &mut dyn ChannelController,
) -> Result<(), String> {
if let Some(f) = &mut self.funcs.resolver_update {
return f(&mut self.data, update, config, channel_controller);
}
Ok(())
}
fn subchannel_update(
&mut self,
subchannel: Arc<dyn Subchannel>,
state: &SubchannelState,
channel_controller: &mut dyn ChannelController,
) {
if let Some(f) = &self.funcs.subchannel_update {
f(&mut self.data, subchannel, state, channel_controller);
}
}
fn exit_idle(&mut self, channel_controller: &mut dyn ChannelController) {
if let Some(f) = &self.funcs.exit_idle {
f(&mut self.data, channel_controller);
}
}
fn work(&mut self, channel_controller: &mut dyn ChannelController) {
if let Some(f) = &self.funcs.work {
f(&mut self.data, channel_controller);
}
}
}
impl StubPolicy {
pub(crate) fn new(funcs: StubPolicyFuncs, options: LbPolicyOptions) -> Self {
Self {
funcs,
data: StubPolicyData::new(options),
}
}
}
#[derive(Debug)]
pub(crate) struct StubPolicyBuilder {
name: &'static str,
funcs: StubPolicyFuncs,
}
#[derive(Serialize, Deserialize, Debug)]
#[serde(rename_all = "camelCase")]
pub(super) struct MockConfig {
shuffle_address_list: Option<bool>,
}
impl LbPolicyBuilder for StubPolicyBuilder {
type LbPolicy = Box<DynLbPolicy>;
fn build(&self, options: LbPolicyOptions) -> Self::LbPolicy {
let data = StubPolicyData::new(options);
Box::new(StubPolicy {
funcs: self.funcs.clone(),
data,
})
}
fn name(&self) -> &'static str {
self.name
}
fn parse_config(&self, config: &ParsedJsonLbConfig) -> Result<Option<DynLbConfig>, String> {
let cfg: MockConfig = match config.convert_to() {
Ok(c) => c,
Err(e) => {
return Err(format!("failed to parse JSON config: {}", e));
}
};
Ok(Some(Arc::new(cfg)))
}
}
pub(crate) fn reg_stub_policy(name: &'static str, funcs: StubPolicyFuncs) {
super::GLOBAL_LB_REGISTRY.add_dyn_builder(Arc::new(StubPolicyBuilder { name, funcs }))
}