Skip to main content

ct2rs/sys/
config.rs

1// config.rs
2//
3// Copyright (c) 2023-2024 Junpei Kawamoto
4//
5// This software is released under the MIT License.
6//
7// http://opensource.org/licenses/mit-license.php
8
9//! Configs and associated enums.
10
11use std::fmt::{Debug, Display, Formatter};
12
13use cxx::UniquePtr;
14
15pub use self::ffi::{
16    get_device_count, get_log_level, get_random_seed, set_log_level, set_random_seed, BatchType,
17    ComputeType, Device, LogLevel,
18};
19
20#[cxx::bridge]
21pub(crate) mod ffi {
22    /// Represents the computing device to be used.
23    ///
24    /// This enum is a Rust binding to the
25    /// [`ctranslate2.Device`](https://opennmt.net/CTranslate2/python/ctranslate2.Device.html),
26    /// which can take one of the following two values:
27    /// - [`CPU`][Device::CPU]
28    /// - [`CUDA`][Device::CUDA]
29    ///
30    /// The default setting for this enum is [`CPU`][Device::CPU].
31    ///
32    /// # Examples
33    ///
34    /// Example of creating a default `Device`:
35    ///
36    /// ```
37    /// use ct2rs::sys::Device;
38    ///
39    /// let device = Device::default();
40    /// # assert_eq!(device, Device::CPU);
41    /// ```
42    ///
43    #[derive(Copy, Clone, Debug)]
44    #[repr(i32)]
45    enum Device {
46        /// Use CPU.
47        CPU,
48        /// Use GPU (CUDA).
49        CUDA,
50    }
51
52    /// Model computation type.
53    ///
54    /// This enum can take one of the following values:
55    /// - [`DEFAULT`][ComputeType::DEFAULT]
56    /// - [`AUTO`][ComputeType::AUTO]
57    /// - [`FLOAT32`][ComputeType::FLOAT32]
58    /// - [`INT8`][ComputeType::INT8]
59    /// - [`INT8_FLOAT32`][ComputeType::INT8_FLOAT32]
60    /// - [`INT8_FLOAT16`][ComputeType::INT8_FLOAT16]
61    /// - [`INT8_BFLOAT16`][ComputeType::INT8_BFLOAT16]
62    /// - [`INT16`][ComputeType::INT16]
63    /// - [`FLOAT16`][ComputeType::FLOAT16]
64    /// - [`BFLOAT16`][ComputeType::BFLOAT16]
65    ///
66    /// The default setting for this enum is [`DEFAULT`][ComputeType::DEFAULT], meaning that unless
67    /// specified otherwise, the computation will proceed with the same quantization level as was
68    /// used during the model's conversion.
69    ///
70    /// See also:
71    /// [Quantization](https://opennmt.net/CTranslate2/quantization.html#quantize-on-model-loading)
72    /// for more details on how quantization affects computation and how it can be applied during
73    /// model loading.
74    ///
75    /// # Examples
76    ///
77    /// Example of creating a default `ComputeType`:
78    ///
79    /// ```
80    /// use ct2rs::sys::ComputeType;
81    ///
82    /// let compute_type = ComputeType::default();
83    /// # assert_eq!(compute_type, ComputeType::DEFAULT);
84    /// ```
85    ///
86    #[derive(Copy, Clone, Debug)]
87    #[repr(i32)]
88    enum ComputeType {
89        /// Keeps the same quantization that was used during model conversion.
90        DEFAULT,
91        /// Uses the fastest computation type that is supported on this system and device.
92        AUTO,
93        /// Utilizes 32-bit floating-point precision.
94        FLOAT32,
95        /// Uses 8-bit integer precision.
96        INT8,
97        /// Combines 8-bit integer quantization with 32-bit floating-point computation.
98        INT8_FLOAT32,
99        /// Combines 8-bit integer quantization with 16-bit floating-point computation.
100        INT8_FLOAT16,
101        /// Combines 8-bit integer quantization with Brain Floating Point (16-bit) computation.
102        INT8_BFLOAT16,
103        /// Uses 16-bit integer precision.
104        INT16,
105        /// Utilizes 16-bit floating-point precision (half precision).
106        FLOAT16,
107        /// Uses Brain Floating Point (16-bit) precision.
108        BFLOAT16,
109    }
110
111    /// Specifies how the `max_batch_size` should be calculated.
112    ///
113    /// This enum can take one of the following two values:
114    /// - [`Examples`][BatchType::Examples]
115    /// - [`Tokens`][BatchType::Tokens]
116    ///
117    /// The default setting for this enum is [`Examples`][BatchType::Examples].
118    ///
119    /// # Examples
120    ///
121    /// Example of creating a default `BatchType`:
122    ///
123    /// ```
124    /// use ct2rs::sys::BatchType;
125    ///
126    /// let batch_type = BatchType::default();
127    /// # assert_eq!(batch_type, BatchType::Examples);
128    /// ```
129    #[derive(Copy, Clone, Debug)]
130    #[repr(i32)]
131    enum BatchType {
132        /// The batch size is calculated based on the number of individual examples.
133        Examples,
134        /// The batch size is calculated based on the total number of tokens across all examples.
135        Tokens,
136    }
137
138    /// Logging level.
139    ///
140    /// This enum can take one of the following values:
141    /// - [`Off`][LogLevel::Off]
142    /// - [`Critical`][LogLevel::Critical]
143    /// - [`Error`][LogLevel::Error]
144    /// - [`Warning`][LogLevel::Warning]
145    /// - [`Info`][LogLevel::Info]
146    /// - [`Debug`][LogLevel::Debug]
147    /// - [`Trace`][LogLevel::Trace]
148    ///
149    /// The default setting for this enum is [`Warning`][LogLevel::Warning].
150    ///
151    /// # Examples
152    ///
153    /// Example of creating a default `LogLevel`:
154    ///
155    /// ```
156    /// use ct2rs::sys::LogLevel;
157    ///
158    /// let log_level = LogLevel::default();
159    /// # assert_eq!(log_level, LogLevel::Warning);
160    /// ```
161    #[derive(Copy, Clone, Debug)]
162    #[repr(i32)]
163    enum LogLevel {
164        Off = -3,
165        Critical = -2,
166        Error = -1,
167        Warning = 0,
168        Info = 1,
169        Debug = 2,
170        Trace = 3,
171    }
172
173    unsafe extern "C++" {
174        include!("ct2rs/include/config.h");
175
176        type Device;
177        type ComputeType;
178        type ReplicaPoolConfig;
179        pub type BatchType;
180
181        fn replica_pool_config(
182            num_threads_per_replica: usize,
183            max_queued_batches: i32,
184            cpu_core_offset: i32,
185        ) -> UniquePtr<ReplicaPoolConfig>;
186
187        pub type Config;
188
189        fn config(
190            device: Device,
191            compute_type: ComputeType,
192            device_indices: &[i32],
193            tensor_parallel: bool,
194            replica_pool_config: UniquePtr<ReplicaPoolConfig>,
195        ) -> UniquePtr<Config>;
196
197        /// Returns the number of devices.
198        fn get_device_count(device: Device) -> i32;
199
200        type LogLevel;
201
202        /// Sets the CTranslate2 logging level.
203        ///
204        /// # Examples
205        /// The following example sets the log level to `Debug`.
206        /// ```
207        /// use ct2rs::sys::{LogLevel, set_log_level};
208        ///
209        /// set_log_level(LogLevel::Debug);
210        /// ```
211        fn set_log_level(level: LogLevel);
212
213        /// Returns the current logging level.
214        fn get_log_level() -> LogLevel;
215
216        /// Sets the seed of random generators.
217        ///
218        /// # Examples
219        /// The following example sets the random seed to `12345`.
220        /// ```
221        /// use ct2rs::sys::set_random_seed;
222        ///
223        /// set_random_seed(12345);
224        /// ```
225        fn set_random_seed(seed: u32);
226
227        /// Returns the current seed of random generators.
228        fn get_random_seed() -> u32;
229    }
230}
231
232impl Default for Device {
233    fn default() -> Self {
234        Self::CPU
235    }
236}
237
238impl Display for Device {
239    fn fmt(&self, f: &mut Formatter) -> std::fmt::Result {
240        match *self {
241            Device::CPU => write!(f, "CPU"),
242            Device::CUDA => write!(f, "CUDA"),
243            _ => write!(f, "Unknown"),
244        }
245    }
246}
247
248impl Default for ComputeType {
249    fn default() -> Self {
250        Self::DEFAULT
251    }
252}
253
254impl Display for ComputeType {
255    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
256        match *self {
257            ComputeType::DEFAULT => write!(f, "default"),
258            ComputeType::AUTO => write!(f, "auto"),
259            ComputeType::FLOAT32 => write!(f, "float32"),
260            ComputeType::INT8 => write!(f, "int8"),
261            ComputeType::INT8_FLOAT32 => write!(f, "int8_float32"),
262            ComputeType::INT8_FLOAT16 => write!(f, "int8_float16"),
263            ComputeType::INT8_BFLOAT16 => write!(f, "int8_bfloat16"),
264            ComputeType::INT16 => write!(f, "int16"),
265            ComputeType::FLOAT16 => write!(f, "float16"),
266            ComputeType::BFLOAT16 => write!(f, "bfloat16"),
267            _ => write!(f, "unknown"),
268        }
269    }
270}
271
272impl Default for BatchType {
273    fn default() -> Self {
274        Self::Examples
275    }
276}
277
278impl Display for BatchType {
279    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
280        match *self {
281            BatchType::Examples => write!(f, "examples"),
282            BatchType::Tokens => write!(f, "tokens"),
283            _ => write!(f, "unknown"),
284        }
285    }
286}
287
288impl Default for LogLevel {
289    fn default() -> Self {
290        Self::Warning
291    }
292}
293
294impl Display for LogLevel {
295    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
296        match *self {
297            LogLevel::Off => write!(f, "off"),
298            LogLevel::Critical => write!(f, "critical"),
299            LogLevel::Error => write!(f, "error"),
300            LogLevel::Warning => write!(f, "warning"),
301            LogLevel::Info => write!(f, "info"),
302            LogLevel::Debug => write!(f, "debug"),
303            LogLevel::Trace => write!(f, "trace"),
304            _ => write!(f, "unknown"),
305        }
306    }
307}
308
309/// The `Config` structure holds the configuration settings for CTranslator2.
310///
311/// # Examples
312///
313/// Example of creating a default `Config`:
314///
315/// ```
316/// use ct2rs::sys::{ComputeType, Config, Device};
317///
318/// let config = Config::default();
319/// # assert_eq!(config.device, Device::default());
320/// # assert_eq!(config.compute_type, ComputeType::default());
321/// # assert_eq!(config.device_indices, vec![0]);
322/// # assert_eq!(config.tensor_parallel, false);
323/// # assert_eq!(config.num_threads_per_replica, 0);
324/// # assert_eq!(config.max_queued_batches, 0);
325/// # assert_eq!(config.cpu_core_offset, -1);
326/// ```
327#[derive(PartialEq, Eq, Clone, Debug)]
328pub struct Config {
329    /// Device to use.
330    pub device: Device,
331    /// Model computation type.
332    pub compute_type: ComputeType,
333    /// Device IDs where to place this generator on. (default: `vec![0]`)
334    pub device_indices: Vec<i32>,
335    /// Run model with tensor parallel mode. (default: false)
336    pub tensor_parallel: bool,
337    /// Number of threads per translator/generator (0 to use a default value). (default: 0)
338    pub num_threads_per_replica: usize,
339    /// Maximum numbers of batches in the queue (-1 for unlimited, 0 for an automatic value).
340    /// When the queue is full, future requests will block until a free slot is available.
341    /// (default: 0)
342    pub max_queued_batches: i32,
343    /// (default: -1)
344    pub cpu_core_offset: i32,
345}
346
347impl Default for Config {
348    fn default() -> Self {
349        Self {
350            device: Default::default(),
351            compute_type: Default::default(),
352            device_indices: vec![0],
353            tensor_parallel: false,
354            num_threads_per_replica: 0,
355            max_queued_batches: 0,
356            cpu_core_offset: -1,
357        }
358    }
359}
360
361impl Config {
362    pub(crate) fn to_ffi(&self) -> UniquePtr<ffi::Config> {
363        ffi::config(
364            self.device,
365            self.compute_type,
366            self.device_indices.as_slice(),
367            false,
368            ffi::replica_pool_config(
369                self.num_threads_per_replica,
370                self.max_queued_batches,
371                self.cpu_core_offset,
372            ),
373        )
374    }
375}
376
377#[cfg(test)]
378mod tests {
379    use rand::random;
380
381    use super::{
382        get_device_count, get_log_level, get_random_seed, set_log_level, set_random_seed,
383        BatchType, ComputeType, Config, Device, LogLevel,
384    };
385
386    #[test]
387    fn test_device_display() {
388        assert_eq!(format!("{}", Device::CPU), "CPU");
389        assert_eq!(format!("{}", Device::CUDA), "CUDA");
390    }
391
392    #[test]
393    fn test_compute_type_display() {
394        assert_eq!(format!("{}", ComputeType::DEFAULT), "default");
395        assert_eq!(format!("{}", ComputeType::AUTO), "auto");
396        assert_eq!(format!("{}", ComputeType::FLOAT32), "float32");
397        assert_eq!(format!("{}", ComputeType::INT8), "int8");
398        assert_eq!(format!("{}", ComputeType::INT8_FLOAT32), "int8_float32");
399        assert_eq!(format!("{}", ComputeType::INT8_FLOAT16), "int8_float16");
400        assert_eq!(format!("{}", ComputeType::INT8_BFLOAT16), "int8_bfloat16");
401        assert_eq!(format!("{}", ComputeType::INT16), "int16");
402        assert_eq!(format!("{}", ComputeType::FLOAT16), "float16");
403        assert_eq!(format!("{}", ComputeType::BFLOAT16), "bfloat16");
404    }
405
406    #[test]
407    fn test_batch_type_display() {
408        assert_eq!(format!("{}", BatchType::Examples), "examples");
409        assert_eq!(format!("{}", BatchType::Tokens), "tokens");
410    }
411
412    #[test]
413    fn test_log_level_display() {
414        assert_eq!(format!("{}", LogLevel::Off), "off");
415        assert_eq!(format!("{}", LogLevel::Critical), "critical");
416        assert_eq!(format!("{}", LogLevel::Error), "error");
417        assert_eq!(format!("{}", LogLevel::Warning), "warning");
418        assert_eq!(format!("{}", LogLevel::Info), "info");
419        assert_eq!(format!("{}", LogLevel::Debug), "debug");
420        assert_eq!(format!("{}", LogLevel::Trace), "trace");
421    }
422
423    #[test]
424    fn test_config_to_ffi() {
425        let config = Config::default();
426        let res = config.to_ffi();
427
428        assert!(!res.is_null());
429    }
430
431    #[cfg(not(feature = "cuda"))]
432    #[test]
433    fn test_get_device_count() {
434        assert_eq!(get_device_count(Device::CPU), 1);
435        assert_eq!(get_device_count(Device::CUDA), 0);
436    }
437
438    #[test]
439    fn test_default_log_level() {
440        assert_eq!(LogLevel::default(), LogLevel::Warning);
441    }
442
443    #[test]
444    fn test_log_level() {
445        for l in [
446            LogLevel::Off,
447            LogLevel::Critical,
448            LogLevel::Error,
449            LogLevel::Warning,
450            LogLevel::Info,
451            LogLevel::Debug,
452            LogLevel::Trace,
453        ] {
454            set_log_level(l);
455            assert_eq!(get_log_level(), l);
456        }
457    }
458
459    #[test]
460    fn test_random_seed() {
461        let r = random::<u32>();
462        set_random_seed(r);
463        assert_eq!(get_random_seed(), r);
464    }
465}