Skip to main content

ct2rs/sys/
types.rs

1// types.rs
2//
3// Copyright (c) 2023-2024 Junpei Kawamoto
4//
5// This software is released under the MIT License.
6//
7// http://opensource.org/licenses/mit-license.php
8
9//! This module defines common structures.
10
11use std::fmt::{Debug, Formatter};
12
13pub use self::ffi::GenerationStepResult;
14pub(super) use self::ffi::{VecStr, VecString, VecUSize};
15
16#[cxx::bridge]
17pub(crate) mod ffi {
18    /// The result for a single generation step.
19    ///
20    /// This struct is a Rust binding to the
21    /// [`ctranslate2.GenerationStepResult`](https://opennmt.net/CTranslate2/python/ctranslate2.GenerationStepResult.html).
22    #[derive(Clone, Debug)]
23    pub struct GenerationStepResult {
24        /// The decoding step.
25        pub step: usize,
26        /// The batch index.
27        pub batch_id: usize,
28        /// ID of the generated token.
29        pub token_id: usize,
30        /// Index of the hypothesis in the batch.
31        pub hypothesis_id: usize,
32        /// String value of the generated token.
33        pub token: String,
34        /// true if score is given.
35        pub has_score: bool,
36        ///  Score of the token.
37        pub score: f32,
38        /// Whether this step is the last generation step for this batch.
39        pub is_last: bool,
40    }
41
42    #[derive(PartialEq, Clone)]
43    pub struct VecString {
44        v: Vec<String>,
45    }
46
47    #[derive(PartialEq, Clone)]
48    pub struct VecStr<'a> {
49        v: Vec<&'a str>,
50    }
51
52    #[derive(PartialEq, Clone)]
53    pub struct VecUSize {
54        v: Vec<usize>,
55    }
56
57    struct _dummy<'a> {
58        _vec_string: Vec<VecString>,
59        _vec_str: Vec<VecStr<'a>>,
60        _vec_usize: Vec<VecUSize>,
61    }
62}
63
64#[inline]
65pub(crate) fn vec_ffi_vecstr<T: AsRef<str>>(src: &[Vec<T>]) -> Vec<VecStr> {
66    src.iter()
67        .map(|v| VecStr {
68            v: v.iter().map(AsRef::as_ref).collect(),
69        })
70        .collect()
71}
72
73impl From<VecString> for Vec<String> {
74    fn from(value: VecString) -> Self {
75        value.v
76    }
77}
78
79impl From<Vec<String>> for VecString {
80    fn from(v: Vec<String>) -> Self {
81        Self { v }
82    }
83}
84
85impl Debug for VecString {
86    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
87        self.v.fmt(f)
88    }
89}
90
91impl<'a> Debug for VecStr<'a> {
92    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
93        self.v.fmt(f)
94    }
95}
96
97impl From<VecUSize> for Vec<usize> {
98    fn from(value: VecUSize) -> Self {
99        value.v
100    }
101}
102
103impl From<Vec<usize>> for VecUSize {
104    fn from(v: Vec<usize>) -> Self {
105        Self { v }
106    }
107}
108
109impl Debug for VecUSize {
110    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
111        self.v.fmt(f)
112    }
113}
114
115#[cfg(test)]
116mod tests {
117    use super::ffi::{VecString, VecUSize};
118    use super::vec_ffi_vecstr;
119
120    #[test]
121    fn str_vectors() {
122        let data = vec![vec!["a", "b", "c"], vec!["1", "2"]];
123        let res = vec_ffi_vecstr(&data);
124
125        assert_eq!(res.len(), data.len());
126        for (i, list) in data.iter().enumerate() {
127            let v = &res.get(i).unwrap().v;
128            assert_eq!(v.len(), list.len());
129            for (j, s) in list.iter().enumerate() {
130                assert_eq!(v.get(j).unwrap(), s);
131            }
132        }
133    }
134
135    #[test]
136    fn empty_inner_vectors() {
137        let data: Vec<Vec<&str>> = vec![vec![], vec![]];
138        let res = vec_ffi_vecstr(&data);
139
140        assert_eq!(res.len(), data.len());
141        for item in res.iter() {
142            assert_eq!(item.v.len(), 0);
143        }
144    }
145
146    #[test]
147    fn empty_vectors() {
148        let data: Vec<Vec<&str>> = vec![];
149        let res = vec_ffi_vecstr(&data);
150
151        assert_eq!(res.len(), 0);
152    }
153
154    #[test]
155    fn from_vec_string() {
156        let s = vec!["a".to_string(), "b".to_string()];
157        let v = VecString { v: s.clone() };
158
159        let res: Vec<String> = v.into();
160        assert_eq!(s, res);
161    }
162
163    #[test]
164    fn into_vec_string() {
165        let v = vec!["a".to_string(), "b".to_string()];
166        let res: VecString = v.clone().into();
167
168        assert_eq!(res, VecString { v });
169    }
170
171    #[test]
172    fn from_vec_usize() {
173        let s: Vec<usize> = vec![1, 2, 3];
174        let v = VecUSize { v: s.clone() };
175
176        let res: Vec<usize> = v.into();
177        assert_eq!(s, res);
178    }
179
180    #[test]
181    fn into_vec_usize() {
182        let v: Vec<usize> = vec![1, 2, 3];
183        let res: VecUSize = v.clone().into();
184
185        assert_eq!(res, VecUSize { v });
186    }
187}