use core::cmp::Ordering;
use std::cell::Cell;
use std::collections::BTreeMap;
use std::collections::VecDeque;
use std::fmt;
use std::sync::Arc;
use anyhow::anyhow;
use anyhow::bail;
use anyhow::Result;
use log::debug;
use log::trace;
use ordered_float::OrderedFloat;
use scx_utils::ravg::ravg_read;
use scx_utils::LoadAggregator;
use scx_utils::LoadLedger;
use sorted_vec::SortedVec;
use crate::bpf_intf;
use crate::bpf_skel::*;
use crate::stats::DomainStats;
use crate::stats::NodeStats;
use crate::DomainGroup;
const DEFAULT_WEIGHT: f64 = bpf_intf::consts_LB_DEFAULT_WEIGHT as f64;
const RAVG_FRAC_BITS: u32 = bpf_intf::ravg_consts_RAVG_FRAC_BITS;
fn now_monotonic() -> u64 {
let mut time = libc::timespec {
tv_sec: 0,
tv_nsec: 0,
};
let ret = unsafe { libc::clock_gettime(libc::CLOCK_MONOTONIC, &mut time) };
assert!(ret == 0);
time.tv_sec as u64 * 1_000_000_000 + time.tv_nsec as u64
}
#[derive(Clone, Copy, Debug, PartialEq)]
enum BalanceState {
Balanced,
NeedsPush,
NeedsPull,
}
impl fmt::Display for BalanceState {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
BalanceState::Balanced => write!(f, "BALANCED"),
BalanceState::NeedsPush => write!(f, "OVER-LOADED"),
BalanceState::NeedsPull => write!(f, "UNDER-LOADED"),
}
}
}
macro_rules! impl_ord_for_type {
($($t:ty),*) => {
$(
impl PartialEq for $t {
fn eq(&self, other: &Self) -> bool {
<dyn LoadOrdered>::eq(self, other)
}
}
impl Eq for $t {}
impl PartialOrd for $t {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
<dyn LoadOrdered>::partial_cmp(self, other)
}
}
impl Ord for $t {
fn cmp(&self, other: &Self) -> Ordering {
<dyn LoadOrdered>::cmp(self, other)
}
}
)*
};
}
trait LoadOrdered {
fn get_load(&self) -> OrderedFloat<f64>;
}
impl dyn LoadOrdered {
#[inline]
fn eq(&self, other: &Self) -> bool {
self.get_load().eq(&other.get_load())
}
#[inline]
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
self.get_load().partial_cmp(&other.get_load())
}
#[inline]
fn cmp(&self, other: &Self) -> Ordering {
self.get_load().cmp(&other.get_load())
}
}
#[derive(Debug, Clone)]
pub struct LoadEntity {
cost_ratio: f64,
push_max_ratio: f64,
xfer_ratio: f64,
load_sum: OrderedFloat<f64>,
load_avg: f64,
load_delta: f64,
bal_state: BalanceState,
}
impl LoadEntity {
fn new(
cost_ratio: f64,
push_max_ratio: f64,
xfer_ratio: f64,
load_sum: f64,
load_avg: f64,
) -> Self {
let mut entity = Self {
cost_ratio,
push_max_ratio,
xfer_ratio,
load_sum: OrderedFloat(load_sum),
load_avg,
load_delta: 0.0f64,
bal_state: BalanceState::Balanced,
};
entity.add_load(0.0f64);
entity
}
pub fn load_sum(&self) -> f64 {
*self.load_sum
}
pub fn load_avg(&self) -> f64 {
self.load_avg
}
pub fn imbal(&self) -> f64 {
self.load_sum() - self.load_avg
}
pub fn delta(&self) -> f64 {
self.load_delta
}
fn state(&self) -> BalanceState {
self.bal_state
}
fn rebalance(&mut self, new_load: f64) {
self.load_sum = OrderedFloat(new_load);
let imbal = self.imbal();
let needs_balance = imbal.abs() > self.load_avg * self.cost_ratio;
self.bal_state = if needs_balance {
if imbal > 0f64 {
BalanceState::NeedsPush
} else {
BalanceState::NeedsPull
}
} else {
BalanceState::Balanced
};
}
fn add_load(&mut self, delta: f64) {
self.rebalance(self.load_sum() + delta);
self.load_delta += delta;
}
fn push_cutoff(&self) -> f64 {
self.imbal().abs() * self.push_max_ratio
}
fn xfer_between(&self, other: &LoadEntity) -> f64 {
self.imbal().abs().min(other.imbal().abs()) * self.xfer_ratio
}
}
#[derive(Debug)]
struct TaskInfo {
taskc_p: *mut types::task_ctx,
load: OrderedFloat<f64>,
dom_mask: u64,
preferred_dom_mask: u64,
migrated: Cell<bool>,
is_kworker: bool,
}
impl LoadOrdered for TaskInfo {
fn get_load(&self) -> OrderedFloat<f64> {
self.load
}
}
impl_ord_for_type!(TaskInfo);
#[derive(Debug)]
struct Domain {
id: usize,
queried_tasks: bool,
load: LoadEntity,
tasks: SortedVec<TaskInfo>,
}
impl Domain {
const LOAD_IMBAL_HIGH_RATIO: f64 = 0.05;
const LOAD_IMBAL_XFER_TARGET_RATIO: f64 = 0.50;
const LOAD_IMBAL_PUSH_MAX_RATIO: f64 = 0.50;
fn new(id: usize, load_sum: f64, load_avg: f64) -> Self {
Self {
id,
queried_tasks: false,
load: LoadEntity::new(
Domain::LOAD_IMBAL_HIGH_RATIO,
Domain::LOAD_IMBAL_PUSH_MAX_RATIO,
Domain::LOAD_IMBAL_XFER_TARGET_RATIO,
load_sum,
load_avg,
),
tasks: SortedVec::new(),
}
}
fn transfer_load(&mut self, load: f64, taskc: &mut types::task_ctx, other: &mut Domain) {
trace!("XFER pid={} dom={}->{}", taskc.pid, self.id, other.id);
let dom_id: u32 = other.id.try_into().unwrap();
taskc.target_dom = dom_id;
self.load.add_load(-load);
other.load.add_load(load);
}
fn xfer_between(&self, other: &Domain) -> f64 {
self.load.xfer_between(&other.load)
}
}
impl LoadOrdered for Domain {
fn get_load(&self) -> OrderedFloat<f64> {
self.load.load_sum
}
}
impl_ord_for_type!(Domain);
#[derive(Debug)]
struct NumaNode {
id: usize,
load: LoadEntity,
domains: SortedVec<Domain>,
}
impl NumaNode {
const LOAD_IMBAL_HIGH_RATIO: f64 = 0.17;
const LOAD_IMBAL_XFER_TARGET_RATIO: f64 = 0.50;
const LOAD_IMBAL_PUSH_MAX_RATIO: f64 = 0.50;
fn new(id: usize, numa_load_avg: f64) -> Self {
Self {
id,
load: LoadEntity::new(
NumaNode::LOAD_IMBAL_HIGH_RATIO,
NumaNode::LOAD_IMBAL_PUSH_MAX_RATIO,
NumaNode::LOAD_IMBAL_XFER_TARGET_RATIO,
0.0f64,
numa_load_avg,
),
domains: SortedVec::new(),
}
}
fn allocate_domain(&mut self, id: usize, load: f64, dom_load_avg: f64) {
let domain = Domain::new(id, load, dom_load_avg);
self.insert_domain(domain);
self.load.rebalance(self.load.load_sum() + load);
}
fn xfer_between(&self, other: &NumaNode) -> f64 {
self.load.xfer_between(&other.load)
}
fn insert_domain(&mut self, domain: Domain) {
self.domains.insert(domain);
}
fn update_load(&mut self, delta: f64) {
self.load.add_load(delta);
}
fn stats(&self) -> NodeStats {
let mut stats = NodeStats::new(
self.load.load_sum(),
self.load.imbal(),
self.load.delta(),
BTreeMap::new(),
);
for dom in self.domains.iter() {
stats.doms.insert(
dom.id,
DomainStats::new(dom.load.load_sum(), dom.load.imbal(), dom.load.delta()),
);
}
stats
}
}
impl LoadOrdered for NumaNode {
fn get_load(&self) -> OrderedFloat<f64> {
self.load.load_sum
}
}
impl_ord_for_type!(NumaNode);
pub struct LoadBalancer<'a, 'b> {
skel: &'a mut BpfSkel<'b>,
dom_group: Arc<DomainGroup>,
skip_kworkers: bool,
infeas_threshold: f64,
nodes: SortedVec<NumaNode>,
lb_apply_weight: bool,
balance_load: bool,
}
const_assert_eq!(
bpf_intf::consts_LB_MAX_WEIGHT % bpf_intf::consts_LB_LOAD_BUCKETS,
0
);
impl<'a, 'b> LoadBalancer<'a, 'b> {
pub fn new(
skel: &'a mut BpfSkel<'b>,
dom_group: Arc<DomainGroup>,
skip_kworkers: bool,
lb_apply_weight: bool,
balance_load: bool,
) -> Self {
Self {
skel,
skip_kworkers,
infeas_threshold: bpf_intf::consts_LB_MAX_WEIGHT as f64,
nodes: SortedVec::new(),
lb_apply_weight,
balance_load,
dom_group,
}
}
pub fn load_balance(&mut self) -> Result<()> {
self.create_domain_hierarchy()?;
if self.balance_load {
self.perform_balancing()?
}
Ok(())
}
pub fn get_stats(&self) -> BTreeMap<usize, NodeStats> {
self.nodes
.iter()
.map(|node| (node.id, node.stats()))
.collect()
}
fn create_domain_hierarchy(&mut self) -> Result<()> {
let ledger = self.calculate_load_avgs()?;
let (dom_loads, total_load) = if !self.lb_apply_weight {
(
ledger
.dom_dcycle_sums()
.iter()
.copied()
.map(|d| DEFAULT_WEIGHT * d)
.collect(),
DEFAULT_WEIGHT * ledger.global_dcycle_sum(),
)
} else {
self.infeas_threshold = ledger.effective_max_weight();
(ledger.dom_load_sums().to_vec(), ledger.global_load_sum())
};
let num_numa_nodes = self.dom_group.nr_nodes();
let numa_load_avg = total_load / num_numa_nodes as f64;
let mut nodes: Vec<NumaNode> = (0..num_numa_nodes)
.map(|id| NumaNode::new(id, numa_load_avg))
.collect();
let dom_load_avg = total_load / dom_loads.len() as f64;
for (dom_id, load) in dom_loads.iter().enumerate() {
let numa_id = self
.dom_group
.dom_numa_id(&dom_id)
.ok_or_else(|| anyhow!("Failed to get NUMA ID for domain {}", dom_id))?;
if numa_id >= num_numa_nodes {
bail!("NUMA ID {} exceeds maximum {}", numa_id, num_numa_nodes);
}
let node = &mut nodes[numa_id];
node.allocate_domain(dom_id, *load, dom_load_avg);
}
self.nodes = SortedVec::from_unsorted(nodes);
Ok(())
}
fn calculate_load_avgs(&mut self) -> Result<LoadLedger> {
const NUM_BUCKETS: u64 = bpf_intf::consts_LB_LOAD_BUCKETS as u64;
let now_mono = now_monotonic();
let load_half_life = self.skel.maps.rodata_data.as_ref().unwrap().load_half_life;
let mut aggregator =
LoadAggregator::new(self.dom_group.weight(), !self.lb_apply_weight.clone());
for (dom_id, dom) in self.dom_group.doms() {
aggregator.init_domain(*dom_id);
let dom_ctx = dom.ctx().unwrap();
for bucket in 0..NUM_BUCKETS {
let bucket_ctx = &dom_ctx.buckets[bucket as usize];
let rd = &bucket_ctx.rd;
let duty_cycle = ravg_read(
rd.val,
rd.val_at,
rd.old,
rd.cur,
now_mono,
load_half_life,
RAVG_FRAC_BITS,
);
if duty_cycle == 0.0f64 {
continue;
}
let weight = self.bucket_weight(bucket);
aggregator.record_dom_load(*dom_id, weight, duty_cycle)?;
}
}
Ok(aggregator.calculate())
}
fn bucket_range(&self, bucket: u64) -> (f64, f64) {
const MAX_WEIGHT: u64 = bpf_intf::consts_LB_MAX_WEIGHT as u64;
const NUM_BUCKETS: u64 = bpf_intf::consts_LB_LOAD_BUCKETS as u64;
const WEIGHT_PER_BUCKET: u64 = MAX_WEIGHT / NUM_BUCKETS;
if bucket >= NUM_BUCKETS {
panic!("Invalid bucket {}, max {}", bucket, NUM_BUCKETS);
}
let min_w = 1 + (MAX_WEIGHT * bucket) / NUM_BUCKETS;
let max_w = min_w + WEIGHT_PER_BUCKET - 1;
(min_w as f64, max_w as f64)
}
fn bucket_weight(&self, bucket: u64) -> usize {
const WEIGHT_PER_BUCKET: f64 = bpf_intf::consts_LB_WEIGHT_PER_BUCKET as f64;
let (min_weight, _) = self.bucket_range(bucket);
(min_weight + (WEIGHT_PER_BUCKET / 2.0f64)).ceil() as usize
}
fn populate_tasks_by_load(&mut self, dom: &mut Domain) -> Result<()> {
if dom.queried_tasks {
return Ok(());
}
dom.queried_tasks = true;
const MAX_TPTRS: u64 = bpf_intf::consts_MAX_DOM_ACTIVE_TPTRS as u64;
let dom_ctx = unsafe { &mut *self.skel.maps.bss_data.as_mut().unwrap().dom_ctxs[dom.id] };
let active_tasks = &mut dom_ctx.active_tasks;
let (mut ridx, widx) = (active_tasks.read_idx, active_tasks.write_idx);
active_tasks.read_idx = active_tasks.write_idx;
active_tasks.genn += 1;
if widx - ridx > MAX_TPTRS {
ridx = widx - MAX_TPTRS;
}
let load_half_life = self.skel.maps.rodata_data.as_ref().unwrap().load_half_life;
let now_mono = now_monotonic();
for idx in ridx..widx {
let taskc_p = active_tasks.tasks[(idx % MAX_TPTRS) as usize];
let taskc = unsafe { &mut *taskc_p };
if taskc.target_dom as usize != dom.id {
continue;
}
let rd = &taskc.dcyc_rd;
let mut load = ravg_read(
rd.val,
rd.val_at,
rd.old,
rd.cur,
now_mono,
load_half_life,
RAVG_FRAC_BITS,
);
let weight = if self.lb_apply_weight {
(taskc.weight as f64).min(self.infeas_threshold)
} else {
DEFAULT_WEIGHT
};
load *= weight;
dom.tasks.insert(TaskInfo {
taskc_p,
load: OrderedFloat(load),
dom_mask: taskc.dom_mask,
preferred_dom_mask: taskc.preferred_dom_mask,
migrated: Cell::new(false),
is_kworker: unsafe { taskc.is_kworker.assume_init() },
});
}
Ok(())
}
fn find_first_candidate<'d, I>(tasks_by_load: I) -> Option<&'d TaskInfo>
where
I: IntoIterator<Item = &'d TaskInfo>,
{
tasks_by_load.into_iter().next()
}
fn try_find_move_task(
&mut self,
(push_dom, to_push): (&mut Domain, f64),
(pull_dom, to_pull): (&mut Domain, f64),
task_filter: impl Fn(&TaskInfo, u32) -> bool,
to_xfer: f64,
) -> Result<Option<f64>> {
let to_pull = to_pull.abs();
let calc_new_imbal = |xfer: f64| (to_push - xfer).abs() + (to_pull - xfer).abs();
self.populate_tasks_by_load(push_dom)?;
let pull_dom_id: u32 = pull_dom.id.try_into().unwrap();
let tasks: Vec<TaskInfo> = std::mem::take(&mut push_dom.tasks)
.into_vec()
.into_iter()
.filter(|task| {
task.dom_mask & (1 << pull_dom_id) != 0
&& !(self.skip_kworkers && task.is_kworker)
&& !task.migrated.get()
})
.collect();
let (task, new_imbal) = match (
Self::find_first_candidate(
tasks
.as_slice()
.iter()
.filter(|x| x.load <= OrderedFloat(to_xfer) && task_filter(x, pull_dom_id))
.rev(),
),
Self::find_first_candidate(
tasks
.as_slice()
.iter()
.filter(|x| x.load >= OrderedFloat(to_xfer) && task_filter(x, pull_dom_id)),
),
) {
(None, None) => {
std::mem::swap(&mut push_dom.tasks, &mut SortedVec::from_unsorted(tasks));
return Ok(None);
}
(Some(task), None) | (None, Some(task)) => (task, calc_new_imbal(*task.load)),
(Some(task0), Some(task1)) => {
let (new_imbal0, new_imbal1) =
(calc_new_imbal(*task0.load), calc_new_imbal(*task1.load));
if new_imbal0 <= new_imbal1 {
(task0, new_imbal0)
} else {
(task1, new_imbal1)
}
}
};
let old_imbal = to_push + to_pull;
if old_imbal < new_imbal {
std::mem::swap(&mut push_dom.tasks, &mut SortedVec::from_unsorted(tasks));
return Ok(None);
}
let load = *(task.load);
let taskc_p = task.taskc_p;
task.migrated.set(true);
std::mem::swap(&mut push_dom.tasks, &mut SortedVec::from_unsorted(tasks));
push_dom.transfer_load(load, unsafe { &mut *taskc_p }, pull_dom);
Ok(Some(load))
}
fn transfer_between_nodes(
&mut self,
push_node: &mut NumaNode,
pull_node: &mut NumaNode,
) -> Result<f64> {
debug!("Inter node {} -> {} started", push_node.id, pull_node.id);
let push_imbal = push_node.load.imbal();
let pull_imbal = pull_node.load.imbal();
let xfer = push_node.xfer_between(pull_node);
if push_imbal <= 0.0f64 || pull_imbal >= 0.0f64 {
bail!(
"push node {}:{}, pull node {}:{}",
push_node.id,
push_imbal,
pull_node.id,
pull_imbal
);
}
let mut pushers = VecDeque::with_capacity(push_node.domains.len());
let mut pullers = Vec::with_capacity(pull_node.domains.len());
let mut pushed = 0f64;
while push_node.domains.len() > 0 {
let mut push_dom = push_node.domains.pop().unwrap();
if push_dom.load.state() != BalanceState::NeedsPush {
push_node.domains.insert(push_dom);
break;
}
while pull_node.domains.len() > 0 {
let mut pull_dom = pull_node.domains.remove_index(0);
if pull_dom.load.state() != BalanceState::NeedsPull {
pull_node.domains.insert(pull_dom);
break;
}
let mut transferred = self.try_find_move_task(
(&mut push_dom, push_imbal),
(&mut pull_dom, pull_imbal),
|task: &TaskInfo, pull_dom: u32| -> bool {
(task.preferred_dom_mask & (1 << pull_dom)) > 0
},
xfer,
)?;
if transferred.is_none() {
transferred = self.try_find_move_task(
(&mut push_dom, push_imbal),
(&mut pull_dom, pull_imbal),
|_task: &TaskInfo, _pull_dom: u32| -> bool { true },
xfer,
)?;
}
pullers.push(pull_dom);
if let Some(transferred) = transferred {
pushed = transferred;
push_node.update_load(-transferred);
pull_node.update_load(transferred);
break;
}
}
while let Some(puller) = pullers.pop() {
pull_node.domains.insert(puller);
}
pushers.push_back(push_dom);
if pushed > 0.0f64 {
break;
}
}
while let Some(pusher) = pushers.pop_front() {
push_node.domains.insert(pusher);
}
Ok(pushed)
}
fn balance_between_nodes(&mut self) -> Result<()> {
debug!("Node <-> Node LB started");
let mut pushers = VecDeque::with_capacity(self.nodes.len());
let mut pullers = Vec::with_capacity(self.nodes.len());
while self.nodes.len() >= 2 {
let mut push_node = self.nodes.pop().unwrap();
if push_node.load.state() != BalanceState::NeedsPush {
self.nodes.insert(push_node);
break;
}
let push_cutoff = push_node.load.push_cutoff();
let mut pushed = 0f64;
while self.nodes.len() > 0 && pushed < push_cutoff {
let mut pull_node = self.nodes.remove_index(0);
let pull_id = pull_node.id;
if pull_node.load.state() != BalanceState::NeedsPull {
self.nodes.insert(pull_node);
break;
}
let migrated = self.transfer_between_nodes(&mut push_node, &mut pull_node)?;
pullers.push(pull_node);
if migrated > 0.0f64 {
pushed += migrated;
debug!(
"NODE {} sending {:.06} --> NODE {}",
push_node.id, migrated, pull_id
);
}
}
while let Some(puller) = pullers.pop() {
self.nodes.insert(puller);
}
if pushed > 0.0f64 {
debug!("NODE {} pushed {:.06} total load", push_node.id, pushed);
}
pushers.push_back(push_node);
}
while let Some(pusher) = pushers.pop_front() {
self.nodes.insert(pusher);
}
Ok(())
}
fn balance_within_node(&mut self, node: &mut NumaNode) -> Result<()> {
if node.domains.len() < 2 {
return Ok(());
}
debug!("Intra node {} LB started", node.id);
let mut pushers = VecDeque::with_capacity(node.domains.len());
let mut pullers = Vec::new();
while node.domains.len() >= 2 {
let mut push_dom = node.domains.pop().unwrap();
if node.domains.len() == 0 || push_dom.load.state() != BalanceState::NeedsPush {
node.domains.insert(push_dom);
break;
}
let mut pushed = 0.0f64;
let push_cutoff = push_dom.load.push_cutoff();
let push_imbal = push_dom.load.imbal();
if push_imbal < 0.0f64 {
bail!(
"Node {} push dom {} had imbal {}",
node.id,
push_dom.id,
push_imbal
);
}
while node.domains.len() > 0 && pushed < push_cutoff {
let mut pull_dom = node.domains.remove_index(0);
if pull_dom.load.state() != BalanceState::NeedsPull {
node.domains.push(pull_dom);
break;
}
let pull_imbal = pull_dom.load.imbal();
if pull_imbal >= 0.0f64 {
bail!(
"Node {} pull dom {} had imbal {}",
node.id,
pull_dom.id,
pull_imbal
);
}
let xfer = push_dom.xfer_between(&pull_dom);
let mut transferred = self.try_find_move_task(
(&mut push_dom, push_imbal),
(&mut pull_dom, pull_imbal),
|task: &TaskInfo, pull_dom: u32| -> bool {
(task.preferred_dom_mask & (1 << pull_dom)) > 0
},
xfer,
)?;
if transferred.is_none() {
transferred = self.try_find_move_task(
(&mut push_dom, push_imbal),
(&mut pull_dom, pull_imbal),
|_task: &TaskInfo, _pull_dom: u32| -> bool { true },
xfer,
)?;
}
if let Some(transferred) = transferred {
if transferred <= 0.0f64 {
bail!("Expected nonzero load transfer")
}
pushed += transferred;
node.domains.insert(pull_dom);
continue;
}
pullers.push(pull_dom);
}
while let Some(puller) = pullers.pop() {
node.domains.insert(puller);
}
if pushed > 0.0f64 {
debug!("DOM {} pushed {:.06} total load", push_dom.id, pushed);
}
pushers.push_back(push_dom);
}
while let Some(pusher) = pushers.pop_front() {
node.domains.insert(pusher);
}
Ok(())
}
fn perform_balancing(&mut self) -> Result<()> {
if self.dom_group.nr_nodes() > 1 {
self.balance_between_nodes()?;
}
debug!("Intra node LBs started");
let mut nodes = std::mem::take(&mut self.nodes).into_vec();
for node in nodes.iter_mut() {
self.balance_within_node(node)?;
}
std::mem::swap(&mut self.nodes, &mut SortedVec::from_unsorted(nodes));
Ok(())
}
}