vyre-driver 0.4.1

Driver layer: registry, runtime, pipeline, routing, diagnostics. Substrate-agnostic backend machinery. Part of the vyre GPU compiler.
//! Cross-dispatch fusion decisions shared by concrete backends.

use crate::specialization::SpecMap;

/// One dispatch's pre-fusion description.
#[derive(Debug, Clone)]
pub struct DispatchShape {
    /// Stable id for this dispatch inside the containing program.
    pub id: &'static str,
    /// Workgroup size `[x, y, z]`.
    pub workgroup_size: [u32; 3],
    /// Per-dispatch shared memory bytes.
    pub shared_memory_bytes: u32,
    /// Buffers this dispatch reads.
    pub inputs: Vec<&'static str>,
    /// Buffers this dispatch writes.
    pub outputs: Vec<&'static str>,
    /// Specialization constants baked into this dispatch.
    pub specs: SpecMap,
}

/// Adapter caps honored by the generic fusion pass.
#[derive(Debug, Clone, Copy)]
pub struct FusionCaps {
    /// Maximum workgroup-shared memory the adapter can serve.
    pub max_shared_memory_bytes: u32,
    /// Maximum workgroup invocation count.
    pub max_invocations_per_workgroup: u32,
}

impl Default for FusionCaps {
    fn default() -> Self {
        Self {
            max_shared_memory_bytes: 16 * 1024,
            max_invocations_per_workgroup: 256,
        }
    }
}

impl FusionCaps {
    /// High-end profile for tests and capability probes.
    #[must_use]
    pub const fn high_end() -> Self {
        Self {
            max_shared_memory_bytes: 128 * 1024,
            max_invocations_per_workgroup: 1024,
        }
    }
}

/// Why the fusion pass accepted or rejected a pair.
#[derive(Debug, Clone, PartialEq, Eq)]
#[non_exhaustive]
pub enum FusionDecision {
    /// Fusion is legal; the concrete backend may stitch its target modules.
    Accept,
    /// Workgroup size mismatch or invocation budget violation.
    WorkgroupSizeMismatch {
        /// Upstream size.
        upstream: [u32; 3],
        /// Downstream size.
        downstream: [u32; 3],
    },
    /// Shared-memory budget would exceed adapter caps.
    SharedMemoryBudget {
        /// Combined bytes the fused kernel would request.
        needed: u32,
        /// Adapter cap.
        cap: u32,
    },
    /// A flow-through output is still consumed by a third dispatch.
    OutputConsumedElsewhere,
    /// No buffer flows from upstream outputs to downstream inputs.
    NoPipelineDependency,
}

impl std::fmt::Display for FusionDecision {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        match self {
            Self::Accept => f.write_str("accept"),
            Self::WorkgroupSizeMismatch {
                upstream,
                downstream,
            } => write!(
                f,
                "workgroup-size-mismatch:{:?}->{:?}",
                upstream, downstream
            ),
            Self::SharedMemoryBudget { needed, cap } => {
                write!(f, "shared-memory-budget:{needed}/{cap}")
            }
            Self::OutputConsumedElsewhere => f.write_str("output-consumed-elsewhere"),
            Self::NoPipelineDependency => f.write_str("no-pipeline-dependency"),
        }
    }
}

/// Pure cross-dispatch fusion analysis.
pub struct FusionPass;

impl FusionPass {
    /// Decide whether `upstream` -> `downstream` is legal to fuse.
    #[must_use]
    pub fn decide(
        upstream: &DispatchShape,
        downstream: &DispatchShape,
        caps: FusionCaps,
        other_consumers: &[&str],
    ) -> FusionDecision {
        if upstream.workgroup_size != downstream.workgroup_size {
            return FusionDecision::WorkgroupSizeMismatch {
                upstream: upstream.workgroup_size,
                downstream: downstream.workgroup_size,
            };
        }
        let invocations = upstream.workgroup_size[0]
            .saturating_mul(upstream.workgroup_size[1])
            .saturating_mul(upstream.workgroup_size[2]);
        if invocations > caps.max_invocations_per_workgroup {
            return FusionDecision::WorkgroupSizeMismatch {
                upstream: upstream.workgroup_size,
                downstream: downstream.workgroup_size,
            };
        }
        let needed = upstream
            .shared_memory_bytes
            .saturating_add(downstream.shared_memory_bytes);
        if needed > caps.max_shared_memory_bytes {
            return FusionDecision::SharedMemoryBudget {
                needed,
                cap: caps.max_shared_memory_bytes,
            };
        }

        let mut has_pipeline_dependency = false;
        for output in &upstream.outputs {
            if !downstream.inputs.iter().any(|input| input == output) {
                continue;
            }
            has_pipeline_dependency = true;
            if other_consumers.iter().any(|consumer| consumer == output) {
                return FusionDecision::OutputConsumedElsewhere;
            }
        }
        if !has_pipeline_dependency {
            return FusionDecision::NoPipelineDependency;
        }
        FusionDecision::Accept
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    fn dispatch(
        id: &'static str,
        inputs: &[&'static str],
        outputs: &[&'static str],
    ) -> DispatchShape {
        DispatchShape {
            id,
            workgroup_size: [64, 1, 1],
            shared_memory_bytes: 1024,
            inputs: inputs.to_vec(),
            outputs: outputs.to_vec(),
            specs: SpecMap::new(),
        }
    }

    #[test]
    fn straight_producer_consumer_fuses() {
        let up = dispatch("load", &["in"], &["stage"]);
        let down = dispatch("xor", &["stage"], &["out"]);
        assert_eq!(
            FusionPass::decide(&up, &down, FusionCaps::high_end(), &[]),
            FusionDecision::Accept
        );
    }

    #[test]
    fn third_consumer_rejects() {
        let up = dispatch("a", &[], &["x"]);
        let down = dispatch("b", &["x"], &[]);
        assert_eq!(
            FusionPass::decide(&up, &down, FusionCaps::high_end(), &["x"]),
            FusionDecision::OutputConsumedElsewhere
        );
    }

    #[test]
    fn fusion_decision_formats_human_string() {
        assert_eq!(
            format!(
                "{}",
                FusionDecision::WorkgroupSizeMismatch {
                    upstream: [8, 4, 1],
                    downstream: [4, 4, 1]
                }
            ),
            "workgroup-size-mismatch:[8, 4, 1]->[4, 4, 1]"
        );
        assert_eq!(
            format!(
                "{}",
                FusionDecision::SharedMemoryBudget {
                    needed: 4096,
                    cap: 2048
                }
            ),
            "shared-memory-budget:4096/2048"
        );
        assert_eq!(
            format!("{}", FusionDecision::Accept),
            "accept"
        );
    }
}