1use crate::base::{Action, Observation, Reward, State};
13use std::collections::BTreeMap;
14use std::fmt::Debug;
15
16#[derive(Debug, Clone, Copy, PartialEq, Eq)]
22pub enum EpisodeStatus {
23 Running,
25 Terminated,
27 Truncated,
29}
30
31impl EpisodeStatus {
32 pub const fn is_done(self) -> bool {
34 matches!(self, Self::Terminated | Self::Truncated)
35 }
36
37 pub const fn is_terminated(self) -> bool {
39 matches!(self, Self::Terminated)
40 }
41
42 pub const fn is_truncated(self) -> bool {
44 matches!(self, Self::Truncated)
45 }
46}
47
48#[derive(Debug, Clone, Default)]
54pub struct SnapshotMetadata {
55 pub components: BTreeMap<&'static str, f32>,
57 pub positions: BTreeMap<&'static str, [f32; 3]>,
59}
60
61impl SnapshotMetadata {
62 pub fn new() -> Self {
64 Self::default()
65 }
66
67 pub fn with(mut self, key: &'static str, value: f32) -> Self {
69 self.components.insert(key, value);
70 self
71 }
72
73 pub fn with_position(mut self, key: &'static str, xyz: [f32; 3]) -> Self {
75 self.positions.insert(key, xyz);
76 self
77 }
78}
79
80#[derive(Debug)]
86pub enum EnvironmentError {
87 InvalidAction(String),
89 RenderFailed(String),
91 IoError(std::io::Error),
93}
94
95impl std::error::Error for EnvironmentError {
96 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
97 match self {
98 EnvironmentError::IoError(io_err) => Some(io_err),
99 _ => None,
100 }
101 }
102}
103
104impl std::fmt::Display for EnvironmentError {
105 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
106 match self {
107 EnvironmentError::InvalidAction(action_error) => {
108 write!(f, "Invalid action: {}", action_error)
109 }
110 EnvironmentError::RenderFailed(render_error) => {
111 write!(f, "Render failed: {}", render_error)
112 }
113 EnvironmentError::IoError(io_err) => {
114 write!(f, "IO operation failed: {}", io_err)
115 }
116 }
117 }
118}
119
120impl From<std::io::Error> for EnvironmentError {
121 fn from(error: std::io::Error) -> Self {
122 EnvironmentError::IoError(error)
123 }
124}
125
126pub trait Snapshot<const R: usize>: Debug {
133 type ObservationType: Observation<R>;
135
136 type RewardType: Reward;
138
139 fn observation(&self) -> &Self::ObservationType;
141
142 fn reward(&self) -> &Self::RewardType;
144
145 fn status(&self) -> EpisodeStatus;
147
148 fn is_done(&self) -> bool {
150 self.status().is_done()
151 }
152
153 fn is_terminated(&self) -> bool {
155 self.status().is_terminated()
156 }
157
158 fn is_truncated(&self) -> bool {
160 self.status().is_truncated()
161 }
162
163 fn metadata(&self) -> Option<&SnapshotMetadata> {
165 None
166 }
167}
168
169#[derive(Debug, Clone)]
181pub struct SnapshotBase<const R: usize, ObservationType: Observation<R>, RewardType: Reward> {
182 pub observation: ObservationType,
184 pub reward: RewardType,
186 pub status: EpisodeStatus,
188}
189
190impl<const R: usize, ObservationType: Observation<R>, RewardType: Reward>
191 SnapshotBase<R, ObservationType, RewardType>
192{
193 pub fn running(observation: ObservationType, reward: RewardType) -> Self {
195 Self {
196 observation,
197 reward,
198 status: EpisodeStatus::Running,
199 }
200 }
201
202 pub fn terminated(observation: ObservationType, reward: RewardType) -> Self {
204 Self {
205 observation,
206 reward,
207 status: EpisodeStatus::Terminated,
208 }
209 }
210
211 pub fn truncated(observation: ObservationType, reward: RewardType) -> Self {
213 Self {
214 observation,
215 reward,
216 status: EpisodeStatus::Truncated,
217 }
218 }
219}
220
221impl<const R: usize, ObservationType: Observation<R>, RewardType: Reward> Snapshot<R>
222 for SnapshotBase<R, ObservationType, RewardType>
223{
224 type ObservationType = ObservationType;
225 type RewardType = RewardType;
226
227 fn observation(&self) -> &Self::ObservationType {
228 &self.observation
229 }
230
231 fn reward(&self) -> &Self::RewardType {
232 &self.reward
233 }
234
235 fn status(&self) -> EpisodeStatus {
236 self.status
237 }
238}
239
240pub trait Environment<const R: usize, const SR: usize, const AR: usize> {
260 type StateType: State<SR>;
262
263 type ObservationType: Observation<R>;
265
266 type ActionType: Action<AR>;
268
269 type RewardType: Reward;
271
272 type SnapshotType: Snapshot<R, ObservationType = Self::ObservationType, RewardType = Self::RewardType>;
274
275 fn reset(&mut self) -> Result<Self::SnapshotType, EnvironmentError>;
286
287 fn step(&mut self, action: Self::ActionType) -> Result<Self::SnapshotType, EnvironmentError>;
300}
301
302pub trait ConstructableEnv {
316 fn new(render: bool) -> Self;
323}
324
325#[cfg(test)]
326mod tests {
327 use serde::{Deserialize, Serialize};
328
329 use super::*;
330 use crate::action::DiscreteAction;
331
332 #[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
333 pub struct MockObservation {
334 position: i32,
336 }
337
338 impl Observation<1> for MockObservation {
339 fn shape() -> [usize; 1] {
340 [1]
341 }
342 }
343
344 #[derive(Debug, Clone, Copy, PartialEq, Eq)]
348 pub struct MockState {
349 position: i32,
351 }
352
353 impl MockState {
354 fn new(position: i32) -> Self {
355 Self { position }
356 }
357
358 fn is_in_bounds(position: i32) -> bool {
360 (0..=6).contains(&position)
361 }
362 }
363
364 impl State<1> for MockState {
365 type Observation = MockObservation;
366 fn numel(&self) -> usize {
367 7
368 }
369
370 fn shape() -> [usize; 1] {
371 [7]
372 }
373
374 fn is_valid(&self) -> bool {
375 Self::is_in_bounds(self.position)
376 }
377
378 fn observe(&self) -> Self::Observation {
379 MockObservation {
380 position: self.position,
381 }
382 }
383 }
384
385 #[derive(Debug, Clone, Copy, PartialEq, Eq)]
386 enum MockAction {
387 MoveLeft, MoveRight, }
390
391 impl Action<1> for MockAction {
392 fn is_valid(&self) -> bool {
393 true }
395
396 fn shape() -> [usize; 1] {
397 [1]
398 }
399 }
400
401 impl DiscreteAction<1> for MockAction {
402 const ACTION_COUNT: usize = 2;
403 fn from_index(index: usize) -> Self {
404 match index {
405 0 => MockAction::MoveLeft,
406 1 => MockAction::MoveRight,
407 _ => panic!("Unknown action index: {}", index),
408 }
409 }
410
411 fn to_index(&self) -> usize {
412 match self {
413 MockAction::MoveLeft => 0,
414 MockAction::MoveRight => 1,
415 }
416 }
417 }
418
419 use crate::reward::ScalarReward;
420
421 struct MockEnvironment {
426 current_state: MockState,
427 step_count: usize,
428 max_steps: usize,
429 }
430
431 impl MockEnvironment {
432 const START_STATE: i32 = 3;
433 const MAX_STEPS: usize = 20;
434 const GOAL_STATE: i32 = 6;
435
436 fn with_defaults(_render: bool) -> Self {
437 Self {
438 current_state: MockState::new(Self::START_STATE),
439 step_count: 0,
440 max_steps: Self::MAX_STEPS,
441 }
442 }
443 }
444
445 impl ConstructableEnv for MockEnvironment {
446 fn new(render: bool) -> Self {
447 Self::with_defaults(render)
448 }
449 }
450
451 impl Environment<1, 1, 1> for MockEnvironment {
452 type StateType = MockState;
453 type ObservationType = MockObservation;
454 type ActionType = MockAction;
455 type RewardType = ScalarReward;
456 type SnapshotType = SnapshotBase<1, MockObservation, ScalarReward>;
457
458 fn reset(&mut self) -> Result<Self::SnapshotType, EnvironmentError> {
459 self.current_state = MockState::new(Self::START_STATE);
460 self.step_count = 0;
461 Ok(SnapshotBase::running(
462 self.current_state.observe(),
463 ScalarReward(0.0),
464 ))
465 }
466
467 fn step(
468 &mut self,
469 action: Self::ActionType,
470 ) -> Result<Self::SnapshotType, EnvironmentError> {
471 if !action.is_valid() {
472 return Err(EnvironmentError::InvalidAction(format!(
473 "Invalid action: {:?}.",
474 action
475 )));
476 }
477
478 let next_position = if action == MockAction::MoveLeft {
480 self.current_state.position - 1 } else {
482 self.current_state.position + 1 };
484
485 let (new_state, reward, terminated) = if next_position < 0 {
487 (MockState::new(0), -1.0, true)
488 } else if next_position > 6 {
489 (MockState::new(6), -1.0, true)
490 } else {
491 let new_state = MockState::new(next_position);
492 let reward = if next_position == Self::GOAL_STATE {
493 1.0
494 } else {
495 0.0
496 };
497 let done = next_position == Self::GOAL_STATE;
498 (new_state, reward, done)
499 };
500
501 self.current_state = new_state;
502 self.step_count += 1;
503
504 let status = if terminated {
505 EpisodeStatus::Terminated
506 } else if self.step_count >= self.max_steps {
507 EpisodeStatus::Truncated
508 } else {
509 EpisodeStatus::Running
510 };
511
512 Ok(SnapshotBase {
513 observation: new_state.observe(),
514 reward: ScalarReward(reward),
515 status,
516 })
517 }
518 }
519
520 #[derive(Debug, Clone)]
522 pub struct CustomSnapshot {
523 observation: MockObservation,
524 reward: ScalarReward,
525 status: EpisodeStatus,
526 step_count: usize,
527 cumulative_reward: f32,
528 }
529
530 impl Snapshot<1> for CustomSnapshot {
531 type ObservationType = MockObservation;
532 type RewardType = ScalarReward;
533
534 fn observation(&self) -> &MockObservation {
535 &self.observation
536 }
537
538 fn reward(&self) -> &ScalarReward {
539 &self.reward
540 }
541
542 fn status(&self) -> EpisodeStatus {
543 self.status
544 }
545 }
546
547 #[test]
549 fn test_snapshot_base_creation() {
550 let obs = MockObservation { position: 42 };
551 let snapshot = SnapshotBase::running(obs, ScalarReward(1.5));
552
553 assert_eq!(snapshot.observation(), &obs);
554 assert_eq!(snapshot.reward(), &ScalarReward(1.5));
555 assert!(!snapshot.is_done());
556 assert_eq!(snapshot.status(), EpisodeStatus::Running);
557 }
558
559 #[test]
560 fn test_snapshot_base_terminal() {
561 let obs = MockObservation { position: 0 };
562 let snapshot = SnapshotBase::terminated(obs, ScalarReward(-1.0));
563
564 assert!(snapshot.is_done());
565 assert!(snapshot.is_terminated());
566 assert!(!snapshot.is_truncated());
567 assert_eq!(snapshot.reward(), &ScalarReward(-1.0));
568 }
569
570 #[test]
571 fn test_snapshot_base_clone() {
572 let obs = MockObservation { position: 10 };
573 let snapshot1 = SnapshotBase::running(obs, ScalarReward(0.5));
574 let snapshot2 = snapshot1.clone();
575
576 assert_eq!(snapshot1.observation(), snapshot2.observation());
577 assert_eq!(snapshot1.reward(), snapshot2.reward());
578 assert_eq!(snapshot1.is_done(), snapshot2.is_done());
579 }
580
581 #[test]
582 fn test_snapshot_debug() {
583 let obs = MockObservation { position: 5 };
584 let snapshot = SnapshotBase::terminated(obs, ScalarReward(2.0));
585 let debug_str = format!("{:?}", snapshot);
586
587 assert!(debug_str.contains("SnapshotBase"));
588 assert!(debug_str.contains("position: 5"));
589 assert!(debug_str.contains("reward: ScalarReward(2.0)"));
590 assert!(debug_str.contains("Terminated"));
591 }
592
593 #[test]
595 fn test_custom_snapshot_trait_impl() {
596 let snapshot = CustomSnapshot {
597 observation: MockObservation { position: 1 },
598 reward: ScalarReward(10.0),
599 status: EpisodeStatus::Running,
600 step_count: 5,
601 cumulative_reward: 25.0,
602 };
603
604 assert_eq!(snapshot.observation().position, 1);
606 assert_eq!(snapshot.reward(), &ScalarReward(10.0));
607 assert!(!snapshot.is_done());
608
609 assert_eq!(snapshot.step_count, 5);
611 assert_eq!(snapshot.cumulative_reward, 25.0);
612 }
613
614 #[test]
616 fn test_environment_creation() {
617 let env = MockEnvironment::new(false);
618 assert_eq!(env.step_count, 0);
619 }
620
621 #[test]
622 fn test_environment_reset() {
623 let mut env = MockEnvironment::new(false);
624 let snapshot = env.reset().expect("Reset should succeed");
625
626 assert_eq!(snapshot.observation().position, 3);
627 assert_eq!(snapshot.reward(), &ScalarReward(0.0));
628 assert!(!snapshot.is_done());
629 }
630
631 #[test]
632 fn test_environment_step_valid_action() {
633 let mut env = MockEnvironment::new(false);
634 env.reset().expect("Reset should succeed");
635
636 let action = MockAction::MoveRight;
637 let snapshot = env
638 .step(action)
639 .expect("Step with valid action should succeed");
640
641 assert_eq!(snapshot.observation().position, 4);
642 assert_eq!(snapshot.reward(), &ScalarReward(0.0));
643 }
644
645 #[test]
646 fn test_environment_episode_termination() {
647 let mut env = MockEnvironment::new(false);
648 env.reset().expect("Reset should succeed");
649 env.current_state.position = 0;
650
651 for i in 0..6 {
653 let action = MockAction::MoveRight;
654 let snapshot = env.step(action).expect("Step should succeed");
655
656 if i < 5 {
657 assert!(
658 !snapshot.is_done(),
659 "Episode should not be done before reaching goal"
660 );
661 } else {
662 assert!(
663 snapshot.is_done(),
664 "Episode should be done upon reaching goal"
665 );
666 }
667 }
668 }
669
670 #[test]
671 fn test_environment_reset_clears_state() {
672 let mut env = MockEnvironment::new(false);
673
674 env.reset().expect("Reset should succeed");
676 for _ in 0..5 {
677 let action = MockAction::MoveRight;
678 let _ = env.step(action);
679 }
680
681 let snapshot = env.reset().expect("Second reset should succeed");
683 assert_eq!(snapshot.observation().position, 3);
684 assert!(!snapshot.is_done());
685 }
686
687 #[test]
688 fn test_environment_error_display() {
689 let error = EnvironmentError::InvalidAction("test action".to_string());
690 let display_str = format!("{}", error);
691 assert!(display_str.contains("Invalid action"));
692 assert!(display_str.contains("test action"));
693 }
694
695 #[test]
696 fn test_environment_error_io_conversion() {
697 let io_error = std::io::Error::new(std::io::ErrorKind::NotFound, "file not found");
698 let env_error = EnvironmentError::from(io_error);
699
700 match env_error {
701 EnvironmentError::IoError(_) => {
702 }
704 _ => panic!("Expected IoError variant"),
705 }
706 }
707
708 #[test]
709 fn test_environment_error_source() {
710 let io_error = std::io::Error::new(std::io::ErrorKind::PermissionDenied, "access denied");
711 let env_error = EnvironmentError::IoError(io_error);
712
713 use std::error::Error;
714 assert!(env_error.source().is_some());
715 }
716
717 #[test]
718 fn test_environment_multiple_episodes() {
719 let mut env = MockEnvironment::new(false);
720
721 for _episode in 0..3 {
722 let mut snapshot = env.reset().expect("Reset should succeed");
723 let mut step = 0;
724
725 while !snapshot.is_done() && step < 5 {
726 let action = MockAction::MoveRight;
727 snapshot = env.step(action).expect("Step should succeed");
728 step += 1;
729 }
730 }
731 }
732
733 #[test]
734 fn test_snapshot_reward_conversion() {
735 let observation = MockObservation { position: 1 };
736 let snapshot = SnapshotBase::running(observation, ScalarReward(42.5));
737
738 let reward_as_f32: f32 = (*snapshot.reward()).into();
740 assert_eq!(reward_as_f32, 42.5);
741 }
742
743 #[test]
744 fn test_metadata_default_is_empty() {
745 let meta = SnapshotMetadata::default();
746 assert!(meta.components.is_empty());
747 assert!(meta.positions.is_empty());
748 }
749
750 #[test]
751 fn test_metadata_builder_components_and_positions() {
752 let meta = SnapshotMetadata::new()
753 .with("forward", 1.25)
754 .with("ctrl", -0.1)
755 .with_position("torso", [0.5, 0.0, 1.1])
756 .with_position("com", [0.4, 0.0, 0.9]);
757
758 assert_eq!(meta.components.len(), 2);
759 assert_eq!(meta.components.get("forward"), Some(&1.25));
760 assert_eq!(meta.positions.len(), 2);
761 assert_eq!(meta.positions.get("torso"), Some(&[0.5, 0.0, 1.1]));
762 }
763}