ct2rs 0.9.19

Rust bindings for OpenNMT/CTranslate2
Documentation
#include "ctranslate2/layers/attention.h"
#include "ctranslate2/ops/split.h"

#include <algorithm>
#include <cmath>
#include <numeric>

#include "dispatch.h"
#include "cpu/parallel.h"

namespace ctranslate2 {
  namespace layers {
    static StorageView build_alibi(dim_t num_heads,
                                   dim_t key_max_length,
                                   bool use_positive_positions,
                                   const float scale) {
      const float closest_power_of_2_f = std::pow(2.f, std::floor(std::log2f(num_heads)));
      const dim_t closest_power_of_2 = closest_power_of_2_f;

      const float base = std::pow(2.f, -std::pow(2.f, -(std::log2f(closest_power_of_2_f) - 3.f)));

      std::vector<float> slopes;
      slopes.reserve(closest_power_of_2);
      for (dim_t power = 1; power <= closest_power_of_2; ++power)
        slopes.emplace_back(std::pow(base, float(power)));

      if (closest_power_of_2 != num_heads) {
        const float extra_base = (
          std::pow(2.f, -std::pow(2.f, -(std::log2f(2 * closest_power_of_2_f) - 3.f))));
        const dim_t num_remaining_heads = std::min(
          closest_power_of_2, num_heads - closest_power_of_2);

        for (dim_t power = 1; power <= 2 * num_remaining_heads; power += 2)
          slopes.emplace_back(std::pow(extra_base, float(power)));
      }

      std::vector<float> positions(key_max_length);
      std::iota(positions.begin(),
                positions.end(),
                use_positive_positions ? 0 : -key_max_length + 1);

      StorageView alibi({1, num_heads, 1, key_max_length});

      for (dim_t h = 0; h < num_heads; ++h) {
        primitives<Device::CPU>::mul(slopes[h] * scale,
                                     positions.data(),
                                     alibi.index<float>({0, h, 0, 0}),
                                     key_max_length);
      }

      return alibi;
    }

    static std::vector<Dense> make_linear_layers(const models::Model& model,
                                                 const std::string& scope,
                                                 bool self_attention) {
      const dim_t num_linear_layers = self_attention ? 2 : 3;
      std::vector<Dense> layers;
      layers.reserve(num_linear_layers);
      for (dim_t i = 0; i < num_linear_layers; ++i)
        if (i == (num_linear_layers - 1)) {
          layers.emplace_back(model, scope + "/linear_" + std::to_string(i), nullptr, true);
        } else
          layers.emplace_back(model, scope + "/linear_" + std::to_string(i));
      return layers;
    }

    static std::unique_ptr<RotaryEmbeddings> make_rotary_embeddings(const models::Model& model,
                                                                    const std::string& scope,
                                                                    bool transpose) {
      const dim_t rotary_dim = model.get_attribute_with_default<int32_t>(scope + "/rotary_dim", -1);
      if (rotary_dim < 0)
        return nullptr;

      const bool interleave = model.get_flag_with_default(scope + "/rotary_interleave", true);
      const float base = model.get_attribute_with_default<float>(scope + "/rotary_base", 10000.f);

      const auto scaling_type = model.get_enum_value<RotaryScalingType>(
        scope + "/rotary_scaling_type", -1);
      const auto scaling_factor = model.get_attribute_with_default<float>(
        scope + "/rotary_scaling_factor", 1.f);
      const auto rotary_long_factor = model.get_variable_if_exists(scope +
                                                                        "/rotary_scaling_long_factor");
      const auto rotary_short_factor = model.get_variable_if_exists(scope +
                                                                   "/rotary_scaling_short_factor");
      const auto original_max_position_embeddings   = model.get_attribute_with_default<int32_t>(
        scope + "/original_max_position_embeddings", 0);

      const auto max_position_embeddings   = model.get_attribute_with_default<int32_t>(
        scope + "/max_position_embeddings", 0);

      const auto rotary_high_freq_factor = model.get_attribute_with_default<float>(scope +
                                                                        "/rotary_high_freq_factor", 4.0);
      const auto rotary_low_freq_factor = model.get_attribute_with_default<float>(scope +
                                                                        "/rotary_low_freq_factor", 1.0);
      return std::make_unique<RotaryEmbeddings>(rotary_dim,
                                                interleave,
                                                scaling_type,
                                                scaling_factor,
                                                base,
                                                /*num_initial_positions*/2048,
                                                rotary_long_factor,
                                                rotary_short_factor,
                                                rotary_low_freq_factor,
                                                rotary_high_freq_factor,
                                                original_max_position_embeddings,
                                                max_position_embeddings,
                                                transpose);
    }


