use crate::error::BrainVisionError;
use crate::types::*;
fn read_u32_le(buf: &[u8], o: &mut usize) -> Result<u32, BrainVisionError> {
if *o + 4 > buf.len() {
return Err(BrainVisionError::Protocol("u32 out of bounds".into()));
}
let v = u32::from_le_bytes(buf[*o..*o + 4].try_into().unwrap());
*o += 4;
Ok(v)
}
fn read_i32_le(buf: &[u8], o: &mut usize) -> Result<i32, BrainVisionError> {
if *o + 4 > buf.len() {
return Err(BrainVisionError::Protocol("i32 out of bounds".into()));
}
let v = i32::from_le_bytes(buf[*o..*o + 4].try_into().unwrap());
*o += 4;
Ok(v)
}
fn read_f64_le(buf: &[u8], o: &mut usize) -> Result<f64, BrainVisionError> {
if *o + 8 > buf.len() {
return Err(BrainVisionError::Protocol("f64 out of bounds".into()));
}
let v = f64::from_le_bytes(buf[*o..*o + 8].try_into().unwrap());
*o += 8;
Ok(v)
}
fn read_cstring(buf: &[u8], o: &mut usize, end: usize) -> Result<String, BrainVisionError> {
let start = *o;
let mut i = *o;
while i < end && i < buf.len() && buf[i] != 0 {
i += 1;
}
if i >= end || i >= buf.len() {
return Err(BrainVisionError::Protocol("unterminated C string".into()));
}
let s = String::from_utf8_lossy(&buf[start..i]).into_owned();
*o = i + 1;
Ok(s)
}
fn decode_markers(
buf: &[u8],
o: &mut usize,
n_markers: u32,
) -> Result<Vec<Marker>, BrainVisionError> {
let mut markers = Vec::with_capacity(n_markers as usize);
if n_markers > 0 {
let remaining = buf.len().saturating_sub(*o);
if remaining < n_markers as usize * 16 {
return Err(BrainVisionError::Protocol(
"not enough bytes for marker headers".into(),
));
}
}
for _ in 0..n_markers {
if *o >= buf.len() {
return Err(BrainVisionError::Protocol(
"marker starts out of bounds".into(),
));
}
let marker_start = *o;
let size = read_u32_le(buf, o)? as usize;
if size < 16 {
return Err(BrainVisionError::Protocol("marker size < 16".into()));
}
let marker_end = marker_start
.checked_add(size)
.ok_or_else(|| BrainVisionError::Protocol("marker size overflow".into()))?;
if marker_end > buf.len() {
return Err(BrainVisionError::Protocol("marker overruns payload".into()));
}
let position = read_u32_le(buf, o)?;
let points = read_u32_le(buf, o)?;
let channel = read_i32_le(buf, o)?;
let kind = read_cstring(buf, o, marker_end)?;
let description = read_cstring(buf, o, marker_end)?;
*o = marker_end;
markers.push(Marker {
position,
points,
channel,
kind,
description,
});
}
Ok(markers)
}
fn decode_start(payload: &[u8]) -> Result<HeaderInfo, BrainVisionError> {
let mut o = 0usize;
let channel_count = read_u32_le(payload, &mut o)?;
let sampling_interval_us = read_f64_le(payload, &mut o)?;
let mut resolutions_uv = Vec::with_capacity(channel_count as usize);
for _ in 0..channel_count {
resolutions_uv.push(read_f64_le(payload, &mut o)?);
}
let mut channel_names = Vec::new();
while o < payload.len() {
let s = read_cstring(payload, &mut o, payload.len())?;
if s.is_empty() {
break;
}
channel_names.push(s);
}
Ok(HeaderInfo {
channel_count,
sampling_interval_us,
resolutions_uv,
channel_names,
})
}
fn decode_data_i16(payload: &[u8], header: &HeaderInfo) -> Result<DataBlock, BrainVisionError> {
let mut o = 0usize;
let block = read_u32_le(payload, &mut o)?;
let points = read_u32_le(payload, &mut o)?;
let n_markers = read_u32_le(payload, &mut o)?;
let n_channels = header.channel_count as usize;
let n_values = points as usize * n_channels;
let bytes_needed = n_values * 2;
if o + bytes_needed > payload.len() {
return Err(BrainVisionError::Protocol(
"data16 payload truncated".into(),
));
}
let mut samples_uv = Vec::with_capacity(n_values);
for i in 0..n_values {
let raw = i16::from_le_bytes(payload[o + i * 2..o + i * 2 + 2].try_into().unwrap());
let ch = i % n_channels;
let res = header.resolutions_uv.get(ch).copied().unwrap_or(1.0);
samples_uv.push(raw as f64 * res);
}
o += bytes_needed;
let markers = decode_markers(payload, &mut o, n_markers)?;
Ok(DataBlock {
block,
points,
samples_uv,
markers,
})
}
fn decode_data_f32(payload: &[u8], header: &HeaderInfo) -> Result<DataBlock, BrainVisionError> {
let mut o = 0usize;
let block = read_u32_le(payload, &mut o)?;
let points = read_u32_le(payload, &mut o)?;
let n_markers = read_u32_le(payload, &mut o)?;
let n_channels = header.channel_count as usize;
let n_values = points as usize * n_channels;
let bytes_needed = n_values * 4;
if o + bytes_needed > payload.len() {
return Err(BrainVisionError::Protocol(
"data32 payload truncated".into(),
));
}
let mut samples_uv = Vec::with_capacity(n_values);
for i in 0..n_values {
let raw = f32::from_le_bytes(payload[o + i * 4..o + i * 4 + 4].try_into().unwrap());
let ch = i % n_channels;
let res = header.resolutions_uv.get(ch).copied().unwrap_or(1.0);
samples_uv.push(raw as f64 * res);
}
o += bytes_needed;
let markers = decode_markers(payload, &mut o, n_markers)?;
Ok(DataBlock {
block,
points,
samples_uv,
markers,
})
}
pub fn decode_frame(
frame: &[u8],
header_ctx: Option<&HeaderInfo>,
) -> Result<RdaMessage, BrainVisionError> {
if frame.len() < ENVELOPE_LEN {
return Err(BrainVisionError::InvalidMessage("frame too short".into()));
}
let mut guid = [0u8; 16];
guid.copy_from_slice(&frame[..16]);
let size = u32::from_le_bytes(frame[16..20].try_into().unwrap()) as usize;
if size != frame.len() {
return Err(BrainVisionError::InvalidMessage(format!(
"size mismatch: header={}, actual={}",
size,
frame.len()
)));
}
let payload = &frame[ENVELOPE_LEN..];
if guid == GUID_START {
Ok(RdaMessage::Start(decode_start(payload)?))
} else if guid == GUID_DATA16 {
let h =
header_ctx.ok_or_else(|| BrainVisionError::Protocol("DATA16 before START".into()))?;
Ok(RdaMessage::Data16(decode_data_i16(payload, h)?))
} else if guid == GUID_DATA32 {
let h =
header_ctx.ok_or_else(|| BrainVisionError::Protocol("DATA32 before START".into()))?;
Ok(RdaMessage::Data32(decode_data_f32(payload, h)?))
} else if guid == GUID_STOP {
Ok(RdaMessage::Stop)
} else {
Err(BrainVisionError::Protocol(
"unknown RDA message GUID".into(),
))
}
}
pub fn make_frame(guid: [u8; 16], payload: &[u8]) -> Vec<u8> {
let mut out = Vec::with_capacity(ENVELOPE_LEN + payload.len());
out.extend_from_slice(&guid);
out.extend_from_slice(&((ENVELOPE_LEN + payload.len()) as u32).to_le_bytes());
out.extend_from_slice(payload);
out
}
#[cfg(test)]
mod tests {
use super::*;
fn push_f64(v: &mut Vec<u8>, x: f64) {
v.extend_from_slice(&x.to_le_bytes());
}
fn marker_bytes(position: u32, points: u32, channel: i32, kind: &str, desc: &str) -> Vec<u8> {
let mut m = Vec::new();
let size = 16 + kind.len() + 1 + desc.len() + 1;
m.extend_from_slice(&(size as u32).to_le_bytes());
m.extend_from_slice(&position.to_le_bytes());
m.extend_from_slice(&points.to_le_bytes());
m.extend_from_slice(&channel.to_le_bytes());
m.extend_from_slice(kind.as_bytes());
m.push(0);
m.extend_from_slice(desc.as_bytes());
m.push(0);
m
}
#[test]
fn test_decode_start() {
let mut payload = Vec::new();
payload.extend_from_slice(&2u32.to_le_bytes());
push_f64(&mut payload, 2000.0);
push_f64(&mut payload, 0.1);
push_f64(&mut payload, 0.1);
payload.extend_from_slice(b"Cz\0Fz\0\0");
let frame = make_frame(GUID_START, &payload);
let msg = decode_frame(&frame, None).unwrap();
match msg {
RdaMessage::Start(h) => {
assert_eq!(h.channel_count, 2);
assert_eq!(h.channel_names, vec!["Cz", "Fz"]);
assert!((h.sampling_rate_hz() - 500.0).abs() < 1e-9);
}
_ => panic!(),
}
}
#[test]
fn test_decode_data16() {
let header = HeaderInfo {
channel_count: 2,
sampling_interval_us: 2000.0,
resolutions_uv: vec![0.1, 0.1],
channel_names: vec!["Cz".into(), "Fz".into()],
};
let mut payload = Vec::new();
payload.extend_from_slice(&1u32.to_le_bytes());
payload.extend_from_slice(&2u32.to_le_bytes());
payload.extend_from_slice(&0u32.to_le_bytes());
for x in [100i16, -100, 200, -200] {
payload.extend_from_slice(&x.to_le_bytes());
}
let frame = make_frame(GUID_DATA16, &payload);
let msg = decode_frame(&frame, Some(&header)).unwrap();
match msg {
RdaMessage::Data16(d) => {
assert_eq!(d.block, 1);
assert_eq!(d.points, 2);
assert_eq!(d.samples_uv.len(), 4);
assert!((d.samples_uv[0] - 10.0).abs() < 1e-9);
assert!((d.samples_uv[1] + 10.0).abs() < 1e-9);
}
_ => panic!(),
}
}
#[test]
fn test_decode_data32() {
let header = HeaderInfo {
channel_count: 2,
sampling_interval_us: 2000.0,
resolutions_uv: vec![1.0, 1.0],
channel_names: vec!["Cz".into(), "Fz".into()],
};
let mut payload = Vec::new();
payload.extend_from_slice(&3u32.to_le_bytes());
payload.extend_from_slice(&2u32.to_le_bytes());
payload.extend_from_slice(&0u32.to_le_bytes());
for x in [1.5f32, -2.0, 3.0, -4.5] {
payload.extend_from_slice(&x.to_le_bytes());
}
let frame = make_frame(GUID_DATA32, &payload);
let msg = decode_frame(&frame, Some(&header)).unwrap();
match msg {
RdaMessage::Data32(d) => {
assert_eq!(d.block, 3);
assert_eq!(d.samples_uv.len(), 4);
assert!((d.samples_uv[0] - 1.5).abs() < 1e-9);
assert!((d.samples_uv[3] + 4.5).abs() < 1e-9);
}
_ => panic!(),
}
}
#[test]
fn test_decode_multiple_markers() {
let header = HeaderInfo {
channel_count: 1,
sampling_interval_us: 1000.0,
resolutions_uv: vec![1.0],
channel_names: vec!["Cz".into()],
};
let mut payload = Vec::new();
payload.extend_from_slice(&1u32.to_le_bytes());
payload.extend_from_slice(&1u32.to_le_bytes());
payload.extend_from_slice(&2u32.to_le_bytes());
payload.extend_from_slice(&10i16.to_le_bytes());
payload.extend_from_slice(&marker_bytes(10, 1, -1, "Stimulus", "S 11"));
payload.extend_from_slice(&marker_bytes(20, 1, -1, "Response", "R 1"));
let frame = make_frame(GUID_DATA16, &payload);
let msg = decode_frame(&frame, Some(&header)).unwrap();
match msg {
RdaMessage::Data16(d) => {
assert_eq!(d.markers.len(), 2);
assert_eq!(d.markers[0].kind, "Stimulus");
assert_eq!(d.markers[1].description, "R 1");
}
_ => panic!(),
}
}
#[test]
fn test_marker_malformed_size() {
let header = HeaderInfo {
channel_count: 1,
sampling_interval_us: 1000.0,
resolutions_uv: vec![1.0],
channel_names: vec!["Cz".into()],
};
let mut payload = Vec::new();
payload.extend_from_slice(&1u32.to_le_bytes());
payload.extend_from_slice(&1u32.to_le_bytes());
payload.extend_from_slice(&1u32.to_le_bytes());
payload.extend_from_slice(&10i16.to_le_bytes());
payload.extend_from_slice(&8u32.to_le_bytes());
let frame = make_frame(GUID_DATA16, &payload);
assert!(decode_frame(&frame, Some(&header)).is_err());
}
#[test]
fn test_unknown_guid() {
let frame = make_frame([1u8; 16], &[]);
assert!(decode_frame(&frame, None).is_err());
}
#[test]
fn test_decode_stop() {
let frame = make_frame(GUID_STOP, &[]);
assert!(matches!(
decode_frame(&frame, None).unwrap(),
RdaMessage::Stop
));
}
#[test]
fn test_size_mismatch() {
let mut frame = make_frame(GUID_STOP, &[]);
frame[16..20].copy_from_slice(&999u32.to_le_bytes());
assert!(decode_frame(&frame, None).is_err());
}
#[test]
fn test_frame_too_short_fuzz() {
for n in 0..ENVELOPE_LEN {
let frame = vec![0u8; n];
assert!(decode_frame(&frame, None).is_err());
}
}
}