use crate::error::{MastishkError, validate_dt};
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[non_exhaustive]
pub enum SleepStage {
Wake,
Nrem1,
Nrem2,
Nrem3,
Rem,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SleepState {
pub stage: SleepStage,
pub adenosine: f32,
pub sleep_debt: f32,
pub time_in_stage: f32,
pub total_sleep: f32,
pub cycles_completed: u32,
}
impl Default for SleepState {
fn default() -> Self {
Self {
stage: SleepStage::Wake,
adenosine: 0.3,
sleep_debt: 0.0,
time_in_stage: 0.0,
total_sleep: 0.0,
cycles_completed: 0,
}
}
}
impl SleepState {
#[inline]
#[must_use]
pub fn sleep_pressure(&self) -> f32 {
(self.adenosine * 0.7 + (self.sleep_debt / 24.0) * 0.3).clamp(0.0, 1.0)
}
#[inline]
#[must_use]
pub fn is_asleep(&self) -> bool {
self.stage != SleepStage::Wake
}
#[inline]
#[must_use]
pub fn recovery_multiplier(&self) -> f32 {
match self.stage {
SleepStage::Wake => 0.0,
SleepStage::Nrem1 => 0.3,
SleepStage::Nrem2 => 0.6,
SleepStage::Nrem3 => 1.0,
SleepStage::Rem => 0.5,
}
}
#[inline]
#[must_use]
pub fn consolidation_rate(&self) -> f32 {
match self.stage {
SleepStage::Wake => 0.0,
SleepStage::Nrem1 => 0.1,
SleepStage::Nrem2 => 0.3,
SleepStage::Nrem3 => 0.7,
SleepStage::Rem => 1.0,
}
}
#[inline]
pub fn tick_adenosine(&mut self, dt_hours: f32) -> Result<(), MastishkError> {
validate_dt(dt_hours)?;
tracing::trace!(dt_hours, stage = ?self.stage, adenosine = self.adenosine, "ticking adenosine");
if self.stage == SleepStage::Wake {
let alpha_w = 1.0 - (-dt_hours / 18.2).exp();
self.adenosine += (1.0 - self.adenosine) * alpha_w;
self.sleep_debt += dt_hours.max(0.0) * 0.125; } else {
self.adenosine *= (-dt_hours / 4.2).exp();
self.sleep_debt = (self.sleep_debt - dt_hours * 0.25).max(0.0);
self.total_sleep += dt_hours;
}
self.adenosine = self.adenosine.clamp(0.0, 1.0);
Ok(())
}
#[inline]
pub fn fall_asleep(&mut self) {
if self.stage == SleepStage::Wake {
self.stage = SleepStage::Nrem1;
self.time_in_stage = 0.0;
self.total_sleep = 0.0;
self.cycles_completed = 0;
tracing::debug!("sleep onset");
}
}
#[inline]
pub fn wake_up(&mut self) {
if self.stage != SleepStage::Wake {
self.stage = SleepStage::Wake;
self.time_in_stage = 0.0;
tracing::debug!(
cycles = self.cycles_completed,
total_sleep = self.total_sleep,
"woke up"
);
}
}
#[inline]
pub fn tick_stage_transitions(&mut self, dt_hours: f32) {
if self.stage == SleepStage::Wake {
return;
}
self.time_in_stage += dt_hours;
let nrem3_duration = if self.cycles_completed < 2 { 0.5 } else { 0.2 };
let rem_duration = if self.cycles_completed < 2 {
0.17
} else {
0.33
};
let transition = match self.stage {
SleepStage::Nrem1 if self.time_in_stage > 0.1 => Some(SleepStage::Nrem2),
SleepStage::Nrem2 if self.time_in_stage > 0.33 => {
if self.total_sleep < (self.cycles_completed as f32 + 0.5) * 1.5 {
Some(SleepStage::Nrem3)
} else {
Some(SleepStage::Rem)
}
}
SleepStage::Nrem3 if self.time_in_stage > nrem3_duration => Some(SleepStage::Nrem2),
SleepStage::Rem if self.time_in_stage > rem_duration => {
self.cycles_completed += 1;
Some(SleepStage::Nrem2) }
_ => None,
};
if let Some(next) = transition {
tracing::trace!(from = ?self.stage, to = ?next, cycle = self.cycles_completed, "sleep stage transition");
self.stage = next;
self.time_in_stage = 0.0;
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_default_is_awake() {
let s = SleepState::default();
assert_eq!(s.stage, SleepStage::Wake);
assert!(!s.is_asleep());
}
#[test]
fn test_sleep_pressure_rises() {
let mut s = SleepState::default();
let initial = s.sleep_pressure();
s.tick_adenosine(8.0).unwrap();
assert!(s.sleep_pressure() > initial);
}
#[test]
fn test_adenosine_clears_during_sleep() {
let mut s = SleepState {
adenosine: 0.8,
stage: SleepStage::Nrem3,
..Default::default()
};
s.tick_adenosine(4.0).unwrap();
assert!(s.adenosine < 0.8);
}
#[test]
fn test_recovery_multiplier() {
let mut s = SleepState::default();
assert!((s.recovery_multiplier() - 0.0).abs() < f32::EPSILON);
s.stage = SleepStage::Nrem3;
assert!((s.recovery_multiplier() - 1.0).abs() < f32::EPSILON);
}
#[test]
fn test_serde_roundtrip() {
let s = SleepState::default();
let json = serde_json::to_string(&s).unwrap();
let s2: SleepState = serde_json::from_str(&json).unwrap();
assert_eq!(s2.stage, SleepStage::Wake);
}
#[test]
fn test_negative_dt_rejected() {
let mut s = SleepState::default();
assert!(s.tick_adenosine(-1.0).is_err());
}
#[test]
fn test_consolidation_rate() {
let mut s = SleepState::default();
assert!((s.consolidation_rate() - 0.0).abs() < f32::EPSILON);
s.stage = SleepStage::Rem;
assert!((s.consolidation_rate() - 1.0).abs() < f32::EPSILON);
s.stage = SleepStage::Nrem3;
assert!((s.consolidation_rate() - 0.7).abs() < f32::EPSILON);
}
#[test]
fn test_total_sleep_accumulates() {
let mut s = SleepState {
stage: SleepStage::Nrem2,
..Default::default()
};
s.tick_adenosine(4.0).unwrap();
assert!((s.total_sleep - 4.0).abs() < f32::EPSILON);
}
#[test]
fn test_sleep_debt_accumulates_during_wake() {
let mut s = SleepState::default();
s.tick_adenosine(8.0).unwrap();
assert!(s.sleep_debt > 0.0);
}
#[test]
fn test_is_asleep_stages() {
let mut s = SleepState::default();
assert!(!s.is_asleep());
for stage in [
SleepStage::Nrem1,
SleepStage::Nrem2,
SleepStage::Nrem3,
SleepStage::Rem,
] {
s.stage = stage;
assert!(s.is_asleep());
}
}
#[test]
fn test_fall_asleep_transitions_to_nrem1() {
let mut s = SleepState::default();
s.fall_asleep();
assert_eq!(s.stage, SleepStage::Nrem1);
}
#[test]
fn test_wake_up_transitions_to_wake() {
let mut s = SleepState {
stage: SleepStage::Nrem3,
..Default::default()
};
s.wake_up();
assert_eq!(s.stage, SleepStage::Wake);
}
#[test]
fn test_ultradian_cycle_progresses() {
let mut s = SleepState::default();
s.fall_asleep();
for _ in 0..120 {
s.tick_stage_transitions(1.0 / 60.0);
s.tick_adenosine(1.0 / 60.0).unwrap();
}
assert_ne!(s.stage, SleepStage::Nrem1);
assert!(s.is_asleep());
}
#[test]
fn test_full_night_completes_cycles() {
let mut s = SleepState::default();
s.fall_asleep();
for _ in 0..480 {
s.tick_stage_transitions(1.0 / 60.0);
s.tick_adenosine(1.0 / 60.0).unwrap();
}
assert!(s.cycles_completed >= 3, "cycles={}", s.cycles_completed);
}
#[test]
fn test_no_transitions_during_wake() {
let mut s = SleepState::default();
s.tick_stage_transitions(1.0);
assert_eq!(s.stage, SleepStage::Wake);
}
}