pcapsql_core/protocol/
ssh.rs

1//! SSH protocol parser.
2//!
3//! Parses SSH (Secure Shell) protocol identification strings and binary packets,
4//! particularly KEXINIT messages for algorithm negotiation analysis.
5
6use compact_str::CompactString;
7use smallvec::SmallVec;
8
9use super::{FieldValue, ParseContext, ParseResult, Protocol};
10use crate::schema::{DataKind, FieldDescriptor};
11
12/// SSH default port.
13pub const SSH_PORT: u16 = 22;
14
15/// SSH message types.
16mod msg_type {
17    pub const SSH_MSG_DISCONNECT: u8 = 1;
18    pub const SSH_MSG_IGNORE: u8 = 2;
19    pub const SSH_MSG_UNIMPLEMENTED: u8 = 3;
20    pub const SSH_MSG_DEBUG: u8 = 4;
21    pub const SSH_MSG_SERVICE_REQUEST: u8 = 5;
22    pub const SSH_MSG_SERVICE_ACCEPT: u8 = 6;
23    pub const SSH_MSG_KEXINIT: u8 = 20;
24    pub const SSH_MSG_NEWKEYS: u8 = 21;
25    pub const SSH_MSG_KEX_DH_INIT: u8 = 30;
26    pub const SSH_MSG_KEX_DH_REPLY: u8 = 31;
27    pub const SSH_MSG_USERAUTH_REQUEST: u8 = 50;
28    pub const SSH_MSG_USERAUTH_FAILURE: u8 = 51;
29    pub const SSH_MSG_USERAUTH_SUCCESS: u8 = 52;
30    pub const SSH_MSG_USERAUTH_BANNER: u8 = 53;
31    pub const SSH_MSG_CHANNEL_OPEN: u8 = 90;
32    pub const SSH_MSG_CHANNEL_OPEN_CONFIRMATION: u8 = 91;
33    pub const SSH_MSG_CHANNEL_OPEN_FAILURE: u8 = 92;
34    pub const SSH_MSG_CHANNEL_WINDOW_ADJUST: u8 = 93;
35    pub const SSH_MSG_CHANNEL_DATA: u8 = 94;
36    pub const SSH_MSG_CHANNEL_EXTENDED_DATA: u8 = 95;
37    pub const SSH_MSG_CHANNEL_EOF: u8 = 96;
38    pub const SSH_MSG_CHANNEL_CLOSE: u8 = 97;
39    pub const SSH_MSG_CHANNEL_REQUEST: u8 = 98;
40    pub const SSH_MSG_CHANNEL_SUCCESS: u8 = 99;
41    pub const SSH_MSG_CHANNEL_FAILURE: u8 = 100;
42}
43
44/// SSH protocol parser.
45#[derive(Debug, Clone, Copy)]
46pub struct SshProtocol;
47
48impl Protocol for SshProtocol {
49    fn name(&self) -> &'static str {
50        "ssh"
51    }
52
53    fn display_name(&self) -> &'static str {
54        "SSH"
55    }
56
57    fn can_parse(&self, context: &ParseContext) -> Option<u32> {
58        let src_port = context.hint("src_port");
59        let dst_port = context.hint("dst_port");
60
61        // Check for SSH port
62        match (src_port, dst_port) {
63            (Some(p), _) | (_, Some(p)) if p == SSH_PORT as u64 => Some(50),
64            _ => None,
65        }
66    }
67
68    fn parse<'a>(&self, data: &'a [u8], _context: &ParseContext) -> ParseResult<'a> {
69        let mut fields = SmallVec::new();
70
71        // Check if this is an SSH protocol identification string
72        if data.starts_with(b"SSH-") {
73            return parse_protocol_identification(data, &mut fields);
74        }
75
76        // Try to parse as SSH binary packet
77        if data.len() < 5 {
78            return ParseResult::error("SSH packet too short".to_string(), data);
79        }
80
81        let packet_length = u32::from_be_bytes([data[0], data[1], data[2], data[3]]) as usize;
82
83        // Sanity check on packet length - SSH packets shouldn't exceed 35000 bytes typically
84        if !(2..=35000).contains(&packet_length) {
85            // This might be encrypted traffic or not an SSH packet
86            fields.push(("encrypted", FieldValue::Bool(true)));
87            let remaining_start = data.len().min(4 + packet_length);
88            return ParseResult::success(fields, &data[remaining_start..], SmallVec::new());
89        }
90
91        let padding_length = data[4] as usize;
92        fields.push(("packet_length", FieldValue::UInt32(packet_length as u32)));
93        fields.push(("padding_length", FieldValue::UInt8(padding_length as u8)));
94
95        // Check if we have the full packet
96        if data.len() < 4 + packet_length {
97            return ParseResult::partial(fields, &data[4..], "SSH packet truncated".to_string());
98        }
99
100        // Payload starts at offset 5
101        let payload_length = packet_length.saturating_sub(padding_length + 1);
102        if payload_length == 0 || data.len() < 6 {
103            return ParseResult::success(fields, &data[4 + packet_length..], SmallVec::new());
104        }
105
106        let msg_type = data[5];
107        fields.push(("msg_type", FieldValue::UInt8(msg_type)));
108        fields.push(("msg_type_name", format_msg_type(msg_type)));
109
110        // Parse specific message types
111        let payload = &data[5..5 + payload_length];
112        match msg_type {
113            msg_type::SSH_MSG_KEXINIT => {
114                parse_kexinit_message(payload, &mut fields);
115            }
116            msg_type::SSH_MSG_USERAUTH_REQUEST => {
117                parse_userauth_request(payload, &mut fields);
118            }
119            msg_type::SSH_MSG_CHANNEL_OPEN => {
120                parse_channel_open(payload, &mut fields);
121            }
122            _ => {}
123        }
124
125        let remaining = &data[4 + packet_length..];
126        ParseResult::success(fields, remaining, SmallVec::new())
127    }
128
129    fn schema_fields(&self) -> Vec<FieldDescriptor> {
130        vec![
131            // Protocol identification
132            FieldDescriptor::new("ssh.protocol_version", DataKind::String).set_nullable(true),
133            FieldDescriptor::new("ssh.software_version", DataKind::String).set_nullable(true),
134            FieldDescriptor::new("ssh.comments", DataKind::String).set_nullable(true),
135            // Binary packet
136            FieldDescriptor::new("ssh.packet_length", DataKind::UInt32).set_nullable(true),
137            FieldDescriptor::new("ssh.padding_length", DataKind::UInt8).set_nullable(true),
138            FieldDescriptor::new("ssh.msg_type", DataKind::UInt8).set_nullable(true),
139            FieldDescriptor::new("ssh.msg_type_name", DataKind::String).set_nullable(true),
140            FieldDescriptor::new("ssh.encrypted", DataKind::Bool).set_nullable(true),
141            // KEXINIT
142            FieldDescriptor::new("ssh.kex_algorithms", DataKind::String).set_nullable(true),
143            FieldDescriptor::new("ssh.host_key_algorithms", DataKind::String).set_nullable(true),
144            FieldDescriptor::new("ssh.encryption_algorithms", DataKind::String).set_nullable(true),
145            FieldDescriptor::new("ssh.mac_algorithms", DataKind::String).set_nullable(true),
146            FieldDescriptor::new("ssh.compression_algorithms", DataKind::String).set_nullable(true),
147            // USERAUTH
148            FieldDescriptor::new("ssh.auth_username", DataKind::String).set_nullable(true),
149            FieldDescriptor::new("ssh.auth_service", DataKind::String).set_nullable(true),
150            FieldDescriptor::new("ssh.auth_method", DataKind::String).set_nullable(true),
151            // CHANNEL
152            FieldDescriptor::new("ssh.channel_type", DataKind::String).set_nullable(true),
153            FieldDescriptor::new("ssh.channel_id", DataKind::UInt32).set_nullable(true),
154        ]
155    }
156
157    fn child_protocols(&self) -> &[&'static str] {
158        &[]
159    }
160
161    fn dependencies(&self) -> &'static [&'static str] {
162        &["tcp"]
163    }
164}
165
166/// Parse SSH protocol identification string.
167fn parse_protocol_identification<'a>(
168    data: &'a [u8],
169    fields: &mut SmallVec<[(&'static str, FieldValue<'a>); 16]>,
170) -> ParseResult<'a> {
171    // Find the end of the identification string (CR LF or just LF)
172    let line_end = data.iter().position(|&b| b == b'\n').unwrap_or(data.len());
173    let line = &data[..line_end];
174
175    // Remove trailing CR if present
176    let line = if line.ends_with(b"\r") {
177        &line[..line.len() - 1]
178    } else {
179        line
180    };
181
182    // Parse "SSH-protoversion-softwareversion SP comments"
183    if let Ok(line_str) = std::str::from_utf8(line) {
184        if let Some(content) = line_str.strip_prefix("SSH-") {
185            if let Some(dash_pos) = content.find('-') {
186                let proto_version = &content[..dash_pos];
187                fields.push((
188                    "protocol_version",
189                    FieldValue::OwnedString(CompactString::new(proto_version)),
190                ));
191
192                let rest = &content[dash_pos + 1..];
193
194                // Software version ends at space (if comments follow) or end of line
195                if let Some(space_pos) = rest.find(' ') {
196                    let software_version = &rest[..space_pos];
197                    let comments = rest[space_pos + 1..].trim();
198
199                    fields.push((
200                        "software_version",
201                        FieldValue::OwnedString(CompactString::new(software_version)),
202                    ));
203                    if !comments.is_empty() {
204                        fields.push((
205                            "comments",
206                            FieldValue::OwnedString(CompactString::new(comments)),
207                        ));
208                    }
209                } else {
210                    fields.push((
211                        "software_version",
212                        FieldValue::OwnedString(CompactString::new(rest)),
213                    ));
214                }
215            }
216        }
217    }
218
219    // Remaining data after the identification string
220    let remaining_start = (line_end + 1).min(data.len());
221    ParseResult::success(fields.clone(), &data[remaining_start..], SmallVec::new())
222}
223
224/// Parse KEXINIT message to extract algorithm lists.
225fn parse_kexinit_message(payload: &[u8], fields: &mut SmallVec<[(&'static str, FieldValue); 16]>) {
226    // KEXINIT format:
227    // byte      SSH_MSG_KEXINIT (20) - included in payload
228    // byte[16]  cookie
229    // name-list kex_algorithms
230    // name-list server_host_key_algorithms
231    // ...
232    if payload.len() < 17 {
233        return;
234    }
235
236    let mut offset = 17; // Skip msg_type (1) + cookie (16)
237
238    // Helper to read a name-list
239    let read_name_list = |data: &[u8], off: &mut usize| -> Option<String> {
240        if *off + 4 > data.len() {
241            return None;
242        }
243        let len = u32::from_be_bytes([data[*off], data[*off + 1], data[*off + 2], data[*off + 3]])
244            as usize;
245        *off += 4;
246        if *off + len > data.len() {
247            return None;
248        }
249        let value = std::str::from_utf8(&data[*off..*off + len])
250            .ok()?
251            .to_string();
252        *off += len;
253        Some(value)
254    };
255
256    if let Some(kex_algs) = read_name_list(payload, &mut offset) {
257        if !kex_algs.is_empty() {
258            fields.push((
259                "kex_algorithms",
260                FieldValue::OwnedString(CompactString::new(kex_algs)),
261            ));
262        }
263    }
264
265    if let Some(host_key_algs) = read_name_list(payload, &mut offset) {
266        if !host_key_algs.is_empty() {
267            fields.push((
268                "host_key_algorithms",
269                FieldValue::OwnedString(CompactString::new(host_key_algs)),
270            ));
271        }
272    }
273
274    if let Some(enc_c2s) = read_name_list(payload, &mut offset) {
275        if !enc_c2s.is_empty() {
276            fields.push((
277                "encryption_algorithms",
278                FieldValue::OwnedString(CompactString::new(enc_c2s)),
279            ));
280        }
281    }
282
283    // Skip encryption_algorithms_server_to_client
284    let _ = read_name_list(payload, &mut offset);
285
286    if let Some(mac_c2s) = read_name_list(payload, &mut offset) {
287        if !mac_c2s.is_empty() {
288            fields.push((
289                "mac_algorithms",
290                FieldValue::OwnedString(CompactString::new(mac_c2s)),
291            ));
292        }
293    }
294
295    // Skip mac_algorithms_server_to_client
296    let _ = read_name_list(payload, &mut offset);
297
298    if let Some(comp_c2s) = read_name_list(payload, &mut offset) {
299        if !comp_c2s.is_empty() {
300            fields.push((
301                "compression_algorithms",
302                FieldValue::OwnedString(CompactString::new(comp_c2s)),
303            ));
304        }
305    }
306}
307
308/// Parse USERAUTH_REQUEST message.
309fn parse_userauth_request(payload: &[u8], fields: &mut SmallVec<[(&'static str, FieldValue); 16]>) {
310    if payload.len() < 5 {
311        return;
312    }
313
314    let mut offset = 1; // Skip msg_type
315
316    let read_string = |data: &[u8], off: &mut usize| -> Option<String> {
317        if *off + 4 > data.len() {
318            return None;
319        }
320        let len = u32::from_be_bytes([data[*off], data[*off + 1], data[*off + 2], data[*off + 3]])
321            as usize;
322        *off += 4;
323        if *off + len > data.len() {
324            return None;
325        }
326        let value = std::str::from_utf8(&data[*off..*off + len])
327            .ok()?
328            .to_string();
329        *off += len;
330        Some(value)
331    };
332
333    if let Some(username) = read_string(payload, &mut offset) {
334        if !username.is_empty() {
335            fields.push((
336                "auth_username",
337                FieldValue::OwnedString(CompactString::new(username)),
338            ));
339        }
340    }
341
342    if let Some(service) = read_string(payload, &mut offset) {
343        if !service.is_empty() {
344            fields.push((
345                "auth_service",
346                FieldValue::OwnedString(CompactString::new(service)),
347            ));
348        }
349    }
350
351    if let Some(method) = read_string(payload, &mut offset) {
352        if !method.is_empty() {
353            fields.push((
354                "auth_method",
355                FieldValue::OwnedString(CompactString::new(method)),
356            ));
357        }
358    }
359}
360
361/// Parse CHANNEL_OPEN message.
362fn parse_channel_open(payload: &[u8], fields: &mut SmallVec<[(&'static str, FieldValue); 16]>) {
363    if payload.len() < 5 {
364        return;
365    }
366
367    let mut offset = 1; // Skip msg_type
368
369    // Read channel type string
370    if offset + 4 > payload.len() {
371        return;
372    }
373    let len = u32::from_be_bytes([
374        payload[offset],
375        payload[offset + 1],
376        payload[offset + 2],
377        payload[offset + 3],
378    ]) as usize;
379    offset += 4;
380
381    if offset + len > payload.len() {
382        return;
383    }
384
385    if let Ok(channel_type) = std::str::from_utf8(&payload[offset..offset + len]) {
386        if !channel_type.is_empty() {
387            fields.push((
388                "channel_type",
389                FieldValue::OwnedString(CompactString::new(channel_type)),
390            ));
391        }
392    }
393    offset += len;
394
395    // Read sender channel ID
396    if offset + 4 <= payload.len() {
397        let channel_id = u32::from_be_bytes([
398            payload[offset],
399            payload[offset + 1],
400            payload[offset + 2],
401            payload[offset + 3],
402        ]);
403        fields.push(("channel_id", FieldValue::UInt32(channel_id)));
404    }
405}
406
407/// Format SSH message type as a readable name.
408fn format_msg_type(msg_type: u8) -> FieldValue<'static> {
409    match msg_type {
410        msg_type::SSH_MSG_DISCONNECT => FieldValue::Str("DISCONNECT"),
411        msg_type::SSH_MSG_IGNORE => FieldValue::Str("IGNORE"),
412        msg_type::SSH_MSG_UNIMPLEMENTED => FieldValue::Str("UNIMPLEMENTED"),
413        msg_type::SSH_MSG_DEBUG => FieldValue::Str("DEBUG"),
414        msg_type::SSH_MSG_SERVICE_REQUEST => FieldValue::Str("SERVICE_REQUEST"),
415        msg_type::SSH_MSG_SERVICE_ACCEPT => FieldValue::Str("SERVICE_ACCEPT"),
416        msg_type::SSH_MSG_KEXINIT => FieldValue::Str("KEXINIT"),
417        msg_type::SSH_MSG_NEWKEYS => FieldValue::Str("NEWKEYS"),
418        msg_type::SSH_MSG_KEX_DH_INIT => FieldValue::Str("KEX_DH_INIT"),
419        msg_type::SSH_MSG_KEX_DH_REPLY => FieldValue::Str("KEX_DH_REPLY"),
420        msg_type::SSH_MSG_USERAUTH_REQUEST => FieldValue::Str("USERAUTH_REQUEST"),
421        msg_type::SSH_MSG_USERAUTH_FAILURE => FieldValue::Str("USERAUTH_FAILURE"),
422        msg_type::SSH_MSG_USERAUTH_SUCCESS => FieldValue::Str("USERAUTH_SUCCESS"),
423        msg_type::SSH_MSG_USERAUTH_BANNER => FieldValue::Str("USERAUTH_BANNER"),
424        msg_type::SSH_MSG_CHANNEL_OPEN => FieldValue::Str("CHANNEL_OPEN"),
425        msg_type::SSH_MSG_CHANNEL_OPEN_CONFIRMATION => FieldValue::Str("CHANNEL_OPEN_CONFIRMATION"),
426        msg_type::SSH_MSG_CHANNEL_OPEN_FAILURE => FieldValue::Str("CHANNEL_OPEN_FAILURE"),
427        msg_type::SSH_MSG_CHANNEL_WINDOW_ADJUST => FieldValue::Str("CHANNEL_WINDOW_ADJUST"),
428        msg_type::SSH_MSG_CHANNEL_DATA => FieldValue::Str("CHANNEL_DATA"),
429        msg_type::SSH_MSG_CHANNEL_EXTENDED_DATA => FieldValue::Str("CHANNEL_EXTENDED_DATA"),
430        msg_type::SSH_MSG_CHANNEL_EOF => FieldValue::Str("CHANNEL_EOF"),
431        msg_type::SSH_MSG_CHANNEL_CLOSE => FieldValue::Str("CHANNEL_CLOSE"),
432        msg_type::SSH_MSG_CHANNEL_REQUEST => FieldValue::Str("CHANNEL_REQUEST"),
433        msg_type::SSH_MSG_CHANNEL_SUCCESS => FieldValue::Str("CHANNEL_SUCCESS"),
434        msg_type::SSH_MSG_CHANNEL_FAILURE => FieldValue::Str("CHANNEL_FAILURE"),
435        _ => FieldValue::OwnedString(CompactString::new(format!("UNKNOWN({msg_type})"))),
436    }
437}
438
439#[cfg(test)]
440mod tests {
441    use super::*;
442
443    fn create_ssh_identification(proto: &str, software: &str, comments: Option<&str>) -> Vec<u8> {
444        let mut packet = Vec::new();
445        packet.extend_from_slice(b"SSH-");
446        packet.extend_from_slice(proto.as_bytes());
447        packet.push(b'-');
448        packet.extend_from_slice(software.as_bytes());
449        if let Some(c) = comments {
450            packet.push(b' ');
451            packet.extend_from_slice(c.as_bytes());
452        }
453        packet.extend_from_slice(b"\r\n");
454        packet
455    }
456
457    fn create_ssh_packet(msg_type: u8, payload: &[u8]) -> Vec<u8> {
458        let mut packet = Vec::new();
459
460        let payload_size = 1 + payload.len();
461        let padding_needed = {
462            let base = 4 + 1 + payload_size;
463            let remainder = base % 8;
464            if remainder == 0 {
465                8
466            } else {
467                8 - remainder
468            }
469        };
470        let padding_length = padding_needed.max(4);
471        let packet_length = 1 + 1 + payload.len() + padding_length;
472
473        packet.extend_from_slice(&(packet_length as u32).to_be_bytes());
474        packet.push(padding_length as u8);
475        packet.push(msg_type);
476        packet.extend_from_slice(payload);
477        packet.extend(std::iter::repeat(0u8).take(padding_length));
478
479        packet
480    }
481
482    fn create_kexinit_payload() -> Vec<u8> {
483        let mut payload = Vec::new();
484        payload.extend_from_slice(&[0u8; 16]); // Cookie
485
486        let write_name_list = |buf: &mut Vec<u8>, list: &str| {
487            buf.extend_from_slice(&(list.len() as u32).to_be_bytes());
488            buf.extend_from_slice(list.as_bytes());
489        };
490
491        write_name_list(
492            &mut payload,
493            "curve25519-sha256,diffie-hellman-group14-sha256",
494        );
495        write_name_list(&mut payload, "ssh-ed25519,rsa-sha2-512");
496        write_name_list(
497            &mut payload,
498            "aes256-gcm@openssh.com,chacha20-poly1305@openssh.com",
499        );
500        write_name_list(
501            &mut payload,
502            "aes256-gcm@openssh.com,chacha20-poly1305@openssh.com",
503        );
504        write_name_list(
505            &mut payload,
506            "hmac-sha2-256-etm@openssh.com,hmac-sha2-512-etm@openssh.com",
507        );
508        write_name_list(
509            &mut payload,
510            "hmac-sha2-256-etm@openssh.com,hmac-sha2-512-etm@openssh.com",
511        );
512        write_name_list(&mut payload, "none,zlib@openssh.com");
513        write_name_list(&mut payload, "none,zlib@openssh.com");
514        write_name_list(&mut payload, "");
515        write_name_list(&mut payload, "");
516        payload.push(0); // first_kex_packet_follows
517        payload.extend_from_slice(&[0u8; 4]); // reserved
518
519        payload
520    }
521
522    #[test]
523    fn test_can_parse_ssh_by_port() {
524        let parser = SshProtocol;
525
526        let ctx1 = ParseContext::new(1);
527        assert!(parser.can_parse(&ctx1).is_none());
528
529        let mut ctx2 = ParseContext::new(1);
530        ctx2.insert_hint("dst_port", 22);
531        assert!(parser.can_parse(&ctx2).is_some());
532
533        let mut ctx3 = ParseContext::new(1);
534        ctx3.insert_hint("src_port", 22);
535        assert!(parser.can_parse(&ctx3).is_some());
536    }
537
538    #[test]
539    fn test_parse_ssh_identification_string() {
540        let packet = create_ssh_identification("2.0", "OpenSSH_8.9p1", Some("Ubuntu-3ubuntu0.1"));
541
542        let parser = SshProtocol;
543        let mut context = ParseContext::new(1);
544        context.insert_hint("dst_port", 22);
545
546        let result = parser.parse(&packet, &context);
547
548        assert!(result.is_ok());
549        assert_eq!(
550            result.get("protocol_version"),
551            Some(&FieldValue::OwnedString(CompactString::new("2.0")))
552        );
553        assert_eq!(
554            result.get("software_version"),
555            Some(&FieldValue::OwnedString(CompactString::new(
556                "OpenSSH_8.9p1"
557            )))
558        );
559        assert_eq!(
560            result.get("comments"),
561            Some(&FieldValue::OwnedString(CompactString::new(
562                "Ubuntu-3ubuntu0.1"
563            )))
564        );
565    }
566
567    #[test]
568    fn test_parse_client_identification() {
569        let packet = create_ssh_identification("2.0", "libssh2_1.10.0", None);
570
571        let parser = SshProtocol;
572        let mut context = ParseContext::new(1);
573        context.insert_hint("dst_port", 22);
574
575        let result = parser.parse(&packet, &context);
576
577        assert!(result.is_ok());
578        assert_eq!(
579            result.get("protocol_version"),
580            Some(&FieldValue::OwnedString(CompactString::new("2.0")))
581        );
582        assert_eq!(
583            result.get("software_version"),
584            Some(&FieldValue::OwnedString(CompactString::new(
585                "libssh2_1.10.0"
586            )))
587        );
588        assert!(result.get("comments").is_none());
589    }
590
591    #[test]
592    fn test_parse_server_identification() {
593        let packet = create_ssh_identification("2.0", "dropbear_2022.83", None);
594
595        let parser = SshProtocol;
596        let mut context = ParseContext::new(1);
597        context.insert_hint("src_port", 22);
598
599        let result = parser.parse(&packet, &context);
600
601        assert!(result.is_ok());
602        assert_eq!(
603            result.get("protocol_version"),
604            Some(&FieldValue::OwnedString(CompactString::new("2.0")))
605        );
606        assert_eq!(
607            result.get("software_version"),
608            Some(&FieldValue::OwnedString(CompactString::new(
609                "dropbear_2022.83"
610            )))
611        );
612    }
613
614    #[test]
615    fn test_protocol_version_extraction() {
616        let packet = create_ssh_identification("1.99", "OpenSSH_7.9", None);
617
618        let parser = SshProtocol;
619        let context = ParseContext::new(1);
620
621        let result = parser.parse(&packet, &context);
622
623        assert!(result.is_ok());
624        assert_eq!(
625            result.get("protocol_version"),
626            Some(&FieldValue::OwnedString(CompactString::new("1.99")))
627        );
628    }
629
630    #[test]
631    fn test_software_version_extraction() {
632        let packet = create_ssh_identification("2.0", "PuTTY_Release_0.78", None);
633
634        let parser = SshProtocol;
635        let context = ParseContext::new(1);
636
637        let result = parser.parse(&packet, &context);
638
639        assert!(result.is_ok());
640        assert_eq!(
641            result.get("software_version"),
642            Some(&FieldValue::OwnedString(CompactString::new(
643                "PuTTY_Release_0.78"
644            )))
645        );
646    }
647
648    #[test]
649    fn test_parse_kexinit_message() {
650        let kexinit_payload = create_kexinit_payload();
651        let packet = create_ssh_packet(msg_type::SSH_MSG_KEXINIT, &kexinit_payload);
652
653        let parser = SshProtocol;
654        let mut context = ParseContext::new(1);
655        context.insert_hint("dst_port", 22);
656
657        let result = parser.parse(&packet, &context);
658
659        assert!(result.is_ok());
660        assert_eq!(
661            result.get("msg_type"),
662            Some(&FieldValue::UInt8(msg_type::SSH_MSG_KEXINIT))
663        );
664        assert_eq!(
665            result.get("msg_type_name"),
666            Some(&FieldValue::Str("KEXINIT"))
667        );
668    }
669
670    #[test]
671    fn test_kex_algorithms_extraction() {
672        let kexinit_payload = create_kexinit_payload();
673        let packet = create_ssh_packet(msg_type::SSH_MSG_KEXINIT, &kexinit_payload);
674
675        let parser = SshProtocol;
676        let mut context = ParseContext::new(1);
677        context.insert_hint("dst_port", 22);
678
679        let result = parser.parse(&packet, &context);
680
681        assert!(result.is_ok());
682        assert_eq!(
683            result.get("kex_algorithms"),
684            Some(&FieldValue::OwnedString(CompactString::new(
685                "curve25519-sha256,diffie-hellman-group14-sha256"
686            )))
687        );
688    }
689
690    #[test]
691    fn test_encryption_algorithms_extraction() {
692        let kexinit_payload = create_kexinit_payload();
693        let packet = create_ssh_packet(msg_type::SSH_MSG_KEXINIT, &kexinit_payload);
694
695        let parser = SshProtocol;
696        let mut context = ParseContext::new(1);
697        context.insert_hint("dst_port", 22);
698
699        let result = parser.parse(&packet, &context);
700
701        assert!(result.is_ok());
702        assert_eq!(
703            result.get("encryption_algorithms"),
704            Some(&FieldValue::OwnedString(CompactString::new(
705                "aes256-gcm@openssh.com,chacha20-poly1305@openssh.com"
706            )))
707        );
708    }
709
710    #[test]
711    fn test_newkeys_detection() {
712        let packet = create_ssh_packet(msg_type::SSH_MSG_NEWKEYS, &[]);
713
714        let parser = SshProtocol;
715        let mut context = ParseContext::new(1);
716        context.insert_hint("dst_port", 22);
717
718        let result = parser.parse(&packet, &context);
719
720        assert!(result.is_ok());
721        assert_eq!(
722            result.get("msg_type"),
723            Some(&FieldValue::UInt8(msg_type::SSH_MSG_NEWKEYS))
724        );
725        assert_eq!(
726            result.get("msg_type_name"),
727            Some(&FieldValue::Str("NEWKEYS"))
728        );
729    }
730
731    #[test]
732    fn test_post_encryption_packet_size() {
733        let mut packet = Vec::new();
734        let encrypted_length: u32 = 128;
735        packet.extend_from_slice(&encrypted_length.to_be_bytes());
736        packet.push(16);
737        packet.extend(std::iter::repeat(0xFFu8).take(127));
738
739        let parser = SshProtocol;
740        let mut context = ParseContext::new(1);
741        context.insert_hint("dst_port", 22);
742
743        let result = parser.parse(&packet, &context);
744
745        assert_eq!(result.get("packet_length"), Some(&FieldValue::UInt32(128)));
746    }
747
748    #[test]
749    fn test_ssh_schema_fields() {
750        let parser = SshProtocol;
751        let fields = parser.schema_fields();
752
753        assert!(!fields.is_empty());
754
755        let field_names: Vec<&str> = fields.iter().map(|f| f.name).collect();
756        assert!(field_names.contains(&"ssh.protocol_version"));
757        assert!(field_names.contains(&"ssh.software_version"));
758        assert!(field_names.contains(&"ssh.msg_type"));
759        assert!(field_names.contains(&"ssh.kex_algorithms"));
760        assert!(field_names.contains(&"ssh.encryption_algorithms"));
761    }
762
763    #[test]
764    fn test_ssh_too_short() {
765        let short_packet = vec![0x00, 0x00, 0x00];
766
767        let parser = SshProtocol;
768        let context = ParseContext::new(1);
769
770        let result = parser.parse(&short_packet, &context);
771
772        assert!(!result.is_ok());
773    }
774
775    #[test]
776    fn test_userauth_request_parsing() {
777        let mut payload = Vec::new();
778        let username = b"testuser";
779        payload.extend_from_slice(&(username.len() as u32).to_be_bytes());
780        payload.extend_from_slice(username);
781        let service = b"ssh-connection";
782        payload.extend_from_slice(&(service.len() as u32).to_be_bytes());
783        payload.extend_from_slice(service);
784        let method = b"publickey";
785        payload.extend_from_slice(&(method.len() as u32).to_be_bytes());
786        payload.extend_from_slice(method);
787
788        let packet = create_ssh_packet(msg_type::SSH_MSG_USERAUTH_REQUEST, &payload);
789
790        let parser = SshProtocol;
791        let mut context = ParseContext::new(1);
792        context.insert_hint("dst_port", 22);
793
794        let result = parser.parse(&packet, &context);
795
796        assert!(result.is_ok());
797        assert_eq!(
798            result.get("msg_type"),
799            Some(&FieldValue::UInt8(msg_type::SSH_MSG_USERAUTH_REQUEST))
800        );
801        assert_eq!(
802            result.get("auth_username"),
803            Some(&FieldValue::OwnedString(CompactString::new("testuser")))
804        );
805        assert_eq!(
806            result.get("auth_service"),
807            Some(&FieldValue::OwnedString(CompactString::new(
808                "ssh-connection"
809            )))
810        );
811        assert_eq!(
812            result.get("auth_method"),
813            Some(&FieldValue::OwnedString(CompactString::new("publickey")))
814        );
815    }
816
817    #[test]
818    fn test_channel_open_parsing() {
819        let mut payload = Vec::new();
820        let channel_type = b"session";
821        payload.extend_from_slice(&(channel_type.len() as u32).to_be_bytes());
822        payload.extend_from_slice(channel_type);
823        let channel_id: u32 = 0;
824        payload.extend_from_slice(&channel_id.to_be_bytes());
825        payload.extend_from_slice(&0x00200000u32.to_be_bytes());
826        payload.extend_from_slice(&0x00008000u32.to_be_bytes());
827
828        let packet = create_ssh_packet(msg_type::SSH_MSG_CHANNEL_OPEN, &payload);
829
830        let parser = SshProtocol;
831        let mut context = ParseContext::new(1);
832        context.insert_hint("dst_port", 22);
833
834        let result = parser.parse(&packet, &context);
835
836        assert!(result.is_ok());
837        assert_eq!(
838            result.get("channel_type"),
839            Some(&FieldValue::OwnedString(CompactString::new("session")))
840        );
841        assert_eq!(result.get("channel_id"), Some(&FieldValue::UInt32(0)));
842    }
843}