use crate::domain::Domain;
use serde::{Deserialize, Serialize};
use std::fmt;
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum KernelMode {
Batch,
Ring,
}
impl KernelMode {
#[must_use]
pub const fn is_batch(&self) -> bool {
matches!(self, KernelMode::Batch)
}
#[must_use]
pub const fn is_ring(&self) -> bool {
matches!(self, KernelMode::Ring)
}
#[must_use]
pub const fn typical_overhead_us(&self) -> f64 {
match self {
KernelMode::Batch => 30.0, KernelMode::Ring => 0.3, }
}
#[must_use]
pub const fn as_str(&self) -> &'static str {
match self {
KernelMode::Batch => "batch",
KernelMode::Ring => "ring",
}
}
}
impl fmt::Display for KernelMode {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.as_str())
}
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct KernelMetadata {
pub id: String,
pub mode: KernelMode,
pub domain: Domain,
pub description: String,
pub expected_throughput: u64,
pub target_latency_us: f64,
pub requires_gpu_native: bool,
pub version: u32,
}
impl KernelMetadata {
#[must_use]
pub fn batch(id: impl Into<String>, domain: Domain) -> Self {
Self {
id: id.into(),
mode: KernelMode::Batch,
domain,
description: String::new(),
expected_throughput: 10_000,
target_latency_us: 50.0,
requires_gpu_native: false,
version: 1,
}
}
#[must_use]
pub fn ring(id: impl Into<String>, domain: Domain) -> Self {
Self {
id: id.into(),
mode: KernelMode::Ring,
domain,
description: String::new(),
expected_throughput: 100_000,
target_latency_us: 1.0,
requires_gpu_native: true,
version: 1,
}
}
#[must_use]
pub fn with_description(mut self, description: impl Into<String>) -> Self {
self.description = description.into();
self
}
#[must_use]
pub fn with_throughput(mut self, ops_per_sec: u64) -> Self {
self.expected_throughput = ops_per_sec;
self
}
#[must_use]
pub fn with_latency_us(mut self, latency_us: f64) -> Self {
self.target_latency_us = latency_us;
self
}
#[must_use]
pub fn with_gpu_native(mut self, required: bool) -> Self {
self.requires_gpu_native = required;
self
}
#[must_use]
pub fn with_version(mut self, version: u32) -> Self {
self.version = version;
self
}
#[must_use]
pub fn feature_string(&self) -> String {
let name = self.id.rsplit('/').next().unwrap_or(&self.id);
let name = to_pascal_case(name);
format!("{}.{}", self.domain, name)
}
#[must_use]
pub fn full_id(&self) -> String {
format!("{}/{}", self.domain.as_str().to_lowercase(), self.id)
}
}
impl Default for KernelMetadata {
fn default() -> Self {
Self::batch("unnamed", Domain::Core)
}
}
fn to_pascal_case(s: &str) -> String {
s.split(['-', '_'])
.filter(|part| !part.is_empty())
.map(|part| {
let mut chars = part.chars();
match chars.next() {
Some(first) => first.to_uppercase().chain(chars).collect::<String>(),
None => String::new(),
}
})
.collect()
}
#[derive(Default)]
pub struct KernelMetadataBuilder {
id: Option<String>,
mode: Option<KernelMode>,
domain: Option<Domain>,
description: String,
expected_throughput: u64,
target_latency_us: f64,
requires_gpu_native: bool,
version: u32,
}
impl KernelMetadataBuilder {
#[must_use]
pub fn new() -> Self {
Self {
expected_throughput: 10_000,
target_latency_us: 50.0,
version: 1,
..Default::default()
}
}
#[must_use]
pub fn id(mut self, id: impl Into<String>) -> Self {
self.id = Some(id.into());
self
}
#[must_use]
pub fn mode(mut self, mode: KernelMode) -> Self {
self.mode = Some(mode);
self
}
#[must_use]
pub fn domain(mut self, domain: Domain) -> Self {
self.domain = Some(domain);
self
}
#[must_use]
pub fn description(mut self, description: impl Into<String>) -> Self {
self.description = description.into();
self
}
#[must_use]
pub fn throughput(mut self, ops_per_sec: u64) -> Self {
self.expected_throughput = ops_per_sec;
self
}
#[must_use]
pub fn latency_us(mut self, latency_us: f64) -> Self {
self.target_latency_us = latency_us;
self
}
#[must_use]
pub fn gpu_native(mut self, required: bool) -> Self {
self.requires_gpu_native = required;
self
}
#[must_use]
pub fn version(mut self, version: u32) -> Self {
self.version = version;
self
}
#[must_use]
pub fn build(self) -> KernelMetadata {
KernelMetadata {
id: self.id.expect("id is required"),
mode: self.mode.expect("mode is required"),
domain: self.domain.expect("domain is required"),
description: self.description,
expected_throughput: self.expected_throughput,
target_latency_us: self.target_latency_us,
requires_gpu_native: self.requires_gpu_native,
version: self.version,
}
}
#[must_use]
pub fn try_build(self) -> Option<KernelMetadata> {
Some(KernelMetadata {
id: self.id?,
mode: self.mode?,
domain: self.domain?,
description: self.description,
expected_throughput: self.expected_throughput,
target_latency_us: self.target_latency_us,
requires_gpu_native: self.requires_gpu_native,
version: self.version,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_kernel_mode_properties() {
assert!(KernelMode::Batch.is_batch());
assert!(!KernelMode::Batch.is_ring());
assert!(KernelMode::Ring.is_ring());
assert!(!KernelMode::Ring.is_batch());
}
#[test]
fn test_kernel_metadata_batch() {
let meta = KernelMetadata::batch("pagerank", Domain::GraphAnalytics)
.with_description("PageRank centrality")
.with_throughput(100_000)
.with_latency_us(10.0);
assert_eq!(meta.id, "pagerank");
assert_eq!(meta.mode, KernelMode::Batch);
assert_eq!(meta.domain, Domain::GraphAnalytics);
assert!(!meta.requires_gpu_native);
}
#[test]
fn test_kernel_metadata_ring() {
let meta = KernelMetadata::ring("pagerank", Domain::GraphAnalytics);
assert_eq!(meta.mode, KernelMode::Ring);
assert!(meta.requires_gpu_native);
}
#[test]
fn test_feature_string() {
let meta = KernelMetadata::ring("pagerank", Domain::GraphAnalytics);
assert_eq!(meta.feature_string(), "GraphAnalytics.Pagerank");
let meta = KernelMetadata::ring("graph/degree-centrality", Domain::GraphAnalytics);
assert_eq!(meta.feature_string(), "GraphAnalytics.DegreeCentrality");
}
#[test]
fn test_to_pascal_case() {
assert_eq!(to_pascal_case("pagerank"), "Pagerank");
assert_eq!(to_pascal_case("degree-centrality"), "DegreeCentrality");
assert_eq!(to_pascal_case("snake_case"), "SnakeCase");
assert_eq!(to_pascal_case("mixed-snake_case"), "MixedSnakeCase");
}
#[test]
fn test_builder() {
let meta = KernelMetadataBuilder::new()
.id("test-kernel")
.mode(KernelMode::Ring)
.domain(Domain::Core)
.throughput(50_000)
.latency_us(0.5)
.build();
assert_eq!(meta.id, "test-kernel");
assert_eq!(meta.mode, KernelMode::Ring);
assert_eq!(meta.expected_throughput, 50_000);
}
}