1use std::io::{Read, Write};
2use std::net::TcpStream;
3use std::sync::Arc;
4
5use log::info;
6use rand::RngCore;
7use rustls::client::{ServerCertVerified, ServerCertVerifier};
8use rustls::{Certificate, ClientConfig, ClientConnection, ServerName, StreamOwned};
9
10use crate::error::{ensure, Error, Result};
11use crate::util::{
12 duration_from_timeout_ms, put_u16, put_u32, read_u16, read_u32, resolve_socket_addrs, Bytes,
13};
14
15pub const K_CELL_BODY_LEN: usize = 509;
16pub const K_RELAY_HEADER_LEN: usize = 11;
17pub const K_RELAY_PAYLOAD_LEN: usize = K_CELL_BODY_LEN - K_RELAY_HEADER_LEN;
18
19pub const CMD_RELAY: u8 = 3;
20pub const CMD_DESTROY: u8 = 4;
21pub const CMD_CREATE_FAST: u8 = 5;
22pub const CMD_CREATED_FAST: u8 = 6;
23pub const CMD_VERSIONS: u8 = 7;
24pub const CMD_NETINFO: u8 = 8;
25pub const CMD_RELAY_EARLY: u8 = 9;
26#[allow(dead_code)]
27pub const CMD_CREATE2: u8 = 10;
28#[allow(dead_code)]
29pub const CMD_CREATED2: u8 = 11;
30
31pub const RELAY_BEGIN: u8 = 1;
32pub const RELAY_DATA: u8 = 2;
33pub const RELAY_END: u8 = 3;
34pub const RELAY_CONNECTED: u8 = 4;
35pub const RELAY_SENDME: u8 = 5;
36pub const RELAY_BEGIN_DIR: u8 = 13;
37pub const RELAY_EXTEND2: u8 = 14;
38pub const RELAY_EXTENDED2: u8 = 15;
39pub const RELAY_ESTABLISH_RENDEZVOUS: u8 = 33;
40pub const RELAY_INTRODUCE1: u8 = 34;
41pub const RELAY_RENDEZVOUS2: u8 = 37;
42pub const RELAY_RENDEZVOUS_ESTABLISHED: u8 = 39;
43pub const RELAY_INTRODUCE_ACK: u8 = 40;
44
45#[derive(Clone, Debug, Eq, PartialEq)]
46pub struct Cell {
47 pub circ_id: u32,
48 pub cmd: u8,
49 pub body: Bytes,
50}
51
52#[derive(Debug)]
53struct NoCertificateVerification;
54
55impl ServerCertVerifier for NoCertificateVerification {
56 fn verify_server_cert(
57 &self,
58 _end_entity: &Certificate,
59 _intermediates: &[Certificate],
60 _server_name: &ServerName,
61 _scts: &mut dyn Iterator<Item = &[u8]>,
62 _ocsp_response: &[u8],
63 _now: std::time::SystemTime,
64 ) -> std::result::Result<ServerCertVerified, rustls::Error> {
65 Ok(ServerCertVerified::assertion())
66 }
67}
68
69pub fn connect_tcp(host: &str, port: u16, timeout_ms: i32) -> Result<TcpStream> {
70 let timeout = duration_from_timeout_ms(timeout_ms)?;
71 let addrs = resolve_socket_addrs(host, port)
72 .map_err(|e| Error::new(format!("getaddrinfo failed for {host}: {e}")))?;
73 let mut last_error = String::from("no address found");
74 for addr in addrs {
75 match TcpStream::connect_timeout(&addr, timeout) {
76 Ok(stream) => {
77 stream.set_read_timeout(Some(timeout))?;
78 stream.set_write_timeout(Some(timeout))?;
79 return Ok(stream);
80 }
81 Err(e) => last_error = e.to_string(),
82 }
83 }
84 Err(Error::new(format!(
85 "tcp connect failed to {host}:{port}: {last_error}"
86 )))
87}
88
89pub fn write_all_fd(stream: &mut TcpStream, data: &[u8]) -> Result<()> {
90 stream.write_all(data)?;
91 Ok(())
92}
93
94pub fn read_all_fd(stream: &mut TcpStream, limit: usize) -> Result<Bytes> {
95 let mut out = Bytes::new();
96 let mut buf = [0u8; 8192];
97 while out.len() < limit {
98 let n = stream.read(&mut buf)?;
99 if n == 0 {
100 break;
101 }
102 out.extend_from_slice(&buf[..n]);
103 }
104 Ok(out)
105}
106
107pub struct TorChannel {
108 tls: StreamOwned<ClientConnection, TcpStream>,
109 link_version: i32,
110}
111
112impl TorChannel {
113 pub fn new(host: String, port: u16, timeout_ms: i32) -> Result<Self> {
114 info!("opening Tor TLS channel to {host}:{port}");
115 let tls = Self::init_tls(&host, port, timeout_ms)?;
116 let mut ch = Self {
117 tls,
118 link_version: 4,
119 };
120 ch.negotiate()?;
121 Ok(ch)
122 }
123
124 pub fn write_cell(&mut self, circ_id: u32, cmd: u8, body: &[u8]) -> Result<()> {
125 let mut out = Bytes::new();
126 if cmd == CMD_VERSIONS {
127 put_u16(&mut out, 0);
128 out.push(cmd);
129 put_u16(&mut out, body.len() as u16);
130 out.extend_from_slice(body);
131 } else if cmd >= 128 {
132 put_u32(&mut out, circ_id);
133 out.push(cmd);
134 put_u16(&mut out, body.len() as u16);
135 out.extend_from_slice(body);
136 } else {
137 ensure(body.len() <= K_CELL_BODY_LEN, "fixed cell body too large")?;
138 put_u32(&mut out, circ_id);
139 out.push(cmd);
140 out.extend_from_slice(body);
141 out.resize(4 + 1 + K_CELL_BODY_LEN, 0);
142 }
143 self.tls.write_all(&out)?;
144 self.tls.flush()?;
145 Ok(())
146 }
147
148 pub fn read_cell(&mut self) -> Result<Cell> {
149 self.read_cell_with_circ_len(4)
150 }
151
152 pub fn new_circ_id(&self) -> u32 {
153 let mut id = rand::rngs::OsRng.next_u32() | 0x8000_0000;
154 if id == 0 {
155 id = 0x8000_0001;
156 }
157 id
158 }
159
160 fn init_tls(
161 host: &str,
162 port: u16,
163 timeout_ms: i32,
164 ) -> Result<StreamOwned<ClientConnection, TcpStream>> {
165 let stream = connect_tcp(host, port, timeout_ms)?;
166 let verifier = Arc::new(NoCertificateVerification);
167 let config = ClientConfig::builder()
168 .with_safe_defaults()
169 .with_custom_certificate_verifier(verifier)
170 .with_no_client_auth();
171 let server_name = ServerName::try_from("ignored.invalid")
172 .map_err(|_| Error::new("bad TLS server name"))?;
173 let conn = ClientConnection::new(Arc::new(config), server_name)?;
174 Ok(StreamOwned::new(conn, stream))
175 }
176
177 fn negotiate(&mut self) -> Result<()> {
178 let mut versions = Bytes::new();
179 put_u16(&mut versions, 4);
180 put_u16(&mut versions, 5);
181 let mut out = Bytes::new();
182 put_u16(&mut out, 0);
183 out.push(CMD_VERSIONS);
184 put_u16(&mut out, versions.len() as u16);
185 out.extend_from_slice(&versions);
186 self.tls.write_all(&out)?;
187 self.tls.flush()?;
188
189 let v = self.read_cell_with_circ_len(2)?;
190 ensure(v.cmd == CMD_VERSIONS, "first Tor cell was not VERSIONS")?;
191 let mut best = 0;
192 let mut i = 0;
193 while i + 1 < v.body.len() {
194 let peer = ((v.body[i] as i32) << 8) | v.body[i + 1] as i32;
195 if (peer == 4 || peer == 5) && peer > best {
196 best = peer;
197 }
198 i += 2;
199 }
200 ensure(best >= 4, "relay does not support link protocol 4+")?;
201 self.link_version = best;
202 info!("negotiated Tor link protocol v{best}");
203 let mut got_netinfo = false;
204 for _ in 0..16 {
205 let c = self.read_cell()?;
206 if c.cmd == CMD_NETINFO {
207 got_netinfo = true;
208 break;
209 }
210 }
211 ensure(got_netinfo, "relay did not send NETINFO")?;
212 self.send_netinfo()
213 }
214
215 fn send_netinfo(&mut self) -> Result<()> {
216 let mut body = vec![0; K_CELL_BODY_LEN];
217 let now = std::time::SystemTime::now()
218 .duration_since(std::time::UNIX_EPOCH)
219 .unwrap_or_default()
220 .as_secs() as u32;
221 body[0..4].copy_from_slice(&now.to_be_bytes());
222 body[4] = 4;
223 body[5] = 4;
224 body[10] = 0;
225 self.write_cell(0, CMD_NETINFO, &body)
226 }
227
228 fn read_cell_with_circ_len(&mut self, circ_len: usize) -> Result<Cell> {
229 let mut header = vec![0; circ_len + 1];
230 self.tls.read_exact(&mut header)?;
231 let circ_id = if circ_len == 2 {
232 read_u16(&header, 0)? as u32
233 } else {
234 read_u32(&header, 0)?
235 };
236 let cmd = header[circ_len];
237 let variable = cmd == CMD_VERSIONS || cmd >= 128;
238 let body = if variable {
239 let mut lenb = [0u8; 2];
240 self.tls.read_exact(&mut lenb)?;
241 let len = read_u16(&lenb, 0)? as usize;
242 let mut body = vec![0; len];
243 self.tls.read_exact(&mut body)?;
244 body
245 } else {
246 let mut body = vec![0; K_CELL_BODY_LEN];
247 self.tls.read_exact(&mut body)?;
248 body
249 };
250 Ok(Cell { circ_id, cmd, body })
251 }
252}