ct2rs 0.9.19

Rust bindings for OpenNMT/CTranslate2
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
#pragma once

#include <algorithm>
#include <limits>

#ifdef CT2_USE_HIP
#include <hip/hip_fp16.h>
#include <hip/hip_bf16.h>
#include <hip/hip_fp8.h>
#include <thrust/iterator/counting_iterator.h>
#define __nv_bfloat16 __hip_bfloat16
#else
#include <cuda_fp16.h>
#include <cuda_bf16.h>
#endif

#include "ctranslate2/types.h"

#include "utils.h"

#if !defined(__CUDACC__) || !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 530 || defined(CT2_USE_HIP)
#  define CUDA_CAN_USE_HALF 1
#else
#  define CUDA_CAN_USE_HALF 0
#endif

#if defined(__CUDACC__) && (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__)) || defined(CT2_USE_HIP)
#  define CUDA_CAN_USE_BF16_MATH 1
#else
#  define CUDA_CAN_USE_BF16_MATH 0
#endif

namespace ctranslate2 {
  namespace cuda {

    // The index type used in CUDA kernels.
    // Currently set to a 32-bit type to maximize performance.
    using index_t = unsigned int;

    constexpr dim_t max_threads = 1024;
    constexpr dim_t max_blocks = std::numeric_limits<int32_t>::max();

    template <typename T>
    struct DeviceType {
      using type = T;
    };

    // Map float16_t and bfloat16_t to their corresponding device types.
    template<>
    struct DeviceType<float16_t> {
      using type = __half;
    };

    template<>
    struct DeviceType<bfloat16_t> {
      using type = __nv_bfloat16;
    };

    template <typename T>
    using device_type = typename DeviceType<T>::type;

    template <typename T>
    inline const device_type<T>* device_cast(const T* x) {
      return reinterpret_cast<const device_type<T>*>(x);
    }

    template <typename T>
    inline device_type<T>* device_cast(T* x) {
      return reinterpret_cast<device_type<T>*>(x);
    }

    template <typename T1, typename T2, typename UnaryFunction>
    inline void unary_transform(const T1* x, T2* y, index_t size, const UnaryFunction& op) {
      THRUST_CALL(thrust::transform, device_cast(x), device_cast(x) + size, device_cast(y), op);
    }

    template <typename T1, typename T2, typename T3, typename BinaryFunction>
    inline void binary_transform(const T1* a,
                                 const T2* b,
                                 T3* c,
                                 index_t size,
                                 const BinaryFunction& op) {
      THRUST_CALL(thrust::transform,
                  device_cast(a), device_cast(a) + size, device_cast(b), device_cast(c), op);
    }

    template <typename T1, typename T2, typename T3, typename BinaryFunction, typename IndexFunction>
    inline void binary_transform(const T1* a,
                                 const T2* b,
                                 T3* c,
                                 index_t size,
                                 const BinaryFunction& op,
                                 const IndexFunction& index_a) {
      auto index_it = thrust::make_transform_iterator(thrust::counting_iterator<index_t>(0), index_a);
      auto a_it = thrust::make_permutation_iterator(device_cast(a), index_it);
      THRUST_CALL(thrust::transform, a_it, a_it + size, device_cast(b), device_cast(c), op);
    }

    // perm_fun is a functor that takes the index in the permuted iterator and
    // return the index in the original iterator.
    template <typename T, typename PermFunction>
    inline void permute(const T* x, T* y, index_t size, const PermFunction& perm_fun) {
      auto ind_it = thrust::counting_iterator<index_t>(0);
      auto perm_ind_it = thrust::make_transform_iterator(ind_it, perm_fun);
      auto perm_it = thrust::make_permutation_iterator(device_cast(x), perm_ind_it);
      THRUST_CALL(thrust::copy, perm_it, perm_it + size, device_cast(y));
    }

    template <typename T>
    class repeat_vec {
    private:
      T _size;
    public:
      repeat_vec(T size)
        : _size(size) {
      }
      __device__
      T operator()(const T i) const {
        return i % _size;
      }
    };

    template <typename T>
    class repeat_vec_depth {
    private:
      T _size;
    public:
      repeat_vec_depth(T size)
        : _size(size) {
      }
      __device__
      T operator()(const T i) const {
        return i / _size;
      }
    };

    template <typename T>
    class repeat_vec_block {
    private:
      T _block;
      T _size;
    public:
      repeat_vec_block(T block, T size)
        : _block(block)
        , _size(size) {
      }
      __device__
      T operator()(const T i) const {
        return (i / _block) % _size;
      }
    };

    // Bind the right argument of a binary operator.
    template <template <typename> class BinaryFunctor, typename T>
    class bind_right {
    private:
      const T _y;
      BinaryFunctor<T> _op;
    public:
      bind_right(const T& y)
        : _y(y) {
      }
      __device__ T operator()(const T& x) const {
        return _op(x, _y);
      }
    };