    AttentionLayer::AttentionLayer(const models::Model& model,
                                           const std::string& scope,
                                           dim_t num_heads,
                                           bool self_attention,
                                           bool pre_norm,
                                           bool is_decoder,
                                           Alibi* alibi,
                                           bool is_flash_attn)
      : _tensor_parallel(model.tensor_parallel())
      , _num_heads(_tensor_parallel ? SAFE_DIVIDE(num_heads, ScopedMPISetter::getNRanks()) : num_heads)
      , _self_attention(self_attention)
      , _is_decoder(is_decoder)
      , _linear(make_linear_layers(model, scope, self_attention))
      , _d_model(_tensor_parallel ? SAFE_DIVIDE(_linear.back().output_size(),  ScopedMPISetter::getNRanks()) : _linear.back().output_size())
      , _d_head(model.get_attribute_with_default<int32_t >(scope + "/head_dim", _d_model / _num_heads))
      , _pre_norm(pre_norm)
      , _layer_norm(build_optional_layer<LayerNorm>(model, scope + "/layer_norm"))
      , _rotary_embeddings(make_rotary_embeddings(model, scope, !is_flash_attn))
      , _alibi(alibi)
      , _queries_scale(model.get_attribute_with_default<float>(
                         scope + "/queries_scale",
                         1.f / std::sqrt(static_cast<float>(_d_head))))
      , _multi_query(model.get_flag_with_default(scope + "/multi_query", false))
      , _num_heads_kv(_multi_query
                      ? 1
                      : (_tensor_parallel ? model.get_attribute_with_default<int32_t>(scope + "/num_heads_kv",
                                            _num_heads * ScopedMPISetter::getNRanks()) / ScopedMPISetter::getNRanks()
                      : model.get_attribute_with_default<int32_t>(scope + "/num_heads_kv", _num_heads)))
      , _sliding_window(model.get_attribute_with_default<int32_t>(scope + "/sliding_window", 0))
    {
    }

    DataType AttentionLayer::output_type() const {
      return _linear.back().output_type();
    }

    dim_t AttentionLayer::output_size() const {
      return _d_model;
    }

    StorageView AttentionLayer::prepare_length_mask(const StorageView& lengths,
                                                        const dim_t num_heads,
                                                        const dim_t num_queries,
                                                        const bool mask_future,
                                                        const bool multi_query) {
      const Device device = lengths.device();
      const dim_t batch_size = lengths.size();
      StorageView mask(lengths.dtype(), device);

      if (multi_query)
        mask.resize({batch_size, num_queries, num_heads});
      else
        mask.resize({batch_size, num_heads, num_queries});

      DEVICE_DISPATCH(device, (primitives<D>::prepare_length_mask(lengths.data<int32_t>(),
                                                                  batch_size,
                                                                  num_heads,
                                                                  num_queries,
                                                                  mask_future,
                                                                  multi_query,
                                                                  mask.data<int32_t>())));
      return mask;
    }


