use std::collections::{HashMap, HashSet};
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
use snafu::ResultExt;
use svod_device::device::Device;
use svod_device::registry::DeviceRegistry;
use svod_device::{Allocator, Buffer, BufferId, CpuTimelineSignal, TimelineSignal};
use svod_dtype::DeviceSpec;
use crate::error::Result;
pub struct DeviceContext {
pub device: DeviceSpec,
pub device_handle: Arc<Device>,
pub signal: Arc<dyn TimelineSignal>,
pub timeline: AtomicU64,
pub allocator: Arc<dyn Allocator>,
}
impl std::fmt::Debug for DeviceContext {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("DeviceContext")
.field("device", &self.device)
.field("timeline", &self.timeline.load(Ordering::Relaxed))
.finish()
}
}
impl DeviceContext {
pub fn new(device: Arc<Device>, signal: Arc<dyn TimelineSignal>) -> Self {
let allocator = device.allocator.clone();
let device_spec = device.device.clone();
Self { device: device_spec, device_handle: device, signal, timeline: AtomicU64::new(0), allocator }
}
pub fn next_timeline(&self) -> u64 {
self.timeline.fetch_add(1, Ordering::Relaxed) + 1
}
pub fn current_timeline(&self) -> u64 {
self.timeline.load(Ordering::Relaxed)
}
pub fn signal_completion(&self, value: u64) {
self.signal.set(value);
}
pub fn wait_for(&self, value: u64) -> Result<()> {
self.signal.wait(value, 0).context(crate::error::DeviceSnafu)?;
Ok(())
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum SyncStrategy {
None,
PeerToPeer,
CpuMediated,
}
#[derive(Debug, Clone)]
pub struct ExecutionNode {
pub id: u64,
pub device: DeviceSpec,
pub inputs: Vec<BufferId>,
pub outputs: Vec<BufferId>,
pub predecessors: Vec<u64>,
pub is_transfer: bool,
pub buffer_access: Option<KernelBufferAccess>,
}
#[derive(Debug, Clone)]
pub struct KernelBufferAccess {
pub buffers: Vec<BufferId>,
pub output_indices: Vec<usize>,
}
#[derive(Debug, Default)]
pub struct ExecutionGraph {
nodes: HashMap<u64, ExecutionNode>,
execution_order: Vec<u64>,
device_groups: HashMap<DeviceSpec, Vec<u64>>,
}
impl ExecutionGraph {
pub fn new() -> Self {
Self::default()
}
pub fn add_node(&mut self, node: ExecutionNode) {
let id = node.id;
let device = node.device.clone();
self.nodes.insert(id, node);
self.device_groups.entry(device).or_default().push(id);
}
pub fn node(&self, id: u64) -> Option<&ExecutionNode> {
self.nodes.get(&id)
}
pub fn nodes(&self) -> impl Iterator<Item = &ExecutionNode> {
self.nodes.values()
}
pub fn compute_parallel_groups(&mut self) -> Vec<Vec<u64>> {
let mut in_degree: HashMap<u64, usize> = HashMap::new();
let mut successors: HashMap<u64, HashSet<u64>> = HashMap::new();
for node in self.nodes.values() {
in_degree.entry(node.id).or_insert(0);
let mut preds: smallvec::SmallVec<[u64; 8]> = node.predecessors.iter().copied().collect();
preds.sort_unstable();
preds.dedup();
for &pred in &preds {
successors.entry(pred).or_default().insert(node.id);
*in_degree.entry(node.id).or_insert(0) += 1;
}
}
let mut groups = Vec::new();
let mut ready: Vec<u64> = in_degree.iter().filter(|&(_, deg)| *deg == 0).map(|(&id, _)| id).collect();
while !ready.is_empty() {
groups.push(ready.clone());
self.execution_order.extend(ready.iter().copied());
let mut next_ready = Vec::new();
for id in ready {
if let Some(succs) = successors.get(&id) {
for &succ in succs {
let deg = in_degree.get_mut(&succ).unwrap();
*deg -= 1;
if *deg == 0 {
next_ready.push(succ);
}
}
}
}
ready = next_ready;
}
groups
}
pub fn device_groups(&self) -> &HashMap<DeviceSpec, Vec<u64>> {
&self.device_groups
}
pub fn is_valid(&self) -> bool {
self.execution_order.len() == self.nodes.len()
}
}
pub struct UnifiedExecutor {
contexts: HashMap<DeviceSpec, DeviceContext>,
registry: &'static DeviceRegistry,
}
impl std::fmt::Debug for UnifiedExecutor {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("UnifiedExecutor").field("contexts", &self.contexts.keys().collect::<Vec<_>>()).finish()
}
}
impl UnifiedExecutor {
pub fn new(registry: &'static DeviceRegistry) -> Self {
Self { contexts: HashMap::new(), registry }
}
pub fn add_device(&mut self, device_spec: DeviceSpec) -> Result<()> {
if self.contexts.contains_key(&device_spec) {
return Ok(()); }
let device = crate::DEVICE_FACTORIES.device(&device_spec, self.registry)?;
let signal: Arc<dyn TimelineSignal> = match &device_spec {
DeviceSpec::Cpu => Arc::new(CpuTimelineSignal::new()),
#[cfg(feature = "cuda")]
DeviceSpec::Cuda { .. } => {
Arc::new(CpuTimelineSignal::new())
}
_ => Arc::new(CpuTimelineSignal::new()),
};
let ctx = DeviceContext::new(device, signal);
self.contexts.insert(device_spec, ctx);
Ok(())
}
pub fn context(&self, device: &DeviceSpec) -> Option<&DeviceContext> {
self.contexts.get(device)
}
pub fn context_mut(&mut self, device: &DeviceSpec) -> Option<&mut DeviceContext> {
self.contexts.get_mut(device)
}
pub fn sync_strategy(from: &DeviceSpec, to: &DeviceSpec) -> SyncStrategy {
if from == to {
SyncStrategy::None
} else if std::mem::discriminant(from) == std::mem::discriminant(to) {
SyncStrategy::PeerToPeer
} else {
SyncStrategy::CpuMediated
}
}
pub fn single_device_check(&self, buffers: &[&Buffer]) -> Option<DeviceSpec> {
if buffers.is_empty() {
return None;
}
let first_device = buffers[0].allocator().device_spec();
for buffer in buffers.iter().skip(1) {
if buffer.allocator().device_spec() != first_device {
return None;
}
}
Some(first_device)
}
pub fn synchronize_all(&self) -> Result<()> {
for ctx in self.contexts.values() {
let current = ctx.current_timeline();
if current > 0 {
ctx.wait_for(current)?;
}
}
Ok(())
}
pub fn execute_kernel<F>(&mut self, device: &DeviceSpec, execute_fn: F) -> Result<u64>
where
F: FnOnce() -> Result<()>,
{
if !self.contexts.contains_key(device) {
self.add_device(device.clone())?;
}
let timeline = self.contexts.get(device).unwrap().next_timeline();
execute_fn()?;
if let Some(ctx) = self.contexts.get(device) {
ctx.signal_completion(timeline);
}
Ok(timeline)
}
pub fn execute_transfer(
&mut self,
src: &Buffer,
dst: &mut Buffer,
src_device: &DeviceSpec,
dst_device: &DeviceSpec,
) -> Result<u64> {
if !self.contexts.contains_key(src_device) {
self.add_device(src_device.clone())?;
}
if !self.contexts.contains_key(dst_device) {
self.add_device(dst_device.clone())?;
}
let timeline = self.contexts.get(dst_device).unwrap().next_timeline();
match Self::sync_strategy(src_device, dst_device) {
SyncStrategy::None => {
dst.copy_from(src).context(crate::error::DeviceSnafu)?;
}
SyncStrategy::PeerToPeer => {
dst.copy_from(src).context(crate::error::DeviceSnafu)?;
}
SyncStrategy::CpuMediated => {
if let Some(src_ctx) = self.contexts.get(src_device) {
let src_timeline = src_ctx.current_timeline();
if src_timeline > 0 {
src_ctx.wait_for(src_timeline)?;
}
}
dst.copy_from(src).context(crate::error::DeviceSnafu)?;
if let Some(dst_ctx) = self.contexts.get(dst_device) {
let dst_timeline = dst_ctx.current_timeline();
if dst_timeline > 0 {
dst_ctx.wait_for(dst_timeline)?;
}
}
}
}
if let Some(ctx) = self.contexts.get(dst_device) {
ctx.signal_completion(timeline);
}
Ok(timeline)
}
}
static EXECUTOR: once_cell::sync::Lazy<parking_lot::Mutex<UnifiedExecutor>> =
once_cell::sync::Lazy::new(|| parking_lot::Mutex::new(UnifiedExecutor::new(svod_device::registry::registry())));
pub fn global_executor() -> parking_lot::MutexGuard<'static, UnifiedExecutor> {
EXECUTOR.lock()
}
#[cfg(test)]
#[path = "test/unit/executor.rs"]
mod tests;