pub fn load_model(
path_to_check_points_folder: String,
) -> Result<(MPNetModel, Tokenizer, MPNetPooler)>Expand description
Loads a model and tokenizer from the specified folder.
This function takes a path to a folder containing the model and tokenizer files. It constructs the paths to the weight and tokenizer files, ensures they exist, and then loads the weights, tokenizer, and model configuration.
§Model Structure
The MPNetModel structure is as follows:
MPNetModel(
(embeddings): MPNetEmbeddings(
(word_embeddings): Embedding(30527, 768, padding_idx=1)
(position_embeddings): Embedding(512, 768, padding_idx=1)
(LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
(encoder): MPNetEncoder(
(layer): ModuleList(
(0-11): 12 x MPNetLayer(
(attention): MPNetAttention(
(attn): MPNetSelfAttention(
(q): Linear(in_features=768, out_features=768, bias=True)
(k): Linear(in_features=768, out_features=768, bias=True)
(v): Linear(in_features=768, out_features=768, bias=True)
(o): Linear(in_features=768, out_features=768, bias=True)
(dropout): Dropout(p=0.1, inplace=False)
)
(LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
(intermediate): MPNetIntermediate(
(dense): Linear(in_features=768, out_features=3072, bias=True)
(intermediate_act_fn): GELUActivation()
)
(output): MPNetOutput(
(dense): Linear(in_features=3072, out_features=768, bias=True)
(LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
)
)
(relative_attention_bias): Embedding(32, 12)
)
(pooler): MPNetPooler(
(dense): Linear(in_features=768, out_features=768, bias=True)
(activation): Tanh()
)
)§Arguments
path_to_check_points_folder- A string that holds the path to the folder containing the model and tokenizer files.
§Returns
Ok((model, tokenizer))- A tuple containing the loaded model and tokenizer.Err- An error if the specified paths do not exist, or if there is an issue loading the weights, tokenizer, or model configuration.
§How to use
use patentpick::mpnet::load_model;
let (model, tokenizer, pooler) = load_model("/path/to/model/and/tokenizer").unwrap();