use tracing::{debug, trace};
use crate::{
common::{
ErrorSpecific, Id, PutRequest, PutRequestSpecific, RequestSpecific, RequestTypeSpecific,
},
Node,
};
use super::socket::KrpcSocket;
#[derive(Debug)]
pub struct PutQuery {
pub target: Id,
stored_at: u8,
inflight_requests: Vec<u32>,
pub request: PutRequestSpecific,
errors: Vec<(u8, ErrorSpecific)>,
extra_nodes: Box<[Node]>,
}
impl PutQuery {
pub fn new(target: Id, request: PutRequestSpecific, extra_nodes: Option<Box<[Node]>>) -> Self {
Self {
target,
stored_at: 0,
inflight_requests: Vec::new(),
request,
errors: Vec::new(),
extra_nodes: extra_nodes.unwrap_or(Box::new([])),
}
}
pub fn start(
&mut self,
socket: &mut KrpcSocket,
closest_nodes: &[Node],
) -> Result<(), PutError> {
if self.started() {
panic!("should not call PutQuery::start() twice");
};
let target = self.target;
trace!(?target, "PutQuery start");
if closest_nodes.is_empty() {
Err(PutQueryError::NoClosestNodes)?;
}
if closest_nodes.len() > u8::MAX as usize {
panic!("should not send PUT query to more than 256 nodes")
}
for node in closest_nodes.iter().chain(self.extra_nodes.iter()) {
if let Some(token) = node.token() {
let tid = socket.request(
node.address(),
RequestSpecific {
requester_id: Id::random(),
request_type: RequestTypeSpecific::Put(PutRequest {
token,
put_request_type: self.request.clone(),
}),
},
);
self.inflight_requests.push(tid);
}
}
Ok(())
}
pub fn started(&self) -> bool {
!self.inflight_requests.is_empty()
}
pub fn inflight(&self, tid: u32) -> bool {
self.inflight_requests.contains(&tid)
}
pub fn success(&mut self) {
debug!(target = ?self.target, "PutQuery got success response");
self.stored_at += 1
}
pub fn error(&mut self, error: ErrorSpecific) {
debug!(target = ?self.target, ?error, "PutQuery got error");
if let Some(pos) = self
.errors
.iter()
.position(|(_, err)| error.code == err.code)
{
self.errors[pos].0 += 1;
let mut i = pos;
while i > 0 && self.errors[i].0 > self.errors[i - 1].0 {
self.errors.swap(i, i - 1);
i -= 1;
}
} else {
self.errors.push((1, error));
}
}
pub fn tick(&mut self, socket: &KrpcSocket) -> Result<bool, PutError> {
if self.inflight_requests.is_empty() {
return Ok(false);
}
if let Some(most_common_error) = self.majority_nodes_rejected_put_mutable() {
let target = self.target;
debug!(
?target,
?most_common_error,
nodes_count = self.inflight_requests.len(),
"PutQuery for MutableItem was rejected by most nodes with 3xx code."
);
return Err(most_common_error)?;
}
if self.is_done(socket) {
let target = self.target;
if self.stored_at == 0 {
let most_common_error = self.most_common_error();
debug!(
?target,
?most_common_error,
nodes_count = self.inflight_requests.len(),
"Put Query: failed"
);
return Err(most_common_error
.map(|(_, error)| error)
.unwrap_or(PutQueryError::Timeout.into()));
}
debug!(?target, stored_at = ?self.stored_at, "PutQuery Done successfully");
return Ok(true);
}
Ok(false)
}
fn is_done(&self, socket: &KrpcSocket) -> bool {
!self
.inflight_requests
.iter()
.any(|&tid| socket.inflight(tid))
}
fn majority_nodes_rejected_put_mutable(&self) -> Option<ConcurrencyError> {
let half = ((self.inflight_requests.len() / 2) + 1) as u8;
if matches!(self.request, PutRequestSpecific::PutMutable(_)) {
return self.most_common_error().and_then(|(count, error)| {
if count >= half {
if let PutError::Concurrency(err) = error {
Some(err)
} else {
None
}
} else {
None
}
});
};
None
}
fn most_common_error(&self) -> Option<(u8, PutError)> {
self.errors
.first()
.and_then(|(count, error)| match error.code {
301 => Some((*count, PutError::from(ConcurrencyError::CasFailed))),
302 => Some((*count, PutError::from(ConcurrencyError::NotMostRecent))),
_ => None,
})
}
}
#[derive(thiserror::Error, Debug, Clone)]
pub enum PutError {
#[error(transparent)]
Query(#[from] PutQueryError),
#[error(transparent)]
Concurrency(#[from] ConcurrencyError),
}
#[derive(thiserror::Error, Debug, Clone)]
pub enum PutQueryError {
#[error("Failed to find any nodes close to store value at")]
NoClosestNodes,
#[error("Query Error Response")]
ErrorResponse(ErrorSpecific),
#[error("PutQuery timed out with no responses neither success or errors")]
Timeout,
}
#[derive(thiserror::Error, Debug, Clone)]
pub enum ConcurrencyError {
#[error("Conflict risk, try reading most recent item before writing again.")]
ConflictRisk,
#[error("MutableItem::seq is not the most recent, try reading most recent item before writing again.")]
NotMostRecent,
#[error("CAS check failed, try reading most recent item before writing again.")]
CasFailed,
}
#[cfg(test)]
mod tests {
use crate::{
common::{PutMutableRequestArguments, PutRequestSpecific},
MutableItem, SigningKey,
};
use super::{ConcurrencyError, PutError, PutQuery};
use crate::common::ErrorSpecific;
use crate::rpc::socket::KrpcSocket;
#[test]
fn mutable_majority_cas_failure_wins_over_completed_success() {
let signer = SigningKey::from_bytes(&[
56, 171, 62, 85, 105, 58, 155, 209, 189, 8, 59, 109, 137, 84, 84, 201, 221, 115, 7,
228, 127, 70, 4, 204, 182, 64, 77, 98, 92, 215, 27, 103,
]);
let item = MutableItem::new(signer, b"value", 1002, None);
let request = PutRequestSpecific::PutMutable(PutMutableRequestArguments::from(
item.clone(),
Some(1000),
));
let mut query = PutQuery::new(*item.target(), request, None);
query.inflight_requests = vec![1, 2, 3];
query.success();
query.error(cas_failed());
query.error(cas_failed());
let socket = KrpcSocket::client().unwrap();
assert!(matches!(
query.tick(&socket),
Err(PutError::Concurrency(ConcurrencyError::CasFailed))
));
}
fn cas_failed() -> ErrorSpecific {
ErrorSpecific {
code: 301,
description: "cas failed".to_string(),
}
}
}