1use crate::{
16 error::{Error, Result},
17 parser::{
18 header::{parse_magic_and_version, CassandraVersion},
19 vint::{parse_vint, parse_vint_length},
20 },
21};
22use nom::{
23 bytes::complete::take,
24 number::complete::{be_u16, be_u32, be_u64, le_u32},
25 IResult,
26};
27use serde::{Deserialize, Serialize};
28use std::collections::HashMap;
29
30#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
32pub enum SSTableComponentType {
33 Data,
35 Index,
37 Summary,
39 Statistics,
41 CompressionInfo,
43 Filter,
45}
46
47#[derive(Debug, Clone, Serialize, Deserialize)]
49pub struct ComponentHeaderSpec {
50 pub component_type: SSTableComponentType,
52 pub has_magic_number: bool,
54 pub magic_number: Option<u32>,
56 pub min_version: u32,
58 pub max_version: u32,
60 pub field_layout: HeaderFieldLayout,
62}
63
64#[derive(Debug, Clone, Serialize, Deserialize)]
66pub struct HeaderFieldLayout {
67 pub fields: Vec<HeaderField>,
69 pub min_size: usize,
71 pub max_size: usize,
73}
74
75#[derive(Debug, Clone, Serialize, Deserialize)]
77pub struct HeaderField {
78 pub name: String,
80 pub field_type: HeaderFieldType,
82 pub optional: bool,
84 pub validation: Option<FieldValidation>,
86}
87
88#[derive(Debug, Clone, Serialize, Deserialize)]
90pub enum HeaderFieldType {
91 U8,
93 U16BE,
95 U32BE,
97 U64BE,
99 U32LE,
101 VInt,
103 VString,
105 FixedBytes(usize),
107 VBytes,
109 Array(Box<HeaderFieldType>),
111 Map(Box<HeaderFieldType>, Box<HeaderFieldType>),
113}
114
115#[derive(Debug, Clone, Serialize, Deserialize)]
117pub struct FieldValidation {
118 pub min_value: Option<u64>,
120 pub max_value: Option<u64>,
122 pub allowed_values: Option<Vec<u64>>,
124 pub max_length: Option<usize>,
126}
127
128#[derive(Debug, Clone, Serialize, Deserialize)]
130pub struct ParsedHeader {
131 pub component_type: SSTableComponentType,
133 pub cassandra_version: CassandraVersion,
135 pub format_version: u32,
137 pub fields: HashMap<String, HeaderFieldValue>,
139 pub header_size: usize,
141}
142
143#[derive(Debug, Clone, Serialize, Deserialize)]
145pub enum HeaderFieldValue {
146 U8(u8),
147 U16(u16),
148 U32(u32),
149 U64(u64),
150 VInt(i64),
151 String(String),
152 Bytes(Vec<u8>),
153 Array(Vec<HeaderFieldValue>),
154 Map(HashMap<String, HeaderFieldValue>),
155}
156
157impl HeaderFieldValue {
158 pub fn as_u32(&self) -> Result<u32> {
160 match self {
161 HeaderFieldValue::U32(v) => Ok(*v),
162 _ => Err(Error::corruption("Expected u32 field value".to_string())),
163 }
164 }
165
166 pub fn as_u64(&self) -> Result<u64> {
168 match self {
169 HeaderFieldValue::U64(v) => Ok(*v),
170 _ => Err(Error::corruption("Expected u64 field value".to_string())),
171 }
172 }
173
174 pub fn as_string(&self) -> Result<&str> {
176 match self {
177 HeaderFieldValue::String(s) => Ok(s),
178 _ => Err(Error::corruption("Expected string field value".to_string())),
179 }
180 }
181
182 pub fn as_bytes(&self) -> Result<&[u8]> {
184 match self {
185 HeaderFieldValue::Bytes(b) => Ok(b),
186 _ => Err(Error::corruption("Expected bytes field value".to_string())),
187 }
188 }
189}
190
191pub struct HeaderSpecRegistry {
193 specs: HashMap<SSTableComponentType, ComponentHeaderSpec>,
194}
195
196impl Default for HeaderSpecRegistry {
197 fn default() -> Self {
198 Self::new()
199 }
200}
201
202impl HeaderSpecRegistry {
203 pub fn new() -> Self {
205 let mut registry = Self {
206 specs: HashMap::new(),
207 };
208 registry.register_default_specs();
209 registry
210 }
211
212 fn register_default_specs(&mut self) {
214 self.specs.insert(
216 SSTableComponentType::Data,
217 ComponentHeaderSpec {
218 component_type: SSTableComponentType::Data,
219 has_magic_number: true,
220 magic_number: None, min_version: 1,
222 max_version: 10,
223 field_layout: HeaderFieldLayout {
224 fields: vec![
225 HeaderField {
226 name: "table_id".to_string(),
227 field_type: HeaderFieldType::FixedBytes(16),
228 optional: false,
229 validation: None,
230 },
231 HeaderField {
232 name: "keyspace".to_string(),
233 field_type: HeaderFieldType::VString,
234 optional: false,
235 validation: Some(FieldValidation {
236 min_value: None,
237 max_value: None,
238 allowed_values: None,
239 max_length: Some(256),
240 }),
241 },
242 HeaderField {
243 name: "table_name".to_string(),
244 field_type: HeaderFieldType::VString,
245 optional: false,
246 validation: Some(FieldValidation {
247 min_value: None,
248 max_value: None,
249 allowed_values: None,
250 max_length: Some(256),
251 }),
252 },
253 HeaderField {
254 name: "generation".to_string(),
255 field_type: HeaderFieldType::U64BE,
256 optional: false,
257 validation: Some(FieldValidation {
258 min_value: Some(1),
259 max_value: Some(u64::MAX),
260 allowed_values: None,
261 max_length: None,
262 }),
263 },
264 ],
265 min_size: 32,
266 max_size: 1024,
267 },
268 },
269 );
270
271 self.specs.insert(
273 SSTableComponentType::Index,
274 ComponentHeaderSpec {
275 component_type: SSTableComponentType::Index,
276 has_magic_number: false, magic_number: None,
278 min_version: 1,
279 max_version: 10,
280 field_layout: HeaderFieldLayout {
281 fields: vec![
282 HeaderField {
283 name: "version".to_string(),
284 field_type: HeaderFieldType::U32BE,
285 optional: false,
286 validation: Some(FieldValidation {
287 min_value: Some(1),
288 max_value: Some(10),
289 allowed_values: None,
290 max_length: None,
291 }),
292 },
293 HeaderField {
294 name: "entry_count".to_string(),
295 field_type: HeaderFieldType::U32BE,
296 optional: false,
297 validation: Some(FieldValidation {
298 min_value: Some(0),
299 max_value: Some(100_000_000),
300 allowed_values: None,
301 max_length: None,
302 }),
303 },
304 HeaderField {
305 name: "data_size".to_string(),
306 field_type: HeaderFieldType::U64BE,
307 optional: false,
308 validation: Some(FieldValidation {
309 min_value: Some(0),
310 max_value: Some(1_000_000_000_000), allowed_values: None,
312 max_length: None,
313 }),
314 },
315 HeaderField {
316 name: "checksum".to_string(),
317 field_type: HeaderFieldType::U32BE,
318 optional: false,
319 validation: None,
320 },
321 ],
322 min_size: 16,
323 max_size: 64,
324 },
325 },
326 );
327
328 self.specs.insert(
330 SSTableComponentType::Summary,
331 ComponentHeaderSpec {
332 component_type: SSTableComponentType::Summary,
333 has_magic_number: false, magic_number: None,
335 min_version: 1,
336 max_version: 10,
337 field_layout: HeaderFieldLayout {
338 fields: vec![
339 HeaderField {
340 name: "version".to_string(),
341 field_type: HeaderFieldType::U32BE,
342 optional: false,
343 validation: Some(FieldValidation {
344 min_value: Some(1),
345 max_value: Some(10),
346 allowed_values: None,
347 max_length: None,
348 }),
349 },
350 HeaderField {
351 name: "entry_count".to_string(),
352 field_type: HeaderFieldType::U32BE,
353 optional: false,
354 validation: Some(FieldValidation {
355 min_value: Some(0),
356 max_value: Some(100_000_000),
357 allowed_values: None,
358 max_length: None,
359 }),
360 },
361 HeaderField {
362 name: "sampling_rate".to_string(),
363 field_type: HeaderFieldType::U32BE,
364 optional: false,
365 validation: Some(FieldValidation {
366 min_value: Some(1),
367 max_value: Some(1_000_000),
368 allowed_values: None,
369 max_length: None,
370 }),
371 },
372 HeaderField {
373 name: "min_token".to_string(),
374 field_type: HeaderFieldType::U64BE,
375 optional: false,
376 validation: None,
377 },
378 HeaderField {
379 name: "max_token".to_string(),
380 field_type: HeaderFieldType::U64BE,
381 optional: false,
382 validation: None,
383 },
384 HeaderField {
385 name: "data_size".to_string(),
386 field_type: HeaderFieldType::U64BE,
387 optional: false,
388 validation: Some(FieldValidation {
389 min_value: Some(1),
390 max_value: Some(1_000_000_000),
391 allowed_values: None,
392 max_length: None,
393 }),
394 },
395 HeaderField {
396 name: "checksum".to_string(),
397 field_type: HeaderFieldType::U32BE,
398 optional: false,
399 validation: None,
400 },
401 ],
402 min_size: 32,
403 max_size: 1024,
404 },
405 },
406 );
407 }
408
409 pub fn get_spec(&self, component_type: SSTableComponentType) -> Result<&ComponentHeaderSpec> {
411 self.specs.get(&component_type).ok_or_else(|| {
412 Error::unsupported_format(format!(
413 "No specification for component: {:?}",
414 component_type
415 ))
416 })
417 }
418
419 pub fn parse_header(
421 &self,
422 input: &[u8],
423 component_type: SSTableComponentType,
424 ) -> Result<ParsedHeader> {
425 let spec = self.get_spec(component_type)?;
426 parse_component_header(input, spec)
427 }
428}
429
430pub fn parse_component_header(input: &[u8], spec: &ComponentHeaderSpec) -> Result<ParsedHeader> {
432 let original_input = input;
433 let mut remaining = input;
434 let mut fields = HashMap::new();
435
436 if input.len() < spec.field_layout.min_size {
438 return Err(Error::corruption(format!(
439 "Insufficient data for {:?} header: need {} bytes, have {}",
440 spec.component_type,
441 spec.field_layout.min_size,
442 input.len()
443 )));
444 }
445
446 let (cassandra_version, format_version) = if spec.has_magic_number {
448 if let Some(expected_magic) = spec.magic_number {
450 if remaining.len() < 4 {
452 return Err(Error::corruption(
453 "Insufficient data for magic number".to_string(),
454 ));
455 }
456 let (new_remaining, magic) = be_u32::<_, nom::error::Error<&[u8]>>(remaining)
457 .map_err(|e| Error::corruption(format!("Failed to parse magic: {:?}", e)))?;
458 if magic != expected_magic {
459 return Err(Error::corruption(format!(
460 "Magic number mismatch: expected 0x{:08X}, got 0x{:08X}",
461 expected_magic, magic
462 )));
463 }
464 remaining = new_remaining;
465
466 let (new_remaining, version) = be_u32::<_, nom::error::Error<&[u8]>>(remaining)
468 .map_err(|e| Error::corruption(format!("Failed to parse version: {:?}", e)))?;
469 remaining = new_remaining;
470
471 (CassandraVersion::Legacy, version as u16)
472 } else {
473 let (new_remaining, (version, format_ver)) = parse_magic_and_version(remaining)
475 .map_err(|e| {
476 Error::corruption(format!("Failed to parse magic/version: {:?}", e))
477 })?;
478 remaining = new_remaining;
479 (version, format_ver)
480 }
481 } else {
482 (CassandraVersion::Legacy, 1u16) };
486
487 for field in &spec.field_layout.fields {
489 let (new_remaining, value) =
490 parse_header_field(remaining, &field.field_type, &field.validation).map_err(|e| {
491 Error::corruption(format!("Failed to parse field '{}': {:?}", field.name, e))
492 })?;
493
494 remaining = new_remaining;
495 fields.insert(field.name.clone(), value);
496 }
497
498 let header_size = original_input.len() - remaining.len();
500
501 if header_size > spec.field_layout.max_size {
503 return Err(Error::corruption(format!(
504 "Header size {} exceeds maximum {} for {:?}",
505 header_size, spec.field_layout.max_size, spec.component_type
506 )));
507 }
508
509 let actual_format_version = if !spec.has_magic_number {
511 if let Some(HeaderFieldValue::U32(version)) = fields.get("version") {
512 *version as u16
513 } else {
514 format_version
515 }
516 } else {
517 format_version
518 };
519
520 Ok(ParsedHeader {
521 component_type: spec.component_type,
522 cassandra_version,
523 format_version: actual_format_version.into(),
524 fields,
525 header_size,
526 })
527}
528
529fn parse_header_field<'a>(
531 input: &'a [u8],
532 field_type: &HeaderFieldType,
533 validation: &Option<FieldValidation>,
534) -> IResult<&'a [u8], HeaderFieldValue> {
535 use nom::error::{Error as NomError, ErrorKind};
536
537 let (remaining, value) = match field_type {
538 HeaderFieldType::U8 => {
539 let (remaining, val) = nom::number::complete::be_u8(input)?;
540 (remaining, HeaderFieldValue::U8(val))
541 }
542 HeaderFieldType::U16BE => {
543 let (remaining, val) = be_u16(input)?;
544 (remaining, HeaderFieldValue::U16(val))
545 }
546 HeaderFieldType::U32BE => {
547 let (remaining, val) = be_u32(input)?;
548 (remaining, HeaderFieldValue::U32(val))
549 }
550 HeaderFieldType::U64BE => {
551 let (remaining, val) = be_u64(input)?;
552 (remaining, HeaderFieldValue::U64(val))
553 }
554 HeaderFieldType::U32LE => {
555 let (remaining, val) = le_u32(input)?;
556 (remaining, HeaderFieldValue::U32(val))
557 }
558 HeaderFieldType::VInt => {
559 let (remaining, val) = parse_vint(input)?;
560 (remaining, HeaderFieldValue::VInt(val))
561 }
562 HeaderFieldType::VString => {
563 if input.is_empty() {
565 return Err(nom::Err::Error(NomError::new(input, ErrorKind::Eof)));
566 }
567 let len = input[0] as usize;
568 if input.len() < 1 + len {
569 return Err(nom::Err::Error(NomError::new(input, ErrorKind::Eof)));
570 }
571 let (remaining, _) = take(1usize)(input)?; let (remaining, bytes) = take(len)(remaining)?;
573 let string = String::from_utf8(bytes.to_vec())
574 .map_err(|_| nom::Err::Error(NomError::new(input, ErrorKind::Verify)))?;
575 (remaining, HeaderFieldValue::String(string))
576 }
577 HeaderFieldType::FixedBytes(size) => {
578 let (remaining, bytes) = take(*size)(input)?;
579 (remaining, HeaderFieldValue::Bytes(bytes.to_vec()))
580 }
581 HeaderFieldType::VBytes => {
582 let (remaining, len) = parse_vint_length(input)?;
583 let (remaining, bytes) = take(len)(remaining)?;
584 (remaining, HeaderFieldValue::Bytes(bytes.to_vec()))
585 }
586 HeaderFieldType::Array(element_type) => {
587 let (remaining, count) = parse_vint_length(input)?;
588 let mut elements = Vec::new();
589 let mut current = remaining;
590
591 for _ in 0..count {
592 let (new_current, element) = parse_header_field(current, element_type, &None)?;
593 elements.push(element);
594 current = new_current;
595 }
596
597 (current, HeaderFieldValue::Array(elements))
598 }
599 HeaderFieldType::Map(key_type, value_type) => {
600 let (remaining, count) = parse_vint_length(input)?;
601 let mut map = HashMap::new();
602 let mut current = remaining;
603
604 for _ in 0..count {
605 let (new_current, key) = parse_header_field(current, key_type, &None)?;
606 let (new_current, value) = parse_header_field(new_current, value_type, &None)?;
607
608 let key_str = match key {
609 HeaderFieldValue::String(s) => s,
610 _ => return Err(nom::Err::Error(NomError::new(input, ErrorKind::Verify))),
611 };
612
613 map.insert(key_str, value);
614 current = new_current;
615 }
616
617 (current, HeaderFieldValue::Map(map))
618 }
619 };
620
621 if let Some(validation) = validation {
623 validate_field_value(&value, validation)
624 .map_err(|_| nom::Err::Error(NomError::new(input, ErrorKind::Verify)))?;
625 }
626
627 Ok((remaining, value))
628}
629
630fn validate_field_value(value: &HeaderFieldValue, validation: &FieldValidation) -> Result<()> {
632 if validation.min_value.is_some() || validation.max_value.is_some() {
634 let num_value = match value {
635 HeaderFieldValue::U8(v) => *v as u64,
636 HeaderFieldValue::U16(v) => *v as u64,
637 HeaderFieldValue::U32(v) => *v as u64,
638 HeaderFieldValue::U64(v) => *v,
639 HeaderFieldValue::VInt(v) => {
640 if *v < 0 {
642 if let Some(min) = validation.min_value {
644 if *v < (min as i64) {
645 return Err(Error::corruption(format!(
646 "Field value {} below minimum {}",
647 v, min
648 )));
649 }
650 }
651 return Ok(()); } else {
653 *v as u64
654 }
655 }
656 _ => return Ok(()), };
658
659 if let Some(min) = validation.min_value {
660 if num_value < min {
661 return Err(Error::corruption(format!(
662 "Field value {} below minimum {}",
663 num_value, min
664 )));
665 }
666 }
667
668 if let Some(max) = validation.max_value {
669 if num_value > max {
670 return Err(Error::corruption(format!(
671 "Field value {} above maximum {}",
672 num_value, max
673 )));
674 }
675 }
676 }
677
678 if let Some(max_len) = validation.max_length {
680 let actual_len = match value {
681 HeaderFieldValue::String(s) => s.len(),
682 HeaderFieldValue::Bytes(b) => b.len(),
683 _ => return Ok(()), };
685
686 if actual_len > max_len {
687 return Err(Error::corruption(format!(
688 "Field length {} exceeds maximum {}",
689 actual_len, max_len
690 )));
691 }
692 }
693
694 if let Some(allowed) = &validation.allowed_values {
696 let num_value = match value {
697 HeaderFieldValue::U8(v) => *v as u64,
698 HeaderFieldValue::U16(v) => *v as u64,
699 HeaderFieldValue::U32(v) => *v as u64,
700 HeaderFieldValue::U64(v) => *v,
701 HeaderFieldValue::VInt(v) => (*v).unsigned_abs(),
702 _ => return Ok(()), };
704
705 if !allowed.contains(&num_value) {
706 return Err(Error::corruption(format!(
707 "Field value {} not in allowed values: {:?}",
708 num_value, allowed
709 )));
710 }
711 }
712
713 Ok(())
714}
715
716impl HeaderSpecRegistry {
718 pub fn parse_data_header(&self, input: &[u8]) -> Result<ParsedHeader> {
720 self.parse_header(input, SSTableComponentType::Data)
721 }
722
723 pub fn parse_index_header(&self, input: &[u8]) -> Result<ParsedHeader> {
725 self.parse_header(input, SSTableComponentType::Index)
726 }
727
728 pub fn parse_summary_header(&self, input: &[u8]) -> Result<ParsedHeader> {
730 if input.len() < 4 {
731 return Err(Error::corruption(
732 "Insufficient data for Summary.db header".to_string(),
733 ));
734 }
735
736 let potential_magic = u32::from_be_bytes([input[0], input[1], input[2], input[3]]);
738 if potential_magic == 0x43515354 {
739 let mut magic_spec = self.get_spec(SSTableComponentType::Summary)?.clone();
741 magic_spec.has_magic_number = true;
742 magic_spec.magic_number = Some(0x43515354); magic_spec
745 .field_layout
746 .fields
747 .retain(|f| f.name != "version");
748 parse_component_header(input, &magic_spec)
749 } else {
750 self.parse_header(input, SSTableComponentType::Summary)
752 }
753 }
754}
755
756static GLOBAL_REGISTRY: std::sync::OnceLock<HeaderSpecRegistry> = std::sync::OnceLock::new();
758
759pub fn get_global_registry() -> &'static HeaderSpecRegistry {
761 GLOBAL_REGISTRY.get_or_init(HeaderSpecRegistry::new)
762}
763
764#[cfg(test)]
765mod tests {
766 use super::*;
767
768 #[test]
769 fn test_registry_creation() {
770 let registry = HeaderSpecRegistry::new();
771 assert!(registry.get_spec(SSTableComponentType::Data).is_ok());
772 assert!(registry.get_spec(SSTableComponentType::Index).is_ok());
773 assert!(registry.get_spec(SSTableComponentType::Summary).is_ok());
774 }
775
776 #[test]
777 fn test_field_validation() {
778 let validation = FieldValidation {
779 min_value: Some(1),
780 max_value: Some(100),
781 allowed_values: None,
782 max_length: None,
783 };
784
785 let value = HeaderFieldValue::U32(50);
786 assert!(validate_field_value(&value, &validation).is_ok());
787
788 let value = HeaderFieldValue::U32(0);
789 assert!(validate_field_value(&value, &validation).is_err());
790
791 let value = HeaderFieldValue::U32(101);
792 assert!(validate_field_value(&value, &validation).is_err());
793 }
794
795 #[test]
796 fn test_string_length_validation() {
797 let validation = FieldValidation {
798 min_value: None,
799 max_value: None,
800 allowed_values: None,
801 max_length: Some(5),
802 };
803
804 let value = HeaderFieldValue::String("test".to_string());
805 assert!(validate_field_value(&value, &validation).is_ok());
806
807 let value = HeaderFieldValue::String("toolong".to_string());
808 assert!(validate_field_value(&value, &validation).is_err());
809 }
810
811 #[test]
812 fn test_allowed_values_validation() {
813 let validation = FieldValidation {
814 min_value: None,
815 max_value: None,
816 allowed_values: Some(vec![1, 2, 3]),
817 max_length: None,
818 };
819
820 let value = HeaderFieldValue::U32(2);
821 assert!(validate_field_value(&value, &validation).is_ok());
822
823 let value = HeaderFieldValue::U32(4);
824 assert!(validate_field_value(&value, &validation).is_err());
825 }
826}