1use crate::{
2 cipherstate::CipherStates,
3 constants::{CIPHERKEYLEN, MAXDHLEN, MAXMSGLEN, TAGLEN},
4 error::{Error, StateProblem},
5 handshakestate::HandshakeState,
6 params::HandshakePattern,
7 utils::Toggle,
8};
9use core::{convert::TryFrom, fmt};
10
11pub struct TransportState {
17 cipherstates: CipherStates,
18 pattern: HandshakePattern,
19 dh_len: usize,
20 rs: Toggle<[u8; MAXDHLEN]>,
21 initiator: bool,
22}
23
24impl TransportState {
25 pub(crate) fn new(handshake: HandshakeState) -> Result<Self, Error> {
26 if !handshake.is_handshake_finished() {
27 return Err(StateProblem::HandshakeNotFinished.into());
28 }
29
30 let dh_len = handshake.dh_len();
31 let HandshakeState { cipherstates, params, rs, initiator, .. } = handshake;
32 let pattern = params.handshake.pattern;
33
34 Ok(TransportState { cipherstates, pattern, dh_len, rs, initiator })
35 }
36
37 #[must_use]
44 pub fn get_remote_static(&self) -> Option<&[u8]> {
45 self.rs.get().map(|rs| &rs[..self.dh_len])
46 }
47
48 pub fn write_message(&mut self, payload: &[u8], message: &mut [u8]) -> Result<usize, Error> {
58 if !self.initiator && self.pattern.is_oneway() {
59 return Err(StateProblem::OneWay.into());
60 } else if payload.len() + TAGLEN > MAXMSGLEN || payload.len() + TAGLEN > message.len() {
61 return Err(Error::Input);
62 }
63
64 let cipher =
65 if self.initiator { &mut self.cipherstates.0 } else { &mut self.cipherstates.1 };
66 cipher.encrypt(payload, message)
67 }
68
69 pub fn read_message(&mut self, message: &[u8], payload: &mut [u8]) -> Result<usize, Error> {
81 if message.len() > MAXMSGLEN {
82 Err(Error::Input)
83 } else if self.initiator && self.pattern.is_oneway() {
84 Err(StateProblem::OneWay.into())
85 } else {
86 let cipher =
87 if self.initiator { &mut self.cipherstates.1 } else { &mut self.cipherstates.0 };
88 cipher.decrypt(message, payload)
89 }
90 }
91
92 pub fn rekey_outgoing(&mut self) {
97 if self.initiator {
98 self.cipherstates.rekey_initiator();
99 } else {
100 self.cipherstates.rekey_responder();
101 }
102 }
103
104 pub fn rekey_incoming(&mut self) {
109 if self.initiator {
110 self.cipherstates.rekey_responder();
111 } else {
112 self.cipherstates.rekey_initiator();
113 }
114 }
115
116 pub fn rekey_manually(
118 &mut self,
119 initiator: Option<&[u8; CIPHERKEYLEN]>,
120 responder: Option<&[u8; CIPHERKEYLEN]>,
121 ) {
122 if let Some(key) = initiator {
123 self.rekey_initiator_manually(key);
124 }
125 if let Some(key) = responder {
126 self.rekey_responder_manually(key);
127 }
128 }
129
130 pub fn rekey_initiator_manually(&mut self, key: &[u8; CIPHERKEYLEN]) {
132 self.cipherstates.rekey_initiator_manually(key);
133 }
134
135 pub fn rekey_responder_manually(&mut self, key: &[u8; CIPHERKEYLEN]) {
137 self.cipherstates.rekey_responder_manually(key);
138 }
139
140 pub fn set_receiving_nonce(&mut self, nonce: u64) {
142 if self.initiator {
143 self.cipherstates.1.set_nonce(nonce);
144 } else {
145 self.cipherstates.0.set_nonce(nonce);
146 }
147 }
148
149 #[must_use]
155 pub fn receiving_nonce(&self) -> u64 {
156 if self.initiator {
157 self.cipherstates.1.nonce()
158 } else {
159 self.cipherstates.0.nonce()
160 }
161 }
162
163 #[must_use]
169 pub fn sending_nonce(&self) -> u64 {
170 if self.initiator {
171 self.cipherstates.0.nonce()
172 } else {
173 self.cipherstates.1.nonce()
174 }
175 }
176
177 #[must_use]
179 pub fn is_initiator(&self) -> bool {
180 self.initiator
181 }
182}
183
184impl fmt::Debug for TransportState {
185 fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
186 fmt.debug_struct("TransportState").finish()
187 }
188}
189
190impl TryFrom<HandshakeState> for TransportState {
191 type Error = Error;
192
193 fn try_from(old: HandshakeState) -> Result<Self, Self::Error> {
194 TransportState::new(old)
195 }
196}