1#[cfg(feature = "ssl")]
5use std::error::Error as StdError;
6use std::io::{ErrorKind, Error};
8#[cfg(feature = "ssl")]
9use std::sync::{Arc, Mutex};
10#[cfg(feature = "ssl")]
11use std::result::Result as StdResult;
12use std::io::{Write, Read, Result, BufReader, BufWriter};
13use std::net::{SocketAddr, ToSocketAddrs, TcpStream};
14#[cfg(test)]
15use std::net::Shutdown;
16use std::os::unix::prelude::AsRawFd;
17
18#[cfg(feature = "ssl")]
19use openssl::ssl::{SslConnectorBuilder, SslMethod, SslStream, SSL_VERIFY_PEER, SSL_VERIFY_NONE};
20#[cfg(feature = "ssl")]
21use openssl::error::ErrorStack;
22#[cfg(feature = "ssl")]
23use openssl::x509;
24use std::str::FromStr;
25use net::config;
27use uuid::Uuid;
28use std::time::Duration;
29pub struct Connection {
33 id: String,
34 pub reader: BufReader<NetStream>,
36 pub writer: BufWriter<NetStream>,
38 config: config::Config,
40 peer_address: String,
41 local_address: String,
42}
43
44impl Connection {
46 fn new(
48 reader: BufReader<NetStream>,
49 writer: BufWriter<NetStream>,
50 config: &config::Config,
51 peer_address: String,
52 local_address: String,
53 ) -> Connection {
54 Connection {
55 id: Uuid::new_v4().to_urn_string(),
56 reader: reader,
57 writer: writer,
58 config: config.clone(),
59 peer_address: peer_address,
60 local_address: local_address,
61 }
62 }
63
64 pub fn get_peer_address(&self) -> &String {
65 &self.peer_address
66 }
67 pub fn get_local_address(&self) -> &String {
68 &self.local_address
69 }
70
71 pub fn connect(config: &config::Config) -> Result<Connection> {
76 if config.use_ssl.unwrap_or(false) {
77 Connection::connect_ssl_internal(config)
78 } else {
79 Connection::connect_internal(config)
80 }
81 }
82
83 pub fn reconnect(&mut self) -> Result<Connection> {
87 if self.config.use_ssl.unwrap_or(false) {
88 Connection::connect_ssl_internal(&self.config)
89 } else {
90 Connection::connect_internal(&self.config)
91 }
92 }
93
94 pub fn id(&self) -> &String {
96 &self.id
97 }
98
99 pub fn is_valid(&self) -> bool {
101 match self.reader.get_ref() {
102 &NetStream::UnsecuredTcpStream(ref tcp) => {
103 debug!("TCP FD:{}", tcp.as_raw_fd());
104 if tcp.as_raw_fd() < 0 { false } else { true }
105 }
106 #[cfg(feature = "ssl")]
107 &NetStream::SslTcpStream(ref ssl) => {
108 let fd = ssl.lock().unwrap().get_ref().as_raw_fd();
109 debug!("SSL FD:{}", fd);
110 if fd < 0 {
111 return false;
112 } else {
113 return true;
114 }
115 }
116 }
117 }
118
119 fn host_to_sock_address(host: &str, port: u16) -> Result<SocketAddr> {
120 let server = match (host, port).to_socket_addrs() {
121 Ok(mut host_iter) => {
122 match host_iter.next() {
123 Some(mut host_addr) => return Ok(host_addr),
124 None => {
125 let err_str = format!("Failed to parse {}:{}. ", host, port);
126 error!("{}", err_str);
127 return Err(Error::new(ErrorKind::Other, err_str));
128 }
129 }
130 }
131 Err(e) => {
132 let err_str = format!("Failed to parse {}:{}. Error:{}", host, port, e);
133 error!("{}", err_str);
134 return Err(Error::new(ErrorKind::Other, err_str));
135 }
136 };
137 let err_str = format!("Failed to parse {}:{}. ", host, port);
138 error!("{}", err_str);
139 return Err(Error::new(ErrorKind::Other, err_str));
140 }
141
142
143 fn connect_internal(config: &config::Config) -> Result<Connection> {
146 let host: &str = &config.server.clone();
147 let port = config.port;
148 error!("Connecting to server {}:{}", host, port);
149 let mut stream_socket;
150
151 let server = try!(Connection::host_to_sock_address(host, port));
152
153 if config.connect_timeout.is_some() {
154 stream_socket = try!(TcpStream::connect_timeout(
155 &server,
156 Duration::from_millis(config.connect_timeout.unwrap()),
157 ));
158 } else {
159 stream_socket = try!(TcpStream::connect(server));
160 }
161 stream_socket.set_nodelay(true);
162 if config.read_timeout.is_some() {
163 stream_socket.set_read_timeout(Some(Duration::from_millis(
164 config.read_timeout.unwrap(),
165 )));
166 }
167 if config.write_timeout.is_some() {
168 stream_socket.set_write_timeout(Some(Duration::from_millis(
169 config.write_timeout.unwrap(),
170 )));
171 }
172
173 let writer_socket = try!(stream_socket.try_clone());
174 let peer_address = match stream_socket.peer_addr() {
175 Ok(sock_addr) => sock_addr.to_string(),
176 Err(_) => String::from(""),
177 };
178 let local_address = match stream_socket.local_addr() {
179 Ok(sock_addr) => sock_addr.to_string(),
180 Err(_) => String::from(""),
181 };
182 Ok(Connection::new(
183 BufReader::new(NetStream::UnsecuredTcpStream(stream_socket)),
184 BufWriter::new(NetStream::UnsecuredTcpStream(writer_socket)),
185 config,
186 peer_address,
187 local_address,
188 ))
189 }
190
191
192
193
194 #[cfg(not(feature = "ssl"))]
196 fn connect_ssl_internal(config: &config::Config) -> Result<Connection> {
197 panic!(
198 "Cannot connect to {}:{} over SSL without compiling with SSL support.",
199 config.server.clone(),
200 config.port
201 )
202 }
203
204 #[cfg(feature = "ssl")]
206 fn connect_ssl_internal(config: &config::Config) -> Result<Connection> {
207 let host: &str = &config.server.clone();
208 let port = config.port;
209 info!("Connecting to server {}:{}", host, port);
210
211 let mut socket;
212 let server = try!(Connection::host_to_sock_address(host, port));
213
214 if config.connect_timeout.is_some() {
215 socket = try!(TcpStream::connect_timeout(
216 &server,
217 Duration::from_millis(config.connect_timeout.unwrap()),
218 ));
219 } else {
220 socket = try!(TcpStream::connect(server));
221 }
222 socket.set_nodelay(true);
223
224 let peer_address = match socket.peer_addr() {
225 Ok(sock_addr) => sock_addr.to_string(),
226 Err(_) => String::from(""),
227 };
228 let local_address = match socket.local_addr() {
229 Ok(sock_addr) => sock_addr.to_string(),
230 Err(_) => String::from(""),
231 };
232
233 if config.read_timeout.is_some() {
234 socket.set_read_timeout(Some(Duration::from_millis(config.read_timeout.unwrap())));
235 }
236 if config.write_timeout.is_some() {
237 socket.set_write_timeout(Some(Duration::from_millis(config.write_timeout.unwrap())));
238 }
239
240
241 let mut ssl_connector_builder = SslConnectorBuilder::new(SslMethod::tls()).unwrap();
242 {
243 let ctx = ssl_connector_builder.builder_mut();
244
245 ctx.set_default_verify_paths().unwrap();
246
247 if config.verify.unwrap_or(false) {
249 ctx.set_verify(SSL_VERIFY_PEER);
250 } else {
251 ctx.set_verify(SSL_VERIFY_NONE);
252 }
253 if config.verify_depth.unwrap_or(0) > 0 {
255 ctx.set_verify_depth(config.verify_depth.unwrap());
256 }
257 if config.certificate_file.is_some() {
258 try!(ssl_to_io(ctx.set_certificate_file(
259 config.certificate_file.as_ref().unwrap(),
260 x509::X509_FILETYPE_PEM,
261 )));
262 }
263 if config.private_key_file.is_some() {
264 try!(ssl_to_io(ctx.set_private_key_file(
265 config.private_key_file.as_ref().unwrap(),
266 x509::X509_FILETYPE_PEM,
267 )));
268 }
269 if config.ca_file.is_some() {
270 try!(ssl_to_io(ctx.set_ca_file(config.ca_file.as_ref().unwrap())));
271 }
272 }
273 let ssl_connector = ssl_connector_builder.build();
274
275 let stream_socket_result =
276 match ssl_connector.connect(&*format!("{}:{}", host, port), socket) {
277 Ok(s) => s,
278 Err(e) => {
279 return Err(Error::new(
280 ErrorKind::Other,
281 &format!(
282 "An SSL error occurred. ({}:{})",
283 e.description(),
284 e.cause().unwrap()
285 )
286 [..],
287 ));
288 }
289 };
290
291
292
293 let stream_socket = Arc::new(Mutex::new(stream_socket_result));
294 let writer_stream = Arc::clone(&stream_socket);
295
296 Ok(Connection::new(
297 BufReader::new(NetStream::SslTcpStream(stream_socket)),
298 BufWriter::new(NetStream::SslTcpStream(writer_stream)),
299 config,
300 peer_address,
301 local_address,
302 ))
303
304
305
306 }
307}
308
309
310#[cfg(feature = "ssl")]
312fn ssl_to_io<T>(res: StdResult<T, ErrorStack>) -> Result<T> {
313 match res {
314 Ok(x) => Ok(x),
315 Err(e) => {
316 Err(Error::new(
317 ErrorKind::Other,
318 &format!("An SSL error occurred. ({})", e.description())[..],
319 ))
320 }
321 }
322}
323
324
325
326
327pub enum NetStream {
330 UnsecuredTcpStream(TcpStream),
332 #[cfg(feature = "ssl")]
335 SslTcpStream(Arc<Mutex<SslStream<TcpStream>>>),
336}
337impl Read for NetStream {
341 fn read(&mut self, buf: &mut [u8]) -> Result<usize> {
342 match self {
343 &mut NetStream::UnsecuredTcpStream(ref mut stream) => stream.read(buf),
344 #[cfg(feature = "ssl")]
345 &mut NetStream::SslTcpStream(ref mut stream) => stream.lock().unwrap().read(buf),
346 }
347 }
348 fn read_exact(&mut self, buf: &mut [u8]) -> Result<()> {
349 match self {
350 &mut NetStream::UnsecuredTcpStream(ref mut stream) => stream.read_exact(buf),
351 #[cfg(feature = "ssl")]
352 &mut NetStream::SslTcpStream(ref mut stream) => stream.lock().unwrap().read_exact(buf),
353 }
354 }
355}
356impl Write for NetStream {
361 fn write(&mut self, buf: &[u8]) -> Result<(usize)> {
362 match self {
363 &mut NetStream::UnsecuredTcpStream(ref mut stream) => stream.write(buf),
364 #[cfg(feature = "ssl")]
365 &mut NetStream::SslTcpStream(ref mut stream) => {
366 stream.lock().unwrap().write(buf)
368 }
369 }
370
371 }
372 fn write_all(&mut self, buf: &[u8]) -> Result<()> {
373 match self {
374 &mut NetStream::UnsecuredTcpStream(ref mut stream) => stream.write_all(buf),
375 #[cfg(feature = "ssl")]
376 &mut NetStream::SslTcpStream(ref mut stream) => stream.lock().unwrap().write_all(buf),
377 }
378 }
379 fn flush(&mut self) -> Result<()> {
380 match self {
381 &mut NetStream::UnsecuredTcpStream(ref mut stream) => stream.flush(),
382 #[cfg(feature = "ssl")]
383 &mut NetStream::SslTcpStream(ref mut stream) => stream.lock().unwrap().flush(),
384 }
385 }
386}
387
388
389#[cfg(test)]
390#[allow(unused_must_use)]
391impl Drop for Connection {
392 fn drop(&mut self) {
394 info!(
395 "Drop for Connection:Dropping connection id: {}",
396 self.id.clone()
397 );
398 match self.reader.get_mut() {
399 &mut NetStream::UnsecuredTcpStream(ref mut stream) => {
400 stream.shutdown(Shutdown::Read);
401 stream.shutdown(Shutdown::Write);
402 }
403 #[cfg(feature = "ssl")]
404 &mut NetStream::SslTcpStream(ref mut ssl) => {
405 ssl.lock().unwrap().shutdown();
406 }
407 }
408 }
409}