use std::sync::Arc;
use morok_dtype::DeviceSpec;
use morok_ir::UOp;
use crate::allocator::Allocator;
use crate::error::Result;
pub trait Program {
unsafe fn execute(
&self,
buffers: &[*mut u8],
vals: &[i64],
global_size: Option<[usize; 3]>,
local_size: Option<[usize; 3]>,
) -> Result<()>;
fn name(&self) -> &str;
}
#[derive(Debug, Clone)]
pub struct CompiledSpec {
pub name: String,
pub src: Option<String>,
pub bytes: Vec<u8>,
pub ast: Arc<UOp>,
pub var_names: Vec<String>,
pub global_size: Option<[usize; 3]>,
pub local_size: Option<[usize; 3]>,
pub buf_count: usize,
}
impl CompiledSpec {
pub fn from_source(name: String, src: String, ast: Arc<UOp>, buf_count: usize) -> Self {
Self {
name,
src: Some(src),
bytes: Vec::new(),
ast,
var_names: Vec::new(),
global_size: None,
local_size: None,
buf_count,
}
}
pub fn from_bytes(name: String, bytes: Vec<u8>, ast: Arc<UOp>) -> Self {
Self { name, src: None, bytes, ast, var_names: Vec::new(), global_size: None, local_size: None, buf_count: 0 }
}
pub fn from_source_with_sizes(
name: String,
src: String,
ast: Arc<UOp>,
global_size: Option<[usize; 3]>,
local_size: Option<[usize; 3]>,
buf_count: usize,
) -> Self {
Self { name, src: Some(src), bytes: Vec::new(), ast, var_names: Vec::new(), global_size, local_size, buf_count }
}
}
pub trait Compiler: Send + Sync {
fn compile(&self, spec: &ProgramSpec) -> Result<CompiledSpec>;
fn cache_key(&self) -> &'static str;
}
pub trait Renderer: Send + Sync {
fn render(&self, ast: &Arc<UOp>, name: Option<&str>) -> Result<ProgramSpec>;
fn device(&self) -> &DeviceSpec;
fn decompositor(&self) -> Option<morok_ir::pattern::TypedPatternMatcher<()>> {
None
}
}
pub type RuntimeFactory = Arc<dyn Fn(&CompiledSpec) -> Result<Box<dyn Program>> + Send + Sync>;
pub type CompilerPair = (Arc<dyn Renderer>, Arc<dyn Compiler>);
pub struct Device {
pub device: DeviceSpec,
pub allocator: Arc<dyn Allocator>,
pub compilers: Vec<CompilerPair>,
pub renderer: Arc<dyn Renderer>,
pub compiler: Arc<dyn Compiler>,
pub runtime: RuntimeFactory,
}
impl Device {
pub fn new(
device: DeviceSpec,
allocator: Arc<dyn Allocator>,
renderer: Arc<dyn Renderer>,
compiler: Arc<dyn Compiler>,
runtime: RuntimeFactory,
) -> Self {
let compilers = vec![(renderer.clone(), compiler.clone())];
Self { device, allocator, compilers, renderer, compiler, runtime }
}
pub fn base_device_key(&self) -> &'static str {
self.device.base_type()
}
}
#[derive(Debug, Clone)]
pub struct ProgramSpec {
pub name: String,
pub src: String,
pub device: DeviceSpec,
pub ast: Arc<UOp>,
pub global_size: Option<[usize; 3]>,
pub local_size: Option<[usize; 3]>,
pub vars: Vec<Variable>,
pub var_names: Vec<String>,
pub globals: Vec<usize>,
pub outs: Vec<usize>,
pub ins: Vec<usize>,
pub buf_count: usize,
}
impl ProgramSpec {
pub fn new(name: String, src: String, device: DeviceSpec, ast: Arc<UOp>) -> Self {
Self {
name,
src,
device,
ast,
global_size: None,
local_size: None,
vars: Vec::new(),
var_names: Vec::new(),
globals: Vec::new(),
outs: Vec::new(),
ins: Vec::new(),
buf_count: 0,
}
}
pub fn add_var(&mut self, var: Variable) {
self.vars.push(var);
}
pub fn set_work_sizes(&mut self, global: [usize; 3], local: [usize; 3]) {
self.global_size = Some(global);
self.local_size = Some(local);
}
pub fn set_var_names(&mut self, var_names: Vec<String>) {
self.var_names = var_names;
}
pub fn set_buffer_metadata(&mut self, globals: Vec<usize>, outs: Vec<usize>, ins: Vec<usize>) {
self.globals = globals;
self.outs = outs;
self.ins = ins;
}
}
#[derive(Debug, Clone)]
pub struct Variable {
pub name: String,
pub min: i64,
pub max: i64,
}
impl Variable {
pub fn new(name: String, min: i64, max: i64) -> Self {
Self { name, min, max }
}
}