1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
use std::io;

pub use rustls::internal::msgs::handshake::ClientHelloPayload;
use rustls::internal::msgs::{
    deframer::MessageDeframer,
    enums::{ContentType, HandshakeType},
    handshake::HandshakePayload,
    message::MessagePayload,
};
pub use rustls::TLSError;

pub struct Parser {
    message_deframer: MessageDeframer,
}

pub enum ParseOutput {
    Done(ClientHelloPayload),
    Partial,
    Invalid(TLSError),
}

impl Parser {
    pub fn new() -> Self {
        Self {
            message_deframer: MessageDeframer::new(),
        }
    }

    // ref https://github.com/ctz/rustls/blob/v/0.18.0/rustls/src/msgs/deframer.rs#L54-L79
    pub fn parse(&mut self, buf: &mut impl io::Read) -> io::Result<ParseOutput> {
        self.message_deframer.read(buf)?;

        match self.message_deframer.frames.pop_front() {
            Some(mut msg) => {
                // https://github.com/ctz/rustls/blob/v/0.18.0/rustls/src/client/mod.rs#L486-L489
                if !msg.decode_payload() {
                    return Ok(ParseOutput::Invalid(TLSError::CorruptMessagePayload(
                        msg.typ,
                    )));
                }

                match msg.payload {
                    MessagePayload::Handshake(ref hsp) => match hsp.payload {
                        HandshakePayload::ClientHello(ref chp) => {
                            let payload = ClientHelloPayload {
                                client_version: chp.client_version,
                                random: chp.random.to_owned(),
                                session_id: chp.session_id,
                                cipher_suites: chp.cipher_suites.to_owned(),
                                compression_methods: chp.compression_methods.to_owned(),
                                extensions: chp.extensions.to_owned(),
                            };

                            Ok(ParseOutput::Done(payload))
                        }
                        _ => {
                            // ref https://github.com/ctz/rustls/blob/v/0.18.0/rustls/src/check.rs#L7-L25
                            Ok(ParseOutput::Invalid(
                                TLSError::InappropriateHandshakeMessage {
                                    expect_types: vec![HandshakeType::ClientHello],
                                    got_type: hsp.typ,
                                },
                            ))
                        }
                    },
                    _ => {
                        // ref https://github.com/ctz/rustls/blob/v/0.18.0/rustls/src/check.rs#L7-L25
                        Ok(ParseOutput::Invalid(TLSError::InappropriateMessage {
                            expect_types: vec![ContentType::Handshake],
                            got_type: msg.typ,
                        }))
                    }
                }
            }
            None => {
                if self.message_deframer.desynced {
                    // ref https://github.com/ctz/rustls/blob/v/0.18.0/rustls/src/server/mod.rs#L457-L459
                    Ok(ParseOutput::Invalid(TLSError::CorruptMessage))
                } else {
                    Ok(ParseOutput::Partial)
                }
            }
        }
    }
}