1#[macro_use]
2extern crate log;
3
4pub mod channel;
5#[cfg(test)]
6mod syncbuf;
7
8use std::io::{Read, Write, Result};
9use std::io::{Error, ErrorKind};
10use std::net::Shutdown as ShutdownMode;
11use std::ops::Deref;
12
13const DIGIT_LIMIT: usize = 64;
15
16pub trait ReadNetstring: Shutdown {
17 fn read_netstring(&mut self) -> Result<String>;
18}
19
20pub trait WriteNetstring: Shutdown {
21 fn write_netstring<S: AsRef<str>>(&mut self, value: S) -> Result<()>;
22 fn flush(&mut self) -> Result<()>;
23}
24
25pub trait Shutdown {
26 fn shutdown(&self, how: ShutdownMode) -> Result<()>;
27}
28
29impl<R: Read + Shutdown> ReadNetstring for R {
30 fn read_netstring(&mut self) -> Result<String> {
31 let ln = try!(read_length(self));
32 let mut data = vec![0u8;ln];
33
34 let mut offset = 0usize;
35 let mut done = false;
36
37 while !done {
38 let r = try!(self.read(data[offset..].as_mut()));
39 offset = offset + r;
40 if r == 0 || offset == ln {
41 done = true;
42 }
43 }
44
45 let mut t = vec![0u8].into_boxed_slice();
48 try!(self.read(t[..].as_mut()));
49
50 if t[0] != b',' {
52 return Err(Error::new(ErrorKind::InvalidData, "Expected `,` delimiter."));
53 }
54
55 match String::from_utf8(data) {
57 Ok(s) => Ok(s),
58 Err(err) => Err(Error::new(ErrorKind::InvalidData, err)),
59 }
60 }
61}
62
63impl<W: Write + Shutdown> WriteNetstring for W {
64 fn write_netstring<S: AsRef<str>>(&mut self, value: S) -> Result<()> {
65 let value = value.as_ref();
66 let s = format!("{}:{},", value.len(), value);
67 try!(self.write_all(s.as_bytes()));
68 Ok(())
69 }
70
71 fn flush(&mut self) -> Result<()> {
72 Write::flush(self)
73 }
74}
75
76impl Shutdown for ::std::os::unix::net::UnixStream {
77 fn shutdown(&self, how: ShutdownMode) -> Result<()> {
78 ::std::os::unix::net::UnixStream::shutdown(self, how)
79 }
80}
81
82impl Shutdown for ::std::net::TcpStream {
83 fn shutdown(&self, how: ShutdownMode) -> Result<()> {
84 ::std::net::TcpStream::shutdown(self, how)
85 }
86}
87
88impl<T: Shutdown> Shutdown for Box<T> {
89 fn shutdown(&self, how: ShutdownMode) -> Result<()> {
90 self.deref().shutdown(how)
91 }
92}
93
94impl<'a> Shutdown for &'a [u8] {
95 fn shutdown(&self, _how: ShutdownMode) -> Result<()> {
96 Ok(())
97 }
98}
99
100impl<T> Shutdown for Vec<T> {
101 fn shutdown(&self, _how: ShutdownMode) -> Result<()> {
102 Ok(())
103 }
104}
105
106impl Shutdown for ::std::io::Sink {
107 fn shutdown(&self, _how: ShutdownMode) -> Result<()> {
108 Ok(())
109 }
110}
111
112fn read_length<R: Read>(r: &mut R) -> Result<usize> {
113 let mut t = [0u8; DIGIT_LIMIT];
114 let mut current = 0usize;
115 let mut done = false;
116 while !done {
117 let r = try!(r.read(t[current..current + 1].as_mut()));
118 if r == 0 {
119 return Err(Error::new(ErrorKind::ConnectionAborted, "Connection closed by target."));
120 }
121 if t[current] == b':' {
123 done = true;
124 } else {
125 current += 1;
126 }
127 }
128
129 let s = match String::from_utf8(t[..current].to_vec()) {
130 Ok(s) => s,
131 Err(err) => return Err(Error::new(ErrorKind::InvalidData, err)),
132 };
133
134 let ln = match s.parse::<u64>() {
135 Ok(x) => x,
136 Err(err) => return Err(Error::new(ErrorKind::InvalidData, err)),
137 };
138
139
140 Ok(ln as usize)
141}
142
143#[cfg(test)]
144mod tests {
145 use super::*;
146
147 #[test]
148 fn basic_read() {
149 let mut raw = "5:hello,".as_bytes();
150 let x = raw.read_netstring().unwrap();
151 assert_eq!("hello", x);
152 }
153
154 #[test]
155 fn basic_write() {
156 let mut raw = vec![];
157 let _ = raw.write_netstring("hello").unwrap();
158 assert_eq!(raw, b"5:hello,");
159 }
160
161 #[test]
162 #[should_panic(expected="Expected `,` delimiter.")]
163 fn invalid_delimiter() {
164 let mut raw = "5:hello?".as_bytes();
165 raw.read_netstring().unwrap();
166 }
167
168 #[test]
169 #[should_panic(expected="Expected `,` delimiter.")]
170 fn longer() {
171 let mut raw = "10:hello,".as_bytes();
172 raw.read_netstring().unwrap();
173 }
174
175 #[test]
176 #[should_panic(expected="Expected `,` delimiter.")]
177 fn shorter() {
178 let mut raw = "2:hello,".as_bytes();
179 raw.read_netstring().unwrap();
180 }
181
182 #[test]
183 fn multiple() {
184 let mut raw = "5:hello,5:world,10:xxxxxxxxxx,".as_bytes();
185 let x1 = raw.read_netstring().unwrap();
186 let x2 = raw.read_netstring().unwrap();
187 let x3 = raw.read_netstring().unwrap();
188 assert_eq!(x1, "hello");
189 assert_eq!(x2, "world");
190 assert_eq!(x3, "xxxxxxxxxx");
191 }
192}