use std::collections::VecDeque;
use std::ops::Add;
use std::ops::Range;
use std::ops::Sub;
use std::sync::Arc;
use std::sync::Mutex;
use anyhow::Result;
use anyhow::bail;
use futures::channel::oneshot;
use ordered_float::OrderedFloat;
use tracing::debug;
use crate::CancellationContext;
use crate::CancellationContextState;
use crate::EngineEvent;
use crate::Events;
struct ParkedTask {
id: usize,
cpu: f64,
memory: u64,
notify: oneshot::Sender<()>,
}
struct Limits {
state: Arc<Mutex<LimitsState>>,
events: Events,
cancellation: CancellationContext,
}
struct LimitsState {
next_id: usize,
cpu: OrderedFloat<f64>,
memory: u64,
parked: VecDeque<ParkedTask>,
}
impl LimitsState {
fn new(cpu: f64, memory: u64) -> Self {
Self {
next_id: 0,
cpu: OrderedFloat(cpu),
memory,
parked: Default::default(),
}
}
fn unpark_tasks(&mut self) {
if self.parked.is_empty() {
return;
}
debug!(
"attempting to unpark tasks with {cpu} CPUs and {memory} bytes of memory available",
cpu = self.cpu,
memory = self.memory,
);
loop {
let parked = self.parked.make_contiguous();
let cpu_by_memory_len = {
let range = fit_longest_range(parked, self.cpu, |task| OrderedFloat(task.cpu));
fit_longest_range(&mut parked[range], self.memory, |task| task.memory).len()
};
let memory_by_cpu = fit_longest_range(parked, self.memory, |task| task.memory);
let memory_by_cpu = fit_longest_range(&mut parked[memory_by_cpu], self.cpu, |task| {
OrderedFloat(task.cpu)
});
if cpu_by_memory_len == 0 && memory_by_cpu.is_empty() {
break;
}
let range = if memory_by_cpu.len() >= cpu_by_memory_len {
memory_by_cpu
} else {
let range = fit_longest_range(parked, self.cpu, |task| OrderedFloat(task.cpu));
fit_longest_range(&mut parked[range], self.memory, |task| task.memory)
};
assert_eq!(
range.start, 0,
"expected the fit tasks to be at the front of the list"
);
for _ in range {
let task = self.parked.pop_front().unwrap();
debug!(
"unparking task with reservation of {cpu} CPU(s) and {memory} bytes of memory",
cpu = task.cpu,
memory = task.memory,
);
self.cpu -= task.cpu;
self.memory -= task.memory;
let _ = task.notify.send(());
}
}
}
}
pub struct TaskManager {
max_cpu: f64,
max_memory: u64,
limits: Option<Limits>,
}
impl TaskManager {
pub fn new(
cpu: f64,
max_cpu: f64,
memory: u64,
max_memory: u64,
events: Events,
cancellation: CancellationContext,
) -> Self {
Self {
max_cpu,
max_memory,
limits: Some(Limits {
state: Arc::new(Mutex::new(LimitsState::new(cpu, memory))),
events,
cancellation,
}),
}
}
pub fn new_unlimited(max_cpu: f64, max_memory: u64) -> Self {
Self {
max_cpu,
max_memory,
limits: None,
}
}
pub async fn run<T, O>(&self, cpu: f64, memory: u64, task: T) -> Result<Option<O>>
where
T: Future<Output = Result<Option<O>>>,
{
if cpu > self.max_cpu {
bail!(
"requested task CPU count of {cpu} exceeds the maximum CPU count of {max_cpu}",
max_cpu = self.max_cpu
);
}
if memory > self.max_memory {
bail!(
"requested task memory of {memory} byte{s} exceeds the maximum memory of \
{max_memory}",
s = if memory == 1 { "" } else { "s" },
max_memory = self.max_memory
);
}
match &self.limits {
Some(limits) => {
let mut parked = {
let mut state = limits.state.lock().expect("failed to lock state");
if cpu > state.cpu.into() || memory > state.memory {
debug!(
"parking task due to insufficient resources: task requests {cpu} \
CPU(s) and {memory} bytes of memory but there are only \
{cpu_remaining} CPU(s) and {memory_remaining} bytes of memory \
available",
cpu_remaining = state.cpu,
memory_remaining = state.memory
);
let (notify_tx, notify_rx) = oneshot::channel();
let id = state.next_id;
state.next_id += 1;
state.parked.push_back(ParkedTask {
id,
cpu,
memory,
notify: notify_tx,
});
Some((notify_rx, id))
} else {
state.cpu -= cpu;
state.memory -= memory;
debug!(
"running task with {cpu} CPUs and {memory} bytes of memory remaining",
cpu = state.cpu,
memory = state.memory
);
None
}
};
let res = match &mut parked {
Some((notify, _)) => {
if let Some(sender) = limits.events.engine() {
let _ = sender.send(EngineEvent::TaskParked);
}
let token = limits.cancellation.first();
let canceled = tokio::select! {
biased;
_ = token.cancelled() => true,
r = notify => {
r?;
false
}
};
if let Some(sender) = limits.events.engine() {
let _ = sender.send(EngineEvent::TaskUnparked { canceled });
}
if canceled { Ok(None) } else { task.await }
}
None => task.await,
};
let mut state = limits.state.lock().expect("failed to lock state");
match parked {
Some((_, id)) if state.parked.iter().any(|t| t.id == id) => {
assert!(matches!(res, Ok(None)), "task should be canceled");
}
_ => {
state.cpu += cpu;
state.memory += memory;
}
}
if limits.cancellation.state() != CancellationContextState::NotCanceled {
state.parked.clear();
} else {
state.unpark_tasks();
}
res
}
None => {
task.await
}
}
}
}
fn fit_longest_range<T, F, W>(slice: &mut [T], total_weight: W, mut weight_fn: F) -> Range<usize>
where
F: FnMut(&T) -> W,
W: Ord + Add<Output = W> + Sub<Output = W> + Default,
{
fn partition<T, F, W>(
slice: &mut [T],
weight_fn: &mut F,
mut low: usize,
high: usize,
) -> (usize, W, W)
where
F: FnMut(&T) -> W,
W: Ord + Add<Output = W> + Sub<Output = W> + Default,
{
assert!(low < high);
slice.swap(high, rand::random_range(low..high));
let pivot_weight = weight_fn(&slice[high]);
let mut sum_weight = W::default();
let range = low..=high;
for i in range {
let weight = weight_fn(&slice[i]);
if weight < pivot_weight {
slice.swap(i, low);
low += 1;
sum_weight = sum_weight.add(weight);
}
}
slice.swap(low, high);
(low, pivot_weight, sum_weight)
}
fn recurse_fit_maximal_range<T, F, W>(
slice: &mut [T],
mut remaining_weight: W,
weight_fn: &mut F,
low: usize,
high: usize,
end: &mut usize,
) where
F: FnMut(&T) -> W,
W: Ord + Add<Output = W> + Sub<Output = W> + Default,
{
if low == high {
let weight = weight_fn(&slice[low]);
if weight <= remaining_weight {
*end += 1;
}
return;
}
if low < high {
let (pivot, pivot_weight, sum) = partition(slice, weight_fn, low, high);
if sum <= remaining_weight {
*end += pivot - low;
remaining_weight = remaining_weight.sub(sum);
if pivot_weight <= remaining_weight {
*end += 1;
remaining_weight = remaining_weight.sub(pivot_weight);
}
recurse_fit_maximal_range(slice, remaining_weight, weight_fn, pivot + 1, high, end);
} else if pivot > 0 {
recurse_fit_maximal_range(slice, remaining_weight, weight_fn, low, pivot - 1, end);
}
}
}
assert!(
total_weight >= W::default(),
"total weight cannot be negative"
);
if slice.is_empty() {
return 0..0;
}
let mut end = 0;
recurse_fit_maximal_range(
slice,
total_weight,
&mut weight_fn,
0,
slice.len() - 1, &mut end,
);
0..end
}
#[cfg(test)]
mod test {
use super::*;
#[test]
fn fit_empty_slice() {
let r = fit_longest_range(&mut [], 100, |i| *i);
assert!(r.is_empty());
}
#[test]
#[should_panic(expected = "total weight cannot be negative")]
fn fit_negative_panic() {
fit_longest_range(&mut [0], -1, |i| *i);
}
#[test]
fn no_fit() {
let r = fit_longest_range(&mut [100, 101, 102], 99, |i| *i);
assert!(r.is_empty());
}
#[test]
fn fit_all() {
let r = fit_longest_range(&mut [1, 2, 3, 4, 5], 15, |i| *i);
assert_eq!(r.len(), 5);
let r = fit_longest_range(&mut [5, 4, 3, 2, 1], 20, |i| *i);
assert_eq!(r.len(), 5);
}
#[test]
fn fit_some() {
let s = &mut [8, 2, 2, 3, 2, 1, 2, 4, 1];
let r = fit_longest_range(s, 10, |i| *i);
assert_eq!(r.len(), 6);
assert_eq!(s[r.start..r.end].iter().copied().sum::<i32>(), 10);
assert!(s[r.end..].contains(&8));
assert!(s[r.end..].contains(&4));
assert!(s[r.end..].contains(&3));
}
}