ct2rs/sys/config.rs
1// config.rs
2//
3// Copyright (c) 2023-2024 Junpei Kawamoto
4//
5// This software is released under the MIT License.
6//
7// http://opensource.org/licenses/mit-license.php
8
9//! Configs and associated enums.
10
11use std::fmt::{Debug, Display, Formatter};
12
13use cxx::UniquePtr;
14
15pub use self::ffi::{
16 get_device_count, get_log_level, get_random_seed, set_log_level, set_random_seed, BatchType,
17 ComputeType, Device, LogLevel,
18};
19
20#[cxx::bridge]
21pub(crate) mod ffi {
22 /// Represents the computing device to be used.
23 ///
24 /// This enum is a Rust binding to the
25 /// [`ctranslate2.Device`](https://opennmt.net/CTranslate2/python/ctranslate2.Device.html),
26 /// which can take one of the following two values:
27 /// - [`CPU`][Device::CPU]
28 /// - [`CUDA`][Device::CUDA]
29 ///
30 /// The default setting for this enum is [`CPU`][Device::CPU].
31 ///
32 /// # Examples
33 ///
34 /// Example of creating a default `Device`:
35 ///
36 /// ```
37 /// use ct2rs::sys::Device;
38 ///
39 /// let device = Device::default();
40 /// # assert_eq!(device, Device::CPU);
41 /// ```
42 ///
43 #[derive(Copy, Clone, Debug)]
44 #[repr(i32)]
45 enum Device {
46 /// Use CPU.
47 CPU,
48 /// Use GPU (CUDA).
49 CUDA,
50 }
51
52 /// Model computation type.
53 ///
54 /// This enum can take one of the following values:
55 /// - [`DEFAULT`][ComputeType::DEFAULT]
56 /// - [`AUTO`][ComputeType::AUTO]
57 /// - [`FLOAT32`][ComputeType::FLOAT32]
58 /// - [`INT8`][ComputeType::INT8]
59 /// - [`INT8_FLOAT32`][ComputeType::INT8_FLOAT32]
60 /// - [`INT8_FLOAT16`][ComputeType::INT8_FLOAT16]
61 /// - [`INT8_BFLOAT16`][ComputeType::INT8_BFLOAT16]
62 /// - [`INT16`][ComputeType::INT16]
63 /// - [`FLOAT16`][ComputeType::FLOAT16]
64 /// - [`BFLOAT16`][ComputeType::BFLOAT16]
65 ///
66 /// The default setting for this enum is [`DEFAULT`][ComputeType::DEFAULT], meaning that unless
67 /// specified otherwise, the computation will proceed with the same quantization level as was
68 /// used during the model's conversion.
69 ///
70 /// See also:
71 /// [Quantization](https://opennmt.net/CTranslate2/quantization.html#quantize-on-model-loading)
72 /// for more details on how quantization affects computation and how it can be applied during
73 /// model loading.
74 ///
75 /// # Examples
76 ///
77 /// Example of creating a default `ComputeType`:
78 ///
79 /// ```
80 /// use ct2rs::sys::ComputeType;
81 ///
82 /// let compute_type = ComputeType::default();
83 /// # assert_eq!(compute_type, ComputeType::DEFAULT);
84 /// ```
85 ///
86 #[derive(Copy, Clone, Debug)]
87 #[repr(i32)]
88 enum ComputeType {
89 /// Keeps the same quantization that was used during model conversion.
90 DEFAULT,
91 /// Uses the fastest computation type that is supported on this system and device.
92 AUTO,
93 /// Utilizes 32-bit floating-point precision.
94 FLOAT32,
95 /// Uses 8-bit integer precision.
96 INT8,
97 /// Combines 8-bit integer quantization with 32-bit floating-point computation.
98 INT8_FLOAT32,
99 /// Combines 8-bit integer quantization with 16-bit floating-point computation.
100 INT8_FLOAT16,
101 /// Combines 8-bit integer quantization with Brain Floating Point (16-bit) computation.
102 INT8_BFLOAT16,
103 /// Uses 16-bit integer precision.
104 INT16,
105 /// Utilizes 16-bit floating-point precision (half precision).
106 FLOAT16,
107 /// Uses Brain Floating Point (16-bit) precision.
108 BFLOAT16,
109 }
110
111 /// Specifies how the `max_batch_size` should be calculated.
112 ///
113 /// This enum can take one of the following two values:
114 /// - [`Examples`][BatchType::Examples]
115 /// - [`Tokens`][BatchType::Tokens]
116 ///
117 /// The default setting for this enum is [`Examples`][BatchType::Examples].
118 ///
119 /// # Examples
120 ///
121 /// Example of creating a default `BatchType`:
122 ///
123 /// ```
124 /// use ct2rs::sys::BatchType;
125 ///
126 /// let batch_type = BatchType::default();
127 /// # assert_eq!(batch_type, BatchType::Examples);
128 /// ```
129 #[derive(Copy, Clone, Debug)]
130 #[repr(i32)]
131 enum BatchType {
132 /// The batch size is calculated based on the number of individual examples.
133 Examples,
134 /// The batch size is calculated based on the total number of tokens across all examples.
135 Tokens,
136 }
137
138 /// Logging level.
139 ///
140 /// This enum can take one of the following values:
141 /// - [`Off`][LogLevel::Off]
142 /// - [`Critical`][LogLevel::Critical]
143 /// - [`Error`][LogLevel::Error]
144 /// - [`Warning`][LogLevel::Warning]
145 /// - [`Info`][LogLevel::Info]
146 /// - [`Debug`][LogLevel::Debug]
147 /// - [`Trace`][LogLevel::Trace]
148 ///
149 /// The default setting for this enum is [`Warning`][LogLevel::Warning].
150 ///
151 /// # Examples
152 ///
153 /// Example of creating a default `LogLevel`:
154 ///
155 /// ```
156 /// use ct2rs::sys::LogLevel;
157 ///
158 /// let log_level = LogLevel::default();
159 /// # assert_eq!(log_level, LogLevel::Warning);
160 /// ```
161 #[derive(Copy, Clone, Debug)]
162 #[repr(i32)]
163 enum LogLevel {
164 Off = -3,
165 Critical = -2,
166 Error = -1,
167 Warning = 0,
168 Info = 1,
169 Debug = 2,
170 Trace = 3,
171 }
172
173 unsafe extern "C++" {
174 include!("ct2rs/include/config.h");
175
176 type Device;
177 type ComputeType;
178 type ReplicaPoolConfig;
179 pub type BatchType;
180
181 fn replica_pool_config(
182 num_threads_per_replica: usize,
183 max_queued_batches: i32,
184 cpu_core_offset: i32,
185 ) -> UniquePtr<ReplicaPoolConfig>;
186
187 pub type Config;
188
189 fn config(
190 device: Device,
191 compute_type: ComputeType,
192 device_indices: &[i32],
193 tensor_parallel: bool,
194 replica_pool_config: UniquePtr<ReplicaPoolConfig>,
195 ) -> UniquePtr<Config>;
196
197 /// Returns the number of devices.
198 fn get_device_count(device: Device) -> i32;
199
200 type LogLevel;
201
202 /// Sets the CTranslate2 logging level.
203 ///
204 /// # Examples
205 /// The following example sets the log level to `Debug`.
206 /// ```
207 /// use ct2rs::sys::{LogLevel, set_log_level};
208 ///
209 /// set_log_level(LogLevel::Debug);
210 /// ```
211 fn set_log_level(level: LogLevel);
212
213 /// Returns the current logging level.
214 fn get_log_level() -> LogLevel;
215
216 /// Sets the seed of random generators.
217 ///
218 /// # Examples
219 /// The following example sets the random seed to `12345`.
220 /// ```
221 /// use ct2rs::sys::set_random_seed;
222 ///
223 /// set_random_seed(12345);
224 /// ```
225 fn set_random_seed(seed: u32);
226
227 /// Returns the current seed of random generators.
228 fn get_random_seed() -> u32;
229 }
230}
231
232impl Default for Device {
233 fn default() -> Self {
234 Self::CPU
235 }
236}
237
238impl Display for Device {
239 fn fmt(&self, f: &mut Formatter) -> std::fmt::Result {
240 match *self {
241 Device::CPU => write!(f, "CPU"),
242 Device::CUDA => write!(f, "CUDA"),
243 _ => write!(f, "Unknown"),
244 }
245 }
246}
247
248impl Default for ComputeType {
249 fn default() -> Self {
250 Self::DEFAULT
251 }
252}
253
254impl Display for ComputeType {
255 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
256 match *self {
257 ComputeType::DEFAULT => write!(f, "default"),
258 ComputeType::AUTO => write!(f, "auto"),
259 ComputeType::FLOAT32 => write!(f, "float32"),
260 ComputeType::INT8 => write!(f, "int8"),
261 ComputeType::INT8_FLOAT32 => write!(f, "int8_float32"),
262 ComputeType::INT8_FLOAT16 => write!(f, "int8_float16"),
263 ComputeType::INT8_BFLOAT16 => write!(f, "int8_bfloat16"),
264 ComputeType::INT16 => write!(f, "int16"),
265 ComputeType::FLOAT16 => write!(f, "float16"),
266 ComputeType::BFLOAT16 => write!(f, "bfloat16"),
267 _ => write!(f, "unknown"),
268 }
269 }
270}
271
272impl Default for BatchType {
273 fn default() -> Self {
274 Self::Examples
275 }
276}
277
278impl Display for BatchType {
279 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
280 match *self {
281 BatchType::Examples => write!(f, "examples"),
282 BatchType::Tokens => write!(f, "tokens"),
283 _ => write!(f, "unknown"),
284 }
285 }
286}
287
288impl Default for LogLevel {
289 fn default() -> Self {
290 Self::Warning
291 }
292}
293
294impl Display for LogLevel {
295 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
296 match *self {
297 LogLevel::Off => write!(f, "off"),
298 LogLevel::Critical => write!(f, "critical"),
299 LogLevel::Error => write!(f, "error"),
300 LogLevel::Warning => write!(f, "warning"),
301 LogLevel::Info => write!(f, "info"),
302 LogLevel::Debug => write!(f, "debug"),
303 LogLevel::Trace => write!(f, "trace"),
304 _ => write!(f, "unknown"),
305 }
306 }
307}
308
309/// The `Config` structure holds the configuration settings for CTranslator2.
310///
311/// # Examples
312///
313/// Example of creating a default `Config`:
314///
315/// ```
316/// use ct2rs::sys::{ComputeType, Config, Device};
317///
318/// let config = Config::default();
319/// # assert_eq!(config.device, Device::default());
320/// # assert_eq!(config.compute_type, ComputeType::default());
321/// # assert_eq!(config.device_indices, vec![0]);
322/// # assert_eq!(config.tensor_parallel, false);
323/// # assert_eq!(config.num_threads_per_replica, 0);
324/// # assert_eq!(config.max_queued_batches, 0);
325/// # assert_eq!(config.cpu_core_offset, -1);
326/// ```
327#[derive(PartialEq, Eq, Clone, Debug)]
328pub struct Config {
329 /// Device to use.
330 pub device: Device,
331 /// Model computation type.
332 pub compute_type: ComputeType,
333 /// Device IDs where to place this generator on. (default: `vec![0]`)
334 pub device_indices: Vec<i32>,
335 /// Run model with tensor parallel mode. (default: false)
336 pub tensor_parallel: bool,
337 /// Number of threads per translator/generator (0 to use a default value). (default: 0)
338 pub num_threads_per_replica: usize,
339 /// Maximum numbers of batches in the queue (-1 for unlimited, 0 for an automatic value).
340 /// When the queue is full, future requests will block until a free slot is available.
341 /// (default: 0)
342 pub max_queued_batches: i32,
343 /// (default: -1)
344 pub cpu_core_offset: i32,
345}
346
347impl Default for Config {
348 fn default() -> Self {
349 Self {
350 device: Default::default(),
351 compute_type: Default::default(),
352 device_indices: vec![0],
353 tensor_parallel: false,
354 num_threads_per_replica: 0,
355 max_queued_batches: 0,
356 cpu_core_offset: -1,
357 }
358 }
359}
360
361impl Config {
362 pub(crate) fn to_ffi(&self) -> UniquePtr<ffi::Config> {
363 ffi::config(
364 self.device,
365 self.compute_type,
366 self.device_indices.as_slice(),
367 false,
368 ffi::replica_pool_config(
369 self.num_threads_per_replica,
370 self.max_queued_batches,
371 self.cpu_core_offset,
372 ),
373 )
374 }
375}
376
377#[cfg(test)]
378mod tests {
379 use rand::random;
380
381 use super::{
382 get_device_count, get_log_level, get_random_seed, set_log_level, set_random_seed,
383 BatchType, ComputeType, Config, Device, LogLevel,
384 };
385
386 #[test]
387 fn test_device_display() {
388 assert_eq!(format!("{}", Device::CPU), "CPU");
389 assert_eq!(format!("{}", Device::CUDA), "CUDA");
390 }
391
392 #[test]
393 fn test_compute_type_display() {
394 assert_eq!(format!("{}", ComputeType::DEFAULT), "default");
395 assert_eq!(format!("{}", ComputeType::AUTO), "auto");
396 assert_eq!(format!("{}", ComputeType::FLOAT32), "float32");
397 assert_eq!(format!("{}", ComputeType::INT8), "int8");
398 assert_eq!(format!("{}", ComputeType::INT8_FLOAT32), "int8_float32");
399 assert_eq!(format!("{}", ComputeType::INT8_FLOAT16), "int8_float16");
400 assert_eq!(format!("{}", ComputeType::INT8_BFLOAT16), "int8_bfloat16");
401 assert_eq!(format!("{}", ComputeType::INT16), "int16");
402 assert_eq!(format!("{}", ComputeType::FLOAT16), "float16");
403 assert_eq!(format!("{}", ComputeType::BFLOAT16), "bfloat16");
404 }
405
406 #[test]
407 fn test_batch_type_display() {
408 assert_eq!(format!("{}", BatchType::Examples), "examples");
409 assert_eq!(format!("{}", BatchType::Tokens), "tokens");
410 }
411
412 #[test]
413 fn test_log_level_display() {
414 assert_eq!(format!("{}", LogLevel::Off), "off");
415 assert_eq!(format!("{}", LogLevel::Critical), "critical");
416 assert_eq!(format!("{}", LogLevel::Error), "error");
417 assert_eq!(format!("{}", LogLevel::Warning), "warning");
418 assert_eq!(format!("{}", LogLevel::Info), "info");
419 assert_eq!(format!("{}", LogLevel::Debug), "debug");
420 assert_eq!(format!("{}", LogLevel::Trace), "trace");
421 }
422
423 #[test]
424 fn test_config_to_ffi() {
425 let config = Config::default();
426 let res = config.to_ffi();
427
428 assert!(!res.is_null());
429 }
430
431 #[cfg(not(feature = "cuda"))]
432 #[test]
433 fn test_get_device_count() {
434 assert_eq!(get_device_count(Device::CPU), 1);
435 assert_eq!(get_device_count(Device::CUDA), 0);
436 }
437
438 #[test]
439 fn test_default_log_level() {
440 assert_eq!(LogLevel::default(), LogLevel::Warning);
441 }
442
443 #[test]
444 fn test_log_level() {
445 for l in [
446 LogLevel::Off,
447 LogLevel::Critical,
448 LogLevel::Error,
449 LogLevel::Warning,
450 LogLevel::Info,
451 LogLevel::Debug,
452 LogLevel::Trace,
453 ] {
454 set_log_level(l);
455 assert_eq!(get_log_level(), l);
456 }
457 }
458
459 #[test]
460 fn test_random_seed() {
461 let r = random::<u32>();
462 set_random_seed(r);
463 assert_eq!(get_random_seed(), r);
464 }
465}