tokenizers/decoders/
sequence.rs1use crate::decoders::DecoderWrapper;
2use crate::tokenizer::{Decoder, Result};
3use crate::utils::macro_rules_attribute;
4use serde::{Deserialize, Serialize};
5
6#[derive(Clone, Debug)]
7#[macro_rules_attribute(impl_serde_type!)]
8pub struct Sequence {
9 decoders: Vec<DecoderWrapper>,
10}
11
12impl Sequence {
13 pub fn new(decoders: Vec<DecoderWrapper>) -> Self {
14 Self { decoders }
15 }
16
17 pub fn get_decoders(&self) -> &[DecoderWrapper] {
18 &self.decoders
19 }
20
21 pub fn get_decoders_mut(&mut self) -> &mut [DecoderWrapper] {
22 &mut self.decoders
23 }
24}
25
26impl Decoder for Sequence {
27 fn decode_chain(&self, mut tokens: Vec<String>) -> Result<Vec<String>> {
28 for decoder in &self.decoders {
29 tokens = decoder.decode_chain(tokens)?;
30 }
31 Ok(tokens)
32 }
33}
34
35#[cfg(test)]
36mod tests {
37 use super::*;
38 use crate::decoders::ctc::CTC;
39 use crate::pre_tokenizers::metaspace::Metaspace;
40
41 #[test]
42 fn sequence_basic() {
43 let decoders = vec![
44 DecoderWrapper::CTC(CTC::default()),
45 DecoderWrapper::Metaspace(Metaspace::default()),
46 ];
47 let decoder = Sequence::new(decoders);
48 let tokens: Vec<String> = vec!["▁", "▁", "H", "H", "i", "i", "▁", "y", "o", "u"]
49 .into_iter()
50 .map(|s| s.to_string())
51 .collect();
52 let out_tokens = decoder.decode(tokens).unwrap();
53 assert_eq!(out_tokens, "Hi you");
54 }
55}