metalssh 0.0.1

Experimental SSH implementation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
use std::io::prelude::*;
use std::net::TcpStream;
use std::time::Duration;

use bstr::ByteSlice;
use bytes::BytesMut;
use metalssh::constants::msg::SSH_MSG_KEX_ECDH_REPLY;
use metalssh::constants::msg::SSH_MSG_KEXINIT;
use metalssh::constants::msg::SSH_MSG_NEWKEYS;
use metalssh::crypto::cipher::Cipher;
use metalssh::crypto::cipher::chacha20poly1305::ChaCha20Poly1305;
use metalssh::crypto::hash::HashFn;
use metalssh::crypto::hash::Sha256;
use metalssh::crypto::kdf;
use metalssh::crypto::kex::Initiate;
use metalssh::crypto::kex::curve25519_sha256;
use metalssh::msg::kex_ecdh_init::KexEcdhInitBuilder;
use metalssh::msg::kex_ecdh_reply::KexEcdhReply;
use metalssh::msg::kexinit::Kexinit;
use metalssh::msg::kexinit::KexinitBuilder;
use metalssh::packet::Packet;
use metalssh::proto::SshReadable;
use metalssh::proto::SshWritable;
use scroll::Pread;

/// Client state for the SSH connection
struct ClientState {
    server_id: Vec<u8>,
    client_id: Vec<u8>,
    server_kexinit: Option<Vec<u8>>,
    client_kexinit: Option<Vec<u8>>,
    kex: Option<curve25519_sha256::Initiator>,
    cipher: Option<ChaCha20Poly1305>,
    encryption_active: bool, // Only true after both sides send NEWKEYS
    sequence_number: u32,
}

fn main() -> anyhow::Result<()> {
    let mut stream = TcpStream::connect("10.0.0.54:22")?;
    stream.set_read_timeout(Some(Duration::from_secs(5)))?;

    let mut recv_buffer = BytesMut::with_capacity(35000);
    recv_buffer.resize(35000, 0);

    // Read and handle SSH banner
    let n = stream.read(&mut recv_buffer)?;
    let banner = recv_buffer[..n].read_bytes_until(&mut 0, b'\n').unwrap();
    println!("read: {n}, banner len: {}", banner.len());

    let server_id = banner.trim_ascii().to_vec();
    println!("found server banner: '{}'", server_id.as_bstr());

    // Send client banner
    let client_id = b"SSH-2.0-MetalSSH_0.0.0".to_vec();
    stream.write_all(&client_id)?;
    stream.write_all(b"\r\n")?;

    // Initialize client state
    let mut state = ClientState {
        server_id,
        client_id,
        server_kexinit: None,
        client_kexinit: None,
        kex: None,
        cipher: None,
        encryption_active: false,
        sequence_number: 0,
    };

    // Main receive loop
    let mut packet_count = 0;
    loop {
        // Clear and reuse the buffer
        recv_buffer.clear();
        recv_buffer.resize(35000, 0);

        println!("\nWaiting for packet {}...", packet_count + 1);

        // Don't sleep before the first read attempt
        if packet_count > 0 {
            std::thread::sleep(Duration::from_millis(100));
        }

        let n = match stream.read(&mut recv_buffer) {
            Ok(0) => {
                println!("Connection closed by server (EOF)");
                break;
            }
            Ok(n) => {
                println!("Raw bytes received: {:02x?}", &recv_buffer[..n.min(50)]);
                n
            }
            Err(e)
                if e.kind() == std::io::ErrorKind::WouldBlock
                    || e.kind() == std::io::ErrorKind::TimedOut =>
            {
                println!("Read timeout - no more data from server");
                break;
            }
            Err(e) => {
                println!("Read error: {} (kind: {:?})", e, e.kind());
                return Err(e.into());
            }
        };

        packet_count += 1;
        println!("Packet {}: read {n} bytes", packet_count);

        // Process all packets in the buffer (server may pipeline multiple packets)
        let mut offset = 0;
        while offset < n {
            // Check if we have enough bytes for a packet length field
            if n - offset < 4 {
                println!(
                    "Not enough bytes for packet length, remaining: {}",
                    n - offset
                );
                break;
            }

            // Read packet length (may be encrypted for ChaCha20-Poly1305)
            let packet_len = if state.encryption_active {
                // For ChaCha20-Poly1305, the packet length is encrypted
                // We need to decrypt it first
                let cipher = state.cipher.as_ref().expect("Cipher not initialized");
                let encrypted_packet = Packet::new(&recv_buffer[offset..n], 16);
                cipher.decrypt_packet_length(&encrypted_packet, state.sequence_number)?
            } else {
                // No encryption yet, read length directly
                u32::from_be_bytes([
                    recv_buffer[offset],
                    recv_buffer[offset + 1],
                    recv_buffer[offset + 2],
                    recv_buffer[offset + 3],
                ])
            };

            let mac_len = if state.encryption_active { 16 } else { 0 };
            let total_packet_len = 4 + packet_len as usize + mac_len;

            // Check if we have the complete packet
            if offset + total_packet_len > n {
                println!(
                    "Incomplete packet, need {} more bytes",
                    offset + total_packet_len - n
                );
                break;
            }

            println!(
                "Processing packet at offset {}, length {} (encrypted: {})",
                offset, total_packet_len, state.encryption_active
            );

            // Dispatch this packet
            dispatch_message(
                &mut stream,
                &mut state,
                &recv_buffer[offset..offset + total_packet_len],
            )?;

            offset += total_packet_len;
        }
    }

    Ok(())
}

