net_utils/net/
conn.rs

1//! Client Connection.  It supports unsecured and secured(SSL) connection
2//#[cfg(feature = "ssl")]
3//use std::borrow::ToOwned;
4#[cfg(feature = "ssl")]
5use std::error::Error as StdError;
6//#[cfg(feature = "ssl")]
7use std::io::{ErrorKind, Error};
8#[cfg(feature = "ssl")]
9use std::sync::{Arc, Mutex};
10#[cfg(feature = "ssl")]
11use std::result::Result as StdResult;
12use std::io::{Write, Read, Result, BufReader, BufWriter};
13use std::net::{SocketAddr, ToSocketAddrs, TcpStream};
14#[cfg(test)]
15use std::net::Shutdown;
16use std::os::unix::prelude::AsRawFd;
17
18#[cfg(feature = "ssl")]
19use openssl::ssl::{SslConnectorBuilder, SslMethod, SslStream, SSL_VERIFY_PEER, SSL_VERIFY_NONE};
20#[cfg(feature = "ssl")]
21use openssl::error::ErrorStack;
22#[cfg(feature = "ssl")]
23use openssl::x509;
24use std::str::FromStr;
25// use std::bool;
26use net::config;
27use uuid::Uuid;
28use std::time::Duration;
29// pub mod config;
30
31/// A Connection object.  Make sure you syncronize if uses in multiple threads
32pub struct Connection {
33    id: String,
34    /// BufReader for NetStream (TCP/SSL)
35    pub reader: BufReader<NetStream>,
36    /// BufWriter for NetStream (TCP/SSL)
37    pub writer: BufWriter<NetStream>,
38    /// Config for connection
39    config: config::Config,
40    peer_address: String,
41    local_address: String,
42}
43
44/// Implementation for Connectio
45impl Connection {
46    /// new function to create default Connection object
47    fn new(
48        reader: BufReader<NetStream>,
49        writer: BufWriter<NetStream>,
50        config: &config::Config,
51        peer_address: String,
52        local_address: String,
53    ) -> Connection {
54        Connection {
55            id: Uuid::new_v4().to_urn_string(),
56            reader: reader,
57            writer: writer,
58            config: config.clone(),
59            peer_address: peer_address,
60            local_address: local_address,
61        }
62    }
63
64    pub fn get_peer_address(&self) -> &String {
65        &self.peer_address
66    }
67    pub fn get_local_address(&self) -> &String {
68        &self.local_address
69    }
70
71    /// connection unique id
72
73    /// Creates a  TCP connection to the specified server.
74
75    pub fn connect(config: &config::Config) -> Result<Connection> {
76        if config.use_ssl.unwrap_or(false) {
77            Connection::connect_ssl_internal(config)
78        } else {
79            Connection::connect_internal(config)
80        }
81    }
82
83    /// Creates a  TCP/SSL connection to the specified server.
84    ///If already connected, it will drop and reconnect
85
86    pub fn reconnect(&mut self) -> Result<Connection> {
87        if self.config.use_ssl.unwrap_or(false) {
88            Connection::connect_ssl_internal(&self.config)
89        } else {
90            Connection::connect_internal(&self.config)
91        }
92    }
93
94    /// Get the connection id
95    pub fn id(&self) -> &String {
96        &self.id
97    }
98
99    /// Is Valid connection
100    pub fn is_valid(&self) -> bool {
101        match self.reader.get_ref() {
102            &NetStream::UnsecuredTcpStream(ref tcp) => {
103                debug!("TCP FD:{}", tcp.as_raw_fd());
104                if tcp.as_raw_fd() < 0 { false } else { true }
105            }
106            #[cfg(feature = "ssl")]
107            &NetStream::SslTcpStream(ref ssl) => {
108                let fd = ssl.lock().unwrap().get_ref().as_raw_fd();
109                debug!("SSL FD:{}", fd);
110                if fd < 0 {
111                    return false;
112                } else {
113                    return true;
114                }
115            }
116        }
117    }
118
119    fn host_to_sock_address(host: &str, port: u16) -> Result<SocketAddr> {
120        let server = match (host, port).to_socket_addrs() {
121            Ok(mut host_iter) => {
122                match host_iter.next() {
123                    Some(mut host_addr) => return Ok(host_addr),
124                    None => {
125                        let err_str = format!("Failed to parse {}:{}. ", host, port);
126                        error!("{}", err_str);
127                        return Err(Error::new(ErrorKind::Other, err_str));
128                    }
129                }
130            } 
131            Err(e) => {
132                let err_str = format!("Failed to parse {}:{}. Error:{}", host, port, e);
133                error!("{}", err_str);
134                return Err(Error::new(ErrorKind::Other, err_str));
135            }
136        };
137        let err_str = format!("Failed to parse {}:{}. ", host, port);
138        error!("{}", err_str);
139        return Err(Error::new(ErrorKind::Other, err_str));
140    }
141
142
143    /// Creates a TCP connection with an optional timeout.
144
145    fn connect_internal(config: &config::Config) -> Result<Connection> {
146        let host: &str = &config.server.clone();
147        let port = config.port;
148        error!("Connecting to server {}:{}", host, port);
149        let mut stream_socket;
150
151        let server = try!(Connection::host_to_sock_address(host, port));
152
153        if config.connect_timeout.is_some() {
154            stream_socket = try!(TcpStream::connect_timeout(
155                &server,
156                Duration::from_millis(config.connect_timeout.unwrap()),
157            ));
158        } else {
159            stream_socket = try!(TcpStream::connect(server));
160        }
161        stream_socket.set_nodelay(true);
162        if config.read_timeout.is_some() {
163            stream_socket.set_read_timeout(Some(Duration::from_millis(
164                config.read_timeout.unwrap(),
165            )));
166        }
167        if config.write_timeout.is_some() {
168            stream_socket.set_write_timeout(Some(Duration::from_millis(
169                config.write_timeout.unwrap(),
170            )));
171        }
172
173        let writer_socket = try!(stream_socket.try_clone());
174        let peer_address = match stream_socket.peer_addr() {
175            Ok(sock_addr) => sock_addr.to_string(),
176            Err(_) => String::from(""),
177        };
178        let local_address = match stream_socket.local_addr() {
179            Ok(sock_addr) => sock_addr.to_string(),
180            Err(_) => String::from(""),
181        };
182        Ok(Connection::new(
183            BufReader::new(NetStream::UnsecuredTcpStream(stream_socket)),
184            BufWriter::new(NetStream::UnsecuredTcpStream(writer_socket)),
185            config,
186            peer_address,
187            local_address,
188        ))
189    }
190
191
192
193
194    /// Panics because SSL support was not included at compilation.
195    #[cfg(not(feature = "ssl"))]
196    fn connect_ssl_internal(config: &config::Config) -> Result<Connection> {
197        panic!(
198            "Cannot connect to {}:{} over SSL without compiling with SSL support.",
199            config.server.clone(),
200            config.port
201        )
202    }
203
204    /// Creates a  TCP connection over SSL.
205    #[cfg(feature = "ssl")]
206    fn connect_ssl_internal(config: &config::Config) -> Result<Connection> {
207        let host: &str = &config.server.clone();
208        let port = config.port;
209        info!("Connecting to server {}:{}", host, port);
210
211        let mut socket;
212        let server = try!(Connection::host_to_sock_address(host, port));
213
214        if config.connect_timeout.is_some() {
215            socket = try!(TcpStream::connect_timeout(
216                &server,
217                Duration::from_millis(config.connect_timeout.unwrap()),
218            ));
219        } else {
220            socket = try!(TcpStream::connect(server));
221        }
222        socket.set_nodelay(true);
223
224        let peer_address = match socket.peer_addr() {
225            Ok(sock_addr) => sock_addr.to_string(),
226            Err(_) => String::from(""),
227        };
228        let local_address = match socket.local_addr() {
229            Ok(sock_addr) => sock_addr.to_string(),
230            Err(_) => String::from(""),
231        };
232
233        if config.read_timeout.is_some() {
234            socket.set_read_timeout(Some(Duration::from_millis(config.read_timeout.unwrap())));
235        }
236        if config.write_timeout.is_some() {
237            socket.set_write_timeout(Some(Duration::from_millis(config.write_timeout.unwrap())));
238        }
239
240
241        let mut ssl_connector_builder = SslConnectorBuilder::new(SslMethod::tls()).unwrap();
242        {
243            let ctx = ssl_connector_builder.builder_mut();
244
245            ctx.set_default_verify_paths().unwrap();
246
247            // verify peer
248            if config.verify.unwrap_or(false) {
249                ctx.set_verify(SSL_VERIFY_PEER);
250            } else {
251                ctx.set_verify(SSL_VERIFY_NONE);
252            }
253            // verify depth
254            if config.verify_depth.unwrap_or(0) > 0 {
255                ctx.set_verify_depth(config.verify_depth.unwrap());
256            }
257            if config.certificate_file.is_some() {
258                try!(ssl_to_io(ctx.set_certificate_file(
259                    config.certificate_file.as_ref().unwrap(),
260                    x509::X509_FILETYPE_PEM,
261                )));
262            }
263            if config.private_key_file.is_some() {
264                try!(ssl_to_io(ctx.set_private_key_file(
265                    config.private_key_file.as_ref().unwrap(),
266                    x509::X509_FILETYPE_PEM,
267                )));
268            }
269            if config.ca_file.is_some() {
270                try!(ssl_to_io(ctx.set_ca_file(config.ca_file.as_ref().unwrap())));
271            }
272        }
273        let ssl_connector = ssl_connector_builder.build();
274
275        let stream_socket_result =
276            match ssl_connector.connect(&*format!("{}:{}", host, port), socket) {
277                Ok(s) => s,
278                Err(e) => {
279                    return Err(Error::new(
280                        ErrorKind::Other,
281                        &format!(
282                            "An SSL error occurred. ({}:{})",
283                            e.description(),
284                            e.cause().unwrap()
285                        )
286                            [..],
287                    ));
288                }
289            };
290
291
292
293        let stream_socket = Arc::new(Mutex::new(stream_socket_result));
294        let writer_stream = Arc::clone(&stream_socket);
295
296        Ok(Connection::new(
297            BufReader::new(NetStream::SslTcpStream(stream_socket)),
298            BufWriter::new(NetStream::SslTcpStream(writer_stream)),
299            config,
300            peer_address,
301            local_address,
302        ))
303
304
305
306    }
307}
308
309
310/// Converts a Result<T, SslError> isizeo an Result<T>.
311#[cfg(feature = "ssl")]
312fn ssl_to_io<T>(res: StdResult<T, ErrorStack>) -> Result<T> {
313    match res {
314        Ok(x) => Ok(x),
315        Err(e) => {
316            Err(Error::new(
317                ErrorKind::Other,
318                &format!("An SSL error occurred. ({})", e.description())[..],
319            ))
320        }
321    }
322}
323
324
325
326
327/// An abstraction over different networked streams.
328
329pub enum NetStream {
330    /// An unsecured TcpStream.
331    UnsecuredTcpStream(TcpStream),
332    /// An SSL-secured TcpStream.
333    /// This is only available when compiled with SSL support.
334    #[cfg(feature = "ssl")]
335    SslTcpStream(Arc<Mutex<SslStream<TcpStream>>>),
336}
337// trait Reader {
338//     fn read(&mut self, buf: &mut [u8]) -> Result<usize>;
339// }
340impl Read for NetStream {
341    fn read(&mut self, buf: &mut [u8]) -> Result<usize> {
342        match self {
343            &mut NetStream::UnsecuredTcpStream(ref mut stream) => stream.read(buf),
344            #[cfg(feature = "ssl")]
345            &mut NetStream::SslTcpStream(ref mut stream) => stream.lock().unwrap().read(buf),
346        }
347    }
348    fn read_exact(&mut self, buf: &mut [u8]) -> Result<()> {
349        match self {
350            &mut NetStream::UnsecuredTcpStream(ref mut stream) => stream.read_exact(buf),
351            #[cfg(feature = "ssl")]
352            &mut NetStream::SslTcpStream(ref mut stream) => stream.lock().unwrap().read_exact(buf),
353        }
354    }
355}
356// trait Writer {
357//     fn write(&mut self, buf: &[u8]) -> Result<()>;
358//     fn write_all(&mut self, buf: &[u8]) -> Result<()>;
359// }
360impl Write for NetStream {
361    fn write(&mut self, buf: &[u8]) -> Result<(usize)> {
362        match self {
363            &mut NetStream::UnsecuredTcpStream(ref mut stream) => stream.write(buf),
364            #[cfg(feature = "ssl")]
365            &mut NetStream::SslTcpStream(ref mut stream) => {
366                // Arc::get_mut(stream).unwrap().write(buf)
367                stream.lock().unwrap().write(buf)
368            }
369        }
370
371    }
372    fn write_all(&mut self, buf: &[u8]) -> Result<()> {
373        match self {
374            &mut NetStream::UnsecuredTcpStream(ref mut stream) => stream.write_all(buf),
375            #[cfg(feature = "ssl")]
376            &mut NetStream::SslTcpStream(ref mut stream) => stream.lock().unwrap().write_all(buf),
377        }
378    }
379    fn flush(&mut self) -> Result<()> {
380        match self {
381            &mut NetStream::UnsecuredTcpStream(ref mut stream) => stream.flush(),
382            #[cfg(feature = "ssl")]
383            &mut NetStream::SslTcpStream(ref mut stream) => stream.lock().unwrap().flush(),
384        }
385    }
386}
387
388
389#[cfg(test)]
390#[allow(unused_must_use)]
391impl Drop for Connection {
392    ///drop method
393    fn drop(&mut self) {
394        info!(
395            "Drop for Connection:Dropping connection id: {}",
396            self.id.clone()
397        );
398        match self.reader.get_mut() {
399            &mut NetStream::UnsecuredTcpStream(ref mut stream) => {
400                stream.shutdown(Shutdown::Read);
401                stream.shutdown(Shutdown::Write);
402            }
403            #[cfg(feature = "ssl")]
404            &mut NetStream::SslTcpStream(ref mut ssl) => {
405                ssl.lock().unwrap().shutdown();
406            }
407        }
408    }
409}