sbert 0.4.1

Rust implementation of Sentence Bert (SBert)
Documentation
use std::path::PathBuf;

use rust_tokenizers::tokenizer::BertTokenizer;
use rust_tokenizers::tokenizer::TruncationStrategy;
use tch::Tensor;

use crate::tokenizers::Tokenizer;
use crate::Error;

pub struct RustTokenizers {
    tokenizer: BertTokenizer,
}

impl Tokenizer for RustTokenizers {
    fn new<P: Into<PathBuf>>(path: P) -> Result<Self, Error>
    where
        Self: Sized,
    {
        let tokenizer = BertTokenizer::from_file(&path.into().to_string_lossy(), false, false)?;

        Ok(Self { tokenizer })
    }

    fn pre_tokenize<S: AsRef<str>>(&self, _input: &[S]) -> Vec<Vec<String>> {
        Vec::new()
    }

    fn tokenize<S: AsRef<str>>(&self, input: &[S]) -> (Vec<Tensor>, Vec<Tensor>) {
        use rust_tokenizers::tokenizer::Tokenizer;

        let tokenized_input =
            self.tokenizer
                .encode_list(input, 128, &TruncationStrategy::LongestFirst, 0);

        let max_len = tokenized_input
            .iter()
            .map(|input| input.token_ids.len())
            .max()
            .unwrap_or_else(|| 0);

        let tokenized_input = tokenized_input
            .into_iter()
            .map(|input| {
                let mut token_ids = input.token_ids;
                token_ids.extend(vec![0; max_len - token_ids.len()]);
                token_ids
            })
            .collect::<Vec<_>>();

        let attention_mask = tokenized_input
            .iter()
            .map(|input| {
                Tensor::from_slice(
                    &input
                        .iter()
                        .map(|e| match *e {
                            0 => 0 as i64,
                            _ => 1 as i64,
                        })
                        .collect::<Vec<_>>(),
                )
            })
            .collect::<Vec<_>>();

        let tokenized_input = tokenized_input
            .into_iter()
            .map(|input| Tensor::from_slice(&(input)))
            .collect::<Vec<_>>();
        (tokenized_input, attention_mask)
    }
}