sentencepiece_model/
lib.rs

1//! SentencePiece model parser generated from the SentencePiece protobuf definition.
2//!
3//! See [`SentencePieceModel`] for the entry point for parsing and accessing sentencepiece models.
4//!
5//! ```rust
6//! # fn main() -> Result<(), Box<dyn std::error::Error>> {
7//! use sentencepiece_model::SentencePieceModel;
8//!
9//! let model = SentencePieceModel::from_file("tests/t5-spiece.model")?;
10//! assert_eq!(model.pieces.len(), 32000);
11//! assert_eq!(model.trainer().unwrap().unk_id(), 2);
12//! # Ok(())
13//! # }
14//! ```
15
16#![cfg_attr(not(feature = "std"), no_std)]
17#![cfg_attr(docsrs, feature(doc_auto_cfg))]
18
19extern crate alloc;
20
21use core::ops::Deref;
22
23mod proto {
24    include!(concat!(env!("OUT_DIR"), "/sentencepiece.rs"));
25}
26
27pub use proto::model_proto::sentence_piece::Type;
28pub use proto::model_proto::SentencePiece;
29pub use proto::self_test_data::Sample;
30pub use proto::trainer_spec::ModelType;
31pub use proto::{ModelProto, NormalizerSpec, SelfTestData, TrainerSpec};
32
33use prost::bytes::Buf;
34use prost::Message;
35
36/// SentencePiece model.
37/// Provides access to the underlying `sentencepiece` model.
38#[derive(Clone, PartialEq, Debug)]
39#[repr(transparent)]
40pub struct SentencePieceModel {
41    model: ModelProto,
42}
43impl SentencePieceModel {
44    /// Parses a `SentencePieceModel` from a byte slice.
45    #[inline]
46    pub fn from_slice(bytes: impl AsRef<[u8]>) -> Result<Self, prost::DecodeError> {
47        let model = ModelProto::decode(bytes.as_ref())?;
48        Ok(Self { model })
49    }
50
51    /// Parses a `SentencePieceModel` from a reader.
52    #[inline]
53    pub fn from_reader<R: Buf>(reader: &mut R) -> Result<Self, prost::DecodeError> {
54        let model = ModelProto::decode(reader)?;
55        Ok(Self { model })
56    }
57
58    /// Parses a `SentencePieceModel` from a file.
59    #[cfg(feature = "std")]
60    #[inline]
61    pub fn from_file(path: impl AsRef<std::path::Path>) -> Result<Self, std::io::Error> {
62        let bytes = std::fs::read(path)?;
63        Self::from_slice(bytes).map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))
64    }
65
66    /// Returns the underlying `ModelProto`.
67    #[inline]
68    pub fn model(&self) -> &ModelProto {
69        &self.model
70    }
71
72    /// Returns the `SentencePiece` list of the model.
73    #[inline]
74    pub fn pieces(&self) -> &[SentencePiece] {
75        &self.model.pieces
76    }
77
78    /// Returns the `TrainerSpec` of the model if it exists.
79    #[inline]
80    pub fn trainer(&self) -> Option<&TrainerSpec> {
81        self.model.trainer_spec.as_ref()
82    }
83
84    /// Returns the `NormalizerSpec` of the model if it exists.
85    #[inline]
86    pub fn normalizer(&self) -> Option<&NormalizerSpec> {
87        self.model.normalizer_spec.as_ref()
88    }
89
90    /// Returns the `DenormalizerSpec` of the model if it exists.
91    #[inline]
92    pub fn denormalizer(&self) -> Option<&NormalizerSpec> {
93        self.model.denormalizer_spec.as_ref()
94    }
95
96    /// Returns the `SelfTestData` of the model if it exists.
97    #[inline]
98    pub fn self_test_data(&self) -> Option<&SelfTestData> {
99        self.model.self_test_data.as_ref()
100    }
101}
102impl Deref for SentencePieceModel {
103    type Target = ModelProto;
104
105    #[inline]
106    fn deref(&self) -> &Self::Target {
107        &self.model
108    }
109}