ct2rs 0.9.19

Rust bindings for OpenNMT/CTranslate2
Documentation
#include "ctranslate2/ops/flash_attention.h"

#include "dispatch.h"

namespace ctranslate2 {
  namespace ops {
    FlashAttention::FlashAttention(float queries_scale, dim_t sliding_window, bool is_causal)
    : _queries_scale(queries_scale)
    , _sliding_window(sliding_window)
    , _is_causal(is_causal)
    {
    }

    void FlashAttention::operator()(StorageView& queries,
                                    StorageView& keys,
                                    StorageView& values,
                                    StorageView& output,
                                    StorageView* cached_keys,
                                    StorageView* cached_values,
                                    StorageView* attention,
                                    bool return_normalized_attention,
                                    StorageView* rotary_cos,
                                    StorageView* rotary_sin,
                                    const bool rotary_interleave,
                                    StorageView* alibi,
                                    dim_t offset) const {
      PROFILE("FlashAttention");
      DEVICE_DISPATCH(queries.device(), compute<D>(queries, keys, values, output, cached_keys, cached_values,
                                                   attention, return_normalized_attention,
                                                   rotary_cos, rotary_sin, rotary_interleave, alibi, offset));
    }
  }
}