    // Some functional operators, similar to the ones from Thrust.

    template <typename T>
    struct plus {
      __device__ T operator()(const T& lhs, const T& rhs) const {
        return lhs + rhs;
      }
    };

    template <typename T>
    struct minus {
      __device__ T operator()(const T& lhs, const T& rhs) const {
        return lhs - rhs;
      }
    };

    template <typename T>
    struct multiplies {
      __device__ T operator()(const T& lhs, const T& rhs) const {
        return lhs * rhs;
      }
    };

    template <typename T>
    struct maximum {
      __device__ T operator()(const T& lhs, const T& rhs) const {
        return lhs < rhs ? rhs : lhs;
      }
    };

    template <typename T>
    struct minimum {
      __device__ T operator()(const T& lhs, const T& rhs) const {
        return lhs < rhs ? lhs : rhs;
      }
    };

#if !CUDA_CAN_USE_HALF
    template<>
    struct plus<__half> {
      __device__ __half operator()(const __half& lhs, const __half& rhs) const {
        return __half(float(lhs) + float(rhs));
      }
    };

    template<>
    struct minus<__half> {
      __device__ __half operator()(const __half& lhs, const __half& rhs) const {
        return __half(float(lhs) - float(rhs));
      }
    };

    template<>
    struct multiplies<__half> {
      __device__ __half operator()(const __half& lhs, const __half& rhs) const {
        return __half(float(lhs) * float(rhs));
      }
    };

    template<>
    struct maximum<__half> {
      __device__ __half operator()(const __half& lhs, const __half& rhs) const {
        return float(lhs) < float(rhs) ? rhs : lhs;
      }
    };

    template<>
    struct minimum<__half> {
      __device__ __half operator()(const __half& lhs, const __half& rhs) const {
        return float(lhs) < float(rhs) ? lhs : rhs;
      }
    };
#endif

    template <typename T>
    struct relu_func {
      __device__ T operator()(T x) const {
        return x > T(0.f) ? x : T(0.f);
      }
    };

#if !CUDA_CAN_USE_HALF
    template<>
    struct relu_func<__half> {
      __device__ __half operator()(__half x) const {
        return float(x) > float(0) ? x : __half(0);
      }
    };
#endif

    template <typename T>
    struct gelu_func {
      // Implicitly promote half to float in this function.
      __device__ float operator()(float x) const {
        return 0.5f * x * (1 + erff(0.7071067811865475f * x));
      }
    };

    template <typename T>
    struct gelu_tanh_func {
      // Implicitly promote half to float in this function.
      __device__ float operator()(float x) const {
        return 0.5f * x * (1.f + tanhf(0.7978845608028654f * (x + 0.044715f * powf(x, 3.f))));
      }
    };

    template <typename T>
    struct gelu_sigmoid_func {
      // Implicitly promote half to float in this function.
      __device__ float operator()(float x) const {
        return x / (1.f + expf(-1.702f * x));
      }
    };

    template <typename T>
    struct sigmoid_func {
      // Implicitly promote half to float in this function.
      __device__ float operator()(float x) const {
        return 1.f / (1.f + expf(-x));
      }
    };

    template <typename T>
    struct swish_func {
      // Implicitly promote half to float in this function.
      __device__ float operator()(float x) const {
        return x / (1.f + expf(-x));
      }
    };

    template <typename T>
    struct tanh_func {
      // Implicitly promote half to float in this function.
      __device__ float operator()(float x) const {
        return tanhf(x);
      }
    };

    template <typename T>
    struct sin_func {
      __device__ T operator()(T x) const {
        return sinf(x);
      }
    };

    template <typename T>
    struct cos_func {
      __device__ T operator()(T x) const {
        return cosf(x);
      }
    };

    template <typename T>
    struct exp_func {
      __device__ T operator()(T x) const {
        return expf(x);
      }
    };

    template <typename T>
    struct log_func {
      __device__ T operator()(T x) const {
        return logf(x);
      }
    };

#if CUDA_CAN_USE_HALF
    template<>
    struct sin_func<__half> {
      __device__ __half operator()(__half x) const {
        return hsin(x);
      }
    };

    template<>
    struct cos_func<__half> {
      __device__ __half operator()(__half x) const {
        return hcos(x);
      }
    };

    template<>
    struct exp_func<__half> {
      __device__ __half operator()(__half x) const {
        return hexp(x);
      }
    };

    template<>
    struct log_func<__half> {
      __device__ __half operator()(__half x) const {
        return hlog(x);
      }
    };
#endif

#if CUDA_CAN_USE_BF16_MATH
    template<>
    struct sin_func<__nv_bfloat16> {
      __device__ __nv_bfloat16 operator()(__nv_bfloat16 x) const {
        return hsin(x);
      }
    };

