Skip to main content

po_session/
state.rs

1//! Connection state machine for PO sessions.
2
3use crate::framer::Framer;
4use crate::handshake::{self, HandshakeError};
5use po_crypto::aead::SessionCipher;
6use po_crypto::identity::{Identity, NodeId};
7use po_transport::traits::AsyncFrameTransport;
8use po_wire::{FrameHeader, FrameType};
9
10/// The lifecycle state of a PO connection.
11#[derive(Debug, Clone, Copy, PartialEq, Eq)]
12pub enum SessionState {
13    /// Connection established at transport level, no handshake yet.
14    New,
15    /// Handshake in progress.
16    Handshaking,
17    /// Handshake complete, encrypted session active.
18    Established,
19    /// Graceful close initiated.
20    Closing,
21    /// Connection fully closed.
22    Closed,
23}
24
25/// A fully managed PO session over any transport.
26///
27/// Handles the handshake, encryption, and frame IO.
28pub struct Session {
29    state: SessionState,
30    framer: Framer,
31    cipher: Option<SessionCipher>,
32    identity: Identity,
33    peer_node_id: Option<NodeId>,
34    peer_pubkey: Option<[u8; 32]>,
35}
36
37impl Session {
38    /// Create a new session with the given identity.
39    pub fn new(identity: Identity) -> Self {
40        Self {
41            state: SessionState::New,
42            framer: Framer::new(),
43            cipher: None,
44            identity,
45            peer_node_id: None,
46            peer_pubkey: None,
47        }
48    }
49
50    /// Get the current session state.
51    pub fn state(&self) -> SessionState {
52        self.state
53    }
54
55    /// Get our own node ID.
56    pub fn node_id(&self) -> &NodeId {
57        self.identity.node_id()
58    }
59
60    /// Get the peer's node ID (available after handshake).
61    pub fn peer_node_id(&self) -> Option<&NodeId> {
62        self.peer_node_id.as_ref()
63    }
64
65    /// Perform the handshake as the initiator (client).
66    pub async fn handshake_initiator(
67        &mut self,
68        transport: &mut dyn AsyncFrameTransport,
69    ) -> Result<(), HandshakeError> {
70        self.state = SessionState::Handshaking;
71
72        let result =
73            handshake::perform_handshake_initiator(&self.identity, transport, &mut self.framer)
74                .await?;
75
76        self.cipher = Some(result.cipher);
77        self.peer_pubkey = Some(result.peer_pubkey);
78        self.peer_node_id = Some(result.peer_node_id);
79        self.state = SessionState::Established;
80
81        Ok(())
82    }
83
84    /// Perform the handshake as the responder (server).
85    pub async fn handshake_responder(
86        &mut self,
87        transport: &mut dyn AsyncFrameTransport,
88    ) -> Result<(), HandshakeError> {
89        self.state = SessionState::Handshaking;
90
91        let result =
92            handshake::perform_handshake_responder(&self.identity, transport, &mut self.framer)
93                .await?;
94
95        self.cipher = Some(result.cipher);
96        self.peer_pubkey = Some(result.peer_pubkey);
97        self.peer_node_id = Some(result.peer_node_id);
98        self.state = SessionState::Established;
99
100        Ok(())
101    }
102
103    /// Send encrypted application data.
104    pub async fn send(
105        &mut self,
106        transport: &mut dyn AsyncFrameTransport,
107        channel: u32,
108        data: &[u8],
109    ) -> Result<(), SessionError> {
110        if self.state != SessionState::Established {
111            return Err(SessionError::NotEstablished);
112        }
113
114        let cipher = self.cipher.as_mut().ok_or(SessionError::NoCipher)?;
115
116        // Encode header bytes for AAD
117        let header = FrameHeader::data(channel, 0).with_encrypted();
118        let mut header_buf = [0u8; 32];
119        let header_len = header
120            .encode(&mut header_buf)
121            .map_err(|e| SessionError::Wire(e.to_string()))?;
122        let aad = &header_buf[..header_len];
123
124        // Encrypt payload
125        let encrypted = cipher
126            .encrypt(data, aad)
127            .map_err(|e| SessionError::Crypto(e.to_string()))?;
128
129        // Update header with actual encrypted payload length
130        let final_header = FrameHeader {
131            payload_len: encrypted.len() as u64,
132            ..header
133        };
134
135        self.framer
136            .write_frame(transport, &final_header, &encrypted)
137            .await
138            .map_err(|e| SessionError::Framer(e.to_string()))?;
139
140        Ok(())
141    }
142
143    /// Receive the next message. Returns `(channel_id, decrypted_data)`.
144    ///
145    /// Automatically handles control frames (Ping/Pong/Close).
146    pub async fn recv(
147        &mut self,
148        transport: &mut dyn AsyncFrameTransport,
149    ) -> Result<Option<(u32, Vec<u8>)>, SessionError> {
150        loop {
151            if self.state == SessionState::Closed {
152                return Ok(None);
153            }
154
155            let (header, payload) = match self.framer.read_frame(transport).await {
156                Ok(Some(frame)) => frame,
157                Ok(None) => {
158                    self.state = SessionState::Closed;
159                    return Ok(None);
160                }
161                Err(e) => return Err(SessionError::Framer(e.to_string())),
162            };
163
164            // Handle control frames
165            match header.frame_type {
166                FrameType::Ping => {
167                    let pong = FrameHeader::control(FrameType::Pong);
168                    self.framer
169                        .write_frame(transport, &pong, &[])
170                        .await
171                        .map_err(|e| SessionError::Framer(e.to_string()))?;
172                    continue; // Don't return pings to the caller
173                }
174                FrameType::Pong => continue, // Absorb pongs
175                FrameType::Close => {
176                    self.state = SessionState::Closed;
177                    return Ok(None);
178                }
179                FrameType::Data => {
180                    // Decrypt if the frame is marked as encrypted
181                    if header.flags.encrypted {
182                        let cipher = self.cipher.as_ref().ok_or(SessionError::NoCipher)?;
183
184                        // Reconstruct AAD from the header (same process as sender)
185                        let aad_header = FrameHeader::data(header.channel_id, 0).with_encrypted();
186                        let mut aad_buf = [0u8; 32];
187                        let aad_len = aad_header
188                            .encode(&mut aad_buf)
189                            .map_err(|e| SessionError::Wire(e.to_string()))?;
190
191                        let decrypted = cipher
192                            .decrypt(&payload, &aad_buf[..aad_len])
193                            .map_err(|e| SessionError::Crypto(e.to_string()))?;
194
195                        return Ok(Some((header.channel_id, decrypted)));
196                    } else {
197                        return Ok(Some((header.channel_id, payload.to_vec())));
198                    }
199                }
200                _ => continue, // Skip other frame types for now
201            }
202        }
203    }
204
205    /// Send a graceful close frame.
206    pub async fn close(
207        &mut self,
208        transport: &mut dyn AsyncFrameTransport,
209    ) -> Result<(), SessionError> {
210        if self.state == SessionState::Closed {
211            return Ok(());
212        }
213
214        self.state = SessionState::Closing;
215        let header = FrameHeader::control(FrameType::Close);
216        self.framer
217            .write_frame(transport, &header, &[])
218            .await
219            .map_err(|e| SessionError::Framer(e.to_string()))?;
220        self.state = SessionState::Closed;
221
222        Ok(())
223    }
224}
225
226/// Session-level errors.
227#[derive(Debug)]
228pub enum SessionError {
229    NotEstablished,
230    NoCipher,
231    Wire(String),
232    Crypto(String),
233    Framer(String),
234}
235
236impl std::fmt::Display for SessionError {
237    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
238        match self {
239            Self::NotEstablished => write!(f, "session not established (handshake not complete)"),
240            Self::NoCipher => write!(f, "no session cipher available"),
241            Self::Wire(e) => write!(f, "wire error: {e}"),
242            Self::Crypto(e) => write!(f, "crypto error: {e}"),
243            Self::Framer(e) => write!(f, "framer error: {e}"),
244        }
245    }
246}
247
248impl std::error::Error for SessionError {}