1pub mod circadian;
13pub mod consolidation;
14pub mod rem;
15pub mod spindles;
16pub mod sws;
17
18pub use circadian::{CircadianRhythm, TimeOfDay};
19pub use consolidation::{ConsolidationEvent, MemoryConsolidator};
20pub use rem::{DreamContent, REMSleep};
21pub use spindles::{KComplex, SleepSpindle, SpindleGenerator};
22pub use sws::{SlowWave, SlowWaveSleep};
23
24use serde::{Deserialize, Serialize};
25use std::collections::VecDeque;
26use thiserror::Error;
27
28#[derive(Debug, Error)]
30pub enum SleepError {
31 #[error("Invalid sleep state transition: {0}")]
32 InvalidTransition(String),
33
34 #[error("Sleep cycle error: {0}")]
35 CycleError(String),
36
37 #[error("Consolidation error: {0}")]
38 ConsolidationError(String),
39}
40
41pub type Result<T> = std::result::Result<T, SleepError>;
42
43#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
45pub enum SleepStage {
46 Wake,
48 N1,
50 N2,
52 N3,
54 REM,
56}
57
58impl SleepStage {
59 pub fn typical_duration(&self) -> f64 {
61 match self {
62 Self::Wake => 0.0,
63 Self::N1 => 5.0,
64 Self::N2 => 20.0,
65 Self::N3 => 30.0,
66 Self::REM => 20.0,
67 }
68 }
69
70 pub fn consolidation_strength(&self) -> f64 {
72 match self {
73 Self::Wake => 0.0,
74 Self::N1 => 0.1,
75 Self::N2 => 0.3,
76 Self::N3 => 1.0, Self::REM => 0.7, }
79 }
80
81 pub fn is_sleeping(&self) -> bool {
83 *self != Self::Wake
84 }
85}
86
87#[derive(Debug, Clone, Serialize, Deserialize)]
89pub struct SleepConfig {
90 pub cycle_duration_hours: f64,
92 pub cycles_per_night: usize,
94 pub wake_threshold: f64,
96 pub pressure_decay_rate: f64,
98 pub pressure_build_rate: f64,
100 pub rem_rebound_factor: f64,
102}
103
104impl Default for SleepConfig {
105 fn default() -> Self {
106 Self {
107 cycle_duration_hours: 1.5,
108 cycles_per_night: 5,
109 wake_threshold: 1.0,
110 pressure_decay_rate: 0.1,
111 pressure_build_rate: 0.05,
112 rem_rebound_factor: 0.3,
113 }
114 }
115}
116
117#[derive(Debug, Clone, Serialize, Deserialize)]
119pub struct SleepEvent {
120 pub event_type: SleepEventType,
122 pub timestamp: u64,
124 pub stage: SleepStage,
126 pub data: serde_json::Value,
128}
129
130#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
132pub enum SleepEventType {
133 StageTransition,
134 SpindleBurst,
135 SlowWave,
136 KComplex,
137 REMBurst,
138 Arousal,
139 Consolidation,
140}
141
142pub struct SleepController {
144 config: SleepConfig,
145 current_stage: SleepStage,
147 stage_time: f64,
149 current_cycle: usize,
151 sleep_pressure: f64,
153 circadian: CircadianRhythm,
155 sws: SlowWaveSleep,
157 rem: REMSleep,
159 spindles: SpindleGenerator,
161 consolidator: MemoryConsolidator,
163 events: VecDeque<SleepEvent>,
165 total_sleep_time: f64,
167 time_awake: f64,
169}
170
171impl SleepController {
172 pub fn new() -> Self {
174 Self::with_config(SleepConfig::default())
175 }
176
177 pub fn with_config(config: SleepConfig) -> Self {
179 Self {
180 config,
181 current_stage: SleepStage::Wake,
182 stage_time: 0.0,
183 current_cycle: 0,
184 sleep_pressure: 0.5,
185 circadian: CircadianRhythm::new(),
186 sws: SlowWaveSleep::new(),
187 rem: REMSleep::new(),
188 spindles: SpindleGenerator::new(),
189 consolidator: MemoryConsolidator::new(),
190 events: VecDeque::with_capacity(1000),
191 total_sleep_time: 0.0,
192 time_awake: 480.0, }
194 }
195
196 pub fn step(&mut self, dt_minutes: f64) -> Vec<SleepEvent> {
198 let mut new_events = Vec::new();
199
200 if self.current_stage == SleepStage::Wake {
201 self.time_awake += dt_minutes;
203 self.sleep_pressure += self.config.pressure_build_rate * dt_minutes / 60.0;
204 self.sleep_pressure = self.sleep_pressure.min(2.0);
205 } else {
206 self.stage_time += dt_minutes;
208 self.total_sleep_time += dt_minutes;
209
210 self.sleep_pressure -= self.config.pressure_decay_rate * dt_minutes / 60.0;
212 self.sleep_pressure = self.sleep_pressure.max(0.0);
213
214 match self.current_stage {
216 SleepStage::N2 => {
217 if let Some(spindle) = self.spindles.step(dt_minutes) {
219 new_events.push(SleepEvent {
220 event_type: SleepEventType::SpindleBurst,
221 timestamp: self.now(),
222 stage: self.current_stage,
223 data: serde_json::to_value(&spindle).unwrap_or_default(),
224 });
225
226 let consolidation = self.consolidator.process_spindle(&spindle);
228 new_events.push(SleepEvent {
229 event_type: SleepEventType::Consolidation,
230 timestamp: self.now(),
231 stage: self.current_stage,
232 data: serde_json::to_value(&consolidation).unwrap_or_default(),
233 });
234 }
235 }
236 SleepStage::N3 => {
237 if let Some(wave) = self.sws.step(dt_minutes) {
239 new_events.push(SleepEvent {
240 event_type: SleepEventType::SlowWave,
241 timestamp: self.now(),
242 stage: self.current_stage,
243 data: serde_json::to_value(&wave).unwrap_or_default(),
244 });
245
246 let consolidation =
248 self.consolidator.process_slow_wave(&wave);
249 new_events.push(SleepEvent {
250 event_type: SleepEventType::Consolidation,
251 timestamp: self.now(),
252 stage: self.current_stage,
253 data: serde_json::to_value(&consolidation).unwrap_or_default(),
254 });
255 }
256 }
257 SleepStage::REM => {
258 if let Some(dream) = self.rem.step(dt_minutes) {
260 new_events.push(SleepEvent {
261 event_type: SleepEventType::REMBurst,
262 timestamp: self.now(),
263 stage: self.current_stage,
264 data: serde_json::to_value(&dream).unwrap_or_default(),
265 });
266
267 let consolidation = self.consolidator.process_dream(&dream);
269 new_events.push(SleepEvent {
270 event_type: SleepEventType::Consolidation,
271 timestamp: self.now(),
272 stage: self.current_stage,
273 data: serde_json::to_value(&consolidation).unwrap_or_default(),
274 });
275 }
276 }
277 _ => {}
278 }
279
280 if self.should_transition() {
282 let next_stage = self.next_stage();
283 if next_stage != self.current_stage {
284 new_events.push(SleepEvent {
285 event_type: SleepEventType::StageTransition,
286 timestamp: self.now(),
287 stage: next_stage,
288 data: serde_json::json!({
289 "from": self.current_stage,
290 "to": next_stage,
291 "cycle": self.current_cycle
292 }),
293 });
294
295 self.transition_to(next_stage);
296 }
297 }
298 }
299
300 for event in &new_events {
302 self.events.push_back(event.clone());
303 if self.events.len() > 1000 {
304 self.events.pop_front();
305 }
306 }
307
308 new_events
309 }
310
311 pub fn fall_asleep(&mut self) -> Result<()> {
313 if self.current_stage != SleepStage::Wake {
314 return Err(SleepError::InvalidTransition("Already sleeping".to_string()));
315 }
316
317 self.current_stage = SleepStage::N1;
318 self.stage_time = 0.0;
319 self.current_cycle = 1;
320 self.time_awake = 0.0;
321
322 Ok(())
323 }
324
325 pub fn wake_up(&mut self) -> Result<()> {
327 if self.current_stage == SleepStage::Wake {
328 return Err(SleepError::InvalidTransition("Already awake".to_string()));
329 }
330
331 self.current_stage = SleepStage::Wake;
332 self.stage_time = 0.0;
333
334 Ok(())
335 }
336
337 fn should_transition(&self) -> bool {
339 let typical_duration = self.current_stage.typical_duration();
340 self.stage_time >= typical_duration
341 }
342
343 fn next_stage(&self) -> SleepStage {
345 match self.current_stage {
347 SleepStage::Wake => SleepStage::Wake,
348 SleepStage::N1 => SleepStage::N2,
349 SleepStage::N2 => {
350 if self.current_cycle <= 2 {
352 SleepStage::N3
353 } else {
354 let rem_probability =
356 0.3 + (self.current_cycle as f64 * self.config.rem_rebound_factor);
357 if rand::random::<f64>() < rem_probability {
358 SleepStage::REM
359 } else {
360 SleepStage::N3
361 }
362 }
363 }
364 SleepStage::N3 => {
365 if self.sleep_pressure < 0.3 {
367 SleepStage::REM
368 } else {
369 SleepStage::N2
370 }
371 }
372 SleepStage::REM => {
373 if self.should_wake() {
375 SleepStage::Wake
376 } else {
377 SleepStage::N1
379 }
380 }
381 }
382 }
383
384 fn transition_to(&mut self, stage: SleepStage) {
386 if self.current_stage == SleepStage::REM && stage == SleepStage::N1 {
388 self.current_cycle += 1;
389 }
390
391 self.current_stage = stage;
392 self.stage_time = 0.0;
393 }
394
395 fn should_wake(&self) -> bool {
397 self.current_cycle >= self.config.cycles_per_night
399 && self.sleep_pressure < self.config.wake_threshold
400 }
401
402 fn now(&self) -> u64 {
404 std::time::SystemTime::now()
405 .duration_since(std::time::UNIX_EPOCH)
406 .unwrap_or_default()
407 .as_millis() as u64
408 }
409
410 pub fn current_stage(&self) -> SleepStage {
412 self.current_stage
413 }
414
415 pub fn current_cycle(&self) -> usize {
417 self.current_cycle
418 }
419
420 pub fn sleep_pressure(&self) -> f64 {
422 self.sleep_pressure
423 }
424
425 pub fn total_sleep_time(&self) -> f64 {
427 self.total_sleep_time
428 }
429
430 pub fn time_awake(&self) -> f64 {
432 self.time_awake
433 }
434
435 pub fn should_sleep(&self) -> bool {
437 let circadian_drive = self.circadian.current_sleep_drive();
438 self.sleep_pressure > 0.7 && circadian_drive > 0.5
439 }
440
441 pub fn stats(&self) -> SleepStats {
443 SleepStats {
444 current_stage: self.current_stage,
445 current_cycle: self.current_cycle,
446 total_sleep_time: self.total_sleep_time,
447 time_awake: self.time_awake,
448 sleep_pressure: self.sleep_pressure,
449 consolidation_count: self.consolidator.consolidation_count(),
450 event_count: self.events.len(),
451 }
452 }
453
454 pub fn reset(&mut self) {
456 self.current_stage = SleepStage::Wake;
457 self.stage_time = 0.0;
458 self.current_cycle = 0;
459 self.sleep_pressure = 0.5;
460 self.total_sleep_time = 0.0;
461 self.time_awake = 0.0;
462 self.events.clear();
463 self.consolidator.reset();
464 }
465
466 pub fn add_memories(&mut self, memories: Vec<Vec<f64>>) {
468 self.consolidator.add_memories(memories);
469 }
470
471 pub fn get_consolidated(&self) -> Vec<Vec<f64>> {
473 self.consolidator.get_consolidated()
474 }
475}
476
477impl Default for SleepController {
478 fn default() -> Self {
479 Self::new()
480 }
481}
482
483#[derive(Debug, Clone, Serialize, Deserialize)]
485pub struct SleepStats {
486 pub current_stage: SleepStage,
487 pub current_cycle: usize,
488 pub total_sleep_time: f64,
489 pub time_awake: f64,
490 pub sleep_pressure: f64,
491 pub consolidation_count: usize,
492 pub event_count: usize,
493}
494
495#[cfg(test)]
496mod tests {
497 use super::*;
498
499 #[test]
500 fn test_sleep_controller_creation() {
501 let controller = SleepController::new();
502 assert_eq!(controller.current_stage(), SleepStage::Wake);
503 }
504
505 #[test]
506 fn test_fall_asleep() {
507 let mut controller = SleepController::new();
508
509 controller.fall_asleep().unwrap();
510 assert_eq!(controller.current_stage(), SleepStage::N1);
511 assert_eq!(controller.current_cycle(), 1);
512 }
513
514 #[test]
515 fn test_wake_up() {
516 let mut controller = SleepController::new();
517
518 controller.fall_asleep().unwrap();
519 controller.wake_up().unwrap();
520 assert_eq!(controller.current_stage(), SleepStage::Wake);
521 }
522
523 #[test]
524 fn test_sleep_stages() {
525 assert!(SleepStage::N3.consolidation_strength() > SleepStage::N1.consolidation_strength());
526 assert!(SleepStage::N3.is_sleeping());
527 assert!(!SleepStage::Wake.is_sleeping());
528 }
529
530 #[test]
531 fn test_sleep_cycle() {
532 let mut controller = SleepController::new();
533 controller.fall_asleep().unwrap();
534
535 for _ in 0..60 {
537 controller.step(1.0);
538 }
539
540 assert!(controller.total_sleep_time() >= 60.0);
541 }
542
543 #[test]
544 fn test_sleep_pressure() {
545 let mut controller = SleepController::new();
546
547 for _ in 0..100 {
549 controller.step(10.0);
550 }
551
552 let pressure_awake = controller.sleep_pressure();
553
554 controller.fall_asleep().unwrap();
555
556 for _ in 0..60 {
558 controller.step(1.0);
559 }
560
561 assert!(controller.sleep_pressure() < pressure_awake);
562 }
563}