pub mod config;
pub mod stats;
pub mod targeting;
pub mod crc_corruption;
pub mod delayed_response;
pub mod exception_injection;
pub mod extra_data;
pub mod no_response;
pub mod partial_frame;
pub mod truncated_response;
pub mod wrong_function_code;
pub mod wrong_transaction_id;
pub mod wrong_unit_id;
pub mod connection_disruption;
pub mod rtu_timing;
use std::sync::Arc;
use std::time::Duration;
use tracing::{debug, trace};
pub use config::*;
pub use stats::{FaultStats, FaultStatsSnapshot};
pub use targeting::FaultTarget;
pub use crc_corruption::CrcCorruptionFault;
pub use delayed_response::DelayedResponseFault;
pub use exception_injection::ExceptionInjectionFault;
pub use extra_data::ExtraDataFault;
pub use no_response::NoResponseFault;
pub use partial_frame::PartialFrameFault;
pub use truncated_response::TruncatedResponseFault;
pub use wrong_function_code::WrongFunctionCodeFault;
pub use wrong_transaction_id::WrongTransactionIdFault;
pub use wrong_unit_id::WrongUnitIdFault;
#[derive(Debug, Clone)]
pub enum FaultAction {
SendResponse(Vec<u8>),
DropResponse,
DelayThenSend { delay: Duration, response: Vec<u8> },
SendPartial { bytes: Vec<u8> },
SendRawBytes(Vec<u8>),
OverrideTransactionId {
transaction_id: u16,
response: Vec<u8>,
},
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum TransportKind {
Tcp,
Rtu,
}
#[derive(Debug, Clone)]
pub struct ModbusFaultContext {
pub transport: TransportKind,
pub unit_id: u8,
pub function_code: u8,
pub request_pdu: Vec<u8>,
pub response_pdu: Vec<u8>,
pub transaction_id: u16,
pub is_broadcast: bool,
pub request_number: u64,
}
impl ModbusFaultContext {
pub fn rtu(
unit_id: u8,
function_code: u8,
request_pdu: &[u8],
response_pdu: &[u8],
request_number: u64,
) -> Self {
Self {
transport: TransportKind::Rtu,
unit_id,
function_code,
request_pdu: request_pdu.to_vec(),
response_pdu: response_pdu.to_vec(),
transaction_id: 0,
is_broadcast: unit_id == 0,
request_number,
}
}
pub fn tcp(
unit_id: u8,
function_code: u8,
request_pdu: &[u8],
response_pdu: &[u8],
transaction_id: u16,
request_number: u64,
) -> Self {
Self {
transport: TransportKind::Tcp,
unit_id,
function_code,
request_pdu: request_pdu.to_vec(),
response_pdu: response_pdu.to_vec(),
transaction_id,
is_broadcast: unit_id == 0,
request_number,
}
}
}
pub trait ModbusFault: Send + Sync {
fn fault_type(&self) -> &'static str;
fn is_enabled(&self) -> bool;
fn set_enabled(&self, enabled: bool);
fn should_activate(&self, ctx: &ModbusFaultContext) -> bool;
fn apply(&self, ctx: &ModbusFaultContext) -> FaultAction;
fn stats(&self) -> FaultStatsSnapshot;
fn reset_stats(&self);
fn is_short_circuit(&self) -> bool {
false
}
fn compatible_transport(&self) -> Option<TransportKind> {
None
}
}
pub struct FaultPipeline {
faults: Vec<Arc<dyn ModbusFault>>,
enabled: std::sync::atomic::AtomicBool,
}
impl FaultPipeline {
pub fn new() -> Self {
Self {
faults: Vec::new(),
enabled: std::sync::atomic::AtomicBool::new(true),
}
}
pub fn with_fault(mut self, fault: Arc<dyn ModbusFault>) -> Self {
self.faults.push(fault);
self
}
pub fn with_faults(mut self, faults: Vec<Arc<dyn ModbusFault>>) -> Self {
self.faults.extend(faults);
self
}
pub fn is_enabled(&self) -> bool {
self.enabled.load(std::sync::atomic::Ordering::Acquire)
}
pub fn set_enabled(&self, enabled: bool) {
self.enabled
.store(enabled, std::sync::atomic::Ordering::Release);
}
pub fn len(&self) -> usize {
self.faults.len()
}
pub fn is_empty(&self) -> bool {
self.faults.is_empty()
}
pub fn apply(&self, ctx: &ModbusFaultContext) -> Option<FaultAction> {
if !self.is_enabled() {
return None;
}
for fault in &self.faults {
if !fault.is_enabled() {
continue;
}
if !fault.is_short_circuit() {
continue;
}
if let Some(required) = fault.compatible_transport() {
if required != ctx.transport {
continue;
}
}
if fault.should_activate(ctx) {
let action = fault.apply(ctx);
debug!(
fault_type = fault.fault_type(),
unit_id = ctx.unit_id,
fc = ctx.function_code,
"Short-circuit fault activated"
);
return Some(action);
}
}
let mut current_response = ctx.response_pdu.clone();
let mut any_activated = false;
let mut total_delay = Duration::ZERO;
let mut override_tid: Option<u16> = None;
let mut send_raw = false;
for fault in &self.faults {
if !fault.is_enabled() {
continue;
}
if fault.is_short_circuit() {
continue;
}
if let Some(required) = fault.compatible_transport() {
if required != ctx.transport {
continue;
}
}
if !fault.should_activate(ctx) {
continue;
}
let mut sub_ctx = ctx.clone();
sub_ctx.response_pdu = current_response.clone();
let action = fault.apply(&sub_ctx);
any_activated = true;
trace!(
fault_type = fault.fault_type(),
unit_id = ctx.unit_id,
fc = ctx.function_code,
"Mutation fault activated"
);
match action {
FaultAction::SendResponse(pdu) => {
current_response = pdu;
}
FaultAction::DelayThenSend { delay, response } => {
total_delay += delay;
current_response = response;
}
FaultAction::OverrideTransactionId {
transaction_id,
response,
} => {
override_tid = Some(transaction_id);
current_response = response;
}
FaultAction::SendRawBytes(bytes) => {
current_response = bytes;
send_raw = true;
}
FaultAction::DropResponse => {
return Some(FaultAction::DropResponse);
}
FaultAction::SendPartial { bytes } => {
return Some(FaultAction::SendPartial { bytes });
}
}
}
if !any_activated {
return None;
}
let action = if send_raw {
if total_delay > Duration::ZERO {
FaultAction::DelayThenSend {
delay: total_delay,
response: current_response,
}
} else {
FaultAction::SendRawBytes(current_response)
}
} else if let Some(tid) = override_tid {
if total_delay > Duration::ZERO {
FaultAction::DelayThenSend {
delay: total_delay,
response: current_response,
}
} else {
FaultAction::OverrideTransactionId {
transaction_id: tid,
response: current_response,
}
}
} else if total_delay > Duration::ZERO {
FaultAction::DelayThenSend {
delay: total_delay,
response: current_response,
}
} else {
FaultAction::SendResponse(current_response)
};
debug!(
unit_id = ctx.unit_id,
fc = ctx.function_code,
"Fault pipeline produced action"
);
Some(action)
}
pub fn all_stats(&self) -> Vec<(&'static str, FaultStatsSnapshot)> {
self.faults
.iter()
.map(|f| (f.fault_type(), f.stats()))
.collect()
}
pub fn reset_all_stats(&self) {
for fault in &self.faults {
fault.reset_stats();
}
}
pub fn from_config(config: &FaultInjectionConfig) -> Self {
let mut pipeline = Self::new();
if !config.enabled {
pipeline
.enabled
.store(false, std::sync::atomic::Ordering::Release);
return pipeline;
}
for fault_cfg in &config.faults {
let fault: Arc<dyn ModbusFault> = match fault_cfg.fault_type {
FaultType::CrcCorruption => Arc::new(CrcCorruptionFault::from_config(
&fault_cfg.config,
fault_cfg.target.clone(),
)),
FaultType::WrongUnitId => Arc::new(WrongUnitIdFault::from_config(
&fault_cfg.config,
fault_cfg.target.clone(),
)),
FaultType::WrongFunctionCode => Arc::new(WrongFunctionCodeFault::from_config(
&fault_cfg.config,
fault_cfg.target.clone(),
)),
FaultType::WrongTransactionId => Arc::new(WrongTransactionIdFault::from_config(
&fault_cfg.config,
fault_cfg.target.clone(),
)),
FaultType::TruncatedResponse => Arc::new(TruncatedResponseFault::from_config(
&fault_cfg.config,
fault_cfg.target.clone(),
)),
FaultType::ExtraData => Arc::new(ExtraDataFault::from_config(
&fault_cfg.config,
fault_cfg.target.clone(),
)),
FaultType::DelayedResponse => Arc::new(DelayedResponseFault::from_config(
&fault_cfg.config,
fault_cfg.target.clone(),
)),
FaultType::NoResponse => Arc::new(NoResponseFault::new(fault_cfg.target.clone())),
FaultType::ExceptionInjection => Arc::new(ExceptionInjectionFault::from_config(
&fault_cfg.config,
fault_cfg.target.clone(),
)),
FaultType::PartialFrame => Arc::new(PartialFrameFault::from_config(
&fault_cfg.config,
fault_cfg.target.clone(),
)),
};
pipeline.faults.push(fault);
}
pipeline
}
}
impl Default for FaultPipeline {
fn default() -> Self {
Self::new()
}
}
impl std::fmt::Debug for FaultPipeline {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("FaultPipeline")
.field("enabled", &self.is_enabled())
.field("fault_count", &self.faults.len())
.field(
"faults",
&self
.faults
.iter()
.map(|f| f.fault_type())
.collect::<Vec<_>>(),
)
.finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
struct AlwaysDropFault {
stats: FaultStats,
}
impl AlwaysDropFault {
fn new() -> Self {
Self {
stats: FaultStats::new(),
}
}
}
impl ModbusFault for AlwaysDropFault {
fn fault_type(&self) -> &'static str {
"always_drop"
}
fn is_enabled(&self) -> bool {
self.stats.is_enabled()
}
fn set_enabled(&self, enabled: bool) {
self.stats.set_enabled(enabled);
}
fn should_activate(&self, _ctx: &ModbusFaultContext) -> bool {
self.stats.record_check();
true
}
fn apply(&self, _ctx: &ModbusFaultContext) -> FaultAction {
self.stats.record_activation();
self.stats.record_affected();
FaultAction::DropResponse
}
fn stats(&self) -> FaultStatsSnapshot {
self.stats.snapshot()
}
fn reset_stats(&self) {
self.stats.reset();
}
fn is_short_circuit(&self) -> bool {
true
}
}
struct PrependByteFault {
byte: u8,
stats: FaultStats,
}
impl PrependByteFault {
fn new(byte: u8) -> Self {
Self {
byte,
stats: FaultStats::new(),
}
}
}
impl ModbusFault for PrependByteFault {
fn fault_type(&self) -> &'static str {
"prepend_byte"
}
fn is_enabled(&self) -> bool {
self.stats.is_enabled()
}
fn set_enabled(&self, enabled: bool) {
self.stats.set_enabled(enabled);
}
fn should_activate(&self, _ctx: &ModbusFaultContext) -> bool {
self.stats.record_check();
true
}
fn apply(&self, ctx: &ModbusFaultContext) -> FaultAction {
self.stats.record_activation();
self.stats.record_affected();
let mut response = vec![self.byte];
response.extend_from_slice(&ctx.response_pdu);
FaultAction::SendResponse(response)
}
fn stats(&self) -> FaultStatsSnapshot {
self.stats.snapshot()
}
fn reset_stats(&self) {
self.stats.reset();
}
}
fn test_ctx() -> ModbusFaultContext {
ModbusFaultContext::tcp(
1,
0x03,
&[0x03, 0x00, 0x00, 0x00, 0x01],
&[0x03, 0x02, 0x00, 0x64],
1,
1,
)
}
#[test]
fn test_empty_pipeline() {
let pipeline = FaultPipeline::new();
assert!(pipeline.is_empty());
assert_eq!(pipeline.len(), 0);
assert!(pipeline.apply(&test_ctx()).is_none());
}
#[test]
fn test_pipeline_disabled() {
let pipeline = FaultPipeline::new().with_fault(Arc::new(AlwaysDropFault::new()));
pipeline.set_enabled(false);
assert!(pipeline.apply(&test_ctx()).is_none());
}
#[test]
fn test_short_circuit_priority() {
let drop_fault = Arc::new(AlwaysDropFault::new());
let prepend_fault = Arc::new(PrependByteFault::new(0xFF));
let pipeline = FaultPipeline::new()
.with_fault(prepend_fault.clone())
.with_fault(drop_fault.clone());
let action = pipeline.apply(&test_ctx());
assert!(matches!(action, Some(FaultAction::DropResponse)));
assert_eq!(drop_fault.stats().checks, 1);
assert_eq!(drop_fault.stats().activations, 1);
assert_eq!(prepend_fault.stats().checks, 0);
}
#[test]
fn test_mutation_chaining() {
let fault1 = Arc::new(PrependByteFault::new(0xAA));
let fault2 = Arc::new(PrependByteFault::new(0xBB));
let pipeline = FaultPipeline::new()
.with_fault(fault1.clone())
.with_fault(fault2.clone());
let ctx = test_ctx();
let action = pipeline.apply(&ctx).unwrap();
match action {
FaultAction::SendResponse(pdu) => {
assert_eq!(pdu[0], 0xBB);
assert_eq!(pdu[1], 0xAA);
assert_eq!(&pdu[2..], &ctx.response_pdu);
}
_ => panic!("Expected SendResponse"),
}
}
#[test]
fn test_disabled_fault_skipped() {
let fault = Arc::new(PrependByteFault::new(0xFF));
fault.set_enabled(false);
let pipeline = FaultPipeline::new().with_fault(fault.clone());
assert!(pipeline.apply(&test_ctx()).is_none());
assert_eq!(fault.stats().checks, 0);
}
#[test]
fn test_all_stats() {
let fault1 = Arc::new(PrependByteFault::new(0xAA));
let fault2 = Arc::new(PrependByteFault::new(0xBB));
let pipeline = FaultPipeline::new().with_fault(fault1).with_fault(fault2);
pipeline.apply(&test_ctx());
let stats = pipeline.all_stats();
assert_eq!(stats.len(), 2);
assert_eq!(stats[0].0, "prepend_byte");
assert_eq!(stats[0].1.activations, 1);
assert_eq!(stats[1].0, "prepend_byte");
assert_eq!(stats[1].1.activations, 1);
}
#[test]
fn test_reset_all_stats() {
let fault = Arc::new(PrependByteFault::new(0xAA));
let pipeline = FaultPipeline::new().with_fault(fault.clone());
pipeline.apply(&test_ctx());
assert_eq!(fault.stats().activations, 1);
pipeline.reset_all_stats();
assert_eq!(fault.stats().activations, 0);
}
#[test]
fn test_rtu_context() {
let ctx = ModbusFaultContext::rtu(1, 0x03, &[0x03], &[0x03, 0x02], 5);
assert_eq!(ctx.transport, TransportKind::Rtu);
assert_eq!(ctx.unit_id, 1);
assert_eq!(ctx.transaction_id, 0);
assert!(!ctx.is_broadcast);
assert_eq!(ctx.request_number, 5);
}
#[test]
fn test_tcp_context() {
let ctx = ModbusFaultContext::tcp(0, 0x03, &[0x03], &[0x03, 0x02], 42, 10);
assert_eq!(ctx.transport, TransportKind::Tcp);
assert_eq!(ctx.unit_id, 0);
assert!(ctx.is_broadcast);
assert_eq!(ctx.transaction_id, 42);
assert_eq!(ctx.request_number, 10);
}
#[test]
fn test_transport_compatibility() {
struct RtuOnlyFault {
stats: FaultStats,
}
impl ModbusFault for RtuOnlyFault {
fn fault_type(&self) -> &'static str {
"rtu_only"
}
fn is_enabled(&self) -> bool {
self.stats.is_enabled()
}
fn set_enabled(&self, enabled: bool) {
self.stats.set_enabled(enabled);
}
fn should_activate(&self, _: &ModbusFaultContext) -> bool {
self.stats.record_check();
true
}
fn apply(&self, ctx: &ModbusFaultContext) -> FaultAction {
self.stats.record_activation();
FaultAction::SendResponse(ctx.response_pdu.clone())
}
fn stats(&self) -> FaultStatsSnapshot {
self.stats.snapshot()
}
fn reset_stats(&self) {
self.stats.reset();
}
fn compatible_transport(&self) -> Option<TransportKind> {
Some(TransportKind::Rtu)
}
}
let fault = Arc::new(RtuOnlyFault {
stats: FaultStats::new(),
});
let pipeline = FaultPipeline::new().with_fault(fault.clone());
let tcp_ctx = test_ctx(); assert!(pipeline.apply(&tcp_ctx).is_none());
let rtu_ctx = ModbusFaultContext::rtu(1, 0x03, &[0x03], &[0x03, 0x02], 1);
assert!(pipeline.apply(&rtu_ctx).is_some());
}
}