1use super::types::TypeTag;
7use std::collections::HashMap;
8
9#[derive(Debug, Clone, PartialEq, Eq)]
11pub struct FeatureFlags {
12 pub supports_nested: bool,
14 pub supports_streaming: bool,
16 pub supports_delta: bool,
18 pub supports_llb: bool,
20 pub requires_checksums: bool,
22 pub requires_canonical: bool,
24}
25
26impl FeatureFlags {
27 pub fn new() -> Self {
29 Self {
30 supports_nested: false,
31 supports_streaming: false,
32 supports_delta: false,
33 supports_llb: false,
34 requires_checksums: false,
35 requires_canonical: false,
36 }
37 }
38
39 pub fn v0_5_full() -> Self {
41 Self {
42 supports_nested: true,
43 supports_streaming: true,
44 supports_delta: true,
45 supports_llb: true,
46 requires_checksums: true,
47 requires_canonical: true,
48 }
49 }
50
51 pub fn v0_4_compatible() -> Self {
53 Self {
54 supports_nested: false,
55 supports_streaming: false,
56 supports_delta: false,
57 supports_llb: false,
58 requires_checksums: false,
59 requires_canonical: true,
60 }
61 }
62
63 pub fn intersect(&self, other: &FeatureFlags) -> FeatureFlags {
65 FeatureFlags {
66 supports_nested: self.supports_nested && other.supports_nested,
67 supports_streaming: self.supports_streaming && other.supports_streaming,
68 supports_delta: self.supports_delta && other.supports_delta,
69 supports_llb: self.supports_llb && other.supports_llb,
70 requires_checksums: self.requires_checksums || other.requires_checksums,
71 requires_canonical: self.requires_canonical || other.requires_canonical,
72 }
73 }
74}
75
76impl Default for FeatureFlags {
77 fn default() -> Self {
78 Self::new()
79 }
80}
81
82#[derive(Debug, Clone, PartialEq)]
84pub struct Capabilities {
85 pub version: u8,
87 pub features: FeatureFlags,
89 pub supported_types: Vec<TypeTag>,
91}
92
93impl Capabilities {
94 pub fn new(version: u8, features: FeatureFlags, supported_types: Vec<TypeTag>) -> Self {
96 Self {
97 version,
98 features,
99 supported_types,
100 }
101 }
102
103 pub fn v0_5() -> Self {
105 Self {
106 version: 0x05,
107 features: FeatureFlags::v0_5_full(),
108 supported_types: vec![
109 TypeTag::Int,
110 TypeTag::Float,
111 TypeTag::Bool,
112 TypeTag::String,
113 TypeTag::StringArray,
114 TypeTag::NestedRecord,
115 TypeTag::NestedArray,
116 ],
117 }
118 }
119
120 pub fn v0_4() -> Self {
122 Self {
123 version: 0x04,
124 features: FeatureFlags::v0_4_compatible(),
125 supported_types: vec![
126 TypeTag::Int,
127 TypeTag::Float,
128 TypeTag::Bool,
129 TypeTag::String,
130 TypeTag::StringArray,
131 ],
132 }
133 }
134
135 pub fn supports_type(&self, type_tag: TypeTag) -> bool {
137 self.supported_types.contains(&type_tag)
138 }
139}
140
141#[derive(Debug, Clone, PartialEq)]
143pub enum NegotiationMessage {
144 Capabilities {
146 version: u8,
148 features: FeatureFlags,
150 supported_types: Vec<TypeTag>,
152 },
153
154 CapabilitiesAck {
156 version: u8,
158 features: FeatureFlags,
160 },
161
162 SelectSchema {
164 schema_id: String,
166 fid_mappings: HashMap<u16, String>,
168 },
169
170 Ready {
172 session_id: u64,
174 },
175
176 Error {
178 code: ErrorCode,
180 message: String,
182 },
183}
184
185#[derive(Debug, Clone, Copy, PartialEq, Eq)]
187pub enum ErrorCode {
188 FidConflict = 0x01,
190 TypeMismatch = 0x02,
192 UnsupportedFeature = 0x03,
194 ProtocolVersionMismatch = 0x04,
196 InvalidState = 0x05,
198 Generic = 0xFF,
200}
201
202impl ErrorCode {
203 pub fn from_u8(byte: u8) -> Option<Self> {
205 match byte {
206 0x01 => Some(ErrorCode::FidConflict),
207 0x02 => Some(ErrorCode::TypeMismatch),
208 0x03 => Some(ErrorCode::UnsupportedFeature),
209 0x04 => Some(ErrorCode::ProtocolVersionMismatch),
210 0x05 => Some(ErrorCode::InvalidState),
211 0xFF => Some(ErrorCode::Generic),
212 _ => None,
213 }
214 }
215
216 pub fn to_u8(self) -> u8 {
218 self as u8
219 }
220}
221
222#[derive(Debug, Clone, PartialEq, Eq)]
224pub enum NegotiationState {
225 Initial,
227 CapabilitiesSent,
229 CapabilitiesReceived,
231 SchemaSelected,
233 Ready,
235 Failed(String),
237}
238
239#[derive(Debug, Clone, PartialEq)]
241pub struct NegotiationSession {
242 pub session_id: u64,
244 pub local_caps: Capabilities,
246 pub remote_caps: Capabilities,
248 pub agreed_features: FeatureFlags,
250 pub fid_mappings: HashMap<u16, String>,
252}
253
254impl NegotiationSession {
255 pub fn new(
257 session_id: u64,
258 local_caps: Capabilities,
259 remote_caps: Capabilities,
260 fid_mappings: HashMap<u16, String>,
261 ) -> Self {
262 let agreed_features = local_caps.features.intersect(&remote_caps.features);
263 Self {
264 session_id,
265 local_caps,
266 remote_caps,
267 agreed_features,
268 fid_mappings,
269 }
270 }
271}
272
273#[derive(Debug, Clone)]
275pub struct SchemaNegotiator {
276 local_capabilities: Capabilities,
278 remote_capabilities: Option<Capabilities>,
280 state: NegotiationState,
282 next_session_id: u64,
284 fid_mappings: HashMap<u16, String>,
286}
287
288#[derive(Debug, Clone, PartialEq)]
290pub enum NegotiationResponse {
291 SendMessage(NegotiationMessage),
293 Complete(NegotiationSession),
295 Failed(String),
297 None,
299}
300
301#[derive(Debug, Clone, PartialEq)]
303pub enum NegotiationError {
304 FidConflict {
306 fid: u16,
308 name1: String,
310 name2: String,
312 },
313 TypeMismatch {
315 fid: u16,
317 expected: TypeTag,
319 found: TypeTag,
321 },
322 UnsupportedFeature {
324 feature: String,
326 },
327 ProtocolVersionMismatch {
329 local: u8,
331 remote: u8,
333 },
334 InvalidState {
336 current: NegotiationState,
338 expected: NegotiationState,
340 },
341}
342
343impl std::fmt::Display for NegotiationError {
344 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
345 match self {
346 NegotiationError::FidConflict { fid, name1, name2 } => {
347 write!(
348 f,
349 "FID conflict: FID {} maps to both '{}' and '{}'",
350 fid, name1, name2
351 )
352 }
353 NegotiationError::TypeMismatch {
354 fid,
355 expected,
356 found,
357 } => {
358 write!(
359 f,
360 "Type mismatch for FID {}: expected {:?}, found {:?}",
361 fid, expected, found
362 )
363 }
364 NegotiationError::UnsupportedFeature { feature } => {
365 write!(f, "Unsupported feature: {}", feature)
366 }
367 NegotiationError::ProtocolVersionMismatch { local, remote } => {
368 write!(
369 f,
370 "Protocol version mismatch: local 0x{:02X}, remote 0x{:02X}",
371 local, remote
372 )
373 }
374 NegotiationError::InvalidState { current, expected } => {
375 write!(
376 f,
377 "Invalid state transition: current {:?}, expected {:?}",
378 current, expected
379 )
380 }
381 }
382 }
383}
384
385impl std::error::Error for NegotiationError {}
386
387impl SchemaNegotiator {
388 pub fn new(local_capabilities: Capabilities) -> Self {
390 Self {
391 local_capabilities,
392 remote_capabilities: None,
393 state: NegotiationState::Initial,
394 next_session_id: 1,
395 fid_mappings: HashMap::new(),
396 }
397 }
398
399 pub fn v0_5() -> Self {
401 Self::new(Capabilities::v0_5())
402 }
403
404 pub fn v0_4() -> Self {
406 Self::new(Capabilities::v0_4())
407 }
408
409 pub fn with_fid_mappings(mut self, mappings: HashMap<u16, String>) -> Self {
411 self.fid_mappings = mappings;
412 self
413 }
414
415 pub fn initiate(&mut self) -> Result<NegotiationMessage, NegotiationError> {
417 if self.state != NegotiationState::Initial {
418 return Err(NegotiationError::InvalidState {
419 current: self.state.clone(),
420 expected: NegotiationState::Initial,
421 });
422 }
423
424 self.state = NegotiationState::CapabilitiesSent;
425
426 Ok(NegotiationMessage::Capabilities {
427 version: self.local_capabilities.version,
428 features: self.local_capabilities.features.clone(),
429 supported_types: self.local_capabilities.supported_types.clone(),
430 })
431 }
432
433 pub fn handle_message(
435 &mut self,
436 message: NegotiationMessage,
437 ) -> Result<NegotiationResponse, NegotiationError> {
438 match message {
439 NegotiationMessage::Capabilities {
440 version,
441 features,
442 supported_types,
443 } => self.handle_capabilities(version, features, supported_types),
444
445 NegotiationMessage::CapabilitiesAck { version, features } => {
446 self.handle_capabilities_ack(version, features)
447 }
448
449 NegotiationMessage::SelectSchema {
450 schema_id,
451 fid_mappings,
452 } => self.handle_select_schema(schema_id, fid_mappings),
453
454 NegotiationMessage::Ready { session_id } => self.handle_ready(session_id),
455
456 NegotiationMessage::Error { code: _, message } => {
457 self.state = NegotiationState::Failed(message.clone());
458 Ok(NegotiationResponse::Failed(message))
459 }
460 }
461 }
462
463 pub fn is_ready(&self) -> bool {
465 self.state == NegotiationState::Ready
466 }
467
468 pub fn state(&self) -> &NegotiationState {
470 &self.state
471 }
472
473 pub fn local_capabilities(&self) -> &Capabilities {
475 &self.local_capabilities
476 }
477
478 pub fn remote_capabilities(&self) -> Option<&Capabilities> {
480 self.remote_capabilities.as_ref()
481 }
482
483 fn handle_capabilities(
486 &mut self,
487 version: u8,
488 features: FeatureFlags,
489 supported_types: Vec<TypeTag>,
490 ) -> Result<NegotiationResponse, NegotiationError> {
491 if version != self.local_capabilities.version {
493 return Err(NegotiationError::ProtocolVersionMismatch {
494 local: self.local_capabilities.version,
495 remote: version,
496 });
497 }
498
499 self.remote_capabilities = Some(Capabilities::new(
501 version,
502 features.clone(),
503 supported_types,
504 ));
505 self.state = NegotiationState::CapabilitiesReceived;
506
507 Ok(NegotiationResponse::SendMessage(
509 NegotiationMessage::CapabilitiesAck {
510 version: self.local_capabilities.version,
511 features: self.local_capabilities.features.clone(),
512 },
513 ))
514 }
515
516 fn handle_capabilities_ack(
517 &mut self,
518 version: u8,
519 features: FeatureFlags,
520 ) -> Result<NegotiationResponse, NegotiationError> {
521 if self.state != NegotiationState::CapabilitiesSent {
522 return Err(NegotiationError::InvalidState {
523 current: self.state.clone(),
524 expected: NegotiationState::CapabilitiesSent,
525 });
526 }
527
528 if version != self.local_capabilities.version {
530 return Err(NegotiationError::ProtocolVersionMismatch {
531 local: self.local_capabilities.version,
532 remote: version,
533 });
534 }
535
536 self.remote_capabilities = Some(Capabilities::new(
538 version,
539 features,
540 self.local_capabilities.supported_types.clone(),
541 ));
542
543 self.state = NegotiationState::SchemaSelected;
545
546 Ok(NegotiationResponse::SendMessage(
548 NegotiationMessage::SelectSchema {
549 schema_id: "default".to_string(),
550 fid_mappings: self.fid_mappings.clone(),
551 },
552 ))
553 }
554
555 fn handle_select_schema(
556 &mut self,
557 _schema_id: String,
558 fid_mappings: HashMap<u16, String>,
559 ) -> Result<NegotiationResponse, NegotiationError> {
560 if self.state != NegotiationState::CapabilitiesReceived {
561 return Err(NegotiationError::InvalidState {
562 current: self.state.clone(),
563 expected: NegotiationState::CapabilitiesReceived,
564 });
565 }
566
567 self.detect_fid_conflicts(&fid_mappings)?;
569
570 self.fid_mappings = fid_mappings;
572 self.state = NegotiationState::SchemaSelected;
573
574 let session_id = self.next_session_id;
576 self.next_session_id += 1;
577
578 Ok(NegotiationResponse::SendMessage(
579 NegotiationMessage::Ready { session_id },
580 ))
581 }
582
583 fn handle_ready(&mut self, session_id: u64) -> Result<NegotiationResponse, NegotiationError> {
584 if self.state != NegotiationState::SchemaSelected {
585 return Err(NegotiationError::InvalidState {
586 current: self.state.clone(),
587 expected: NegotiationState::SchemaSelected,
588 });
589 }
590
591 self.state = NegotiationState::Ready;
592
593 let remote_caps = self
595 .remote_capabilities
596 .clone()
597 .expect("Remote capabilities should be set");
598
599 let session = NegotiationSession::new(
600 session_id,
601 self.local_capabilities.clone(),
602 remote_caps,
603 self.fid_mappings.clone(),
604 );
605
606 Ok(NegotiationResponse::Complete(session))
607 }
608
609 fn detect_fid_conflicts(
610 &self,
611 remote_mappings: &HashMap<u16, String>,
612 ) -> Result<(), NegotiationError> {
613 for (fid, remote_name) in remote_mappings {
614 if let Some(local_name) = self.fid_mappings.get(fid) {
615 if local_name != remote_name {
616 return Err(NegotiationError::FidConflict {
617 fid: *fid,
618 name1: local_name.clone(),
619 name2: remote_name.clone(),
620 });
621 }
622 }
623 }
624 Ok(())
625 }
626
627 pub fn detect_conflicts(
631 local_mappings: &HashMap<u16, String>,
632 remote_mappings: &HashMap<u16, String>,
633 ) -> Vec<NegotiationError> {
634 let mut conflicts = Vec::new();
635
636 for (fid, remote_name) in remote_mappings {
637 if let Some(local_name) = local_mappings.get(fid) {
638 if local_name != remote_name {
639 conflicts.push(NegotiationError::FidConflict {
640 fid: *fid,
641 name1: local_name.clone(),
642 name2: remote_name.clone(),
643 });
644 }
645 }
646 }
647
648 conflicts
649 }
650
651 pub fn detect_type_mismatches(
655 expected_types: &HashMap<u16, TypeTag>,
656 actual_types: &HashMap<u16, TypeTag>,
657 ) -> Vec<NegotiationError> {
658 let mut mismatches = Vec::new();
659
660 for (fid, actual_type) in actual_types {
661 if let Some(expected_type) = expected_types.get(fid) {
662 if expected_type != actual_type {
663 mismatches.push(NegotiationError::TypeMismatch {
664 fid: *fid,
665 expected: *expected_type,
666 found: *actual_type,
667 });
668 }
669 }
670 }
671
672 mismatches.sort_by(|a, b| {
676 let fid_a = match a {
677 NegotiationError::TypeMismatch { fid, .. } => *fid,
678 _ => 0,
679 };
680 let fid_b = match b {
681 NegotiationError::TypeMismatch { fid, .. } => *fid,
682 _ => 0,
683 };
684 fid_a.cmp(&fid_b)
685 });
686
687 mismatches
688 }
689}
690
691#[cfg(test)]
692mod tests {
693 #![allow(clippy::approx_constant)]
694
695 use super::*;
696
697 #[test]
698 fn test_feature_flags_new() {
699 let flags = FeatureFlags::new();
700 assert!(!flags.supports_nested);
701 assert!(!flags.supports_streaming);
702 assert!(!flags.supports_delta);
703 assert!(!flags.supports_llb);
704 assert!(!flags.requires_checksums);
705 assert!(!flags.requires_canonical);
706 }
707
708 #[test]
709 fn test_feature_flags_v0_5_full() {
710 let flags = FeatureFlags::v0_5_full();
711 assert!(flags.supports_nested);
712 assert!(flags.supports_streaming);
713 assert!(flags.supports_delta);
714 assert!(flags.supports_llb);
715 assert!(flags.requires_checksums);
716 assert!(flags.requires_canonical);
717 }
718
719 #[test]
720 fn test_feature_flags_v0_4_compatible() {
721 let flags = FeatureFlags::v0_4_compatible();
722 assert!(!flags.supports_nested);
723 assert!(!flags.supports_streaming);
724 assert!(!flags.supports_delta);
725 assert!(!flags.supports_llb);
726 assert!(!flags.requires_checksums);
727 assert!(flags.requires_canonical);
728 }
729
730 #[test]
731 fn test_feature_flags_intersect() {
732 let flags1 = FeatureFlags {
733 supports_nested: true,
734 supports_streaming: true,
735 supports_delta: false,
736 supports_llb: true,
737 requires_checksums: false,
738 requires_canonical: true,
739 };
740
741 let flags2 = FeatureFlags {
742 supports_nested: true,
743 supports_streaming: false,
744 supports_delta: true,
745 supports_llb: true,
746 requires_checksums: true,
747 requires_canonical: false,
748 };
749
750 let intersection = flags1.intersect(&flags2);
751 assert!(intersection.supports_nested);
752 assert!(!intersection.supports_streaming);
753 assert!(!intersection.supports_delta);
754 assert!(intersection.supports_llb);
755 assert!(intersection.requires_checksums); assert!(intersection.requires_canonical); }
758
759 #[test]
760 fn test_feature_flags_intersect_all_enabled() {
761 let flags1 = FeatureFlags::v0_5_full();
762 let flags2 = FeatureFlags::v0_5_full();
763
764 let intersection = flags1.intersect(&flags2);
765 assert!(intersection.supports_nested);
766 assert!(intersection.supports_streaming);
767 assert!(intersection.supports_delta);
768 assert!(intersection.supports_llb);
769 assert!(intersection.requires_checksums);
770 assert!(intersection.requires_canonical);
771 }
772
773 #[test]
774 fn test_feature_flags_intersect_all_disabled() {
775 let flags1 = FeatureFlags::new();
776 let flags2 = FeatureFlags::new();
777
778 let intersection = flags1.intersect(&flags2);
779 assert!(!intersection.supports_nested);
780 assert!(!intersection.supports_streaming);
781 assert!(!intersection.supports_delta);
782 assert!(!intersection.supports_llb);
783 assert!(!intersection.requires_checksums);
784 assert!(!intersection.requires_canonical);
785 }
786
787 #[test]
788 fn test_feature_flags_intersect_v0_5_with_v0_4() {
789 let v0_5 = FeatureFlags::v0_5_full();
790 let v0_4 = FeatureFlags::v0_4_compatible();
791
792 let intersection = v0_5.intersect(&v0_4);
793 assert!(!intersection.supports_nested);
795 assert!(!intersection.supports_streaming);
796 assert!(!intersection.supports_delta);
797 assert!(!intersection.supports_llb);
798 assert!(intersection.requires_checksums); assert!(intersection.requires_canonical); }
802
803 #[test]
804 fn test_negotiation_session_agreed_features() {
805 let local_caps = Capabilities {
806 version: 0x05,
807 features: FeatureFlags {
808 supports_nested: true,
809 supports_streaming: true,
810 supports_delta: false,
811 supports_llb: true,
812 requires_checksums: false,
813 requires_canonical: true,
814 },
815 supported_types: vec![TypeTag::Int],
816 };
817
818 let remote_caps = Capabilities {
819 version: 0x05,
820 features: FeatureFlags {
821 supports_nested: true,
822 supports_streaming: false,
823 supports_delta: true,
824 supports_llb: true,
825 requires_checksums: true,
826 requires_canonical: false,
827 },
828 supported_types: vec![TypeTag::Int],
829 };
830
831 let session = NegotiationSession::new(1, local_caps, remote_caps, HashMap::new());
832
833 assert!(session.agreed_features.supports_nested);
835 assert!(!session.agreed_features.supports_streaming);
836 assert!(!session.agreed_features.supports_delta);
837 assert!(session.agreed_features.supports_llb);
838 assert!(session.agreed_features.requires_checksums);
839 assert!(session.agreed_features.requires_canonical);
840 }
841
842 #[test]
843 fn test_capabilities_new() {
844 let features = FeatureFlags::new();
845 let types = vec![TypeTag::Int, TypeTag::String];
846 let caps = Capabilities::new(0x05, features.clone(), types.clone());
847
848 assert_eq!(caps.version, 0x05);
849 assert_eq!(caps.features, features);
850 assert_eq!(caps.supported_types, types);
851 }
852
853 #[test]
854 fn test_capabilities_v0_5() {
855 let caps = Capabilities::v0_5();
856 assert_eq!(caps.version, 0x05);
857 assert!(caps.features.supports_nested);
858 assert!(caps.supports_type(TypeTag::NestedRecord));
859 assert!(caps.supports_type(TypeTag::NestedArray));
860 }
861
862 #[test]
863 fn test_capabilities_v0_4() {
864 let caps = Capabilities::v0_4();
865 assert_eq!(caps.version, 0x04);
866 assert!(!caps.features.supports_nested);
867 assert!(!caps.supports_type(TypeTag::NestedRecord));
868 assert!(!caps.supports_type(TypeTag::NestedArray));
869 assert!(caps.supports_type(TypeTag::Int));
870 assert!(caps.supports_type(TypeTag::String));
871 }
872
873 #[test]
874 fn test_capabilities_supports_type() {
875 let caps = Capabilities::v0_5();
876 assert!(caps.supports_type(TypeTag::Int));
877 assert!(caps.supports_type(TypeTag::Float));
878 assert!(caps.supports_type(TypeTag::Bool));
879 assert!(caps.supports_type(TypeTag::String));
880 assert!(caps.supports_type(TypeTag::StringArray));
881 assert!(caps.supports_type(TypeTag::NestedRecord));
882 assert!(caps.supports_type(TypeTag::NestedArray));
883 assert!(!caps.supports_type(TypeTag::Reserved09));
884 }
885
886 #[test]
887 fn test_negotiation_message_capabilities() {
888 let msg = NegotiationMessage::Capabilities {
889 version: 0x05,
890 features: FeatureFlags::v0_5_full(),
891 supported_types: vec![TypeTag::Int, TypeTag::String],
892 };
893
894 match msg {
895 NegotiationMessage::Capabilities { version, .. } => {
896 assert_eq!(version, 0x05);
897 }
898 _ => panic!("Expected Capabilities variant"),
899 }
900 }
901
902 #[test]
903 fn test_negotiation_message_capabilities_ack() {
904 let msg = NegotiationMessage::CapabilitiesAck {
905 version: 0x05,
906 features: FeatureFlags::new(),
907 };
908
909 match msg {
910 NegotiationMessage::CapabilitiesAck { version, .. } => {
911 assert_eq!(version, 0x05);
912 }
913 _ => panic!("Expected CapabilitiesAck variant"),
914 }
915 }
916
917 #[test]
918 fn test_negotiation_message_select_schema() {
919 let mut mappings = HashMap::new();
920 mappings.insert(1, "user_id".to_string());
921 mappings.insert(2, "username".to_string());
922
923 let msg = NegotiationMessage::SelectSchema {
924 schema_id: "user_schema_v1".to_string(),
925 fid_mappings: mappings.clone(),
926 };
927
928 match msg {
929 NegotiationMessage::SelectSchema {
930 schema_id,
931 fid_mappings,
932 } => {
933 assert_eq!(schema_id, "user_schema_v1");
934 assert_eq!(fid_mappings.len(), 2);
935 assert_eq!(fid_mappings.get(&1), Some(&"user_id".to_string()));
936 }
937 _ => panic!("Expected SelectSchema variant"),
938 }
939 }
940
941 #[test]
942 fn test_negotiation_message_ready() {
943 let msg = NegotiationMessage::Ready { session_id: 12345 };
944
945 match msg {
946 NegotiationMessage::Ready { session_id } => {
947 assert_eq!(session_id, 12345);
948 }
949 _ => panic!("Expected Ready variant"),
950 }
951 }
952
953 #[test]
954 fn test_negotiation_message_error() {
955 let msg = NegotiationMessage::Error {
956 code: ErrorCode::FidConflict,
957 message: "FID 7 conflict".to_string(),
958 };
959
960 match msg {
961 NegotiationMessage::Error { code, message } => {
962 assert_eq!(code, ErrorCode::FidConflict);
963 assert_eq!(message, "FID 7 conflict");
964 }
965 _ => panic!("Expected Error variant"),
966 }
967 }
968
969 #[test]
970 fn test_error_code_from_u8() {
971 assert_eq!(ErrorCode::from_u8(0x01), Some(ErrorCode::FidConflict));
972 assert_eq!(ErrorCode::from_u8(0x02), Some(ErrorCode::TypeMismatch));
973 assert_eq!(
974 ErrorCode::from_u8(0x03),
975 Some(ErrorCode::UnsupportedFeature)
976 );
977 assert_eq!(
978 ErrorCode::from_u8(0x04),
979 Some(ErrorCode::ProtocolVersionMismatch)
980 );
981 assert_eq!(ErrorCode::from_u8(0x05), Some(ErrorCode::InvalidState));
982 assert_eq!(ErrorCode::from_u8(0xFF), Some(ErrorCode::Generic));
983 assert_eq!(ErrorCode::from_u8(0x99), None);
984 }
985
986 #[test]
987 fn test_error_code_to_u8() {
988 assert_eq!(ErrorCode::FidConflict.to_u8(), 0x01);
989 assert_eq!(ErrorCode::TypeMismatch.to_u8(), 0x02);
990 assert_eq!(ErrorCode::UnsupportedFeature.to_u8(), 0x03);
991 assert_eq!(ErrorCode::ProtocolVersionMismatch.to_u8(), 0x04);
992 assert_eq!(ErrorCode::InvalidState.to_u8(), 0x05);
993 assert_eq!(ErrorCode::Generic.to_u8(), 0xFF);
994 }
995
996 #[test]
997 fn test_error_code_round_trip() {
998 let codes = vec![
999 ErrorCode::FidConflict,
1000 ErrorCode::TypeMismatch,
1001 ErrorCode::UnsupportedFeature,
1002 ErrorCode::ProtocolVersionMismatch,
1003 ErrorCode::InvalidState,
1004 ErrorCode::Generic,
1005 ];
1006
1007 for code in codes {
1008 let byte = code.to_u8();
1009 let parsed = ErrorCode::from_u8(byte).unwrap();
1010 assert_eq!(parsed, code);
1011 }
1012 }
1013}
1014
1015#[test]
1016fn test_detect_type_mismatches_no_mismatches() {
1017 let mut expected = HashMap::new();
1018 expected.insert(1, TypeTag::Int);
1019 expected.insert(2, TypeTag::String);
1020
1021 let mut actual = HashMap::new();
1022 actual.insert(1, TypeTag::Int);
1023 actual.insert(2, TypeTag::String);
1024
1025 let mismatches = SchemaNegotiator::detect_type_mismatches(&expected, &actual);
1026 assert!(mismatches.is_empty());
1027}
1028
1029#[test]
1030fn test_detect_type_mismatches_single_mismatch() {
1031 let mut expected = HashMap::new();
1032 expected.insert(1, TypeTag::Int);
1033 expected.insert(2, TypeTag::String);
1034
1035 let mut actual = HashMap::new();
1036 actual.insert(1, TypeTag::Float); actual.insert(2, TypeTag::String);
1038
1039 let mismatches = SchemaNegotiator::detect_type_mismatches(&expected, &actual);
1040 assert_eq!(mismatches.len(), 1);
1041
1042 match &mismatches[0] {
1043 NegotiationError::TypeMismatch {
1044 fid,
1045 expected,
1046 found,
1047 } => {
1048 assert_eq!(*fid, 1);
1049 assert_eq!(*expected, TypeTag::Int);
1050 assert_eq!(*found, TypeTag::Float);
1051 }
1052 _ => panic!("Expected TypeMismatch"),
1053 }
1054}
1055
1056#[test]
1057fn test_detect_type_mismatches_multiple_mismatches() {
1058 let mut expected = HashMap::new();
1059 expected.insert(1, TypeTag::Int);
1060 expected.insert(2, TypeTag::String);
1061 expected.insert(3, TypeTag::Bool);
1062
1063 let mut actual = HashMap::new();
1064 actual.insert(1, TypeTag::Float); actual.insert(2, TypeTag::Bool); actual.insert(3, TypeTag::Bool);
1067
1068 let mismatches = SchemaNegotiator::detect_type_mismatches(&expected, &actual);
1069 assert_eq!(mismatches.len(), 2);
1070}
1071
1072#[test]
1073fn test_detect_type_mismatches_partial_overlap() {
1074 let mut expected = HashMap::new();
1075 expected.insert(1, TypeTag::Int);
1076 expected.insert(2, TypeTag::String);
1077
1078 let mut actual = HashMap::new();
1079 actual.insert(2, TypeTag::String);
1080 actual.insert(3, TypeTag::Bool); let mismatches = SchemaNegotiator::detect_type_mismatches(&expected, &actual);
1083 assert!(mismatches.is_empty());
1084}
1085
1086#[test]
1087fn test_detect_type_mismatches_empty_types() {
1088 let expected = HashMap::new();
1089 let actual = HashMap::new();
1090
1091 let mismatches = SchemaNegotiator::detect_type_mismatches(&expected, &actual);
1092 assert!(mismatches.is_empty());
1093}
1094
1095#[test]
1096fn test_detect_type_mismatches_nested_types() {
1097 let mut expected = HashMap::new();
1098 expected.insert(1, TypeTag::NestedRecord);
1099 expected.insert(2, TypeTag::NestedArray);
1100
1101 let mut actual = HashMap::new();
1102 actual.insert(1, TypeTag::String); actual.insert(2, TypeTag::NestedArray);
1104
1105 let mismatches = SchemaNegotiator::detect_type_mismatches(&expected, &actual);
1106 assert_eq!(mismatches.len(), 1);
1107
1108 match &mismatches[0] {
1109 NegotiationError::TypeMismatch {
1110 fid,
1111 expected,
1112 found,
1113 } => {
1114 assert_eq!(*fid, 1);
1115 assert_eq!(*expected, TypeTag::NestedRecord);
1116 assert_eq!(*found, TypeTag::String);
1117 }
1118 _ => panic!("Expected TypeMismatch"),
1119 }
1120}
1121
1122#[test]
1123fn test_detect_conflicts_no_conflicts() {
1124 let mut local = HashMap::new();
1125 local.insert(1, "user_id".to_string());
1126 local.insert(2, "username".to_string());
1127
1128 let mut remote = HashMap::new();
1129 remote.insert(1, "user_id".to_string());
1130 remote.insert(2, "username".to_string());
1131
1132 let conflicts = SchemaNegotiator::detect_conflicts(&local, &remote);
1133 assert!(conflicts.is_empty());
1134}
1135
1136#[test]
1137fn test_detect_conflicts_single_conflict() {
1138 let mut local = HashMap::new();
1139 local.insert(1, "user_id".to_string());
1140 local.insert(2, "username".to_string());
1141
1142 let mut remote = HashMap::new();
1143 remote.insert(1, "userId".to_string()); remote.insert(2, "username".to_string());
1145
1146 let conflicts = SchemaNegotiator::detect_conflicts(&local, &remote);
1147 assert_eq!(conflicts.len(), 1);
1148
1149 match &conflicts[0] {
1150 NegotiationError::FidConflict { fid, name1, name2 } => {
1151 assert_eq!(*fid, 1);
1152 assert_eq!(name1, "user_id");
1153 assert_eq!(name2, "userId");
1154 }
1155 _ => panic!("Expected FidConflict"),
1156 }
1157}
1158
1159#[test]
1160fn test_detect_conflicts_multiple_conflicts() {
1161 let mut local = HashMap::new();
1162 local.insert(1, "user_id".to_string());
1163 local.insert(2, "username".to_string());
1164 local.insert(3, "email".to_string());
1165
1166 let mut remote = HashMap::new();
1167 remote.insert(1, "userId".to_string()); remote.insert(2, "userName".to_string()); remote.insert(3, "email".to_string());
1170
1171 let conflicts = SchemaNegotiator::detect_conflicts(&local, &remote);
1172 assert_eq!(conflicts.len(), 2);
1173}
1174
1175#[test]
1176fn test_detect_conflicts_partial_overlap() {
1177 let mut local = HashMap::new();
1178 local.insert(1, "user_id".to_string());
1179 local.insert(2, "username".to_string());
1180
1181 let mut remote = HashMap::new();
1182 remote.insert(2, "username".to_string());
1183 remote.insert(3, "email".to_string()); let conflicts = SchemaNegotiator::detect_conflicts(&local, &remote);
1186 assert!(conflicts.is_empty());
1187}
1188
1189#[test]
1190fn test_detect_conflicts_empty_mappings() {
1191 let local = HashMap::new();
1192 let remote = HashMap::new();
1193
1194 let conflicts = SchemaNegotiator::detect_conflicts(&local, &remote);
1195 assert!(conflicts.is_empty());
1196}
1197
1198#[test]
1199fn test_negotiation_state_equality() {
1200 assert_eq!(NegotiationState::Initial, NegotiationState::Initial);
1201 assert_eq!(
1202 NegotiationState::CapabilitiesSent,
1203 NegotiationState::CapabilitiesSent
1204 );
1205 assert_ne!(NegotiationState::Initial, NegotiationState::Ready);
1206}
1207
1208#[test]
1209fn test_negotiation_session_new() {
1210 let local_caps = Capabilities::v0_5();
1211 let remote_caps = Capabilities::v0_5();
1212 let mut mappings = HashMap::new();
1213 mappings.insert(1, "user_id".to_string());
1214
1215 let session = NegotiationSession::new(
1216 123,
1217 local_caps.clone(),
1218 remote_caps.clone(),
1219 mappings.clone(),
1220 );
1221
1222 assert_eq!(session.session_id, 123);
1223 assert_eq!(session.local_caps, local_caps);
1224 assert_eq!(session.remote_caps, remote_caps);
1225 assert_eq!(session.fid_mappings, mappings);
1226 assert!(session.agreed_features.supports_nested);
1227}
1228
1229#[test]
1230fn test_schema_negotiator_new() {
1231 let caps = Capabilities::v0_5();
1232 let negotiator = SchemaNegotiator::new(caps.clone());
1233
1234 assert_eq!(negotiator.local_capabilities(), &caps);
1235 assert_eq!(negotiator.state(), &NegotiationState::Initial);
1236 assert!(negotiator.remote_capabilities().is_none());
1237 assert!(!negotiator.is_ready());
1238}
1239
1240#[test]
1241fn test_schema_negotiator_v0_5() {
1242 let negotiator = SchemaNegotiator::v0_5();
1243 assert_eq!(negotiator.local_capabilities().version, 0x05);
1244 assert!(negotiator.local_capabilities().features.supports_nested);
1245}
1246
1247#[test]
1248fn test_schema_negotiator_v0_4() {
1249 let negotiator = SchemaNegotiator::v0_4();
1250 assert_eq!(negotiator.local_capabilities().version, 0x04);
1251 assert!(!negotiator.local_capabilities().features.supports_nested);
1252}
1253
1254#[test]
1255fn test_schema_negotiator_with_fid_mappings() {
1256 let mut mappings = HashMap::new();
1257 mappings.insert(1, "user_id".to_string());
1258 mappings.insert(2, "username".to_string());
1259
1260 let negotiator = SchemaNegotiator::v0_5().with_fid_mappings(mappings.clone());
1261 assert_eq!(negotiator.fid_mappings, mappings);
1262}
1263
1264#[test]
1265fn test_schema_negotiator_initiate() {
1266 let mut negotiator = SchemaNegotiator::v0_5();
1267 let result = negotiator.initiate();
1268
1269 assert!(result.is_ok());
1270 assert_eq!(negotiator.state(), &NegotiationState::CapabilitiesSent);
1271
1272 match result.unwrap() {
1273 NegotiationMessage::Capabilities { version, .. } => {
1274 assert_eq!(version, 0x05);
1275 }
1276 _ => panic!("Expected Capabilities message"),
1277 }
1278}
1279
1280#[test]
1281fn test_schema_negotiator_initiate_invalid_state() {
1282 let mut negotiator = SchemaNegotiator::v0_5();
1283 negotiator.initiate().unwrap();
1284
1285 let result = negotiator.initiate();
1287 assert!(result.is_err());
1288 match result {
1289 Err(NegotiationError::InvalidState { .. }) => {}
1290 _ => panic!("Expected InvalidState error"),
1291 }
1292}
1293
1294#[test]
1295fn test_schema_negotiator_handle_capabilities() {
1296 let mut negotiator = SchemaNegotiator::v0_5();
1297
1298 let msg = NegotiationMessage::Capabilities {
1299 version: 0x05,
1300 features: FeatureFlags::v0_5_full(),
1301 supported_types: vec![TypeTag::Int, TypeTag::String],
1302 };
1303
1304 let result = negotiator.handle_message(msg);
1305 assert!(result.is_ok());
1306 assert_eq!(negotiator.state(), &NegotiationState::CapabilitiesReceived);
1307
1308 match result.unwrap() {
1309 NegotiationResponse::SendMessage(NegotiationMessage::CapabilitiesAck {
1310 version, ..
1311 }) => {
1312 assert_eq!(version, 0x05);
1313 }
1314 _ => panic!("Expected SendMessage with CapabilitiesAck"),
1315 }
1316}
1317
1318#[test]
1319fn test_schema_negotiator_handle_capabilities_version_mismatch() {
1320 let mut negotiator = SchemaNegotiator::v0_5();
1321
1322 let msg = NegotiationMessage::Capabilities {
1323 version: 0x04, features: FeatureFlags::v0_4_compatible(),
1325 supported_types: vec![TypeTag::Int],
1326 };
1327
1328 let result = negotiator.handle_message(msg);
1329 assert!(result.is_err());
1330 match result {
1331 Err(NegotiationError::ProtocolVersionMismatch { local, remote }) => {
1332 assert_eq!(local, 0x05);
1333 assert_eq!(remote, 0x04);
1334 }
1335 _ => panic!("Expected ProtocolVersionMismatch error"),
1336 }
1337}
1338
1339#[test]
1340fn test_schema_negotiator_handle_capabilities_ack() {
1341 let mut negotiator = SchemaNegotiator::v0_5();
1342 negotiator.initiate().unwrap();
1343
1344 let msg = NegotiationMessage::CapabilitiesAck {
1345 version: 0x05,
1346 features: FeatureFlags::v0_5_full(),
1347 };
1348
1349 let result = negotiator.handle_message(msg);
1350 assert!(result.is_ok());
1351 assert_eq!(negotiator.state(), &NegotiationState::SchemaSelected);
1352
1353 match result.unwrap() {
1354 NegotiationResponse::SendMessage(NegotiationMessage::SelectSchema { .. }) => {}
1355 _ => panic!("Expected SendMessage with SelectSchema"),
1356 }
1357}
1358
1359#[test]
1360fn test_schema_negotiator_handle_select_schema() {
1361 let mut negotiator = SchemaNegotiator::v0_5();
1362
1363 negotiator.state = NegotiationState::CapabilitiesReceived;
1365 negotiator.remote_capabilities = Some(Capabilities::v0_5());
1366
1367 let mut mappings = HashMap::new();
1368 mappings.insert(1, "user_id".to_string());
1369
1370 let msg = NegotiationMessage::SelectSchema {
1371 schema_id: "test_schema".to_string(),
1372 fid_mappings: mappings,
1373 };
1374
1375 let result = negotiator.handle_message(msg);
1376 assert!(result.is_ok());
1377 assert_eq!(negotiator.state(), &NegotiationState::SchemaSelected);
1378
1379 match result.unwrap() {
1380 NegotiationResponse::SendMessage(NegotiationMessage::Ready { session_id }) => {
1381 assert_eq!(session_id, 1);
1382 }
1383 _ => panic!("Expected SendMessage with Ready"),
1384 }
1385}
1386
1387#[test]
1388fn test_schema_negotiator_handle_select_schema_fid_conflict() {
1389 let mut local_mappings = HashMap::new();
1390 local_mappings.insert(1, "user_id".to_string());
1391
1392 let mut negotiator = SchemaNegotiator::v0_5().with_fid_mappings(local_mappings);
1393
1394 negotiator.state = NegotiationState::CapabilitiesReceived;
1396 negotiator.remote_capabilities = Some(Capabilities::v0_5());
1397
1398 let mut remote_mappings = HashMap::new();
1399 remote_mappings.insert(1, "username".to_string()); let msg = NegotiationMessage::SelectSchema {
1402 schema_id: "test_schema".to_string(),
1403 fid_mappings: remote_mappings,
1404 };
1405
1406 let result = negotiator.handle_message(msg);
1407 assert!(result.is_err());
1408 match result {
1409 Err(NegotiationError::FidConflict { fid, name1, name2 }) => {
1410 assert_eq!(fid, 1);
1411 assert_eq!(name1, "user_id");
1412 assert_eq!(name2, "username");
1413 }
1414 _ => panic!("Expected FidConflict error"),
1415 }
1416}
1417
1418#[test]
1419fn test_schema_negotiator_handle_ready() {
1420 let mut negotiator = SchemaNegotiator::v0_5();
1421
1422 negotiator.state = NegotiationState::SchemaSelected;
1424 negotiator.remote_capabilities = Some(Capabilities::v0_5());
1425
1426 let msg = NegotiationMessage::Ready { session_id: 42 };
1427
1428 let result = negotiator.handle_message(msg);
1429 assert!(result.is_ok());
1430 assert_eq!(negotiator.state(), &NegotiationState::Ready);
1431 assert!(negotiator.is_ready());
1432
1433 match result.unwrap() {
1434 NegotiationResponse::Complete(session) => {
1435 assert_eq!(session.session_id, 42);
1436 }
1437 _ => panic!("Expected Complete response"),
1438 }
1439}
1440
1441#[test]
1442fn test_schema_negotiator_handle_error() {
1443 let mut negotiator = SchemaNegotiator::v0_5();
1444
1445 let msg = NegotiationMessage::Error {
1446 code: ErrorCode::Generic,
1447 message: "Test error".to_string(),
1448 };
1449
1450 let result = negotiator.handle_message(msg);
1451 assert!(result.is_ok());
1452
1453 match negotiator.state() {
1454 NegotiationState::Failed(msg) => {
1455 assert_eq!(msg, "Test error");
1456 }
1457 _ => panic!("Expected Failed state"),
1458 }
1459
1460 match result.unwrap() {
1461 NegotiationResponse::Failed(msg) => {
1462 assert_eq!(msg, "Test error");
1463 }
1464 _ => panic!("Expected Failed response"),
1465 }
1466}
1467
1468#[test]
1469fn test_negotiation_error_display() {
1470 let err = NegotiationError::FidConflict {
1471 fid: 7,
1472 name1: "field_a".to_string(),
1473 name2: "field_b".to_string(),
1474 };
1475 let msg = format!("{}", err);
1476 assert!(msg.contains("FID 7"));
1477 assert!(msg.contains("field_a"));
1478 assert!(msg.contains("field_b"));
1479}
1480
1481#[test]
1482fn test_full_negotiation_flow_client_initiated() {
1483 let mut client = SchemaNegotiator::v0_5();
1485 let mut client_mappings = HashMap::new();
1486 client_mappings.insert(1, "user_id".to_string());
1487 client = client.with_fid_mappings(client_mappings.clone());
1488
1489 let mut server = SchemaNegotiator::v0_5();
1491 server = server.with_fid_mappings(client_mappings.clone());
1492
1493 let caps_msg = client.initiate().unwrap();
1495 assert_eq!(client.state(), &NegotiationState::CapabilitiesSent);
1496
1497 let server_response = server.handle_message(caps_msg).unwrap();
1499 assert_eq!(server.state(), &NegotiationState::CapabilitiesReceived);
1500
1501 let ack_msg = match server_response {
1502 NegotiationResponse::SendMessage(msg) => msg,
1503 _ => panic!("Expected SendMessage"),
1504 };
1505
1506 let client_response = client.handle_message(ack_msg).unwrap();
1508 assert_eq!(client.state(), &NegotiationState::SchemaSelected);
1509
1510 let select_msg = match client_response {
1511 NegotiationResponse::SendMessage(msg) => msg,
1512 _ => panic!("Expected SendMessage"),
1513 };
1514
1515 let server_response = server.handle_message(select_msg).unwrap();
1517 assert_eq!(server.state(), &NegotiationState::SchemaSelected);
1518
1519 let ready_msg = match server_response {
1520 NegotiationResponse::SendMessage(msg) => msg,
1521 _ => panic!("Expected SendMessage"),
1522 };
1523
1524 let client_response = client.handle_message(ready_msg).unwrap();
1526 assert_eq!(client.state(), &NegotiationState::Ready);
1527 assert!(client.is_ready());
1528
1529 match client_response {
1530 NegotiationResponse::Complete(session) => {
1531 assert_eq!(session.session_id, 1);
1532 assert!(session.agreed_features.supports_nested);
1533 }
1534 _ => panic!("Expected Complete response"),
1535 }
1536}