1use crate::device::oneapi_device;
24use crate::host::{self, HostBuf};
25use crate::kernels::kernels;
26use rlx_compile::memory::{BufferSlot, MemoryPlan};
27use rlx_ir::op::Activation;
28use rlx_ir::{DType, Dim, Graph, NodeId, Op, RngOptions, Shape};
29use std::collections::HashMap;
30use std::ffi::c_void;
31
32pub const SUPPORTED_OPS: &[rlx_ir::OpKind] = {
36 use rlx_ir::OpKind::*;
37 &[
38 Input,
39 Param,
40 Constant,
41 Cast,
42 StopGradient,
43 Reshape, Binary,
45 Compare,
46 Where,
47 Activation, MatMul,
49 Reduce,
50 Softmax, LayerNorm,
52 RmsNorm,
53 LayerNorm2d, Rope,
55 Attention, FusedAttentionBlock,
59 Transpose,
60 Narrow,
61 Concat,
62 Expand,
63 Gather,
64 Cumsum,
65 Reverse, ArgMax,
67 ArgMin,
68 Pool,
69 ResizeNearest2x,
70 Conv, GroupedMatMul, SelectiveScan, Im2Col,
74 ScatterAdd,
75 TopK, Lstm,
77 Gru,
78 Rnn,
79 Mamba2,
80 GatedDeltaNet,
81 ConvTranspose2d,
82 Fft,
83 DequantMatMul,
84 DequantGroupedMatMul,
85 DequantMoEWeights, RngNormal,
87 RngUniform,
88 Sample, ]
90};
91
92fn native_kernel(op: &Op) -> Option<&'static str> {
96 match op {
97 Op::Binary(_) => Some("binary"),
98 Op::Activation(_) => Some("unary"),
99 Op::MatMul => Some("matmul"),
100 Op::Softmax { .. } => Some("softmax"),
101 Op::RmsNorm { .. } => Some("rmsnorm"),
102 _ => None,
103 }
104}
105
106#[derive(Clone)]
107enum ParamVal {
108 F32(Vec<f32>),
109 Bytes(Vec<u8>),
110}
111
112pub struct OneApiExecutable {
113 graph: Graph,
115 params: HashMap<String, ParamVal>,
116 output_ids: Vec<NodeId>,
117 output_dtypes: Vec<DType>,
118 rng: RngOptions,
119 active_extent: Option<(usize, usize)>,
120}
121
122unsafe impl Send for OneApiExecutable {}
123
124impl OneApiExecutable {
125 pub fn compile(graph: Graph) -> Self {
126 Self::compile_rng(graph, RngOptions::default())
127 }
128
129 pub fn compile_rng(graph: Graph, rng: RngOptions) -> Self {
131 use rlx_opt::pass::Pass as _;
132
133 let graph = rlx_opt::LowerControlFlow.run(graph);
134 let graph = rlx_opt::unfuse::unfuse_attention_block(graph);
138 let graph = rlx_opt::legalize_or_rewrite_for_backend(graph, SUPPORTED_OPS)
139 .unwrap_or_else(|errs| panic!("{}", rlx_opt::format_legalize_error("oneapi", &errs)));
140 let graph = rlx_opt::LegalizeBroadcast.run(graph);
141
142 let output_ids = graph.outputs.clone();
143 let output_dtypes = output_ids
144 .iter()
145 .map(|&id| graph.node(id).shape.dtype())
146 .collect();
147
148 Self {
149 graph,
150 params: HashMap::new(),
151 output_ids,
152 output_dtypes,
153 rng,
154 active_extent: None,
155 }
156 }
157
158 pub fn set_param(&mut self, name: &str, data: &[f32]) {
159 self.params
160 .insert(name.to_string(), ParamVal::F32(data.to_vec()));
161 }
162
163 pub fn set_param_bytes(&mut self, name: &str, data: &[u8]) {
164 self.params
165 .insert(name.to_string(), ParamVal::Bytes(data.to_vec()));
166 }
167
168 pub fn output_dtypes(&self) -> Vec<DType> {
169 self.output_dtypes.clone()
170 }
171
172 pub fn set_active_extent(&mut self, extent: Option<(usize, usize)>) {
173 self.active_extent = extent;
174 }
175
176 pub fn set_rng(&mut self, rng: RngOptions) {
177 self.rng = rng;
178 }
179
180 pub fn rng(&self) -> RngOptions {
181 self.rng
182 }
183
184 pub fn run(&mut self, inputs: &[(&str, &[f32])]) -> Vec<Vec<f32>> {
185 self.run_read_outputs(inputs, None)
186 }
187
188 pub fn run_read_outputs(
189 &mut self,
190 inputs: &[(&str, &[f32])],
191 read_indices: Option<&[usize]>,
192 ) -> Vec<Vec<f32>> {
193 if oneapi_device().is_some() && kernels().is_some() {
196 self.run_l0(inputs, read_indices)
197 } else {
198 self.run_host(inputs, read_indices)
199 }
200 }
201
202 fn run_host(&self, inputs: &[(&str, &[f32])], read_indices: Option<&[usize]>) -> Vec<Vec<f32>> {
205 let in_map: HashMap<&str, &[f32]> = inputs.iter().copied().collect();
206 let mut f32v: HashMap<NodeId, Vec<f32>> = HashMap::new();
207 let mut bytev: HashMap<NodeId, Vec<u8>> = HashMap::new();
208
209 for node in self.graph.nodes() {
210 let numel = node.shape.num_elements().unwrap_or(0);
211 match &node.op {
212 Op::Input { name } => {
213 let v = in_map
214 .get(name.as_str())
215 .map(|s| s.to_vec())
216 .unwrap_or_else(|| vec![0.0; numel]);
217 f32v.insert(node.id, v);
218 }
219 Op::Param { name } => match self.params.get(name) {
220 Some(ParamVal::F32(v)) => {
221 f32v.insert(node.id, v.clone());
222 }
223 Some(ParamVal::Bytes(b)) => {
224 bytev.insert(node.id, b.clone());
225 }
226 None => {
227 f32v.insert(node.id, vec![0.0; numel]);
228 }
229 },
230 Op::Constant { data } => {
231 if matches!(node.shape.dtype(), DType::U8 | DType::I8) {
232 bytev.insert(node.id, data.clone());
233 } else {
234 f32v.insert(node.id, widen_const_to_f32(data, node.shape.dtype()));
235 }
236 }
237 _ => {
238 let in_specs: Vec<(Shape, HostBuf)> = node
239 .inputs
240 .iter()
241 .map(|&id| {
242 let sh = self.graph.node(id).shape.clone();
243 let buf = if let Some(b) = bytev.get(&id) {
244 HostBuf::Bytes(b.clone())
245 } else {
246 HostBuf::F32(f32v.get(&id).cloned().unwrap_or_default())
247 };
248 (sh, buf)
249 })
250 .collect();
251 let out = host::eval(&node.op, &node.shape, &in_specs);
252 f32v.insert(node.id, out);
253 }
254 }
255 }
256
257 self.read_outputs(read_indices, |id, n| {
258 f32v.get(&id)
259 .map(|v| v[..n.min(v.len())].to_vec())
260 .unwrap_or_else(|| vec![0.0; n])
261 })
262 }
263
264 fn run_l0(
267 &mut self,
268 inputs: &[(&str, &[f32])],
269 read_indices: Option<&[usize]>,
270 ) -> Vec<Vec<f32>> {
271 let dev = oneapi_device().expect("rlx-oneapi: no device");
272 let kerns = kernels().expect("rlx-oneapi: no kernels");
273
274 let plan = plan_f32_uniform(&self.graph, 64);
275 let arena = match crate::arena::Arena::from_plan(&plan) {
276 Ok(a) => a,
277 Err(_) => return self.run_host(inputs, read_indices),
280 };
281
282 for node in self.graph.nodes() {
284 match &node.op {
285 Op::Constant { data } if arena.has(node.id) && !data.is_empty() => {
286 if matches!(node.shape.dtype(), DType::U8 | DType::I8) {
287 arena.write_bytes(node.id, data);
288 } else {
289 arena.write_f32(node.id, &widen_const_to_f32(data, node.shape.dtype()));
290 }
291 }
292 Op::Param { name } => match self.params.get(name) {
293 Some(ParamVal::F32(v)) => arena.write_f32(node.id, v),
294 Some(ParamVal::Bytes(b)) => arena.write_bytes(node.id, b),
295 None => {}
296 },
297 _ => {}
298 }
299 }
300 let in_map: HashMap<&str, &[f32]> = inputs.iter().copied().collect();
301 for node in self.graph.nodes() {
302 if let Op::Input { name } = &node.op {
303 if let Some(data) = in_map.get(name.as_str()) {
304 arena.write_f32(node.id, data);
305 }
306 }
307 }
308
309 let list = dev.create_command_list().expect("rlx-oneapi: command list");
312 for node in self.graph.nodes() {
313 if matches!(
314 node.op,
315 Op::Input { .. }
316 | Op::Param { .. }
317 | Op::Constant { .. }
318 | Op::Reshape { .. }
319 | Op::Cast { .. }
320 | Op::StopGradient
321 ) {
322 continue;
323 }
324 match native_kernel(&node.op) {
325 Some(name) => self.dispatch(dev, kerns, list, name, node, &arena),
326 None => {
327 let in_specs: Vec<(Shape, HostBuf)> = node
329 .inputs
330 .iter()
331 .map(|&id| {
332 let sh = self.graph.node(id).shape.clone();
333 let nn = sh.num_elements().unwrap_or(0);
334 let buf = if matches!(sh.dtype(), DType::U8 | DType::I8) {
335 HostBuf::Bytes(arena.read_bytes(id, nn))
336 } else {
337 HostBuf::F32(arena.read_f32(id, nn))
338 };
339 (sh, buf)
340 })
341 .collect();
342 let out = host::eval(&node.op, &node.shape, &in_specs);
343 arena.write_f32(node.id, &out);
344 }
345 }
346 }
347 dev.execute_sync(list).expect("rlx-oneapi: execute");
348 unsafe {
349 let _ = (dev.lib.command_list_destroy)(list);
350 }
351
352 self.read_outputs(read_indices, |id, n| arena.read_f32(id, n))
353 }
354
355 fn dispatch(
358 &self,
359 dev: &crate::device::OneApiDevice,
360 kerns: &crate::kernels::Kernels,
361 list: crate::level_zero::CommandListHandle,
362 name: &str,
363 node: &rlx_ir::Node,
364 arena: &crate::arena::Arena,
365 ) {
366 let Some(kernel) = kerns.get(name) else {
367 return;
368 };
369 let off = |id: NodeId| arena.elem_offset(id);
370 let out = node.id;
371 let mut args: Vec<KArg> = vec![KArg::Ptr(arena.base_ptr())];
372 let (global, local): (usize, u32) = match &node.op {
373 Op::Binary(op) => {
374 let a = node.inputs[0];
375 let b = node.inputs[1];
376 let n = numel(&dims(&self.graph, out));
377 let an = numel(&dims(&self.graph, a));
378 let bn = numel(&dims(&self.graph, b));
379 args.extend([
380 KArg::U32(n as u32),
381 KArg::U32(off(a)),
382 KArg::U32(off(b)),
383 KArg::U32(off(out)),
384 KArg::U32(if an == n { 0 } else { an as u32 }),
385 KArg::U32(if bn == n { 0 } else { bn as u32 }),
386 KArg::U32(binop_id(*op)),
387 ]);
388 (n, 256)
389 }
390 Op::Activation(act) => {
391 let x = node.inputs[0];
392 let n = numel(&dims(&self.graph, out));
393 args.extend([
394 KArg::U32(n as u32),
395 KArg::U32(off(x)),
396 KArg::U32(off(out)),
397 KArg::U32(act_id(*act)),
398 ]);
399 (n, 256)
400 }
401 Op::MatMul => {
402 let a = node.inputs[0];
403 let b = node.inputs[1];
404 let ad = dims(&self.graph, a);
405 let bd = dims(&self.graph, b);
406 let od = dims(&self.graph, out);
407 let (m, k) = (ad[ad.len() - 2], ad[ad.len() - 1]);
408 let n = bd[bd.len() - 1];
409 let batch = if od.len() > 2 {
410 numel(&od[..od.len() - 2])
411 } else {
412 1
413 };
414 let a_batch = if ad.len() > 2 {
415 numel(&ad[..ad.len() - 2])
416 } else {
417 1
418 };
419 let b_batch = if bd.len() > 2 {
420 numel(&bd[..bd.len() - 2])
421 } else {
422 1
423 };
424 let a_bs = if a_batch <= 1 { 0 } else { m * k };
425 let b_bs = if b_batch <= 1 { 0 } else { k * n };
426 args.extend([
427 KArg::U32(m as u32),
428 KArg::U32(k as u32),
429 KArg::U32(n as u32),
430 KArg::U32(off(a)),
431 KArg::U32(off(b)),
432 KArg::U32(off(out)),
433 KArg::U32(batch as u32),
434 KArg::U32(a_bs as u32),
435 KArg::U32(b_bs as u32),
436 KArg::U32((m * n) as u32),
437 ]);
438 (batch.max(1) * m * n, 64)
439 }
440 Op::Softmax { axis } => {
441 let x = node.inputs[0];
442 let xd = dims(&self.graph, x);
443 let ax = norm_axis(*axis, xd.len());
444 let axis_len = xd[ax];
445 let outer = numel(&xd[..ax]);
446 let inner = numel(&xd[ax + 1..]);
447 args.extend([
448 KArg::U32(outer as u32),
449 KArg::U32(axis_len as u32),
450 KArg::U32(inner as u32),
451 KArg::U32(off(x)),
452 KArg::U32(off(out)),
453 ]);
454 (outer * inner, 256)
455 }
456 Op::RmsNorm { axis, eps } => {
457 let x = node.inputs[0];
458 let gamma = node.inputs[1];
459 let beta = node.inputs[2];
460 let xd = dims(&self.graph, x);
461 let ax = norm_axis(*axis, xd.len());
462 let n = xd[ax];
463 let rows = numel(&xd) / n.max(1);
464 args.extend([
465 KArg::U32(rows as u32),
466 KArg::U32(n as u32),
467 KArg::U32(off(x)),
468 KArg::U32(off(gamma)),
469 KArg::U32(off(beta)),
470 KArg::U32(off(out)),
471 KArg::F32(*eps),
472 ]);
473 (rows, 64)
474 }
475 _ => return,
476 };
477
478 unsafe {
479 let _ = (dev.lib.kernel_set_group_size)(kernel, local, 1, 1);
480 for (i, a) in args.iter().enumerate() {
481 let (size, ptr) = a.as_arg();
482 let _ = (dev.lib.kernel_set_argument_value)(kernel, i as u32, size, ptr);
483 }
484 let groups = crate::level_zero::GroupCount {
485 group_count_x: ceil_div(global, local).max(1),
486 group_count_y: 1,
487 group_count_z: 1,
488 };
489 let _ = (dev.lib.command_list_append_launch_kernel)(
490 list,
491 kernel,
492 &groups,
493 std::ptr::null_mut(),
494 0,
495 std::ptr::null_mut(),
496 );
497 let _ = (dev.lib.command_list_append_barrier)(
499 list,
500 std::ptr::null_mut(),
501 0,
502 std::ptr::null_mut(),
503 );
504 }
505 }
506
507 fn read_outputs(
508 &self,
509 read_indices: Option<&[usize]>,
510 mut read: impl FnMut(NodeId, usize) -> Vec<f32>,
511 ) -> Vec<Vec<f32>> {
512 let want: Vec<usize> = match read_indices {
513 Some(ix) => ix.to_vec(),
514 None => (0..self.output_ids.len()).collect(),
515 };
516 want.into_iter()
517 .filter_map(|i| {
518 let id = *self.output_ids.get(i)?;
519 let n = self.graph.node(id).shape.num_elements().unwrap_or(0);
520 Some(read(id, n))
521 })
522 .collect()
523 }
524
525 pub fn clone_for_cache(&self) -> Self {
528 Self {
529 graph: self.graph.clone(),
530 params: self.params.clone(),
531 output_ids: self.output_ids.clone(),
532 output_dtypes: self.output_dtypes.clone(),
533 rng: self.rng,
534 active_extent: self.active_extent,
535 }
536 }
537}
538
539enum KArg {
542 Ptr(*mut c_void),
543 U32(u32),
544 F32(f32),
545}
546
547impl KArg {
548 fn as_arg(&self) -> (usize, *const c_void) {
552 match self {
553 KArg::Ptr(p) => (
554 std::mem::size_of::<*mut c_void>(),
555 p as *const *mut c_void as *const c_void,
556 ),
557 KArg::U32(v) => (4, v as *const u32 as *const c_void),
558 KArg::F32(v) => (4, v as *const f32 as *const c_void),
559 }
560 }
561}
562
563fn plan_f32_uniform(graph: &Graph, align: usize) -> MemoryPlan {
566 let mut assignments: HashMap<NodeId, BufferSlot> = HashMap::new();
567 let mut schedule = Vec::with_capacity(graph.nodes().len());
568 let mut cursor = 0usize;
569 for node in graph.nodes() {
570 if matches!(
571 node.op,
572 Op::Reshape { .. } | Op::Cast { .. } | Op::StopGradient
573 ) {
574 if let Some(in_id) = node.inputs.first() {
575 if let Some(slot) = assignments.get(in_id) {
576 let aliased = slot.clone();
577 assignments.insert(node.id, aliased);
578 schedule.push(node.id);
579 continue;
580 }
581 }
582 }
583 let elems = node.shape.num_elements().unwrap_or(0);
584 let bytes = (elems * 4).max(4);
585 let aligned = bytes.div_ceil(align) * align;
586 assignments.insert(
587 node.id,
588 BufferSlot {
589 offset: cursor,
590 size: aligned,
591 },
592 );
593 schedule.push(node.id);
594 cursor += aligned;
595 }
596 MemoryPlan {
597 arena_size: cursor.max(align),
598 assignments,
599 schedule,
600 }
601}
602
603fn dims(graph: &Graph, id: NodeId) -> Vec<usize> {
606 graph
607 .node(id)
608 .shape
609 .dims()
610 .iter()
611 .map(|d| match d {
612 Dim::Static(s) => *s,
613 _ => 0,
614 })
615 .collect()
616}
617
618fn numel(d: &[usize]) -> usize {
619 d.iter()
620 .product::<usize>()
621 .max(if d.is_empty() { 1 } else { 0 })
622}
623
624fn norm_axis(axis: i32, rank: usize) -> usize {
625 if axis < 0 {
626 (rank as i32 + axis).max(0) as usize
627 } else {
628 (axis as usize).min(rank.saturating_sub(1))
629 }
630}
631
632fn ceil_div(n: usize, d: u32) -> u32 {
633 (n as u64).div_ceil(d as u64) as u32
634}
635
636fn act_id(a: Activation) -> u32 {
637 match a {
638 Activation::Gelu => 0,
639 Activation::GeluApprox => 1,
640 Activation::Silu => 2,
641 Activation::Relu => 3,
642 Activation::Sigmoid => 4,
643 Activation::Tanh => 5,
644 Activation::Exp => 6,
645 Activation::Log => 7,
646 Activation::Sqrt => 8,
647 Activation::Rsqrt => 9,
648 Activation::Neg => 10,
649 Activation::Abs => 11,
650 Activation::Sin => 12,
651 Activation::Cos => 13,
652 Activation::Tan => 14,
653 Activation::Atan => 15,
654 Activation::Round => 16,
655 }
656}
657
658fn binop_id(op: rlx_ir::op::BinaryOp) -> u32 {
659 use rlx_ir::op::BinaryOp::*;
660 match op {
661 Add => 0,
662 Sub => 1,
663 Mul => 2,
664 Div => 3,
665 Max => 4,
666 Min => 5,
667 Pow => 6,
668 }
669}
670
671fn widen_const_to_f32(data: &[u8], dt: DType) -> Vec<f32> {
673 match dt {
674 DType::F32 => data
675 .chunks_exact(4)
676 .map(|c| f32::from_le_bytes([c[0], c[1], c[2], c[3]]))
677 .collect(),
678 DType::F16 => data
679 .chunks_exact(2)
680 .map(|c| half::f16::from_le_bytes([c[0], c[1]]).to_f32())
681 .collect(),
682 DType::BF16 => data
683 .chunks_exact(2)
684 .map(|c| half::bf16::from_le_bytes([c[0], c[1]]).to_f32())
685 .collect(),
686 DType::F64 => data
687 .chunks_exact(8)
688 .map(|c| f64::from_le_bytes([c[0], c[1], c[2], c[3], c[4], c[5], c[6], c[7]]) as f32)
689 .collect(),
690 DType::I64 => data
691 .chunks_exact(8)
692 .map(|c| i64::from_le_bytes([c[0], c[1], c[2], c[3], c[4], c[5], c[6], c[7]]) as f32)
693 .collect(),
694 DType::I32 | DType::U32 => data
695 .chunks_exact(4)
696 .map(|c| i32::from_le_bytes([c[0], c[1], c[2], c[3]]) as f32)
697 .collect(),
698 DType::I16 => data
699 .chunks_exact(2)
700 .map(|c| i16::from_le_bytes([c[0], c[1]]) as f32)
701 .collect(),
702 DType::I8 => data.iter().map(|&b| b as i8 as f32).collect(),
703 DType::U8 | DType::Bool => data.iter().map(|&b| b as f32).collect(),
704 DType::C64 => data
705 .chunks_exact(4)
706 .map(|c| f32::from_le_bytes([c[0], c[1], c[2], c[3]]))
707 .collect(),
708 }
709}