1use crate::messages::ProtocolMessage;
2use crate::{BatchId, CommandBatch, NodeId, PhaseId, RabiaError, Result};
3use std::time::{SystemTime, UNIX_EPOCH};
4
5pub trait Validator {
6 fn validate(&self) -> Result<()>;
7}
8
9#[derive(Debug, Clone)]
10pub struct ValidationConfig {
11 pub max_batch_size: usize,
12 pub max_command_size: usize,
13 pub max_clock_skew_ms: u64,
14 pub min_phase_id: u64,
15 pub max_phase_id: u64,
16}
17
18impl Default for ValidationConfig {
19 fn default() -> Self {
20 Self {
21 max_batch_size: 1000,
22 max_command_size: 1024 * 1024, max_clock_skew_ms: 60_000, min_phase_id: 0,
25 max_phase_id: u64::MAX,
26 }
27 }
28}
29
30impl Validator for ProtocolMessage {
31 fn validate(&self) -> Result<()> {
32 let now = SystemTime::now()
34 .duration_since(UNIX_EPOCH)
35 .unwrap()
36 .as_millis() as u64;
37
38 let config = ValidationConfig::default();
39
40 if self.timestamp > now + config.max_clock_skew_ms {
41 return Err(RabiaError::internal(format!(
42 "Message timestamp {} is too far in the future (current: {})",
43 self.timestamp, now
44 )));
45 }
46
47 if now.saturating_sub(self.timestamp) > config.max_clock_skew_ms * 10 {
48 return Err(RabiaError::internal(format!(
49 "Message timestamp {} is too old (current: {})",
50 self.timestamp, now
51 )));
52 }
53
54 match &self.message_type {
56 crate::messages::MessageType::Propose(propose) => {
57 validate_phase_id(&propose.phase_id)?;
58 validate_batch_id(&propose.batch_id)?;
59 if let Some(batch) = &propose.batch {
60 batch.validate()?;
61 }
62 }
63 crate::messages::MessageType::VoteRound1(vote) => {
64 validate_phase_id(&vote.phase_id)?;
65 validate_batch_id(&vote.batch_id)?;
66 validate_node_id(&vote.voter_id)?;
67 }
68 crate::messages::MessageType::VoteRound2(vote) => {
69 validate_phase_id(&vote.phase_id)?;
70 validate_batch_id(&vote.batch_id)?;
71 validate_node_id(&vote.voter_id)?;
72
73 if vote.round1_votes.is_empty() {
75 return Err(RabiaError::internal(
76 "Round 2 vote must include round 1 votes".to_string(),
77 ));
78 }
79 }
80 crate::messages::MessageType::Decision(decision) => {
81 validate_phase_id(&decision.phase_id)?;
82 validate_batch_id(&decision.batch_id)?;
83 if let Some(batch) = &decision.batch {
84 batch.validate()?;
85 }
86 }
87 crate::messages::MessageType::SyncRequest(request) => {
88 validate_phase_id(&request.requester_phase)?;
89 }
90 crate::messages::MessageType::SyncResponse(response) => {
91 validate_phase_id(&response.responder_phase)?;
92
93 for (batch_id, batch) in &response.pending_batches {
95 validate_batch_id(batch_id)?;
96 batch.validate()?;
97 }
98 }
99 crate::messages::MessageType::NewBatch(new_batch) => {
100 new_batch.batch.validate()?;
101 validate_node_id(&new_batch.originator)?;
102 }
103 crate::messages::MessageType::HeartBeat(heartbeat) => {
104 validate_phase_id(&heartbeat.current_phase)?;
105 validate_phase_id(&heartbeat.last_committed_phase)?;
106
107 if heartbeat.last_committed_phase > heartbeat.current_phase {
109 return Err(RabiaError::InvalidStateTransition {
110 from: format!("committed={}", heartbeat.last_committed_phase),
111 to: format!("current={}", heartbeat.current_phase),
112 });
113 }
114 }
115 crate::messages::MessageType::QuorumNotification(notification) => {
116 for node_id in ¬ification.active_nodes {
117 validate_node_id(node_id)?;
118 }
119 }
120 }
121
122 Ok(())
123 }
124}
125
126impl Validator for CommandBatch {
127 fn validate(&self) -> Result<()> {
128 let config = ValidationConfig::default();
129
130 if self.commands.len() > config.max_batch_size {
132 return Err(RabiaError::internal(format!(
133 "Batch size {} exceeds maximum {}",
134 self.commands.len(),
135 config.max_batch_size
136 )));
137 }
138
139 if self.commands.is_empty() {
140 return Err(RabiaError::internal("Batch cannot be empty".to_string()));
141 }
142
143 for command in &self.commands {
145 if command.data.len() > config.max_command_size {
146 return Err(RabiaError::internal(format!(
147 "Command size {} exceeds maximum {}",
148 command.data.len(),
149 config.max_command_size
150 )));
151 }
152
153 if command.data.is_empty() {
155 return Err(RabiaError::internal(
156 "Command data cannot be empty".to_string(),
157 ));
158 }
159 }
160
161 let _calculated_checksum = self.checksum();
163 let now = SystemTime::now()
167 .duration_since(UNIX_EPOCH)
168 .unwrap()
169 .as_millis() as u64;
170
171 if self.timestamp > now + config.max_clock_skew_ms {
172 return Err(RabiaError::internal(format!(
173 "Batch timestamp {} is too far in the future",
174 self.timestamp
175 )));
176 }
177
178 Ok(())
179 }
180}
181
182fn validate_phase_id(phase_id: &PhaseId) -> Result<()> {
183 let config = ValidationConfig::default();
184 let value = phase_id.value();
185
186 if value < config.min_phase_id || value > config.max_phase_id {
187 return Err(RabiaError::internal(format!(
188 "Phase ID {} is out of valid range [{}, {}]",
189 value, config.min_phase_id, config.max_phase_id
190 )));
191 }
192
193 Ok(())
194}
195
196fn validate_batch_id(_batch_id: &BatchId) -> Result<()> {
197 Ok(())
200}
201
202fn validate_node_id(_node_id: &NodeId) -> Result<()> {
203 Ok(())
206}
207
208pub fn validate_message_sequence(previous_phase: PhaseId, current_phase: PhaseId) -> Result<()> {
209 if current_phase.value() <= previous_phase.value() {
210 return Err(RabiaError::InvalidStateTransition {
211 from: format!("phase={}", previous_phase),
212 to: format!("phase={}", current_phase),
213 });
214 }
215
216 let jump = current_phase.value() - previous_phase.value();
218 if jump > 1000 {
219 return Err(RabiaError::internal(format!(
220 "Phase jump {} is suspiciously large",
221 jump
222 )));
223 }
224
225 Ok(())
226}
227
228#[cfg(test)]
229mod tests {
230 use super::*;
231 use crate::Command;
232
233 #[test]
234 fn test_batch_validation() {
235 let commands = vec![Command::new("SET key1 value1"), Command::new("GET key1")];
236 let batch = CommandBatch::new(commands);
237
238 assert!(batch.validate().is_ok());
239 }
240
241 #[test]
242 fn test_empty_batch_validation() {
243 let batch = CommandBatch::new(vec![]);
244 assert!(batch.validate().is_err());
245 }
246
247 #[test]
248 fn test_phase_sequence_validation() {
249 let phase1 = PhaseId::new(1);
250 let phase2 = PhaseId::new(2);
251 let phase3 = PhaseId::new(1); assert!(validate_message_sequence(phase1, phase2).is_ok());
254 assert!(validate_message_sequence(phase2, phase3).is_err());
255 }
256}