Skip to main content

idun_rs/
cloud.rs

1//! IDUN Cloud WebSocket client for server-side EEG decoding.
2//!
3//! When the local experimental 12-bit decoder fails (or is not trusted),
4//! this module provides a fallback that sends raw BLE packets to the
5//! IDUN Cloud API for authoritative decoding.
6//!
7//! # API token
8//!
9//! You need an **IDUN API token** to use the cloud decoder.
10//! Get one from <https://idun.tech/>, then either:
11//!
12//! - Set the `IDUN_API_TOKEN` environment variable:
13//!   ```bash
14//!   export IDUN_API_TOKEN="your_token"
15//!   ```
16//! - Or pass it directly via [`CloudDecoder::new`]
17//! - Or use the CLI flag: `--token your_token`
18//!
19//! # Protocol
20//!
21//! The IDUN Cloud uses a WebSocket endpoint at `wss://ws-api.idun.cloud`.
22//! Authentication is via query parameter: `?authorization={api_token}`.
23//!
24//! ## Session flow
25//!
26//! 1. Connect WebSocket with auth token
27//! 2. Send `startNewRecording` → receive `recordingUpdate` with `recordingId`
28//! 3. Send `subscribeLiveStreamInsights` for `RAW_EEG` / `FILTERED_EEG`
29//! 4. Send `publishRawMeasurements` with base64-encoded raw BLE packets
30//! 5. Receive `liveStreamInsights` with decoded EEG data
31//! 6. Send `endOngoingRecording` when done
32//!
33//! # Usage
34//!
35//! ```no_run
36//! use idun_rs::cloud::CloudDecoder;
37//!
38//! # #[tokio::main]
39//! # async fn main() -> anyhow::Result<()> {
40//! let mut decoder = CloudDecoder::new(
41//!     "my-api-token".to_string(),
42//!     "AA-BB-CC-DD-EE-FF".to_string(),
43//! );
44//! decoder.connect().await?;
45//!
46//! // Send a raw BLE packet for cloud decoding
47//! let raw_packet = vec![0xAA, 0x01, /* ... */];
48//! decoder.send_raw_packet(&raw_packet, 1234567890.0, 0).await?;
49//!
50//! // Receive decoded data
51//! if let Some(decoded) = decoder.recv_decoded().await? {
52//!     println!("Decoded: {:?}", decoded);
53//! }
54//!
55//! decoder.disconnect().await?;
56//! # Ok(())
57//! # }
58//! ```
59
60use std::time::{SystemTime, UNIX_EPOCH};
61
62use anyhow::{anyhow, Result};
63use base64::Engine;
64use futures::{SinkExt, StreamExt};
65use log::{debug, error, info, warn};
66use serde_json::{json, Value};
67use tokio::sync::mpsc;
68use tokio_tungstenite::{connect_async, tungstenite::Message};
69
70const WS_ENDPOINT: &str = "wss://ws-api.idun.cloud";
71const PLATFORM: &str = "SDK_RUST";
72
73fn now_ms() -> u64 {
74    SystemTime::now()
75        .duration_since(UNIX_EPOCH)
76        .unwrap_or_default()
77        .as_millis() as u64
78}
79
80/// Decoded EEG data received from the IDUN Cloud.
81#[derive(Debug, Clone)]
82pub struct CloudDecodedEeg {
83    /// The action that produced this data (e.g. `"liveStreamInsights"`).
84    pub action: String,
85    /// The full JSON message from the cloud (contains decoded samples,
86    /// stream type, timestamps, etc.).
87    pub data: Value,
88    /// Sequence number of the original packet that was sent.
89    pub sequence: Option<u64>,
90}
91
92/// State of the cloud recording session.
93#[derive(Debug, Clone, PartialEq)]
94enum SessionState {
95    /// Not connected.
96    Disconnected,
97    /// WebSocket connected, waiting for recording to start.
98    Connected,
99    /// Recording started, waiting for data flow.
100    RecordingStarted {
101        recording_id: String,
102    },
103    /// Recording is ongoing, data is flowing.
104    RecordingOngoing {
105        recording_id: String,
106    },
107    /// Session ended.
108    Ended,
109}
110
111/// IDUN Cloud WebSocket client for server-side EEG packet decoding.
112///
113/// Manages the full lifecycle: connect → start recording → subscribe to
114/// live insights → send raw packets → receive decoded data → end recording.
115pub struct CloudDecoder {
116    api_token: String,
117    device_id: String,
118    state: SessionState,
119    /// Channel to send raw messages to the WebSocket writer task.
120    ws_tx: Option<mpsc::Sender<String>>,
121    /// Channel to receive decoded events from the WebSocket reader task.
122    decoded_rx: Option<mpsc::Receiver<CloudDecodedEeg>>,
123    /// Sequence counter for outgoing packets.
124    sequence: u64,
125    /// Whether the live insights subscription has been sent.
126    subscribed: bool,
127}
128
129impl CloudDecoder {
130    /// Create a new cloud decoder.
131    ///
132    /// # Arguments
133    /// * `api_token` — IDUN API token (or set `IDUN_API_TOKEN` env var)
134    /// * `device_id` — Guardian MAC address (format: `"AA-BB-CC-DD-EE-FF"`)
135    pub fn new(api_token: String, device_id: String) -> Self {
136        Self {
137            api_token,
138            device_id,
139            state: SessionState::Disconnected,
140            ws_tx: None,
141            decoded_rx: None,
142            sequence: 0,
143            subscribed: false,
144        }
145    }
146
147    /// Create a cloud decoder from the `IDUN_API_TOKEN` environment variable.
148    ///
149    /// Returns `Err` if the env var is not set.
150    pub fn from_env(device_id: String) -> Result<Self> {
151        let token = std::env::var("IDUN_API_TOKEN")
152            .map_err(|_| anyhow!("IDUN_API_TOKEN environment variable not set. \
153                Set it with: export IDUN_API_TOKEN=your-token"))?;
154        Ok(Self::new(token, device_id))
155    }
156
157    /// Connect to the IDUN Cloud WebSocket and start a recording session.
158    ///
159    /// This method:
160    /// 1. Opens the WebSocket connection with authentication
161    /// 2. Spawns reader/writer tasks
162    /// 3. Sends `startNewRecording`
163    /// 4. Waits for the recording ID to be assigned
164    /// 5. Subscribes to `FILTERED_EEG` live stream insights
165    pub async fn connect(&mut self) -> Result<()> {
166        let ws_url = format!("{}?authorization={}", WS_ENDPOINT, self.api_token);
167
168        info!("[CLOUD] Connecting to IDUN Cloud at {WS_ENDPOINT}…");
169        let (ws_stream, _response) = connect_async(&ws_url).await
170            .map_err(|e| anyhow!("Failed to connect to IDUN Cloud: {e}"))?;
171        info!("[CLOUD] WebSocket connected");
172
173        let (mut ws_write, mut ws_read) = ws_stream.split();
174
175        // Channel for outgoing messages (to the WebSocket writer)
176        let (out_tx, mut out_rx) = mpsc::channel::<String>(256);
177        // Channel for decoded events (from the WebSocket reader)
178        let (decoded_tx, decoded_rx) = mpsc::channel::<CloudDecodedEeg>(256);
179
180        self.ws_tx = Some(out_tx.clone());
181        self.decoded_rx = Some(decoded_rx);
182        self.state = SessionState::Connected;
183
184        // ── Writer task: drain out_tx and send to WebSocket ──────────────
185        tokio::spawn(async move {
186            while let Some(msg) = out_rx.recv().await {
187                if let Err(e) = ws_write.send(Message::Text(msg.into())).await {
188                    error!("[CLOUD] WebSocket write error: {e}");
189                    break;
190                }
191            }
192            debug!("[CLOUD] Writer task ended");
193        });
194
195        // ── Reader task: read from WebSocket and dispatch ────────────────
196        let state_tx = mpsc::channel::<(String, String)>(16);
197        let mut state_rx = state_tx.1;
198        let recording_state_tx = state_tx.0;
199
200        tokio::spawn(async move {
201            while let Some(msg) = ws_read.next().await {
202                match msg {
203                    Ok(Message::Text(text)) => {
204                        // Cloud sends base64-encoded JSON
205                        let json_str = if let Ok(decoded_bytes) =
206                            base64::engine::general_purpose::STANDARD.decode(text.as_bytes())
207                        {
208                            String::from_utf8_lossy(&decoded_bytes).to_string()
209                        } else {
210                            // Maybe it's plain JSON
211                            text.to_string()
212                        };
213
214                        match serde_json::from_str::<Value>(&json_str) {
215                            Ok(event) => {
216                                let action = event
217                                    .get("action")
218                                    .and_then(|a| a.as_str())
219                                    .unwrap_or("")
220                                    .to_string();
221
222                                match action.as_str() {
223                                    "recordingUpdate" => {
224                                        let status = event
225                                            .get("message")
226                                            .and_then(|m| m.get("status"))
227                                            .and_then(|s| s.as_str())
228                                            .unwrap_or("");
229                                        let rec_id = event
230                                            .get("message")
231                                            .and_then(|m| m.get("recordingId"))
232                                            .and_then(|r| r.as_str())
233                                            .unwrap_or("")
234                                            .to_string();
235                                        info!("[CLOUD] Recording update: status={status} id={rec_id}");
236                                        let _ = recording_state_tx
237                                            .send((status.to_string(), rec_id))
238                                            .await;
239                                    }
240                                    "liveStreamInsights" => {
241                                        let seq = event.get("sequence")
242                                            .and_then(|s| s.as_u64());
243                                        let _ = decoded_tx
244                                            .send(CloudDecodedEeg {
245                                                action: action.clone(),
246                                                data: event,
247                                                sequence: seq,
248                                            })
249                                            .await;
250                                    }
251                                    "realtimePredictionsResponse" => {
252                                        let seq = event.get("sequence")
253                                            .and_then(|s| s.as_u64());
254                                        let _ = decoded_tx
255                                            .send(CloudDecodedEeg {
256                                                action: action.clone(),
257                                                data: event,
258                                                sequence: seq,
259                                            })
260                                            .await;
261                                    }
262                                    "clientError" => {
263                                        let msg = event
264                                            .get("message")
265                                            .and_then(|m| m.as_str())
266                                            .unwrap_or("unknown error");
267                                        error!("[CLOUD] Client error: {msg}");
268                                    }
269                                    other => {
270                                        debug!("[CLOUD] Unhandled action: {other}");
271                                    }
272                                }
273                            }
274                            Err(e) => {
275                                debug!("[CLOUD] JSON parse error: {e} | raw: {json_str}");
276                            }
277                        }
278                    }
279                    Ok(Message::Binary(bin)) => {
280                        // Try base64 decode then JSON parse
281                        if let Ok(decoded) =
282                            base64::engine::general_purpose::STANDARD.decode(&bin)
283                        {
284                            let json_str = String::from_utf8_lossy(&decoded);
285                            debug!("[CLOUD] Binary message: {json_str}");
286                        }
287                    }
288                    Ok(Message::Close(_)) => {
289                        info!("[CLOUD] WebSocket closed by server");
290                        break;
291                    }
292                    Err(e) => {
293                        error!("[CLOUD] WebSocket read error: {e}");
294                        break;
295                    }
296                    _ => {}
297                }
298            }
299            debug!("[CLOUD] Reader task ended");
300        });
301
302        // ── Send startNewRecording ───────────────────────────────────────
303        let start_msg = json!({
304            "version": 1,
305            "platform": PLATFORM,
306            "action": "startNewRecording",
307            "deviceId": self.device_id,
308            "deviceTs": now_ms(),
309        });
310        self.send_raw_json(&start_msg).await?;
311        info!("[CLOUD] Sent startNewRecording");
312
313        // ── Wait for recording ID ────────────────────────────────────────
314        let timeout = tokio::time::Duration::from_secs(15);
315        let mut recording_id = String::new();
316
317        let deadline = tokio::time::Instant::now() + timeout;
318        loop {
319            tokio::select! {
320                Some((status, rec_id)) = state_rx.recv() => {
321                    match status.as_str() {
322                        "NOT_STARTED" => {
323                            recording_id = rec_id;
324                            info!("[CLOUD] Recording ID assigned: {recording_id}");
325                            self.state = SessionState::RecordingStarted {
326                                recording_id: recording_id.clone(),
327                            };
328                        }
329                        "ONGOING" => {
330                            if recording_id.is_empty() {
331                                recording_id = rec_id;
332                            }
333                            info!("[CLOUD] Recording is ONGOING");
334                            self.state = SessionState::RecordingOngoing {
335                                recording_id: recording_id.clone(),
336                            };
337                            break;
338                        }
339                        "COMPLETED" | "FAILED" => {
340                            warn!("[CLOUD] Recording ended with status: {status}");
341                            self.state = SessionState::Ended;
342                            return Err(anyhow!("Recording ended unexpectedly: {status}"));
343                        }
344                        _ => {
345                            debug!("[CLOUD] Unexpected recording status: {status}");
346                        }
347                    }
348                }
349                _ = tokio::time::sleep_until(deadline) => {
350                    // If we got a recording ID but never saw ONGOING, proceed anyway
351                    if !recording_id.is_empty() {
352                        info!("[CLOUD] Proceeding with recording ID (no ONGOING received)");
353                        self.state = SessionState::RecordingOngoing {
354                            recording_id: recording_id.clone(),
355                        };
356                        break;
357                    }
358                    return Err(anyhow!("Timed out waiting for recording ID from cloud"));
359                }
360            }
361        }
362
363        // ── Subscribe to live stream insights ────────────────────────────
364        self.subscribe_live_insights(&recording_id).await?;
365
366        // Spawn a task to keep draining state_rx so it doesn't block
367        tokio::spawn(async move {
368            while state_rx.recv().await.is_some() {}
369        });
370
371        Ok(())
372    }
373
374    /// Subscribe to FILTERED_EEG and RAW_EEG live stream insights.
375    async fn subscribe_live_insights(&mut self, recording_id: &str) -> Result<()> {
376        let msg = json!({
377            "version": 1,
378            "platform": PLATFORM,
379            "action": "subscribeLiveStreamInsights",
380            "deviceId": self.device_id,
381            "deviceTs": now_ms(),
382            "recordingId": recording_id,
383            "streamsTypes": ["RAW_EEG", "FILTERED_EEG"],
384        });
385        self.send_raw_json(&msg).await?;
386        self.subscribed = true;
387        info!("[CLOUD] Subscribed to live stream insights (RAW_EEG, FILTERED_EEG)");
388        Ok(())
389    }
390
391    /// Send a raw BLE packet to the cloud for decoding.
392    ///
393    /// The packet is base64-encoded and wrapped in a `publishRawMeasurements`
394    /// message, matching the format used by the official Python SDK.
395    ///
396    /// Returns `Err` if the WebSocket is not connected or the send fails.
397    pub async fn send_raw_packet(
398        &mut self,
399        raw_data: &[u8],
400        device_ts: f64,
401        sequence: u64,
402    ) -> Result<()> {
403        let recording_id = match &self.state {
404            SessionState::RecordingOngoing { recording_id } => recording_id.clone(),
405            SessionState::RecordingStarted { recording_id } => recording_id.clone(),
406            _ => return Err(anyhow!("Cloud session not active (state: {:?})", self.state)),
407        };
408
409        let b64 = base64::engine::general_purpose::STANDARD.encode(raw_data);
410
411        let msg = json!({
412            "version": 1,
413            "platform": PLATFORM,
414            "action": "publishRawMeasurements",
415            "deviceId": self.device_id,
416            "deviceTs": device_ts as u64,
417            "event": b64,
418            "recordingId": recording_id,
419            "sequence": sequence,
420        });
421
422        self.send_raw_json(&msg).await?;
423        self.sequence = sequence + 1;
424        Ok(())
425    }
426
427    /// Try to receive the next decoded EEG event from the cloud.
428    ///
429    /// Returns `Ok(None)` if no data is available yet (non-blocking).
430    /// Returns `Ok(Some(decoded))` with the cloud-decoded data.
431    /// Returns `Err` if the channel is closed (cloud disconnected).
432    pub fn try_recv_decoded(&mut self) -> Result<Option<CloudDecodedEeg>> {
433        if let Some(ref mut rx) = self.decoded_rx {
434            match rx.try_recv() {
435                Ok(decoded) => Ok(Some(decoded)),
436                Err(mpsc::error::TryRecvError::Empty) => Ok(None),
437                Err(mpsc::error::TryRecvError::Disconnected) => {
438                    Err(anyhow!("Cloud decoder channel closed"))
439                }
440            }
441        } else {
442            Ok(None)
443        }
444    }
445
446    /// Receive the next decoded EEG event from the cloud (blocking).
447    ///
448    /// Returns `Ok(None)` if the cloud disconnected.
449    pub async fn recv_decoded(&mut self) -> Result<Option<CloudDecodedEeg>> {
450        if let Some(ref mut rx) = self.decoded_rx {
451            Ok(rx.recv().await)
452        } else {
453            Ok(None)
454        }
455    }
456
457    /// Send a raw JSON message to the cloud WebSocket.
458    async fn send_raw_json(&self, msg: &Value) -> Result<()> {
459        if let Some(ref tx) = self.ws_tx {
460            let text = serde_json::to_string(msg)?;
461            tx.send(text)
462                .await
463                .map_err(|e| anyhow!("Failed to send to cloud: {e}"))?;
464            Ok(())
465        } else {
466            Err(anyhow!("WebSocket not connected"))
467        }
468    }
469
470    /// End the recording session and close the WebSocket.
471    pub async fn disconnect(&mut self) -> Result<()> {
472        let recording_id = match &self.state {
473            SessionState::RecordingOngoing { recording_id }
474            | SessionState::RecordingStarted { recording_id } => recording_id.clone(),
475            _ => {
476                self.state = SessionState::Disconnected;
477                self.ws_tx = None;
478                self.decoded_rx = None;
479                return Ok(());
480            }
481        };
482
483        let msg = json!({
484            "version": 1,
485            "platform": PLATFORM,
486            "action": "endOngoingRecording",
487            "deviceId": self.device_id,
488            "deviceTs": now_ms(),
489            "recordingId": recording_id,
490        });
491
492        if let Err(e) = self.send_raw_json(&msg).await {
493            warn!("[CLOUD] Error sending endOngoingRecording: {e}");
494        } else {
495            info!("[CLOUD] Sent endOngoingRecording for {recording_id}");
496        }
497
498        // Give the server a moment to process
499        tokio::time::sleep(tokio::time::Duration::from_millis(500)).await;
500
501        self.state = SessionState::Ended;
502        self.ws_tx = None;
503        self.decoded_rx = None;
504        self.subscribed = false;
505
506        info!("[CLOUD] Disconnected from IDUN Cloud");
507        Ok(())
508    }
509
510    /// Check if the cloud session is active and ready to send/receive data.
511    pub fn is_connected(&self) -> bool {
512        matches!(
513            self.state,
514            SessionState::RecordingOngoing { .. } | SessionState::RecordingStarted { .. }
515        )
516    }
517
518    /// Get the current recording ID, if a session is active.
519    pub fn recording_id(&self) -> Option<&str> {
520        match &self.state {
521            SessionState::RecordingOngoing { recording_id }
522            | SessionState::RecordingStarted { recording_id } => Some(recording_id),
523            _ => None,
524        }
525    }
526}