use std::collections::BTreeMap;
use std::sync::Arc;
use std::time::Duration;
use openraft::{Config, Raft};
use serde::{Deserialize, Serialize};
use thiserror::Error;
use super::committer::RaftCommitter;
use super::http_network::HttpRaftNetworkFactory;
use super::log_storage::SqliteRaftLogStorage;
use super::state_machine::YantrikStateMachine;
use super::types::{YantrikNode, YantrikNodeId, YantrikRaftTypeConfig};
use crate::commit::{Applier, MutationCommitter};
use crate::security::cluster_tls::{ClusterTlsConfig, ClusterTlsError};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum RaftClusterMode {
Disabled,
OpenRaft,
}
impl Default for RaftClusterMode {
fn default() -> Self {
RaftClusterMode::Disabled
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum HandlerWritePath {
LocalSqlite,
RaftSubmitter,
}
impl Default for HandlerWritePath {
fn default() -> Self {
HandlerWritePath::LocalSqlite
}
}
#[derive(Debug, Error)]
pub enum AssemblyError {
#[error(
"openraft mode requires fully-specified cluster_tls (cert_path, key_path, ca_path) \
to prevent accidental plaintext cluster traffic; missing field: {missing}"
)]
MtlsRequired { missing: &'static str },
#[error("cluster_tls build failed: {0}")]
ClusterTls(#[from] ClusterTlsError),
#[error("reqwest client build failed: {0}")]
ReqwestBuild(String),
#[error("openraft Raft::new fatal: {0}")]
RaftNew(String),
#[error("read PEM file `{path}`: {source}")]
PemRead {
path: std::path::PathBuf,
#[source]
source: std::io::Error,
},
#[error(
"openraft mode requires handler_write_path = \"raft_submitter\"; got {actual:?}. \
Configure cluster.handler_write_path = \"raft_submitter\", or set \
cluster.raft_mode = \"disabled\" for single-node deployments."
)]
WritePathMismatch {
actual: HandlerWritePath,
expected: HandlerWritePath,
},
#[error(
"openraft mode requires at least 2 peers (got {have}). \
A 1-peer cluster has no quorum semantics; configure additional \
peers in cluster.peers or set cluster.raft_mode = \"disabled\"."
)]
InsufficientPeers { have: usize, need: usize },
}
pub struct RaftAssemblyConfig {
pub mode: RaftClusterMode,
pub node_id: YantrikNodeId,
pub node_addr: String,
pub cluster_tls: Option<ClusterTlsConfig>,
pub peers: Vec<String>,
pub write_path: HandlerWritePath,
pub applier: Arc<dyn Applier>,
pub request_timeout: Duration,
pub openraft_config: Config,
}
impl RaftAssemblyConfig {
pub fn production_defaults(node_id: YantrikNodeId, node_addr: String) -> Self {
Self {
mode: RaftClusterMode::OpenRaft,
node_id,
node_addr,
cluster_tls: None, peers: Vec::new(), write_path: HandlerWritePath::RaftSubmitter, applier: Arc::new(crate::commit::LocalApplier::new()),
request_timeout: Duration::from_secs(10),
openraft_config: Config {
cluster_name: "yantrikdb".into(),
heartbeat_interval: 200,
election_timeout_min: 800,
election_timeout_max: 1600,
..Default::default()
},
}
}
pub(crate) fn validate(&self) -> Result<(), AssemblyError> {
if self.mode == RaftClusterMode::OpenRaft {
if self.write_path != HandlerWritePath::RaftSubmitter {
return Err(AssemblyError::WritePathMismatch {
actual: self.write_path,
expected: HandlerWritePath::RaftSubmitter,
});
}
if self.peers.len() < 2 {
return Err(AssemblyError::InsufficientPeers {
have: self.peers.len(),
need: 2,
});
}
}
Ok(())
}
}
pub struct RaftAssembly {
pub raft: Arc<Raft<YantrikRaftTypeConfig>>,
pub committer: RaftCommitter,
pub network_factory: HttpRaftNetworkFactory,
}
fn validate_cluster_tls_for_openraft(
cluster_tls: Option<&ClusterTlsConfig>,
) -> Result<&ClusterTlsConfig, AssemblyError> {
let cfg = cluster_tls.ok_or(AssemblyError::MtlsRequired {
missing: "cluster_tls (entire section)",
})?;
if cfg.cert_path.is_none() {
return Err(AssemblyError::MtlsRequired {
missing: "cert_path",
});
}
if cfg.key_path.is_none() {
return Err(AssemblyError::MtlsRequired {
missing: "key_path",
});
}
if cfg.ca_path.is_none() {
return Err(AssemblyError::MtlsRequired { missing: "ca_path" });
}
Ok(cfg)
}
fn build_reqwest_client_for_cluster(
cfg: &ClusterTlsConfig,
request_timeout: Duration,
) -> Result<reqwest::Client, AssemblyError> {
let cert_path = cfg.cert_path.as_ref().expect("validated above");
let key_path = cfg.key_path.as_ref().expect("validated above");
let ca_path = cfg.ca_path.as_ref().expect("validated above");
let cert_pem = std::fs::read(cert_path).map_err(|e| AssemblyError::PemRead {
path: cert_path.clone(),
source: e,
})?;
let key_pem = std::fs::read(key_path).map_err(|e| AssemblyError::PemRead {
path: key_path.clone(),
source: e,
})?;
let ca_pem = std::fs::read(ca_path).map_err(|e| AssemblyError::PemRead {
path: ca_path.clone(),
source: e,
})?;
let mut bundle = cert_pem.clone();
bundle.push(b'\n');
bundle.extend_from_slice(&key_pem);
let identity = reqwest::Identity::from_pem(&bundle)
.map_err(|e| AssemblyError::ReqwestBuild(format!("Identity::from_pem: {e}")))?;
let ca_cert = reqwest::Certificate::from_pem(&ca_pem)
.map_err(|e| AssemblyError::ReqwestBuild(format!("Certificate::from_pem: {e}")))?;
let mut builder = reqwest::Client::builder()
.timeout(request_timeout)
.identity(identity)
.add_root_certificate(ca_cert)
.tls_built_in_root_certs(false);
if cfg.dev_mode {
tracing::warn!(
"cluster_tls.dev_mode = true — accepting self-signed peer certs. \
NEVER set this in production."
);
builder = builder.danger_accept_invalid_certs(true);
}
builder.build().map_err(|e| {
let mut chain = format!("build: {e}");
let mut src = std::error::Error::source(&e);
while let Some(s) = src {
chain.push_str(&format!(" / {s}"));
src = s.source();
}
AssemblyError::ReqwestBuild(chain)
})
}
pub async fn build_raft_cluster(
cfg: RaftAssemblyConfig,
log_storage: SqliteRaftLogStorage,
local: Arc<dyn MutationCommitter>,
) -> Result<RaftAssembly, AssemblyError> {
cfg.validate()?;
let cluster_tls = validate_cluster_tls_for_openraft(cfg.cluster_tls.as_ref())?;
let client = build_reqwest_client_for_cluster(cluster_tls, cfg.request_timeout)?;
let network_factory = HttpRaftNetworkFactory::new(client, cfg.request_timeout);
let validated_config = Arc::new(
cfg.openraft_config
.validate()
.map_err(|e| AssemblyError::RaftNew(format!("openraft Config::validate: {e}")))?,
);
let state_machine = YantrikStateMachine::new(local.clone(), cfg.applier);
let raft = Raft::<YantrikRaftTypeConfig>::new(
cfg.node_id,
validated_config,
network_factory.clone(),
log_storage,
state_machine,
)
.await
.map_err(|e| AssemblyError::RaftNew(format!("{e}")))?;
let raft = Arc::new(raft);
let committer = RaftCommitter::new(raft.clone(), local);
Ok(RaftAssembly {
raft,
committer,
network_factory,
})
}
pub async fn initialize_single_node(
assembly: &RaftAssembly,
node_addr: String,
) -> Result<
(),
openraft::error::RaftError<
YantrikNodeId,
openraft::error::InitializeError<YantrikNodeId, YantrikNode>,
>,
> {
let me = {
let metrics = assembly.raft.metrics().borrow().clone();
metrics.id
};
let mut nodes = BTreeMap::new();
nodes.insert(me, YantrikNode::new(node_addr));
assembly.raft.initialize(nodes).await
}
#[cfg(test)]
mod tests {
use super::*;
fn empty_tls() -> ClusterTlsConfig {
ClusterTlsConfig::default()
}
fn tls_with(cert: Option<&str>, key: Option<&str>, ca: Option<&str>) -> ClusterTlsConfig {
ClusterTlsConfig {
cert_path: cert.map(std::path::PathBuf::from),
key_path: key.map(std::path::PathBuf::from),
ca_path: ca.map(std::path::PathBuf::from),
dev_mode: false,
rotate_check_secs: 60,
}
}
#[test]
fn openraft_mode_rejects_missing_cluster_tls_section() {
let err = validate_cluster_tls_for_openraft(None).unwrap_err();
match err {
AssemblyError::MtlsRequired { missing } => assert!(missing.contains("cluster_tls")),
other => panic!("expected MtlsRequired, got {other:?}"),
}
}
#[test]
fn openraft_mode_rejects_empty_tls_config() {
let cfg = empty_tls();
let err = validate_cluster_tls_for_openraft(Some(&cfg)).unwrap_err();
match err {
AssemblyError::MtlsRequired { missing } => assert_eq!(missing, "cert_path"),
other => panic!("expected MtlsRequired, got {other:?}"),
}
}
#[test]
fn openraft_mode_rejects_missing_key() {
let cfg = tls_with(Some("/tmp/cert.pem"), None, Some("/tmp/ca.pem"));
let err = validate_cluster_tls_for_openraft(Some(&cfg)).unwrap_err();
match err {
AssemblyError::MtlsRequired { missing } => assert_eq!(missing, "key_path"),
other => panic!("expected MtlsRequired, got {other:?}"),
}
}
#[test]
fn openraft_mode_rejects_missing_ca() {
let cfg = tls_with(Some("/tmp/cert.pem"), Some("/tmp/key.pem"), None);
let err = validate_cluster_tls_for_openraft(Some(&cfg)).unwrap_err();
match err {
AssemblyError::MtlsRequired { missing } => assert_eq!(missing, "ca_path"),
other => panic!("expected MtlsRequired, got {other:?}"),
}
}
#[test]
fn openraft_mode_accepts_fully_specified_tls() {
let cfg = tls_with(
Some("/tmp/cert.pem"),
Some("/tmp/key.pem"),
Some("/tmp/ca.pem"),
);
validate_cluster_tls_for_openraft(Some(&cfg))
.expect("fully-specified config must pass validation");
}
#[tokio::test]
async fn build_raft_cluster_fails_on_missing_cluster_tls() {
let local = Arc::new(crate::commit::LocalSqliteCommitter::open_in_memory().unwrap())
as Arc<dyn MutationCommitter>;
let log_storage = SqliteRaftLogStorage::open_in_memory();
let cfg = RaftAssemblyConfig {
mode: RaftClusterMode::OpenRaft,
node_id: YantrikNodeId::new(1),
node_addr: "https://127.0.0.1:7100".into(),
cluster_tls: None,
peers: vec![
"https://127.0.0.1:7100".into(),
"https://127.0.0.1:7101".into(),
],
write_path: HandlerWritePath::RaftSubmitter,
applier: Arc::new(crate::commit::LocalApplier::new()),
request_timeout: Duration::from_secs(1),
openraft_config: Config::default(),
};
match build_raft_cluster(cfg, log_storage, local).await {
Err(AssemblyError::MtlsRequired { .. }) => {}
Err(other) => panic!("expected MtlsRequired, got {other:?}"),
Ok(_) => panic!("expected MtlsRequired, assembly succeeded"),
}
}
#[tokio::test]
async fn build_raft_cluster_fails_on_unreadable_cert_files() {
let local = Arc::new(crate::commit::LocalSqliteCommitter::open_in_memory().unwrap())
as Arc<dyn MutationCommitter>;
let log_storage = SqliteRaftLogStorage::open_in_memory();
let cluster_tls = tls_with(
Some("/nonexistent/cert.pem"),
Some("/nonexistent/key.pem"),
Some("/nonexistent/ca.pem"),
);
let cfg = RaftAssemblyConfig {
mode: RaftClusterMode::OpenRaft,
node_id: YantrikNodeId::new(1),
node_addr: "https://127.0.0.1:7100".into(),
cluster_tls: Some(cluster_tls),
peers: vec![
"https://127.0.0.1:7100".into(),
"https://127.0.0.1:7101".into(),
],
write_path: HandlerWritePath::RaftSubmitter,
applier: Arc::new(crate::commit::LocalApplier::new()),
request_timeout: Duration::from_secs(1),
openraft_config: Config::default(),
};
match build_raft_cluster(cfg, log_storage, local).await {
Err(AssemblyError::PemRead { .. }) => {}
Err(other) => panic!("expected PemRead, got {other:?}"),
Ok(_) => panic!("expected PemRead, assembly succeeded"),
}
}
#[test]
fn cluster_mode_default_is_disabled() {
assert_eq!(RaftClusterMode::default(), RaftClusterMode::Disabled);
}
#[test]
fn production_defaults_demand_explicit_tls() {
let d = RaftAssemblyConfig::production_defaults(
YantrikNodeId::new(1),
"https://10.0.0.1:7100".into(),
);
assert_eq!(d.mode, RaftClusterMode::OpenRaft);
assert!(
d.cluster_tls.is_none(),
"production_defaults must NOT bake in any cluster_tls — operator supplies it"
);
}
fn cfg_for(
mode: RaftClusterMode,
write_path: HandlerWritePath,
peers: Vec<String>,
) -> RaftAssemblyConfig {
RaftAssemblyConfig {
mode,
node_id: YantrikNodeId::new(1),
node_addr: "https://10.0.0.1:7100".into(),
cluster_tls: Some(tls_with(
Some("/tmp/cert.pem"),
Some("/tmp/key.pem"),
Some("/tmp/ca.pem"),
)),
peers,
write_path,
applier: Arc::new(crate::commit::LocalApplier::new()),
request_timeout: Duration::from_secs(1),
openraft_config: Config::default(),
}
}
fn three_peer_set() -> Vec<String> {
vec![
"https://10.0.0.1:7100".into(),
"https://10.0.0.2:7100".into(),
"https://10.0.0.3:7100".into(),
]
}
#[test]
fn pr_6_5_openraft_with_localsqlite_write_path_is_rejected() {
let cfg = cfg_for(
RaftClusterMode::OpenRaft,
HandlerWritePath::LocalSqlite,
three_peer_set(),
);
match cfg.validate() {
Err(AssemblyError::WritePathMismatch { actual, expected }) => {
assert_eq!(actual, HandlerWritePath::LocalSqlite);
assert_eq!(expected, HandlerWritePath::RaftSubmitter);
}
other => panic!("expected WritePathMismatch, got {other:?}"),
}
}
#[test]
fn pr_6_5_openraft_with_empty_peers_is_rejected() {
let cfg = cfg_for(
RaftClusterMode::OpenRaft,
HandlerWritePath::RaftSubmitter,
vec![],
);
match cfg.validate() {
Err(AssemblyError::InsufficientPeers { have, need }) => {
assert_eq!(have, 0);
assert_eq!(need, 2);
}
other => panic!("expected InsufficientPeers, got {other:?}"),
}
}
#[test]
fn pr_6_5_openraft_with_one_peer_is_rejected() {
let cfg = cfg_for(
RaftClusterMode::OpenRaft,
HandlerWritePath::RaftSubmitter,
vec!["https://10.0.0.1:7100".into()],
);
assert!(matches!(
cfg.validate(),
Err(AssemblyError::InsufficientPeers { have: 1, need: 2 })
));
}
#[test]
fn pr_6_5_openraft_with_two_peers_passes() {
let cfg = cfg_for(
RaftClusterMode::OpenRaft,
HandlerWritePath::RaftSubmitter,
vec![
"https://10.0.0.1:7100".into(),
"https://10.0.0.2:7100".into(),
],
);
cfg.validate()
.expect("two-peer openraft cluster must validate");
}
#[test]
fn pr_6_5_disabled_mode_does_not_demand_peers() {
let cfg = cfg_for(
RaftClusterMode::Disabled,
HandlerWritePath::LocalSqlite,
vec![],
);
cfg.validate()
.expect("single-node mode must validate without peers");
}
#[test]
fn pr_6_5_disabled_mode_with_raft_submitter_is_currently_permitted() {
let cfg = cfg_for(
RaftClusterMode::Disabled,
HandlerWritePath::RaftSubmitter,
vec![],
);
cfg.validate()
.expect("Disabled+RaftSubmitter is permitted (no-op declaration)");
}
#[tokio::test]
async fn pr_6_5_build_raft_cluster_runs_validate_first() {
let local = Arc::new(crate::commit::LocalSqliteCommitter::open_in_memory().unwrap())
as Arc<dyn MutationCommitter>;
let log_storage = SqliteRaftLogStorage::open_in_memory();
let cfg = cfg_for(
RaftClusterMode::OpenRaft,
HandlerWritePath::LocalSqlite, three_peer_set(),
);
match build_raft_cluster(cfg, log_storage, local).await {
Err(AssemblyError::WritePathMismatch { .. }) => {}
Err(other) => panic!("expected WritePathMismatch, got {other:?}"),
Ok(_) => panic!("expected WritePathMismatch, assembly succeeded"),
}
}
#[test]
fn handler_write_path_default_is_local_sqlite() {
assert_eq!(HandlerWritePath::default(), HandlerWritePath::LocalSqlite);
}
#[test]
fn production_defaults_pair_openraft_with_raft_submitter() {
let d = RaftAssemblyConfig::production_defaults(
YantrikNodeId::new(1),
"https://10.0.0.1:7100".into(),
);
assert_eq!(d.write_path, HandlerWritePath::RaftSubmitter);
assert!(d.peers.is_empty(), "operator MUST supply peers");
}
}