pipa/http/
connect_state.rs1use std::net::TcpStream;
2use std::os::unix::io::{AsRawFd, RawFd};
3
4use crate::http::conn::{Connection, IoHint};
5
6#[derive(Debug, Clone, Copy, PartialEq, Eq)]
7pub enum ConnPhase {
8 Connecting,
9 TlsHandshaking,
10 Ready,
11 Failed,
12}
13
14#[derive(Debug)]
15pub enum ConnEvent {
16 NeedRead,
17 NeedWrite,
18 NeedReadWrite,
19 Connected(Connection),
20 Error(String),
21}
22
23pub struct ConnectState {
24 phase: ConnPhase,
25 stream: Option<TcpStream>,
26 conn: Option<Connection>,
27 use_tls: bool,
28 tls_host: String,
29 extra_roots: Vec<Vec<u8>>,
30}
31
32impl ConnectState {
33 pub fn dummy() -> Self {
34 ConnectState {
35 phase: ConnPhase::Failed,
36 stream: None,
37 conn: None,
38 use_tls: false,
39 tls_host: String::new(),
40 extra_roots: Vec::new(),
41 }
42 }
43
44 pub fn new(
45 host: &str,
46 port: u16,
47 use_tls: bool,
48 extra_roots: Vec<Vec<u8>>,
49 ) -> Result<Self, String> {
50 let stream = Connection::connect_nonblocking(host, port)?;
51 Ok(ConnectState {
52 phase: ConnPhase::Connecting,
53 stream: Some(stream),
54 conn: None,
55 use_tls,
56 tls_host: host.to_string(),
57 extra_roots,
58 })
59 }
60
61 pub fn phase(&self) -> ConnPhase {
62 self.phase
63 }
64
65 pub fn fd(&self) -> Option<RawFd> {
66 self.stream
67 .as_ref()
68 .map(|s| s.as_raw_fd())
69 .or_else(|| self.conn.as_ref().map(|c| c.raw_fd()))
70 }
71
72 pub fn wants_read(&self) -> bool {
73 match self.phase {
74 ConnPhase::Connecting => false,
75 ConnPhase::TlsHandshaking => self.conn.as_ref().map_or(false, |c| c.tls_wants_read()),
76 _ => false,
77 }
78 }
79
80 pub fn wants_write(&self) -> bool {
81 match self.phase {
82 ConnPhase::Connecting => true,
83 ConnPhase::TlsHandshaking => self.conn.as_ref().map_or(false, |c| c.tls_wants_write()),
84 _ => false,
85 }
86 }
87
88 pub fn try_advance(&mut self) -> ConnEvent {
89 match self.phase {
90 ConnPhase::Connecting => {
91 let stream = match self.stream.as_ref() {
92 Some(s) => s,
93 None => {
94 self.phase = ConnPhase::Failed;
95 return ConnEvent::Error("no stream".into());
96 }
97 };
98 match Connection::check_connect(stream) {
99 Ok(()) => {
100 let stream = self.stream.take().unwrap();
101 if self.use_tls {
102 match Connection::start_tls(&self.tls_host, stream, &self.extra_roots) {
103 Ok(conn) => {
104 self.conn = Some(conn);
105 self.phase = ConnPhase::TlsHandshaking;
106 self.try_advance()
107 }
108 Err(e) => {
109 self.phase = ConnPhase::Failed;
110 ConnEvent::Error(e)
111 }
112 }
113 } else {
114 let conn = Connection::Plain(stream);
115 self.phase = ConnPhase::Ready;
116 ConnEvent::Connected(conn)
117 }
118 }
119 Err(e) => {
120 self.phase = ConnPhase::Failed;
121 ConnEvent::Error(e)
122 }
123 }
124 }
125 ConnPhase::TlsHandshaking => {
126 let conn = match self.conn.as_mut() {
127 Some(c) => c,
128 None => {
129 self.phase = ConnPhase::Failed;
130 return ConnEvent::Error("no connection for tls".into());
131 }
132 };
133 match conn.tls_handshake_step() {
134 Ok(IoHint::Ready) => {
135 self.phase = ConnPhase::Ready;
136 let conn = self.conn.take().unwrap();
137 ConnEvent::Connected(conn)
138 }
139 Ok(IoHint::Read) => ConnEvent::NeedRead,
140 Ok(IoHint::Write) => ConnEvent::NeedWrite,
141 Ok(IoHint::ReadWrite) => ConnEvent::NeedReadWrite,
142 Err(e) => {
143 self.phase = ConnPhase::Failed;
144 ConnEvent::Error(e)
145 }
146 }
147 }
148 _ => ConnEvent::Error("invalid state for advance".into()),
149 }
150 }
151}