hyprwire 0.2.7

A fast and consistent wire protocol for IPC
Documentation
pub mod bind_protocol;
pub mod fatal_protocol_error;
pub mod generic_protocol_message;
pub mod handshake_ack;
pub mod handshake_begin;
pub mod handshake_protocols;
pub mod hello;
pub mod new_object;
pub mod roundtrip_done;
pub mod roundtrip_request;

use crate::implementation::types;
use crate::message;
use std::fmt::Write;
use std::result;

pub type Result<T> = result::Result<T, message::Error>;

pub trait Message {
    fn data(&self) -> &[u8];

    fn message_type(&self) -> message::MessageType;

    fn fds(&self) -> &[i32] {
        &[]
    }

    fn parse_data(&self) -> String {
        let mut result = String::new();
        let data = self.data();

        let _ = write!(result, "{} ( ", self.message_type());

        let mut first = true;
        let mut needle: usize = 1;
        while needle < data.len() {
            let magic = types::MessageMagic::try_from(data[needle]).unwrap();
            needle += 1;

            match magic {
                types::MessageMagic::TypeSeq => {
                    if !first {
                        result.push_str(", ");
                    }
                    first = false;
                    let bytes: [u8; 4] = data[needle..needle + 4].try_into().unwrap();
                    let value = u32::from_le_bytes(bytes);
                    let _ = write!(result, "seq: {value}");
                    needle += 4;
                }
                types::MessageMagic::TypeUint => {
                    if !first {
                        result.push_str(", ");
                    }
                    first = false;
                    let bytes: [u8; 4] = data[needle..needle + 4].try_into().unwrap();
                    let value = u32::from_le_bytes(bytes);
                    let _ = write!(result, "{value}");
                    needle += 4;
                }
                types::MessageMagic::TypeInt => {
                    if !first {
                        result.push_str(", ");
                    }
                    first = false;
                    let bytes: [u8; 4] = data[needle..needle + 4].try_into().unwrap();
                    let value = i32::from_le_bytes(bytes);
                    let _ = write!(result, "{value}");
                    needle += 4;
                }
                types::MessageMagic::TypeF32 => {
                    if !first {
                        result.push_str(", ");
                    }
                    first = false;
                    let bytes: [u8; 4] = data[needle..needle + 4].try_into().unwrap();
                    let value = f32::from_le_bytes(bytes);
                    let _ = write!(result, "{value}");
                    needle += 4;
                }
                types::MessageMagic::TypeVarchar => {
                    if !first {
                        result.push_str(", ");
                    }
                    first = false;
                    let (len, int_len) = message::parse_var_int(data, needle);
                    if len > 0 {
                        let str_data = &data[needle + int_len..needle + int_len + len];
                        let s = String::from_utf8_lossy(str_data);
                        let _ = write!(result, "\"{s}\"");
                    } else {
                        result.push_str("\"\"");
                    }
                    needle += int_len + len;
                }
                types::MessageMagic::TypeArray => {
                    if !first {
                        result.push_str(", ");
                    }
                    first = false;
                    let type_byte = data[needle];
                    let this_type = types::MessageMagic::try_from(type_byte).unwrap();
                    needle += 1;

                    let (els, int_len) = message::parse_var_int(data, needle);
                    result.push_str("{ ");
                    needle += int_len;

                    for i in 0..els {
                        let (s, len) = format_primitive_type(&data[needle..], this_type).unwrap();

                        needle += len;
                        result.push_str(&s);
                        if i < els - 1 {
                            result.push_str(", ");
                        }
                    }

                    result.push_str(" }");
                }
                types::MessageMagic::TypeObject => {
                    if !first {
                        result.push_str(", ");
                    }
                    first = false;
                    let bytes: [u8; 4] = data[needle..needle + 4].try_into().unwrap();
                    let id = u32::from_le_bytes(bytes);
                    let _ = write!(result, "object({id})");
                    needle += 4;
                }
                types::MessageMagic::TypeFd => {
                    if !first {
                        result.push_str(", ");
                    }
                    first = false;
                    result.push_str("<fd>");
                }
                types::MessageMagic::End | types::MessageMagic::TypeObjectId => {}
            }
        }

        result.push_str(" ) ");
        result
    }
}

