extern crate rand;
extern crate url;
extern crate serde;
extern crate serde_json;
extern crate time;
use errors::*;
use errors::ErrorKind::*;
use self::rand::{thread_rng, Rng};
use self::serde_json::de;
use self::serde_json::value::Value;
use self::time::{Duration, SteadyTime};
use self::url::{ParseError, ParseResult, SchemeType, Url, UrlParser};
use std::cmp;
use std::io;
use std::io::{BufRead, BufReader, Write};
use std::net::TcpStream;
use std::thread;
const CIRCUIT_BREAKER_WAIT_AFTER_BREAKING_MS: i64 = 2000;
const CIRCUIT_BREAKER_WAIT_BETWEEN_ROUNDS_MS: u32 = 250;
const CIRCUIT_BREAKER_ROUNDS_BEFORE_BREAKING: u32 = 4;
const DEFAULT_NAME: &'static str = "#rustlang";
const DEFAULT_PORT: u16 = 4222;
const URI_SCHEME: &'static str = "nats";
const RETRIES_MAX: u32 = 10;
#[derive(Clone, Debug)]
struct Credentials {
username: String,
password: String
}
#[derive(Clone, Debug)]
struct ServerInfo {
host: String,
port: u16,
credentials: Option<Credentials>,
max_payload: usize
}
#[derive(Debug)]
struct ClientState {
stream_writer: TcpStream,
buf_reader: BufReader<TcpStream>,
max_payload: usize
}
#[derive(Debug)]
pub struct Client {
servers_info: Vec<ServerInfo>,
server_idx: usize,
verbose: bool,
pedantic: bool,
name: String,
state: Option<ClientState>,
circuit_breaker: Option<SteadyTime>,
sid: u64
}
#[derive(Serialize, Debug)]
struct ConnectNoCredentials {
verbose: bool,
pedantic: bool,
name: String
}
#[derive(Serialize, Debug)]
struct ConnectWithCredentials {
verbose: bool,
pedantic: bool,
name: String,
user: String,
pass: String
}
#[derive(Debug)]
pub struct Channel {
sid: u64
}
#[derive(Debug)]
pub struct Event {
subject: String,
channel: Channel,
msg: Vec<u8>,
inbox: Option<String>
}
pub struct Events<'t> {
client: &'t mut Client
}
impl Client {
pub fn new<T: ToStringVec>(uris: T) -> Result<Client, NatsError> {
let mut servers_info = Vec::new();
for uri in uris.to_string_vec() {
let parsed = try!(parse_nats_uri(&uri));
let host = try!(parsed.serialize_host().ok_or((InvalidClientConfig, "Missing host name")));
let port = try!(parsed.port_or_default().ok_or((InvalidClientConfig, "Invalid port number")));
let credentials = match (parsed.username(), parsed.password()) {
(None, None) | (Some(""), None) => None,
(Some(username), Some(password)) => Some(Credentials {
username: username.to_owned(), password: password.to_owned()
}),
(None, Some(_)) => return Err(NatsError::from((InvalidClientConfig, "Username can't be empty"))),
(Some(_), None) => return Err(NatsError::from((InvalidClientConfig, "Password can't be empty"))),
};
servers_info.push(ServerInfo {
host: host,
port: port,
credentials: credentials,
max_payload: 0
})
}
thread_rng().shuffle(&mut servers_info);
Ok(Client {
servers_info: servers_info,
server_idx: 0,
verbose: false,
pedantic: false,
name: DEFAULT_NAME.to_owned(),
state: None,
sid: 1,
circuit_breaker: None
})
}
pub fn set_synchronous(&mut self, synchronous: bool) {
self.verbose = synchronous;
}
pub fn set_name(&mut self, name: &str) {
self.name = name.to_owned();
}
pub fn subscribe(&mut self, subject: &str, queue: Option<&str>) -> Result<Channel, NatsError> {
try!(subject_check(subject));
let sid = self.sid;
let cmd = match queue {
None => format!("SUB {} {}\r\n", subject, sid),
Some(queue) => {
try!(queue_check(queue));
format!("SUB {} {} {}\r\n", subject, queue, sid)
}
};
let verbose = self.verbose;
try!(self.maybe_connect());
let res = self.with_reconnect(|mut state| -> Result<Channel, NatsError> {
try!(state.stream_writer.write_all(cmd.as_bytes()));
try!(wait_ok(&mut state, verbose));
Ok(Channel {
sid: sid
})
});
if res.is_ok() {
self.sid = self.sid.wrapping_add(1);
}
res
}
pub fn unsubscribe(&mut self, channel: Channel) -> Result<(), NatsError> {
let cmd = format!("UNSUB {}\r\n", channel.sid);
let verbose = self.verbose;
try!(self.maybe_connect());
self.with_reconnect(|mut state| -> Result<(), NatsError> {
try!(state.stream_writer.write_all(cmd.as_bytes()));
try!(wait_ok(&mut state, verbose));
Ok(())
})
}
pub fn unsubscribe_after(&mut self, channel: Channel, max: u64) -> Result<(), NatsError> {
let cmd = format!("UNSUB {} {}\r\n", channel.sid, max);
let verbose = self.verbose;
try!(self.maybe_connect());
self.with_reconnect(|mut state| -> Result<(), NatsError> {
try!(state.stream_writer.write_all(cmd.as_bytes()));
try!(wait_ok(&mut state, verbose));
Ok(())
})
}
pub fn publish(&mut self, subject: &str, msg: &[u8]) -> Result<(), NatsError> {
self.publish_with_optional_inbox(subject, msg, None)
}
pub fn make_request(&mut self, subject: &str, msg: &[u8]) -> Result<String, NatsError> {
let mut rng = rand::thread_rng();
let inbox: String = rng.gen_ascii_chars().take(16).collect();
let sid = try!(self.subscribe(&inbox, None));
try!(self.unsubscribe_after(sid, 1));
try!(self.publish_with_optional_inbox(subject, msg, Some(&inbox)));
Ok(inbox)
}
pub fn wait(&mut self) -> Result<Event, NatsError> {
try!(self.maybe_connect());
self.with_reconnect(|mut state| -> Result<Event, NatsError> {
let mut buf_reader = &mut state.buf_reader;
loop {
let mut line = String::new();
match buf_reader.read_line(&mut line) {
Ok(line_len) if line_len < "PING\r\n".len() =>
return Err(NatsError::from((ErrorKind::ServerProtocolError, "Incomplete server response"))),
Err(e) => return Err(NatsError::from(e)),
Ok(_) => { }
};
if line.starts_with("MSG ") {
return wait_read_msg(line, buf_reader)
}
if line != "PING\r\n" {
return Err(NatsError::from((ErrorKind::ServerProtocolError, "Server sent an unexpected response", line)));
}
let cmd = "PONG\r\n";
try!(state.stream_writer.write_all(cmd.as_bytes()));
}
})
}
pub fn events(&mut self) -> Events {
Events {
client: self
}
}
fn try_connect(&mut self) -> io::Result<()> {
let server_info = &mut self.servers_info[self.server_idx];
let stream_reader = try!(TcpStream::connect((&server_info.host as &str, server_info.port)));
let mut stream_writer = try!(stream_reader.try_clone());
let mut buf_reader = BufReader::new(stream_reader);
let mut line = String::new();
match buf_reader.read_line(&mut line) {
Ok(line_len) if line_len < "INFO {}".len() =>
return Err(io::Error::new(io::ErrorKind::InvalidInput, "Unexpected EOF")),
Err(e) => return Err(e),
Ok(_) => { }
};
if line.starts_with("INFO ") == false {
return Err(io::Error::new(io::ErrorKind::InvalidInput, "Server INFO not received"));
}
let obj: Value = try!(de::from_str(&line[5..]).
or(Err(io::Error::new(io::ErrorKind::InvalidInput, "Invalid JSON object sent by the server"))));
let obj = try!(obj.as_object().
ok_or(io::Error::new(io::ErrorKind::InvalidInput, "Invalid JSON object sent by the server")));
let max_payload = try!(try!(obj.get("max_payload").
ok_or(io::Error::new(io::ErrorKind::InvalidInput, "Server didn't send the max payload size"))).
as_u64().ok_or(io::Error::new(io::ErrorKind::InvalidInput, "Received payload size is not a u64")));
if max_payload < 1 {
return Err(io::Error::new(io::ErrorKind::InvalidInput, "Invalid max payload size received"));
}
server_info.max_payload = max_payload as usize;
let auth_required = try!(try!(obj.get("auth_required").
ok_or(io::Error::new(io::ErrorKind::InvalidInput, "Server didn't send auth_required"))).
as_boolean().ok_or(io::Error::new(io::ErrorKind::InvalidInput, "Received auth_required is not a boolean")));
let connect_json = match (auth_required, &server_info.credentials) {
(true, &Some(ref credentials)) => {
let connect = ConnectWithCredentials {
verbose: self.verbose,
pedantic: self.pedantic,
name: self.name.clone(),
user: credentials.username.clone(),
pass: credentials.password.clone()
};
try!(serde_json::to_string(&connect).or(Err(io::Error::new(io::ErrorKind::InvalidInput, "Received auth_required is not a boolean"))))
}
(false, _) | (_, &None) => {
let connect = ConnectNoCredentials {
verbose: self.verbose,
pedantic: self.pedantic,
name: self.name.clone()
};
serde_json::to_string(&connect).unwrap()
}
};
let connect_string = format!("CONNECT {}\nPING\n", connect_json);
let connect_bytes = connect_string.as_bytes();
stream_writer.write_all(connect_bytes).unwrap();
if self.verbose {
let mut line = String::new();
match buf_reader.read_line(&mut line) {
Ok(line_len) if line_len != "+OK\r\n".len() =>
return Err(io::Error::new(io::ErrorKind::InvalidInput, "Unexpected EOF")),
Err(e) => return Err(e),
Ok(_) => { }
};
if line != "+OK\r\n" {
return Err(io::Error::new(io::ErrorKind::InvalidInput, "Server +OK not received"));
}
}
let mut line = String::new();
match buf_reader.read_line(&mut line) {
Ok(line_len) if line_len != "PONG\r\n".len() =>
return Err(io::Error::new(io::ErrorKind::InvalidInput, "Unexpected EOF")),
Err(e) => return Err(e),
Ok(_) => { }
};
if line != "PONG\r\n" {
return Err(io::Error::new(io::ErrorKind::InvalidInput, "Server PONG not received"));
}
let state = ClientState {
stream_writer: stream_writer,
buf_reader: buf_reader,
max_payload: max_payload as usize
};
self.state = Some(state);
Ok(())
}
fn connect(&mut self) -> Result<(), NatsError> {
if let Some(circuit_breaker) = self.circuit_breaker {
if SteadyTime::now() - circuit_breaker < Duration::milliseconds(CIRCUIT_BREAKER_WAIT_AFTER_BREAKING_MS) {
return Err(NatsError::from((ErrorKind::ServerProtocolError, "Cluster down - Connections are temporarily suspended")));
}
self.circuit_breaker = None;
}
self.state = None;
let servers_count = self.servers_info.len();
for _ in (0..CIRCUIT_BREAKER_ROUNDS_BEFORE_BREAKING) {
for _ in (0..servers_count) {
if self.try_connect().is_ok() {
if self.state.is_none() {
panic!("Inconsistent state");
}
return Ok(());
}
self.server_idx = (self.server_idx + 1) % servers_count;
}
thread::sleep_ms(CIRCUIT_BREAKER_WAIT_BETWEEN_ROUNDS_MS);
}
self.circuit_breaker = Some(SteadyTime::now());
Err(NatsError::from((ErrorKind::ServerProtocolError,
"The entire cluster is down or unreachable")))
}
fn reconnect(&mut self) -> Result<(), NatsError> {
if let Some(mut state) = self.state.take() {
let _ = state.stream_writer.flush();
}
self.connect()
}
fn maybe_connect(&mut self) -> Result<(), NatsError> {
if self.state.is_none() {
return self.connect()
}
Ok(())
}
fn with_reconnect<F, T>(&mut self, f: F) -> Result<T, NatsError> where F: Fn(&mut ClientState) -> Result<T, NatsError> {
let mut res: Result<T, NatsError> = Err(NatsError::from((ErrorKind::IoError, "I/O error")));
for _ in (0..RETRIES_MAX) {
let mut state = self.state.take().unwrap();
res = match f(&mut state) {
e @ Err(_) => match self.reconnect() {
Err(e) => return Err(e),
Ok(_) => e
},
res @ Ok(_) => {
self.state = Some(state);
return res;
}
};
}
res
}
fn publish_with_optional_inbox(&mut self, subject: &str, msg: &[u8], inbox: Option<&str>) -> Result<(), NatsError> {
try!(subject_check(subject));
let msg_len = msg.len();
let cmd = match inbox {
None => format!("PUB {} {}\r\n", subject, msg_len),
Some(inbox) => {
try!(inbox_check(inbox));
format!("PUB {} {} {}\r\n", subject, inbox, msg_len)
}
};
let mut cmd: Vec<u8> = cmd.as_bytes().to_owned();
let cmd_len0 = cmd.len();
cmd.reserve(cmd_len0 + msg_len + 2);
cmd.push_all(msg);
cmd.push(0x0d);
cmd.push(0x0a);
let verbose = self.verbose;
try!(self.maybe_connect());
self.with_reconnect(|mut state| -> Result<(), NatsError> {
let max_payload = state.max_payload;
if cmd.len() > max_payload {
return Err(NatsError::from((ErrorKind::ClientProtocolError, "Message too large",
format!("Maximum payload size is {} bytes", max_payload))));
}
try!(state.stream_writer.write_all(&cmd));
try!(wait_ok(&mut state, verbose));
Ok(())
})
}
}
impl<'t> Iterator for Events<'t> {
type Item = Event;
fn next(&mut self) -> Option<Event> {
let mut client = &mut self.client;
match client.wait() {
Ok(event) => Some(event),
Err(_) => None
}
}
}
pub trait ToStringVec {
fn to_string_vec(self) -> Vec<String>;
}
impl ToStringVec for String {
fn to_string_vec(self) -> Vec<String> {
vec!(self)
}
}
impl<'t> ToStringVec for &'t str {
fn to_string_vec(self) -> Vec<String> {
vec!(self.to_owned())
}
}
impl ToStringVec for Vec<String> {
fn to_string_vec(self) -> Vec<String> {
self
}
}
impl<'t> ToStringVec for Vec<&'t str> {
fn to_string_vec(self) -> Vec<String> {
self.iter().map(|&x| x.to_owned()).collect()
}
}
fn space_check(name: &str, errmsg: &'static str) -> Result<(), NatsError> {
if name.contains(' ') {
return Err(NatsError::from((ErrorKind::ClientProtocolError, errmsg)));
}
Ok(())
}
fn subject_check(subject: &str) -> Result<(), NatsError> {
space_check(subject, "A subject cannot contain spaces")
}
fn inbox_check(inbox: &str) -> Result<(), NatsError> {
space_check(inbox, "An inbox name cannot contain spaces")
}
fn queue_check(queue: &str) -> Result<(), NatsError> {
space_check(queue, "A queue name cannot contain spaces")
}
fn nats_scheme_type_mapper(scheme: &str) -> SchemeType {
match scheme {
URI_SCHEME => SchemeType::Relative(DEFAULT_PORT),
_ => SchemeType::NonRelative
}
}
fn parse_nats_uri(uri: &str) -> ParseResult<Url> {
let mut parser = UrlParser::new();
parser.scheme_type_mapper(nats_scheme_type_mapper);
match parser.parse(uri) {
Ok(res) => {
if res.scheme == URI_SCHEME {
Ok(res)
} else {
Err(ParseError::InvalidScheme)
}
},
e => e
}
}
fn read_exact<R: BufRead + ?Sized>(reader: &mut R, buf: &mut Vec<u8>) -> io::Result<usize> {
let len = buf.len();
let mut to_read = len;
buf.clear();
while to_read > 0 {
let used = {
let buffer = match reader.fill_buf() {
Ok(buffer) => buffer,
Err(ref e) if e.kind() == io::ErrorKind::Interrupted => continue,
Err(e) => return Err(e)
};
let used = cmp::min(buffer.len(), to_read);
buf.push_all(&buffer[..used]);
used
};
reader.consume(used);
to_read -= used;
}
Ok(len)
}
fn wait_ok(state: &mut ClientState, verbose: bool) -> Result<(), NatsError> {
if verbose == false {
return Ok(());
}
let mut buf_reader = &mut state.buf_reader;
let mut line = String::new();
match buf_reader.read_line(&mut line) {
Ok(line_len) if line_len < "OK\r\n".len() =>
return Err(NatsError::from((ErrorKind::ServerProtocolError, "Incomplete server response"))),
Err(e) => return Err(NatsError::from(e)),
Ok(_) => { }
};
match line.as_ref() {
"+OK\r\n" => { },
"PING\r\n" => {
let pong = "PONG\r\n".as_bytes();
try!(state.stream_writer.write_all(pong));
},
_ => return Err(NatsError::from((ErrorKind::ServerProtocolError,
"Received unexpected response from the server", line)))
}
Ok(())
}
fn wait_read_msg(line: String, buf_reader: &mut BufReader<TcpStream>) -> Result<Event, NatsError> {
if line.len() < "MSG _ _ _\r\n".len() {
return Err(NatsError::from((ErrorKind::ServerProtocolError, "Incomplete server response", line.clone())));
}
let line = line.trim_right();
let mut parts = line[4..].split(' ');
let subject = try!(parts.next().
ok_or(NatsError::from((ErrorKind::ServerProtocolError, "Unsupported server response", line.to_owned()))));
let sid: u64 = try!(parts.next().
ok_or(NatsError::from((ErrorKind::ServerProtocolError, "Unsupported server response", line.to_owned())))).
parse().unwrap_or(0);
let inbox_or_len_s = try!(parts.next().
ok_or(NatsError::from((ErrorKind::ServerProtocolError, "Unsupported server response", line.to_owned()))));
let mut inbox: Option<String> = None;
let len_s = match parts.next() {
None => inbox_or_len_s,
Some(len_s) => {
inbox = Some(inbox_or_len_s.to_owned());
len_s
}
};
let len: usize = try!(len_s.parse().ok().
ok_or(NatsError::from((ErrorKind::ServerProtocolError, "Suspicous message length",
format!("{} (len: [{}])", line, len_s)))));
let mut msg: Vec<u8> = vec![0; len];
try!(read_exact(buf_reader, &mut msg));
let mut crlf: Vec<u8> = vec![0; 2];
try!(read_exact(buf_reader, &mut crlf));
if crlf[0] != 0x0d || crlf[1] != 0x0a {
return Err(NatsError::from((ErrorKind::ServerProtocolError, "Missing CRLF after a message", line.to_owned())))
}
let event = Event {
subject: subject.to_owned(),
channel: Channel {
sid: sid
},
msg: msg,
inbox: inbox
};
Ok(event)
}
#[test]
fn client_test() {
let mut client = Client::new(vec!("nats://user:password@127.0.0.1")).unwrap();
client.set_synchronous(false);
client.set_name("test");
client.subscribe("chan", None).unwrap();
client.publish("chan", "test".as_bytes()).unwrap();
client.wait().unwrap();
let s = client.subscribe("chan2", Some("queue")).unwrap();
client.unsubscribe(s).unwrap();
client.make_request("chan", "test".as_bytes()).unwrap();
client.wait().unwrap();
client.subscribe("chan.*", None).unwrap();
client.publish("chan", "test1".as_bytes()).unwrap();
client.publish("chan", "test2".as_bytes()).unwrap();
client.publish("chan", "test3".as_bytes()).unwrap();
client.publish("chan.last", "test4".as_bytes()).unwrap();
client.events().find(|event| event.subject == "chan.last").unwrap();
}