#![allow(clippy::needless_range_loop)]
use pokeys_lib::*;
#[cfg(test)]
mod protocol_tests {
#[test]
fn test_request_buffer_formatting() {
let mut buffer = [0u8; 64];
buffer[0] = 0x00; buffer[1] = 0x01; buffer[2] = 0x02; buffer[3] = 0x03; buffer[4] = 0x04;
assert_eq!(buffer[0], 0x00);
assert_eq!(buffer[1], 0x01);
assert_eq!(buffer[2], 0x02);
assert_eq!(buffer[3], 0x03);
assert_eq!(buffer[4], 0x04);
for i in 5..64 {
assert_eq!(buffer[i], 0);
}
}
#[test]
fn test_response_parsing() {
let response = [
0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A, 0x37, 0x06, 0x08, 0x19, ];
let device_type = response[1] as u32
| ((response[2] as u32) << 8)
| ((response[3] as u32) << 16)
| ((response[4] as u32) << 24);
let firmware_major = response[5];
let firmware_minor = response[6];
let serial = response[7] as u32
| ((response[8] as u32) << 8)
| ((response[9] as u32) << 16)
| ((response[10] as u32) << 24);
let pin_count = response[11];
let pwm_count = response[12];
let analog_inputs = response[13];
let encoders = response[14];
assert_eq!(device_type, 0x04030201);
assert_eq!(firmware_major, 5);
assert_eq!(firmware_minor, 6);
assert_eq!(serial, 0x0A090807);
assert_eq!(pin_count, 55);
assert_eq!(pwm_count, 6);
assert_eq!(analog_inputs, 8);
assert_eq!(encoders, 25);
}
#[test]
fn test_pin_state_encoding() {
let mut pin_states = [0u8; 8];
let pins_to_set = [1, 5, 17, 33, 55];
for &pin in &pins_to_set {
if pin > 0 && pin <= 64 {
let byte_index = (pin - 1) / 8;
let bit_index = (pin - 1) % 8;
pin_states[byte_index as usize] |= 1 << bit_index;
}
}
for &pin in &pins_to_set {
if pin > 0 && pin <= 64 {
let byte_index = (pin - 1) / 8;
let bit_index = (pin - 1) % 8;
let is_set = (pin_states[byte_index as usize] & (1 << bit_index)) != 0;
assert!(is_set, "Pin {pin} should be set");
}
}
let test_pins = [2, 3, 4, 6, 16, 18, 32, 34, 54, 56];
for &pin in &test_pins {
if pin > 0 && pin <= 64 {
let byte_index = (pin - 1) / 8;
let bit_index = (pin - 1) % 8;
let is_set = (pin_states[byte_index as usize] & (1 << bit_index)) != 0;
assert!(!is_set, "Pin {pin} should not be set");
}
}
}
#[test]
fn test_analog_value_encoding() {
let analog_values = [0u16, 2048, 4095, 1234, 3456];
let mut buffer = [0u8; 16];
for (i, &value) in analog_values.iter().enumerate() {
let offset = i * 2;
buffer[offset] = (value & 0xFF) as u8;
buffer[offset + 1] = ((value >> 8) & 0xFF) as u8;
}
for (i, &expected) in analog_values.iter().enumerate() {
let offset = i * 2;
let decoded = buffer[offset] as u16 | ((buffer[offset + 1] as u16) << 8);
assert_eq!(decoded, expected, "Analog value {i} mismatch");
}
}
#[test]
fn test_encoder_value_encoding() {
let encoder_values = [0i32, -1, 1, i32::MIN, i32::MAX, 12345, -67890];
let mut buffer = [0u8; 32];
for (i, &value) in encoder_values.iter().enumerate() {
let offset = i * 4;
let bytes = value.to_le_bytes();
buffer[offset..offset + 4].copy_from_slice(&bytes);
}
for (i, &expected) in encoder_values.iter().enumerate() {
let offset = i * 4;
let mut bytes = [0u8; 4];
bytes.copy_from_slice(&buffer[offset..offset + 4]);
let decoded = i32::from_le_bytes(bytes);
assert_eq!(decoded, expected, "Encoder value {i} mismatch");
}
}
#[test]
fn test_pwm_duty_cycle_encoding() {
let duty_percentages = [0.0, 25.0, 50.0, 75.0, 100.0];
let pwm_period = 1000u16;
for &duty_percent in &duty_percentages {
let raw_value = ((duty_percent / 100.0) * pwm_period as f32) as u16;
assert!(raw_value <= pwm_period);
let back_to_percent = (raw_value as f32 / pwm_period as f32) * 100.0;
assert!((back_to_percent - duty_percent).abs() < 0.1);
}
}
#[test]
fn test_checksum_calculation() {
fn sum_checksum(data: &[u8]) -> u8 {
data.iter().fold(0u8, |acc, &x| acc.wrapping_add(x))
}
fn xor_checksum(data: &[u8]) -> u8 {
data.iter().fold(0u8, |acc, &x| acc ^ x)
}
let test_data = [0x01, 0x02, 0x03, 0x04, 0x05];
let sum_check = sum_checksum(&test_data);
let xor_check = xor_checksum(&test_data);
assert_eq!(sum_check, 15); assert_eq!(xor_check, 1);
let all_zeros = [0u8; 10];
assert_eq!(sum_checksum(&all_zeros), 0);
assert_eq!(xor_checksum(&all_zeros), 0);
let all_ones = [0xFF; 4];
assert_eq!(sum_checksum(&all_ones), 252); assert_eq!(xor_checksum(&all_ones), 0); }
#[test]
fn test_command_id_constants() {
let basic_commands = [
0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x10, 0x20, 0x30, 0x40, 0x50, ];
for &cmd in &basic_commands {
assert!(cmd < 0x80, "Command ID {cmd} should be < 0x80");
}
}
#[test]
fn test_multi_byte_value_handling() {
let test_value = 0x12345678u32;
let le_bytes = test_value.to_le_bytes();
assert_eq!(le_bytes, [0x78, 0x56, 0x34, 0x12]);
let be_bytes = test_value.to_be_bytes();
assert_eq!(be_bytes, [0x12, 0x34, 0x56, 0x78]);
assert_eq!(u32::from_le_bytes(le_bytes), test_value);
assert_eq!(u32::from_be_bytes(be_bytes), test_value);
}
#[test]
fn test_buffer_overflow_protection() {
let mut buffer = [0u8; 64];
let data_to_copy = [0xAA; 32];
buffer[0..32].copy_from_slice(&data_to_copy);
for i in 0..32 {
assert_eq!(buffer[i], 0xAA);
}
for i in 32..64 {
assert_eq!(buffer[i], 0);
}
}
#[test]
fn test_protocol_version_handling() {
struct ProtocolVersion {
major: u8,
minor: u8,
}
impl ProtocolVersion {
fn is_compatible(&self, other: &ProtocolVersion) -> bool {
self.major == other.major
}
fn is_newer(&self, other: &ProtocolVersion) -> bool {
self.major > other.major || (self.major == other.major && self.minor > other.minor)
}
}
let v1_0 = ProtocolVersion { major: 1, minor: 0 };
let v1_1 = ProtocolVersion { major: 1, minor: 1 };
let v2_0 = ProtocolVersion { major: 2, minor: 0 };
assert!(v1_0.is_compatible(&v1_1));
assert!(v1_1.is_compatible(&v1_0));
assert!(!v1_0.is_compatible(&v2_0));
assert!(v1_1.is_newer(&v1_0));
assert!(v2_0.is_newer(&v1_1));
assert!(!v1_0.is_newer(&v1_1));
}
#[test]
fn test_error_response_parsing() {
let error_responses = [
([0xFF, 0x01, 0x00, 0x00], "Invalid command"),
([0xFF, 0x02, 0x00, 0x00], "Invalid parameter"),
([0xFF, 0x03, 0x00, 0x00], "Device busy"),
([0xFF, 0x04, 0x00, 0x00], "Communication error"),
];
for (response, description) in &error_responses {
let status = response[0];
let error_code = response[1];
assert_eq!(
status, 0xFF,
"Error status should be 0xFF for {description}"
);
assert!(
error_code > 0,
"Error code should be non-zero for {description}"
);
}
}
#[test]
fn test_timeout_handling() {
use std::time::{Duration, Instant};
fn check_timeout(start_time: Instant, timeout: Duration) -> bool {
start_time.elapsed() > timeout
}
let start = Instant::now();
let short_timeout = Duration::from_millis(1);
let long_timeout = Duration::from_millis(1000);
assert!(!check_timeout(start, long_timeout));
std::thread::sleep(Duration::from_millis(2));
assert!(check_timeout(start, short_timeout));
}
#[test]
fn test_data_validation() {
fn validate_pin_number(pin: u8, max_pins: u8) -> bool {
pin > 0 && pin <= max_pins
}
fn validate_pwm_duty_cycle(duty: f32) -> bool {
(0.0..=100.0).contains(&duty)
}
fn validate_encoder_id(encoder: u8, max_encoders: u8) -> bool {
encoder < max_encoders
}
assert!(!validate_pin_number(0, 55));
assert!(validate_pin_number(1, 55));
assert!(validate_pin_number(55, 55));
assert!(!validate_pin_number(56, 55));
assert!(validate_pwm_duty_cycle(0.0));
assert!(validate_pwm_duty_cycle(50.0));
assert!(validate_pwm_duty_cycle(100.0));
assert!(!validate_pwm_duty_cycle(-1.0));
assert!(!validate_pwm_duty_cycle(101.0));
assert!(validate_encoder_id(0, 25));
assert!(validate_encoder_id(24, 25));
assert!(!validate_encoder_id(25, 25));
}
#[test]
fn test_bit_field_operations() {
#[derive(Default)]
struct DeviceFlags {
value: u32,
}
impl DeviceFlags {
fn set_flag(&mut self, bit: u8) {
self.value |= 1 << bit;
}
fn clear_flag(&mut self, bit: u8) {
self.value &= !(1 << bit);
}
fn get_flag(&self, bit: u8) -> bool {
(self.value & (1 << bit)) != 0
}
fn from_bytes(bytes: &[u8; 4]) -> Self {
Self {
value: u32::from_le_bytes(*bytes),
}
}
fn to_bytes(&self) -> [u8; 4] {
self.value.to_le_bytes()
}
}
let mut flags = DeviceFlags::default();
flags.set_flag(0);
flags.set_flag(15);
flags.set_flag(31);
assert!(flags.get_flag(0));
assert!(flags.get_flag(15));
assert!(flags.get_flag(31));
assert!(!flags.get_flag(1));
assert!(!flags.get_flag(16));
flags.clear_flag(15);
assert!(!flags.get_flag(15));
assert!(flags.get_flag(0));
assert!(flags.get_flag(31));
let bytes = flags.to_bytes();
let flags2 = DeviceFlags::from_bytes(&bytes);
assert_eq!(flags.value, flags2.value);
}
}
#[cfg(test)]
mod communication_tests {
use super::*;
use std::time::Duration;
struct MockCommunication {
connected: bool,
last_request: Vec<u8>,
next_response: Vec<u8>,
should_fail: bool,
delay: Duration,
}
impl MockCommunication {
fn new() -> Self {
Self {
connected: false,
last_request: Vec::new(),
next_response: Vec::new(),
should_fail: false,
delay: Duration::from_millis(0),
}
}
fn connect(&mut self) -> Result<()> {
if self.should_fail {
return Err(PoKeysError::CannotConnect);
}
self.connected = true;
Ok(())
}
fn disconnect(&mut self) {
self.connected = false;
}
fn send_request(&mut self, request: &[u8]) -> Result<Vec<u8>> {
if !self.connected {
return Err(PoKeysError::NotConnected);
}
if self.should_fail {
return Err(PoKeysError::Transfer("Communication failed".to_string()));
}
if self.delay > Duration::from_millis(0) {
std::thread::sleep(self.delay);
}
self.last_request = request.to_vec();
Ok(self.next_response.clone())
}
fn set_next_response(&mut self, response: Vec<u8>) {
self.next_response = response;
}
fn set_should_fail(&mut self, fail: bool) {
self.should_fail = fail;
}
fn set_delay(&mut self, delay: Duration) {
self.delay = delay;
}
fn get_last_request(&self) -> &[u8] {
&self.last_request
}
}
#[test]
fn test_mock_communication_basic() {
let mut comm = MockCommunication::new();
assert!(!comm.connected);
comm.connect().unwrap();
assert!(comm.connected);
let request = vec![0x00, 0x01, 0x02];
let response = vec![0x10, 0x11, 0x12];
comm.set_next_response(response.clone());
let result = comm.send_request(&request).unwrap();
assert_eq!(result, response);
assert_eq!(comm.get_last_request(), &request);
comm.disconnect();
assert!(!comm.connected);
}
#[test]
fn test_mock_communication_errors() {
let mut comm = MockCommunication::new();
let request = vec![0x00];
assert!(comm.send_request(&request).is_err());
comm.set_should_fail(true);
assert!(comm.connect().is_err());
comm.set_should_fail(false);
comm.connect().unwrap();
comm.set_should_fail(true);
assert!(comm.send_request(&request).is_err());
}
#[test]
fn test_mock_communication_timeout_simulation() {
let mut comm = MockCommunication::new();
comm.connect().unwrap();
comm.set_delay(Duration::from_millis(100));
comm.set_next_response(vec![0xFF]);
let start = std::time::Instant::now();
let request = vec![0x00];
let _result = comm.send_request(&request).unwrap();
let elapsed = start.elapsed();
assert!(elapsed >= Duration::from_millis(90)); }
#[test]
fn test_request_response_patterns() {
let mut comm = MockCommunication::new();
comm.connect().unwrap();
let device_info_request = vec![0x00, 0x00, 0x00, 0x00, 0x00];
let device_info_response = vec![
0x00, 0x0A, 0x00, 0x00, 0x00, 0x01, 0x02, 0x34, 0x12, 0x00, 0x00, 55, 6, 8, 25, ];
comm.set_next_response(device_info_response.clone());
let response = comm.send_request(&device_info_request).unwrap();
assert_eq!(response, device_info_response);
assert_eq!(comm.get_last_request(), &device_info_request);
let digital_out_request = vec![0x03, 0x01, 0xFF, 0x00, 0x00]; let digital_out_response = vec![0x00];
comm.set_next_response(digital_out_response.clone());
let response = comm.send_request(&digital_out_request).unwrap();
assert_eq!(response, digital_out_response);
assert_eq!(comm.get_last_request(), &digital_out_request);
}
#[test]
fn test_multiple_requests() {
let mut comm = MockCommunication::new();
comm.connect().unwrap();
let requests_responses = vec![
(vec![0x00], vec![0x00, 0x01]),
(vec![0x01], vec![0x00, 0x02]),
(vec![0x02], vec![0x00, 0x03]),
];
for (request, expected_response) in requests_responses {
comm.set_next_response(expected_response.clone());
let response = comm.send_request(&request).unwrap();
assert_eq!(response, expected_response);
assert_eq!(comm.get_last_request(), &request);
}
}
#[test]
fn test_large_data_transfer() {
let mut comm = MockCommunication::new();
comm.connect().unwrap();
let large_request = vec![0xAA; 1024];
let large_response = vec![0x55; 2048];
comm.set_next_response(large_response.clone());
let response = comm.send_request(&large_request).unwrap();
assert_eq!(response, large_response);
assert_eq!(comm.get_last_request(), &large_request);
}
#[test]
fn test_empty_data_handling() {
let mut comm = MockCommunication::new();
comm.connect().unwrap();
let empty_request = vec![];
let response = vec![0x00];
comm.set_next_response(response.clone());
let result = comm.send_request(&empty_request).unwrap();
assert_eq!(result, response);
assert_eq!(comm.get_last_request(), &empty_request);
let request = vec![0x00];
let empty_response = vec![];
comm.set_next_response(empty_response.clone());
let result = comm.send_request(&request).unwrap();
assert_eq!(result, empty_response);
}
}
#[cfg(test)]
mod device_name_tests {
fn build_set_device_name_packet(name: &str, request_id: u8) -> [u8; 64] {
let mut name_bytes = [0u8; 20];
let name_str = if name.len() > 20 { &name[..20] } else { name };
name_bytes[..name_str.len()].copy_from_slice(name_str.as_bytes());
let mut request = [0u8; 64];
request[0] = 0xBB;
request[1] = 0x06;
request[2] = 0x01;
request[3] = 0x01;
request[4] = 0x00;
request[5] = 0x00;
request[6] = request_id;
request[35..55].copy_from_slice(&name_bytes);
let mut checksum: u8 = 0;
for i in 0..7 {
checksum = checksum.wrapping_add(request[i]);
}
request[7] = checksum;
request
}
#[test]
fn test_packet_header_bytes() {
let pkt = build_set_device_name_packet("Test", 1);
assert_eq!(pkt[0], 0xBB, "header byte");
assert_eq!(pkt[1], 0x06, "command byte");
assert_eq!(pkt[2], 0x01, "write-flag byte");
assert_eq!(pkt[3], 0x01, "long-name flag");
assert_eq!(pkt[4], 0x00);
assert_eq!(pkt[5], 0x00);
}
#[test]
fn test_checksum_calculation() {
let pkt = build_set_device_name_packet("Hello", 42);
let expected: u8 = pkt[0..7].iter().fold(0u8, |acc, &b| acc.wrapping_add(b));
assert_eq!(pkt[7], expected);
}
#[test]
fn test_checksum_changes_with_request_id() {
let pkt1 = build_set_device_name_packet("Hello", 1);
let pkt2 = build_set_device_name_packet("Hello", 2);
assert_ne!(pkt1[7], pkt2[7]);
}
#[test]
fn test_name_placed_at_correct_offset() {
let name = "MyDevice";
let pkt = build_set_device_name_packet(name, 1);
let name_field = &pkt[35..55];
assert_eq!(&name_field[..name.len()], name.as_bytes());
for &b in &name_field[name.len()..] {
assert_eq!(b, 0);
}
}
#[test]
fn test_name_truncated_to_20_bytes() {
let long_name = "ABCDEFGHIJKLMNOPQRSTUVWXYZ"; let pkt = build_set_device_name_packet(long_name, 1);
let name_field = &pkt[35..55];
assert_eq!(name_field, b"ABCDEFGHIJKLMNOPQRST");
}
#[test]
fn test_exact_20_char_name_not_truncated() {
let name = "01234567890123456789"; let pkt = build_set_device_name_packet(name, 1);
assert_eq!(&pkt[35..55], name.as_bytes());
}
#[test]
fn test_empty_name_produces_zero_name_field() {
let pkt = build_set_device_name_packet("", 1);
assert_eq!(&pkt[35..55], &[0u8; 20]);
}
#[test]
fn test_bytes_outside_name_field_are_zero() {
let pkt = build_set_device_name_packet("Test", 1);
for (i, &b) in pkt[8..35].iter().enumerate() {
assert_eq!(b, 0, "byte {} should be zero", i + 8);
}
for (i, &b) in pkt[55..64].iter().enumerate() {
assert_eq!(b, 0, "byte {} should be zero", i + 55);
}
}
}