    template<>
    struct cos_func<__nv_bfloat16> {
      __device__ __nv_bfloat16 operator()(__nv_bfloat16 x) const {
        return hcos(x);
      }
    };

    template<>
    struct exp_func<__nv_bfloat16> {
      __device__ __nv_bfloat16 operator()(__nv_bfloat16 x) const {
        return hexp(x);
      }
    };

    template<>
    struct log_func<__nv_bfloat16> {
      __device__ __nv_bfloat16 operator()(__nv_bfloat16 x) const {
        return hlog(x);
      }
    };
#endif

    // The following kernels are adapted from:
    // https://github.com/pytorch/pytorch/blob/40eff454ce5638fbff638a7f4502e29ffb9a2f0d/aten/src/ATen/native/cuda/SoftMax.cu
    // They help define row-wise reduction where each block handles a single row.

#define C10_WARP_SIZE 32

    template <index_t ILP = 2>
    inline dim3 get_block_size(index_t dim_size) {
      index_t block_size = 1;
      index_t max_block_size = std::min(dim_size / ILP, static_cast<index_t>(max_threads));
      while (block_size < max_block_size)
        block_size *= 2;
      // Launch at least a single warp - the kernel assumes that.
      block_size = std::max(static_cast<index_t>(block_size), static_cast<index_t>(C10_WARP_SIZE));
      return dim3(block_size);
    }

    template <typename Reduction, typename AccumT>
    __device__ __forceinline__ AccumT block_reduce(AccumT* smem,
                                                   AccumT val,
                                                   const Reduction& r,
                                                   AccumT defaultVal)
    {
      // To avoid RaW races from chaining blockReduce calls together, we need a sync here
      __syncthreads();

      smem[threadIdx.x] = val;

      __syncthreads();

      AccumT warpVal = defaultVal;

      // First warp will perform per-warp reductions for the remaining warps
      uint64_t mask = (((uint64_t)1) << (blockDim.x / C10_WARP_SIZE)) - 1;
      if (threadIdx.x < C10_WARP_SIZE) {
        index_t lane = threadIdx.x % C10_WARP_SIZE;
        if (lane < blockDim.x / C10_WARP_SIZE) {
          #pragma unroll
          for (index_t i = 0; i < C10_WARP_SIZE; ++i) {
            warpVal = r(warpVal, smem[lane * C10_WARP_SIZE + i]);
          }
          __syncwarp(mask);
          smem[lane] = warpVal;
        }
      }

      __syncthreads();

      // First thread will perform a reduction of the above per-warp reductions
      AccumT blockVal = defaultVal;

      if (threadIdx.x == 0) {
        for (index_t i = 0; i < blockDim.x / C10_WARP_SIZE; ++i) {
          blockVal = r(blockVal, smem[i]);
        }
        smem[0] = blockVal;
      }

      // Sync and broadcast
      __syncthreads();
      return smem[0];
    }

    template <typename Reduction,
              typename T,
              typename AccumT = T,
              index_t ILP = 2>
    __device__ __forceinline__ AccumT ilp_reduce(const T* data,
                                                 index_t size,
                                                 const Reduction& r,
                                                 AccumT defaultVal)
    {
      AccumT threadVal = defaultVal;
      index_t offset = threadIdx.x;
      index_t last = size % (ILP * blockDim.x);

      // Body (unroll by ILP times)
      for (; offset < size - last; offset += blockDim.x * ILP) {
        T tmp[ILP];

        #pragma unroll
        for (index_t j = 0; j < ILP; ++j)
          tmp[j] = data[offset + j * blockDim.x];

        #pragma unroll
        for (index_t j = 0; j < ILP; ++j)
          threadVal = r(threadVal, tmp[j]);
      }

      // Epilogue
      for (; offset < size; offset += blockDim.x)
        threadVal = r(threadVal, data[offset]);

      return threadVal;
    }

    template <typename Epilogue,
              typename scalar_t,
              typename outscalar_t,
              index_t ILP = 2>
    __device__ __forceinline__ void
    apply_epilogue(const scalar_t* input,
                   index_t depth,
                   const Epilogue& epilogue,
                   outscalar_t* output) {
      index_t offset = threadIdx.x;
      index_t last = depth % (ILP * blockDim.x);
      for (; offset < depth - last; offset += blockDim.x * ILP) {
        scalar_t tmp[ILP];

        #pragma unroll
        for (index_t j = 0; j < ILP; ++j)
          tmp[j] = input[offset + j * blockDim.x];

        #pragma unroll
        for (index_t j = 0; j < ILP; ++j)
          output[offset + j * blockDim.x] = epilogue(tmp[j]);
      }

      for (; offset < depth; offset += blockDim.x)
        output[offset] = epilogue(input[offset]);
    }

  }
}