use std::fmt::{Debug, Display, Formatter};
use cxx::UniquePtr;
pub use self::ffi::{
get_device_count, get_log_level, get_random_seed, set_log_level, set_random_seed, BatchType,
ComputeType, Device, LogLevel,
};
#[cxx::bridge]
pub(crate) mod ffi {
#[derive(Copy, Clone, Debug)]
#[repr(i32)]
enum Device {
CPU,
CUDA,
}
#[derive(Copy, Clone, Debug)]
#[repr(i32)]
enum ComputeType {
DEFAULT,
AUTO,
FLOAT32,
INT8,
INT8_FLOAT32,
INT8_FLOAT16,
INT8_BFLOAT16,
INT16,
FLOAT16,
BFLOAT16,
}
#[derive(Copy, Clone, Debug)]
#[repr(i32)]
enum BatchType {
Examples,
Tokens,
}
#[derive(Copy, Clone, Debug)]
#[repr(i32)]
enum LogLevel {
Off = -3,
Critical = -2,
Error = -1,
Warning = 0,
Info = 1,
Debug = 2,
Trace = 3,
}
unsafe extern "C++" {
include!("ct2rs/include/config.h");
type Device;
type ComputeType;
type ReplicaPoolConfig;
pub type BatchType;
fn replica_pool_config(
num_threads_per_replica: usize,
max_queued_batches: i32,
cpu_core_offset: i32,
) -> UniquePtr<ReplicaPoolConfig>;
pub type Config;
fn config(
device: Device,
compute_type: ComputeType,
device_indices: &[i32],
tensor_parallel: bool,
replica_pool_config: UniquePtr<ReplicaPoolConfig>,
) -> UniquePtr<Config>;
fn get_device_count(device: Device) -> i32;
type LogLevel;
fn set_log_level(level: LogLevel);
fn get_log_level() -> LogLevel;
fn set_random_seed(seed: u32);
fn get_random_seed() -> u32;
}
}
impl Default for Device {
fn default() -> Self {
Self::CPU
}
}
impl Display for Device {
fn fmt(&self, f: &mut Formatter) -> std::fmt::Result {
match *self {
Device::CPU => write!(f, "CPU"),
Device::CUDA => write!(f, "CUDA"),
_ => write!(f, "Unknown"),
}
}
}
impl Default for ComputeType {
fn default() -> Self {
Self::DEFAULT
}
}
impl Display for ComputeType {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
match *self {
ComputeType::DEFAULT => write!(f, "default"),
ComputeType::AUTO => write!(f, "auto"),
ComputeType::FLOAT32 => write!(f, "float32"),
ComputeType::INT8 => write!(f, "int8"),
ComputeType::INT8_FLOAT32 => write!(f, "int8_float32"),
ComputeType::INT8_FLOAT16 => write!(f, "int8_float16"),
ComputeType::INT8_BFLOAT16 => write!(f, "int8_bfloat16"),
ComputeType::INT16 => write!(f, "int16"),
ComputeType::FLOAT16 => write!(f, "float16"),
ComputeType::BFLOAT16 => write!(f, "bfloat16"),
_ => write!(f, "unknown"),
}
}
}
impl Default for BatchType {
fn default() -> Self {
Self::Examples
}
}
impl Display for BatchType {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
match *self {
BatchType::Examples => write!(f, "examples"),
BatchType::Tokens => write!(f, "tokens"),
_ => write!(f, "unknown"),
}
}
}
impl Default for LogLevel {
fn default() -> Self {
Self::Warning
}
}
impl Display for LogLevel {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
match *self {
LogLevel::Off => write!(f, "off"),
LogLevel::Critical => write!(f, "critical"),
LogLevel::Error => write!(f, "error"),
LogLevel::Warning => write!(f, "warning"),
LogLevel::Info => write!(f, "info"),
LogLevel::Debug => write!(f, "debug"),
LogLevel::Trace => write!(f, "trace"),
_ => write!(f, "unknown"),
}
}
}
#[derive(PartialEq, Eq, Clone, Debug)]
pub struct Config {
pub device: Device,
pub compute_type: ComputeType,
pub device_indices: Vec<i32>,
pub tensor_parallel: bool,
pub num_threads_per_replica: usize,
pub max_queued_batches: i32,
pub cpu_core_offset: i32,
}
impl Default for Config {
fn default() -> Self {
Self {
device: Default::default(),
compute_type: Default::default(),
device_indices: vec![0],
tensor_parallel: false,
num_threads_per_replica: 0,
max_queued_batches: 0,
cpu_core_offset: -1,
}
}
}
impl Config {
pub(crate) fn to_ffi(&self) -> UniquePtr<ffi::Config> {
ffi::config(
self.device,
self.compute_type,
self.device_indices.as_slice(),
false,
ffi::replica_pool_config(
self.num_threads_per_replica,
self.max_queued_batches,
self.cpu_core_offset,
),
)
}
}
#[cfg(test)]
mod tests {
use rand::random;
use super::{
get_device_count, get_log_level, get_random_seed, set_log_level, set_random_seed,
BatchType, ComputeType, Config, Device, LogLevel,
};
#[test]
fn test_device_display() {
assert_eq!(format!("{}", Device::CPU), "CPU");
assert_eq!(format!("{}", Device::CUDA), "CUDA");
}
#[test]
fn test_compute_type_display() {
assert_eq!(format!("{}", ComputeType::DEFAULT), "default");
assert_eq!(format!("{}", ComputeType::AUTO), "auto");
assert_eq!(format!("{}", ComputeType::FLOAT32), "float32");
assert_eq!(format!("{}", ComputeType::INT8), "int8");
assert_eq!(format!("{}", ComputeType::INT8_FLOAT32), "int8_float32");
assert_eq!(format!("{}", ComputeType::INT8_FLOAT16), "int8_float16");
assert_eq!(format!("{}", ComputeType::INT8_BFLOAT16), "int8_bfloat16");
assert_eq!(format!("{}", ComputeType::INT16), "int16");
assert_eq!(format!("{}", ComputeType::FLOAT16), "float16");
assert_eq!(format!("{}", ComputeType::BFLOAT16), "bfloat16");
}
#[test]
fn test_batch_type_display() {
assert_eq!(format!("{}", BatchType::Examples), "examples");
assert_eq!(format!("{}", BatchType::Tokens), "tokens");
}
#[test]
fn test_log_level_display() {
assert_eq!(format!("{}", LogLevel::Off), "off");
assert_eq!(format!("{}", LogLevel::Critical), "critical");
assert_eq!(format!("{}", LogLevel::Error), "error");
assert_eq!(format!("{}", LogLevel::Warning), "warning");
assert_eq!(format!("{}", LogLevel::Info), "info");
assert_eq!(format!("{}", LogLevel::Debug), "debug");
assert_eq!(format!("{}", LogLevel::Trace), "trace");
}
#[test]
fn test_config_to_ffi() {
let config = Config::default();
let res = config.to_ffi();
assert!(!res.is_null());
}
#[cfg(not(feature = "cuda"))]
#[test]
fn test_get_device_count() {
assert_eq!(get_device_count(Device::CPU), 1);
assert_eq!(get_device_count(Device::CUDA), 0);
}
#[test]
fn test_default_log_level() {
assert_eq!(LogLevel::default(), LogLevel::Warning);
}
#[test]
fn test_log_level() {
for l in [
LogLevel::Off,
LogLevel::Critical,
LogLevel::Error,
LogLevel::Warning,
LogLevel::Info,
LogLevel::Debug,
LogLevel::Trace,
] {
set_log_level(l);
assert_eq!(get_log_level(), l);
}
}
#[test]
fn test_random_seed() {
let r = random::<u32>();
set_random_seed(r);
assert_eq!(get_random_seed(), r);
}
}