1use std::collections::HashMap;
4
5use tracing::{debug, instrument};
6
7use sapient_core::buffer::{BufferHandle, CpuBuffer};
8use sapient_core::error::{Result, SapientError};
9use sapient_core::{DType, Tensor};
10use sapient_ir::graph::Graph;
11use sapient_ir::node::{Node, NodeId};
12use sapient_ir::op::OpType;
13
14use crate::kernels;
15use crate::pool::PoolAllocator;
16
17pub trait ExecutionBackend: Send + Sync {
24 fn name(&self) -> &str;
26
27 fn allocate(&self, shape: &[usize], dtype: DType) -> Result<BufferHandle>;
29
30 fn execute(&self, graph: &Graph, inputs: HashMap<String, Tensor>) -> Result<Vec<Tensor>>;
33
34 fn supports_op(&self, op: &OpType) -> bool;
36
37 fn is_available() -> bool
39 where
40 Self: Sized,
41 {
42 true
43 }
44}
45
46pub struct CpuBackend {
50 pool: PoolAllocator,
51}
52
53impl std::fmt::Debug for CpuBackend {
54 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
55 f.debug_struct("CpuBackend").finish()
56 }
57}
58
59impl Default for CpuBackend {
60 fn default() -> Self {
61 Self::new(256 * 1024 * 1024) }
63}
64
65impl CpuBackend {
66 pub fn new(pool_bytes: usize) -> Self {
68 Self {
69 pool: PoolAllocator::new(pool_bytes),
70 }
71 }
72}
73
74impl ExecutionBackend for CpuBackend {
75 fn name(&self) -> &str {
76 "cpu"
77 }
78
79 fn allocate(&self, shape: &[usize], dtype: DType) -> Result<BufferHandle> {
80 let numel: usize = shape.iter().product();
81 if let Some(handle) = self.pool.acquire(numel, dtype) {
83 return Ok(handle);
84 }
85 let buf = CpuBuffer::zeros(numel, dtype)?;
86 Ok(BufferHandle::new(buf))
87 }
88
89 #[instrument(skip_all, fields(graph = %graph.name))]
90 fn execute(&self, graph: &Graph, inputs: HashMap<String, Tensor>) -> Result<Vec<Tensor>> {
91 let order = graph.topological_order()?;
93
94 let mut values: HashMap<(NodeId, usize), Tensor> = HashMap::new();
96
97 for id in &graph.inputs {
99 if let Some(Node::Input { name, .. }) = graph.get(*id) {
100 if let Some(t) = inputs.get(name) {
101 values.insert((*id, 0), t.clone());
102 }
103 }
104 }
105
106 for id in &order {
107 match graph.get(*id) {
108 Some(Node::Constant { value, .. }) => {
109 values.insert((*id, 0), value.clone());
110 }
111 Some(Node::Input { .. }) => {
112 }
114 Some(Node::Operator {
115 op,
116 inputs: inp_ids,
117 num_outputs,
118 ..
119 }) => {
120 let op = op.clone();
121 let inp_ids = inp_ids.clone();
122 let _num_outputs = *num_outputs;
123
124 let input_tensors: Vec<Tensor> = inp_ids
126 .iter()
127 .map(|&inp| {
128 values.get(&(inp, 0)).cloned().ok_or_else(|| {
129 SapientError::internal(format!("missing value for node {inp}"))
130 })
131 })
132 .collect::<Result<Vec<_>>>()?;
133
134 let outputs = self.dispatch(&op, &input_tensors)?;
136
137 for (i, t) in outputs.into_iter().enumerate() {
138 values.insert((*id, i), t);
139 }
140 }
141 Some(Node::Output { source, .. }) => {
142 if let Some(t) = values.get(&(*source, 0)).cloned() {
144 values.insert((*id, 0), t);
145 }
146 }
147 None => {}
148 }
149 }
150
151 let out_tensors: Vec<Tensor> = graph
153 .outputs
154 .iter()
155 .map(|&oid| {
156 values
157 .get(&(oid, 0))
158 .cloned()
159 .ok_or_else(|| SapientError::internal(format!("output {oid} not computed")))
160 })
161 .collect::<Result<Vec<_>>>()?;
162
163 debug!(
164 outputs = out_tensors.len(),
165 "CpuBackend: execution complete"
166 );
167 Ok(out_tensors)
168 }
169
170 fn supports_op(&self, op: &OpType) -> bool {
171 matches!(
172 op,
173 OpType::MatMul | OpType::Gemm { .. }
174 | OpType::Add | OpType::Sub | OpType::Mul | OpType::Div | OpType::Pow
175 | OpType::Neg | OpType::Abs | OpType::Sqrt | OpType::Exp | OpType::Log
176 | OpType::Relu | OpType::Sigmoid | OpType::Tanh | OpType::Gelu
177 | OpType::LeakyRelu { .. } | OpType::Silu | OpType::HardSwish
178 | OpType::Softmax { .. } | OpType::LogSoftmax { .. }
179 | OpType::LayerNorm { .. } | OpType::RmsNorm { .. }
180 | OpType::Conv2d { .. }
181 | OpType::Reshape | OpType::Transpose { .. } | OpType::Flatten { .. }
182 | OpType::Concat { .. }
183 | OpType::ReduceSum { .. } | OpType::ReduceMean { .. }
184 | OpType::ReduceMax { .. } | OpType::ReduceMin { .. }
185 | OpType::Identity | OpType::Clip { .. }
186 | OpType::Erf | OpType::Floor | OpType::Ceil | OpType::Round
187 | OpType::Embedding { .. }
189 | OpType::MultiHeadAttention { .. }
190 | OpType::GroupedQueryAttention { .. }
191 | OpType::RotaryEmbedding { .. }
192 | OpType::CausalMask
193 | OpType::KVCacheConcat
194 | OpType::RepeatKV { .. }
195 )
196 }
197}
198
199impl CpuBackend {
200 fn dispatch(&self, op: &OpType, inputs: &[Tensor]) -> Result<Vec<Tensor>> {
202 let out = match op {
203 OpType::MatMul => {
205 let a = inputs
206 .get(0)
207 .ok_or_else(|| SapientError::internal("MatMul: missing a"))?;
208 let b = inputs
209 .get(1)
210 .ok_or_else(|| SapientError::internal("MatMul: missing b"))?;
211 vec![kernels::matmul::matmul(a, b)?]
212 }
213 OpType::Gemm {
214 alpha,
215 beta,
216 trans_a,
217 trans_b,
218 } => {
219 let a = inputs
220 .get(0)
221 .ok_or_else(|| SapientError::internal("Gemm: missing a"))?;
222 let b = inputs
223 .get(1)
224 .ok_or_else(|| SapientError::internal("Gemm: missing b"))?;
225 let c = inputs.get(2);
226 vec![kernels::matmul::gemm(
227 a,
228 b,
229 c,
230 alpha.0 as f32,
231 beta.0 as f32,
232 *trans_a,
233 *trans_b,
234 )?]
235 }
236
237 OpType::Add => vec![kernels::elementwise::add(
239 inputs.get(0).unwrap(),
240 inputs.get(1).unwrap(),
241 )?],
242 OpType::Sub => vec![kernels::elementwise::sub(
243 inputs.get(0).unwrap(),
244 inputs.get(1).unwrap(),
245 )?],
246 OpType::Mul => vec![kernels::elementwise::mul(
247 inputs.get(0).unwrap(),
248 inputs.get(1).unwrap(),
249 )?],
250 OpType::Div => vec![kernels::elementwise::div(
251 inputs.get(0).unwrap(),
252 inputs.get(1).unwrap(),
253 )?],
254 OpType::Pow => vec![kernels::elementwise::pow(
255 inputs.get(0).unwrap(),
256 inputs.get(1).unwrap(),
257 )?],
258 OpType::Neg => vec![kernels::elementwise::neg(inputs.get(0).unwrap())?],
259 OpType::Abs => vec![kernels::elementwise::abs(inputs.get(0).unwrap())?],
260 OpType::Sqrt => vec![kernels::elementwise::sqrt(inputs.get(0).unwrap())?],
261 OpType::Exp => vec![kernels::elementwise::exp(inputs.get(0).unwrap())?],
262 OpType::Log => vec![kernels::elementwise::log(inputs.get(0).unwrap())?],
263 OpType::Erf => vec![kernels::elementwise::erf(inputs.get(0).unwrap())?],
264 OpType::Floor => vec![kernels::elementwise::floor(inputs.get(0).unwrap())?],
265 OpType::Ceil => vec![kernels::elementwise::ceil(inputs.get(0).unwrap())?],
266 OpType::Round => vec![kernels::elementwise::round(inputs.get(0).unwrap())?],
267
268 OpType::Relu => vec![kernels::elementwise::relu(inputs.get(0).unwrap())?],
270 OpType::Sigmoid => vec![kernels::elementwise::sigmoid(inputs.get(0).unwrap())?],
271 OpType::Tanh => vec![kernels::elementwise::tanh_act(inputs.get(0).unwrap())?],
272 OpType::Gelu => vec![kernels::elementwise::gelu(inputs.get(0).unwrap())?],
273 OpType::Silu => vec![kernels::elementwise::silu(inputs.get(0).unwrap())?],
274 OpType::HardSwish => vec![kernels::elementwise::hard_swish(inputs.get(0).unwrap())?],
275 OpType::LeakyRelu { alpha } => {
276 vec![kernels::elementwise::leaky_relu(
277 inputs.get(0).unwrap(),
278 alpha.0 as f32,
279 )?]
280 }
281 OpType::Clip { min, max } => {
282 vec![kernels::elementwise::clip(
283 inputs.get(0).unwrap(),
284 min.map(|v| v.0 as f32),
285 max.map(|v| v.0 as f32),
286 )?]
287 }
288
289 OpType::Softmax { axis } => {
291 vec![kernels::softmax::softmax(inputs.get(0).unwrap(), *axis)?]
292 }
293 OpType::LogSoftmax { axis } => {
294 vec![kernels::softmax::log_softmax(
295 inputs.get(0).unwrap(),
296 *axis,
297 )?]
298 }
299 OpType::LayerNorm { axis, epsilon } => {
300 let weight = inputs.get(1);
301 let bias = inputs.get(2);
302 vec![kernels::layernorm::layer_norm(
303 inputs.get(0).unwrap(),
304 weight,
305 bias,
306 *axis,
307 epsilon.0 as f32,
308 )?]
309 }
310 OpType::RmsNorm { epsilon } => {
311 let weight = inputs.get(1);
312 vec![kernels::layernorm::rms_norm(
313 inputs.get(0).unwrap(),
314 weight,
315 epsilon.0 as f32,
316 )?]
317 }
318
319 OpType::Conv2d {
321 kernel_shape,
322 pads,
323 strides,
324 dilations,
325 groups,
326 } => {
327 let x = inputs.get(0).unwrap();
328 let w = inputs.get(1).unwrap();
329 let b = inputs.get(2);
330 vec![kernels::conv2d::conv2d(
331 x,
332 w,
333 b,
334 *kernel_shape,
335 *pads,
336 *strides,
337 *dilations,
338 *groups,
339 )?]
340 }
341
342 OpType::Reshape => {
344 let x = inputs.get(0).unwrap();
345 vec![x.clone()]
349 }
350 OpType::Identity => vec![inputs.get(0).unwrap().clone()],
351
352 OpType::ReduceSum { axes, keep_dims } => {
354 vec![kernels::reduce::reduce_sum(
355 inputs.get(0).unwrap(),
356 axes,
357 *keep_dims,
358 )?]
359 }
360 OpType::ReduceMean { axes, keep_dims } => {
361 vec![kernels::reduce::reduce_mean(
362 inputs.get(0).unwrap(),
363 axes,
364 *keep_dims,
365 )?]
366 }
367 OpType::ReduceMax { axes, keep_dims } => {
368 vec![kernels::reduce::reduce_max(
369 inputs.get(0).unwrap(),
370 axes,
371 *keep_dims,
372 )?]
373 }
374
375 OpType::Embedding { .. } => {
379 let weight = inputs
380 .get(0)
381 .ok_or_else(|| SapientError::internal("Embedding: missing weight"))?;
382 let ids_t = inputs
383 .get(1)
384 .ok_or_else(|| SapientError::internal("Embedding: missing input_ids"))?;
385 let dims = ids_t.shape().dims();
386 let seq_len: usize = dims.iter().product();
387 let hidden = weight.shape().dims()[1];
388 let w = weight.to_f32_vec();
390 let ids: Vec<u32> = if ids_t.dtype() == DType::F32 {
391 ids_t.as_f32_slice().iter().map(|&v| v as u32).collect()
392 } else {
393 ids_t
394 .as_bytes()
395 .chunks_exact(4)
396 .map(|c| u32::from_le_bytes(c.try_into().unwrap()))
397 .collect()
398 };
399 let mut out = vec![0.0f32; seq_len * hidden];
400 for (i, &id) in ids.iter().enumerate() {
401 let row = id as usize * hidden;
402 out[i * hidden..(i + 1) * hidden].copy_from_slice(&w[row..row + hidden]);
403 }
404 let batch = if dims.len() >= 2 { dims[0] } else { 1 };
405 let seq = if dims.len() >= 2 { dims[1] } else { seq_len };
406 vec![Tensor::from_f32(&out, vec![batch, seq, hidden])
407 .map_err(|e| SapientError::internal(e.to_string()))?]
408 }
409
410 OpType::GroupedQueryAttention {
412 n_heads: _,
413 n_kv_heads,
414 head_dim: _,
415 causal,
416 } => {
417 let q = inputs
418 .get(0)
419 .ok_or_else(|| SapientError::internal("GQA: missing Q"))?;
420 let k = inputs
421 .get(1)
422 .ok_or_else(|| SapientError::internal("GQA: missing K"))?;
423 let v = inputs
424 .get(2)
425 .ok_or_else(|| SapientError::internal("GQA: missing V"))?;
426 let mask = if *causal {
427 let seq_q = q.shape().dims().get(2).copied().unwrap_or(1);
428 let seq_k = k.shape().dims().get(2).copied().unwrap_or(1);
429 Some(kernels::attention::causal_mask(seq_q, seq_k))
430 } else {
431 None
432 };
433 vec![kernels::attention::scaled_dot_product_attention(
434 q,
435 k,
436 v,
437 mask.as_ref(),
438 None,
439 *n_kv_heads,
440 )?]
441 }
442
443 OpType::MultiHeadAttention {
445 num_heads,
446 head_dim: _,
447 causal,
448 scale,
449 } => {
450 let q = inputs
451 .get(0)
452 .ok_or_else(|| SapientError::internal("MHA: missing Q"))?;
453 let k = inputs
454 .get(1)
455 .ok_or_else(|| SapientError::internal("MHA: missing K"))?;
456 let v = inputs
457 .get(2)
458 .ok_or_else(|| SapientError::internal("MHA: missing V"))?;
459 let mask = if *causal {
460 let sq = q.shape().dims().get(2).copied().unwrap_or(1);
461 let sk = k.shape().dims().get(2).copied().unwrap_or(1);
462 Some(kernels::attention::causal_mask(sq, sk))
463 } else {
464 None
465 };
466 vec![kernels::attention::scaled_dot_product_attention(
467 q,
468 k,
469 v,
470 mask.as_ref(),
471 scale.map(|s| s.0 as f32),
472 *num_heads,
473 )?]
474 }
475
476 OpType::RotaryEmbedding { base, dim: _ } => {
478 let x = inputs
479 .get(0)
480 .ok_or_else(|| SapientError::internal("RoPE: missing input"))?;
481 let seq_len = x.shape().dims().get(2).copied().unwrap_or(1);
482 let positions: Vec<usize> = (0..seq_len).collect();
483 vec![kernels::rope::apply_rope(x, &positions, base.0 as f32)?]
484 }
485
486 OpType::CausalMask => {
488 let seq = inputs
489 .get(0)
490 .map(|t| t.shape().dims().get(1).copied().unwrap_or(1))
491 .unwrap_or(1);
492 vec![kernels::attention::causal_mask(seq, seq)]
493 }
494
495 OpType::KVCacheConcat | OpType::RepeatKV { .. } => {
497 vec![inputs.get(0).unwrap().clone()]
498 }
499
500 OpType::MoEGate { .. } | OpType::ScaledDotProductAttention { .. } => {
502 vec![inputs.get(0).unwrap().clone()]
503 }
504
505 OpType::ALiBi { .. } => {
507 vec![Tensor::zeros(vec![1], DType::F32).unwrap()]
508 }
509
510 other => {
512 return Err(SapientError::unsupported_op("cpu", &other.to_string()));
513 }
514 };
515 Ok(out)
516 }
517}