1use serde::{Deserialize, Serialize};
4use sha2::{Digest, Sha256};
5use std::time::{SystemTime, UNIX_EPOCH};
6
7use crate::error::{CompatError, Result};
8use crate::user::UserId;
9
10pub type SshKeyId = u64;
12
13#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct SshKey {
16 pub id: SshKeyId,
18 pub user_id: UserId,
20 pub title: String,
22 pub key_type: SshKeyType,
24 pub public_key: String,
26 pub fingerprint: String,
28 pub created_at: u64,
30 pub last_used_at: Option<u64>,
32}
33
34impl SshKey {
35 pub fn new(id: SshKeyId, user_id: UserId, title: String, public_key: String) -> Result<Self> {
37 let (key_type, fingerprint) = parse_and_validate_key(&public_key)?;
38
39 let now = SystemTime::now()
40 .duration_since(UNIX_EPOCH)
41 .unwrap()
42 .as_secs();
43
44 Ok(Self {
45 id,
46 user_id,
47 title,
48 key_type,
49 public_key,
50 fingerprint,
51 created_at: now,
52 last_used_at: None,
53 })
54 }
55
56 pub fn touch(&mut self) {
58 self.last_used_at = Some(
59 SystemTime::now()
60 .duration_since(UNIX_EPOCH)
61 .unwrap()
62 .as_secs(),
63 );
64 }
65
66 pub fn to_response(&self) -> SshKeyResponse {
68 SshKeyResponse {
69 id: self.id,
70 title: self.title.clone(),
71 key_type: self.key_type,
72 key: self.public_key.clone(),
73 fingerprint: self.fingerprint.clone(),
74 created_at: format_timestamp(self.created_at),
75 last_used_at: self.last_used_at.map(format_timestamp),
76 }
77 }
78}
79
80#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
82#[serde(rename_all = "lowercase")]
83pub enum SshKeyType {
84 Ed25519,
86 Rsa,
88 Ecdsa,
90}
91
92impl std::fmt::Display for SshKeyType {
93 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
94 match self {
95 Self::Ed25519 => write!(f, "ssh-ed25519"),
96 Self::Rsa => write!(f, "ssh-rsa"),
97 Self::Ecdsa => write!(f, "ecdsa-sha2-nistp256"),
98 }
99 }
100}
101
102fn parse_and_validate_key(key: &str) -> Result<(SshKeyType, String)> {
106 let parts: Vec<&str> = key.split_whitespace().collect();
107
108 if parts.len() < 2 {
109 return Err(CompatError::InvalidSshKey(
110 "key must have at least type and data parts".to_string(),
111 ));
112 }
113
114 let key_type = match parts[0] {
115 "ssh-ed25519" => SshKeyType::Ed25519,
116 "ssh-rsa" => SshKeyType::Rsa,
117 "ecdsa-sha2-nistp256" | "ecdsa-sha2-nistp384" | "ecdsa-sha2-nistp521" => SshKeyType::Ecdsa,
118 other => {
119 return Err(CompatError::InvalidSshKey(format!(
120 "unsupported key type: {}",
121 other
122 )));
123 }
124 };
125
126 let key_data = parts[1];
128 let decoded = base64_decode(key_data)
129 .map_err(|e| CompatError::InvalidSshKey(format!("invalid base64 encoding: {}", e)))?;
130
131 validate_key_data(&decoded, key_type)?;
133
134 let fingerprint = calculate_fingerprint(&decoded);
136
137 Ok((key_type, fingerprint))
138}
139
140fn base64_decode(input: &str) -> std::result::Result<Vec<u8>, &'static str> {
142 const ALPHABET: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
143
144 fn char_to_value(c: u8) -> std::result::Result<u8, &'static str> {
145 if let Some(pos) = ALPHABET.iter().position(|&x| x == c) {
146 Ok(pos as u8)
147 } else if c == b'=' {
148 Ok(0) } else {
150 Err("invalid base64 character")
151 }
152 }
153
154 let input = input.trim();
155 if input.is_empty() {
156 return Err("empty input");
157 }
158
159 let bytes: Vec<u8> = input
160 .bytes()
161 .filter(|b| *b != b'\n' && *b != b'\r')
162 .collect();
163
164 if !bytes.len().is_multiple_of(4) {
165 return Err("invalid base64 length");
166 }
167
168 let mut result = Vec::with_capacity(bytes.len() * 3 / 4);
169
170 for chunk in bytes.chunks(4) {
171 let a = char_to_value(chunk[0])?;
172 let b = char_to_value(chunk[1])?;
173 let c = char_to_value(chunk[2])?;
174 let d = char_to_value(chunk[3])?;
175
176 result.push((a << 2) | (b >> 4));
177
178 if chunk[2] != b'=' {
179 result.push((b << 4) | (c >> 2));
180 }
181 if chunk[3] != b'=' {
182 result.push((c << 6) | d);
183 }
184 }
185
186 Ok(result)
187}
188
189fn validate_key_data(data: &[u8], key_type: SshKeyType) -> Result<()> {
191 if data.len() < 4 {
192 return Err(CompatError::InvalidSshKey("key data too short".to_string()));
193 }
194
195 let type_len = u32::from_be_bytes([data[0], data[1], data[2], data[3]]) as usize;
197
198 if data.len() < 4 + type_len {
199 return Err(CompatError::InvalidSshKey("key data truncated".to_string()));
200 }
201
202 let type_str = std::str::from_utf8(&data[4..4 + type_len])
204 .map_err(|_| CompatError::InvalidSshKey("invalid key type encoding".to_string()))?;
205
206 let expected_type = match key_type {
207 SshKeyType::Ed25519 => "ssh-ed25519",
208 SshKeyType::Rsa => "ssh-rsa",
209 SshKeyType::Ecdsa => {
210 if !type_str.starts_with("ecdsa-sha2-") {
212 return Err(CompatError::InvalidSshKey(format!(
213 "expected ecdsa key type, got: {}",
214 type_str
215 )));
216 }
217 return Ok(());
218 }
219 };
220
221 if type_str != expected_type {
222 return Err(CompatError::InvalidSshKey(format!(
223 "key type mismatch: expected {}, got {}",
224 expected_type, type_str
225 )));
226 }
227
228 Ok(())
229}
230
231fn calculate_fingerprint(data: &[u8]) -> String {
233 let hash = Sha256::digest(data);
234 format!("SHA256:{}", base64_encode(&hash))
235}
236
237fn base64_encode(data: &[u8]) -> String {
239 const ALPHABET: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
240
241 let mut result = String::with_capacity(data.len().div_ceil(3) * 4);
242
243 for chunk in data.chunks(3) {
244 let b0 = chunk[0] as usize;
245 let b1 = chunk.get(1).copied().unwrap_or(0) as usize;
246 let b2 = chunk.get(2).copied().unwrap_or(0) as usize;
247
248 result.push(ALPHABET[b0 >> 2] as char);
249 result.push(ALPHABET[((b0 & 0x03) << 4) | (b1 >> 4)] as char);
250
251 if chunk.len() > 1 {
252 result.push(ALPHABET[((b1 & 0x0f) << 2) | (b2 >> 6)] as char);
253 }
254 if chunk.len() > 2 {
255 result.push(ALPHABET[b2 & 0x3f] as char);
256 }
257 }
258
259 result
260}
261
262#[derive(Debug, Clone, Serialize, Deserialize)]
264pub struct SshKeyResponse {
265 pub id: SshKeyId,
267 pub title: String,
269 pub key_type: SshKeyType,
271 pub key: String,
273 pub fingerprint: String,
275 pub created_at: String,
277 #[serde(skip_serializing_if = "Option::is_none")]
279 pub last_used_at: Option<String>,
280}
281
282#[derive(Debug, Clone, Serialize, Deserialize)]
284pub struct AddSshKeyRequest {
285 pub title: String,
287 pub key: String,
289}
290
291fn format_timestamp(timestamp: u64) -> String {
293 let secs_per_day = 86400;
294 let secs_per_hour = 3600;
295 let secs_per_min = 60;
296
297 let mut days = timestamp / secs_per_day;
298 let remaining = timestamp % secs_per_day;
299 let hours = remaining / secs_per_hour;
300 let remaining = remaining % secs_per_hour;
301 let minutes = remaining / secs_per_min;
302 let seconds = remaining % secs_per_min;
303
304 let mut year = 1970;
305 loop {
306 let days_in_year = if is_leap_year(year) { 366 } else { 365 };
307 if days < days_in_year {
308 break;
309 }
310 days -= days_in_year;
311 year += 1;
312 }
313
314 let days_in_month = if is_leap_year(year) {
315 [31, 29, 31, 30, 31, 30, 31, 31, 30, 31, 30, 31]
316 } else {
317 [31, 28, 31, 30, 31, 30, 31, 31, 30, 31, 30, 31]
318 };
319
320 let mut month = 0;
321 for (i, &dim) in days_in_month.iter().enumerate() {
322 if days < dim as u64 {
323 month = i + 1;
324 break;
325 }
326 days -= dim as u64;
327 }
328 let day = days + 1;
329
330 format!(
331 "{:04}-{:02}-{:02}T{:02}:{:02}:{:02}Z",
332 year, month, day, hours, minutes, seconds
333 )
334}
335
336fn is_leap_year(year: u64) -> bool {
337 (year.is_multiple_of(4) && !year.is_multiple_of(100)) || year.is_multiple_of(400)
338}
339
340#[cfg(test)]
341mod tests {
342 use super::*;
343
344 const TEST_ED25519_KEY: &str = "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIOMqqnkVzrm0SdG6UOoqKLsabgH5C9okWi0dh2l9GKJl test@example.com";
346
347 #[test]
348 fn test_parse_ed25519_key() {
349 let result = parse_and_validate_key(TEST_ED25519_KEY);
350 assert!(result.is_ok());
351
352 let (key_type, fingerprint) = result.unwrap();
353 assert_eq!(key_type, SshKeyType::Ed25519);
354 assert!(fingerprint.starts_with("SHA256:"));
355 }
356
357 #[test]
358 fn test_invalid_key_format() {
359 assert!(parse_and_validate_key("invalid").is_err());
360 assert!(parse_and_validate_key("unknown-type AAAAB3NzaC1").is_err());
361 }
362
363 #[test]
364 fn test_ssh_key_creation() {
365 let key = SshKey::new(1, 1, "My Key".to_string(), TEST_ED25519_KEY.to_string()).unwrap();
366
367 assert_eq!(key.id, 1);
368 assert_eq!(key.user_id, 1);
369 assert_eq!(key.title, "My Key");
370 assert_eq!(key.key_type, SshKeyType::Ed25519);
371 assert!(key.fingerprint.starts_with("SHA256:"));
372 }
373
374 #[test]
375 fn test_ssh_key_response() {
376 let key = SshKey::new(1, 1, "My Key".to_string(), TEST_ED25519_KEY.to_string()).unwrap();
377 let response = key.to_response();
378
379 assert_eq!(response.id, 1);
380 assert_eq!(response.title, "My Key");
381 assert_eq!(response.key_type, SshKeyType::Ed25519);
382 }
383
384 #[test]
385 fn test_key_type_display() {
386 assert_eq!(SshKeyType::Ed25519.to_string(), "ssh-ed25519");
387 assert_eq!(SshKeyType::Rsa.to_string(), "ssh-rsa");
388 assert_eq!(SshKeyType::Ecdsa.to_string(), "ecdsa-sha2-nistp256");
389 }
390
391 #[test]
392 fn test_base64_roundtrip() {
393 let data = b"Hello, World!";
394 let encoded = base64_encode(data);
395 let padded = format!(
397 "{}{}",
398 encoded,
399 match data.len() % 3 {
400 1 => "==",
401 2 => "=",
402 _ => "",
403 }
404 );
405 let decoded = base64_decode(&padded).unwrap();
406 assert_eq!(decoded, data);
407 }
408}