use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
pub enum NativeVsNixlPolicy {
PreferNative,
PreferNixl,
#[default]
Automatic,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TransferPreferences {
pub native_vs_nixl: NativeVsNixlPolicy,
pub prefer_async_cuda: bool,
}
impl Default for TransferPreferences {
fn default() -> Self {
Self {
native_vs_nixl: NativeVsNixlPolicy::default(),
prefer_async_cuda: true,
}
}
}
impl TransferPreferences {
pub fn new() -> Self {
Self::default()
}
pub fn prefer_native() -> Self {
Self {
native_vs_nixl: NativeVsNixlPolicy::PreferNative,
prefer_async_cuda: true,
}
}
pub fn prefer_nixl() -> Self {
Self {
native_vs_nixl: NativeVsNixlPolicy::PreferNixl,
prefer_async_cuda: true,
}
}
pub fn with_native_vs_nixl(mut self, policy: NativeVsNixlPolicy) -> Self {
self.native_vs_nixl = policy;
self
}
pub fn with_async_cuda(mut self, prefer_async: bool) -> Self {
self.prefer_async_cuda = prefer_async;
self
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_default_preferences() {
let prefs = TransferPreferences::default();
assert_eq!(prefs.native_vs_nixl, NativeVsNixlPolicy::Automatic);
assert!(prefs.prefer_async_cuda);
}
#[test]
fn test_prefer_native() {
let prefs = TransferPreferences::prefer_native();
assert_eq!(prefs.native_vs_nixl, NativeVsNixlPolicy::PreferNative);
assert!(prefs.prefer_async_cuda);
}
#[test]
fn test_prefer_nixl() {
let prefs = TransferPreferences::prefer_nixl();
assert_eq!(prefs.native_vs_nixl, NativeVsNixlPolicy::PreferNixl);
assert!(prefs.prefer_async_cuda);
}
#[test]
fn test_builder_pattern() {
let prefs = TransferPreferences::new()
.with_native_vs_nixl(NativeVsNixlPolicy::PreferNixl)
.with_async_cuda(false);
assert_eq!(prefs.native_vs_nixl, NativeVsNixlPolicy::PreferNixl);
assert!(!prefs.prefer_async_cuda);
}
}