fn format_primitive_type(s: &[u8], r#type: types::MessageMagic) -> Result<(String, usize)> {
    match r#type {
        types::MessageMagic::TypeUint => {
            let bytes: [u8; 4] = s
                .get(0..4)
                .ok_or(message::Error::UnexpectedEof)?
                .try_into()
                .unwrap();
            let value = u32::from_le_bytes(bytes);
            Ok((value.to_string(), 4))
        }
        types::MessageMagic::TypeInt => {
            let bytes: [u8; 4] = s
                .get(0..4)
                .ok_or(message::Error::UnexpectedEof)?
                .try_into()
                .unwrap();
            let value = i32::from_le_bytes(bytes);
            Ok((value.to_string(), 4))
        }
        types::MessageMagic::TypeF32 => {
            let bytes: [u8; 4] = s
                .get(0..4)
                .ok_or(message::Error::UnexpectedEof)?
                .try_into()
                .unwrap();
            let value = f32::from_le_bytes(bytes);
            Ok((value.to_string(), 4))
        }
        types::MessageMagic::TypeFd => Ok(("<fd>".to_string(), 0)),
        types::MessageMagic::TypeObject => {
            let bytes: [u8; 4] = s
                .get(0..4)
                .ok_or(message::Error::UnexpectedEof)?
                .try_into()
                .unwrap();
            let id = u32::from_le_bytes(bytes);
            let obj_str = if id == 0 {
                "null".to_string()
            } else {
                id.to_string()
            };
            Ok((format!("object: {obj_str}"), 4))
        }
        types::MessageMagic::TypeVarchar => {
            let (len, int_len) = crate::message::parse_var_int(s, 0);
            let str_data = s
                .get(int_len..int_len + len)
                .ok_or(message::Error::UnexpectedEof)?;
            let value = String::from_utf8(str_data.to_vec())
                .map_err(|_| message::Error::MalformedMessage)?;
            Ok((format!("\"{value}\""), len + int_len))
        }
        _ => Err(message::Error::MalformedMessage),
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::implementation::types;
    use crate::message::messages::{
        bind_protocol::BindProtocol, fatal_protocol_error::FatalProtocolError,
        handshake_ack::HandshakeAck, handshake_protocols::HandshakeProtocols, hello::Hello,
        new_object::NewObject, roundtrip_done::RoundtripDone, roundtrip_request::RoundtripRequest,
    };

    struct TestMessage<'a> {
        data: &'a [u8],
        message_type: message::MessageType,
    }

    impl<'a> Message for TestMessage<'a> {
        fn data(&self) -> &[u8] {
            self.data
        }
        fn message_type(&self) -> message::MessageType {
            self.message_type
        }
    }

    #[test]
    fn parse_data_integer_types() {
        let bytes: &[u8] = &[
            message::MessageType::GenericProtocolMessage as u8,
            types::MessageMagic::TypeSeq as u8,
            0x01,
            0x00,
            0x00,
            0x00,
            types::MessageMagic::TypeInt as u8,
            0x01,
            0x00,
            0x00,
            0x00,
            types::MessageMagic::TypeF32 as u8,
            0x01,
            0x00,
            0x00,
            0x00,
            types::MessageMagic::End as u8,
        ];
        let msg = TestMessage {
            data: bytes,
            message_type: message::MessageType::GenericProtocolMessage,
        };
        let data = msg.parse_data();
        let expected_f32 = f32::from_le_bytes([0x01, 0x00, 0x00, 0x00]);
        assert_eq!(
            data,
            format!("GenericProtocolMessage ( seq: 1, 1, {expected_f32} ) ")
        );
    }

    #[test]
    fn parse_data_hello_contains_type_and_payload() {
        let msg = Hello::new();
        let parsed = msg.parse_data();
        assert!(parsed.contains("Sup"), "missing type in: {parsed}");
        assert!(parsed.contains("\"VAX\""), "missing payload in: {parsed}");
    }

    #[test]
    fn parse_data_handshake_ack_contains_type_and_version() {
        let msg = HandshakeAck::new(7);
        let parsed = msg.parse_data();
        assert!(parsed.contains("HandshakeAck"), "missing type in: {parsed}");
        assert!(parsed.contains('7'), "missing version in: {parsed}");
    }

    #[test]
    fn parse_data_handshake_protocols_contains_protocol_names() {
        let msg = HandshakeProtocols::new(&["proto@1", "second@2"]);
        let parsed = msg.parse_data();
        assert!(
            parsed.contains("HandshakeProtocols"),
            "missing type in: {parsed}"
        );
        assert!(
            parsed.contains("\"proto@1\""),
            "missing proto@1 in: {parsed}"
        );
        assert!(
            parsed.contains("\"second@2\""),
            "missing second@2 in: {parsed}"
        );
    }

    #[test]
    fn parse_data_bind_protocol_contains_core_fields() {
        let msg = BindProtocol::new("my_proto", 12, 3);
        let parsed = msg.parse_data();
        assert!(parsed.contains("BindProtocol"), "missing type in: {parsed}");
        assert!(parsed.contains("12"), "missing seq in: {parsed}");
        assert!(
            parsed.contains("\"my_proto\""),
            "missing protocol in: {parsed}"
        );
        assert!(parsed.contains('3'), "missing version in: {parsed}");
    }

    #[test]
    fn parse_data_new_object_contains_object_and_seq() {
        let msg = NewObject::new(9, 77);
        let parsed = msg.parse_data();
        assert!(parsed.contains("NewObject"), "missing type in: {parsed}");
        assert!(parsed.contains("77"), "missing id in: {parsed}");
        assert!(parsed.contains('9'), "missing seq in: {parsed}");
    }

    #[test]
    fn parse_data_fatal_error_contains_identifiers_and_message() {
        let msg = FatalProtocolError::new(0, 123, "oops");
        let parsed = msg.parse_data();
        assert!(
            parsed.contains("FatalProtocolError"),
            "missing type in: {parsed}"
        );
        assert!(parsed.contains("123"), "missing error_id in: {parsed}");
        assert!(
            parsed.contains("\"oops\""),
            "missing error_msg in: {parsed}"
        );
    }

    #[test]
    fn parse_data_roundtrip_messages_contain_type_and_sequence() {
        let req = RoundtripRequest::new(777);
        let done = RoundtripDone::new(888);

        let req_parsed = req.parse_data();
        let done_parsed = done.parse_data();

        assert!(
            req_parsed.contains("RoundtripRequest"),
            "missing type in: {req_parsed}"
        );
        assert!(req_parsed.contains("777"), "missing seq in: {req_parsed}");

        assert!(
            done_parsed.contains("RoundtripDone"),
            "missing type in: {done_parsed}"
        );
        assert!(done_parsed.contains("888"), "missing seq in: {done_parsed}");
    }

    #[test]
    fn parse_data_varchar_empty() {
        let bytes: &[u8] = &[
            message::MessageType::GenericProtocolMessage as u8,
            types::MessageMagic::TypeVarchar as u8,
            0x00,
            types::MessageMagic::End as u8,
        ];
        let msg = TestMessage {
            data: bytes,
            message_type: message::MessageType::GenericProtocolMessage,
        };
        let data = msg.parse_data();
        assert_eq!(data, "GenericProtocolMessage ( \"\" ) ");
    }
}