fn dispatch_message(
    stream: &mut TcpStream,
    state: &mut ClientState,
    data: &[u8],
) -> anyhow::Result<()> {
    // Determine MAC length based on encryption state
    let mac_len = if state.encryption_active { 16 } else { 0 };

    // If encrypted, we need to decrypt first
    if state.encryption_active {
        let cipher = state.cipher.as_ref().expect("Cipher not initialized");

        // Make a mutable copy for decryption
        let mut decrypted_data = data.to_vec();
        let mut decrypted_packet = Packet::new(&mut decrypted_data, mac_len);

        cipher.decrypt_packet(&mut decrypted_packet, state.sequence_number)?;
        state.sequence_number += 1;

        // Now parse the decrypted packet
        let packet = Packet::new(&decrypted_data, mac_len);
        println!("\n✓ Packet decrypted successfully!");
        dbg!(&packet);

        let payload = packet.payload()?;
        let msg_type: u8 = payload.pread(0).unwrap();

        match msg_type {
            _ => {
                println!("Received encrypted message type: {}", msg_type);
                println!("Decrypted payload: {:02x?}", payload);
            }
        }

        return Ok(());
    }

    // Unencrypted packet processing
    let packet = Packet::new(data, mac_len);
    dbg!(&packet);

    let payload = packet.payload().unwrap();

    // Read message type
    let msg_type: u8 = payload.pread(0).unwrap();

    match msg_type {
        SSH_MSG_KEXINIT => {
            // Store the raw server KEXINIT packet for later hash calculation
            state.server_kexinit = Some(payload.to_vec());
            // Parse the server's KEXINIT using our new Kexinit struct
            let server_kexinit = Kexinit::new(payload)?;
            println!("Server KEXINIT parsed:");
            dbg!(&server_kexinit);

            // Parse the server's kex algorithms to modify them
            // Note: kex_algorithms() returns the raw name-list contents (without length
            // prefix) so we parse it directly by splitting on commas
            let server_kex_algs = server_kexinit.kex_algorithms();
            let mut kex_algs_vec = server_kex_algs.split_str(",").collect::<Vec<_>>();

            println!("\nServer's kex algorithms:");
            for alg in &kex_algs_vec {
                println!("  {}", alg.as_bstr());
            }

            // Reorder the kex algorithms to prioritize curve25519-sha256
            // Find and move curve25519-sha256 variants to the front
            let prioritized = [
                b"curve25519-sha256".as_slice(),
                b"curve25519-sha256@libssh.org".as_slice(),
            ];
            for priority_alg in prioritized.iter().rev() {
                if let Some(pos) = kex_algs_vec.iter().position(|&alg| alg == *priority_alg) {
                    let alg = kex_algs_vec.remove(pos);
                    kex_algs_vec.insert(0, alg);
                }
            }

            // Replace server-specific extensions with client equivalents
            // ext-info-s -> ext-info-c
            // kex-strict-s-v00@openssh.com -> kex-strict-c-v00@openssh.com
            for alg in kex_algs_vec.iter_mut() {
                if *alg == b"kex-strict-s-v00@openssh.com" {
                    *alg = b"kex-strict-c-v00@openssh.com";
                } else if *alg == b"ext-info-s" {
                    *alg = b"ext-info-c";
                }
            }

            // Build the modified kex algorithms name-list
            let mut kex_algs_string = Vec::new();
            for (i, alg) in kex_algs_vec.iter().enumerate() {
                if i > 0 {
                    kex_algs_string.push(b',');
                }
                kex_algs_string.extend_from_slice(alg);
            }

            println!("\nClient's reordered kex algorithms:");
            for alg in &kex_algs_vec {
                println!("  {}", alg.as_bstr());
            }

            // Generate a random cookie for the client
            // In a real implementation, this should be cryptographically random
            // For now, we'll use a simple pattern
            let client_cookie = b"\x42\x42\x42\x42\x42\x42\x42\x42\x42\x42\x42\x42\x42\x42\x42\x42";

            // Build our client KEXINIT response using the modified values
            let client_kexinit = KexinitBuilder::new()
                .cookie(client_cookie)
                .kex_algorithms(&kex_algs_string)
                .server_host_key_algorithms(server_kexinit.server_host_key_algorithms())
                .encryption_algorithms_client_to_server(
                    server_kexinit.encryption_algorithms_client_to_server(),
                )
                .encryption_algorithms_server_to_client(
                    server_kexinit.encryption_algorithms_server_to_client(),
                )
                .mac_algorithms_client_to_server(server_kexinit.mac_algorithms_client_to_server())
                .mac_algorithms_server_to_client(server_kexinit.mac_algorithms_server_to_client())
                .compression_algorithms_client_to_server(
                    server_kexinit.compression_algorithms_client_to_server(),
                )
                .compression_algorithms_server_to_client(
                    server_kexinit.compression_algorithms_server_to_client(),
                )
                .languages_client_to_server(server_kexinit.languages_client_to_server())
                .languages_server_to_client(server_kexinit.languages_server_to_client())
                .first_kex_packet_follows(server_kexinit.first_kex_packet_follows())
                .build()?;

            println!(
                "\nClient KEXINIT built, size: {} bytes",
                client_kexinit.len()
            );

            // Store the client KEXINIT for later hash calculation
            state.client_kexinit = Some(client_kexinit.clone());

            // Wrap the KEXINIT payload in an SSH packet using Packet::from_payload
            let kexinit_packet = Packet::from_payload(&client_kexinit, 8, 0);
            println!(
                "KEXINIT packet: length={}, padding={}",
                kexinit_packet.packet_length()?,
                kexinit_packet.padding_length()?
            );

            // Send the packet
            let packet_data = kexinit_packet.as_ref();
            println!(
                "\nSending client KEXINIT packet, total size: {} bytes",
                packet_data.len()
            );
            println!(
                "First 100 bytes: {:02x?}",
                &packet_data[..packet_data.len().min(100)]
            );
            stream.write_all(packet_data)?;
            stream.flush()?;
            println!("Client KEXINIT sent successfully!");

            // Try to peek if there's data available
            println!("Checking for immediate server response...");
            std::thread::sleep(Duration::from_millis(200));

            // Set a very short timeout to check for immediate response
            stream.set_read_timeout(Some(Duration::from_millis(500)))?;

            // Now send KEX_ECDH_INIT with our curve25519 public key
            println!("\n=== Generating curve25519 keypair ===");
            let kex = curve25519_sha256::Initiator::new()?;
            let public_key_bytes = kex.public_key();
            println!("Generated public key: {} bytes", public_key_bytes.len());
            println!("Public key bytes: {:02x?}", public_key_bytes);

            // Build KEX_ECDH_INIT message using the builder API
            println!("\n=== Building KEX_ECDH_INIT packet ===");
            let kex_init_payload = KexEcdhInitBuilder::new().q_c(public_key_bytes).build()?;

            println!("Payload size: {} bytes", kex_init_payload.len());

            // Store the kex for later use in KEX_ECDH_REPLY
            state.kex = Some(kex);

            // Wrap in SSH packet using Packet::from_payload
            let kex_packet = Packet::from_payload(&kex_init_payload, 8, 0);
            println!(
                "KEX_ECDH_INIT packet: length={}, padding={}",
                kex_packet.packet_length()?,
                kex_packet.padding_length()?
            );

            // Send KEX_ECDH_INIT packet
            let kex_packet_data = kex_packet.as_ref();
            println!(
                "\nSending KEX_ECDH_INIT packet, total size: {} bytes",
                kex_packet_data.len()
            );
            println!(
                "First 50 bytes: {:02x?}",
                &kex_packet_data[..kex_packet_data.len().min(50)]
            );
            stream.write_all(kex_packet_data)?;
            stream.flush()?;
            println!("KEX_ECDH_INIT sent successfully!");

            // Reset timeout for reading server response
            stream.set_read_timeout(Some(Duration::from_secs(5)))?;
        }
        SSH_MSG_KEX_ECDH_REPLY => {
            // Parse the server's KEX_ECDH_REPLY using our KexEcdhReply struct
            let kex_ecdh_reply = KexEcdhReply::new(payload)?;
            println!("\n=== Received KEX_ECDH_REPLY ===");
            println!(
                "Server's public host key (K_S): {} bytes",
                kex_ecdh_reply.k_s().len()
            );
            println!(
                "Server's ephemeral public key (Q_S): {} bytes",
                kex_ecdh_reply.q_s().len()
            );
            println!("Signature: {} bytes", kex_ecdh_reply.signature().len());
            dbg!(&kex_ecdh_reply);

            // Calculate shared secret K using ECDH
            println!("\n=== Calculating shared secret K ===");
            let kex = state.kex.take().expect("Kex not found");

            // Save client public key before consuming kex
            let client_public_key = kex.public_key().to_vec();

            let shared_secret = kex.agree(kex_ecdh_reply.q_s())?;

            println!(
                "Shared secret K: {} bytes (mpint encoded)",
                shared_secret.len()
            );

            // Compute exchange hash H according to RFC 4253 Section 8
            println!("\n=== Computing exchange hash H ===");

            let client_kexinit = state
                .client_kexinit
                .as_ref()
                .expect("Client KEXINIT not found");
            let server_kexinit = state
                .server_kexinit
                .as_ref()
                .expect("Server KEXINIT not found");

            // TODO: Having to encode these fields back with their length prefixes exposes
            // the want to have a method on the message that returns the wire encoding
            // before parsing, ie its length prefix + bytes.

            // Encode each field as SSH byte-strings into temporary buffers
            let mut v_c = vec![0u8; 4 + state.client_id.len()];
            v_c.write_byte_string(&state.client_id, &mut 0)?;

            let mut v_s = vec![0u8; 4 + state.server_id.len()];
            v_s.write_byte_string(&state.server_id, &mut 0)?;

            let mut i_c = vec![0u8; 4 + client_kexinit.len()];
            i_c.write_byte_string(client_kexinit, &mut 0)?;

            let mut i_s = vec![0u8; 4 + server_kexinit.len()];
            i_s.write_byte_string(server_kexinit, &mut 0)?;

            let mut k_s = vec![0u8; 4 + kex_ecdh_reply.k_s().len()];
            k_s.write_byte_string(kex_ecdh_reply.k_s(), &mut 0)?;

            let mut q_c = vec![0u8; 4 + client_public_key.len()];
            q_c.write_byte_string(&client_public_key, &mut 0)?;

            let mut q_s = vec![0u8; 4 + kex_ecdh_reply.q_s().len()];
            q_s.write_byte_string(kex_ecdh_reply.q_s(), &mut 0)?;

            // Hash with SHA-256, passing each encoded field separately
            let h = Sha256::exchange_hash(&v_c, &v_s, &i_c, &i_s, &k_s, &q_c, &q_s, &shared_secret);
            let h_bytes = &h;
            println!("Exchange hash H: {} bytes", h_bytes.len());
            println!("H = {:02x?}", h_bytes);

            // Derive encryption keys using SSH KDF (RFC 4253 Section 7.2)
            println!("\n=== Deriving encryption keys ===");

            // Initial IV client to server: HASH(K || H || "A" || session_id)
            // Initial IV server to client: HASH(K || H || "B" || session_id)
            // Encryption key client to server: HASH(K || H || "C" || session_id)
            // Encryption key server to client: HASH(K || H || "D" || session_id)
            // Integrity key client to server: HASH(K || H || "E" || session_id)
            // Integrity key server to client: HASH(K || H || "F" || session_id)

            // For ChaCha20-Poly1305, we need 64 bytes total (32 for main key + 32 for
            // poly1305 key) We'll derive key material for server to client
            // direction

            let session_id = h_bytes; // First H becomes the session_id

            // Use ssh_kdf to derive the encryption key with our new Sha256 type
            let mut key_s2c = [0u8; 64];
            kdf::ssh_kdf::<Sha256, 32>(
                &shared_secret,
                h_bytes,
                kdf::key_type::ENCRYPTION_KEY_SERVER_TO_CLIENT,
                session_id,
                &mut key_s2c,
            );
            println!("Encryption key (server to client): {} bytes", key_s2c.len());

            // Create ChaCha20-Poly1305 cipher with the derived key material
            let mut key_material = [0u8; 64];
            key_material.copy_from_slice(&key_s2c);
            state.cipher = Some(ChaCha20Poly1305::new(key_material));

            println!("ChaCha20-Poly1305 cipher initialized!");

            // Send SSH_MSG_NEWKEYS
            println!("\n=== Sending SSH_MSG_NEWKEYS ===");
            let newkeys_payload = vec![SSH_MSG_NEWKEYS];
            let newkeys_packet = Packet::from_payload(&newkeys_payload, 8, 0);
            stream.write_all(newkeys_packet.as_ref())?;
            stream.flush()?;
            println!("SSH_MSG_NEWKEYS sent!");

            // The server should send NEWKEYS immediately, so don't wait
            stream.set_read_timeout(Some(Duration::from_millis(100)))?;
        }
        SSH_MSG_NEWKEYS => {
            println!("\n=== Received SSH_MSG_NEWKEYS ===");
            println!("Server has switched to encrypted mode!");
            println!("Activating ChaCha20-Poly1305 encryption for all subsequent packets");

            // Now activate encryption - all subsequent packets will be encrypted
            state.encryption_active = true;
        }
        _ => {
            println!("Received unhandled message type: {}", msg_type);
            println!("Packet details:");
            dbg!(&packet);
        }
    }

    Ok(())
}