pub mod api;
pub mod codec;
use std::cmp;
use std::collections::HashMap;
use std::io::{self, Cursor, Read, Seek};
use std::str::Utf8Error;
use std::vec::Drain;
use api::Request;
#[derive(thiserror::Error, Debug)]
pub enum Error {
#[error("utf8 error: {0}")]
Utf8(#[from] Utf8Error),
#[error("failed to encode: {0}")]
Encode(#[from] rmp::encode::ValueWriteError),
#[error("failed to decode: {0}")]
Decode(#[from] rmp::decode::ValueReadError),
#[error("failed to decode: {0}")]
DecodeNum(#[from] rmp::decode::NumValueReadError),
#[error("service responded with error: {0}")]
Response(#[from] ResponseError),
#[error("io error: {0}")]
Io(#[from] io::Error),
#[error("message size hint is 0")]
ZeroSizeHint,
#[error("DATA not found in response body but is required for call/eval")]
ResponseDataNotFound,
#[error("encode/decode error: {0}")]
Tarantool(Box<crate::error::Error>),
}
impl From<crate::error::Error> for Error {
fn from(err: crate::error::Error) -> Self {
Error::Tarantool(Box::new(err))
}
}
#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct SyncIndex(u64);
impl SyncIndex {
pub fn next_index(&mut self) -> Self {
let sync = self.0;
self.0 += 1;
Self(sync)
}
}
#[derive(Debug, thiserror::Error, Clone)]
#[error("{message}")]
pub struct ResponseError {
pub(crate) message: String,
}
#[derive(Debug, Clone, Copy, Eq, PartialEq)]
enum State {
Init,
Auth,
Ready,
}
#[derive(Debug, Clone, Default, Eq, PartialEq)]
pub struct Config {
pub creds: Option<(String, String)>,
}
#[derive(Debug)]
pub struct Protocol {
state: State,
msg_size_hint: Option<usize>,
outgoing: Vec<u8>,
pending_outgoing: Vec<u8>,
sync: SyncIndex,
incoming: HashMap<SyncIndex, Result<Vec<u8>, ResponseError>>,
creds: Option<(String, String)>,
}
impl Default for Protocol {
fn default() -> Self {
Self::new()
}
}
impl Protocol {
pub fn new() -> Self {
Self {
state: State::Init,
sync: SyncIndex(0),
pending_outgoing: Vec::new(),
creds: None,
outgoing: Vec::new(),
incoming: HashMap::new(),
msg_size_hint: Some(128),
}
}
pub fn with_config(config: Config) -> Self {
let mut protocol = Self::new();
protocol.creds = config.creds;
protocol
}
pub fn is_ready(&self) -> bool {
matches!(self.state, State::Ready)
}
pub fn send_request(&mut self, request: &impl Request) -> Result<SyncIndex, Error> {
let end = self.pending_outgoing.len();
let mut buf = Cursor::new(&mut self.pending_outgoing);
buf.set_position(end as u64);
write_to_buffer(&mut buf, self.sync, request)?;
self.process_pending_data();
Ok(self.sync.next_index())
}
pub fn take_response<R: Request>(
&mut self,
sync: SyncIndex,
request: &R,
) -> Option<Result<R::Response, Error>> {
let response = match self.incoming.remove(&sync)? {
Ok(response) => response,
Err(err) => return Some(Err(err.into())),
};
Some(request.decode_body(&mut Cursor::new(response)))
}
pub fn drop_response(&mut self, sync: SyncIndex) {
self.incoming.remove(&sync);
}
pub fn read_size_hint(&self) -> usize {
if let Some(hint) = self.msg_size_hint {
hint
} else {
5
}
}
pub fn process_incoming<R: Read + Seek>(
&mut self,
chunk: &mut R,
) -> Result<Option<SyncIndex>, Error> {
if self.msg_size_hint.is_some() {
self.msg_size_hint = None;
self.process_message(chunk)
} else {
let hint = rmp::decode::read_u32(chunk)?;
if hint > 0 {
self.msg_size_hint = Some(hint as usize);
Ok(None)
} else {
Err(Error::ZeroSizeHint)
}
}
}
fn process_message<R: Read + Seek>(
&mut self,
message: &mut R,
) -> Result<Option<SyncIndex>, Error> {
let sync = match self.state {
State::Init => {
let salt = codec::decode_greeting(message)?;
if let Some((user, pass)) = self.creds.as_ref() {
self.state = State::Auth;
debug_assert!(self.outgoing.is_empty());
let mut buf = Cursor::new(&mut self.outgoing);
let sync = self.sync.next_index();
write_to_buffer(
&mut buf,
sync,
&api::Auth {
user,
pass,
salt: &salt,
},
)?;
} else {
self.state = State::Ready;
}
None
}
State::Auth => {
let header = codec::decode_header(message)?;
if header.status_code != 0 {
return Err(codec::decode_error(message)?.into());
}
self.state = State::Ready;
None
}
State::Ready => {
let header = codec::decode_header(message)?;
let response = if header.status_code != 0 {
Err(codec::decode_error(message)?)
} else {
let mut buf = Vec::new();
message.read_to_end(&mut buf)?;
Ok(buf)
};
self.incoming.insert(header.sync, response);
Some(header.sync)
}
};
self.process_pending_data();
Ok(sync)
}
pub fn ready_outgoing_len(&self) -> usize {
self.outgoing.len()
}
pub fn drain_outgoing_data(&mut self, max: Option<usize>) -> Drain<u8> {
let bound = if let Some(max) = max {
cmp::min(self.ready_outgoing_len(), max)
} else {
self.ready_outgoing_len()
};
self.outgoing.drain(..bound)
}
fn process_pending_data(&mut self) {
if self.is_ready() {
let pending_data = self.pending_outgoing.drain(..);
self.outgoing.extend(pending_data);
}
}
}
fn write_to_buffer(
buffer: &mut Cursor<&mut Vec<u8>>,
sync: SyncIndex,
request: &impl Request,
) -> Result<(), Error> {
let msg_start_offset = buffer.position();
rmp::encode::write_u32(buffer, 0)?;
let payload_start_offset = buffer.position();
request.encode(buffer, sync)?;
let payload_end_offset = buffer.position();
buffer.set_position(msg_start_offset);
rmp::encode::write_u32(buffer, (payload_end_offset - payload_start_offset) as u32)?;
buffer.set_position(payload_end_offset);
Ok(())
}
#[cfg(feature = "internal_test")]
mod tests {
use super::*;
fn fake_greeting() -> Vec<u8> {
let mut greeting = Vec::new();
greeting.extend([0; 63].iter());
greeting.push(b'\n');
greeting.extend(b"QK2HoFZGXTXBq2vFj7soCsHqTo6PGTF575ssUBAJLAI=".iter());
while greeting.len() < 127 {
greeting.push(0);
}
greeting.push(b'\n');
greeting
}
#[crate::test(tarantool = "crate")]
fn connection_established() {
let mut conn = Protocol::new();
assert!(!conn.is_ready());
assert_eq!(conn.msg_size_hint, Some(128));
assert_eq!(conn.read_size_hint(), 128);
conn.process_incoming(&mut Cursor::new(fake_greeting()))
.unwrap();
assert_eq!(conn.msg_size_hint, None);
assert_eq!(conn.read_size_hint(), 5);
assert!(conn.is_ready())
}
#[crate::test(tarantool = "crate")]
fn send_bytes_generated() {
let mut conn = Protocol::new();
conn.process_incoming(&mut Cursor::new(fake_greeting()))
.unwrap();
conn.send_request(&api::Ping).unwrap();
assert!(conn.ready_outgoing_len() > 0);
}
}