netstring/
lib.rs

1#[macro_use]
2extern crate log;
3
4pub mod channel;
5#[cfg(test)]
6mod syncbuf;
7
8use std::io::{Read, Write, Result};
9use std::io::{Error, ErrorKind};
10use std::net::Shutdown as ShutdownMode;
11use std::ops::Deref;
12
13// TODO: get rid of this
14const DIGIT_LIMIT: usize = 64;
15
16pub trait ReadNetstring: Shutdown {
17    fn read_netstring(&mut self) -> Result<String>;
18}
19
20pub trait WriteNetstring: Shutdown {
21    fn write_netstring<S: AsRef<str>>(&mut self, value: S) -> Result<()>;
22    fn flush(&mut self) -> Result<()>;
23}
24
25pub trait Shutdown {
26    fn shutdown(&self, how: ShutdownMode) -> Result<()>;
27}
28
29impl<R: Read + Shutdown> ReadNetstring for R {
30    fn read_netstring(&mut self) -> Result<String> {
31        let ln = try!(read_length(self));
32        let mut data = vec![0u8;ln];
33
34        let mut offset = 0usize;
35        let mut done = false;
36
37        while !done {
38            let r = try!(self.read(data[offset..].as_mut()));
39            offset = offset + r;
40            if r == 0 || offset == ln {
41                done = true;
42            }
43        }
44
45        // TODO: there has to be a a cleaner way to do this...
46        // read delimiter ","
47        let mut t = vec![0u8].into_boxed_slice();
48        try!(self.read(t[..].as_mut()));
49
50        // Verify delimiter
51        if t[0] != b',' {
52            return Err(Error::new(ErrorKind::InvalidData, "Expected `,` delimiter."));
53        }
54
55        // return utf8 string
56        match String::from_utf8(data) {
57            Ok(s) => Ok(s),
58            Err(err) => Err(Error::new(ErrorKind::InvalidData, err)),
59        }
60    }
61}
62
63impl<W: Write + Shutdown> WriteNetstring for W {
64    fn write_netstring<S: AsRef<str>>(&mut self, value: S) -> Result<()> {
65        let value = value.as_ref();
66        let s = format!("{}:{},", value.len(), value);
67        try!(self.write_all(s.as_bytes()));
68        Ok(())
69    }
70
71    fn flush(&mut self) -> Result<()> {
72        Write::flush(self)
73    }
74}
75
76impl Shutdown for ::std::os::unix::net::UnixStream {
77    fn shutdown(&self, how: ShutdownMode) -> Result<()> {
78        ::std::os::unix::net::UnixStream::shutdown(self, how)
79    }
80}
81
82impl Shutdown for ::std::net::TcpStream {
83    fn shutdown(&self, how: ShutdownMode) -> Result<()> {
84        ::std::net::TcpStream::shutdown(self, how)
85    }
86}
87
88impl<T: Shutdown> Shutdown for Box<T> {
89    fn shutdown(&self, how: ShutdownMode) -> Result<()> {
90        self.deref().shutdown(how)
91    }
92}
93
94impl<'a> Shutdown for &'a [u8] {
95     fn shutdown(&self, _how: ShutdownMode) -> Result<()> {
96         Ok(())
97     }
98}
99
100impl<T> Shutdown for Vec<T> {
101     fn shutdown(&self, _how: ShutdownMode) -> Result<()> {
102         Ok(())
103     }
104}
105
106impl Shutdown for ::std::io::Sink {
107    fn shutdown(&self, _how: ShutdownMode) -> Result<()> {
108        Ok(())
109    }
110}
111
112fn read_length<R: Read>(r: &mut R) -> Result<usize> {
113    let mut t = [0u8; DIGIT_LIMIT];
114    let mut current = 0usize;
115    let mut done = false;
116    while !done {
117        let r = try!(r.read(t[current..current + 1].as_mut()));
118        if r == 0 {
119            return Err(Error::new(ErrorKind::ConnectionAborted, "Connection closed by target."));
120        }
121        // Reached ":" signaling the end of the length
122        if t[current] == b':' {
123            done = true;
124        } else {
125            current += 1;
126        }
127    }
128
129    let s = match String::from_utf8(t[..current].to_vec()) {
130        Ok(s) => s,
131        Err(err) => return Err(Error::new(ErrorKind::InvalidData, err)),
132    };
133
134    let ln = match s.parse::<u64>() {
135        Ok(x) => x,
136        Err(err) => return Err(Error::new(ErrorKind::InvalidData, err)),
137    };
138
139
140    Ok(ln as usize)
141}
142
143#[cfg(test)]
144mod tests {
145    use super::*;
146
147    #[test]
148    fn basic_read() {
149        let mut raw = "5:hello,".as_bytes();
150        let x = raw.read_netstring().unwrap();
151        assert_eq!("hello", x);
152    }
153
154    #[test]
155    fn basic_write() {
156        let mut raw = vec![];
157        let _ = raw.write_netstring("hello").unwrap();
158        assert_eq!(raw, b"5:hello,");
159    }
160
161    #[test]
162    #[should_panic(expected="Expected `,` delimiter.")]
163    fn invalid_delimiter() {
164        let mut raw = "5:hello?".as_bytes();
165        raw.read_netstring().unwrap();
166    }
167
168    #[test]
169    #[should_panic(expected="Expected `,` delimiter.")]
170    fn longer() {
171        let mut raw = "10:hello,".as_bytes();
172        raw.read_netstring().unwrap();
173    }
174
175    #[test]
176    #[should_panic(expected="Expected `,` delimiter.")]
177    fn shorter() {
178        let mut raw = "2:hello,".as_bytes();
179        raw.read_netstring().unwrap();
180    }
181
182    #[test]
183    fn multiple() {
184        let mut raw = "5:hello,5:world,10:xxxxxxxxxx,".as_bytes();
185        let x1 = raw.read_netstring().unwrap();
186        let x2 = raw.read_netstring().unwrap();
187        let x3 = raw.read_netstring().unwrap();
188        assert_eq!(x1, "hello");
189        assert_eq!(x2, "world");
190        assert_eq!(x3, "xxxxxxxxxx");
191    }
192}