Skip to main content

fraiseql_wire/connection/
state.rs

1//! Connection state machine
2
3use crate::{Error, Result};
4
5/// Connection state
6#[derive(Debug, Clone, Copy, PartialEq, Eq)]
7pub enum ConnectionState {
8    /// Initial state (not connected)
9    Initial,
10
11    /// Startup sent, awaiting authentication request
12    AwaitingAuth,
13
14    /// Authentication in progress
15    Authenticating,
16
17    /// Idle (ready for query)
18    Idle,
19
20    /// Query in progress
21    QueryInProgress,
22
23    /// Reading query results
24    ReadingResults,
25
26    /// Closed
27    Closed,
28}
29
30impl ConnectionState {
31    /// Check if transition is valid
32    pub const fn can_transition_to(&self, next: ConnectionState) -> bool {
33        use ConnectionState::{
34            Authenticating, AwaitingAuth, Closed, Idle, Initial, QueryInProgress, ReadingResults,
35        };
36
37        matches!(
38            (self, next),
39            (Initial, AwaitingAuth)
40                | (AwaitingAuth, Authenticating)
41                | (Authenticating | ReadingResults, Idle)
42                | (Idle, QueryInProgress)
43                | (QueryInProgress, ReadingResults)
44                | (_, Closed)
45        )
46    }
47
48    /// Transition to new state
49    ///
50    /// # Errors
51    ///
52    /// Returns `Error::InvalidState` if the transition from the current state to `next` is not allowed.
53    pub fn transition(&mut self, next: ConnectionState) -> Result<()> {
54        if !self.can_transition_to(next) {
55            return Err(Error::InvalidState {
56                expected: format!("valid transition from {:?}", self),
57                actual: format!("{:?}", next),
58            });
59        }
60        *self = next;
61        Ok(())
62    }
63}
64
65impl std::fmt::Display for ConnectionState {
66    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
67        match self {
68            Self::Initial => write!(f, "initial"),
69            Self::AwaitingAuth => write!(f, "awaiting_auth"),
70            Self::Authenticating => write!(f, "authenticating"),
71            Self::Idle => write!(f, "idle"),
72            Self::QueryInProgress => write!(f, "query_in_progress"),
73            Self::ReadingResults => write!(f, "reading_results"),
74            Self::Closed => write!(f, "closed"),
75        }
76    }
77}
78
79#[cfg(test)]
80mod tests {
81    use super::*;
82
83    #[test]
84    fn test_valid_transitions() {
85        let mut state = ConnectionState::Initial;
86        assert!(state.transition(ConnectionState::AwaitingAuth).is_ok());
87        assert!(state.transition(ConnectionState::Authenticating).is_ok());
88        assert!(state.transition(ConnectionState::Idle).is_ok());
89    }
90
91    #[test]
92    fn test_invalid_transition() {
93        let mut state = ConnectionState::Initial;
94        assert!(state.transition(ConnectionState::Idle).is_err());
95    }
96
97    #[test]
98    fn test_close_from_any_state() {
99        let mut state = ConnectionState::QueryInProgress;
100        assert!(state.transition(ConnectionState::Closed).is_ok());
101    }
102}