use std::io::{self, Cursor, Read, Seek, Write};
use std::os::raw::c_char;
use crate::auth::AuthMethod;
use crate::error::Error;
use crate::error::TarantoolError;
use crate::index::IteratorType;
use crate::msgpack;
use crate::network::protocol::ProtocolError;
use crate::tuple::{ToTupleBuffer, Tuple};
use super::SyncIndex;
const MP_STR_MAX_HEADER_SIZE: usize = 5;
pub mod iproto_key {
pub const REQUEST_TYPE: u8 = 0x00;
pub const SYNC: u8 = 0x01;
pub const SCHEMA_VERSION: u8 = 0x05;
pub const SPACE_ID: u8 = 0x10;
pub const INDEX_ID: u8 = 0x11;
pub const LIMIT: u8 = 0x12;
pub const OFFSET: u8 = 0x13;
pub const ITERATOR: u8 = 0x14;
pub const INDEX_BASE: u8 = 0x15;
pub const KEY: u8 = 0x20;
pub const TUPLE: u8 = 0x21;
pub const FUNCTION_NAME: u8 = 0x22;
pub const USER_NAME: u8 = 0x23;
pub const EXPR: u8 = 0x27;
pub const OPS: u8 = 0x28;
pub const DATA: u8 = 0x30;
pub const ERROR: u8 = 0x31;
pub const SQL_TEXT: u8 = 0x40;
pub const SQL_BIND: u8 = 0x41;
pub const SQL_INFO: u8 = 0x42;
pub const STMT_ID: u8 = 0x43;
pub const ERROR_EXT: u8 = 0x52;
pub const CLUSTER_UUID: u8 = 0x5c;
}
use iproto_key::*;
crate::define_enum_with_introspection! {
#[non_exhaustive]
#[repr(C)]
pub enum IProtoType {
Ok = 0,
Select = 1,
Insert = 2,
Replace = 3,
Update = 4,
Delete = 5,
LegacyCall = 6,
Auth = 7,
Eval = 8,
Upsert = 9,
Call = 10,
Execute = 11,
Nop = 12,
Prepare = 13,
Begin = 14,
Commit = 15,
Rollback = 16,
Id = 73,
Ping = 64,
Error = 1 << 15,
}
}
#[inline(always)]
pub fn encode_header(
stream: &mut impl Write,
sync: SyncIndex,
request_type: IProtoType,
) -> Result<(), Error> {
let helper = Header {
sync,
iproto_type: request_type as _,
error_code: 0,
schema_version: 0,
};
helper.encode(stream)
}
#[inline]
pub fn chap_sha1_prepare(password: impl AsRef<[u8]>, salt: &[u8; 20]) -> Vec<u8> {
use sha1::{Digest as Sha1Digest, Sha1};
let mut hasher = Sha1::new();
hasher.update(password);
let mut step_1_and_scramble = hasher.finalize();
let mut hasher = Sha1::new();
hasher.update(step_1_and_scramble);
let step_2 = hasher.finalize();
let mut hasher = Sha1::new();
hasher.update(salt);
hasher.update(step_2);
let step_3 = hasher.finalize();
step_1_and_scramble
.iter_mut()
.zip(step_3.iter())
.for_each(|(a, b)| *a ^= *b);
let scramble_bytes = step_1_and_scramble.to_vec();
debug_assert_eq!(scramble_bytes.len(), 20);
scramble_bytes
}
#[inline]
pub fn chap_sha1_auth_data(password: &str, salt: &[u8; 20]) -> Vec<u8> {
let hashed_data = chap_sha1_prepare(password, salt);
let hashed_len = hashed_data.len();
let mut res = Vec::with_capacity(hashed_len + MP_STR_MAX_HEADER_SIZE);
rmp::encode::write_str_len(&mut res, hashed_len as _).expect("Can't fail for a Vec");
res.write_all(&hashed_data).expect("Can't fail for a Vec");
res
}
#[cfg(feature = "picodata")]
#[inline]
pub fn ldap_prepare(password: impl AsRef<[u8]>) -> Vec<u8> {
password.as_ref().to_vec()
}
#[cfg(feature = "picodata")]
#[inline]
pub fn ldap_auth_data(password: &str) -> Vec<u8> {
let hashed_data = ldap_prepare(password);
let hashed_len = hashed_data.len();
let mut res = Vec::with_capacity(hashed_len + MP_STR_MAX_HEADER_SIZE);
rmp::encode::write_str_len(&mut res, hashed_len as _).expect("Can't fail for a Vec");
res.write_all(&hashed_data).expect("Can't fail for a Vec");
res
}
#[cfg(feature = "picodata")]
#[inline]
pub fn md5_prepare(user: &str, password: impl AsRef<[u8]>, salt: &[u8; 4]) -> Vec<u8> {
use md5::{Digest as Md5Digest, Md5};
let mut md5 = Md5::new();
md5.update(password);
md5.update(user);
let shadow_pass = format!("{:x}", md5.finalize_reset());
md5.update(shadow_pass);
md5.update(salt);
let client_pass = format!("md5{:x}", md5.finalize());
client_pass.into_bytes()
}
#[cfg(feature = "picodata")]
#[inline]
pub fn md5_auth_data(user: &str, password: &str, salt: &[u8; 4]) -> Vec<u8> {
let hashed_data = md5_prepare(user, password, salt);
let hashed_len = hashed_data.len();
let mut res = Vec::with_capacity(hashed_len + MP_STR_MAX_HEADER_SIZE);
rmp::encode::write_str_len(&mut res, hashed_len as _).expect("Can't fail for a Vec");
res.write_all(&hashed_data).expect("Can't fail for a Vec");
res
}
pub fn encode_auth(
stream: &mut impl Write,
user: &str,
password: &str,
salt: &[u8],
method: AuthMethod,
) -> Result<(), Error> {
let auth_data;
match method {
AuthMethod::ChapSha1 => {
let salt = salt
.first_chunk()
.ok_or_else(|| std::io::Error::other("bad salt length (expect 20)"))?;
auth_data = chap_sha1_auth_data(password, salt);
}
#[cfg(feature = "picodata")]
AuthMethod::Ldap => {
auth_data = ldap_auth_data(password);
}
#[cfg(feature = "picodata")]
AuthMethod::Md5 => {
let salt = salt
.first_chunk()
.ok_or_else(|| std::io::Error::other("bad salt length (expect >= 4)"))?;
auth_data = md5_auth_data(user, password, salt);
}
#[cfg(feature = "picodata")]
AuthMethod::ScramSha256 => {
use crate::error::{BoxError, TarantoolErrorCode};
return Err(BoxError::new(
TarantoolErrorCode::UnknownAuthMethod,
"scram-sha256 over iproto is not supported",
)
.into());
}
}
rmp::encode::write_map_len(stream, 2)?;
rmp::encode::write_pfix(stream, USER_NAME)?;
rmp::encode::write_str(stream, user)?;
rmp::encode::write_pfix(stream, TUPLE)?;
rmp::encode::write_array_len(stream, 2)?;
rmp::encode::write_str(stream, method.as_str())?;
stream.write_all(&auth_data)?;
Ok(())
}
pub fn encode_ping(stream: &mut impl Write) -> Result<(), Error> {
rmp::encode::write_map_len(stream, 0)?;
Ok(())
}
pub fn encode_id(stream: &mut impl Write, cluster_uuid: Option<&str>) -> Result<(), Error> {
use iproto_key::CLUSTER_UUID;
if let Some(uuid) = cluster_uuid {
rmp::encode::write_map_len(stream, 1)?;
rmp::encode::write_pfix(stream, CLUSTER_UUID)?;
rmp::encode::write_str(stream, uuid)?;
} else {
rmp::encode::write_map_len(stream, 0)?;
}
Ok(())
}
pub fn encode_execute<P>(stream: &mut impl Write, sql: &str, bind_params: &P) -> Result<(), Error>
where
P: ToTupleBuffer + ?Sized,
{
rmp::encode::write_map_len(stream, 2)?;
rmp::encode::write_pfix(stream, SQL_TEXT)?;
rmp::encode::write_str(stream, sql)?;
rmp::encode::write_pfix(stream, SQL_BIND)?;
bind_params.write_tuple_data(stream)?;
Ok(())
}
pub fn encode_call<T>(stream: &mut impl Write, function_name: &str, args: &T) -> Result<(), Error>
where
T: ToTupleBuffer + ?Sized,
{
rmp::encode::write_map_len(stream, 2)?;
rmp::encode::write_pfix(stream, FUNCTION_NAME)?;
rmp::encode::write_str(stream, function_name)?;
rmp::encode::write_pfix(stream, TUPLE)?;
args.write_tuple_data(stream)?;
Ok(())
}
pub fn encode_eval<T>(stream: &mut impl Write, expression: &str, args: &T) -> Result<(), Error>
where
T: ToTupleBuffer + ?Sized,
{
rmp::encode::write_map_len(stream, 2)?;
rmp::encode::write_pfix(stream, EXPR)?;
rmp::encode::write_str(stream, expression)?;
rmp::encode::write_pfix(stream, TUPLE)?;
args.write_tuple_data(stream)?;
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub fn encode_select<K>(
stream: &mut impl Write,
space_id: u32,
index_id: u32,
limit: u32,
offset: u32,
iterator_type: IteratorType,
key: &K,
) -> Result<(), Error>
where
K: ToTupleBuffer + ?Sized,
{
rmp::encode::write_map_len(stream, 6)?;
rmp::encode::write_pfix(stream, SPACE_ID)?;
rmp::encode::write_u32(stream, space_id)?;
rmp::encode::write_pfix(stream, INDEX_ID)?;
rmp::encode::write_u32(stream, index_id)?;
rmp::encode::write_pfix(stream, LIMIT)?;
rmp::encode::write_u32(stream, limit)?;
rmp::encode::write_pfix(stream, OFFSET)?;
rmp::encode::write_u32(stream, offset)?;
rmp::encode::write_pfix(stream, ITERATOR)?;
rmp::encode::write_u32(stream, iterator_type as u32)?;
rmp::encode::write_pfix(stream, KEY)?;
key.write_tuple_data(stream)?;
Ok(())
}
pub fn encode_insert<T>(stream: &mut impl Write, space_id: u32, value: &T) -> Result<(), Error>
where
T: ToTupleBuffer + ?Sized,
{
rmp::encode::write_map_len(stream, 2)?;
rmp::encode::write_pfix(stream, SPACE_ID)?;
rmp::encode::write_u32(stream, space_id)?;
rmp::encode::write_pfix(stream, TUPLE)?;
value.write_tuple_data(stream)?;
Ok(())
}
pub fn encode_replace<T>(stream: &mut impl Write, space_id: u32, value: &T) -> Result<(), Error>
where
T: ToTupleBuffer + ?Sized,
{
rmp::encode::write_map_len(stream, 2)?;
rmp::encode::write_pfix(stream, SPACE_ID)?;
rmp::encode::write_u32(stream, space_id)?;
rmp::encode::write_pfix(stream, TUPLE)?;
value.write_tuple_data(stream)?;
Ok(())
}
pub fn encode_update<K, Op>(
stream: &mut impl Write,
space_id: u32,
index_id: u32,
key: &K,
ops: &Op,
) -> Result<(), Error>
where
K: ToTupleBuffer + ?Sized,
Op: ToTupleBuffer + ?Sized,
{
rmp::encode::write_map_len(stream, 4)?;
rmp::encode::write_pfix(stream, SPACE_ID)?;
rmp::encode::write_u32(stream, space_id)?;
rmp::encode::write_pfix(stream, INDEX_ID)?;
rmp::encode::write_u32(stream, index_id)?;
rmp::encode::write_pfix(stream, KEY)?;
key.write_tuple_data(stream)?;
rmp::encode::write_pfix(stream, TUPLE)?;
ops.write_tuple_data(stream)?;
Ok(())
}
pub fn encode_upsert<T, Op>(
stream: &mut impl Write,
space_id: u32,
index_id: u32,
value: &T,
ops: &Op,
) -> Result<(), Error>
where
T: ToTupleBuffer + ?Sized,
Op: ToTupleBuffer + ?Sized,
{
rmp::encode::write_map_len(stream, 4)?;
rmp::encode::write_pfix(stream, SPACE_ID)?;
rmp::encode::write_u32(stream, space_id)?;
rmp::encode::write_pfix(stream, INDEX_BASE)?;
rmp::encode::write_u32(stream, index_id)?;
rmp::encode::write_pfix(stream, OPS)?;
ops.write_tuple_data(stream)?;
rmp::encode::write_pfix(stream, TUPLE)?;
value.write_tuple_data(stream)?;
Ok(())
}
pub fn encode_delete<K>(
stream: &mut impl Write,
space_id: u32,
index_id: u32,
key: &K,
) -> Result<(), Error>
where
K: ToTupleBuffer + ?Sized,
{
rmp::encode::write_map_len(stream, 3)?;
rmp::encode::write_pfix(stream, SPACE_ID)?;
rmp::encode::write_u32(stream, space_id)?;
rmp::encode::write_pfix(stream, INDEX_ID)?;
rmp::encode::write_u32(stream, index_id)?;
rmp::encode::write_pfix(stream, KEY)?;
key.write_tuple_data(stream)?;
Ok(())
}
#[derive(Debug)]
pub struct Header {
pub sync: SyncIndex,
pub iproto_type: u32,
pub error_code: u32,
pub schema_version: u64,
}
impl Header {
pub fn encode(&self, stream: &mut impl Write) -> Result<(), Error> {
rmp::encode::write_map_len(stream, 2)?;
rmp::encode::write_pfix(stream, REQUEST_TYPE)?;
rmp::encode::write_uint(stream, self.iproto_type as _)?;
rmp::encode::write_pfix(stream, SYNC)?;
rmp::encode::write_uint(stream, self.sync.0)?;
Ok(())
}
#[inline(always)]
pub fn encode_from_parts(
stream: &mut impl Write,
sync: SyncIndex,
request_type: IProtoType,
) -> Result<(), Error> {
encode_header(stream, sync, request_type)
}
pub fn decode(stream: &mut (impl Read + Seek)) -> Result<Header, Error> {
let mut sync: Option<u64> = None;
let mut iproto_type: Option<u32> = None;
let mut error_code: u32 = 0;
let mut schema_version: Option<u64> = None;
let map_len = rmp::decode::read_map_len(stream)?;
for _ in 0..map_len {
let key = rmp::decode::read_pfix(stream)?;
match key {
REQUEST_TYPE => {
let r#type: u32 = rmp::decode::read_int(stream)?;
const IPROTO_TYPE_ERROR: u32 = IProtoType::Error as _;
if (r#type & IPROTO_TYPE_ERROR) != 0 {
iproto_type = Some(IPROTO_TYPE_ERROR);
error_code = r#type & !IPROTO_TYPE_ERROR;
} else {
iproto_type = Some(r#type);
}
}
SYNC => sync = Some(rmp::decode::read_int(stream)?),
SCHEMA_VERSION => schema_version = Some(rmp::decode::read_int(stream)?),
_ => msgpack::skip_value(stream)?,
}
}
if sync.is_none() || iproto_type.is_none() || schema_version.is_none() {
return Err(io::Error::from(io::ErrorKind::InvalidData).into());
}
Ok(Header {
sync: SyncIndex(sync.unwrap()),
iproto_type: iproto_type.unwrap(),
error_code,
schema_version: schema_version.unwrap(),
})
}
}
pub struct Response<T> {
pub header: Header,
pub payload: T,
}
#[inline(always)]
pub fn decode_header(stream: &mut (impl Read + Seek)) -> Result<Header, Error> {
Header::decode(stream)
}
mod extended_error_keys {
pub const STACK: u8 = 0;
}
mod error_field {
pub const TYPE: u8 = 0x00;
pub const FILE: u8 = 0x01;
pub const LINE: u8 = 0x02;
pub const MESSAGE: u8 = 0x03;
pub const ERRNO: u8 = 0x04;
pub const CODE: u8 = 0x05;
pub const FIELDS: u8 = 0x06;
}
pub fn decode_error(stream: &mut impl Read, header: &Header) -> Result<TarantoolError, Error> {
let mut error = TarantoolError::default();
let map_len = rmp::decode::read_map_len(stream)?;
for _ in 0..map_len {
let key = rmp::decode::read_pfix(stream)?;
match key {
ERROR => {
let message = decode_string(stream)?;
error.message = Some(message.into());
error.code = header.error_code;
}
ERROR_EXT => {
if let Some(e) = decode_extended_error(stream)? {
error = e;
} else {
crate::say_verbose!("empty ERROR_EXT field");
}
}
_ => {
crate::say_verbose!("unhandled iproto key {key} when decoding error");
}
}
}
if error.message.is_none() {
return Err(ProtocolError::ResponseFieldNotFound {
key: "ERROR",
context: "required for error responses",
}
.into());
}
Ok(error)
}
pub fn decode_extended_error(stream: &mut impl Read) -> Result<Option<TarantoolError>, Error> {
let extended_error_n_fields = rmp::decode::read_map_len(stream)? as usize;
if extended_error_n_fields == 0 {
return Ok(None);
}
let mut error_info = None;
for _ in 0..extended_error_n_fields {
let key = rmp::decode::read_pfix(stream)?;
match key {
extended_error_keys::STACK => {
if error_info.is_some() {
crate::say_verbose!("duplicate error stack in response");
}
let error_stack_len = rmp::decode::read_array_len(stream)? as usize;
if error_stack_len == 0 {
continue;
}
let mut stack_nodes = Vec::with_capacity(error_stack_len);
for _ in 0..error_stack_len {
stack_nodes.push(decode_error_stack_node(stream)?);
}
for mut node in stack_nodes.into_iter().rev() {
if let Some(next_node) = error_info {
node.cause = Some(Box::new(next_node));
}
error_info = Some(node);
}
}
_ => {
crate::say_verbose!("unknown extended error key {key}");
}
}
}
Ok(error_info)
}
pub fn decode_error_stack_node(mut stream: &mut impl Read) -> Result<TarantoolError, Error> {
let mut res = TarantoolError::default();
let map_len = rmp::decode::read_map_len(stream)? as usize;
for _ in 0..map_len {
let key = rmp::decode::read_pfix(stream)?;
match key {
error_field::TYPE => {
res.error_type = Some(decode_string(stream)?.into_boxed_str());
}
error_field::FILE => {
res.file = Some(decode_string(stream)?.into_boxed_str());
}
error_field::LINE => {
res.line = Some(rmp::decode::read_int(stream)?);
}
error_field::MESSAGE => {
res.message = Some(decode_string(stream)?.into_boxed_str());
}
error_field::ERRNO => {
let n = rmp::decode::read_int(stream)?;
if n != 0 {
res.errno = Some(n);
}
}
error_field::CODE => {
res.code = rmp::decode::read_int(stream)?;
}
error_field::FIELDS => match rmp_serde::from_read(&mut stream) {
Ok(f) => {
res.fields = f;
}
Err(e) => {
crate::say_verbose!("failed decoding error fields: {e}");
}
},
_ => {
crate::say_verbose!("unexpected error field {key}");
}
}
}
Ok(res)
}
pub fn decode_string(stream: &mut impl Read) -> Result<String, Error> {
let len = rmp::decode::read_str_len(stream)? as usize;
let mut str_buf = vec![0u8; len];
stream.read_exact(&mut str_buf)?;
let res = String::from_utf8(str_buf)?;
Ok(res)
}
pub fn decode_greeting(stream: &mut impl Read) -> Result<Vec<u8>, Error> {
let mut buf = [0; 128];
stream.read_exact(&mut buf)?;
let salt = base64::decode(&buf[64..108]).unwrap();
Ok(salt)
}
pub fn decode_call(buffer: &mut Cursor<Vec<u8>>) -> Result<Tuple, Error> {
let payload_len = rmp::decode::read_map_len(buffer)?;
for _ in 0..payload_len {
let key = rmp::decode::read_pfix(buffer)?;
match key {
DATA => {
return decode_tuple(buffer);
}
_ => {
msgpack::skip_value(buffer)?;
}
};
}
Err(ProtocolError::ResponseFieldNotFound {
key: "DATA",
context: "required for CALL/EVAL responses",
}
.into())
}
pub fn decode_multiple_rows(buffer: &mut Cursor<Vec<u8>>) -> Result<Vec<Tuple>, Error> {
let payload_len = rmp::decode::read_map_len(buffer)?;
for _ in 0..payload_len {
let key = rmp::decode::read_pfix(buffer)?;
match key {
DATA => {
let items_count = rmp::decode::read_array_len(buffer)? as usize;
let mut result = Vec::with_capacity(items_count);
for _ in 0..items_count {
result.push(decode_tuple(buffer)?);
}
return Ok(result);
}
_ => {
msgpack::skip_value(buffer)?;
}
};
}
Ok(vec![])
}
pub fn decode_single_row(buffer: &mut Cursor<Vec<u8>>) -> Result<Option<Tuple>, Error> {
let payload_len = rmp::decode::read_map_len(buffer)?;
for _ in 0..payload_len {
let key = rmp::decode::read_pfix(buffer)?;
match key {
DATA => {
let items_count = rmp::decode::read_array_len(buffer)? as usize;
return Ok(if items_count == 0 {
None
} else {
Some(decode_tuple(buffer)?)
});
}
_ => {
msgpack::skip_value(buffer)?;
}
}
}
Ok(None)
}
pub fn decode_tuple(buffer: &mut Cursor<Vec<u8>>) -> Result<Tuple, Error> {
let payload_offset = buffer.position();
msgpack::skip_value(buffer)?;
let payload_len = buffer.position() - payload_offset;
let buf = buffer.get_mut();
unsafe {
Ok(Tuple::from_raw_data(
buf.as_slice().as_ptr().add(payload_offset as usize) as *mut c_char,
payload_len as u32,
))
}
}
pub fn value_slice(cursor: &mut Cursor<impl AsRef<[u8]>>) -> crate::Result<&[u8]> {
let start = cursor.position() as usize;
msgpack::skip_value(cursor)?;
Ok(&cursor.get_ref().as_ref()[start..(cursor.position() as usize)])
}