Function load_model

Source
pub fn load_model(
    path_to_check_points_folder: String,
) -> Result<(MPNetModel, Tokenizer, MPNetPooler)>
Expand description

Loads a model and tokenizer from the specified folder.

This function takes a path to a folder containing the model and tokenizer files. It constructs the paths to the weight and tokenizer files, ensures they exist, and then loads the weights, tokenizer, and model configuration.

§Model Structure

The MPNetModel structure is as follows:

MPNetModel(
    (embeddings): MPNetEmbeddings(
        (word_embeddings): Embedding(30527, 768, padding_idx=1)
        (position_embeddings): Embedding(512, 768, padding_idx=1)
        (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
        (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): MPNetEncoder(
        (layer): ModuleList(
            (0-11): 12 x MPNetLayer(
                (attention): MPNetAttention(
                    (attn): MPNetSelfAttention(
                        (q): Linear(in_features=768, out_features=768, bias=True)
                        (k): Linear(in_features=768, out_features=768, bias=True)
                        (v): Linear(in_features=768, out_features=768, bias=True)
                        (o): Linear(in_features=768, out_features=768, bias=True)
                        (dropout): Dropout(p=0.1, inplace=False)
                    )
                    (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
                    (dropout): Dropout(p=0.1, inplace=False)
             )
                (intermediate): MPNetIntermediate(
                    (dense): Linear(in_features=768, out_features=3072, bias=True)
                    (intermediate_act_fn): GELUActivation()
                    )
                (output): MPNetOutput(
                    (dense): Linear(in_features=3072, out_features=768, bias=True)
                    (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
                    (dropout): Dropout(p=0.1, inplace=False)
                )
            )
        )
        (relative_attention_bias): Embedding(32, 12)
    )
    (pooler): MPNetPooler(
        (dense): Linear(in_features=768, out_features=768, bias=True)
        (activation): Tanh()
    )
)

§Arguments

  • path_to_check_points_folder - A string that holds the path to the folder containing the model and tokenizer files.

§Returns

  • Ok((model, tokenizer)) - A tuple containing the loaded model and tokenizer.
  • Err - An error if the specified paths do not exist, or if there is an issue loading the weights, tokenizer, or model configuration.

§How to use

use patentpick::mpnet::load_model;
let (model, tokenizer, pooler) = load_model("/path/to/model/and/tokenizer").unwrap();