1use crate::framer::Framer;
4use crate::handshake::{self, HandshakeError};
5use po_crypto::aead::SessionCipher;
6use po_crypto::identity::{Identity, NodeId};
7use po_transport::traits::AsyncFrameTransport;
8use po_wire::{FrameHeader, FrameType};
9
10#[derive(Debug, Clone, Copy, PartialEq, Eq)]
12pub enum SessionState {
13 New,
15 Handshaking,
17 Established,
19 Closing,
21 Closed,
23}
24
25pub struct Session {
29 state: SessionState,
30 framer: Framer,
31 cipher: Option<SessionCipher>,
32 identity: Identity,
33 peer_node_id: Option<NodeId>,
34 peer_pubkey: Option<[u8; 32]>,
35}
36
37impl Session {
38 pub fn new(identity: Identity) -> Self {
40 Self {
41 state: SessionState::New,
42 framer: Framer::new(),
43 cipher: None,
44 identity,
45 peer_node_id: None,
46 peer_pubkey: None,
47 }
48 }
49
50 pub fn state(&self) -> SessionState {
52 self.state
53 }
54
55 pub fn node_id(&self) -> &NodeId {
57 self.identity.node_id()
58 }
59
60 pub fn peer_node_id(&self) -> Option<&NodeId> {
62 self.peer_node_id.as_ref()
63 }
64
65 pub async fn handshake_initiator(
67 &mut self,
68 transport: &mut dyn AsyncFrameTransport,
69 ) -> Result<(), HandshakeError> {
70 self.state = SessionState::Handshaking;
71
72 let result =
73 handshake::perform_handshake_initiator(&self.identity, transport, &mut self.framer)
74 .await?;
75
76 self.cipher = Some(result.cipher);
77 self.peer_pubkey = Some(result.peer_pubkey);
78 self.peer_node_id = Some(result.peer_node_id);
79 self.state = SessionState::Established;
80
81 Ok(())
82 }
83
84 pub async fn handshake_responder(
86 &mut self,
87 transport: &mut dyn AsyncFrameTransport,
88 ) -> Result<(), HandshakeError> {
89 self.state = SessionState::Handshaking;
90
91 let result =
92 handshake::perform_handshake_responder(&self.identity, transport, &mut self.framer)
93 .await?;
94
95 self.cipher = Some(result.cipher);
96 self.peer_pubkey = Some(result.peer_pubkey);
97 self.peer_node_id = Some(result.peer_node_id);
98 self.state = SessionState::Established;
99
100 Ok(())
101 }
102
103 pub async fn send(
105 &mut self,
106 transport: &mut dyn AsyncFrameTransport,
107 channel: u32,
108 data: &[u8],
109 ) -> Result<(), SessionError> {
110 if self.state != SessionState::Established {
111 return Err(SessionError::NotEstablished);
112 }
113
114 let cipher = self.cipher.as_mut().ok_or(SessionError::NoCipher)?;
115
116 let header = FrameHeader::data(channel, 0).with_encrypted();
118 let mut header_buf = [0u8; 32];
119 let header_len = header
120 .encode(&mut header_buf)
121 .map_err(|e| SessionError::Wire(e.to_string()))?;
122 let aad = &header_buf[..header_len];
123
124 let encrypted = cipher
126 .encrypt(data, aad)
127 .map_err(|e| SessionError::Crypto(e.to_string()))?;
128
129 let final_header = FrameHeader {
131 payload_len: encrypted.len() as u64,
132 ..header
133 };
134
135 self.framer
136 .write_frame(transport, &final_header, &encrypted)
137 .await
138 .map_err(|e| SessionError::Framer(e.to_string()))?;
139
140 Ok(())
141 }
142
143 pub async fn recv(
147 &mut self,
148 transport: &mut dyn AsyncFrameTransport,
149 ) -> Result<Option<(u32, Vec<u8>)>, SessionError> {
150 loop {
151 if self.state == SessionState::Closed {
152 return Ok(None);
153 }
154
155 let (header, payload) = match self.framer.read_frame(transport).await {
156 Ok(Some(frame)) => frame,
157 Ok(None) => {
158 self.state = SessionState::Closed;
159 return Ok(None);
160 }
161 Err(e) => return Err(SessionError::Framer(e.to_string())),
162 };
163
164 match header.frame_type {
166 FrameType::Ping => {
167 let pong = FrameHeader::control(FrameType::Pong);
168 self.framer
169 .write_frame(transport, &pong, &[])
170 .await
171 .map_err(|e| SessionError::Framer(e.to_string()))?;
172 continue; }
174 FrameType::Pong => continue, FrameType::Close => {
176 self.state = SessionState::Closed;
177 return Ok(None);
178 }
179 FrameType::Data => {
180 if header.flags.encrypted {
182 let cipher = self.cipher.as_ref().ok_or(SessionError::NoCipher)?;
183
184 let aad_header = FrameHeader::data(header.channel_id, 0).with_encrypted();
186 let mut aad_buf = [0u8; 32];
187 let aad_len = aad_header
188 .encode(&mut aad_buf)
189 .map_err(|e| SessionError::Wire(e.to_string()))?;
190
191 let decrypted = cipher
192 .decrypt(&payload, &aad_buf[..aad_len])
193 .map_err(|e| SessionError::Crypto(e.to_string()))?;
194
195 return Ok(Some((header.channel_id, decrypted)));
196 } else {
197 return Ok(Some((header.channel_id, payload.to_vec())));
198 }
199 }
200 _ => continue, }
202 }
203 }
204
205 pub async fn close(
207 &mut self,
208 transport: &mut dyn AsyncFrameTransport,
209 ) -> Result<(), SessionError> {
210 if self.state == SessionState::Closed {
211 return Ok(());
212 }
213
214 self.state = SessionState::Closing;
215 let header = FrameHeader::control(FrameType::Close);
216 self.framer
217 .write_frame(transport, &header, &[])
218 .await
219 .map_err(|e| SessionError::Framer(e.to_string()))?;
220 self.state = SessionState::Closed;
221
222 Ok(())
223 }
224}
225
226#[derive(Debug)]
228pub enum SessionError {
229 NotEstablished,
230 NoCipher,
231 Wire(String),
232 Crypto(String),
233 Framer(String),
234}
235
236impl std::fmt::Display for SessionError {
237 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
238 match self {
239 Self::NotEstablished => write!(f, "session not established (handshake not complete)"),
240 Self::NoCipher => write!(f, "no session cipher available"),
241 Self::Wire(e) => write!(f, "wire error: {e}"),
242 Self::Crypto(e) => write!(f, "crypto error: {e}"),
243 Self::Framer(e) => write!(f, "framer error: {e}"),
244 }
245 }
246}
247
248impl std::error::Error for SessionError {}