Skip to main content

fraiseql_wire/connection/
state.rs

1//! Connection state machine
2
3use crate::{Error, Result};
4
5/// Connection state
6#[derive(Debug, Clone, Copy, PartialEq, Eq)]
7pub enum ConnectionState {
8    /// Initial state (not connected)
9    Initial,
10
11    /// TLS negotiation in progress (SSLRequest sent, awaiting S/N response)
12    NegotiatingTls,
13
14    /// Startup sent, awaiting authentication request
15    AwaitingAuth,
16
17    /// Authentication in progress
18    Authenticating,
19
20    /// Idle (ready for query)
21    Idle,
22
23    /// Query in progress
24    QueryInProgress,
25
26    /// Reading query results
27    ReadingResults,
28
29    /// Closed
30    Closed,
31}
32
33impl ConnectionState {
34    /// Check if transition is valid
35    pub fn can_transition_to(&self, next: ConnectionState) -> bool {
36        use ConnectionState::*;
37
38        matches!(
39            (self, next),
40            (Initial, NegotiatingTls)
41                | (Initial, AwaitingAuth)
42                | (NegotiatingTls, AwaitingAuth)
43                | (AwaitingAuth, Authenticating)
44                | (Authenticating, Idle)
45                | (Idle, QueryInProgress)
46                | (QueryInProgress, ReadingResults)
47                | (ReadingResults, Idle)
48                | (_, Closed)
49        )
50    }
51
52    /// Transition to new state
53    pub fn transition(&mut self, next: ConnectionState) -> Result<()> {
54        if !self.can_transition_to(next) {
55            return Err(Error::InvalidState {
56                expected: format!("valid transition from {:?}", self),
57                actual: format!("{:?}", next),
58            });
59        }
60        *self = next;
61        Ok(())
62    }
63}
64
65impl std::fmt::Display for ConnectionState {
66    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
67        match self {
68            Self::Initial => write!(f, "initial"),
69            Self::NegotiatingTls => write!(f, "negotiating_tls"),
70            Self::AwaitingAuth => write!(f, "awaiting_auth"),
71            Self::Authenticating => write!(f, "authenticating"),
72            Self::Idle => write!(f, "idle"),
73            Self::QueryInProgress => write!(f, "query_in_progress"),
74            Self::ReadingResults => write!(f, "reading_results"),
75            Self::Closed => write!(f, "closed"),
76        }
77    }
78}
79
80#[cfg(test)]
81mod tests {
82    use super::*;
83
84    #[test]
85    fn test_valid_transitions() {
86        let mut state = ConnectionState::Initial;
87        assert!(state.transition(ConnectionState::AwaitingAuth).is_ok());
88        assert!(state.transition(ConnectionState::Authenticating).is_ok());
89        assert!(state.transition(ConnectionState::Idle).is_ok());
90    }
91
92    #[test]
93    fn test_invalid_transition() {
94        let mut state = ConnectionState::Initial;
95        assert!(state.transition(ConnectionState::Idle).is_err());
96    }
97
98    #[test]
99    fn test_close_from_any_state() {
100        let mut state = ConnectionState::QueryInProgress;
101        assert!(state.transition(ConnectionState::Closed).is_ok());
102    }
103
104    #[test]
105    fn test_tls_negotiation_transitions() {
106        let mut state = ConnectionState::Initial;
107        assert!(state.transition(ConnectionState::NegotiatingTls).is_ok());
108        assert!(state.transition(ConnectionState::AwaitingAuth).is_ok());
109    }
110
111    #[test]
112    fn test_initial_can_skip_tls_negotiation() {
113        // When sslmode=disable, we skip NegotiatingTls
114        let mut state = ConnectionState::Initial;
115        assert!(state.transition(ConnectionState::AwaitingAuth).is_ok());
116    }
117
118    #[test]
119    fn test_invalid_tls_transition() {
120        let mut state = ConnectionState::Idle;
121        assert!(state.transition(ConnectionState::NegotiatingTls).is_err());
122    }
123}