pub mod api;
pub use api::*;
pub mod codec;
pub use codec::*;
use crate::auth::AuthMethod;
use crate::error;
use crate::error::TarantoolError;
use std::collections::HashMap;
use std::io::{Cursor, Read, Seek};
use std::time::Duration;
#[deprecated = "use `ProtocolError` instead"]
pub type Error = ProtocolError;
#[non_exhaustive]
#[derive(thiserror::Error, Debug)]
pub enum ProtocolError {
#[error("message size hint is 0")]
ZeroSizeHint,
#[error("{key} not found in iproto response body, {context}")]
ResponseFieldNotFound {
key: &'static str,
context: &'static str,
},
#[error("{0} is not implemented yet")]
Unimplemented(String),
}
#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct SyncIndex(pub(crate) u64);
impl SyncIndex {
pub fn next_index(&mut self) -> Self {
let sync = self.0;
self.0 += 1;
Self(sync)
}
#[inline(always)]
pub fn get(&self) -> u64 {
self.0
}
}
#[deprecated = "use `TarantoolError` instead"]
pub type ResponseError = TarantoolError;
#[derive(Debug, Clone, Copy, Eq, PartialEq)]
enum State {
Init,
Id,
Auth,
Ready,
}
#[derive(Debug, Clone, Default, Eq, PartialEq)]
#[non_exhaustive]
pub struct Config {
pub creds: Option<(String, String)>,
pub auth_method: AuthMethod,
pub connect_timeout: Option<Duration>,
pub cluster_uuid: Option<String>,
}
#[derive(Debug)]
pub struct Protocol {
state: State,
msg_size_hint: Option<usize>,
outgoing: Vec<u8>,
pending_outgoing: Vec<u8>,
sync: SyncIndex,
incoming: HashMap<SyncIndex, Result<Vec<u8>, TarantoolError>>,
creds: Option<(String, String)>,
auth_method: AuthMethod,
cluster_uuid: Option<String>,
greeting_salt: Option<Vec<u8>>,
}
impl Default for Protocol {
fn default() -> Self {
Self::new()
}
}
impl Protocol {
pub fn new() -> Self {
Self {
state: State::Init,
sync: SyncIndex(0),
pending_outgoing: Vec::new(),
creds: None,
auth_method: AuthMethod::default(),
outgoing: Vec::new(),
incoming: HashMap::new(),
msg_size_hint: Some(128),
cluster_uuid: None,
greeting_salt: None,
}
}
pub fn with_config(config: Config) -> Self {
let mut protocol = Self::new();
protocol.creds = config.creds;
protocol.auth_method = config.auth_method;
protocol.cluster_uuid = config.cluster_uuid;
protocol
}
pub fn is_ready(&self) -> bool {
matches!(self.state, State::Ready)
}
pub fn send_request(&mut self, request: &impl Request) -> Result<SyncIndex, error::Error> {
let end = self.pending_outgoing.len();
let mut buf = Cursor::new(&mut self.pending_outgoing);
buf.set_position(end as u64);
write_to_buffer(&mut buf, self.sync, request)?;
self.process_pending_data();
Ok(self.sync.next_index())
}
pub fn take_response<R: Request>(
&mut self,
sync: SyncIndex,
) -> Option<Result<R::Response, error::Error>> {
let response = match self.incoming.remove(&sync)? {
Ok(response) => response,
Err(err) => return Some(Err(error::Error::Remote(err))),
};
Some(R::decode_response_body(&mut Cursor::new(response)))
}
pub fn drop_response(&mut self, sync: SyncIndex) {
self.incoming.remove(&sync);
}
pub fn read_size_hint(&self) -> usize {
if let Some(hint) = self.msg_size_hint {
hint
} else {
5
}
}
pub fn process_incoming<R: Read + Seek>(
&mut self,
chunk: &mut R,
) -> Result<Option<SyncIndex>, error::Error> {
if self.msg_size_hint.is_some() {
self.msg_size_hint = None;
self.process_message(chunk)
} else {
let hint = rmp::decode::read_u32(chunk)?;
if hint > 0 {
self.msg_size_hint = Some(hint as usize);
Ok(None)
} else {
Err(ProtocolError::ZeroSizeHint.into())
}
}
}
fn handle_error_response<R: Read + Seek>(
&self,
message: &mut R,
header: &codec::Header,
) -> Result<(), error::Error> {
if header.iproto_type == IProtoType::Error as u32 {
let error = codec::decode_error(message, header)?;
return Err(error::Error::Remote(error));
}
Ok(())
}
fn send_auth_request(
&mut self,
user: &str,
pass: &str,
salt: &[u8],
) -> Result<(), error::Error> {
debug_assert!(self.outgoing.is_empty());
let mut buf = Cursor::new(&mut self.outgoing);
let sync = self.sync.next_index();
write_to_buffer(
&mut buf,
sync,
&api::Auth {
user,
pass,
salt,
method: self.auth_method,
},
)
}
fn send_id_request(&mut self) -> Result<(), error::Error> {
debug_assert!(self.outgoing.is_empty());
let mut buf = Cursor::new(&mut self.outgoing);
let sync = self.sync.next_index();
write_to_buffer(
&mut buf,
sync,
&api::Id {
cluster_uuid: self.cluster_uuid.as_deref(),
},
)
}
fn process_message<R: Read + Seek>(
&mut self,
message: &mut R,
) -> Result<Option<SyncIndex>, error::Error> {
let sync = match self.state {
State::Init => {
let salt = codec::decode_greeting(message)?;
self.greeting_salt = Some(salt.clone());
if self.cluster_uuid.is_some() {
self.state = State::Id;
self.send_id_request()?;
} else if let Some((user, pass)) = self.creds.clone() {
self.state = State::Auth;
self.send_auth_request(&user, &pass, &salt)?;
} else {
self.state = State::Ready;
}
None
}
State::Id => {
let header = codec::Header::decode(message)?;
if header.iproto_type == IProtoType::Error as u32 {
let err = codec::decode_error(message, &header)?;
if err.code != 20 {
return Err(error::Error::Remote(err));
}
crate::say_warn!(
"IPROTO_ID: ignoring ER_INVALID_MSGPACK (code 20); vanilla Tarantool likely lacks iproto_key_type entry for CLUSTER_UUID"
);
}
if let Some((user, pass)) = self.creds.clone() {
self.state = State::Auth;
let salt = self.greeting_salt.clone().unwrap_or_default();
self.send_auth_request(&user, &pass, &salt)?;
} else {
self.state = State::Ready;
}
None
}
State::Auth => {
let header = codec::Header::decode(message)?;
self.handle_error_response(message, &header)?;
self.state = State::Ready;
None
}
State::Ready => {
let header = codec::Header::decode(message)?;
let response = if header.iproto_type == IProtoType::Error as u32 {
Err(codec::decode_error(message, &header)?)
} else {
let mut buf = Vec::new();
message.read_to_end(&mut buf)?;
Ok(buf)
};
self.incoming.insert(header.sync, response);
Some(header.sync)
}
};
self.process_pending_data();
Ok(sync)
}
pub fn ready_outgoing_len(&self) -> usize {
self.outgoing.len()
}
pub fn take_outgoing_data(&mut self) -> Vec<u8> {
std::mem::take(&mut self.outgoing)
}
fn process_pending_data(&mut self) {
if self.is_ready() {
let mut pending_data = std::mem::take(&mut self.pending_outgoing);
self.outgoing.append(&mut pending_data);
}
}
}
pub(crate) fn write_to_buffer(
buffer: &mut Cursor<&mut Vec<u8>>,
sync: SyncIndex,
request: &impl Request,
) -> Result<(), error::Error> {
let msg_start_offset = buffer.position();
rmp::encode::write_u32(buffer, 0)?;
let payload_start_offset = buffer.position();
request.encode(buffer, sync)?;
let payload_end_offset = buffer.position();
buffer.set_position(msg_start_offset);
rmp::encode::write_u32(buffer, (payload_end_offset - payload_start_offset) as u32)?;
buffer.set_position(payload_end_offset);
Ok(())
}
#[cfg(feature = "internal_test")]
mod tests {
use super::*;
fn fake_greeting() -> Vec<u8> {
let mut greeting = Vec::new();
greeting.extend([0; 63].iter());
greeting.push(b'\n');
greeting.extend(b"QK2HoFZGXTXBq2vFj7soCsHqTo6PGTF575ssUBAJLAI=".iter());
while greeting.len() < 127 {
greeting.push(0);
}
greeting.push(b'\n');
greeting
}
#[crate::test(tarantool = "crate")]
fn connection_established() {
let mut conn = Protocol::new();
assert!(!conn.is_ready());
assert_eq!(conn.msg_size_hint, Some(128));
assert_eq!(conn.read_size_hint(), 128);
conn.process_incoming(&mut Cursor::new(fake_greeting()))
.unwrap();
assert_eq!(conn.msg_size_hint, None);
assert_eq!(conn.read_size_hint(), 5);
assert!(conn.is_ready())
}
#[crate::test(tarantool = "crate")]
fn send_bytes_generated() {
let mut conn = Protocol::new();
conn.process_incoming(&mut Cursor::new(fake_greeting()))
.unwrap();
conn.send_request(&api::Ping).unwrap();
assert!(conn.ready_outgoing_len() > 0);
}
}