    RotaryEmbeddings::RotaryEmbeddings(const dim_t dim,
                                       const bool interleave,
                                       const RotaryScalingType scaling_type,
                                       const float scaling_factor,
                                       const float base,
                                       const dim_t num_initial_positions,
                                       const StorageView* long_scaling_factor,
                                       const StorageView* short_scaling_factor,
                                       const float low_freq_factor,
                                       const float high_freq_factor,
                                       const dim_t original_max_position_embeddings,
                                       const dim_t max_position_embeddings,
                                       const bool transpose)
      : _dim(dim)
      , _interleave(interleave)
      , _scaling_type(scaling_type)
      , _scaling_factor(scaling_factor)
      , _base(base)
      , _num_initial_positions(num_initial_positions)
      , _rotary_scaling_long_factor(long_scaling_factor ?
                                    std::make_unique<StorageView>(*long_scaling_factor) : nullptr)
      , _rotary_scaling_short_factor(short_scaling_factor ?
                                    std::make_unique<StorageView>(*short_scaling_factor) : nullptr)
      , _rotary_low_freq_factor(low_freq_factor)
      , _rotary_high_freq_factor(high_freq_factor)
      , _original_max_position_embeddings(original_max_position_embeddings)
      , _max_position_embeddings(max_position_embeddings)
      , _rotary_op(dim, interleave)
      , _transpose(transpose)
    {
      if (_rotary_scaling_long_factor && _rotary_scaling_long_factor->device() != Device::CPU)
        _rotary_scaling_long_factor = std::make_unique<StorageView>(_rotary_scaling_long_factor->to(Device::CPU));
      if (_rotary_scaling_short_factor && _rotary_scaling_short_factor->device() != Device::CPU)
        _rotary_scaling_short_factor = std::make_unique<StorageView>(_rotary_scaling_short_factor->to(Device::CPU));
    }

    void RotaryEmbeddings::apply(StorageView& x, const dim_t offset, bool fa2) {
      const Device device = x.device();
      const DataType dtype = x.dtype();
      const dim_t max_time = _transpose ? x.dim(-2) : x.dim(-3);
      const dim_t dim = _dim == 0 ? x.dim(-1) : _dim;

      if (!_sin || offset + max_time > _sin.dim(0)) {
        const dim_t cur_num_positions = _sin ? _sin.dim(0) : 0;
        const dim_t new_num_positions = std::max(offset + max_time, cur_num_positions + _num_initial_positions);
        initialize(new_num_positions, dim, device, dtype);
        if (fa2) {
          if (!_sin_half)
          {
            _sin_half = std::make_unique<StorageView>(dtype, device);
            _cos_half = std::make_unique<StorageView>(dtype, device);
          }
          const ops::Slide slide_op(1, 0, dim / 2);
          slide_op(_cos, *_cos_half);
          slide_op(_sin, *_sin_half);
          if (offset != 0)
            return;
        }
      }
      if (offset != 0 && fa2)
        return;

      StorageView sin(dtype, device);
      StorageView cos(dtype, device);
      TYPE_DISPATCH(dtype,
                    {
                      sin.view(_sin.index<T>({offset, 0}), {max_time, dim});
                      cos.view(_cos.index<T>({offset, 0}), {max_time, dim});
                    });

      StorageView y(dtype, device);
      _rotary_op(x, sin, cos, y, _transpose);
      x = std::move(y);
    }

