1use super::{
17 error::Error,
18 routes::{Route, Routes},
19};
20
21use std::{
22 io::{self, Read, Write},
23 net::{self, TcpStream},
24 process::{Child, Command, Stdio},
25};
26
27#[derive(Debug)]
28pub enum Socket {
29 TcpStream(TcpStream),
30 SshClient(Child),
31}
32
33impl From<TcpStream> for Socket {
34 fn from(s: TcpStream) -> Self {
35 Socket::TcpStream(s)
36 }
37}
38
39impl Socket {
40 pub fn borrow_mut_read(&mut self) -> &mut dyn Read {
41 match *self {
42 Socket::TcpStream(ref mut s) => s,
43 Socket::SshClient(ref mut p) => {
44 p.stdout.as_mut().expect("child process's stdout is piped")
45 }
46 }
47 }
48
49 pub fn borrow_mut_write(&mut self) -> &mut dyn Write {
50 match *self {
51 Socket::TcpStream(ref mut s) => s,
52 Socket::SshClient(ref mut p) => {
53 p.stdin.as_mut().expect("child process's stdin is piped")
54 }
55 }
56 }
57}
58
59impl Drop for Socket {
60 fn drop(&mut self) {
61 match *self {
62 Socket::TcpStream(ref mut s) => {
63 let _ignore = s.shutdown(net::Shutdown::Both);
64 }
65 Socket::SshClient(ref mut p) => {
66 if let Ok(Some(_)) = p.try_wait() {
67 } else {
69 let _ = p.kill();
71 }
72 }
73 }
74 }
75}
76
77pub fn connect(mut routes: Routes) -> Result<Socket, Error> {
78 let first_route = routes.next().expect("there is at least one route");
79 match connect_impl(&first_route) {
80 Ok(socket) => Ok(socket),
81 Err(first_err) if first_err.kind() == io::ErrorKind::ConnectionRefused => {
82 for route in routes {
83 match connect_impl(&route) {
84 Ok(socket) => {
85 return Ok(socket);
86 }
87 Err(e) if e.kind() == io::ErrorKind::ConnectionRefused => {
88 continue;
89 }
90 Err(err) => return Err(Error::FailedToConnectToHost(err)),
91 }
92 }
93 Err(Error::FailedToConnectToHost(first_err))
94 }
95 Err(err) => Err(Error::FailedToConnectToHost(err)),
96 }
97}
98
99fn connect_impl(route: &Route) -> Result<Socket, io::Error> {
100 use Route::*;
101 match *route {
102 Direct(ip) => {
103 let s = TcpStream::connect(ip)?;
104 s.set_nodelay(true)?;
105 Ok(Socket::from(s))
106 }
107 Tunneled(ref opts) => {
108 let mut cmd = Command::new("ssh");
109 cmd
110 .stdin(Stdio::piped())
111 .stdout(Stdio::piped())
112 .arg("-x")
113 .arg("-N")
114 .arg("-T")
115 .arg("-o")
116 .arg("ExitOnForwardFailure=yes")
117 .arg("-o")
118 .arg("ClearAllForwardings=yes")
119 .arg("-o")
120 .arg("ConnectTimeout=5")
121 .arg("-W")
122 .arg(format!("{}:{}", opts.host_addr, opts.host_port));
123 if let Some(ref user) = opts.ssh_user {
124 cmd.arg("-l").arg(user);
125 }
126 if let Some(ref port) = opts.ssh_port {
127 cmd.arg("-p").arg(port.to_string());
128 }
129 cmd.arg(opts.ssh_addr.to_string());
130 Ok(Socket::SshClient(cmd.spawn()?))
143 }
144 }
145}