Skip to main content

guts_compat/
ssh_key.rs

1//! SSH key management types.
2
3use serde::{Deserialize, Serialize};
4use sha2::{Digest, Sha256};
5use std::time::{SystemTime, UNIX_EPOCH};
6
7use crate::error::{CompatError, Result};
8use crate::user::UserId;
9
10/// Unique identifier for an SSH key.
11pub type SshKeyId = u64;
12
13/// An SSH public key for authentication.
14#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct SshKey {
16    /// Unique key ID.
17    pub id: SshKeyId,
18    /// User who owns this key.
19    pub user_id: UserId,
20    /// User-provided title/name.
21    pub title: String,
22    /// Key type (ed25519, rsa, ecdsa).
23    pub key_type: SshKeyType,
24    /// Full public key string.
25    pub public_key: String,
26    /// SHA256 fingerprint.
27    pub fingerprint: String,
28    /// When the key was added.
29    pub created_at: u64,
30    /// Last time the key was used.
31    pub last_used_at: Option<u64>,
32}
33
34impl SshKey {
35    /// Create a new SSH key from a public key string.
36    pub fn new(id: SshKeyId, user_id: UserId, title: String, public_key: String) -> Result<Self> {
37        let (key_type, fingerprint) = parse_and_validate_key(&public_key)?;
38
39        let now = SystemTime::now()
40            .duration_since(UNIX_EPOCH)
41            .unwrap()
42            .as_secs();
43
44        Ok(Self {
45            id,
46            user_id,
47            title,
48            key_type,
49            public_key,
50            fingerprint,
51            created_at: now,
52            last_used_at: None,
53        })
54    }
55
56    /// Update the last_used_at timestamp.
57    pub fn touch(&mut self) {
58        self.last_used_at = Some(
59            SystemTime::now()
60                .duration_since(UNIX_EPOCH)
61                .unwrap()
62                .as_secs(),
63        );
64    }
65
66    /// Convert to API response.
67    pub fn to_response(&self) -> SshKeyResponse {
68        SshKeyResponse {
69            id: self.id,
70            title: self.title.clone(),
71            key_type: self.key_type,
72            key: self.public_key.clone(),
73            fingerprint: self.fingerprint.clone(),
74            created_at: format_timestamp(self.created_at),
75            last_used_at: self.last_used_at.map(format_timestamp),
76        }
77    }
78}
79
80/// SSH key type.
81#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
82#[serde(rename_all = "lowercase")]
83pub enum SshKeyType {
84    /// Ed25519 key (preferred).
85    Ed25519,
86    /// RSA key.
87    Rsa,
88    /// ECDSA key.
89    Ecdsa,
90}
91
92impl std::fmt::Display for SshKeyType {
93    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
94        match self {
95            Self::Ed25519 => write!(f, "ssh-ed25519"),
96            Self::Rsa => write!(f, "ssh-rsa"),
97            Self::Ecdsa => write!(f, "ecdsa-sha2-nistp256"),
98        }
99    }
100}
101
102/// Parse and validate an SSH public key string.
103///
104/// Returns the key type and SHA256 fingerprint.
105fn parse_and_validate_key(key: &str) -> Result<(SshKeyType, String)> {
106    let parts: Vec<&str> = key.split_whitespace().collect();
107
108    if parts.len() < 2 {
109        return Err(CompatError::InvalidSshKey(
110            "key must have at least type and data parts".to_string(),
111        ));
112    }
113
114    let key_type = match parts[0] {
115        "ssh-ed25519" => SshKeyType::Ed25519,
116        "ssh-rsa" => SshKeyType::Rsa,
117        "ecdsa-sha2-nistp256" | "ecdsa-sha2-nistp384" | "ecdsa-sha2-nistp521" => SshKeyType::Ecdsa,
118        other => {
119            return Err(CompatError::InvalidSshKey(format!(
120                "unsupported key type: {}",
121                other
122            )));
123        }
124    };
125
126    // Validate base64 encoding and calculate fingerprint
127    let key_data = parts[1];
128    let decoded = base64_decode(key_data)
129        .map_err(|e| CompatError::InvalidSshKey(format!("invalid base64 encoding: {}", e)))?;
130
131    // Validate key data has correct structure
132    validate_key_data(&decoded, key_type)?;
133
134    // Calculate SHA256 fingerprint
135    let fingerprint = calculate_fingerprint(&decoded);
136
137    Ok((key_type, fingerprint))
138}
139
140/// Basic base64 decoding.
141fn base64_decode(input: &str) -> std::result::Result<Vec<u8>, &'static str> {
142    const ALPHABET: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
143
144    fn char_to_value(c: u8) -> std::result::Result<u8, &'static str> {
145        if let Some(pos) = ALPHABET.iter().position(|&x| x == c) {
146            Ok(pos as u8)
147        } else if c == b'=' {
148            Ok(0) // Padding
149        } else {
150            Err("invalid base64 character")
151        }
152    }
153
154    let input = input.trim();
155    if input.is_empty() {
156        return Err("empty input");
157    }
158
159    let bytes: Vec<u8> = input
160        .bytes()
161        .filter(|b| *b != b'\n' && *b != b'\r')
162        .collect();
163
164    if !bytes.len().is_multiple_of(4) {
165        return Err("invalid base64 length");
166    }
167
168    let mut result = Vec::with_capacity(bytes.len() * 3 / 4);
169
170    for chunk in bytes.chunks(4) {
171        let a = char_to_value(chunk[0])?;
172        let b = char_to_value(chunk[1])?;
173        let c = char_to_value(chunk[2])?;
174        let d = char_to_value(chunk[3])?;
175
176        result.push((a << 2) | (b >> 4));
177
178        if chunk[2] != b'=' {
179            result.push((b << 4) | (c >> 2));
180        }
181        if chunk[3] != b'=' {
182            result.push((c << 6) | d);
183        }
184    }
185
186    Ok(result)
187}
188
189/// Validate key data structure.
190fn validate_key_data(data: &[u8], key_type: SshKeyType) -> Result<()> {
191    if data.len() < 4 {
192        return Err(CompatError::InvalidSshKey("key data too short".to_string()));
193    }
194
195    // First 4 bytes are length of key type string
196    let type_len = u32::from_be_bytes([data[0], data[1], data[2], data[3]]) as usize;
197
198    if data.len() < 4 + type_len {
199        return Err(CompatError::InvalidSshKey("key data truncated".to_string()));
200    }
201
202    // Verify key type matches
203    let type_str = std::str::from_utf8(&data[4..4 + type_len])
204        .map_err(|_| CompatError::InvalidSshKey("invalid key type encoding".to_string()))?;
205
206    let expected_type = match key_type {
207        SshKeyType::Ed25519 => "ssh-ed25519",
208        SshKeyType::Rsa => "ssh-rsa",
209        SshKeyType::Ecdsa => {
210            // ECDSA can have different curve names
211            if !type_str.starts_with("ecdsa-sha2-") {
212                return Err(CompatError::InvalidSshKey(format!(
213                    "expected ecdsa key type, got: {}",
214                    type_str
215                )));
216            }
217            return Ok(());
218        }
219    };
220
221    if type_str != expected_type {
222        return Err(CompatError::InvalidSshKey(format!(
223            "key type mismatch: expected {}, got {}",
224            expected_type, type_str
225        )));
226    }
227
228    Ok(())
229}
230
231/// Calculate SHA256 fingerprint of key data.
232fn calculate_fingerprint(data: &[u8]) -> String {
233    let hash = Sha256::digest(data);
234    format!("SHA256:{}", base64_encode(&hash))
235}
236
237/// Basic base64 encoding (without padding).
238fn base64_encode(data: &[u8]) -> String {
239    const ALPHABET: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
240
241    let mut result = String::with_capacity(data.len().div_ceil(3) * 4);
242
243    for chunk in data.chunks(3) {
244        let b0 = chunk[0] as usize;
245        let b1 = chunk.get(1).copied().unwrap_or(0) as usize;
246        let b2 = chunk.get(2).copied().unwrap_or(0) as usize;
247
248        result.push(ALPHABET[b0 >> 2] as char);
249        result.push(ALPHABET[((b0 & 0x03) << 4) | (b1 >> 4)] as char);
250
251        if chunk.len() > 1 {
252            result.push(ALPHABET[((b1 & 0x0f) << 2) | (b2 >> 6)] as char);
253        }
254        if chunk.len() > 2 {
255            result.push(ALPHABET[b2 & 0x3f] as char);
256        }
257    }
258
259    result
260}
261
262/// SSH key response for API.
263#[derive(Debug, Clone, Serialize, Deserialize)]
264pub struct SshKeyResponse {
265    /// Key ID.
266    pub id: SshKeyId,
267    /// User-provided title.
268    pub title: String,
269    /// Key type.
270    pub key_type: SshKeyType,
271    /// Full public key.
272    pub key: String,
273    /// SHA256 fingerprint.
274    pub fingerprint: String,
275    /// Creation timestamp.
276    pub created_at: String,
277    /// Last used timestamp.
278    #[serde(skip_serializing_if = "Option::is_none")]
279    pub last_used_at: Option<String>,
280}
281
282/// Request to add an SSH key.
283#[derive(Debug, Clone, Serialize, Deserialize)]
284pub struct AddSshKeyRequest {
285    /// Title/name for the key.
286    pub title: String,
287    /// Full public key string.
288    pub key: String,
289}
290
291/// Format a Unix timestamp as ISO 8601.
292fn format_timestamp(timestamp: u64) -> String {
293    let secs_per_day = 86400;
294    let secs_per_hour = 3600;
295    let secs_per_min = 60;
296
297    let mut days = timestamp / secs_per_day;
298    let remaining = timestamp % secs_per_day;
299    let hours = remaining / secs_per_hour;
300    let remaining = remaining % secs_per_hour;
301    let minutes = remaining / secs_per_min;
302    let seconds = remaining % secs_per_min;
303
304    let mut year = 1970;
305    loop {
306        let days_in_year = if is_leap_year(year) { 366 } else { 365 };
307        if days < days_in_year {
308            break;
309        }
310        days -= days_in_year;
311        year += 1;
312    }
313
314    let days_in_month = if is_leap_year(year) {
315        [31, 29, 31, 30, 31, 30, 31, 31, 30, 31, 30, 31]
316    } else {
317        [31, 28, 31, 30, 31, 30, 31, 31, 30, 31, 30, 31]
318    };
319
320    let mut month = 0;
321    for (i, &dim) in days_in_month.iter().enumerate() {
322        if days < dim as u64 {
323            month = i + 1;
324            break;
325        }
326        days -= dim as u64;
327    }
328    let day = days + 1;
329
330    format!(
331        "{:04}-{:02}-{:02}T{:02}:{:02}:{:02}Z",
332        year, month, day, hours, minutes, seconds
333    )
334}
335
336fn is_leap_year(year: u64) -> bool {
337    (year.is_multiple_of(4) && !year.is_multiple_of(100)) || year.is_multiple_of(400)
338}
339
340#[cfg(test)]
341mod tests {
342    use super::*;
343
344    // A valid Ed25519 public key for testing
345    const TEST_ED25519_KEY: &str = "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIOMqqnkVzrm0SdG6UOoqKLsabgH5C9okWi0dh2l9GKJl test@example.com";
346
347    #[test]
348    fn test_parse_ed25519_key() {
349        let result = parse_and_validate_key(TEST_ED25519_KEY);
350        assert!(result.is_ok());
351
352        let (key_type, fingerprint) = result.unwrap();
353        assert_eq!(key_type, SshKeyType::Ed25519);
354        assert!(fingerprint.starts_with("SHA256:"));
355    }
356
357    #[test]
358    fn test_invalid_key_format() {
359        assert!(parse_and_validate_key("invalid").is_err());
360        assert!(parse_and_validate_key("unknown-type AAAAB3NzaC1").is_err());
361    }
362
363    #[test]
364    fn test_ssh_key_creation() {
365        let key = SshKey::new(1, 1, "My Key".to_string(), TEST_ED25519_KEY.to_string()).unwrap();
366
367        assert_eq!(key.id, 1);
368        assert_eq!(key.user_id, 1);
369        assert_eq!(key.title, "My Key");
370        assert_eq!(key.key_type, SshKeyType::Ed25519);
371        assert!(key.fingerprint.starts_with("SHA256:"));
372    }
373
374    #[test]
375    fn test_ssh_key_response() {
376        let key = SshKey::new(1, 1, "My Key".to_string(), TEST_ED25519_KEY.to_string()).unwrap();
377        let response = key.to_response();
378
379        assert_eq!(response.id, 1);
380        assert_eq!(response.title, "My Key");
381        assert_eq!(response.key_type, SshKeyType::Ed25519);
382    }
383
384    #[test]
385    fn test_key_type_display() {
386        assert_eq!(SshKeyType::Ed25519.to_string(), "ssh-ed25519");
387        assert_eq!(SshKeyType::Rsa.to_string(), "ssh-rsa");
388        assert_eq!(SshKeyType::Ecdsa.to_string(), "ecdsa-sha2-nistp256");
389    }
390
391    #[test]
392    fn test_base64_roundtrip() {
393        let data = b"Hello, World!";
394        let encoded = base64_encode(data);
395        // Note: Our encode doesn't add padding, so we add it for decode
396        let padded = format!(
397            "{}{}",
398            encoded,
399            match data.len() % 3 {
400                1 => "==",
401                2 => "=",
402                _ => "",
403            }
404        );
405        let decoded = base64_decode(&padded).unwrap();
406        assert_eq!(decoded, data);
407    }
408}