1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
use crate::graph::OpKind;
use crate::memory::SizeClassPool;
use crate::OnnxError;
use std::collections::HashMap;
use std::sync::Mutex;
use super::super::types::NodeProfile;
use super::super::Session;
use super::state::SessionRunState;
impl Session {
// GREP_GUARD: all intermediates writes must go through dispatch_node / SessionRunState::insert
/// Sequential execution path using `SessionRunState` for buffer-reuse-aware
/// intermediate storage.
pub(crate) fn run_sequential_inner(
&self,
state: &mut SessionRunState,
ref_counts: &mut HashMap<String, usize>,
output_set: &std::collections::HashSet<&str>,
) -> Result<(), OnnxError> {
let resolved = self
.resolved_shapes
.lock()
.map(|s| s.clone())
.unwrap_or_default();
for node in &self.sorted_nodes {
if let OpKind::Unknown(_) = &node.op {
continue;
}
// Determine operator placement based on the configured strategy.
// output_bytes and placement are used by the GPU dispatch block below.
// CUDA and DirectML dispatch check op_placement directly (no size threshold).
#[cfg(feature = "gpu")]
let output_bytes =
Self::estimate_output_bytes(node, state.as_map(), &self.weights, &resolved);
#[cfg(feature = "gpu")]
let placement = crate::execution_providers::decide_placement(
&node.op,
output_bytes,
&self.op_placement,
);
// When no hardware-acceleration feature is active, read op_placement to
// satisfy the compiler (field is always valid, just unused at runtime).
#[cfg(not(any(feature = "gpu", feature = "cuda", feature = "directml")))]
let _ = &self.op_placement;
// CUDA dispatch (only when placement allows)
#[cfg(feature = "cuda")]
{
let try_cuda = self.cuda.is_some()
&& !matches!(
self.op_placement,
crate::execution_providers::OpPlacement::CpuOnly
);
if try_cuda {
if let Some(cuda_ctx) = &self.cuda {
let cuda_start = std::time::Instant::now();
match oxionnx_cuda::try_cuda_dispatch(
node,
&self.weights,
state.as_map(),
cuda_ctx,
) {
Ok(Some(results)) => {
let cuda_elapsed = cuda_start.elapsed();
if let Some(ref profiling) = self.profiling_data {
if let Ok(mut data) = profiling.lock() {
data.push(NodeProfile {
node_name: node.name.clone(),
op_type: node.op.as_str().to_string(),
duration: cuda_elapsed,
output_shapes: results
.iter()
.map(|t| t.shape.clone())
.collect(),
});
}
}
let pool = self.pool.as_ref().map(|m| m as &Mutex<SizeClassPool>);
for (name, tensor) in node.outputs.iter().zip(results) {
if !name.is_empty() {
state.insert(name.clone(), tensor, pool);
}
}
self.decrement_refs_state(node, state, ref_counts, output_set);
continue;
}
Ok(None) => {
// Op not supported on CUDA — fall through to CPU
}
Err(_e) => {
// CUDA dispatch failed — fall back to CPU gracefully
#[cfg(debug_assertions)]
tracing::debug!(
op = %node.op.as_str(),
node = %node.name,
err = %_e,
"CUDA dispatch error, falling back to CPU",
);
}
}
}
}
}
// DirectML dispatch — Windows D3D12 GPU, higher priority than wgpu on Windows
#[cfg(feature = "directml")]
{
let try_dml = self.dml.is_some()
&& !matches!(
self.op_placement,
crate::execution_providers::OpPlacement::CpuOnly
);
if try_dml {
if let Some(ctx) = &self.dml {
let dml_start = std::time::Instant::now();
match oxionnx_directml::try_directml_dispatch(
node,
&self.weights,
state.as_map(),
ctx,
) {
Ok(Some(results)) => {
let dml_elapsed = dml_start.elapsed();
if let Some(ref profiling) = self.profiling_data {
if let Ok(mut data) = profiling.lock() {
data.push(NodeProfile {
node_name: node.name.clone(),
op_type: node.op.as_str().to_string(),
duration: dml_elapsed,
output_shapes: results
.iter()
.map(|t| t.shape.clone())
.collect(),
});
}
}
let pool = self.pool.as_ref().map(|m| m as &Mutex<SizeClassPool>);
for (name, tensor) in node.outputs.iter().zip(results) {
if !name.is_empty() {
state.insert(name.clone(), tensor, pool);
}
}
self.decrement_refs_state(node, state, ref_counts, output_set);
continue;
}
Ok(None) => {
// Op not supported by DirectML — fall through to wgpu/CPU
}
Err(_e) => {
// DirectML dispatch error — fall back silently
#[cfg(debug_assertions)]
tracing::debug!(
op = %node.op.as_str(),
node = %node.name,
err = %_e,
"DirectML dispatch error, falling back",
);
}
}
}
}
}
// GPU dispatch (only when placement routes to GPU)
#[cfg(feature = "gpu")]
{
use super::super::gpu_dispatch::{try_gpu_dispatch, GpuExecutionProvider};
use crate::execution_providers::ProviderKind;
let try_gpu = matches!(placement, ProviderKind::Gpu);
if try_gpu {
if let Some(gpu_ctx) = &self.gpu {
if let Some(results) =
try_gpu_dispatch(node, &self.weights, state.as_map(), gpu_ctx)?
{
let pool = self.pool.as_ref().map(|m| m as &Mutex<SizeClassPool>);
for (name, tensor) in node.outputs.iter().zip(results) {
if !name.is_empty() {
state.insert(name.clone(), tensor, pool);
}
}
self.decrement_refs_state(node, state, ref_counts, output_set);
continue;
}
// GPU dispatch returned None — falling back to CPU for this op
if GpuExecutionProvider::is_supported(node.op.as_str()) {
#[cfg(debug_assertions)]
tracing::debug!(
op = %node.op.as_str(),
node = %node.name,
"GPU fallback: fell back to CPU",
);
}
}
}
}
let op_name = node.op.as_str();
// Mixed precision: try native f16 execution for f16-safe element-wise ops
if self.mixed_precision && super::super::mixed_precision::should_use_f16(op_name) {
let input_refs: Vec<&crate::tensor::Tensor> = node
.inputs
.iter()
.filter_map(|name| {
if name.is_empty() {
None
} else {
state.get(name).or_else(|| self.weights.get(name))
}
})
.collect();
let start = std::time::Instant::now();
if let Some(f16_result) =
super::super::mixed_precision::execute_elementwise_f16(op_name, &input_refs)
{
let results = f16_result?;
let elapsed = start.elapsed();
if let Some(ref profiling) = self.profiling_data {
if let Ok(mut data) = profiling.lock() {
data.push(NodeProfile {
node_name: node.name.clone(),
op_type: format!("{op_name}(f16)"),
duration: elapsed,
output_shapes: results.iter().map(|t| t.shape.clone()).collect(),
});
}
}
let pool = self.pool.as_ref().map(|m| m as &Mutex<SizeClassPool>);
for (name, tensor) in node.outputs.iter().zip(results) {
if !name.is_empty() {
state.insert(name.clone(), tensor, pool);
}
}
self.decrement_refs_state(node, state, ref_counts, output_set);
continue;
}
// No native f16 path — fall through to normal execution with f16 rounding
}
let operator = self.registry.get(op_name).ok_or_else(|| {
OnnxError::UnknownOp(format!("No operator registered for '{}'", op_name))
})?;
let elapsed =
self.dispatch_node(node, operator, state, ref_counts, output_set, &resolved)?;
// Mixed precision: round outputs to f16 for f16-safe ops without native f16 path.
// This simulates f16 storage precision for ops that ran in f32.
if self.mixed_precision && super::super::mixed_precision::should_use_f16(op_name) {
let pool = self.pool.as_ref().map(|m| m as &Mutex<SizeClassPool>);
for out_name in &node.outputs {
if out_name.is_empty() {
continue;
}
if let Some(t) = state.take(out_name) {
let rounded = super::super::mixed_precision::round_to_f16_precision(&t);
state.insert(out_name.clone(), rounded, pool);
}
}
}
if let Some(ref profiling) = self.profiling_data {
if let Ok(mut data) = profiling.lock() {
// Gather output shapes for profiling
let output_shapes: Vec<Vec<usize>> = node
.outputs
.iter()
.filter(|n| !n.is_empty())
.filter_map(|n| state.get(n).map(|t| t.shape.clone()))
.collect();
data.push(NodeProfile {
node_name: node.name.clone(),
op_type: node.op.as_str().to_string(),
duration: elapsed,
output_shapes,
});
}
}
self.decrement_refs_state(node, state, ref_counts, output_set);
}
Ok(())
}
}