divvunspell 1.0.0-beta.3

Spell checking library for ZHFST/BHFST spellers, with case handling and tokenization support.
Documentation
use std::path::Path;
use std::sync::Arc;

use parking_lot::Mutex;
use rust_bert::pipelines::common::ModelType;
use rust_bert::pipelines::text_generation::{TextGenerationConfig, TextGenerationModel};
use rust_bert::resources::{LocalResource, Resource};
use rust_bert::RustBertError;

use super::Predictor;

pub struct Gpt2Predictor {
    model: Mutex<TextGenerationModel>,
}

impl Gpt2Predictor {
    pub fn new(model_path: &Path) -> Result<Self, RustBertError> {
        let config_resource = Resource::Local(LocalResource {
            local_path: model_path.join("config.json"),
        });
        let vocab_resource = Resource::Local(LocalResource {
            local_path: model_path.join("vocab.json"),
        });
        let merges_resource = Resource::Local(LocalResource {
            local_path: model_path.join("merges.txt"),
        });
        let weights_resource = Resource::Local(LocalResource {
            local_path: model_path.join("rust_model.ot"),
        });

        let generate_config = TextGenerationConfig {
            model_resource: weights_resource,
            vocab_resource: vocab_resource,
            merges_resource: merges_resource,
            config_resource: config_resource,
            model_type: ModelType::GPT2,
            max_length: 24,
            do_sample: true,
            num_beams: 1,
            temperature: 1.1,
            num_return_sequences: 1,
            ..Default::default()
        };
        let model = Mutex::new(TextGenerationModel::new(generate_config)?);
        Ok(Self { model })
    }

    fn generate(&self, raw_input: &str) -> Vec<String> {
        let guard = self.model.lock();
        guard.generate(&[raw_input], None)
    }
}

impl Predictor for Gpt2Predictor {
    fn predict(self: Arc<Self>, raw_input: &str) -> Vec<String> {
        self.generate(raw_input)
    }
}