1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
//! CPU implementation of conditional operations.
use crate::dtype::DType;
use crate::error::{Error, Result};
use crate::ops::ConditionalOps;
use crate::runtime::cpu::{
CpuClient, CpuRuntime,
helpers::{dispatch_dtype, ensure_contiguous},
kernels,
};
use crate::tensor::Tensor;
/// ConditionalOps implementation for CPU runtime.
impl ConditionalOps<CpuRuntime> for CpuClient {
fn where_cond(
&self,
cond: &Tensor<CpuRuntime>,
x: &Tensor<CpuRuntime>,
y: &Tensor<CpuRuntime>,
) -> Result<Tensor<CpuRuntime>> {
use crate::ops::broadcast_shape;
// Validate that x and y have the same dtype
if x.dtype() != y.dtype() {
return Err(Error::DTypeMismatch {
lhs: x.dtype(),
rhs: y.dtype(),
});
}
let dtype = x.dtype();
let cond_dtype = cond.dtype();
// Compute broadcast shape (cond, x, y) -> out
let xy_shape =
broadcast_shape(x.shape(), y.shape()).ok_or_else(|| Error::BroadcastError {
lhs: x.shape().to_vec(),
rhs: y.shape().to_vec(),
})?;
let out_shape =
broadcast_shape(cond.shape(), &xy_shape).ok_or_else(|| Error::BroadcastError {
lhs: cond.shape().to_vec(),
rhs: xy_shape.clone(),
})?;
let out = Tensor::<CpuRuntime>::empty(&out_shape, dtype, &self.device);
let out_ptr = out.ptr();
// Fast path: all same shape, use simple kernel
if cond.shape() == x.shape() && x.shape() == y.shape() {
let cond_contig = ensure_contiguous(cond);
let x_contig = ensure_contiguous(x);
let y_contig = ensure_contiguous(y);
let cond_ptr = cond_contig.ptr();
let x_ptr = x_contig.ptr();
let y_ptr = y_contig.ptr();
let numel = x.numel();
// Double dispatch: cond dtype and value dtype
// For U8 condition, use optimized SIMD kernel
if cond_dtype == DType::U8 {
dispatch_dtype!(dtype, T => {
unsafe {
kernels::where_kernel::<T>(
cond_ptr as *const u8,
x_ptr as *const T,
y_ptr as *const T,
out_ptr as *mut T,
numel,
);
}
}, "where_cond");
} else {
// Generic kernel for any condition dtype (non-zero = true)
dispatch_dtype!(cond_dtype, C => {
dispatch_dtype!(dtype, T => {
unsafe {
kernels::where_kernel_generic::<C, T>(
cond_ptr as *const C,
x_ptr as *const T,
y_ptr as *const T,
out_ptr as *mut T,
numel,
);
}
}, "where_cond");
}, "where_cond");
}
} else {
// Broadcasting path: use strided kernel
// Broadcast all inputs to output shape (zero-copy views with stride 0 for broadcast dims)
let cond_broadcast = cond.broadcast_to(&out_shape)?;
let x_broadcast = x.broadcast_to(&out_shape)?;
let y_broadcast = y.broadcast_to(&out_shape)?;
let cond_ptr = cond_broadcast.ptr();
let x_ptr = x_broadcast.ptr();
let y_ptr = y_broadcast.ptr();
// Get strides from broadcast layouts
let cond_strides: Vec<isize> = cond_broadcast.layout().strides().to_vec();
let x_strides: Vec<isize> = x_broadcast.layout().strides().to_vec();
let y_strides: Vec<isize> = y_broadcast.layout().strides().to_vec();
let cond_offset = cond_broadcast.layout().offset();
let x_offset = x_broadcast.layout().offset();
let y_offset = y_broadcast.layout().offset();
// For U8 condition, use optimized kernel
if cond_dtype == DType::U8 {
dispatch_dtype!(dtype, T => {
unsafe {
kernels::where_strided_kernel::<T>(
cond_ptr as *const u8,
x_ptr as *const T,
y_ptr as *const T,
out_ptr as *mut T,
&out_shape,
&cond_strides,
&x_strides,
&y_strides,
cond_offset,
x_offset,
y_offset,
);
}
}, "where_cond");
} else {
// Generic kernel for any condition dtype
dispatch_dtype!(cond_dtype, C => {
dispatch_dtype!(dtype, T => {
unsafe {
kernels::where_strided_kernel_generic::<C, T>(
cond_ptr as *const C,
x_ptr as *const T,
y_ptr as *const T,
out_ptr as *mut T,
&out_shape,
&cond_strides,
&x_strides,
&y_strides,
cond_offset,
x_offset,
y_offset,
);
}
}, "where_cond");
}, "where_cond");
}
}
Ok(out)
}
}