nomad_protocol/client/
bootstrap.rs1use std::net::SocketAddr;
9
10use thiserror::Error;
11
12#[derive(Debug, Error)]
14pub enum BootstrapError {
15 #[error("invalid server public key: {0}")]
17 InvalidServerKey(String),
18
19 #[error("key generation failed: {0}")]
21 KeyGenerationFailed(String),
22
23 #[error("invalid key format: {0}")]
25 InvalidKeyFormat(String),
26}
27
28#[derive(Clone)]
30pub struct ClientIdentity {
31 private_key: [u8; 32],
33 public_key: [u8; 32],
35}
36
37impl ClientIdentity {
38 pub fn generate() -> Result<Self, BootstrapError> {
40 use std::time::{SystemTime, UNIX_EPOCH};
43
44 let seed = SystemTime::now()
45 .duration_since(UNIX_EPOCH)
46 .unwrap()
47 .as_nanos() as u64;
48
49 let mut private_key = [0u8; 32];
51 let mut state = seed;
52 for byte in &mut private_key {
53 state = state.wrapping_mul(6364136223846793005).wrapping_add(1);
54 *byte = (state >> 33) as u8;
55 }
56
57 let public_key = private_key; Ok(Self {
61 private_key,
62 public_key,
63 })
64 }
65
66 pub fn from_private_key(private_key: [u8; 32]) -> Result<Self, BootstrapError> {
68 let public_key = private_key; Ok(Self {
72 private_key,
73 public_key,
74 })
75 }
76
77 pub fn private_key(&self) -> &[u8; 32] {
79 &self.private_key
80 }
81
82 pub fn public_key(&self) -> &[u8; 32] {
84 &self.public_key
85 }
86}
87
88impl std::fmt::Debug for ClientIdentity {
89 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
90 f.debug_struct("ClientIdentity")
91 .field("public_key", &hex_preview(&self.public_key))
92 .field("private_key", &"[REDACTED]")
93 .finish()
94 }
95}
96
97#[derive(Debug, Clone)]
99pub struct ServerInfo {
100 pub addr: SocketAddr,
102 pub public_key: [u8; 32],
104}
105
106impl ServerInfo {
107 pub fn new(addr: SocketAddr, public_key: [u8; 32]) -> Self {
109 Self { addr, public_key }
110 }
111
112 pub fn from_base64_key(addr: SocketAddr, key_base64: &str) -> Result<Self, BootstrapError> {
114 let key_bytes = decode_base64(key_base64)
115 .map_err(|e| BootstrapError::InvalidServerKey(e.to_string()))?;
116
117 if key_bytes.len() != 32 {
118 return Err(BootstrapError::InvalidServerKey(format!(
119 "expected 32 bytes, got {}",
120 key_bytes.len()
121 )));
122 }
123
124 let mut public_key = [0u8; 32];
125 public_key.copy_from_slice(&key_bytes);
126
127 Ok(Self { addr, public_key })
128 }
129}
130
131#[derive(Debug, Clone)]
133pub struct BootstrapConfig {
134 pub identity: Option<ClientIdentity>,
136 pub server: ServerInfo,
138 pub state_type_id: String,
140 pub extensions: Vec<u16>,
142}
143
144#[derive(Debug)]
146pub struct BootstrapResult {
147 pub identity: ClientIdentity,
149 pub extensions: Vec<u16>,
151 pub session_id: [u8; 6],
153}
154
155pub fn prepare_bootstrap(config: BootstrapConfig) -> Result<BootstrapResult, BootstrapError> {
162 let identity = match config.identity {
164 Some(id) => id,
165 None => ClientIdentity::generate()?,
166 };
167
168 Ok(BootstrapResult {
172 identity,
173 extensions: config.extensions,
174 session_id: [0u8; 6], })
176}
177
178fn hex_preview(bytes: &[u8]) -> String {
181 if bytes.len() <= 8 {
182 bytes.iter().map(|b| format!("{:02x}", b)).collect()
183 } else {
184 format!(
185 "{}...",
186 bytes[..4].iter().map(|b| format!("{:02x}", b)).collect::<String>()
187 )
188 }
189}
190
191fn decode_base64(input: &str) -> Result<Vec<u8>, BootstrapError> {
192 const ALPHABET: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
194
195 let input = input.trim().as_bytes();
196 let mut output = Vec::with_capacity(input.len() * 3 / 4);
197
198 let mut buffer = 0u32;
199 let mut bits = 0u32;
200
201 for &byte in input {
202 if byte == b'=' {
203 break;
204 }
205
206 let value = ALPHABET
207 .iter()
208 .position(|&c| c == byte)
209 .ok_or_else(|| BootstrapError::InvalidKeyFormat("invalid base64 character".into()))?
210 as u32;
211
212 buffer = (buffer << 6) | value;
213 bits += 6;
214
215 if bits >= 8 {
216 bits -= 8;
217 output.push((buffer >> bits) as u8);
218 buffer &= (1 << bits) - 1;
219 }
220 }
221
222 Ok(output)
223}
224
225#[cfg(test)]
226mod tests {
227 use super::*;
228
229 #[test]
230 fn test_client_identity_generate() {
231 let id1 = ClientIdentity::generate().unwrap();
232 let id2 = ClientIdentity::generate().unwrap();
233
234 assert_ne!(id1.private_key(), id2.private_key());
236 }
237
238 #[test]
239 fn test_base64_decode() {
240 let result = decode_base64("SGVsbG8=").unwrap();
242 assert_eq!(result, b"Hello");
243 }
244
245 #[test]
246 fn test_server_info_from_base64() {
247 let key_b64 = "AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=";
249 let addr: SocketAddr = "127.0.0.1:19999".parse().unwrap();
250
251 let server = ServerInfo::from_base64_key(addr, key_b64).unwrap();
252 assert_eq!(server.public_key, [0u8; 32]);
253 }
254}