use std::collections::HashMap;
use std::sync::Arc;
use std::thread;
use std::time::{Duration, Instant};
use crate::batch::batch_executor::SharedSlice;
use crate::cluster::Node;
use crate::commands::{self, Command};
use crate::errors::{ErrorKind, Result, ResultExt};
use crate::net::Connection;
use crate::policy::{BatchPolicy, Policy, PolicyLike};
use crate::{value, BatchRead, Record, ResultCode, Value};
struct BatchRecord {
batch_index: usize,
record: Option<Record>,
}
pub struct BatchReadCommand<'a, 'b> {
policy: &'b BatchPolicy,
pub node: Arc<Node>,
batch_reads: SharedSlice<BatchRead<'a>>,
offsets: Vec<usize>,
}
impl<'a, 'b> BatchReadCommand<'a, 'b> {
pub fn new(
policy: &'b BatchPolicy,
node: Arc<Node>,
batch_reads: SharedSlice<BatchRead<'a>>,
offsets: Vec<usize>,
) -> Self {
BatchReadCommand {
policy,
node,
batch_reads,
offsets,
}
}
pub fn execute(&mut self) -> Result<()> {
let mut iterations = 0;
let base_policy = self.policy.base();
let deadline = base_policy.deadline();
loop {
iterations += 1;
if let Some(max_retries) = base_policy.max_retries() {
if iterations > max_retries + 1 {
bail!(ErrorKind::Connection(format!(
"Timeout after {} tries",
iterations
)));
}
}
if iterations > 1 {
if let Some(sleep_between_retries) = base_policy.sleep_between_retries() {
thread::sleep(sleep_between_retries);
}
}
if let Some(deadline) = deadline {
if Instant::now() > deadline {
break;
}
}
let node = match self.get_node() {
Ok(node) => node,
Err(_) => continue, };
let mut conn = match node.get_connection(base_policy.timeout()) {
Ok(conn) => conn,
Err(err) => {
warn!("Node {}: {}", node, err);
continue;
}
};
self.prepare_buffer(&mut conn)
.chain_err(|| "Failed to prepare send buffer")?;
self.write_timeout(&mut conn, base_policy.timeout())
.chain_err(|| "Failed to set timeout for send buffer")?;
if let Err(err) = self.write_buffer(&mut conn) {
conn.invalidate();
warn!("Node {}: {}", node, err);
continue;
}
if let Err(err) = self.parse_result(&mut conn) {
if !commands::keep_connection(&err) {
conn.invalidate();
}
return Err(err);
}
return Ok(());
}
bail!(ErrorKind::Connection("Timeout".to_string()))
}
fn parse_group(&mut self, conn: &mut Connection, size: usize) -> Result<bool> {
while conn.bytes_read() < size {
conn.read_buffer(commands::buffer::MSG_REMAINING_HEADER_SIZE as usize)?;
match self.parse_record(conn)? {
None => return Ok(false),
Some(batch_record) => {
let batch_read = self
.batch_reads
.get_mut(batch_record.batch_index)
.expect("Invalid batch index");
batch_read.record = batch_record.record;
}
}
}
Ok(true)
}
fn parse_record(&mut self, conn: &mut Connection) -> Result<Option<BatchRecord>> {
let found_key = match ResultCode::from(conn.buffer.read_u8(Some(5))?) {
ResultCode::Ok => true,
ResultCode::KeyNotFoundError => false,
rc => bail!(ErrorKind::ServerError(rc)),
};
let info3 = conn.buffer.read_u8(Some(3))?;
if info3 & commands::buffer::INFO3_LAST == commands::buffer::INFO3_LAST {
return Ok(None);
}
conn.buffer.skip(6)?;
let generation = conn.buffer.read_u32(None)?;
let expiration = conn.buffer.read_u32(None)?;
let batch_index = conn.buffer.read_u32(None)?;
let field_count = conn.buffer.read_u16(None)? as usize; let op_count = conn.buffer.read_u16(None)? as usize;
let key = commands::StreamCommand::parse_key(conn, field_count)?;
let record = if found_key {
let mut bins: HashMap<String, Value> = HashMap::with_capacity(op_count);
for _ in 0..op_count {
conn.read_buffer(8)?;
let op_size = conn.buffer.read_u32(None)? as usize;
conn.buffer.skip(1)?;
let particle_type = conn.buffer.read_u8(None)?;
conn.buffer.skip(1)?;
let name_size = conn.buffer.read_u8(None)? as usize;
conn.read_buffer(name_size)?;
let name = conn.buffer.read_str(name_size)?;
let particle_bytes_size = op_size - (4 + name_size);
conn.read_buffer(particle_bytes_size)?;
let value =
value::bytes_to_particle(particle_type, &mut conn.buffer, particle_bytes_size)?;
bins.insert(name, value);
}
Some(Record::new(Some(key), bins, generation, expiration))
} else {
None
};
Ok(Some(BatchRecord {
batch_index: batch_index as usize,
record,
}))
}
}
impl<'a, 'b> commands::Command for BatchReadCommand<'a, 'b> {
fn write_timeout(&mut self, conn: &mut Connection, timeout: Option<Duration>) -> Result<()> {
conn.buffer.write_timeout(timeout);
Ok(())
}
fn write_buffer(&mut self, conn: &mut Connection) -> Result<()> {
conn.flush()
}
fn prepare_buffer(&mut self, conn: &mut Connection) -> Result<()> {
conn.buffer.set_batch_read(
self.policy,
self.batch_reads.clone(),
self.offsets.as_slice(),
)
}
fn get_node(&self) -> Result<Arc<Node>> {
Ok(self.node.clone())
}
fn parse_result(&mut self, conn: &mut Connection) -> Result<()> {
loop {
conn.read_buffer(8)?;
let size = conn.buffer.read_msg_size(None)?;
conn.bookmark();
if size > 0 && !self.parse_group(conn, size as usize)? {
break;
}
}
Ok(())
}
}