use crate::framer::Framer;
use crate::handshake::{self, HandshakeError};
use po_crypto::aead::SessionCipher;
use po_crypto::identity::{Identity, NodeId};
use po_transport::traits::AsyncFrameTransport;
use po_wire::{FrameHeader, FrameType};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum SessionState {
New,
Handshaking,
Established,
Closing,
Closed,
}
pub struct Session {
state: SessionState,
framer: Framer,
cipher: Option<SessionCipher>,
identity: Identity,
peer_node_id: Option<NodeId>,
peer_pubkey: Option<[u8; 32]>,
}
impl Session {
pub fn new(identity: Identity) -> Self {
Self {
state: SessionState::New,
framer: Framer::new(),
cipher: None,
identity,
peer_node_id: None,
peer_pubkey: None,
}
}
pub fn state(&self) -> SessionState {
self.state
}
pub fn node_id(&self) -> &NodeId {
self.identity.node_id()
}
pub fn peer_node_id(&self) -> Option<&NodeId> {
self.peer_node_id.as_ref()
}
pub async fn handshake_initiator(
&mut self,
transport: &mut dyn AsyncFrameTransport,
) -> Result<(), HandshakeError> {
self.state = SessionState::Handshaking;
let result =
handshake::perform_handshake_initiator(&self.identity, transport, &mut self.framer)
.await?;
self.cipher = Some(result.cipher);
self.peer_pubkey = Some(result.peer_pubkey);
self.peer_node_id = Some(result.peer_node_id);
self.state = SessionState::Established;
Ok(())
}
pub async fn handshake_responder(
&mut self,
transport: &mut dyn AsyncFrameTransport,
) -> Result<(), HandshakeError> {
self.state = SessionState::Handshaking;
let result =
handshake::perform_handshake_responder(&self.identity, transport, &mut self.framer)
.await?;
self.cipher = Some(result.cipher);
self.peer_pubkey = Some(result.peer_pubkey);
self.peer_node_id = Some(result.peer_node_id);
self.state = SessionState::Established;
Ok(())
}
pub async fn send(
&mut self,
transport: &mut dyn AsyncFrameTransport,
channel: u32,
data: &[u8],
) -> Result<(), SessionError> {
if self.state != SessionState::Established {
return Err(SessionError::NotEstablished);
}
let cipher = self.cipher.as_mut().ok_or(SessionError::NoCipher)?;
let header = FrameHeader::data(channel, 0).with_encrypted();
let mut header_buf = [0u8; 32];
let header_len = header
.encode(&mut header_buf)
.map_err(|e| SessionError::Wire(e.to_string()))?;
let aad = &header_buf[..header_len];
let encrypted = cipher
.encrypt(data, aad)
.map_err(|e| SessionError::Crypto(e.to_string()))?;
let final_header = FrameHeader {
payload_len: encrypted.len() as u64,
..header
};
self.framer
.write_frame(transport, &final_header, &encrypted)
.await
.map_err(|e| SessionError::Framer(e.to_string()))?;
Ok(())
}
pub async fn recv(
&mut self,
transport: &mut dyn AsyncFrameTransport,
) -> Result<Option<(u32, Vec<u8>)>, SessionError> {
loop {
if self.state == SessionState::Closed {
return Ok(None);
}
let (header, payload) = match self.framer.read_frame(transport).await {
Ok(Some(frame)) => frame,
Ok(None) => {
self.state = SessionState::Closed;
return Ok(None);
}
Err(e) => return Err(SessionError::Framer(e.to_string())),
};
match header.frame_type {
FrameType::Ping => {
let pong = FrameHeader::control(FrameType::Pong);
self.framer
.write_frame(transport, &pong, &[])
.await
.map_err(|e| SessionError::Framer(e.to_string()))?;
continue; }
FrameType::Pong => continue, FrameType::Close => {
self.state = SessionState::Closed;
return Ok(None);
}
FrameType::Data => {
if header.flags.encrypted {
let cipher = self.cipher.as_ref().ok_or(SessionError::NoCipher)?;
let aad_header = FrameHeader::data(header.channel_id, 0).with_encrypted();
let mut aad_buf = [0u8; 32];
let aad_len = aad_header
.encode(&mut aad_buf)
.map_err(|e| SessionError::Wire(e.to_string()))?;
let decrypted = cipher
.decrypt(&payload, &aad_buf[..aad_len])
.map_err(|e| SessionError::Crypto(e.to_string()))?;
return Ok(Some((header.channel_id, decrypted)));
} else {
return Ok(Some((header.channel_id, payload.to_vec())));
}
}
_ => continue, }
}
}
pub async fn close(
&mut self,
transport: &mut dyn AsyncFrameTransport,
) -> Result<(), SessionError> {
if self.state == SessionState::Closed {
return Ok(());
}
self.state = SessionState::Closing;
let header = FrameHeader::control(FrameType::Close);
self.framer
.write_frame(transport, &header, &[])
.await
.map_err(|e| SessionError::Framer(e.to_string()))?;
self.state = SessionState::Closed;
Ok(())
}
}
#[derive(Debug)]
pub enum SessionError {
NotEstablished,
NoCipher,
Wire(String),
Crypto(String),
Framer(String),
}
impl std::fmt::Display for SessionError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::NotEstablished => write!(f, "session not established (handshake not complete)"),
Self::NoCipher => write!(f, "no session cipher available"),
Self::Wire(e) => write!(f, "wire error: {e}"),
Self::Crypto(e) => write!(f, "crypto error: {e}"),
Self::Framer(e) => write!(f, "framer error: {e}"),
}
}
}
impl std::error::Error for SessionError {}