#![macro_use]
#[cfg(feature = "cache-aio")]
use crate::cmd::CommandCacheConfig;
use crate::cmd::{cmd, cmd_len, Cmd};
use crate::connection::ConnectionLike;
use crate::errors::ErrorKind;
use crate::types::{
from_owned_redis_value, FromRedisValue, HashSet, RedisResult, ToRedisArgs, Value,
};
#[derive(Clone)]
pub struct Pipeline {
pub(crate) commands: Vec<Cmd>,
pub(crate) transaction_mode: bool,
pub(crate) ignored_commands: HashSet<usize>,
}
impl Pipeline {
pub fn new() -> Pipeline {
Self::with_capacity(0)
}
pub fn with_capacity(capacity: usize) -> Pipeline {
Pipeline {
commands: Vec::with_capacity(capacity),
transaction_mode: false,
ignored_commands: HashSet::new(),
}
}
#[inline]
pub fn atomic(&mut self) -> &mut Pipeline {
self.transaction_mode = true;
self
}
pub fn is_transaction(&self) -> bool {
self.transaction_mode
}
pub fn get_packed_pipeline(&self) -> Vec<u8> {
encode_pipeline(&self.commands, self.transaction_mode)
}
pub fn len(&self) -> usize {
self.commands.len()
}
pub fn is_empty(&self) -> bool {
self.commands.is_empty()
}
#[inline]
pub fn query<T: FromRedisValue>(&self, con: &mut dyn ConnectionLike) -> RedisResult<T> {
if !con.supports_pipelining() {
fail!((
ErrorKind::Client,
"This connection does not support pipelining."
));
}
let response = if self.commands.is_empty() {
vec![]
} else if self.transaction_mode {
con.req_packed_commands(
&encode_pipeline(&self.commands, true),
self.commands.len() + 1,
1,
)?
} else {
con.req_packed_commands(
&encode_pipeline(&self.commands, false),
0,
self.commands.len(),
)?
};
self.complete_request(response)
}
#[inline]
#[cfg(feature = "aio")]
pub async fn query_async<T: FromRedisValue>(
&self,
con: &mut impl crate::aio::ConnectionLike,
) -> RedisResult<T> {
let response = if self.commands.is_empty() {
vec![]
} else if self.transaction_mode {
con.req_packed_commands(self, self.commands.len() + 1, 1)
.await?
} else {
con.req_packed_commands(self, 0, self.commands.len())
.await?
};
self.complete_request(response)
}
#[inline]
pub fn exec(&self, con: &mut dyn ConnectionLike) -> RedisResult<()> {
self.query::<()>(con)
}
#[cfg(feature = "aio")]
pub async fn exec_async(&self, con: &mut impl crate::aio::ConnectionLike) -> RedisResult<()> {
self.query_async::<()>(con).await
}
fn complete_request<T: FromRedisValue>(&self, mut response: Vec<Value>) -> RedisResult<T> {
let response = if self.is_transaction() {
match response.pop() {
Some(Value::Nil) => {
return Ok(from_owned_redis_value(Value::Nil)?);
}
Some(Value::Array(items)) => items,
_ => {
return Err((
ErrorKind::UnexpectedReturnType,
"Invalid response when parsing multi response",
)
.into());
}
}
} else {
response
};
self.compose_response(response)
}
}
fn encode_pipeline(cmds: &[Cmd], atomic: bool) -> Vec<u8> {
let mut rv = vec![];
write_pipeline(&mut rv, cmds, atomic);
rv
}
fn write_pipeline(rv: &mut Vec<u8>, cmds: &[Cmd], atomic: bool) {
let cmds_len = cmds.iter().map(cmd_len).sum();
if atomic {
let multi = cmd("MULTI");
let exec = cmd("EXEC");
rv.reserve(cmd_len(&multi) + cmd_len(&exec) + cmds_len);
multi.write_packed_command_preallocated(rv);
for cmd in cmds {
cmd.write_packed_command_preallocated(rv);
}
exec.write_packed_command_preallocated(rv);
} else {
rv.reserve(cmds_len);
for cmd in cmds {
cmd.write_packed_command_preallocated(rv);
}
}
}
macro_rules! implement_pipeline_commands {
($struct_name:ident) => {
impl $struct_name {
#[inline]
pub fn add_command(&mut self, cmd: Cmd) -> &mut Self {
self.commands.push(cmd);
self
}
#[inline]
pub fn cmd(&mut self, name: &str) -> &mut Self {
self.add_command(cmd(name))
}
pub fn cmd_iter(&self) -> impl Iterator<Item = &Cmd> {
self.commands.iter()
}
#[inline]
pub fn ignore(&mut self) -> &mut Self {
match self.commands.len() {
0 => true,
x => self.ignored_commands.insert(x - 1),
};
self
}
#[inline]
pub fn arg<T: ToRedisArgs>(&mut self, arg: T) -> &mut Self {
{
let cmd = self.get_last_command();
cmd.arg(arg);
}
self
}
#[inline]
pub fn clear(&mut self) {
self.commands.clear();
self.ignored_commands.clear();
}
#[inline]
fn get_last_command(&mut self) -> &mut Cmd {
let idx = match self.commands.len() {
0 => panic!("No command on stack"),
x => x - 1,
};
&mut self.commands[idx]
}
fn filter_ignored_results(&self, resp: Vec<Value>) -> Vec<Value> {
resp.into_iter()
.enumerate()
.filter_map(|(index, result)| {
(!self.ignored_commands.contains(&index)).then(|| result)
})
.collect()
}
fn compose_response<T: FromRedisValue>(&self, response: Vec<Value>) -> RedisResult<T> {
let server_errors: Vec<_> = response
.iter()
.enumerate()
.filter_map(|(index, value)| match value {
Value::ServerError(error) => Some((index, error.clone())),
_ => None,
})
.collect();
if server_errors.is_empty() {
Ok(from_owned_redis_value(
Value::Array(self.filter_ignored_results(response)).extract_error()?,
)?)
} else {
Err(crate::RedisError::pipeline(server_errors))
}
}
}
impl Default for $struct_name {
fn default() -> Self {
Self::new()
}
}
};
}
implement_pipeline_commands!(Pipeline);
impl Pipeline {
#[cfg(feature = "cache-aio")]
#[cfg_attr(docsrs, doc(cfg(feature = "cache-aio")))]
pub fn set_cache_config(&mut self, command_cache_config: CommandCacheConfig) -> &mut Self {
let cmd = self.get_last_command();
cmd.set_cache_config(command_cache_config);
self
}
}
#[cfg(test)]
mod tests {
use crate::{
errors::{Repr, ServerError},
pipe, ServerErrorKind,
};
use super::*;
fn test_pipe() -> Pipeline {
let mut pipeline = pipe();
pipeline
.cmd("FOO")
.cmd("BAR")
.ignore()
.cmd("baz")
.ignore()
.cmd("barvaz");
pipeline
}
fn server_error() -> Value {
Value::ServerError(ServerError(Repr::Known {
kind: ServerErrorKind::CrossSlot,
detail: None,
}))
}
#[test]
fn test_pipeline_passes_values_only_from_non_ignored_commands() {
let pipeline = test_pipe();
let inputs = vec![Value::Int(1), Value::Int(2), Value::Int(3), Value::Okay];
let result = pipeline.complete_request(inputs);
let expected = vec!["1".to_string(), "OK".to_string()];
assert_eq!(result, Ok(expected));
}
#[test]
fn test_pipeline_passes_errors_from_ignored_commands() {
let pipeline = test_pipe();
let inputs = vec![Value::Okay, server_error(), Value::Okay, server_error()];
let error = pipeline.compose_response::<Vec<Value>>(inputs).unwrap_err();
let error_message = error.to_string();
assert!(error_message.contains("Index 1"), "{error_message}");
assert!(error_message.contains("Index 3"), "{error_message}");
}
}