1use crate::crypto::ellipticcurve::Signature;
7use crate::hash::TranscriptHash;
8use crate::net::record::{Record, RecordPayloadProtection, RecordType};
9use crate::net::server::{ServerHello, ServerConfig};
10use crate::net::{
11 alert::TlsError,
12 client::ClientHello,
13 extensions::{ServerExtensions, SignatureScheme},
14 handshake::{
15 get_finished_handshake, get_verify_data_for_finished, Certificate, Handshake, HandshakeType,
16 },
17};
18use crate::net::{KeySchedule, TlsStream};
19use crate::rand::{RngCore, PRNG, SimpleRng, URandomRng, SeedableRng};
20use crate::utils::{bytes, keylog::KeyLog, log};
21use ibig::{IBig, ibig};
22use std::net::SocketAddr;
23use std::net::TcpListener;
24
25use std::result::Result;
26
27pub struct ServerConnection {
28 server: TcpListener,
29 config: ServerConfig,
30}
31
32impl ServerConnection {
33 pub fn new(server: TcpListener, config: ServerConfig) -> Self {
34 Self { server, config }
35 }
36
37 pub fn accept(&self) -> std::result::Result<(TlsStream, SocketAddr), TlsError> {
38 let (sock, _addr) = match self.server.accept() {
39 Ok(e) => e,
40 Err(e) => {
41 log::error!("TCP accept error: {e:?}");
42 return Err(TlsError::BrokenPipe);
43 }
44 };
45
46 let mut stream = TlsStream::new(sock);
47
48 let mut shs = ServerHandshake::new(&mut stream, &self.config);
49 shs.do_handshake_with_error()?;
50
51 Ok((stream, _addr))
52 }
53}
54
55#[derive(PartialEq, PartialOrd, Clone, Copy, Debug)]
56#[repr(u8)]
57enum ServerHsState {
58 ClientHello,
59 ClientCertificate = 0x10,
60 ClientCertificateVerify,
61 ClientFinished,
62 FinishWithError(TlsError),
63 Ready,
64}
65
66struct ServerHandshake<'a> {
67 stream: &'a mut TlsStream,
68 config: &'a ServerConfig,
69 state: ServerHsState,
70 keylog: Option<KeyLog>,
71 client_cert: Option<Certificate>,
72 certificate_request_context: Option<Vec<u8>>,
73 rng: Box<dyn RngCore<IBig>>,
74 tshash: Option<Box<dyn TranscriptHash>>,
75 tshash_clienthello_serverfinished: Option<Box<dyn TranscriptHash>>,
76}
77
78impl<'a> ServerHandshake<'a> {
79 pub fn new(stream: &'a mut TlsStream, config: &'a ServerConfig) -> Self {
80 Self {
81 stream,
82 config,
83 state: ServerHsState::ClientHello,
84 keylog: None,
85 client_cert: None,
86 certificate_request_context: None,
87 rng: match config.prng_type {
88 PRNG::Simple => Box::new(SimpleRng::from_seed(ibig!(0))),
89 PRNG::URandom => Box::new(URandomRng::new())
90 },
91 tshash: None,
92 tshash_clienthello_serverfinished: None,
93 }
94 }
95 pub fn do_handshake_with_error(&mut self) -> Result<(), TlsError> {
96 if let Err(mut err) = self.do_handshake() {
97 if err < TlsError::NotOfficial {
98 log::error!("<-- Alert ({err:?})");
99 self.stream.write_alert(err)?;
100 }
101 if let TlsError::GotAlert(err_code) = err {
102 err = TlsError::new(err_code);
103 }
104 return Err(err);
105 }
106 Ok(())
107 }
108
109 fn do_handshake(&mut self) -> Result<(), TlsError> {
110 let mut rx_buf: [u8; 4096] = [0; 4096];
111
112 while self.state != ServerHsState::Ready {
113 let n = self.stream.tcp_read(&mut rx_buf)?;
114 let mut consumed_total = 0;
115 while consumed_total < n {
116 let (consumed, record) = Record::from_raw(&rx_buf[consumed_total..n])?;
117 consumed_total += consumed;
118 self.handle_handshake_record(record)?;
119 }
120 self.stream.flush()?;
122 }
123 Ok(())
124 }
125
126 fn handle_handshake_record(&mut self, record: Record) -> Result<(), TlsError> {
127 match record.content_type {
128 RecordType::ChangeCipherSpec => {
129 log::debug!("--> ChangeCipherSpec");
130 if self.state == ServerHsState::ClientHello {
131 return Err(TlsError::UnexpectedMessage);
132 }
133 return Ok(());
134 }
135 RecordType::Alert => {
136 let alert_code = record.fraqment.as_ref()[1];
137 let alert = TlsError::new(alert_code);
138 log::debug!("--> Alert {alert:?}");
139 if self.state != ServerHsState::Ready {
140 log::error!("Handshake aborted by client");
141 }
142 return Err(TlsError::GotAlert(alert_code));
143 }
144 _ => match self.state {
145 ServerHsState::ClientHello => {
146 if record.content_type != RecordType::Handshake {
147 return Err(TlsError::UnexpectedMessage);
148 }
149 self.handle_client_hello(record)?;
150 }
151 ServerHsState::ClientCertificate
152 | ServerHsState::ClientCertificateVerify
153 | ServerHsState::FinishWithError(_)
154 | ServerHsState::ClientFinished => {
155 self.handle_handshake_encrypted_record(record)?;
156 }
157 ServerHsState::Ready => {}
158 },
159 }
160 Ok(())
161 }
162 fn handle_client_hello(&mut self, record: Record) -> Result<(), TlsError> {
163 let handshake = Handshake::from_raw(record.fraqment.as_ref())?;
164
165 if handshake.handshake_type != HandshakeType::ClientHello {
166 return Err(TlsError::UnexpectedMessage);
167 }
168
169 log::debug!("--> ClientHello");
170 let client_hello = ClientHello::from_raw(handshake.fraqment)?;
171
172 let (server_hello, private_key) =
173 ServerHello::from_client_hello(&client_hello, &mut *self.rng, self.config)?;
174
175 let mut tshash = server_hello.cipher_suite.get_tshash()?;
176 tshash.update(record.fraqment.as_ref());
178
179 let handshake_raw =
181 Handshake::to_bytes(HandshakeType::ServerHello, server_hello.as_bytes());
182 tshash.update(&handshake_raw);
183 self.stream
184 .write_record(RecordType::Handshake, &handshake_raw)?;
185
186 if client_hello.legacy_session_id_echo.is_some() {
191 self.stream
192 .write_record(RecordType::ChangeCipherSpec, &[0x01])?;
193 }
194
195 let key_share_entry = match client_hello.get_public_key_share() {
197 Some(kse) => kse,
198 None => return Err(TlsError::HandshakeFailure),
199 };
200 let key_schedule =
201 KeySchedule::from_handshake(tshash.as_ref(), &private_key, key_share_entry)?;
202
203 let cipher = server_hello.cipher_suite.get_cipher()?;
204 let protection = RecordPayloadProtection::new(key_schedule, cipher, false);
205
206 if let Some(filepath) = &self.config.keylog {
207 if protection.is_some() {
208 let protection = protection.as_ref().unwrap();
209 let keylog = KeyLog::new(filepath.to_owned(), client_hello.random);
210 keylog.append_handshake_traffic_secrets(
211 &protection.handshake_keys.server.traffic_secret,
212 &protection.handshake_keys.client.traffic_secret,
213 );
214 self.keylog = Some(keylog);
215 }
216 }
217
218 self.stream.set_protection(protection);
219
220 let encrypted_extensions_raw = ServerExtensions::new().as_bytes();
222 let handshake_raw =
223 Handshake::to_bytes(HandshakeType::EncryptedExtensions, encrypted_extensions_raw);
224
225 log::debug!("<-- EncryptedExtensions");
226 self.stream
227 .write_record(RecordType::Handshake, &handshake_raw)?;
228 tshash.update(&handshake_raw);
229
230 if let Some(client_cert_ca) = &self.config.client_cert_ca {
232 self.certificate_request_context = Some(self.rng.bytes(32));
235 let certificate_request = client_cert_ca
236 .get_certificate_request(self.certificate_request_context.as_ref().unwrap());
237
238 let handshake_raw =
239 Handshake::to_bytes(HandshakeType::CertificateRequest, certificate_request);
240
241 log::debug!("<-- CertificateRequest");
242 self.stream
243 .write_record(RecordType::Handshake, &handshake_raw)?;
244 tshash.update(&handshake_raw);
245 }
246
247 let handshake_raw = Handshake::to_bytes(
249 HandshakeType::Certificate,
250 self.config.cert.get_certificate_for_handshake(vec![0x00]),
251 );
252
253 log::debug!("<-- Certificate");
254 self.stream
255 .write_record(RecordType::Handshake, &handshake_raw)?;
256 tshash.update(&handshake_raw);
257
258 let certificate_verify_raw = self.config.cert.get_certificate_verify_for_handshake(
260 &self.config.privkey,
261 tshash.as_ref(),
262 b"server",
263 )?;
264
265 let handshake_raw =
266 Handshake::to_bytes(HandshakeType::CertificateVerify, certificate_verify_raw);
267
268 tshash.update(&handshake_raw);
269 self.stream
270 .write_record(RecordType::Handshake, &handshake_raw)?;
271 log::debug!("<-- CertificateVerify");
272
273 let handshake_raw = get_finished_handshake(
275 &self
276 .stream
277 .protection
278 .as_ref()
279 .unwrap()
280 .key_schedule
281 .server_handshake_traffic_secret,
282 tshash.as_ref(),
283 )?;
284
285 tshash.update(&handshake_raw);
286 self.stream
287 .write_record(RecordType::Handshake, &handshake_raw)?;
288 log::debug!("<-- ServerFinished");
289
290 self.state = if self.config.client_cert_ca.is_some() {
291 ServerHsState::ClientCertificate
292 } else {
293 ServerHsState::ClientFinished
294 };
295
296 self.tshash = Some(tshash);
297 Ok(())
298 }
299
300 fn handle_handshake_encrypted_record(&mut self, record: Record) -> Result<(), TlsError> {
301 let (content_type, content) = self.stream.protection.as_mut().unwrap().decrypt(record)?;
302
303 let record = Record::new(content_type, crate::net::record::Value::Owned(content));
304
305 if record.content_type != RecordType::Handshake
306 || (self.config.client_cert_ca.is_some() && self.certificate_request_context.is_none())
307 {
308 if record.content_type == RecordType::Alert {
309 return Err(TlsError::GotAlert(record.fraqment.as_ref()[1]));
310 }
311 return Err(TlsError::UnexpectedMessage);
312 }
313
314 match self.state {
315 ServerHsState::ClientCertificate => self.handle_client_certificate(record)?,
316 ServerHsState::ClientCertificateVerify => {
317 self.handle_client_certificate_verify(record)?
318 }
319 ServerHsState::ClientFinished | ServerHsState::FinishWithError(_) => {
320 self.handle_client_finish(record)?
321 }
322 _ => (),
323 }
324 Ok(())
325 }
326
327 pub fn handle_client_certificate(&mut self, record: Record) -> Result<(), TlsError> {
328 let handshake = Handshake::from_raw(record.fraqment.as_ref())?;
329
330 if handshake.handshake_type != HandshakeType::Certificate
331 || self.config.client_cert_ca.is_none()
332 {
333 return Err(TlsError::UnexpectedMessage);
334 }
335
336 self.tshash_clienthello_serverfinished = Some((*self.tshash.as_ref().unwrap()).clone());
337
338 self.tshash
339 .as_mut()
340 .unwrap()
341 .update(record.fraqment.as_ref());
342
343 log::debug!("--> ClientCertificate");
344
345 let cert_request_context_len = handshake.fraqment[0] as usize;
346 let cert_request_context = &handshake.fraqment[1..cert_request_context_len + 1];
347
348 if cert_request_context != self.certificate_request_context.as_ref().unwrap() {
349 return Err(TlsError::HandshakeFailure);
350 }
351
352 let mut certs = match Certificate::from_hello(handshake.fraqment) {
353 Ok(e) => e,
354 Err(e) => {
355 self.state = ServerHsState::FinishWithError(e);
356 return Ok(());
357 }
358 };
359
360 let cert = certs.pop().unwrap();
361
362 if self
364 .config
365 .client_cert_ca
366 .as_ref()
367 .unwrap()
368 .has_signed(&cert)
369 .is_err()
370 {
371 self.state = ServerHsState::FinishWithError(TlsError::UnknownCa)
372 } else {
373 if let Some(f) = self.config.client_cert_custom_verify_fn.as_ref() {
374 if !f(cert.x509.as_ref().unwrap()) {
375 log::debug!("Certificate denied by custom verify function");
376 self.state = ServerHsState::FinishWithError(TlsError::AccessDenied);
377 return Ok(());
378 }
379 }
380 self.client_cert = Some(cert);
381 self.state = ServerHsState::ClientCertificateVerify;
382 }
383
384 Ok(())
385 }
386
387 pub fn handle_client_certificate_verify(&mut self, record: Record) -> Result<(), TlsError> {
388 let handshake = Handshake::from_raw(record.fraqment.as_ref())?;
389
390 if handshake.handshake_type != HandshakeType::CertificateVerify
391 || self.client_cert.is_none()
392 {
393 return Err(TlsError::UnexpectedMessage);
394 }
395
396 log::debug!("--> ClientCertificateVerify");
397
398 let algo = SignatureScheme::new(bytes::to_u16(&handshake.fraqment[0..2]))?;
399
400 let mut consumed = 4; match algo {
403 SignatureScheme::ecdsa_secp256r1_sha256 => {
404 let (signature, size) = match Signature::from_der(&handshake.fraqment[consumed..]) {
405 Ok(e) => e,
406 Err(e) => {
407 self.state = ServerHsState::FinishWithError(e);
408 return Ok(());
409 }
410 };
411
412 consumed += size;
413
414 if !self.client_cert.as_ref().unwrap().is_certificate_valid(
415 signature,
416 self.tshash.as_ref().unwrap().as_ref(),
417 b"client",
418 ) {
419 self.state = ServerHsState::FinishWithError(TlsError::BadCertificate);
420 return Ok(());
421 }
422 }
423 e => todo!("SignatureScheme {e:?} for client cert not implemented yet"),
424 }
425
426 let sign_len = bytes::to_u16(&handshake.fraqment[2..4]) as usize;
427 if sign_len != consumed - 4 {
428 self.state = ServerHsState::FinishWithError(TlsError::BadCertificate);
429 return Ok(());
430 }
431
432 self.tshash
433 .as_mut()
434 .unwrap()
435 .update(record.fraqment.as_ref());
436
437 self.state = ServerHsState::ClientFinished;
438 Ok(())
439 }
440 pub fn handle_client_finish(&mut self, record: Record) -> Result<(), TlsError> {
441 let handshake = Handshake::from_raw(record.fraqment.as_ref())?;
442
443 if handshake.handshake_type != HandshakeType::Finished {
444 if let ServerHsState::FinishWithError(_) = self.state {
445 self.tshash.as_mut().unwrap().update(handshake.as_bytes());
447 return Ok(());
448 }
449 return Err(TlsError::UnexpectedMessage);
450 }
451
452
453 let protection = self.stream.protection.as_mut().unwrap();
454 log::debug!("--> ClientFinished");
455 let fraqment = handshake.fraqment.to_owned();
456
457 let verify_data = Some(get_verify_data_for_finished(
458 &protection.key_schedule.client_handshake_traffic_secret,
459 self.tshash.as_mut().unwrap().as_ref(),
460 )?);
461
462 if fraqment != verify_data.unwrap() {
463 return Err(TlsError::DecryptError);
464 }
465
466 let tshash = if self.tshash_clienthello_serverfinished.is_some() {
468 self.tshash_clienthello_serverfinished.as_mut().unwrap()
469 } else {
470 self.tshash.as_mut().unwrap()
471 };
472
473 protection.generate_application_keys(tshash.as_ref())?;
474
475 if let Some(k) = &self.keylog {
476 k.append_from_record_payload_protection(protection);
477 }
478
479 if let ServerHsState::FinishWithError(err) = self.state {
480 log::error!("Abort connection: {err:?}");
481 return Err(err);
482 }
483
484 self.state = ServerHsState::Ready;
485 Ok(())
486 }
487}