    void RotaryEmbeddings::initialize(const dim_t num_positions,
                                      const dim_t dim,
                                      const Device device,
                                      const DataType dtype) {
      StorageView inv_freq({1, dim / 2});
      if (_scaling_type == RotaryScalingType::Su) {
        StorageView* scaling_factor;
        for (dim_t i = 0; i < inv_freq.size(); ++i) {
          if (num_positions > _original_max_position_embeddings)
            scaling_factor = _rotary_scaling_long_factor.get();
          else
            scaling_factor = _rotary_scaling_short_factor.get();
          inv_freq.at<float>(i) = 1.f / (scaling_factor->at<float>(i) *
                                         (std::pow(_base, float(i * 2) / float(dim))));
        }
      }
      else {
        for (dim_t i = 0; i < inv_freq.size(); ++i)
          inv_freq.at<float>(i) = 1.f / std::pow(_base, float(i * 2) / float(dim));
        if (_scaling_type == RotaryScalingType::Llama3) {
          StorageView new_freqs = inv_freq.sync_copy();

          const auto factor = _scaling_factor;
          const float low_freq_factor = _rotary_low_freq_factor;
          const float high_freq_factor = _rotary_high_freq_factor;
          const auto old_context_len = static_cast< float >(_original_max_position_embeddings);

          float low_freq_wavelen = old_context_len / low_freq_factor;
          float high_freq_wavelen = old_context_len / high_freq_factor;
          for (dim_t i = 0; i < inv_freq.size(); ++i) {
            float wavelen = 2.0f * M_PI / inv_freq.at<float>(i);
            if (wavelen < high_freq_wavelen) {
              // do nothing as we copied from inv_freq already.
            } else if (wavelen > low_freq_wavelen) {
              new_freqs.at<float>(i) /= factor;
            } else {
              float smooth = (old_context_len / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor);
              auto freq = inv_freq.at<float>(i);
              new_freqs.at<float>(i) = ((1 - smooth) * freq / factor + smooth * freq);
            }
          }
          inv_freq = std::move(new_freqs);
        }
      }
      if (inv_freq.device() != device)
        inv_freq = inv_freq.to(device);

      StorageView t({num_positions, 1});
      for (dim_t i = 0; i < t.size(); ++i)
        t.at<float>(i) = _scaling_type != RotaryScalingType::Linear ? i : float(i) / _scaling_factor;
      if (t.device() != device)
        t = t.to(device);

      StorageView freqs(device);
      ops::MatMul()(t, inv_freq, freqs);

      if (_interleave)
        freqs.expand_dims(-1);

      StorageView emb(device);
      ops::Concat(-1)({&freqs, &freqs}, emb);

      if (_interleave) {
        emb.reshape({num_positions, dim});
      }

      StorageView sin(device);
      ops::Sin()(emb, sin);
      if (sin.dtype() == dtype)
        _sin = std::move(sin);
      else
        _sin = sin.to(dtype);

      StorageView cos(device);
      ops::Cos()(emb, cos);
      if (cos.dtype() == dtype)
        _cos = std::move(cos);
      else
        _cos = cos.to(dtype);

      if (_original_max_position_embeddings != 0 && _max_position_embeddings != 0 && _scaling_type != RotaryScalingType::Llama3) {
        StorageView scaling_factor;
        float scale = _max_position_embeddings / _original_max_position_embeddings;
        if (scale <= 1)
          scaling_factor = StorageView(1.0f, device);
        else
          scaling_factor = StorageView(static_cast<float>(sqrt(1 + std::log(scale) / std::log(_original_max_position_embeddings))));

        ops::Mul()(_sin, scaling_factor, _sin);
        ops::Mul()(_cos, scaling_factor, _cos);
      }
    }


    Alibi::Alibi(const bool use_positive_positions, const bool scale_alibi, const dim_t num_initial_positions)
      : _use_positive_positions(use_positive_positions)
      , _num_initial_positions(num_initial_positions)
      , _scale_alibi(scale_alibi)
      , _alibi_op(use_positive_positions)
    {
    }

    void Alibi::apply(StorageView& x, const float scale) {
      const dim_t cur_length = _alibi ? _alibi.dim(-1) : 0;
      const dim_t key_length = x.dim(-1);

      if (key_length > cur_length) {
        const dim_t num_heads = x.dim(1);
        const dim_t new_length = cur_length + _num_initial_positions;
        _alibi = build_alibi(num_heads, new_length, _use_positive_positions, _scale_alibi ? scale : 1);
        _alibi.move_to(x.device(), x.dtype());
      }

      _alibi_op(x, _alibi, x);
    }

  }
}