mod durable_log;
pub mod grpc;
mod log_store;
use std::collections::BTreeMap;
use std::io::Cursor;
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
use openraft::error::{Infallible, InstallSnapshotError};
use openraft::network::{RPCOption, RaftNetwork, RaftNetworkFactory};
use openraft::raft::{
AppendEntriesRequest, AppendEntriesResponse, InstallSnapshotRequest, InstallSnapshotResponse,
VoteRequest, VoteResponse,
};
use openraft::storage::{RaftStateMachine, Snapshot};
use openraft::{
BasicNode, Config, Entry, EntryPayload, LogId, RaftSnapshotBuilder, RaftTypeConfig,
SnapshotMeta, StorageError, StorageIOError, StoredMembership,
};
use quiver_core::WalOp;
use serde::{Deserialize, Serialize};
use tokio::sync::RwLock;
pub type NodeId = u64;
pub type LogStore = log_store::LogStore<TypeConfig>;
pub type Raft = openraft::Raft<TypeConfig>;
pub type RaftError<E = Infallible> = openraft::error::RaftError<NodeId, E>;
pub type RpcError<E = Infallible> = openraft::error::RPCError<NodeId, BasicNode, RaftError<E>>;
openraft::declare_raft_types!(
pub TypeConfig:
D = WalOp,
R = RaftResponse,
);
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct RaftResponse;
pub trait ApplyOp: Send + Sync + 'static {
fn apply(&self, op: WalOp) -> impl std::future::Future<Output = std::io::Result<()>> + Send;
fn snapshot(&self) -> impl std::future::Future<Output = std::io::Result<Vec<u8>>> + Send {
async { Ok(Vec::new()) }
}
fn restore(
&self,
data: Vec<u8>,
) -> impl std::future::Future<Output = std::io::Result<()>> + Send {
let _ = data;
async { Ok(()) }
}
}
#[derive(Debug, Default, Clone, Serialize, Deserialize)]
pub struct StateMachineData {
pub last_applied_log: Option<LogId<NodeId>>,
pub last_membership: StoredMembership<NodeId, BasicNode>,
}
#[derive(Debug, Clone)]
pub struct StoredSnapshot {
pub meta: SnapshotMeta<NodeId, BasicNode>,
pub data: Vec<u8>,
}
#[derive(Serialize, Deserialize)]
struct SnapshotPayload {
sm: StateMachineData,
engine: Vec<u8>,
}
#[derive(Debug)]
pub struct StateMachineStore<A: ApplyOp> {
pub applier: A,
state_machine: RwLock<StateMachineData>,
snapshot_idx: AtomicU64,
current_snapshot: RwLock<Option<StoredSnapshot>>,
}
impl<A: ApplyOp> StateMachineStore<A> {
pub fn new(applier: A) -> Self {
Self {
applier,
state_machine: RwLock::default(),
snapshot_idx: AtomicU64::new(0),
current_snapshot: RwLock::default(),
}
}
}
impl<A: ApplyOp> RaftSnapshotBuilder<TypeConfig> for Arc<StateMachineStore<A>> {
async fn build_snapshot(&mut self) -> Result<Snapshot<TypeConfig>, StorageError<NodeId>> {
let (sm, last_applied_log, last_membership) = {
let sm = self.state_machine.read().await;
(sm.clone(), sm.last_applied_log, sm.last_membership.clone())
};
let engine = self
.applier
.snapshot()
.await
.map_err(|e| StorageIOError::read_state_machine(&e))?;
let data = serde_json::to_vec(&SnapshotPayload { sm, engine })
.map_err(|e| StorageIOError::read_state_machine(&e))?;
let snapshot_idx = self.snapshot_idx.fetch_add(1, Ordering::Relaxed) + 1;
let snapshot_id = match last_applied_log {
Some(last) => format!("{}-{}-{}", last.leader_id, last.index, snapshot_idx),
None => format!("--{snapshot_idx}"),
};
let meta = SnapshotMeta {
last_log_id: last_applied_log,
last_membership,
snapshot_id,
};
*self.current_snapshot.write().await = Some(StoredSnapshot {
meta: meta.clone(),
data: data.clone(),
});
Ok(Snapshot {
meta,
snapshot: Box::new(Cursor::new(data)),
})
}
}
impl<A: ApplyOp> RaftStateMachine<TypeConfig> for Arc<StateMachineStore<A>> {
type SnapshotBuilder = Self;
async fn applied_state(
&mut self,
) -> Result<(Option<LogId<NodeId>>, StoredMembership<NodeId, BasicNode>), StorageError<NodeId>>
{
let sm = self.state_machine.read().await;
Ok((sm.last_applied_log, sm.last_membership.clone()))
}
async fn apply<I>(&mut self, entries: I) -> Result<Vec<RaftResponse>, StorageError<NodeId>>
where
I: IntoIterator<Item = Entry<TypeConfig>> + Send,
{
let mut responses = Vec::new();
let mut sm = self.state_machine.write().await;
for entry in entries {
sm.last_applied_log = Some(entry.log_id);
match entry.payload {
EntryPayload::Blank => {}
EntryPayload::Normal(ref op) => {
self.applier
.apply(op.clone())
.await
.map_err(|e| StorageIOError::apply(entry.log_id, &e))?;
}
EntryPayload::Membership(ref mem) => {
sm.last_membership = StoredMembership::new(Some(entry.log_id), mem.clone());
}
}
responses.push(RaftResponse);
}
Ok(responses)
}
async fn get_snapshot_builder(&mut self) -> Self::SnapshotBuilder {
self.clone()
}
async fn begin_receiving_snapshot(
&mut self,
) -> Result<Box<<TypeConfig as RaftTypeConfig>::SnapshotData>, StorageError<NodeId>> {
Ok(Box::new(Cursor::new(Vec::new())))
}
async fn install_snapshot(
&mut self,
meta: &SnapshotMeta<NodeId, BasicNode>,
snapshot: Box<<TypeConfig as RaftTypeConfig>::SnapshotData>,
) -> Result<(), StorageError<NodeId>> {
let data = snapshot.into_inner();
let payload: SnapshotPayload = serde_json::from_slice(&data)
.map_err(|e| StorageIOError::read_snapshot(Some(meta.signature()), &e))?;
self.applier
.restore(payload.engine)
.await
.map_err(|e| StorageIOError::write_snapshot(Some(meta.signature()), &e))?;
*self.state_machine.write().await = payload.sm;
*self.current_snapshot.write().await = Some(StoredSnapshot {
meta: meta.clone(),
data,
});
Ok(())
}
async fn get_current_snapshot(
&mut self,
) -> Result<Option<Snapshot<TypeConfig>>, StorageError<NodeId>> {
Ok(self
.current_snapshot
.read()
.await
.as_ref()
.map(|s| Snapshot {
meta: s.meta.clone(),
snapshot: Box::new(Cursor::new(s.data.clone())),
}))
}
}
#[derive(Debug, Clone, Default)]
pub struct NoNetwork;
impl RaftNetworkFactory<TypeConfig> for NoNetwork {
type Network = NoConnection;
async fn new_client(&mut self, _target: NodeId, _node: &BasicNode) -> Self::Network {
unreachable!("single-member raft (4a) has no peers; real RPC arrives in 4b")
}
}
#[derive(Debug, Clone)]
pub struct NoConnection;
impl RaftNetwork<TypeConfig> for NoConnection {
async fn append_entries(
&mut self,
_req: AppendEntriesRequest<TypeConfig>,
_option: RPCOption,
) -> Result<AppendEntriesResponse<NodeId>, RpcError> {
unreachable!("single-member raft (4a) sends no RPCs; real network arrives in 4b")
}
async fn install_snapshot(
&mut self,
_req: InstallSnapshotRequest<TypeConfig>,
_option: RPCOption,
) -> Result<InstallSnapshotResponse<NodeId>, RpcError<InstallSnapshotError>> {
unreachable!("single-member raft (4a) sends no RPCs; real network arrives in 4b")
}
async fn vote(
&mut self,
_req: VoteRequest<NodeId>,
_option: RPCOption,
) -> Result<VoteResponse<NodeId>, RpcError> {
unreachable!("single-member raft (4a) sends no RPCs; real network arrives in 4b")
}
}
pub async fn start_single_member<A: ApplyOp>(
node_id: NodeId,
applier: A,
) -> Result<Raft, Box<dyn std::error::Error + Send + Sync>> {
let config = Arc::new(
Config {
heartbeat_interval: 250,
election_timeout_min: 500,
election_timeout_max: 1000,
..Default::default()
}
.validate()?,
);
let log_store = LogStore::default();
let state_machine = Arc::new(StateMachineStore::new(applier));
let raft = openraft::Raft::new(node_id, config, NoNetwork, log_store, state_machine).await?;
let mut members = BTreeMap::new();
members.insert(node_id, BasicNode::default());
raft.initialize(members).await?;
Ok(raft)
}
pub struct EngineApplier {
db: Arc<std::sync::RwLock<quiver_embed::Database>>,
}
impl EngineApplier {
pub fn new(db: Arc<std::sync::RwLock<quiver_embed::Database>>) -> Self {
Self { db }
}
}
impl ApplyOp for EngineApplier {
async fn apply(&self, op: WalOp) -> std::io::Result<()> {
let db = Arc::clone(&self.db);
tokio::task::spawn_blocking(move || {
let mut guard = db
.write()
.map_err(|_| std::io::Error::other("database lock poisoned"))?;
guard
.apply_replicated(op)
.map_err(|e| std::io::Error::other(e.to_string()))
})
.await
.map_err(|e| std::io::Error::other(format!("blocking apply task failed: {e}")))?
}
async fn snapshot(&self) -> std::io::Result<Vec<u8>> {
let db = Arc::clone(&self.db);
tokio::task::spawn_blocking(move || {
let guard = db
.read()
.map_err(|_| std::io::Error::other("database lock poisoned"))?;
let ops = guard
.replication_snapshot()
.map_err(|e| std::io::Error::other(e.to_string()))?;
postcard::to_allocvec(&ops).map_err(|e| std::io::Error::other(e.to_string()))
})
.await
.map_err(|e| std::io::Error::other(format!("blocking snapshot task failed: {e}")))?
}
async fn restore(&self, data: Vec<u8>) -> std::io::Result<()> {
let db = Arc::clone(&self.db);
tokio::task::spawn_blocking(move || {
let ops: Vec<WalOp> =
postcard::from_bytes(&data).map_err(|e| std::io::Error::other(e.to_string()))?;
let mut guard = db
.write()
.map_err(|_| std::io::Error::other("database lock poisoned"))?;
for name in guard.collection_names() {
guard
.drop_collection(&name)
.map_err(|e| std::io::Error::other(e.to_string()))?;
}
for op in ops {
guard
.apply_replicated(op)
.map_err(|e| std::io::Error::other(e.to_string()))?;
}
Ok(())
})
.await
.map_err(|e| std::io::Error::other(format!("blocking restore task failed: {e}")))?
}
}
pub struct RaftShard {
pub raft: Raft,
pub node_id: NodeId,
pub members: std::sync::RwLock<BTreeMap<NodeId, String>>,
pub create_lock: tokio::sync::Mutex<()>,
}
impl RaftShard {
pub fn member_url(&self, id: NodeId) -> Option<String> {
self.members.read().ok()?.get(&id).cloned()
}
pub async fn add_voter(
&self,
id: NodeId,
url: String,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
self.raft
.add_learner(id, BasicNode::new(url.clone()), true)
.await?;
self.raft
.change_membership(
openraft::ChangeMembers::AddVoterIds([id].into_iter().collect()),
true,
)
.await?;
if let Ok(mut m) = self.members.write() {
m.insert(id, url);
}
Ok(())
}
pub async fn remove_voter(
&self,
id: NodeId,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
self.raft
.change_membership(
openraft::ChangeMembers::RemoveVoters([id].into_iter().collect()),
false,
)
.await?;
if let Ok(mut m) = self.members.write() {
m.remove(&id);
}
Ok(())
}
}
pub async fn start_member(
node_id: NodeId,
members: BTreeMap<NodeId, String>,
applier: EngineApplier,
log_dir: &std::path::Path,
) -> Result<RaftShard, Box<dyn std::error::Error + Send + Sync>> {
let config = Arc::new(
Config {
heartbeat_interval: 250,
election_timeout_min: 500,
election_timeout_max: 1000,
..Default::default()
}
.validate()?,
);
let log_store = durable_log::DurableLogStore::open(log_dir)?;
let state_machine = Arc::new(StateMachineStore::new(applier));
let raft = openraft::Raft::new(
node_id,
config,
grpc::GrpcRaftNetwork,
log_store,
state_machine,
)
.await?;
if members.keys().next() == Some(&node_id) && !raft.is_initialized().await? {
let nodes: BTreeMap<NodeId, BasicNode> = members
.iter()
.map(|(id, url)| (*id, BasicNode::new(url.clone())))
.collect();
if let Err(e) = raft.initialize(nodes).await {
tracing::debug!(error = %e, "raft initialize (already bootstrapped?)");
}
}
Ok(RaftShard {
raft,
node_id,
members: std::sync::RwLock::new(members),
create_lock: tokio::sync::Mutex::new(()),
})
}
#[cfg(test)]
mod tests {
use std::collections::HashSet;
use std::sync::Arc;
use std::time::Duration;
use openraft::ServerState;
use quiver_embed::{Database, Descriptor, DistanceMetric, Dtype, SearchParams};
use tokio::sync::Mutex;
use super::*;
struct NoopApplier;
impl ApplyOp for NoopApplier {
async fn apply(&self, _op: WalOp) -> std::io::Result<()> {
Ok(())
}
}
struct EngineApplier(Arc<Mutex<Database>>);
impl ApplyOp for EngineApplier {
async fn apply(&self, op: WalOp) -> std::io::Result<()> {
self.0
.lock()
.await
.apply_replicated(op)
.map_err(|e| std::io::Error::other(e.to_string()))
}
async fn snapshot(&self) -> std::io::Result<Vec<u8>> {
let ops = self
.0
.lock()
.await
.replication_snapshot()
.map_err(|e| std::io::Error::other(e.to_string()))?;
postcard::to_allocvec(&ops).map_err(|e| std::io::Error::other(e.to_string()))
}
async fn restore(&self, data: Vec<u8>) -> std::io::Result<()> {
let ops: Vec<WalOp> =
postcard::from_bytes(&data).map_err(|e| std::io::Error::other(e.to_string()))?;
let mut db = self.0.lock().await;
for name in db.collection_names() {
db.drop_collection(&name)
.map_err(|e| std::io::Error::other(e.to_string()))?;
}
for op in ops {
db.apply_replicated(op)
.map_err(|e| std::io::Error::other(e.to_string()))?;
}
Ok(())
}
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn single_member_group_applies_committed_ops_to_engine() {
let src_dir = tempfile::tempdir().unwrap();
let mut src = Database::open(src_dir.path()).unwrap();
src.create_collection("docs", Descriptor::new(4, Dtype::F32, DistanceMetric::L2))
.unwrap();
src.upsert(
"docs",
"a",
&[1.0, 0.0, 0.0, 0.0],
&serde_json::json!({"t": "a"}),
)
.unwrap();
src.upsert(
"docs",
"b",
&[0.0, 1.0, 0.0, 0.0],
&serde_json::json!({"t": "b"}),
)
.unwrap();
let ops = src.replication_snapshot().unwrap();
assert!(ops.len() >= 3, "create-collection + two upserts");
let tgt_dir = tempfile::tempdir().unwrap();
let target = Arc::new(Mutex::new(Database::open(tgt_dir.path()).unwrap()));
let raft = start_single_member(1, EngineApplier(target.clone()))
.await
.unwrap();
raft.wait(Some(Duration::from_secs(10)))
.state(ServerState::Leader, "single member becomes leader")
.await
.unwrap();
for op in ops {
raft.client_write(op).await.unwrap();
}
let params = SearchParams {
k: 2,
ef_search: 16,
with_payload: false,
with_vector: false,
filter: None,
};
let hits = target
.lock()
.await
.search("docs", &[1.0, 0.0, 0.0, 0.0], ¶ms)
.unwrap();
assert_eq!(hits.first().map(|m| m.id.as_str()), Some("a"));
let ids: HashSet<_> = hits.iter().map(|m| m.id.clone()).collect();
assert!(ids.contains("a") && ids.contains("b"), "both points served");
raft.shutdown().await.unwrap();
}
struct FailingApplier;
impl ApplyOp for FailingApplier {
async fn apply(&self, _op: WalOp) -> std::io::Result<()> {
Err(std::io::Error::other("simulated engine apply fault"))
}
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn apply_failure_surfaces_not_swallowed() {
let raft = start_single_member(1, FailingApplier).await.unwrap();
raft.wait(Some(Duration::from_secs(10)))
.state(ServerState::Leader, "leader")
.await
.unwrap();
let op = WalOp::Delete {
collection_id: quiver_core::CollectionId(1),
external_id: "x".to_owned(),
};
assert!(
raft.client_write(op).await.is_err(),
"an engine apply fault must surface, not be swallowed"
);
}
#[tokio::test]
async fn state_machine_snapshot_roundtrip() {
use std::io::Cursor;
use openraft::SnapshotMeta;
let sm = Arc::new(StateMachineStore::new(NoopApplier));
let (applied, _membership) = sm.clone().applied_state().await.unwrap();
assert!(applied.is_none());
assert!(sm.clone().get_current_snapshot().await.unwrap().is_none());
let mut builder = sm.clone().get_snapshot_builder().await;
let built = builder.build_snapshot().await.unwrap();
assert!(!built.meta.snapshot_id.is_empty());
assert!(built.meta.last_log_id.is_none());
assert!(sm.clone().get_current_snapshot().await.unwrap().is_some());
{
let mut data = sm.state_machine.write().await;
data.last_applied_log = Some(openraft::LogId::new(
openraft::CommittedLeaderId::new(1, 1),
7,
));
}
let applied_snapshot = sm
.clone()
.get_snapshot_builder()
.await
.build_snapshot()
.await
.unwrap();
assert!(applied_snapshot.meta.last_log_id.is_some());
let bytes = serde_json::to_vec(&SnapshotPayload {
sm: StateMachineData::default(),
engine: Vec::new(),
})
.unwrap();
let mut receiver = sm.clone();
let mut cursor = receiver.begin_receiving_snapshot().await.unwrap();
*cursor = Cursor::new(bytes);
let meta = SnapshotMeta {
last_log_id: None,
last_membership: StoredMembership::default(),
snapshot_id: "installed".to_owned(),
};
receiver.install_snapshot(&meta, cursor).await.unwrap();
let current = sm.clone().get_current_snapshot().await.unwrap().unwrap();
assert_eq!(current.meta.snapshot_id, "installed", "install replaced it");
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn snapshot_transfers_engine_state_to_a_fresh_voter() {
let src_dir = tempfile::tempdir().unwrap();
let src = Arc::new(Mutex::new(Database::open(src_dir.path()).unwrap()));
{
let mut db = src.lock().await;
db.create_collection("docs", Descriptor::new(4, Dtype::F32, DistanceMetric::L2))
.unwrap();
db.upsert("docs", "a", &[1.0, 0.0, 0.0, 0.0], &serde_json::json!({}))
.unwrap();
db.upsert("docs", "b", &[0.0, 1.0, 0.0, 0.0], &serde_json::json!({}))
.unwrap();
}
let sm_src = Arc::new(StateMachineStore::new(EngineApplier(src.clone())));
let snap = sm_src
.clone()
.get_snapshot_builder()
.await
.build_snapshot()
.await
.unwrap();
let bytes = snap.snapshot.into_inner();
let tgt_dir = tempfile::tempdir().unwrap();
let tgt = Arc::new(Mutex::new(Database::open(tgt_dir.path()).unwrap()));
let mut receiver = Arc::new(StateMachineStore::new(EngineApplier(tgt.clone())));
receiver
.install_snapshot(&snap.meta, Box::new(std::io::Cursor::new(bytes)))
.await
.unwrap();
let params = SearchParams {
k: 2,
ef_search: 16,
with_payload: false,
with_vector: false,
filter: None,
};
let hits = tgt
.lock()
.await
.search("docs", &[1.0, 0.0, 0.0, 0.0], ¶ms)
.unwrap();
let ids: std::collections::HashSet<_> = hits.iter().map(|m| m.id.clone()).collect();
assert!(
ids.contains("a") && ids.contains("b"),
"the snapshot transferred the engine state to a fresh voter"
);
}
use std::collections::BTreeMap;
use std::sync::Mutex as StdMutex;
use openraft::BasicNode;
use openraft::error::{InstallSnapshotError, RPCError, RemoteError, Unreachable};
use openraft::network::{RPCOption, RaftNetwork, RaftNetworkFactory};
use openraft::raft::{
AppendEntriesRequest, AppendEntriesResponse, InstallSnapshotRequest,
InstallSnapshotResponse, VoteRequest, VoteResponse,
};
#[derive(Clone, Default)]
struct Switchboard {
nodes: Arc<StdMutex<BTreeMap<NodeId, Raft>>>,
}
impl Switchboard {
fn register(&self, id: NodeId, raft: Raft) {
self.nodes.lock().unwrap().insert(id, raft);
}
fn kill(&self, id: NodeId) {
self.nodes.lock().unwrap().remove(&id);
}
fn handle(&self, id: NodeId) -> Option<Raft> {
self.nodes.lock().unwrap().get(&id).cloned()
}
}
impl RaftNetworkFactory<TypeConfig> for Switchboard {
type Network = Link;
async fn new_client(&mut self, target: NodeId, _node: &BasicNode) -> Link {
Link {
target,
board: self.clone(),
}
}
}
struct Link {
target: NodeId,
board: Switchboard,
}
impl Link {
#[allow(clippy::result_large_err)]
fn target(&self) -> Result<Raft, RPCError<NodeId, BasicNode, RaftError>> {
self.board.handle(self.target).ok_or_else(|| {
RPCError::Unreachable(Unreachable::new(&std::io::Error::other("node down")))
})
}
}
impl RaftNetwork<TypeConfig> for Link {
async fn append_entries(
&mut self,
rpc: AppendEntriesRequest<TypeConfig>,
_option: RPCOption,
) -> Result<AppendEntriesResponse<NodeId>, RpcError> {
self.target()?
.append_entries(rpc)
.await
.map_err(|e| RPCError::RemoteError(RemoteError::new(self.target, e)))
}
async fn vote(
&mut self,
rpc: VoteRequest<NodeId>,
_option: RPCOption,
) -> Result<VoteResponse<NodeId>, RpcError> {
self.target()?
.vote(rpc)
.await
.map_err(|e| RPCError::RemoteError(RemoteError::new(self.target, e)))
}
async fn install_snapshot(
&mut self,
rpc: InstallSnapshotRequest<TypeConfig>,
_option: RPCOption,
) -> Result<InstallSnapshotResponse<NodeId>, RpcError<InstallSnapshotError>> {
let target = self.board.handle(self.target).ok_or_else(|| {
RPCError::Unreachable(Unreachable::new(&std::io::Error::other("node down")))
})?;
target
.install_snapshot(rpc)
.await
.map_err(|e| RPCError::RemoteError(RemoteError::new(self.target, e)))
}
}
struct Voter {
id: NodeId,
raft: Raft,
engine: Arc<Mutex<Database>>,
_dir: tempfile::TempDir,
}
async fn boot_cluster(ids: &[NodeId]) -> (Switchboard, Vec<Voter>) {
let board = Switchboard::default();
let mut voters = Vec::new();
for &id in ids {
let dir = tempfile::tempdir().unwrap();
let engine = Arc::new(Mutex::new(Database::open(dir.path()).unwrap()));
let config = Arc::new(
Config {
heartbeat_interval: 100,
election_timeout_min: 300,
election_timeout_max: 600,
..Default::default()
}
.validate()
.unwrap(),
);
let sm = Arc::new(StateMachineStore::new(EngineApplier(engine.clone())));
let raft = openraft::Raft::new(id, config, board.clone(), LogStore::default(), sm)
.await
.unwrap();
board.register(id, raft.clone());
voters.push(Voter {
id,
raft,
engine,
_dir: dir,
});
}
let members: BTreeMap<NodeId, BasicNode> =
ids.iter().map(|&id| (id, BasicNode::default())).collect();
voters[0].raft.initialize(members).await.unwrap();
(board, voters)
}
fn current_leader(board: &Switchboard, voters: &[Voter]) -> Option<NodeId> {
for v in voters {
if board.handle(v.id).is_none() {
continue; }
let leader = v.raft.metrics().borrow().current_leader;
if let Some(leader) = leader
&& board.handle(leader).is_some()
{
return Some(leader);
}
}
None
}
async fn commit(board: &Switchboard, voters: &[Voter], op: &WalOp) {
for _ in 0..100 {
if let Some(leader_id) = current_leader(board, voters)
&& let Some(leader) = voters.iter().find(|v| v.id == leader_id)
&& leader.raft.client_write(op.clone()).await.is_ok()
{
return;
}
tokio::time::sleep(Duration::from_millis(40)).await;
}
panic!("no leader committed the op within the budget");
}
async fn await_serves(engine: &Arc<Mutex<Database>>, query: &[f32], want_id: &str) {
let params = SearchParams {
k: 5,
filter: None,
ef_search: 32,
with_payload: false,
with_vector: false,
};
for _ in 0..200 {
if let Ok(hits) = engine.lock().await.search("docs", query, ¶ms)
&& hits.iter().any(|m| m.id == want_id)
{
return;
}
tokio::time::sleep(Duration::from_millis(40)).await;
}
panic!("engine never served {want_id}");
}
fn collection_ops(points: &[(&str, [f32; 4])]) -> Vec<WalOp> {
let dir = tempfile::tempdir().unwrap();
let mut db = Database::open(dir.path()).unwrap();
db.create_collection("docs", Descriptor::new(4, Dtype::F32, DistanceMetric::L2))
.unwrap();
for (id, v) in points {
db.upsert("docs", id, v, &serde_json::json!({})).unwrap();
}
db.replication_snapshot().unwrap()
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn three_member_group_applies_on_every_voter() {
let (board, voters) = boot_cluster(&[1, 2, 3]).await;
voters[0]
.raft
.wait(Some(Duration::from_secs(10)))
.state(ServerState::Leader, "bootstrap leader")
.await
.unwrap();
let a = [1.0, 0.0, 0.0, 0.0];
let b = [0.0, 1.0, 0.0, 0.0];
for op in collection_ops(&[("a", a), ("b", b)]) {
commit(&board, &voters, &op).await;
}
for v in &voters {
await_serves(&v.engine, &a, "a").await;
await_serves(&v.engine, &b, "b").await;
}
for v in &voters {
v.raft.shutdown().await.unwrap();
}
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn leader_failure_preserves_acknowledged_writes() {
let (board, voters) = boot_cluster(&[1, 2, 3]).await;
voters[0]
.raft
.wait(Some(Duration::from_secs(10)))
.state(ServerState::Leader, "bootstrap leader")
.await
.unwrap();
let a = [1.0, 0.0, 0.0, 0.0];
let b = [0.0, 1.0, 0.0, 0.0];
let ops = collection_ops(&[("a", a), ("b", b)]);
let coll_id = ops
.iter()
.find_map(|op| match op {
WalOp::CreateCollection { collection_id, .. } => Some(*collection_id),
_ => None,
})
.expect("create-collection op");
for op in &ops {
commit(&board, &voters, op).await;
}
let dead = current_leader(&board, &voters).expect("a leader exists");
board.kill(dead);
if let Some(v) = voters.iter().find(|v| v.id == dead) {
v.raft.shutdown().await.unwrap();
}
let survivors: Vec<&Voter> = voters.iter().filter(|v| v.id != dead).collect();
let c = [0.0f32, 0.0, 1.0, 0.0];
let c_op = WalOp::Upsert {
collection_id: coll_id,
external_id: "c".to_owned(),
vector: c.iter().flat_map(|f| f.to_le_bytes()).collect(),
payload: b"{}".to_vec(),
};
commit(&board, &voters, &c_op).await;
for v in &survivors {
await_serves(&v.engine, &a, "a").await;
await_serves(&v.engine, &b, "b").await;
await_serves(&v.engine, &c, "c").await;
}
let truth_dir = tempfile::tempdir().unwrap();
let mut truth = Database::open(truth_dir.path()).unwrap();
for op in collection_ops(&[("a", a), ("b", b), ("c", c)]) {
truth.apply_replicated(op).unwrap();
}
let params = SearchParams {
k: 1,
filter: None,
ef_search: 16,
with_payload: false,
with_vector: false,
};
for (q, want) in [(a, "a"), (b, "b"), (c, "c")] {
let truth_top = truth.search("docs", &q, ¶ms).unwrap()[0].id.clone();
assert_eq!(truth_top, want);
let survivor_top = survivors[0]
.engine
.lock()
.await
.search("docs", &q, ¶ms)
.unwrap()[0]
.id
.clone();
assert_eq!(
survivor_top, truth_top,
"survivor matches single-node truth"
);
}
for v in survivors {
v.raft.shutdown().await.unwrap();
}
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn a_minority_cannot_commit_a_write() {
let (board, voters) = boot_cluster(&[1, 2, 3]).await;
voters[0]
.raft
.wait(Some(Duration::from_secs(10)))
.state(ServerState::Leader, "bootstrap leader")
.await
.unwrap();
let a = [1.0f32, 0.0, 0.0, 0.0];
let ops = collection_ops(&[("a", a)]);
let coll_id = ops
.iter()
.find_map(|op| match op {
WalOp::CreateCollection { collection_id, .. } => Some(*collection_id),
_ => None,
})
.expect("create-collection op");
for op in &ops {
commit(&board, &voters, op).await;
}
for v in &voters {
await_serves(&v.engine, &a, "a").await;
}
for v in &voters[1..] {
board.kill(v.id);
v.raft.shutdown().await.unwrap();
}
let survivor = &voters[0];
let op = WalOp::Upsert {
collection_id: coll_id,
external_id: "b".to_owned(),
vector: [0.0f32, 1.0, 0.0, 0.0]
.iter()
.flat_map(|f| f.to_le_bytes())
.collect(),
payload: b"{}".to_vec(),
};
let committed =
tokio::time::timeout(Duration::from_secs(3), survivor.raft.client_write(op)).await;
assert!(
matches!(committed, Err(_) | Ok(Err(_))),
"a minority of one committed a write — split-brain"
);
await_serves(&survivor.engine, &a, "a").await;
survivor.raft.shutdown().await.unwrap();
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn a_new_voter_catches_up_via_snapshot_after_compaction() {
let board = Switchboard::default();
let cfg = Arc::new(
Config {
heartbeat_interval: 100,
election_timeout_min: 300,
election_timeout_max: 600,
max_in_snapshot_log_to_keep: 0,
purge_batch_size: 1,
..Default::default()
}
.validate()
.unwrap(),
);
let dir1 = tempfile::tempdir().unwrap();
let e1 = Arc::new(Mutex::new(Database::open(dir1.path()).unwrap()));
let r1 = openraft::Raft::new(
1,
cfg.clone(),
board.clone(),
LogStore::default(),
Arc::new(StateMachineStore::new(EngineApplier(e1.clone()))),
)
.await
.unwrap();
board.register(1, r1.clone());
let mut members = BTreeMap::new();
members.insert(1, BasicNode::default());
r1.initialize(members).await.unwrap();
r1.wait(Some(Duration::from_secs(10)))
.state(ServerState::Leader, "leader")
.await
.unwrap();
let a = [1.0f32, 0.0, 0.0, 0.0];
let b = [0.0f32, 1.0, 0.0, 0.0];
for op in collection_ops(&[("a", a), ("b", b)]) {
r1.client_write(op).await.unwrap();
}
r1.trigger().snapshot().await.unwrap();
let snap_index = loop {
if let Some(s) = r1.metrics().borrow().snapshot {
break s.index;
}
tokio::time::sleep(Duration::from_millis(20)).await;
};
r1.trigger().purge_log(snap_index).await.unwrap();
let dir2 = tempfile::tempdir().unwrap();
let e2 = Arc::new(Mutex::new(Database::open(dir2.path()).unwrap()));
let r2 = openraft::Raft::new(
2,
cfg.clone(),
board.clone(),
LogStore::default(),
Arc::new(StateMachineStore::new(EngineApplier(e2.clone()))),
)
.await
.unwrap();
board.register(2, r2.clone());
r1.add_learner(2, BasicNode::default(), true).await.unwrap();
await_serves(&e2, &a, "a").await;
await_serves(&e2, &b, "b").await;
r1.shutdown().await.unwrap();
r2.shutdown().await.unwrap();
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn a_partitioned_voter_rejoins_and_catches_up() {
let (board, voters) = boot_cluster(&[1, 2, 3]).await;
voters[0]
.raft
.wait(Some(Duration::from_secs(10)))
.state(ServerState::Leader, "bootstrap leader")
.await
.unwrap();
let a = [1.0f32, 0.0, 0.0, 0.0];
let b = [0.0f32, 1.0, 0.0, 0.0];
let ops = collection_ops(&[("a", a)]);
let coll_id = ops
.iter()
.find_map(|op| match op {
WalOp::CreateCollection { collection_id, .. } => Some(*collection_id),
_ => None,
})
.expect("create-collection op");
for op in &ops {
commit(&board, &voters, op).await;
}
for v in &voters {
await_serves(&v.engine, &a, "a").await;
}
let isolated = 3;
board.kill(isolated);
let b_op = WalOp::Upsert {
collection_id: coll_id,
external_id: "b".to_owned(),
vector: b.iter().flat_map(|f| f.to_le_bytes()).collect(),
payload: b"{}".to_vec(),
};
commit(&board, &voters, &b_op).await;
for v in voters.iter().filter(|v| v.id != isolated) {
await_serves(&v.engine, &b, "b").await;
}
board.register(isolated, voters[2].raft.clone());
await_serves(&voters[2].engine, &a, "a").await;
await_serves(&voters[2].engine, &b, "b").await;
for v in &voters {
v.raft.shutdown().await.unwrap();
}
}
}