Skip to main content

flodl_sys/
lib.rs

1//! Raw FFI bindings to the libtorch C++ shim.
2//!
3//! Every function that can fail returns a `*mut i8` error string (caller
4//! must free it with [`flodl_free_string`]). A null pointer means success.
5//!
6//! `FlodlTensor` is an opaque `*mut c_void` handle to a heap-allocated
7//! `torch::Tensor`. Caller owns it and must free with [`flodl_free_tensor`].
8
9use std::ffi::c_void;
10
11/// Opaque handle to a `torch::Tensor` on the C++ side.
12pub type FlodlTensor = *mut c_void;
13
14// --- DType constants (must match shim.h) ---
15pub const FLODL_FLOAT16: i32 = 5;
16pub const FLODL_BFLOAT16: i32 = 15;
17pub const FLODL_FLOAT32: i32 = 6;
18pub const FLODL_FLOAT64: i32 = 7;
19pub const FLODL_INT32: i32 = 3;
20pub const FLODL_INT64: i32 = 4;
21
22// --- Device constants (must match shim.h) ---
23pub const FLODL_CPU: i32 = 0;
24pub const FLODL_CUDA: i32 = 1;
25
26unsafe extern "C" {
27    // --- Tensor creation ---
28
29    pub fn flodl_zeros(
30        shape: *mut i64, ndim: i32, dtype: i32,
31        device_type: i32, device_index: i32,
32        result: *mut FlodlTensor,
33    ) -> *mut i8;
34
35    pub fn flodl_ones(
36        shape: *mut i64, ndim: i32, dtype: i32,
37        device_type: i32, device_index: i32,
38        result: *mut FlodlTensor,
39    ) -> *mut i8;
40
41    pub fn flodl_rand(
42        shape: *mut i64, ndim: i32, dtype: i32,
43        device_type: i32, device_index: i32,
44        result: *mut FlodlTensor,
45    ) -> *mut i8;
46
47    pub fn flodl_randn(
48        shape: *mut i64, ndim: i32, dtype: i32,
49        device_type: i32, device_index: i32,
50        result: *mut FlodlTensor,
51    ) -> *mut i8;
52
53    pub fn flodl_from_blob(
54        data: *mut c_void, shape: *mut i64, ndim: i32,
55        dtype: i32, device_type: i32, device_index: i32,
56        result: *mut FlodlTensor,
57    ) -> *mut i8;
58
59    pub fn flodl_linspace(
60        start: f64, end: f64, steps: i64,
61        dtype: i32, device_type: i32, device_index: i32,
62        result: *mut FlodlTensor,
63    ) -> *mut i8;
64
65    pub fn flodl_arange(
66        start: f64, end: f64, step: f64,
67        dtype: i32, device_type: i32, device_index: i32,
68        result: *mut FlodlTensor,
69    ) -> *mut i8;
70
71    pub fn flodl_expand(
72        t: FlodlTensor, new_shape: *mut i64, ndim: i32,
73        result: *mut FlodlTensor,
74    ) -> *mut i8;
75
76    // --- Tensor lifecycle ---
77
78    pub fn flodl_free_tensor(t: FlodlTensor);
79    pub fn flodl_shallow_clone(t: FlodlTensor, result: *mut FlodlTensor) -> *mut i8;
80
81    // --- Tensor metadata ---
82
83    pub fn flodl_ndim(t: FlodlTensor) -> i32;
84    pub fn flodl_shape(t: FlodlTensor, dim: i32) -> i64;
85    pub fn flodl_dtype(t: FlodlTensor) -> i32;
86    pub fn flodl_device_type(t: FlodlTensor) -> i32;
87    pub fn flodl_device_index(t: FlodlTensor) -> i32;
88    pub fn flodl_numel(t: FlodlTensor) -> i64;
89
90    // --- Data access ---
91
92    pub fn flodl_copy_data(
93        t: FlodlTensor, buffer: *mut c_void, buffer_bytes: i64,
94    ) -> *mut i8;
95
96    // --- Arithmetic ---
97
98    pub fn flodl_add(a: FlodlTensor, b: FlodlTensor, result: *mut FlodlTensor) -> *mut i8;
99    pub fn flodl_sub(a: FlodlTensor, b: FlodlTensor, result: *mut FlodlTensor) -> *mut i8;
100    pub fn flodl_mul(a: FlodlTensor, b: FlodlTensor, result: *mut FlodlTensor) -> *mut i8;
101    pub fn flodl_div(a: FlodlTensor, b: FlodlTensor, result: *mut FlodlTensor) -> *mut i8;
102    pub fn flodl_matmul(a: FlodlTensor, b: FlodlTensor, result: *mut FlodlTensor) -> *mut i8;
103
104    pub fn flodl_add_scalar(
105        t: FlodlTensor, scalar: f64, result: *mut FlodlTensor,
106    ) -> *mut i8;
107
108    pub fn flodl_mul_scalar(
109        t: FlodlTensor, scalar: f64, result: *mut FlodlTensor,
110    ) -> *mut i8;
111
112    pub fn flodl_div_scalar(
113        t: FlodlTensor, scalar: f64, result: *mut FlodlTensor,
114    ) -> *mut i8;
115
116    pub fn flodl_neg(t: FlodlTensor, result: *mut FlodlTensor) -> *mut i8;
117
118    // --- Activations ---
119
120    pub fn flodl_relu(t: FlodlTensor, result: *mut FlodlTensor) -> *mut i8;
121    pub fn flodl_sigmoid(t: FlodlTensor, result: *mut FlodlTensor) -> *mut i8;
122    pub fn flodl_tanh_op(t: FlodlTensor, result: *mut FlodlTensor) -> *mut i8;
123    pub fn flodl_softmax(t: FlodlTensor, dim: i32, result: *mut FlodlTensor) -> *mut i8;
124    pub fn flodl_log_softmax(t: FlodlTensor, dim: i32, result: *mut FlodlTensor) -> *mut i8;
125    pub fn flodl_gelu(t: FlodlTensor, result: *mut FlodlTensor) -> *mut i8;
126    pub fn flodl_silu(t: FlodlTensor, result: *mut FlodlTensor) -> *mut i8;
127
128    // --- Layer normalization ---
129
130    pub fn flodl_native_layer_norm(
131        input: FlodlTensor, weight: FlodlTensor, bias: FlodlTensor,
132        normalized_size: i64, eps: f64,
133        output: *mut FlodlTensor, mean: *mut FlodlTensor, rstd: *mut FlodlTensor,
134    ) -> *mut i8;
135
136    // --- Element-wise math ---
137
138    pub fn flodl_exp(t: FlodlTensor, result: *mut FlodlTensor) -> *mut i8;
139    pub fn flodl_log(t: FlodlTensor, result: *mut FlodlTensor) -> *mut i8;
140    pub fn flodl_sqrt(t: FlodlTensor, result: *mut FlodlTensor) -> *mut i8;
141    pub fn flodl_abs(t: FlodlTensor, result: *mut FlodlTensor) -> *mut i8;
142    pub fn flodl_triu(t: FlodlTensor, diagonal: i64, result: *mut FlodlTensor) -> *mut i8;
143
144    pub fn flodl_pow_scalar(
145        t: FlodlTensor, exponent: f64, result: *mut FlodlTensor,
146    ) -> *mut i8;
147
148    pub fn flodl_clamp(
149        t: FlodlTensor, min_val: f64, max_val: f64, result: *mut FlodlTensor,
150    ) -> *mut i8;
151
152    // --- Reductions ---
153
154    pub fn flodl_sum(t: FlodlTensor, result: *mut FlodlTensor) -> *mut i8;
155    pub fn flodl_mean(t: FlodlTensor, result: *mut FlodlTensor) -> *mut i8;
156
157    pub fn flodl_sum_dim(
158        t: FlodlTensor, dim: i32, keepdim: i32, result: *mut FlodlTensor,
159    ) -> *mut i8;
160
161    pub fn flodl_mean_dim(
162        t: FlodlTensor, dim: i32, keepdim: i32, result: *mut FlodlTensor,
163    ) -> *mut i8;
164
165    pub fn flodl_min(t: FlodlTensor, result: *mut FlodlTensor) -> *mut i8;
166    pub fn flodl_max(t: FlodlTensor, result: *mut FlodlTensor) -> *mut i8;
167    pub fn flodl_norm(t: FlodlTensor, result: *mut FlodlTensor) -> *mut i8;
168
169    pub fn flodl_min_dim(
170        t: FlodlTensor, dim: i32, keepdim: i32, result: *mut FlodlTensor,
171    ) -> *mut i8;
172
173    pub fn flodl_max_dim(
174        t: FlodlTensor, dim: i32, keepdim: i32, result: *mut FlodlTensor,
175    ) -> *mut i8;
176
177    pub fn flodl_argmax(
178        t: FlodlTensor, dim: i32, keepdim: i32, result: *mut FlodlTensor,
179    ) -> *mut i8;
180
181    // --- Comparison (return float masks: 0.0 or 1.0) ---
182
183    pub fn flodl_gt_scalar(
184        t: FlodlTensor, scalar: f64, result: *mut FlodlTensor,
185    ) -> *mut i8;
186
187    pub fn flodl_ge_scalar(
188        t: FlodlTensor, scalar: f64, result: *mut FlodlTensor,
189    ) -> *mut i8;
190
191    pub fn flodl_le_scalar(
192        t: FlodlTensor, scalar: f64, result: *mut FlodlTensor,
193    ) -> *mut i8;
194
195    pub fn flodl_lt_scalar(
196        t: FlodlTensor, scalar: f64, result: *mut FlodlTensor,
197    ) -> *mut i8;
198
199    // --- Shape operations ---
200
201    pub fn flodl_reshape(
202        t: FlodlTensor, shape: *mut i64, ndim: i32, result: *mut FlodlTensor,
203    ) -> *mut i8;
204
205    pub fn flodl_transpose(
206        t: FlodlTensor, dim0: i32, dim1: i32, result: *mut FlodlTensor,
207    ) -> *mut i8;
208
209    pub fn flodl_permute(
210        t: FlodlTensor, dims: *mut i64, ndim: i32, result: *mut FlodlTensor,
211    ) -> *mut i8;
212
213    pub fn flodl_select(
214        t: FlodlTensor, dim: i32, index: i64, result: *mut FlodlTensor,
215    ) -> *mut i8;
216
217    pub fn flodl_narrow(
218        t: FlodlTensor, dim: i32, start: i64, length: i64,
219        result: *mut FlodlTensor,
220    ) -> *mut i8;
221
222    pub fn flodl_squeeze(
223        t: FlodlTensor, dim: i32, result: *mut FlodlTensor,
224    ) -> *mut i8;
225
226    pub fn flodl_unsqueeze(
227        t: FlodlTensor, dim: i32, result: *mut FlodlTensor,
228    ) -> *mut i8;
229
230    pub fn flodl_flatten(
231        t: FlodlTensor, start_dim: i32, end_dim: i32, result: *mut FlodlTensor,
232    ) -> *mut i8;
233
234    // --- Scatter ---
235
236    pub fn flodl_select_scatter(
237        input: FlodlTensor, src: FlodlTensor, dim: i32, index: i64,
238        result: *mut FlodlTensor,
239    ) -> *mut i8;
240
241    pub fn flodl_narrow_scatter(
242        input: FlodlTensor, src: FlodlTensor, dim: i32, start: i64,
243        result: *mut FlodlTensor,
244    ) -> *mut i8;
245
246    // --- Indexing ---
247
248    pub fn flodl_index_select(
249        t: FlodlTensor, dim: i32, index: FlodlTensor,
250        result: *mut FlodlTensor,
251    ) -> *mut i8;
252
253    pub fn flodl_index_add(
254        t: FlodlTensor, dim: i32, index: FlodlTensor, src: FlodlTensor,
255        result: *mut FlodlTensor,
256    ) -> *mut i8;
257
258    // --- Concatenation ---
259
260    pub fn flodl_cat2(
261        a: FlodlTensor, b: FlodlTensor, dim: i32, result: *mut FlodlTensor,
262    ) -> *mut i8;
263
264    pub fn flodl_cat(
265        tensors: *mut FlodlTensor, count: i32, dim: i32, result: *mut FlodlTensor,
266    ) -> *mut i8;
267
268    pub fn flodl_stack(
269        tensors: *mut FlodlTensor, count: i32, dim: i32, result: *mut FlodlTensor,
270    ) -> *mut i8;
271
272    // --- Conditional ---
273
274    pub fn flodl_where(
275        condition: FlodlTensor, x: FlodlTensor, y: FlodlTensor,
276        result: *mut FlodlTensor,
277    ) -> *mut i8;
278
279    // --- Like constructors ---
280
281    pub fn flodl_zeros_like(t: FlodlTensor, result: *mut FlodlTensor) -> *mut i8;
282    pub fn flodl_ones_like(t: FlodlTensor, result: *mut FlodlTensor) -> *mut i8;
283
284    // --- Convolution ---
285
286    pub fn flodl_conv2d(
287        input: FlodlTensor, weight: FlodlTensor, bias: FlodlTensor,
288        stride: *mut i64, padding: *mut i64, dilation: *mut i64,
289        groups: i64, result: *mut FlodlTensor,
290    ) -> *mut i8;
291
292    // --- Transposed convolution ---
293
294    pub fn flodl_conv_transpose2d(
295        input: FlodlTensor, weight: FlodlTensor, bias: FlodlTensor,
296        stride: *mut i64, padding: *mut i64,
297        output_padding: *mut i64, dilation: *mut i64,
298        groups: i64, result: *mut FlodlTensor,
299    ) -> *mut i8;
300
301    // --- Pooling ---
302
303    pub fn flodl_max_pool2d(
304        input: FlodlTensor, kernel_size: *mut i64,
305        stride: *mut i64, padding: *mut i64, dilation: *mut i64,
306        ceil_mode: i32, result: *mut FlodlTensor,
307    ) -> *mut i8;
308
309    pub fn flodl_adaptive_avg_pool2d(
310        input: FlodlTensor, output_size: *mut i64,
311        result: *mut FlodlTensor,
312    ) -> *mut i8;
313
314    // --- Grid sampling ---
315
316    pub fn flodl_grid_sample(
317        input: FlodlTensor, grid: FlodlTensor,
318        mode: i32, padding_mode: i32, align_corners: i32,
319        result: *mut FlodlTensor,
320    ) -> *mut i8;
321
322    // --- Device ---
323
324    pub fn flodl_to_device(
325        t: FlodlTensor, device_type: i32, device_index: i32,
326        result: *mut FlodlTensor,
327    ) -> *mut i8;
328
329    pub fn flodl_to_device_async(
330        t: FlodlTensor, device_type: i32, device_index: i32,
331        result: *mut FlodlTensor,
332    ) -> *mut i8;
333
334    pub fn flodl_cuda_is_available() -> i32;
335    pub fn flodl_cuda_device_count() -> i32;
336    pub fn flodl_force_cuda_link() -> i32;
337    pub fn flodl_set_current_device(device_index: i32);
338    pub fn flodl_get_current_device() -> i32;
339    pub fn flodl_cuda_synchronize(device_index: i32);
340
341    // --- CUDA memory/utilization (monitor support) ---
342
343    pub fn flodl_cuda_mem_info(
344        device_index: i32, used_bytes: *mut u64, total_bytes: *mut u64,
345    ) -> *mut i8;
346
347    pub fn flodl_cuda_alloc_bytes(
348        device_index: i32, allocated_bytes: *mut u64,
349    ) -> *mut i8;
350
351    pub fn flodl_cuda_active_bytes(
352        device_index: i32, active_bytes: *mut u64,
353    ) -> *mut i8;
354
355    pub fn flodl_cuda_peak_active_bytes(
356        device_index: i32, peak_bytes: *mut u64,
357    ) -> *mut i8;
358
359    pub fn flodl_cuda_peak_reserved_bytes(
360        device_index: i32, peak_bytes: *mut u64,
361    ) -> *mut i8;
362
363    pub fn flodl_cuda_reset_peak_stats(device_index: i32);
364
365    pub fn flodl_cuda_empty_cache();
366
367    pub fn flodl_cuda_utilization(device_index: i32) -> i32;
368
369    pub fn flodl_cuda_device_name(
370        device_index: i32, buf: *mut i8, buf_len: i32,
371    ) -> *mut i8;
372
373    // --- Dtype casting ---
374
375    pub fn flodl_to_dtype(
376        t: FlodlTensor, dtype: i32, result: *mut FlodlTensor,
377    ) -> *mut i8;
378
379    pub fn flodl_all_finite(t: FlodlTensor, result: *mut i32) -> *mut i8;
380
381    // --- Comparison (tensor-tensor, return float masks: 0.0 or 1.0) ---
382
383    pub fn flodl_gt_tensor(
384        a: FlodlTensor, b: FlodlTensor, result: *mut FlodlTensor,
385    ) -> *mut i8;
386
387    pub fn flodl_lt_tensor(
388        a: FlodlTensor, b: FlodlTensor, result: *mut FlodlTensor,
389    ) -> *mut i8;
390
391    pub fn flodl_ge_tensor(
392        a: FlodlTensor, b: FlodlTensor, result: *mut FlodlTensor,
393    ) -> *mut i8;
394
395    pub fn flodl_le_tensor(
396        a: FlodlTensor, b: FlodlTensor, result: *mut FlodlTensor,
397    ) -> *mut i8;
398
399    pub fn flodl_eq_tensor(
400        a: FlodlTensor, b: FlodlTensor, result: *mut FlodlTensor,
401    ) -> *mut i8;
402
403    pub fn flodl_ne_tensor(
404        a: FlodlTensor, b: FlodlTensor, result: *mut FlodlTensor,
405    ) -> *mut i8;
406
407    // --- Additional reductions ---
408
409    pub fn flodl_argmin(
410        t: FlodlTensor, dim: i32, keepdim: i32, result: *mut FlodlTensor,
411    ) -> *mut i8;
412
413    pub fn flodl_var(t: FlodlTensor, result: *mut FlodlTensor) -> *mut i8;
414    pub fn flodl_std_op(t: FlodlTensor, result: *mut FlodlTensor) -> *mut i8;
415
416    pub fn flodl_var_dim(
417        t: FlodlTensor, dim: i32, keepdim: i32, result: *mut FlodlTensor,
418    ) -> *mut i8;
419
420    pub fn flodl_std_dim(
421        t: FlodlTensor, dim: i32, keepdim: i32, result: *mut FlodlTensor,
422    ) -> *mut i8;
423
424    // --- Element-wise math (trig, rounding, sign) ---
425
426    pub fn flodl_sin(t: FlodlTensor, result: *mut FlodlTensor) -> *mut i8;
427    pub fn flodl_cos(t: FlodlTensor, result: *mut FlodlTensor) -> *mut i8;
428    pub fn flodl_sign(t: FlodlTensor, result: *mut FlodlTensor) -> *mut i8;
429    pub fn flodl_floor(t: FlodlTensor, result: *mut FlodlTensor) -> *mut i8;
430    pub fn flodl_ceil(t: FlodlTensor, result: *mut FlodlTensor) -> *mut i8;
431    pub fn flodl_round(t: FlodlTensor, result: *mut FlodlTensor) -> *mut i8;
432    pub fn flodl_reciprocal(t: FlodlTensor, result: *mut FlodlTensor) -> *mut i8;
433
434    // --- Advanced indexing ---
435
436    pub fn flodl_gather(
437        t: FlodlTensor, dim: i32, index: FlodlTensor,
438        result: *mut FlodlTensor,
439    ) -> *mut i8;
440
441    pub fn flodl_scatter_add(
442        t: FlodlTensor, dim: i32, index: FlodlTensor, src: FlodlTensor,
443        result: *mut FlodlTensor,
444    ) -> *mut i8;
445
446    // --- Sorting ---
447
448    pub fn flodl_topk(
449        t: FlodlTensor, k: i64, dim: i32, largest: i32, sorted: i32,
450        values: *mut FlodlTensor, indices: *mut FlodlTensor,
451    ) -> *mut i8;
452
453    pub fn flodl_sort(
454        t: FlodlTensor, dim: i32, descending: i32,
455        values: *mut FlodlTensor, indices: *mut FlodlTensor,
456    ) -> *mut i8;
457
458    // --- Tensor creation (additional) ---
459
460    pub fn flodl_eye(
461        n: i64, dtype: i32, device_type: i32, device_index: i32,
462        result: *mut FlodlTensor,
463    ) -> *mut i8;
464
465    pub fn flodl_full(
466        shape: *mut i64, ndim: i32, value: f64, dtype: i32,
467        device_type: i32, device_index: i32,
468        result: *mut FlodlTensor,
469    ) -> *mut i8;
470
471    // --- Shape operations (additional) ---
472
473    pub fn flodl_chunk(
474        t: FlodlTensor, chunks: i32, dim: i32,
475        results: *mut *mut FlodlTensor, count: *mut i32,
476    ) -> *mut i8;
477
478    pub fn flodl_repeat(
479        t: FlodlTensor, repeats: *mut i64, ndim: i32,
480        result: *mut FlodlTensor,
481    ) -> *mut i8;
482
483    pub fn flodl_pad(
484        t: FlodlTensor, padding: *mut i64, pad_len: i32, value: f64,
485        result: *mut FlodlTensor,
486    ) -> *mut i8;
487
488    // --- Autograd ---
489
490    pub fn flodl_set_requires_grad(
491        t: FlodlTensor, requires_grad: i32, result: *mut FlodlTensor,
492    ) -> *mut i8;
493
494    pub fn flodl_requires_grad(t: FlodlTensor) -> i32;
495
496    pub fn flodl_backward(t: FlodlTensor) -> *mut i8;
497
498    pub fn flodl_grad(t: FlodlTensor, result: *mut FlodlTensor) -> *mut i8;
499
500    pub fn flodl_set_grad(t: FlodlTensor, grad: FlodlTensor) -> *mut i8;
501
502    pub fn flodl_zero_grad(t: FlodlTensor) -> *mut i8;
503
504    pub fn flodl_detach(t: FlodlTensor, result: *mut FlodlTensor) -> *mut i8;
505
506    pub fn flodl_detach_(t: FlodlTensor) -> *mut i8;
507
508    pub fn flodl_is_leaf(t: FlodlTensor) -> i32;
509
510    // --- Autograd context ---
511
512    pub fn flodl_no_grad_guard_new() -> *mut c_void;
513    pub fn flodl_no_grad_guard_delete(guard: *mut c_void);
514    pub fn flodl_is_grad_enabled() -> i32;
515
516    // --- Autocast (automatic mixed precision) ---
517
518    pub fn flodl_autocast_guard_new(device_type: i32, dtype: i32) -> *mut c_void;
519    pub fn flodl_autocast_guard_delete(guard: *mut c_void);
520    pub fn flodl_is_autocast_enabled(device_type: i32) -> i32;
521
522    // --- Meshgrid ---
523
524    pub fn flodl_meshgrid(
525        tensors: *mut FlodlTensor, count: i32,
526        results: *mut *mut FlodlTensor, result_count: *mut i32,
527    ) -> *mut i8;
528
529    // --- Pairwise distance ---
530
531    pub fn flodl_cdist(
532        x: FlodlTensor, y: FlodlTensor, p: f64,
533        result: *mut FlodlTensor,
534    ) -> *mut i8;
535
536    // --- Fused ops ---
537
538    pub fn flodl_linear(
539        input: FlodlTensor, weight: FlodlTensor, bias: FlodlTensor,
540        result: *mut FlodlTensor,
541    ) -> *mut i8;
542
543    pub fn flodl_gru_cell(
544        input: FlodlTensor, hx: FlodlTensor,
545        w_ih: FlodlTensor, w_hh: FlodlTensor,
546        b_ih: FlodlTensor, b_hh: FlodlTensor,
547        result: *mut FlodlTensor,
548    ) -> *mut i8;
549
550    pub fn flodl_lstm_cell(
551        input: FlodlTensor, hx: FlodlTensor, cx: FlodlTensor,
552        w_ih: FlodlTensor, w_hh: FlodlTensor,
553        b_ih: FlodlTensor, b_hh: FlodlTensor,
554        h_out: *mut FlodlTensor, c_out: *mut FlodlTensor,
555    ) -> *mut i8;
556
557    // --- cuDNN benchmark ---
558
559    pub fn flodl_set_cudnn_benchmark(enable: i32);
560
561    // --- RNG seed ---
562
563    pub fn flodl_manual_seed(seed: u64);
564    pub fn flodl_cuda_manual_seed_all(seed: u64);
565
566    // --- In-place operations ---
567
568    pub fn flodl_add_(t: FlodlTensor, other: FlodlTensor) -> *mut i8;
569    pub fn flodl_sub_(t: FlodlTensor, other: FlodlTensor) -> *mut i8;
570    pub fn flodl_mul_scalar_(t: FlodlTensor, scalar: f64) -> *mut i8;
571    pub fn flodl_add_scalar_(t: FlodlTensor, scalar: f64) -> *mut i8;
572    pub fn flodl_zero_(t: FlodlTensor) -> *mut i8;
573
574    // --- Fused Adam step ---
575
576    pub fn flodl_adam_step(
577        param: FlodlTensor, grad: FlodlTensor,
578        m: FlodlTensor, v: FlodlTensor,
579        lr: f64, beta1: f64, beta2: f64, eps: f64,
580        weight_decay: f64, step: i64,
581    ) -> *mut i8;
582
583    // --- Batched Adam step ---
584
585    pub fn flodl_adam_step_batched(
586        params: *mut FlodlTensor, grads: *mut FlodlTensor,
587        ms: *mut FlodlTensor, vs: *mut FlodlTensor,
588        lrs: *mut f64, count: i32,
589        beta1: f64, beta2: f64, eps: f64,
590        weight_decay: f64, step: i64,
591    ) -> *mut i8;
592
593    // --- Fused Adam/AdamW (multi-tensor kernel) ---
594
595    pub fn flodl_fused_adam_(
596        params: *mut FlodlTensor, grads: *mut FlodlTensor,
597        exp_avgs: *mut FlodlTensor, exp_avg_sqs: *mut FlodlTensor,
598        count: i32, lr: f64,
599        beta1: f64, beta2: f64, eps: f64,
600        weight_decay: f64, step: i64,
601        grad_scale: FlodlTensor, found_inf: FlodlTensor,
602    ) -> *mut i8;
603
604    pub fn flodl_fused_adamw_(
605        params: *mut FlodlTensor, grads: *mut FlodlTensor,
606        exp_avgs: *mut FlodlTensor, exp_avg_sqs: *mut FlodlTensor,
607        count: i32, lr: f64,
608        beta1: f64, beta2: f64, eps: f64,
609        weight_decay: f64, step: i64,
610        grad_scale: FlodlTensor, found_inf: FlodlTensor,
611    ) -> *mut i8;
612
613    // --- Pinned memory ---
614
615    pub fn flodl_pin_memory(t: FlodlTensor, result: *mut FlodlTensor) -> *mut i8;
616    pub fn flodl_is_pinned(t: FlodlTensor) -> i32;
617
618    // --- Memory diagnostics ---
619
620    pub fn flodl_malloc_trim() -> i32;
621
622    // --- Zero grad (set_to_none) ---
623
624    pub fn flodl_zero_grad_set_to_none(t: FlodlTensor);
625
626    // --- Fused clip_grad_norm ---
627
628    pub fn flodl_clip_grad_norm(
629        params: *mut FlodlTensor, count: i32,
630        max_norm: f64, total_norm_out: *mut f64,
631    ) -> *mut i8;
632
633    // --- Multi-tensor foreach operations ---
634
635    pub fn flodl_foreach_add_scalar_(
636        tensors: *mut FlodlTensor, count: i32, scalar: f64,
637    ) -> *mut i8;
638
639    pub fn flodl_foreach_mul_scalar_(
640        tensors: *mut FlodlTensor, count: i32, scalar: f64,
641    ) -> *mut i8;
642
643    pub fn flodl_foreach_zero_(
644        tensors: *mut FlodlTensor, count: i32,
645    ) -> *mut i8;
646
647    pub fn flodl_foreach_add_list_(
648        tensors1: *mut FlodlTensor, tensors2: *mut FlodlTensor,
649        count: i32, alpha: f64,
650    ) -> *mut i8;
651
652    pub fn flodl_foreach_norm(
653        tensors: *mut FlodlTensor, count: i32, ord: f64,
654        results: *mut FlodlTensor,
655    ) -> *mut i8;
656
657    pub fn flodl_foreach_lerp_scalar_(
658        tensors1: *mut FlodlTensor, tensors2: *mut FlodlTensor,
659        count: i32, weight: f64,
660    ) -> *mut i8;
661
662    pub fn flodl_foreach_sqrt_(
663        tensors: *mut FlodlTensor, count: i32,
664    ) -> *mut i8;
665
666    // --- Autograd diagnostics ---
667
668    pub fn flodl_autograd_node_count(t: FlodlTensor) -> i64;
669
670    // --- Fused loss functions ---
671
672    pub fn flodl_mse_loss(
673        pred: FlodlTensor, target: FlodlTensor,
674        reduction: i64, result: *mut FlodlTensor,
675    ) -> *mut i8;
676
677    pub fn flodl_cross_entropy_loss(
678        pred: FlodlTensor, target: FlodlTensor,
679        reduction: i64, ignore_index: i64, label_smoothing: f64,
680        result: *mut FlodlTensor,
681    ) -> *mut i8;
682
683    pub fn flodl_bce_with_logits_loss(
684        pred: FlodlTensor, target: FlodlTensor,
685        reduction: i64, result: *mut FlodlTensor,
686    ) -> *mut i8;
687
688    pub fn flodl_l1_loss(
689        pred: FlodlTensor, target: FlodlTensor,
690        reduction: i64, result: *mut FlodlTensor,
691    ) -> *mut i8;
692
693    pub fn flodl_smooth_l1_loss(
694        pred: FlodlTensor, target: FlodlTensor,
695        reduction: i64, beta: f64,
696        result: *mut FlodlTensor,
697    ) -> *mut i8;
698
699    pub fn flodl_kl_div_loss(
700        input: FlodlTensor, target: FlodlTensor,
701        reduction: i64, log_target: i32,
702        result: *mut FlodlTensor,
703    ) -> *mut i8;
704
705    // --- Fused batch normalization ---
706
707    pub fn flodl_batch_norm(
708        input: FlodlTensor, weight: FlodlTensor,
709        bias: FlodlTensor, running_mean: FlodlTensor,
710        running_var: FlodlTensor, training: i32,
711        momentum: f64, eps: f64,
712        result: *mut FlodlTensor,
713    ) -> *mut i8;
714
715    // --- Fused dropout ---
716
717    pub fn flodl_dropout(
718        input: FlodlTensor, p: f64, training: i32,
719        result: *mut FlodlTensor,
720    ) -> *mut i8;
721
722    pub fn flodl_feature_dropout(
723        input: FlodlTensor, p: f64, training: i32,
724        result: *mut FlodlTensor,
725    ) -> *mut i8;
726
727    // --- In-place copy ---
728
729    pub fn flodl_copy_(dst: FlodlTensor, src: FlodlTensor, non_blocking: i32) -> *mut i8;
730
731    // --- Memory format ---
732
733    pub fn flodl_to_channels_last(t: FlodlTensor, result: *mut FlodlTensor) -> *mut i8;
734    pub fn flodl_is_channels_last(t: FlodlTensor) -> i32;
735
736    // --- CUDA Graphs ---
737
738    pub fn flodl_cuda_graph_new(graph_out: *mut *mut c_void) -> *mut i8;
739    pub fn flodl_cuda_graph_capture_begin(
740        graph: *mut c_void, pool_hi: u64, pool_lo: u64, mode: i32,
741    ) -> *mut i8;
742    pub fn flodl_cuda_graph_capture_end(graph: *mut c_void) -> *mut i8;
743    pub fn flodl_cuda_graph_replay(graph: *mut c_void) -> *mut i8;
744    pub fn flodl_cuda_graph_reset(graph: *mut c_void) -> *mut i8;
745    pub fn flodl_cuda_graph_delete(graph: *mut c_void);
746    pub fn flodl_cuda_graph_pool(
747        graph: *mut c_void, pool_hi: *mut u64, pool_lo: *mut u64,
748    );
749    pub fn flodl_cuda_graph_pool_handle(pool_hi: *mut u64, pool_lo: *mut u64);
750
751    // --- Utility ---
752
753    pub fn flodl_free_string(s: *mut i8);
754}