Skip to main content

pipa/http/
connect_state.rs

1use std::net::TcpStream;
2use std::os::unix::io::{AsRawFd, RawFd};
3
4use crate::http::conn::{Connection, IoHint};
5
6#[derive(Debug, Clone, Copy, PartialEq, Eq)]
7pub enum ConnPhase {
8    Connecting,
9    TlsHandshaking,
10    Ready,
11    Failed,
12}
13
14#[derive(Debug)]
15pub enum ConnEvent {
16    NeedRead,
17    NeedWrite,
18    NeedReadWrite,
19    Connected(Connection),
20    Error(String),
21}
22
23pub struct ConnectState {
24    phase: ConnPhase,
25    stream: Option<TcpStream>,
26    conn: Option<Connection>,
27    use_tls: bool,
28    tls_host: String,
29    extra_roots: Vec<Vec<u8>>,
30}
31
32impl ConnectState {
33    pub fn dummy() -> Self {
34        ConnectState {
35            phase: ConnPhase::Failed,
36            stream: None,
37            conn: None,
38            use_tls: false,
39            tls_host: String::new(),
40            extra_roots: Vec::new(),
41        }
42    }
43
44    pub fn new(
45        host: &str,
46        port: u16,
47        use_tls: bool,
48        extra_roots: Vec<Vec<u8>>,
49    ) -> Result<Self, String> {
50        let stream = Connection::connect_nonblocking(host, port)?;
51        Ok(ConnectState {
52            phase: ConnPhase::Connecting,
53            stream: Some(stream),
54            conn: None,
55            use_tls,
56            tls_host: host.to_string(),
57            extra_roots,
58        })
59    }
60
61    pub fn phase(&self) -> ConnPhase {
62        self.phase
63    }
64
65    pub fn fd(&self) -> Option<RawFd> {
66        self.stream
67            .as_ref()
68            .map(|s| s.as_raw_fd())
69            .or_else(|| self.conn.as_ref().map(|c| c.raw_fd()))
70    }
71
72    pub fn wants_read(&self) -> bool {
73        match self.phase {
74            ConnPhase::Connecting => false,
75            ConnPhase::TlsHandshaking => self.conn.as_ref().map_or(false, |c| c.tls_wants_read()),
76            _ => false,
77        }
78    }
79
80    pub fn wants_write(&self) -> bool {
81        match self.phase {
82            ConnPhase::Connecting => true,
83            ConnPhase::TlsHandshaking => self.conn.as_ref().map_or(false, |c| c.tls_wants_write()),
84            _ => false,
85        }
86    }
87
88    pub fn try_advance(&mut self) -> ConnEvent {
89        match self.phase {
90            ConnPhase::Connecting => {
91                let stream = match self.stream.as_ref() {
92                    Some(s) => s,
93                    None => {
94                        self.phase = ConnPhase::Failed;
95                        return ConnEvent::Error("no stream".into());
96                    }
97                };
98                match Connection::check_connect(stream) {
99                    Ok(()) => {
100                        let stream = self.stream.take().unwrap();
101                        if self.use_tls {
102                            match Connection::start_tls(&self.tls_host, stream, &self.extra_roots) {
103                                Ok(conn) => {
104                                    self.conn = Some(conn);
105                                    self.phase = ConnPhase::TlsHandshaking;
106                                    self.try_advance()
107                                }
108                                Err(e) => {
109                                    self.phase = ConnPhase::Failed;
110                                    ConnEvent::Error(e)
111                                }
112                            }
113                        } else {
114                            let conn = Connection::Plain(stream);
115                            self.phase = ConnPhase::Ready;
116                            ConnEvent::Connected(conn)
117                        }
118                    }
119                    Err(e) => {
120                        self.phase = ConnPhase::Failed;
121                        ConnEvent::Error(e)
122                    }
123                }
124            }
125            ConnPhase::TlsHandshaking => {
126                let conn = match self.conn.as_mut() {
127                    Some(c) => c,
128                    None => {
129                        self.phase = ConnPhase::Failed;
130                        return ConnEvent::Error("no connection for tls".into());
131                    }
132                };
133                match conn.tls_handshake_step() {
134                    Ok(IoHint::Ready) => {
135                        self.phase = ConnPhase::Ready;
136                        let conn = self.conn.take().unwrap();
137                        ConnEvent::Connected(conn)
138                    }
139                    Ok(IoHint::Read) => ConnEvent::NeedRead,
140                    Ok(IoHint::Write) => ConnEvent::NeedWrite,
141                    Ok(IoHint::ReadWrite) => ConnEvent::NeedReadWrite,
142                    Err(e) => {
143                        self.phase = ConnPhase::Failed;
144                        ConnEvent::Error(e)
145                    }
146                }
147            }
148            _ => ConnEvent::Error("invalid state for advance".into()),
149        }
150    }
151}