hightower_wireguard/
lib.rs

1use thiserror::Error;
2
3/// Cryptographic primitives for WireGuard protocol
4pub mod crypto;
5/// Handshake initiator implementation
6pub mod initiator;
7/// WireGuard message structures
8pub mod messages;
9/// High-level protocol implementation
10pub mod protocol;
11/// Replay protection implementation
12pub mod replay;
13/// Handshake responder implementation
14pub mod responder;
15/// WireGuard connection implementation (UDP-based Connection and Stream)
16#[cfg(feature = "transport")]
17pub mod connection;
18
19/// Errors that can occur during WireGuard protocol operations
20#[derive(Error, Debug)]
21pub enum WireGuardError {
22    #[error("Cryptographic error: {0}")]
23    CryptoError(String),
24    #[error("Protocol error: {0}")]
25    ProtocolError(String),
26    #[error("Invalid key length")]
27    InvalidKeyLength,
28    #[error("Authentication failed")]
29    AuthenticationFailed,
30}
31
32pub type Result<T> = std::result::Result<T, WireGuardError>;
33
34#[cfg(test)]
35mod tests {
36    use super::*;
37    use crypto::{aead_decrypt, aead_encrypt, dh_generate};
38    use protocol::{PeerInfo, WireGuardProtocol};
39
40    #[test]
41    fn test_full_handshake_and_message_exchange() {
42        // Generate keys for both peers
43        let (alice_private, alice_public) = dh_generate();
44        let (bob_private, bob_public) = dh_generate();
45
46        // Create protocol instances
47        let mut alice = WireGuardProtocol::new(Some(alice_private));
48        let mut bob = WireGuardProtocol::new(Some(bob_private));
49
50        // Add each other as peers
51        alice.add_peer(PeerInfo {
52            public_key: bob_public,
53            preshared_key: None,
54            endpoint: None,
55            allowed_ips: Vec::new(),
56            persistent_keepalive: None,
57        });
58
59        bob.add_peer(PeerInfo {
60            public_key: alice_public,
61            preshared_key: None,
62            endpoint: None,
63            allowed_ips: Vec::new(),
64            persistent_keepalive: None,
65        });
66
67        // Alice initiates handshake
68        let initiation = alice.initiate_handshake(&bob_public).unwrap();
69        println!(
70            "Alice created initiation with sender ID: {}",
71            initiation.sender
72        );
73
74        // Bob processes initiation and creates response
75        let response = bob.process_initiation(&initiation).unwrap();
76        println!("Bob created response with sender ID: {}", response.sender);
77
78        // Alice processes response to complete handshake
79        let peer_key = alice.process_response(&response).unwrap();
80        assert_eq!(peer_key, bob_public);
81
82        // Both sides should now have active sessions
83        let alice_session = alice.get_session(response.sender).unwrap();
84        let bob_session = bob.get_session(response.sender).unwrap();
85
86        println!("Handshake complete!");
87        println!(
88            "Alice session keys: send={:?}, recv={:?}",
89            &alice_session.keys.send_key[..8],
90            &alice_session.keys.recv_key[..8]
91        );
92        println!(
93            "Bob session keys: send={:?}, recv={:?}",
94            &bob_session.keys.send_key[..8],
95            &bob_session.keys.recv_key[..8]
96        );
97
98        // Test message encryption/decryption
99        let message = b"Hello from Alice to Bob!";
100        let counter = 0u64;
101
102        // Alice encrypts message with her send key
103        let encrypted = aead_encrypt(&alice_session.keys.send_key, counter, message, &[]).unwrap();
104        println!("Alice encrypted message: {} bytes", encrypted.len());
105
106        // Bob decrypts with his receive key (should match Alice's send key)
107        let decrypted = aead_decrypt(&bob_session.keys.recv_key, counter, &encrypted, &[]).unwrap();
108        assert_eq!(decrypted, message);
109        println!("Bob decrypted: {:?}", String::from_utf8_lossy(&decrypted));
110
111        // Test reverse direction
112        let reply = b"Hello back from Bob to Alice!";
113        let reply_counter = 0u64;
114
115        // Bob encrypts reply with his send key
116        let encrypted_reply =
117            aead_encrypt(&bob_session.keys.send_key, reply_counter, reply, &[]).unwrap();
118
119        // Alice decrypts with her receive key
120        let decrypted_reply = aead_decrypt(
121            &alice_session.keys.recv_key,
122            reply_counter,
123            &encrypted_reply,
124            &[],
125        )
126        .unwrap();
127        assert_eq!(decrypted_reply, reply);
128        println!(
129            "Alice decrypted reply: {:?}",
130            String::from_utf8_lossy(&decrypted_reply)
131        );
132
133        println!("Bidirectional message exchange successful!");
134    }
135
136    #[test]
137    fn test_handshake_with_preshared_key() {
138        let psk = [42u8; 32]; // Shared preshared key
139
140        let (alice_private, alice_public) = dh_generate();
141        let (bob_private, bob_public) = dh_generate();
142
143        let mut alice = WireGuardProtocol::new(Some(alice_private));
144        let mut bob = WireGuardProtocol::new(Some(bob_private));
145
146        // Add peers with PSK
147        alice.add_peer(PeerInfo {
148            public_key: bob_public,
149            preshared_key: Some(psk),
150            endpoint: None,
151            allowed_ips: Vec::new(),
152            persistent_keepalive: None,
153        });
154
155        bob.add_peer(PeerInfo {
156            public_key: alice_public,
157            preshared_key: Some(psk),
158            endpoint: None,
159            allowed_ips: Vec::new(),
160            persistent_keepalive: None,
161        });
162
163        // Perform handshake
164        let initiation = alice.initiate_handshake(&bob_public).unwrap();
165        let response = bob.process_initiation(&initiation).unwrap();
166        let _peer_key = alice.process_response(&response).unwrap();
167
168        // Verify sessions exist
169        assert!(alice.get_session(response.sender).is_some());
170        assert!(bob.get_session(response.sender).is_some());
171
172        println!("PSK handshake successful!");
173    }
174
175    #[test]
176    fn test_replay_protection() {
177        // Generate keys for both peers
178        let (alice_private, alice_public) = dh_generate();
179        let (bob_private, bob_public) = dh_generate();
180
181        // Create protocol instances
182        let mut alice = WireGuardProtocol::new(Some(alice_private));
183        let mut bob = WireGuardProtocol::new(Some(bob_private));
184
185        // Add each other as peers
186        alice.add_peer(PeerInfo {
187            public_key: bob_public,
188            preshared_key: None,
189            endpoint: None,
190            allowed_ips: Vec::new(),
191            persistent_keepalive: None,
192        });
193
194        bob.add_peer(PeerInfo {
195            public_key: alice_public,
196            preshared_key: None,
197            endpoint: None,
198            allowed_ips: Vec::new(),
199            persistent_keepalive: None,
200        });
201
202        // Complete handshake
203        let initiation = alice.initiate_handshake(&bob_public).unwrap();
204        let response = bob.process_initiation(&initiation).unwrap();
205        alice.process_response(&response).unwrap();
206
207        let session_id = response.sender;
208
209        // Test sequential counters
210        assert!(bob.check_replay(session_id, 1).is_ok());
211        assert!(bob.check_replay(session_id, 2).is_ok());
212        assert!(bob.check_replay(session_id, 3).is_ok());
213
214        // Test replay detection
215        assert!(bob.check_replay(session_id, 2).is_err()); // Replay of counter 2
216        assert!(bob.check_replay(session_id, 1).is_err()); // Replay of counter 1
217
218        // Test out-of-order within window
219        assert!(bob.check_replay(session_id, 10).is_ok());
220        assert!(bob.check_replay(session_id, 5).is_ok());
221        assert!(bob.check_replay(session_id, 8).is_ok());
222
223        // Test replay of out-of-order packets
224        assert!(bob.check_replay(session_id, 5).is_err());
225        assert!(bob.check_replay(session_id, 8).is_err());
226
227        // Test counter too old (outside window)
228        assert!(bob.check_replay(session_id, 3000).is_ok());
229        assert!(bob.check_replay(session_id, 500).is_err()); // Too old
230
231        println!("Replay protection test successful!");
232    }
233}