use std::sync::Arc;
use std::time::Duration;
use anyhow::{anyhow, Result};
use futures_util::{SinkExt, StreamExt};
use log::{debug, info, warn};
use serde_json::Value;
use tokio::sync::{mpsc, Mutex};
use tokio_tungstenite::tungstenite::Message;
use crate::protocol::*;
use crate::types::*;
#[derive(Debug, Clone)]
pub struct CortexClientConfig {
pub client_id: String,
pub client_secret: String,
pub license: String,
pub debit: i64,
pub headset_id: String,
pub auto_create_session: bool,
pub ws_url: String,
pub debug_mode: bool,
}
impl Default for CortexClientConfig {
fn default() -> Self {
Self {
client_id: String::new(),
client_secret: String::new(),
license: String::new(),
debit: 10,
headset_id: String::new(),
auto_create_session: true,
ws_url: CORTEX_WS_URL.to_string(),
debug_mode: false,
}
}
}
struct ClientState {
auth_token: String,
session_id: String,
headset_id: String,
}
pub struct CortexClient {
config: CortexClientConfig,
state: Arc<Mutex<ClientState>>,
ws_tx: Arc<Mutex<Option<mpsc::Sender<String>>>>,
}
impl CortexClient {
pub fn new(config: CortexClientConfig) -> Self {
Self {
config,
state: Arc::new(Mutex::new(ClientState {
auth_token: String::new(),
session_id: String::new(),
headset_id: String::new(),
})),
ws_tx: Arc::new(Mutex::new(None)),
}
}
pub async fn connect(&self) -> Result<(mpsc::Receiver<CortexEvent>, CortexHandle)> {
let (event_tx, event_rx) = mpsc::channel::<CortexEvent>(512);
let (cmd_tx, cmd_rx) = mpsc::channel::<String>(64);
{
let mut ws = self.ws_tx.lock().await;
*ws = Some(cmd_tx.clone());
}
let url = self.config.ws_url.clone();
let config = self.config.clone();
let state = Arc::clone(&self.state);
tokio::spawn(async move {
if let Err(e) = run_ws_loop(url, config, state, event_tx.clone(), cmd_rx).await {
warn!("WebSocket loop exited with error: {e}");
let _ = event_tx.send(CortexEvent::Error(e.to_string())).await;
let _ = event_tx.send(CortexEvent::Disconnected).await;
}
});
let handle = CortexHandle {
state: Arc::clone(&self.state),
ws_tx: Arc::clone(&self.ws_tx),
};
Ok((event_rx, handle))
}
pub async fn auth_token(&self) -> String {
self.state.lock().await.auth_token.clone()
}
pub async fn session_id(&self) -> String {
self.state.lock().await.session_id.clone()
}
pub async fn headset_id(&self) -> String {
self.state.lock().await.headset_id.clone()
}
}
#[derive(Clone)]
pub struct CortexHandle {
state: Arc<Mutex<ClientState>>,
ws_tx: Arc<Mutex<Option<mpsc::Sender<String>>>>,
}
impl CortexHandle {
pub async fn send_raw(&self, request: Value) -> Result<()> {
let ws = self.ws_tx.lock().await;
if let Some(tx) = ws.as_ref() {
tx.send(request.to_string()).await
.map_err(|e| anyhow!("Failed to send: {e}"))?;
}
Ok(())
}
pub async fn subscribe(&self, streams: &[&str]) -> Result<()> {
let s = self.state.lock().await;
self.send_raw(subscribe(&s.auth_token, &s.session_id, streams)).await
}
pub async fn unsubscribe(&self, streams: &[&str]) -> Result<()> {
let s = self.state.lock().await;
self.send_raw(unsubscribe(&s.auth_token, &s.session_id, streams)).await
}
pub async fn create_record(&self, title: &str, description: &str) -> Result<()> {
let s = self.state.lock().await;
self.send_raw(create_record(&s.auth_token, &s.session_id, title, description)).await
}
pub async fn stop_record(&self) -> Result<()> {
let s = self.state.lock().await;
self.send_raw(stop_record(&s.auth_token, &s.session_id)).await
}
pub async fn export_record(
&self, folder: &str, format: &str, stream_types: &[&str],
record_ids: &[&str], version: &str,
) -> Result<()> {
let s = self.state.lock().await;
self.send_raw(export_record(&s.auth_token, folder, format, stream_types, record_ids, version)).await
}
pub async fn inject_marker(&self, time: f64, value: &str, label: &str) -> Result<()> {
let s = self.state.lock().await;
self.send_raw(inject_marker(&s.auth_token, &s.session_id, time, value, label)).await
}
pub async fn update_marker(&self, marker_id: &str, time: f64) -> Result<()> {
let s = self.state.lock().await;
self.send_raw(update_marker(&s.auth_token, &s.session_id, marker_id, time)).await
}
pub async fn query_profile(&self) -> Result<()> {
let s = self.state.lock().await;
self.send_raw(query_profile(&s.auth_token)).await
}
pub async fn get_current_profile(&self) -> Result<()> {
let s = self.state.lock().await;
self.send_raw(get_current_profile(&s.auth_token, &s.headset_id)).await
}
pub async fn setup_profile(&self, profile_name: &str, status: &str) -> Result<()> {
let s = self.state.lock().await;
self.send_raw(setup_profile(&s.auth_token, &s.headset_id, profile_name, status)).await
}
pub async fn train(&self, detection: &str, action: &str, status: &str) -> Result<()> {
let s = self.state.lock().await;
self.send_raw(train_request(&s.auth_token, &s.session_id, detection, action, status)).await
}
pub async fn get_mc_active_action(&self, profile_name: &str) -> Result<()> {
let s = self.state.lock().await;
self.send_raw(get_mental_command_active_action(&s.auth_token, profile_name)).await
}
pub async fn get_mc_sensitivity(&self, profile_name: &str) -> Result<()> {
let s = self.state.lock().await;
self.send_raw(get_mental_command_sensitivity(&s.auth_token, profile_name)).await
}
pub async fn set_mc_sensitivity(&self, profile_name: &str, values: &[i32]) -> Result<()> {
let s = self.state.lock().await;
self.send_raw(set_mental_command_sensitivity(&s.auth_token, profile_name, &s.session_id, values)).await
}
pub async fn get_mc_brain_map(&self, profile_name: &str) -> Result<()> {
let s = self.state.lock().await;
self.send_raw(get_mental_command_brain_map(&s.auth_token, profile_name, &s.session_id)).await
}
pub async fn get_mc_training_threshold(&self) -> Result<()> {
let s = self.state.lock().await;
self.send_raw(get_mental_command_training_threshold(&s.auth_token, &s.session_id)).await
}
pub async fn query_records(&self, query: Value) -> Result<()> {
let s = self.state.lock().await;
self.send_raw(query_records(&s.auth_token, query)).await
}
pub async fn request_download_records(&self, record_ids: &[&str]) -> Result<()> {
let s = self.state.lock().await;
self.send_raw(request_download_records(&s.auth_token, record_ids)).await
}
pub async fn sync_headset_clock(&self) -> Result<()> {
let s = self.state.lock().await;
self.send_raw(sync_with_headset_clock(&s.headset_id)).await
}
pub async fn close_session(&self) -> Result<()> {
let s = self.state.lock().await;
self.send_raw(close_session(&s.auth_token, &s.session_id)).await
}
pub async fn query_headsets(&self) -> Result<()> {
self.send_raw(query_headsets()).await
}
pub async fn get_cortex_info(&self) -> Result<()> {
self.send_raw(get_cortex_info()).await
}
pub async fn auth_token(&self) -> String {
self.state.lock().await.auth_token.clone()
}
pub async fn session_id(&self) -> String {
self.state.lock().await.session_id.clone()
}
pub async fn headset_id(&self) -> String {
self.state.lock().await.headset_id.clone()
}
}
async fn run_ws_loop(
url: String,
config: CortexClientConfig,
state: Arc<Mutex<ClientState>>,
event_tx: mpsc::Sender<CortexEvent>,
mut cmd_rx: mpsc::Receiver<String>,
) -> Result<()> {
info!("Connecting to Cortex service at {url}");
let tls_connector = build_tls_connector()?;
let connector = tokio_tungstenite::Connector::Rustls(Arc::new(tls_connector));
let (ws_stream, _response) = match tokio_tungstenite::connect_async_tls_with_config(
&url,
None,
false,
Some(connector),
).await {
Ok(pair) => pair,
Err(e) => {
let msg = format!("WebSocket connection failed: {e}");
warn!("{msg}");
return Err(anyhow!("{msg}"));
}
};
info!("WebSocket connected to {url}");
let _ = event_tx.send(CortexEvent::Connected).await;
let (mut write, mut read) = ws_stream.split();
let auth_msg = has_access_right(&config.client_id, &config.client_secret);
write.send(Message::Text(auth_msg.to_string().into())).await
.map_err(|e| anyhow!("Failed to send auth message: {e}"))?;
loop {
tokio::select! {
msg = read.next() => {
match msg {
Some(Ok(Message::Text(text))) => {
let text_str: &str = text.as_ref();
if config.debug_mode {
eprintln!("[emotiv-ws] recv: {text_str}");
}
match serde_json::from_str::<Value>(text_str) {
Ok(recv) => {
let (responses, session_ended) = handle_message(
&recv, &config, &state, &event_tx,
).await;
for resp in responses {
write.send(Message::Text(resp.into())).await
.map_err(|e| anyhow!("WebSocket write failed: {e}"))?;
}
if session_ended {
info!("Session ended — closing WebSocket loop");
break;
}
}
Err(e) => {
warn!("Failed to parse WS message: {e}");
}
}
}
Some(Ok(Message::Close(_))) => {
info!("WebSocket closed by server");
let _ = event_tx.send(CortexEvent::Disconnected).await;
break;
}
Some(Err(e)) => {
warn!("WebSocket error: {e}");
let _ = event_tx.send(CortexEvent::Error(e.to_string())).await;
let _ = event_tx.send(CortexEvent::Disconnected).await;
break;
}
None => {
info!("WebSocket stream ended");
let _ = event_tx.send(CortexEvent::Disconnected).await;
break;
}
_ => {}
}
}
cmd = cmd_rx.recv() => {
match cmd {
Some(msg) => {
if config.debug_mode {
eprintln!("[emotiv-ws] send: {msg}");
}
write.send(Message::Text(msg.into())).await
.map_err(|e| anyhow!("WebSocket write failed: {e}"))?;
}
None => {
info!("Command channel closed");
let _ = event_tx.send(CortexEvent::Disconnected).await;
break;
}
}
}
}
}
Ok(())
}
fn build_tls_connector() -> Result<rustls::ClientConfig> {
use rustls::ClientConfig;
let config = ClientConfig::builder()
.dangerous()
.with_custom_certificate_verifier(Arc::new(AcceptAnyCert))
.with_no_client_auth();
Ok(config)
}
#[derive(Debug)]
struct AcceptAnyCert;
impl rustls::client::danger::ServerCertVerifier for AcceptAnyCert {
fn verify_server_cert(
&self,
_end_entity: &rustls::pki_types::CertificateDer<'_>,
_intermediates: &[rustls::pki_types::CertificateDer<'_>],
_server_name: &rustls::pki_types::ServerName<'_>,
_ocsp_response: &[u8],
_now: rustls::pki_types::UnixTime,
) -> std::result::Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
Ok(rustls::client::danger::ServerCertVerified::assertion())
}
fn verify_tls12_signature(
&self,
_message: &[u8],
_cert: &rustls::pki_types::CertificateDer<'_>,
_dss: &rustls::DigitallySignedStruct,
) -> std::result::Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
}
fn verify_tls13_signature(
&self,
_message: &[u8],
_cert: &rustls::pki_types::CertificateDer<'_>,
_dss: &rustls::DigitallySignedStruct,
) -> std::result::Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
}
fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
vec![
rustls::SignatureScheme::RSA_PKCS1_SHA256,
rustls::SignatureScheme::RSA_PKCS1_SHA384,
rustls::SignatureScheme::RSA_PKCS1_SHA512,
rustls::SignatureScheme::ECDSA_NISTP256_SHA256,
rustls::SignatureScheme::ECDSA_NISTP384_SHA384,
rustls::SignatureScheme::ECDSA_NISTP521_SHA512,
rustls::SignatureScheme::RSA_PSS_SHA256,
rustls::SignatureScheme::RSA_PSS_SHA384,
rustls::SignatureScheme::RSA_PSS_SHA512,
rustls::SignatureScheme::ED25519,
rustls::SignatureScheme::ED448,
]
}
}
async fn handle_message(
recv: &Value,
config: &CortexClientConfig,
state: &Arc<Mutex<ClientState>>,
event_tx: &mpsc::Sender<CortexEvent>,
) -> (Vec<String>, bool) {
let mut responses = Vec::new();
if recv.get("sid").is_some() {
handle_stream_data(recv, event_tx).await;
return (responses, false);
}
if let Some(result) = recv.get("result") {
if let Some(id) = recv.get("id").and_then(|v| v.as_i64()) {
let resps = handle_result(id, result, config, state, event_tx).await;
responses.extend(resps);
}
return (responses, false);
}
if let Some(error) = recv.get("error") {
let msg = error.get("message").and_then(|v| v.as_str()).unwrap_or("unknown");
let code = error.get("code").and_then(|v| v.as_i64()).unwrap_or(0);
let req_id = recv.get("id").and_then(|v| v.as_i64()).unwrap_or(-1);
eprintln!("[emotiv-ws] ERROR (req_id={req_id}, code={code}): {msg}");
let cortex_err = crate::error::CortexError::from_api_error(code as i32, msg);
let _ = event_tx.send(CortexEvent::Error(
format!("[req_id={req_id}] {cortex_err}")
)).await;
return (responses, false);
}
if let Some(warning) = recv.get("warning") {
let code = warning.get("code").and_then(|v| v.as_i64()).unwrap_or(-1);
let message = warning.get("message").cloned().unwrap_or(Value::Null);
let _ = event_tx.send(CortexEvent::Warning { code, message: message.clone() }).await;
match code {
ACCESS_RIGHT_GRANTED => {
let already_authed = !state.lock().await.auth_token.is_empty();
if !already_authed {
responses.push(authorize(&config.client_id, &config.client_secret, &config.license, config.debit).to_string());
} else {
info!("ACCESS_RIGHT_GRANTED received but already authorized — skipping re-auth");
}
}
HEADSET_CONNECTED => {
let has_session = !state.lock().await.session_id.is_empty();
if !has_session {
responses.push(query_headsets().to_string());
} else {
info!("HEADSET_CONNECTED received but session already active — skipping query");
}
}
CORTEX_STOP_ALL_STREAMS | CORTEX_CLOSE_SESSION => {
let mut s = state.lock().await;
s.session_id.clear();
drop(s);
warn!("Session ended by Cortex service (code={code})");
let _ = event_tx.send(CortexEvent::Disconnected).await;
return (responses, true);
}
HEADSET_DISCONNECTED | HEADSET_CONNECTION_FAILED => {
let headset_msg = message.as_str()
.or_else(|| message.get("headsetId").and_then(|v| v.as_str()))
.unwrap_or("unknown");
warn!("Headset event (code={code}): {headset_msg}");
let mut s = state.lock().await;
s.session_id.clear();
drop(s);
let _ = event_tx.send(CortexEvent::Disconnected).await;
return (responses, true);
}
CORTEX_RECORD_POST_PROCESSING_DONE => {
if let Some(record_id) = message.get("recordId").and_then(|v| v.as_str()) {
let _ = event_tx.send(CortexEvent::RecordPostProcessingDone(record_id.to_string())).await;
}
}
HEADSET_SCANNING_FINISHED => {
let has_session = !state.lock().await.session_id.is_empty();
if !has_session {
responses.push(refresh_headset_list().to_string());
}
}
_ => {}
}
return (responses, false);
}
(responses, false)
}
async fn handle_result(
req_id: i64,
result: &Value,
config: &CortexClientConfig,
state: &Arc<Mutex<ClientState>>,
event_tx: &mpsc::Sender<CortexEvent>,
) -> Vec<String> {
let mut responses = Vec::new();
match req_id {
HAS_ACCESS_RIGHT_ID => {
let granted = result.get("accessGranted").and_then(|v| v.as_bool()).unwrap_or(false);
info!("hasAccessRight: granted={granted}");
if granted {
responses.push(authorize(&config.client_id, &config.client_secret, &config.license, config.debit).to_string());
} else {
responses.push(request_access(&config.client_id, &config.client_secret).to_string());
}
}
REQUEST_ACCESS_ID => {
let granted = result.get("accessGranted").and_then(|v| v.as_bool()).unwrap_or(false);
if granted {
responses.push(authorize(&config.client_id, &config.client_secret, &config.license, config.debit).to_string());
} else {
let msg = result.get("message").and_then(|v| v.as_str()).unwrap_or("Access not granted");
warn!("Access not granted: {msg}");
let _ = event_tx.send(CortexEvent::Error(format!("Access not granted: {msg}"))).await;
}
}
AUTHORIZE_ID => {
if let Some(token) = result.get("cortexToken").and_then(|v| v.as_str()) {
info!("Authorized successfully");
state.lock().await.auth_token = token.to_string();
let _ = event_tx.send(CortexEvent::Authorized).await;
if config.auto_create_session {
responses.push(refresh_headset_list().to_string());
responses.push(query_headsets().to_string());
}
}
}
QUERY_HEADSET_ID => {
if let Some(headsets) = result.as_array() {
let infos: Vec<crate::types::HeadsetInfo> = headsets.iter()
.filter_map(|v| serde_json::from_value(v.clone()).ok())
.collect();
let _ = event_tx.send(CortexEvent::HeadsetsQueried(infos)).await;
if !config.auto_create_session {
return responses;
}
let mut s = state.lock().await;
if headsets.is_empty() {
warn!("No headsets available");
return responses;
}
let target_id = if config.headset_id.is_empty() {
headsets[0].get("id").and_then(|v| v.as_str()).unwrap_or("").to_string()
} else {
config.headset_id.clone()
};
s.headset_id = target_id.clone();
for hs in headsets {
let id = hs.get("id").and_then(|v| v.as_str()).unwrap_or("");
let status = hs.get("status").and_then(|v| v.as_str()).unwrap_or("");
info!("Headset: {id}, status: {status}");
if id == target_id {
match status {
"connected" => {
responses.push(create_session(&s.auth_token, &s.headset_id).to_string());
}
"discovered" => {
responses.push(connect_headset(&s.headset_id).to_string());
}
"connecting" => {
tokio::spawn({
let headset_id = s.headset_id.clone();
async move {
tokio::time::sleep(Duration::from_secs(3)).await;
info!("Would re-query headset {headset_id}");
}
});
}
_ => {
warn!("Unknown headset status: {status}");
}
}
break;
}
}
}
}
CREATE_SESSION_ID => {
if let Some(session_id) = result.get("id").and_then(|v| v.as_str()) {
info!("Session created: {session_id}");
state.lock().await.session_id = session_id.to_string();
let _ = event_tx.send(CortexEvent::SessionCreated(session_id.to_string())).await;
}
}
SUB_REQUEST_ID => {
if let Some(success) = result.get("success").and_then(|v| v.as_array()) {
for stream in success {
let name = stream.get("streamName").and_then(|v| v.as_str()).unwrap_or("");
let cols = stream.get("cols").and_then(|v| v.as_array());
info!("Subscribed to stream: {name}");
if let Some(cols) = cols {
if name != "com" && name != "fac" {
let labels: Vec<String> = cols.iter()
.filter_map(|v| v.as_str().map(String::from))
.collect();
let _ = event_tx.send(CortexEvent::DataLabels(DataLabels {
stream_name: name.to_string(),
labels,
})).await;
}
}
}
}
if let Some(failure) = result.get("failure").and_then(|v| v.as_array()) {
for stream in failure {
let name = stream.get("streamName").and_then(|v| v.as_str()).unwrap_or("");
let code = stream.get("code").and_then(|v| v.as_i64()).unwrap_or(0);
let message = stream.get("message").and_then(|v| v.as_str()).unwrap_or("");
warn!("Failed to subscribe to stream '{name}': code={code} {message}");
let _ = event_tx.send(CortexEvent::Error(
format!("Subscribe '{name}' failed: code={code} {message}")
)).await;
}
}
}
UNSUB_REQUEST_ID => {
if let Some(success) = result.get("success").and_then(|v| v.as_array()) {
for stream in success {
let name = stream.get("streamName").and_then(|v| v.as_str()).unwrap_or("");
info!("Unsubscribed from stream: {name}");
}
}
}
QUERY_PROFILE_ID => {
if let Some(profiles) = result.as_array() {
let names: Vec<String> = profiles.iter()
.filter_map(|p| p.get("name").and_then(|v| v.as_str()).map(String::from))
.collect();
let _ = event_tx.send(CortexEvent::ProfilesQueried(names)).await;
}
}
SETUP_PROFILE_ID => {
let action = result.get("action").and_then(|v| v.as_str()).unwrap_or("");
match action {
"load" => {
info!("Profile loaded");
let _ = event_tx.send(CortexEvent::ProfileLoaded(true)).await;
}
"unload" => {
info!("Profile unloaded");
let _ = event_tx.send(CortexEvent::ProfileLoaded(false)).await;
}
"save" => {
info!("Profile saved");
let _ = event_tx.send(CortexEvent::ProfileSaved).await;
}
"create" => {
if let Some(name) = result.get("name").and_then(|v| v.as_str()) {
info!("Profile created: {name}");
}
}
_ => {}
}
}
CREATE_RECORD_REQUEST_ID => {
if let Some(record) = result.get("record") {
if let Ok(rec) = serde_json::from_value::<Record>(record.clone()) {
let _ = event_tx.send(CortexEvent::RecordCreated(rec)).await;
}
}
}
STOP_RECORD_REQUEST_ID => {
if let Some(record) = result.get("record") {
if let Ok(rec) = serde_json::from_value::<Record>(record.clone()) {
let _ = event_tx.send(CortexEvent::RecordStopped(rec)).await;
}
}
}
EXPORT_RECORD_ID => {
let mut success_ids = Vec::new();
if let Some(success) = result.get("success").and_then(|v| v.as_array()) {
for r in success {
if let Some(id) = r.get("recordId").and_then(|v| v.as_str()) {
success_ids.push(id.to_string());
}
}
}
let _ = event_tx.send(CortexEvent::RecordExported(success_ids)).await;
}
INJECT_MARKER_REQUEST_ID => {
if let Some(marker) = result.get("marker") {
if let Ok(m) = serde_json::from_value::<Marker>(marker.clone()) {
let _ = event_tx.send(CortexEvent::MarkerInjected(m)).await;
}
}
}
UPDATE_MARKER_REQUEST_ID => {
if let Some(marker) = result.get("marker") {
if let Ok(m) = serde_json::from_value::<Marker>(marker.clone()) {
let _ = event_tx.send(CortexEvent::MarkerUpdated(m)).await;
}
}
}
MENTAL_COMMAND_ACTIVE_ACTION_ID => {
let _ = event_tx.send(CortexEvent::McActiveActions(result.clone())).await;
}
SENSITIVITY_REQUEST_ID => {
let _ = event_tx.send(CortexEvent::McSensitivity(result.clone())).await;
}
MENTAL_COMMAND_BRAIN_MAP_ID => {
let _ = event_tx.send(CortexEvent::McBrainMap(result.clone())).await;
}
MENTAL_COMMAND_TRAINING_THRESHOLD => {
let _ = event_tx.send(CortexEvent::McTrainingThreshold(result.clone())).await;
}
QUERY_RECORDS_ID => {
let count = result.get("count").and_then(|v| v.as_u64()).unwrap_or(0);
let records_val = result.get("records").and_then(|v| v.as_array());
let records: Vec<Record> = records_val
.map(|arr| arr.iter().filter_map(|v| serde_json::from_value(v.clone()).ok()).collect())
.unwrap_or_default();
let _ = event_tx.send(CortexEvent::QueryRecordsDone { records, count }).await;
}
REQUEST_DOWNLOAD_RECORDS_ID => {
let _ = event_tx.send(CortexEvent::DownloadRecordsDone(result.clone())).await;
}
GET_CORTEX_INFO_ID => {
let _ = event_tx.send(CortexEvent::CortexInfo(result.clone())).await;
}
SYNC_WITH_HEADSET_CLOCK_ID => {
let _ = event_tx.send(CortexEvent::HeadsetClockSynced(result.clone())).await;
}
_ => {
debug!("Unhandled result for request id={req_id}");
}
}
responses
}
async fn handle_stream_data(recv: &Value, event_tx: &mpsc::Sender<CortexEvent>) {
let time = recv.get("time").and_then(|v| v.as_f64()).unwrap_or(0.0);
if let Some(eeg) = recv.get("eeg").and_then(|v| v.as_array()) {
let samples: Vec<f64> = eeg.iter()
.map(|v| v.as_f64().unwrap_or(f64::NAN))
.collect();
let _ = event_tx.send(CortexEvent::Eeg(EegData { samples, time })).await;
} else if let Some(mot) = recv.get("mot").and_then(|v| v.as_array()) {
let samples: Vec<f64> = mot.iter()
.map(|v| v.as_f64().unwrap_or(f64::NAN))
.collect();
let _ = event_tx.send(CortexEvent::Motion(MotionData { samples, time })).await;
} else if let Some(dev) = recv.get("dev").and_then(|v| v.as_array()) {
let signal = dev.get(1).and_then(|v| v.as_f64()).unwrap_or(0.0);
let cq = dev.get(2).and_then(|v| v.as_array())
.map(|a| a.iter().filter_map(|v| v.as_f64()).collect())
.unwrap_or_default();
let bat = dev.get(3).and_then(|v| v.as_f64()).unwrap_or(0.0);
let _ = event_tx.send(CortexEvent::Dev(DevData {
signal, contact_quality: cq, battery_percent: bat, time,
})).await;
} else if let Some(met) = recv.get("met").and_then(|v| v.as_array()) {
let values: Vec<f64> = met.iter().filter_map(|v| v.as_f64()).collect();
let _ = event_tx.send(CortexEvent::Metrics(MetricsData { values, time })).await;
} else if let Some(pow) = recv.get("pow").and_then(|v| v.as_array()) {
let powers: Vec<f64> = pow.iter().filter_map(|v| v.as_f64()).collect();
let _ = event_tx.send(CortexEvent::BandPower(BandPowerData { powers, time })).await;
} else if let Some(com) = recv.get("com").and_then(|v| v.as_array()) {
let action = com.first().and_then(|v| v.as_str()).unwrap_or("neutral").to_string();
let power = com.get(1).and_then(|v| v.as_f64()).unwrap_or(0.0);
let _ = event_tx.send(CortexEvent::MentalCommand(MentalCommandData {
action, power, time,
})).await;
} else if let Some(fac) = recv.get("fac").and_then(|v| v.as_array()) {
let eye = fac.first().and_then(|v| v.as_str()).unwrap_or("").to_string();
let u_act = fac.get(1).and_then(|v| v.as_str()).unwrap_or("").to_string();
let u_pow = fac.get(2).and_then(|v| v.as_f64()).unwrap_or(0.0);
let l_act = fac.get(3).and_then(|v| v.as_str()).unwrap_or("").to_string();
let l_pow = fac.get(4).and_then(|v| v.as_f64()).unwrap_or(0.0);
let _ = event_tx.send(CortexEvent::FacialExpression(FacialExpressionData {
eye_action: eye, upper_action: u_act, upper_power: u_pow,
lower_action: l_act, lower_power: l_pow, time,
})).await;
} else if let Some(sys) = recv.get("sys").and_then(|v| v.as_array()) {
let events: Vec<Value> = sys.clone();
let _ = event_tx.send(CortexEvent::Sys(SysData { events })).await;
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_config_defaults() {
let config = CortexClientConfig::default();
assert_eq!(config.ws_url, "wss://localhost:6868");
assert_eq!(config.debit, 10);
assert!(config.auto_create_session);
}
#[tokio::test]
async fn test_handle_stream_eeg() {
let (tx, mut rx) = mpsc::channel(16);
let msg = serde_json::json!({
"sid": "test-session",
"eeg": [1.0, 2.0, 3.0, 4.0, 5.0],
"time": 1234567890.123
});
handle_stream_data(&msg, &tx).await;
if let Some(CortexEvent::Eeg(data)) = rx.recv().await {
assert_eq!(data.samples.len(), 5);
assert!((data.time - 1234567890.123).abs() < 0.001);
} else {
panic!("Expected Eeg event");
}
}
#[tokio::test]
async fn test_handle_stream_motion() {
let (tx, mut rx) = mpsc::channel(16);
let msg = serde_json::json!({
"sid": "test-session",
"mot": [0.0, 0.0, 0.5, 0.3, 0.2, 0.1, 0.01, 0.02, -1.0, 50.0, 30.0, 20.0],
"time": 100.0
});
handle_stream_data(&msg, &tx).await;
if let Some(CortexEvent::Motion(data)) = rx.recv().await {
assert_eq!(data.samples.len(), 12);
} else {
panic!("Expected Motion event");
}
}
#[tokio::test]
async fn test_handle_stream_dev() {
let (tx, mut rx) = mpsc::channel(16);
let msg = serde_json::json!({
"sid": "test-session",
"dev": [0, 1.0, [4, 4, 4, 4, 4], 85.0],
"time": 100.0
});
handle_stream_data(&msg, &tx).await;
if let Some(CortexEvent::Dev(data)) = rx.recv().await {
assert!((data.signal - 1.0).abs() < 0.001);
assert_eq!(data.contact_quality.len(), 5);
assert!((data.battery_percent - 85.0).abs() < 0.001);
} else {
panic!("Expected Dev event");
}
}
#[tokio::test]
async fn test_handle_stream_metrics() {
let (tx, mut rx) = mpsc::channel(16);
let msg = serde_json::json!({
"sid": "test-session",
"met": [1.0, 0.5, 1.0, 0.4, 0.3, 1.0, 0.2, 1.0, 0.6, 1.0, 0.5, 1.0, 0.55],
"time": 100.0
});
handle_stream_data(&msg, &tx).await;
if let Some(CortexEvent::Metrics(data)) = rx.recv().await {
assert_eq!(data.values.len(), 13);
} else {
panic!("Expected Metrics event");
}
}
#[tokio::test]
async fn test_handle_stream_com() {
let (tx, mut rx) = mpsc::channel(16);
let msg = serde_json::json!({
"sid": "test-session",
"com": ["push", 0.85],
"time": 100.0
});
handle_stream_data(&msg, &tx).await;
if let Some(CortexEvent::MentalCommand(data)) = rx.recv().await {
assert_eq!(data.action, "push");
assert!((data.power - 0.85).abs() < 0.001);
} else {
panic!("Expected MentalCommand event");
}
}
#[tokio::test]
async fn test_handle_stream_fac() {
let (tx, mut rx) = mpsc::channel(16);
let msg = serde_json::json!({
"sid": "test-session",
"fac": ["blink", "surprise", 0.7, "smile", 0.5],
"time": 100.0
});
handle_stream_data(&msg, &tx).await;
if let Some(CortexEvent::FacialExpression(data)) = rx.recv().await {
assert_eq!(data.eye_action, "blink");
assert_eq!(data.upper_action, "surprise");
assert!((data.upper_power - 0.7).abs() < 0.001);
} else {
panic!("Expected FacialExpression event");
}
}
#[tokio::test]
async fn test_handle_stream_pow() {
let (tx, mut rx) = mpsc::channel(16);
let msg = serde_json::json!({
"sid": "test-session",
"pow": [5.0, 4.0, 3.0, 1.0, 0.5],
"time": 100.0
});
handle_stream_data(&msg, &tx).await;
if let Some(CortexEvent::BandPower(data)) = rx.recv().await {
assert_eq!(data.powers.len(), 5);
} else {
panic!("Expected BandPower event");
}
}
#[tokio::test]
async fn test_handle_stream_sys() {
let (tx, mut rx) = mpsc::channel(16);
let msg = serde_json::json!({
"sid": "test-session",
"sys": ["mentalCommand", "MC_Succeeded"]
});
handle_stream_data(&msg, &tx).await;
if let Some(CortexEvent::Sys(data)) = rx.recv().await {
assert_eq!(data.events.len(), 2);
} else {
panic!("Expected Sys event");
}
}
fn test_config() -> CortexClientConfig {
CortexClientConfig {
client_id: "test_id".into(),
client_secret: "test_secret".into(),
..Default::default()
}
}
fn test_state() -> Arc<Mutex<ClientState>> {
Arc::new(Mutex::new(ClientState {
auth_token: "test_token".into(),
session_id: "test_session".into(),
headset_id: "test_headset".into(),
}))
}
#[tokio::test]
async fn test_handle_has_access_right_granted() {
let (tx, _rx) = mpsc::channel(16);
let config = test_config();
let state = test_state();
let result = serde_json::json!({"accessGranted": true});
let responses = handle_result(HAS_ACCESS_RIGHT_ID, &result, &config, &state, &tx).await;
assert_eq!(responses.len(), 1);
let resp: serde_json::Value = serde_json::from_str(&responses[0]).unwrap();
assert_eq!(resp["method"], "authorize");
}
#[tokio::test]
async fn test_handle_has_access_right_denied() {
let (tx, _rx) = mpsc::channel(16);
let config = test_config();
let state = test_state();
let result = serde_json::json!({"accessGranted": false});
let responses = handle_result(HAS_ACCESS_RIGHT_ID, &result, &config, &state, &tx).await;
assert_eq!(responses.len(), 1);
let resp: serde_json::Value = serde_json::from_str(&responses[0]).unwrap();
assert_eq!(resp["method"], "requestAccess");
}
#[tokio::test]
async fn test_handle_authorize() {
let (tx, mut rx) = mpsc::channel(16);
let config = test_config();
let state = test_state();
let result = serde_json::json!({"cortexToken": "new_token_abc"});
let responses = handle_result(AUTHORIZE_ID, &result, &config, &state, &tx).await;
assert_eq!(state.lock().await.auth_token, "new_token_abc");
if let Some(CortexEvent::Authorized) = rx.recv().await {
} else {
panic!("Expected Authorized event");
}
assert_eq!(responses.len(), 2);
}
#[tokio::test]
async fn test_handle_create_session() {
let (tx, mut rx) = mpsc::channel(16);
let config = test_config();
let state = test_state();
let result = serde_json::json!({"id": "session-xyz"});
let _responses = handle_result(CREATE_SESSION_ID, &result, &config, &state, &tx).await;
assert_eq!(state.lock().await.session_id, "session-xyz");
if let Some(CortexEvent::SessionCreated(id)) = rx.recv().await {
assert_eq!(id, "session-xyz");
} else {
panic!("Expected SessionCreated event");
}
}
#[tokio::test]
async fn test_handle_subscribe() {
let (tx, mut rx) = mpsc::channel(16);
let config = test_config();
let state = test_state();
let result = serde_json::json!({
"success": [
{"streamName": "eeg", "cols": ["AF3", "F7", "F3"]},
{"streamName": "mot", "cols": ["ACCX", "ACCY", "ACCZ"]}
],
"failure": []
});
let _responses = handle_result(SUB_REQUEST_ID, &result, &config, &state, &tx).await;
let ev1 = rx.recv().await.unwrap();
let ev2 = rx.recv().await.unwrap();
let mut labels_received = vec![];
for ev in [ev1, ev2] {
if let CortexEvent::DataLabels(l) = ev {
labels_received.push(l.stream_name);
}
}
labels_received.sort();
assert_eq!(labels_received, vec!["eeg", "mot"]);
}
#[tokio::test]
async fn test_handle_query_profile() {
let (tx, mut rx) = mpsc::channel(16);
let config = test_config();
let state = test_state();
let result = serde_json::json!([
{"name": "profile_a", "readOnly": false},
{"name": "profile_b", "readOnly": true}
]);
let _responses = handle_result(QUERY_PROFILE_ID, &result, &config, &state, &tx).await;
if let Some(CortexEvent::ProfilesQueried(names)) = rx.recv().await {
assert_eq!(names, vec!["profile_a", "profile_b"]);
} else {
panic!("Expected ProfilesQueried event");
}
}
#[tokio::test]
async fn test_handle_setup_profile_load() {
let (tx, mut rx) = mpsc::channel(16);
let config = test_config();
let state = test_state();
let result = serde_json::json!({"action": "load"});
let _responses = handle_result(SETUP_PROFILE_ID, &result, &config, &state, &tx).await;
if let Some(CortexEvent::ProfileLoaded(true)) = rx.recv().await {
} else {
panic!("Expected ProfileLoaded(true)");
}
}
#[tokio::test]
async fn test_handle_setup_profile_save() {
let (tx, mut rx) = mpsc::channel(16);
let config = test_config();
let state = test_state();
let result = serde_json::json!({"action": "save"});
let _responses = handle_result(SETUP_PROFILE_ID, &result, &config, &state, &tx).await;
if let Some(CortexEvent::ProfileSaved) = rx.recv().await {
} else {
panic!("Expected ProfileSaved");
}
}
#[tokio::test]
async fn test_handle_create_record() {
let (tx, mut rx) = mpsc::channel(16);
let config = test_config();
let state = test_state();
let result = serde_json::json!({
"record": {
"uuid": "rec-123",
"title": "Test",
"startDatetime": "2026-01-01T00:00:00Z"
}
});
let _responses = handle_result(CREATE_RECORD_REQUEST_ID, &result, &config, &state, &tx).await;
if let Some(CortexEvent::RecordCreated(rec)) = rx.recv().await {
assert_eq!(rec.uuid, "rec-123");
assert_eq!(rec.title, "Test");
} else {
panic!("Expected RecordCreated");
}
}
#[tokio::test]
async fn test_handle_stop_record() {
let (tx, mut rx) = mpsc::channel(16);
let config = test_config();
let state = test_state();
let result = serde_json::json!({
"record": {
"uuid": "rec-123",
"title": "Test",
"startDatetime": "2026-01-01T00:00:00Z",
"endDatetime": "2026-01-01T00:01:00Z"
}
});
let _responses = handle_result(STOP_RECORD_REQUEST_ID, &result, &config, &state, &tx).await;
if let Some(CortexEvent::RecordStopped(rec)) = rx.recv().await {
assert_eq!(rec.uuid, "rec-123");
} else {
panic!("Expected RecordStopped");
}
}
#[tokio::test]
async fn test_handle_export_record() {
let (tx, mut rx) = mpsc::channel(16);
let config = test_config();
let state = test_state();
let result = serde_json::json!({
"success": [{"recordId": "rec-1"}, {"recordId": "rec-2"}],
"failure": []
});
let _responses = handle_result(EXPORT_RECORD_ID, &result, &config, &state, &tx).await;
if let Some(CortexEvent::RecordExported(ids)) = rx.recv().await {
assert_eq!(ids, vec!["rec-1", "rec-2"]);
} else {
panic!("Expected RecordExported");
}
}
#[tokio::test]
async fn test_handle_inject_marker() {
let (tx, mut rx) = mpsc::channel(16);
let config = test_config();
let state = test_state();
let result = serde_json::json!({
"marker": {
"uuid": "mk-001",
"type": "instance",
"label": "test_label",
"value": "test_val",
"startDatetime": "2026-01-01T00:00:00Z"
}
});
let _responses = handle_result(INJECT_MARKER_REQUEST_ID, &result, &config, &state, &tx).await;
if let Some(CortexEvent::MarkerInjected(m)) = rx.recv().await {
assert_eq!(m.uuid, "mk-001");
assert_eq!(m.label, "test_label");
} else {
panic!("Expected MarkerInjected");
}
}
#[tokio::test]
async fn test_handle_query_records() {
let (tx, mut rx) = mpsc::channel(16);
let config = test_config();
let state = test_state();
let result = serde_json::json!({
"count": 2,
"limit": 100,
"offset": 0,
"records": [
{"uuid": "r1", "title": "First"},
{"uuid": "r2", "title": "Second"}
]
});
let _responses = handle_result(QUERY_RECORDS_ID, &result, &config, &state, &tx).await;
if let Some(CortexEvent::QueryRecordsDone { records, count }) = rx.recv().await {
assert_eq!(count, 2);
assert_eq!(records.len(), 2);
assert_eq!(records[0].uuid, "r1");
assert_eq!(records[1].title, "Second");
} else {
panic!("Expected QueryRecordsDone");
}
}
#[tokio::test]
async fn test_handle_mc_active_actions() {
let (tx, mut rx) = mpsc::channel(16);
let config = test_config();
let state = test_state();
let result = serde_json::json!(["neutral", "push", "pull"]);
let _responses = handle_result(MENTAL_COMMAND_ACTIVE_ACTION_ID, &result, &config, &state, &tx).await;
if let Some(CortexEvent::McActiveActions(data)) = rx.recv().await {
assert!(data.is_array());
assert_eq!(data.as_array().unwrap().len(), 3);
} else {
panic!("Expected McActiveActions");
}
}
#[tokio::test]
async fn test_handle_mc_sensitivity() {
let (tx, mut rx) = mpsc::channel(16);
let config = test_config();
let state = test_state();
let result = serde_json::json!([7, 8, 5, 5]);
let _responses = handle_result(SENSITIVITY_REQUEST_ID, &result, &config, &state, &tx).await;
if let Some(CortexEvent::McSensitivity(data)) = rx.recv().await {
assert_eq!(data[0], 7);
} else {
panic!("Expected McSensitivity");
}
}
#[tokio::test]
async fn test_handle_warning_message() {
let (tx, mut rx) = mpsc::channel(16);
let config = test_config();
let state = test_state();
let msg = serde_json::json!({
"warning": {
"code": 30,
"message": {"recordId": "rec-done"}
}
});
let (responses, _ended) = handle_message(&msg, &config, &state, &tx).await;
let mut found_warning = false;
let mut found_ppd = false;
while let Ok(ev) = rx.try_recv() {
match ev {
CortexEvent::Warning { code, .. } => {
assert_eq!(code, 30);
found_warning = true;
}
CortexEvent::RecordPostProcessingDone(rid) => {
assert_eq!(rid, "rec-done");
found_ppd = true;
}
_ => {}
}
}
assert!(found_warning, "Expected Warning event");
assert!(found_ppd, "Expected RecordPostProcessingDone event");
let _ = responses; }
#[tokio::test]
async fn test_handle_error_message() {
let (tx, mut rx) = mpsc::channel(16);
let config = test_config();
let state = test_state();
let msg = serde_json::json!({
"id": 999,
"error": {
"code": -32046,
"message": "Profile access denied"
}
});
let (_responses, _ended) = handle_message(&msg, &config, &state, &tx).await;
if let Some(CortexEvent::Error(e)) = rx.recv().await {
assert!(e.contains("-32046"));
assert!(e.contains("Profile access denied"));
} else {
panic!("Expected Error event");
}
}
#[tokio::test]
async fn test_handle_stream_data_routing() {
let (tx, mut rx) = mpsc::channel(16);
let config = test_config();
let state = test_state();
let msg = serde_json::json!({
"sid": "ses-001",
"eeg": [10.0, 20.0],
"time": 500.0
});
let (responses, _ended) = handle_message(&msg, &config, &state, &tx).await;
assert!(responses.is_empty());
if let Some(CortexEvent::Eeg(data)) = rx.recv().await {
assert_eq!(data.samples, vec![10.0, 20.0]);
} else {
panic!("Expected Eeg from stream routing");
}
}
#[tokio::test]
async fn test_handle_headset_query_connected() {
let (tx, mut rx) = mpsc::channel(16);
let config = CortexClientConfig {
client_id: "test".into(),
client_secret: "test".into(),
headset_id: "EPOCX-001".into(),
..Default::default()
};
let state = Arc::new(Mutex::new(ClientState {
auth_token: "tok".into(),
session_id: String::new(),
headset_id: String::new(),
}));
let result = serde_json::json!([
{"id": "EPOCX-001", "status": "connected", "connectedBy": "dongle"}
]);
let responses = handle_result(QUERY_HEADSET_ID, &result, &config, &state, &tx).await;
if let Some(CortexEvent::HeadsetsQueried(headsets)) = rx.recv().await {
assert_eq!(headsets.len(), 1);
assert_eq!(headsets[0].id, "EPOCX-001");
assert_eq!(headsets[0].status, "connected");
} else {
panic!("Expected HeadsetsQueried event");
}
assert_eq!(responses.len(), 1);
let resp: serde_json::Value = serde_json::from_str(&responses[0]).unwrap();
assert_eq!(resp["method"], "createSession");
assert_eq!(state.lock().await.headset_id, "EPOCX-001");
}
#[tokio::test]
async fn test_handle_headset_query_discovered() {
let (tx, mut rx) = mpsc::channel(16);
let config = CortexClientConfig {
client_id: "test".into(),
client_secret: "test".into(),
headset_id: "INSIGHT-002".into(),
..Default::default()
};
let state = Arc::new(Mutex::new(ClientState {
auth_token: "tok".into(),
session_id: String::new(),
headset_id: String::new(),
}));
let result = serde_json::json!([
{"id": "INSIGHT-002", "status": "discovered", "connectedBy": ""}
]);
let responses = handle_result(QUERY_HEADSET_ID, &result, &config, &state, &tx).await;
if let Some(CortexEvent::HeadsetsQueried(headsets)) = rx.recv().await {
assert_eq!(headsets.len(), 1);
assert_eq!(headsets[0].id, "INSIGHT-002");
assert_eq!(headsets[0].status, "discovered");
} else {
panic!("Expected HeadsetsQueried event");
}
assert_eq!(responses.len(), 1);
let resp: serde_json::Value = serde_json::from_str(&responses[0]).unwrap();
assert_eq!(resp["method"], "controlDevice");
assert_eq!(resp["params"]["command"], "connect");
}
#[tokio::test]
async fn test_handle_headset_query_multiple_headsets() {
let (tx, mut rx) = mpsc::channel(16);
let config = CortexClientConfig {
client_id: "test".into(),
client_secret: "test".into(),
..Default::default() };
let state = Arc::new(Mutex::new(ClientState {
auth_token: "tok".into(),
session_id: String::new(),
headset_id: String::new(),
}));
let result = serde_json::json!([
{"id": "EPOCX-AAA", "status": "connected", "connectedBy": "dongle"},
{"id": "INSIGHT-BBB", "status": "discovered", "connectedBy": ""}
]);
let responses = handle_result(QUERY_HEADSET_ID, &result, &config, &state, &tx).await;
if let Some(CortexEvent::HeadsetsQueried(headsets)) = rx.recv().await {
assert_eq!(headsets.len(), 2);
assert_eq!(headsets[0].id, "EPOCX-AAA");
assert_eq!(headsets[1].id, "INSIGHT-BBB");
} else {
panic!("Expected HeadsetsQueried event");
}
assert_eq!(responses.len(), 1);
assert_eq!(state.lock().await.headset_id, "EPOCX-AAA");
}
#[tokio::test]
async fn test_handle_headset_query_empty() {
let (tx, mut rx) = mpsc::channel(16);
let config = test_config();
let state = test_state();
let result = serde_json::json!([]);
let responses = handle_result(QUERY_HEADSET_ID, &result, &config, &state, &tx).await;
if let Some(CortexEvent::HeadsetsQueried(headsets)) = rx.recv().await {
assert!(headsets.is_empty());
} else {
panic!("Expected HeadsetsQueried event");
}
assert!(responses.is_empty());
}
#[tokio::test]
async fn test_handle_warning_headset_disconnected() {
let (tx, mut rx) = mpsc::channel(16);
let config = test_config();
let state = test_state();
let msg = serde_json::json!({
"warning": {
"code": 102,
"message": {"headsetId": "EPOCX-001"}
}
});
let (_responses, ended) = handle_message(&msg, &config, &state, &tx).await;
assert!(ended, "HEADSET_DISCONNECTED should signal session ended");
let mut found_warning = false;
let mut found_disconnect = false;
while let Ok(ev) = rx.try_recv() {
match ev {
CortexEvent::Warning { code, .. } => {
assert_eq!(code, 102);
found_warning = true;
}
CortexEvent::Disconnected => {
found_disconnect = true;
}
_ => {}
}
}
assert!(found_warning, "Expected Warning event");
assert!(found_disconnect, "Expected Disconnected event on headset disconnect");
assert!(state.lock().await.session_id.is_empty(), "Session should be cleared");
}
#[tokio::test]
async fn test_handle_warning_headset_connection_failed() {
let (tx, mut rx) = mpsc::channel(16);
let config = test_config();
let state = test_state();
let msg = serde_json::json!({
"warning": {
"code": 103,
"message": "Connection timed out"
}
});
let (_responses, ended) = handle_message(&msg, &config, &state, &tx).await;
assert!(ended, "HEADSET_CONNECTION_FAILED should signal session ended");
let mut found_warning = false;
let mut found_disconnect = false;
while let Ok(ev) = rx.try_recv() {
match ev {
CortexEvent::Warning { code, .. } => {
assert_eq!(code, 103);
found_warning = true;
}
CortexEvent::Disconnected => {
found_disconnect = true;
}
_ => {}
}
}
assert!(found_warning, "Expected Warning event");
assert!(found_disconnect, "Expected Disconnected event on connection failure");
}
#[tokio::test]
async fn test_handle_warning_stop_all_streams_clears_session() {
let (tx, _rx) = mpsc::channel(16);
let config = test_config();
let state = test_state();
assert!(!state.lock().await.session_id.is_empty());
let msg = serde_json::json!({
"warning": {
"code": 0,
"message": "Stop all streams"
}
});
let (_responses, ended) = handle_message(&msg, &config, &state, &tx).await;
assert!(ended, "CORTEX_STOP_ALL_STREAMS should signal session ended");
assert!(state.lock().await.session_id.is_empty(), "Session should be cleared");
}
#[tokio::test]
async fn test_handle_warning_close_session_ends_session() {
let (tx, mut rx) = mpsc::channel(16);
let config = test_config();
let state = test_state();
assert!(!state.lock().await.session_id.is_empty());
let msg = serde_json::json!({
"warning": {
"code": 1,
"message": "Session closed"
}
});
let (_responses, ended) = handle_message(&msg, &config, &state, &tx).await;
assert!(ended, "CORTEX_CLOSE_SESSION should signal session ended");
assert!(state.lock().await.session_id.is_empty(), "Session should be cleared");
let mut found_disconnect = false;
while let Ok(ev) = rx.try_recv() {
if matches!(ev, CortexEvent::Disconnected) {
found_disconnect = true;
}
}
assert!(found_disconnect, "Expected Disconnected event");
}
#[tokio::test]
async fn test_handle_warning_access_granted_skips_reauth() {
let (tx, _rx) = mpsc::channel(16);
let config = test_config();
let state = test_state(); let msg = serde_json::json!({
"warning": {
"code": 9,
"message": "Access granted"
}
});
let (responses, _ended) = handle_message(&msg, &config, &state, &tx).await;
assert!(responses.is_empty(), "Should skip re-auth when already authorized");
}
#[tokio::test]
async fn test_handle_warning_access_granted_authorizes_when_no_token() {
let (tx, _rx) = mpsc::channel(16);
let config = test_config();
let state = Arc::new(Mutex::new(ClientState {
auth_token: String::new(),
session_id: String::new(),
headset_id: String::new(),
}));
let msg = serde_json::json!({
"warning": {
"code": 9,
"message": "Access granted"
}
});
let (responses, _ended) = handle_message(&msg, &config, &state, &tx).await;
assert_eq!(responses.len(), 1);
let resp: serde_json::Value = serde_json::from_str(&responses[0]).unwrap();
assert_eq!(resp["method"], "authorize");
}
#[tokio::test]
async fn test_handle_warning_headset_connected_skips_when_session_active() {
let (tx, _rx) = mpsc::channel(16);
let config = test_config();
let state = test_state(); let msg = serde_json::json!({
"warning": {
"code": 104,
"message": {"headsetId": "EPOCX-001"}
}
});
let (responses, _ended) = handle_message(&msg, &config, &state, &tx).await;
assert!(responses.is_empty(), "Should skip query when session already active");
}
#[tokio::test]
async fn test_handle_warning_headset_connected_queries_when_no_session() {
let (tx, _rx) = mpsc::channel(16);
let config = test_config();
let state = Arc::new(Mutex::new(ClientState {
auth_token: "tok".into(),
session_id: String::new(),
headset_id: String::new(),
}));
let msg = serde_json::json!({
"warning": {
"code": 104,
"message": {"headsetId": "EPOCX-001"}
}
});
let (responses, _ended) = handle_message(&msg, &config, &state, &tx).await;
assert_eq!(responses.len(), 1);
let resp: serde_json::Value = serde_json::from_str(&responses[0]).unwrap();
assert_eq!(resp["method"], "queryHeadsets");
}
#[tokio::test]
async fn test_handle_subscribe_failure() {
let (tx, mut rx) = mpsc::channel(16);
let config = test_config();
let state = test_state();
let result = serde_json::json!({
"success": [],
"failure": [
{"streamName": "eeg", "code": -32015, "message": "License expired"}
]
});
let _responses = handle_result(SUB_REQUEST_ID, &result, &config, &state, &tx).await;
if let Some(CortexEvent::Error(e)) = rx.recv().await {
assert!(e.contains("eeg"), "Error should mention stream name");
assert!(e.contains("License expired"), "Error should mention reason");
} else {
panic!("Expected Error event for subscribe failure");
}
}
#[tokio::test]
async fn test_handle_headset_query_no_auto_connect() {
let (tx, mut rx) = mpsc::channel(16);
let config = CortexClientConfig {
client_id: "test".into(),
client_secret: "test".into(),
auto_create_session: false,
..Default::default()
};
let state = Arc::new(Mutex::new(ClientState {
auth_token: "tok".into(),
session_id: String::new(),
headset_id: String::new(),
}));
let result = serde_json::json!([
{"id": "EPOCX-AAA", "status": "connected", "connectedBy": "dongle"},
{"id": "INSIGHT-BBB", "status": "discovered", "connectedBy": ""}
]);
let responses = handle_result(QUERY_HEADSET_ID, &result, &config, &state, &tx).await;
if let Some(CortexEvent::HeadsetsQueried(headsets)) = rx.recv().await {
assert_eq!(headsets.len(), 2);
} else {
panic!("Expected HeadsetsQueried event");
}
assert!(responses.is_empty(), "expected no side-effect responses, got {responses:?}");
assert!(state.lock().await.headset_id.is_empty());
}
}