Skip to main content

rabia_core/
validation.rs

1use crate::messages::ProtocolMessage;
2use crate::{BatchId, CommandBatch, NodeId, PhaseId, RabiaError, Result};
3use std::time::{SystemTime, UNIX_EPOCH};
4
5pub trait Validator {
6    fn validate(&self) -> Result<()>;
7}
8
9#[derive(Debug, Clone)]
10pub struct ValidationConfig {
11    pub max_batch_size: usize,
12    pub max_command_size: usize,
13    pub max_clock_skew_ms: u64,
14    pub min_phase_id: u64,
15    pub max_phase_id: u64,
16}
17
18impl Default for ValidationConfig {
19    fn default() -> Self {
20        Self {
21            max_batch_size: 1000,
22            max_command_size: 1024 * 1024, // 1MB
23            max_clock_skew_ms: 60_000,     // 1 minute
24            min_phase_id: 0,
25            max_phase_id: u64::MAX,
26        }
27    }
28}
29
30impl Validator for ProtocolMessage {
31    fn validate(&self) -> Result<()> {
32        // Validate timestamp
33        let now = SystemTime::now()
34            .duration_since(UNIX_EPOCH)
35            .unwrap()
36            .as_millis() as u64;
37
38        let config = ValidationConfig::default();
39
40        if self.timestamp > now + config.max_clock_skew_ms {
41            return Err(RabiaError::internal(format!(
42                "Message timestamp {} is too far in the future (current: {})",
43                self.timestamp, now
44            )));
45        }
46
47        if now.saturating_sub(self.timestamp) > config.max_clock_skew_ms * 10 {
48            return Err(RabiaError::internal(format!(
49                "Message timestamp {} is too old (current: {})",
50                self.timestamp, now
51            )));
52        }
53
54        // Validate message type specific fields
55        match &self.message_type {
56            crate::messages::MessageType::Propose(propose) => {
57                validate_phase_id(&propose.phase_id)?;
58                validate_batch_id(&propose.batch_id)?;
59                if let Some(batch) = &propose.batch {
60                    batch.validate()?;
61                }
62            }
63            crate::messages::MessageType::VoteRound1(vote) => {
64                validate_phase_id(&vote.phase_id)?;
65                validate_batch_id(&vote.batch_id)?;
66                validate_node_id(&vote.voter_id)?;
67            }
68            crate::messages::MessageType::VoteRound2(vote) => {
69                validate_phase_id(&vote.phase_id)?;
70                validate_batch_id(&vote.batch_id)?;
71                validate_node_id(&vote.voter_id)?;
72
73                // Validate round1_votes mapping
74                if vote.round1_votes.is_empty() {
75                    return Err(RabiaError::internal(
76                        "Round 2 vote must include round 1 votes".to_string(),
77                    ));
78                }
79            }
80            crate::messages::MessageType::Decision(decision) => {
81                validate_phase_id(&decision.phase_id)?;
82                validate_batch_id(&decision.batch_id)?;
83                if let Some(batch) = &decision.batch {
84                    batch.validate()?;
85                }
86            }
87            crate::messages::MessageType::SyncRequest(request) => {
88                validate_phase_id(&request.requester_phase)?;
89            }
90            crate::messages::MessageType::SyncResponse(response) => {
91                validate_phase_id(&response.responder_phase)?;
92
93                // Validate pending batches
94                for (batch_id, batch) in &response.pending_batches {
95                    validate_batch_id(batch_id)?;
96                    batch.validate()?;
97                }
98            }
99            crate::messages::MessageType::NewBatch(new_batch) => {
100                new_batch.batch.validate()?;
101                validate_node_id(&new_batch.originator)?;
102            }
103            crate::messages::MessageType::HeartBeat(heartbeat) => {
104                validate_phase_id(&heartbeat.current_phase)?;
105                validate_phase_id(&heartbeat.last_committed_phase)?;
106
107                // Committed phase should not be greater than current phase
108                if heartbeat.last_committed_phase > heartbeat.current_phase {
109                    return Err(RabiaError::InvalidStateTransition {
110                        from: format!("committed={}", heartbeat.last_committed_phase),
111                        to: format!("current={}", heartbeat.current_phase),
112                    });
113                }
114            }
115            crate::messages::MessageType::QuorumNotification(notification) => {
116                for node_id in &notification.active_nodes {
117                    validate_node_id(node_id)?;
118                }
119            }
120        }
121
122        Ok(())
123    }
124}
125
126impl Validator for CommandBatch {
127    fn validate(&self) -> Result<()> {
128        let config = ValidationConfig::default();
129
130        // Validate batch size
131        if self.commands.len() > config.max_batch_size {
132            return Err(RabiaError::internal(format!(
133                "Batch size {} exceeds maximum {}",
134                self.commands.len(),
135                config.max_batch_size
136            )));
137        }
138
139        if self.commands.is_empty() {
140            return Err(RabiaError::internal("Batch cannot be empty".to_string()));
141        }
142
143        // Validate individual commands
144        for command in &self.commands {
145            if command.data.len() > config.max_command_size {
146                return Err(RabiaError::internal(format!(
147                    "Command size {} exceeds maximum {}",
148                    command.data.len(),
149                    config.max_command_size
150                )));
151            }
152
153            // Basic command validation
154            if command.data.is_empty() {
155                return Err(RabiaError::internal(
156                    "Command data cannot be empty".to_string(),
157                ));
158            }
159        }
160
161        // Validate checksum
162        let _calculated_checksum = self.checksum();
163        // In a real implementation, we would compare against a stored checksum
164
165        // Validate timestamp
166        let now = SystemTime::now()
167            .duration_since(UNIX_EPOCH)
168            .unwrap()
169            .as_millis() as u64;
170
171        if self.timestamp > now + config.max_clock_skew_ms {
172            return Err(RabiaError::internal(format!(
173                "Batch timestamp {} is too far in the future",
174                self.timestamp
175            )));
176        }
177
178        Ok(())
179    }
180}
181
182fn validate_phase_id(phase_id: &PhaseId) -> Result<()> {
183    let config = ValidationConfig::default();
184    let value = phase_id.value();
185
186    if value < config.min_phase_id || value > config.max_phase_id {
187        return Err(RabiaError::internal(format!(
188            "Phase ID {} is out of valid range [{}, {}]",
189            value, config.min_phase_id, config.max_phase_id
190        )));
191    }
192
193    Ok(())
194}
195
196fn validate_batch_id(_batch_id: &BatchId) -> Result<()> {
197    // BatchId is a UUID, so basic validation is that it's not nil
198    // Additional validation could include checking against known batches
199    Ok(())
200}
201
202fn validate_node_id(_node_id: &NodeId) -> Result<()> {
203    // NodeId is a UUID, so basic validation is that it's not nil
204    // Additional validation could include checking against authorized nodes
205    Ok(())
206}
207
208pub fn validate_message_sequence(previous_phase: PhaseId, current_phase: PhaseId) -> Result<()> {
209    if current_phase.value() <= previous_phase.value() {
210        return Err(RabiaError::InvalidStateTransition {
211            from: format!("phase={}", previous_phase),
212            to: format!("phase={}", current_phase),
213        });
214    }
215
216    // Check for reasonable phase progression (not too large jumps)
217    let jump = current_phase.value() - previous_phase.value();
218    if jump > 1000 {
219        return Err(RabiaError::internal(format!(
220            "Phase jump {} is suspiciously large",
221            jump
222        )));
223    }
224
225    Ok(())
226}
227
228#[cfg(test)]
229mod tests {
230    use super::*;
231    use crate::Command;
232
233    #[test]
234    fn test_batch_validation() {
235        let commands = vec![Command::new("SET key1 value1"), Command::new("GET key1")];
236        let batch = CommandBatch::new(commands);
237
238        assert!(batch.validate().is_ok());
239    }
240
241    #[test]
242    fn test_empty_batch_validation() {
243        let batch = CommandBatch::new(vec![]);
244        assert!(batch.validate().is_err());
245    }
246
247    #[test]
248    fn test_phase_sequence_validation() {
249        let phase1 = PhaseId::new(1);
250        let phase2 = PhaseId::new(2);
251        let phase3 = PhaseId::new(1); // Invalid: going backwards
252
253        assert!(validate_message_sequence(phase1, phase2).is_ok());
254        assert!(validate_message_sequence(phase2, phase3).is_err());
255    }
256}