anothertls/net/server/
connection.rs

1/*
2 * Copyright (c) 2023, Tobias Müller <git@tsmr.eu>
3 *
4 */
5
6use crate::crypto::ellipticcurve::Signature;
7use crate::hash::TranscriptHash;
8use crate::net::record::{Record, RecordPayloadProtection, RecordType};
9use crate::net::server::{ServerHello, ServerConfig};
10use crate::net::{
11    alert::TlsError,
12    client::ClientHello,
13    extensions::{ServerExtensions, SignatureScheme},
14    handshake::{
15        get_finished_handshake, get_verify_data_for_finished, Certificate, Handshake, HandshakeType,
16    },
17};
18use crate::net::{KeySchedule, TlsStream};
19use crate::rand::{RngCore, PRNG, SimpleRng, URandomRng, SeedableRng};
20use crate::utils::{bytes, keylog::KeyLog, log};
21use ibig::{IBig, ibig};
22use std::net::SocketAddr;
23use std::net::TcpListener;
24
25use std::result::Result;
26
27pub struct ServerConnection {
28    server: TcpListener,
29    config: ServerConfig,
30}
31
32impl ServerConnection {
33    pub fn new(server: TcpListener, config: ServerConfig) -> Self {
34        Self { server, config }
35    }
36
37    pub fn accept(&self) -> std::result::Result<(TlsStream, SocketAddr), TlsError> {
38        let (sock, _addr) = match self.server.accept() {
39            Ok(e) => e,
40            Err(e) => {
41                log::error!("TCP accept error: {e:?}");
42                return Err(TlsError::BrokenPipe);
43            }
44        };
45
46        let mut stream = TlsStream::new(sock);
47
48        let mut shs = ServerHandshake::new(&mut stream, &self.config);
49        shs.do_handshake_with_error()?;
50
51        Ok((stream, _addr))
52    }
53}
54
55#[derive(PartialEq, PartialOrd, Clone, Copy, Debug)]
56#[repr(u8)]
57enum ServerHsState {
58    ClientHello,
59    ClientCertificate = 0x10,
60    ClientCertificateVerify,
61    ClientFinished,
62    FinishWithError(TlsError),
63    Ready,
64}
65
66struct ServerHandshake<'a> {
67    stream: &'a mut TlsStream,
68    config: &'a ServerConfig,
69    state: ServerHsState,
70    keylog: Option<KeyLog>,
71    client_cert: Option<Certificate>,
72    certificate_request_context: Option<Vec<u8>>,
73    rng: Box<dyn RngCore<IBig>>,
74    tshash: Option<Box<dyn TranscriptHash>>,
75    tshash_clienthello_serverfinished: Option<Box<dyn TranscriptHash>>,
76}
77
78impl<'a> ServerHandshake<'a> {
79    pub fn new(stream: &'a mut TlsStream, config: &'a ServerConfig) -> Self {
80        Self {
81            stream,
82            config,
83            state: ServerHsState::ClientHello,
84            keylog: None,
85            client_cert: None,
86            certificate_request_context: None,
87            rng: match config.prng_type {
88                PRNG::Simple => Box::new(SimpleRng::from_seed(ibig!(0))),
89                PRNG::URandom  => Box::new(URandomRng::new())
90            },
91            tshash: None,
92            tshash_clienthello_serverfinished: None,
93        }
94    }
95    pub fn do_handshake_with_error(&mut self) -> Result<(), TlsError> {
96        if let Err(mut err) = self.do_handshake() {
97            if err < TlsError::NotOfficial {
98                log::error!("<-- Alert ({err:?})");
99                self.stream.write_alert(err)?;
100            }
101            if let TlsError::GotAlert(err_code) = err {
102                err = TlsError::new(err_code);
103            }
104            return Err(err);
105        }
106        Ok(())
107    }
108
109    fn do_handshake(&mut self) -> Result<(), TlsError> {
110        let mut rx_buf: [u8; 4096] = [0; 4096];
111
112        while self.state != ServerHsState::Ready {
113            let n = self.stream.tcp_read(&mut rx_buf)?;
114            let mut consumed_total = 0;
115            while consumed_total < n {
116                let (consumed, record) = Record::from_raw(&rx_buf[consumed_total..n])?;
117                consumed_total += consumed;
118                self.handle_handshake_record(record)?;
119            }
120            // send server handshake records to the client
121            self.stream.flush()?;
122        }
123        Ok(())
124    }
125
126    fn handle_handshake_record(&mut self, record: Record) -> Result<(), TlsError> {
127        match record.content_type {
128            RecordType::ChangeCipherSpec => {
129                log::debug!("--> ChangeCipherSpec");
130                if self.state == ServerHsState::ClientHello {
131                    return Err(TlsError::UnexpectedMessage);
132                }
133                return Ok(());
134            }
135            RecordType::Alert => {
136                let alert_code = record.fraqment.as_ref()[1];
137                let alert = TlsError::new(alert_code);
138                log::debug!("--> Alert {alert:?}");
139                if self.state != ServerHsState::Ready {
140                    log::error!("Handshake aborted by client");
141                }
142                return Err(TlsError::GotAlert(alert_code));
143            }
144            _ => match self.state {
145                ServerHsState::ClientHello => {
146                    if record.content_type != RecordType::Handshake {
147                        return Err(TlsError::UnexpectedMessage);
148                    }
149                    self.handle_client_hello(record)?;
150                }
151                ServerHsState::ClientCertificate
152                | ServerHsState::ClientCertificateVerify
153                | ServerHsState::FinishWithError(_)
154                | ServerHsState::ClientFinished => {
155                    self.handle_handshake_encrypted_record(record)?;
156                }
157                ServerHsState::Ready => {}
158            },
159        }
160        Ok(())
161    }
162    fn handle_client_hello(&mut self, record: Record) -> Result<(), TlsError> {
163        let handshake = Handshake::from_raw(record.fraqment.as_ref())?;
164
165        if handshake.handshake_type != HandshakeType::ClientHello {
166            return Err(TlsError::UnexpectedMessage);
167        }
168
169        log::debug!("--> ClientHello");
170        let client_hello = ClientHello::from_raw(handshake.fraqment)?;
171
172        let (server_hello, private_key) =
173            ServerHello::from_client_hello(&client_hello, &mut *self.rng, self.config)?;
174
175        let mut tshash = server_hello.cipher_suite.get_tshash()?;
176        // Add ClientHello
177        tshash.update(record.fraqment.as_ref());
178
179        // -- ServerHello --
180        let handshake_raw =
181            Handshake::to_bytes(HandshakeType::ServerHello, server_hello.as_bytes());
182        tshash.update(&handshake_raw);
183        self.stream
184            .write_record(RecordType::Handshake, &handshake_raw)?;
185
186        // -- Change Cipher Spec --
187        // Either side can send change_cipher_spec at any time during the handshake, as they
188        // must be ignored by the peer, but if the client sends a non-empty session ID, the
189        // server MUST send the change_cipher_spec as described in this appendix.
190        if client_hello.legacy_session_id_echo.is_some() {
191            self.stream
192                .write_record(RecordType::ChangeCipherSpec, &[0x01])?;
193        }
194
195        // -- Handshake Keys Calc --
196        let key_share_entry = match client_hello.get_public_key_share() {
197            Some(kse) => kse,
198            None => return Err(TlsError::HandshakeFailure),
199        };
200        let key_schedule =
201            KeySchedule::from_handshake(tshash.as_ref(), &private_key, key_share_entry)?;
202
203        let cipher = server_hello.cipher_suite.get_cipher()?;
204        let protection = RecordPayloadProtection::new(key_schedule, cipher, false);
205
206        if let Some(filepath) = &self.config.keylog {
207            if protection.is_some() {
208                let protection = protection.as_ref().unwrap();
209                let keylog = KeyLog::new(filepath.to_owned(), client_hello.random);
210                keylog.append_handshake_traffic_secrets(
211                    &protection.handshake_keys.server.traffic_secret,
212                    &protection.handshake_keys.client.traffic_secret,
213                );
214                self.keylog = Some(keylog);
215            }
216        }
217
218        self.stream.set_protection(protection);
219
220        // -- EncryptedExtensions --
221        let encrypted_extensions_raw = ServerExtensions::new().as_bytes();
222        let handshake_raw =
223            Handshake::to_bytes(HandshakeType::EncryptedExtensions, encrypted_extensions_raw);
224
225        log::debug!("<-- EncryptedExtensions");
226        self.stream
227            .write_record(RecordType::Handshake, &handshake_raw)?;
228        tshash.update(&handshake_raw);
229
230        // -- Certificate Request --
231        if let Some(client_cert_ca) = &self.config.client_cert_ca {
232            // prevent an attacker who has temporary access to the client's
233            // private key from pre-computing valid CertificateVerify messages
234            self.certificate_request_context = Some(self.rng.bytes(32));
235            let certificate_request = client_cert_ca
236                .get_certificate_request(self.certificate_request_context.as_ref().unwrap());
237
238            let handshake_raw =
239                Handshake::to_bytes(HandshakeType::CertificateRequest, certificate_request);
240
241            log::debug!("<-- CertificateRequest");
242            self.stream
243                .write_record(RecordType::Handshake, &handshake_raw)?;
244            tshash.update(&handshake_raw);
245        }
246
247        // -- Server Certificate --
248        let handshake_raw = Handshake::to_bytes(
249            HandshakeType::Certificate,
250            self.config.cert.get_certificate_for_handshake(vec![0x00]),
251        );
252
253        log::debug!("<-- Certificate");
254        self.stream
255            .write_record(RecordType::Handshake, &handshake_raw)?;
256        tshash.update(&handshake_raw);
257
258        // -- Server Certificate Verify --
259        let certificate_verify_raw = self.config.cert.get_certificate_verify_for_handshake(
260            &self.config.privkey,
261            tshash.as_ref(),
262            b"server",
263        )?;
264
265        let handshake_raw =
266            Handshake::to_bytes(HandshakeType::CertificateVerify, certificate_verify_raw);
267
268        tshash.update(&handshake_raw);
269        self.stream
270            .write_record(RecordType::Handshake, &handshake_raw)?;
271        log::debug!("<-- CertificateVerify");
272
273        // -- FINISHED --
274        let handshake_raw = get_finished_handshake(
275            &self
276                .stream
277                .protection
278                .as_ref()
279                .unwrap()
280                .key_schedule
281                .server_handshake_traffic_secret,
282            tshash.as_ref(),
283        )?;
284
285        tshash.update(&handshake_raw);
286        self.stream
287            .write_record(RecordType::Handshake, &handshake_raw)?;
288        log::debug!("<-- ServerFinished");
289
290        self.state = if self.config.client_cert_ca.is_some() {
291            ServerHsState::ClientCertificate
292        } else {
293            ServerHsState::ClientFinished
294        };
295
296        self.tshash = Some(tshash);
297        Ok(())
298    }
299
300    fn handle_handshake_encrypted_record(&mut self, record: Record) -> Result<(), TlsError> {
301        let (content_type, content) = self.stream.protection.as_mut().unwrap().decrypt(record)?;
302
303        let record = Record::new(content_type, crate::net::record::Value::Owned(content));
304
305        if record.content_type != RecordType::Handshake
306            || (self.config.client_cert_ca.is_some() && self.certificate_request_context.is_none())
307        {
308            if record.content_type == RecordType::Alert {
309                return Err(TlsError::GotAlert(record.fraqment.as_ref()[1]));
310            }
311            return Err(TlsError::UnexpectedMessage);
312        }
313
314        match self.state {
315            ServerHsState::ClientCertificate => self.handle_client_certificate(record)?,
316            ServerHsState::ClientCertificateVerify => {
317                self.handle_client_certificate_verify(record)?
318            }
319            ServerHsState::ClientFinished | ServerHsState::FinishWithError(_) => {
320                self.handle_client_finish(record)?
321            }
322            _ => (),
323        }
324        Ok(())
325    }
326
327    pub fn handle_client_certificate(&mut self, record: Record) -> Result<(), TlsError> {
328        let handshake = Handshake::from_raw(record.fraqment.as_ref())?;
329
330        if handshake.handshake_type != HandshakeType::Certificate
331            || self.config.client_cert_ca.is_none()
332        {
333            return Err(TlsError::UnexpectedMessage);
334        }
335
336        self.tshash_clienthello_serverfinished = Some((*self.tshash.as_ref().unwrap()).clone());
337
338        self.tshash
339            .as_mut()
340            .unwrap()
341            .update(record.fraqment.as_ref());
342
343        log::debug!("--> ClientCertificate");
344
345        let cert_request_context_len = handshake.fraqment[0] as usize;
346        let cert_request_context = &handshake.fraqment[1..cert_request_context_len + 1];
347
348        if cert_request_context != self.certificate_request_context.as_ref().unwrap() {
349            return Err(TlsError::HandshakeFailure);
350        }
351
352        let mut certs = match Certificate::from_hello(handshake.fraqment) {
353            Ok(e) => e,
354            Err(e) => {
355                self.state = ServerHsState::FinishWithError(e);
356                return Ok(());
357            }
358        };
359
360        let cert = certs.pop().unwrap();
361
362        // Validate client cert against the CA
363        if self
364            .config
365            .client_cert_ca
366            .as_ref()
367            .unwrap()
368            .has_signed(&cert)
369            .is_err()
370        {
371            self.state = ServerHsState::FinishWithError(TlsError::UnknownCa)
372        } else {
373            if let Some(f) = self.config.client_cert_custom_verify_fn.as_ref() {
374                if !f(cert.x509.as_ref().unwrap()) {
375                    log::debug!("Certificate denied by custom verify function");
376                    self.state = ServerHsState::FinishWithError(TlsError::AccessDenied);
377                    return Ok(());
378                }
379            }
380            self.client_cert = Some(cert);
381            self.state = ServerHsState::ClientCertificateVerify;
382        }
383
384        Ok(())
385    }
386
387    pub fn handle_client_certificate_verify(&mut self, record: Record) -> Result<(), TlsError> {
388        let handshake = Handshake::from_raw(record.fraqment.as_ref())?;
389
390        if handshake.handshake_type != HandshakeType::CertificateVerify
391            || self.client_cert.is_none()
392        {
393            return Err(TlsError::UnexpectedMessage);
394        }
395
396        log::debug!("--> ClientCertificateVerify");
397
398        let algo = SignatureScheme::new(bytes::to_u16(&handshake.fraqment[0..2]))?;
399
400        let mut consumed = 4; // algo and len
401
402        match algo {
403            SignatureScheme::ecdsa_secp256r1_sha256 => {
404                let (signature, size) = match Signature::from_der(&handshake.fraqment[consumed..]) {
405                    Ok(e) => e,
406                    Err(e) => {
407                        self.state = ServerHsState::FinishWithError(e);
408                        return Ok(());
409                    }
410                };
411
412                consumed += size;
413
414                if !self.client_cert.as_ref().unwrap().is_certificate_valid(
415                    signature,
416                    self.tshash.as_ref().unwrap().as_ref(),
417                    b"client",
418                ) {
419                    self.state = ServerHsState::FinishWithError(TlsError::BadCertificate);
420                    return Ok(());
421                }
422            }
423            e => todo!("SignatureScheme {e:?} for client cert not implemented yet"),
424        }
425
426        let sign_len = bytes::to_u16(&handshake.fraqment[2..4]) as usize;
427        if sign_len != consumed - 4 {
428            self.state = ServerHsState::FinishWithError(TlsError::BadCertificate);
429            return Ok(());
430        }
431
432        self.tshash
433            .as_mut()
434            .unwrap()
435            .update(record.fraqment.as_ref());
436
437        self.state = ServerHsState::ClientFinished;
438        Ok(())
439    }
440    pub fn handle_client_finish(&mut self, record: Record) -> Result<(), TlsError> {
441        let handshake = Handshake::from_raw(record.fraqment.as_ref())?;
442
443        if handshake.handshake_type != HandshakeType::Finished {
444            if let ServerHsState::FinishWithError(_) = self.state {
445                // When error in Certificate, then CertificateVerify will follow
446                self.tshash.as_mut().unwrap().update(handshake.as_bytes());
447                return Ok(());
448            }
449            return Err(TlsError::UnexpectedMessage);
450        }
451
452
453        let protection = self.stream.protection.as_mut().unwrap();
454        log::debug!("--> ClientFinished");
455        let fraqment = handshake.fraqment.to_owned();
456
457        let verify_data = Some(get_verify_data_for_finished(
458            &protection.key_schedule.client_handshake_traffic_secret,
459            self.tshash.as_mut().unwrap().as_ref(),
460        )?);
461
462        if fraqment != verify_data.unwrap() {
463            return Err(TlsError::DecryptError);
464        }
465
466        // Derive-Secret: ClientHello..server Finished
467        let tshash = if self.tshash_clienthello_serverfinished.is_some() {
468            self.tshash_clienthello_serverfinished.as_mut().unwrap()
469        } else {
470            self.tshash.as_mut().unwrap()
471        };
472
473        protection.generate_application_keys(tshash.as_ref())?;
474
475        if let Some(k) = &self.keylog {
476            k.append_from_record_payload_protection(protection);
477        }
478
479        if let ServerHsState::FinishWithError(err) = self.state {
480            log::error!("Abort connection: {err:?}");
481            return Err(err);
482        }
483
484        self.state = ServerHsState::Ready;
485        Ok(())
486    }
487}