use std::collections::VecDeque;
use std::io::Read;
use std::net::TcpStream;
use std::time::{Duration, Instant};
use crate::error::BrainVisionError;
use crate::protocol::decode_frame;
use crate::types::*;
#[derive(Debug, Clone)]
pub struct Scan {
pub data: Vec<f64>,
}
impl Scan {
pub fn eeg(&self) -> &[f64] {
&self.data
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum BackpressurePolicy {
DropOldest,
DropNewest,
Block,
}
#[derive(Debug, Clone)]
pub struct DeviceConfig {
pub read_timeout: Duration,
pub max_scan_buffer: usize,
pub backpressure_policy: BackpressurePolicy,
}
impl Default for DeviceConfig {
fn default() -> Self {
Self {
read_timeout: Duration::from_millis(1500),
max_scan_buffer: 16_384,
backpressure_policy: BackpressurePolicy::DropOldest,
}
}
}
#[derive(Debug, Clone, Default)]
pub struct DeviceStats {
pub last_block: Option<u32>,
pub dropped_blocks: u64,
pub dropped_samples_estimate: u64,
pub dropped_by_backpressure: u64,
pub last_marker: Option<String>,
pub last_block_dt: Option<Duration>,
pub expected_block_dt: Option<Duration>,
}
type MarkerCallback = Box<dyn Fn(&Marker) + Send + Sync>;
pub struct BrainVisionDevice {
endpoint: String,
stream: TcpStream,
header: Option<HeaderInfo>,
scan_buf: VecDeque<Scan>,
stats: DeviceStats,
last_data_at: Option<Instant>,
last_points: Option<u32>,
config: DeviceConfig,
marker_callback: Option<MarkerCallback>,
}
impl BrainVisionDevice {
pub fn connect(host: &str, port: u16) -> Result<Self, BrainVisionError> {
Self::connect_with_config(host, port, DeviceConfig::default())
}
pub fn connect_with_config(
host: &str,
port: u16,
config: DeviceConfig,
) -> Result<Self, BrainVisionError> {
#[cfg(feature = "sandbox")]
if !crate::sandbox::endpoint_allowed(host, port) {
return Err(BrainVisionError::NotSupported(format!(
"endpoint {host}:{port} not in allowlist; call allow_only_endpoint first"
)));
}
let endpoint = format!("{host}:{port}");
let stream = TcpStream::connect(&endpoint)
.map_err(|e| BrainVisionError::Connection(e.to_string()))?;
stream
.set_read_timeout(Some(config.read_timeout))
.map_err(|e| BrainVisionError::Io(e.to_string()))?;
Ok(Self {
endpoint,
stream,
header: None,
scan_buf: VecDeque::new(),
stats: DeviceStats::default(),
last_data_at: None,
last_points: None,
config,
marker_callback: None,
})
}
pub fn connect_default(host: &str) -> Result<Self, BrainVisionError> {
Self::connect(host, RDA_PORT_I16)
}
pub fn set_marker_callback<F>(&mut self, cb: F)
where
F: Fn(&Marker) + Send + Sync + 'static,
{
self.marker_callback = Some(Box::new(cb));
}
pub fn reconnect(&mut self) -> Result<(), BrainVisionError> {
let stream = TcpStream::connect(&self.endpoint)
.map_err(|e| BrainVisionError::Connection(e.to_string()))?;
stream
.set_read_timeout(Some(self.config.read_timeout))
.map_err(|e| BrainVisionError::Io(e.to_string()))?;
self.stream = stream;
self.scan_buf.clear();
self.last_data_at = None;
self.last_points = None;
Ok(())
}
pub fn reconnect_with_backoff(
&mut self,
retries: u32,
base_delay: Duration,
) -> Result<(), BrainVisionError> {
let mut delay = base_delay;
for attempt in 0..=retries {
if self.reconnect().is_ok() {
log::info!(
"Reconnected to {} on attempt {}",
self.endpoint,
attempt + 1
);
return Ok(());
}
if attempt < retries {
std::thread::sleep(delay);
delay = delay.saturating_mul(2);
}
}
Err(BrainVisionError::Connection(format!(
"failed to reconnect to {}",
self.endpoint
)))
}
pub fn header(&self) -> Option<&HeaderInfo> {
self.header.as_ref()
}
pub fn stats(&self) -> &DeviceStats {
&self.stats
}
fn read_frame(&mut self) -> Result<Vec<u8>, BrainVisionError> {
let mut env = [0u8; ENVELOPE_LEN];
self.stream.read_exact(&mut env)?;
let size = u32::from_le_bytes(env[16..20].try_into().unwrap()) as usize;
if size < ENVELOPE_LEN {
return Err(BrainVisionError::Protocol("message size < envelope".into()));
}
let mut frame = Vec::with_capacity(size);
frame.extend_from_slice(&env);
if size > ENVELOPE_LEN {
let mut payload = vec![0u8; size - ENVELOPE_LEN];
self.stream.read_exact(&mut payload)?;
frame.extend_from_slice(&payload);
}
Ok(frame)
}
pub fn read_message(&mut self) -> Result<RdaMessage, BrainVisionError> {
let frame = self.read_frame()?;
let msg = decode_frame(&frame, self.header.as_ref())?;
if let RdaMessage::Start(h) = &msg {
self.header = Some(h.clone());
}
Ok(msg)
}
pub fn read_message_resilient(
&mut self,
retries: u32,
base_delay: Duration,
) -> Result<RdaMessage, BrainVisionError> {
match self.read_message() {
Ok(m) => Ok(m),
Err(e) if is_transient(&e) => {
self.reconnect_with_backoff(retries, base_delay)?;
self.read_message()
}
Err(e) => Err(e),
}
}
pub fn wait_for_start(&mut self) -> Result<HeaderInfo, BrainVisionError> {
loop {
if let RdaMessage::Start(h) = self.read_message()? {
return Ok(h);
}
}
}
fn update_stats_from_block(&mut self, b: &DataBlock) {
if let Some(prev) = self.stats.last_block {
if b.block > prev + 1 {
self.stats.dropped_blocks += (b.block - prev - 1) as u64;
if let Some(ppb) = self.last_points {
self.stats.dropped_samples_estimate += (b.block - prev - 1) as u64 * ppb as u64;
}
}
}
self.stats.last_block = Some(b.block);
self.last_points = Some(b.points);
if let Some(m) = b.markers.last() {
self.stats.last_marker = Some(format!("{}:{}", m.kind, m.description));
}
if let Some(cb) = &self.marker_callback {
for m in &b.markers {
cb(m);
}
}
let now = Instant::now();
self.stats.last_block_dt = self.last_data_at.map(|t| now.saturating_duration_since(t));
self.last_data_at = Some(now);
if let Some(h) = &self.header {
let rate = h.sampling_rate_hz();
if rate > 0.0 {
let expected = Duration::from_secs_f64((b.points as f64) / rate);
self.stats.expected_block_dt = Some(expected);
if let Some(observed) = self.stats.last_block_dt {
if observed > expected.mul_f64(1.5) {
let extra = observed.as_secs_f64() - expected.as_secs_f64();
self.stats.dropped_samples_estimate += (extra * rate).max(0.0) as u64;
}
}
}
}
}
pub fn next_block(&mut self) -> Result<Option<DataBlock>, BrainVisionError> {
loop {
match self.read_message()? {
RdaMessage::Data16(b) | RdaMessage::Data32(b) => {
self.update_stats_from_block(&b);
return Ok(Some(b));
}
RdaMessage::Stop => return Ok(None),
_ => {}
}
}
}
pub fn next_block_resilient(
&mut self,
retries: u32,
base_delay: Duration,
) -> Result<Option<DataBlock>, BrainVisionError> {
match self.next_block() {
Ok(b) => Ok(b),
Err(e) if is_transient(&e) => {
self.reconnect_with_backoff(retries, base_delay)?;
self.next_block()
}
Err(e) => Err(e),
}
}
pub fn read_scans(&mut self) -> Result<Vec<Scan>, BrainVisionError> {
let b = match self.next_block()? {
None => return Ok(Vec::new()),
Some(b) => b,
};
let channels = self
.header
.as_ref()
.map(|h| h.channel_count as usize)
.ok_or_else(|| BrainVisionError::Protocol("no header context".into()))?;
let mut scans = Vec::with_capacity(b.points as usize);
for chunk in b.samples_uv.chunks(channels) {
scans.push(Scan {
data: chunk.to_vec(),
});
}
Ok(scans)
}
fn enqueue_scan(&mut self, scan: Scan) {
let cap = self.config.max_scan_buffer;
if cap == 0 {
self.scan_buf.push_back(scan);
return;
}
if self.scan_buf.len() < cap {
self.scan_buf.push_back(scan);
return;
}
match self.config.backpressure_policy {
BackpressurePolicy::DropOldest => {
let _ = self.scan_buf.pop_front();
self.scan_buf.push_back(scan);
self.stats.dropped_by_backpressure += 1;
}
BackpressurePolicy::DropNewest => {
self.stats.dropped_by_backpressure += 1;
}
BackpressurePolicy::Block => {
let start = Instant::now();
while self.scan_buf.len() >= cap && start.elapsed() < Duration::from_millis(10) {
std::thread::sleep(Duration::from_millis(1));
}
if self.scan_buf.len() < cap {
self.scan_buf.push_back(scan);
} else {
self.stats.dropped_by_backpressure += 1;
}
}
}
}
pub fn next_scan(&mut self) -> Result<Option<Scan>, BrainVisionError> {
if let Some(s) = self.scan_buf.pop_front() {
return Ok(Some(s));
}
let scans = self.read_scans()?;
if scans.is_empty() {
return Ok(None);
}
for s in scans {
self.enqueue_scan(s);
}
Ok(self.scan_buf.pop_front())
}
pub fn capture(&mut self, n_scans: u32) -> Result<Vec<Scan>, BrainVisionError> {
if self.header.is_none() {
let _ = self.wait_for_start()?;
}
let mut out = Vec::with_capacity(n_scans as usize);
while out.len() < n_scans as usize {
match self.next_scan()? {
Some(s) => out.push(s),
None => break,
}
}
Ok(out)
}
pub fn close(self) {
drop(self);
}
}
fn is_transient(e: &BrainVisionError) -> bool {
matches!(
e,
BrainVisionError::Connection(_) | BrainVisionError::Io(_) | BrainVisionError::Timeout
)
}