1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
//! Copyright 2026 0xClandestine, Ekryski, TheTom, Ambisphaeric
//! SPDX-License-Identifier: Apache-2.0
//! Fused image-unfold + linear-projection patch embedding for vision
//! transformers.
//!
//! A ViT patch embedding takes an image, cuts it into non-overlapping
//! `patch_h × patch_w` tiles, flattens each tile into a
//! `in_ch · patch_h · patch_w` vector, and linearly projects every
//! vector into the model's hidden dimension. Done as two ops it
//! materialises the unfolded `[num_patches, patch_dim]` tensor in global
//! memory — pure bandwidth waste, since each unfolded value is read
//! exactly once by the projection GEMM. This kernel fuses the unfold and
//! the projection: each thread gathers its patch's pixels straight from
//! the image and dots them with one weight row, no intermediate buffer.
//!
//! It differs from `conv2d` in layout, not arithmetic — `conv2d` keeps
//! the NCHW image convention and writes NCHW output; `patch_embed` takes
//! the same NCHW image but treats the weight as a flat linear matrix
//! `[hidden, patch_dim]` and writes transformer-token output
//! `[num_patches, hidden]`, which is what a ViT block consumes directly.
//!
//! Layouts (NCHW image, flat linear weight):
//!
//! image [in_ch, in_h, in_w] T (single image)
//! weight [hidden, in_ch * patch_h * patch_w] T
//! bias [hidden] T
//! out [num_patches, hidden] T
//!
//! patches_h = in_h / patch_h
//! patches_w = in_w / patch_w
//! num_patches = patches_h * patches_w
//! patch_dim = in_ch * patch_h * patch_w
//!
//! One thread per output element `(patch, h)` where `patch` indexes the
//! flattened patch grid (row-major over the `patches_h × patches_w`
//! grid) and `h` indexes the hidden dimension. The thread walks the
//! patch's `in_ch × patch_h × patch_w` pixels, dotting each with the
//! matching weight column, accumulating in fp32. Generic over T.
//!
//! Patch order matches PyTorch `unfold` / `nn.Conv2d` flattening:
//! the weight column for `(ic, py, px)` is at
//! `ic*patch_h*patch_w + py*patch_w + px`.
//!
//! Codegen-only. Correctness validated by `patch_embed_gpu_correctness`.
use ;