use std::{collections::HashSet, fmt::Debug, hash::Hash};
use super::{
InstructionFrameInteraction, MemoryAccessDependency, MemoryAccessType, ScheduledGraphNode,
};
pub(super) trait Access: Copy + Eq + Hash + Debug {
type Action: Copy + Eq + Hash + Debug;
type Read: Copy + Eq + Hash + Debug;
type Write: Copy + Eq + Hash + Debug;
type Dependency: Copy + Eq + Hash + Debug;
fn initial_writer() -> Option<Self::Write>;
fn classify(self, action: Self::Action) -> AccessingAction<Self>;
fn read_dependency(read: Self::Read) -> Self::Dependency;
fn write_dependency(write: Self::Write) -> Self::Dependency;
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub(super) enum AccessingAction<A: Access> {
Read(A::Read),
Write(A::Write),
}
impl Access for MemoryAccessType {
type Action = ScheduledGraphNode;
type Read = ScheduledGraphNode;
type Write = MemoryAccessDependency;
type Dependency = MemoryAccessDependency;
#[inline]
fn initial_writer() -> Option<<Self as Access>::Write> {
None
}
#[inline]
fn classify(self, action: Self::Action) -> AccessingAction<Self> {
match self {
Self::Read => AccessingAction::Read(action),
Self::Write | Self::Capture => AccessingAction::Write(MemoryAccessDependency {
access_type: self,
node_id: action,
}),
}
}
#[inline]
fn read_dependency(read: <Self as Access>::Read) -> Self::Dependency {
MemoryAccessDependency {
access_type: Self::Read,
node_id: read,
}
}
#[inline]
fn write_dependency(write: <Self as Access>::Write) -> Self::Dependency {
write
}
}
impl Access for InstructionFrameInteraction {
type Action = ScheduledGraphNode;
type Read = ScheduledGraphNode;
type Write = ScheduledGraphNode;
type Dependency = ScheduledGraphNode;
#[inline]
fn initial_writer() -> Option<Self::Write> {
Some(ScheduledGraphNode::BlockStart)
}
#[inline]
fn classify(self, action: Self::Action) -> AccessingAction<Self> {
match self {
Self::Blocking => AccessingAction::Read(action),
Self::Using => AccessingAction::Write(action),
}
}
#[inline]
fn read_dependency(read: Self::Read) -> Self::Dependency {
read
}
#[inline]
fn write_dependency(write: Self::Write) -> Self::Dependency {
write
}
}
pub(super) struct DependencyQueue<A: Access> {
write: Option<A::Write>,
reads: HashSet<A::Read>,
}
impl<A: Access> DependencyQueue<A> {
pub(super) fn new() -> Self {
Self {
write: A::initial_writer(),
reads: HashSet::new(),
}
}
pub(super) fn record_access_and_get_dependencies(
&mut self,
action: A::Action,
access_type: A,
) -> HashSet<A::Dependency> {
let mut result: HashSet<_> = self.write.into_iter().map(A::write_dependency).collect();
match access_type.classify(action) {
AccessingAction::Write(write) => {
result.extend(self.reads.drain().map(A::read_dependency));
self.write = Some(write);
}
AccessingAction::Read(read) => {
self.reads.insert(read);
}
}
result
}
pub(super) fn into_pending_dependencies(self) -> HashSet<A::Dependency> {
self.reads
.into_iter()
.map(A::read_dependency)
.chain(self.write.into_iter().map(A::write_dependency))
.collect()
}
}
impl<A: Access> Default for DependencyQueue<A> {
fn default() -> Self {
Self::new()
}
}