Skip to main content

radixox_common/
lib.rs

1use bytes::BytesMut;
2use monoio::{buf::IoBufMut, io::OwnedReadHalf, net::TcpStream};
3use prost::{EncodeError, Message};
4
5use crate::network::{
6    NetError, ResponseResult, net_response::NetResponseResult, net_success_response::Body,
7};
8
9pub mod protocol;
10
11// ============================================================================
12// NETWORK MODULE - Protobuf types and validation
13// ============================================================================
14
15pub mod network {
16    use bytes::Bytes;
17
18    use crate::{
19        NetValidate,
20        network::net_command::NetAction,
21        parse_response_result,
22        protocol::{Command, CommandAction},
23    };
24
25    // Include generated protobuf code
26    include!(concat!(env!("OUT_DIR"), "/radixox.rs"));
27
28    /// Network-level errors
29    #[derive(Debug)]
30    pub enum NetError {
31        NetError(String),
32        CommandEmpty,
33        GetEmpty,
34        SetEmpty,
35        PrefixNotAscii,
36        KeyNotAscii,
37        ResponseBodyEmpty,
38    }
39
40    /// Validated response from server
41    #[derive(Debug)]
42    pub struct Response {
43        pub command_id: u64,
44        pub result: ResponseResult,
45    }
46
47    /// Response result variants
48    #[derive(Debug)]
49    pub enum ResponseResult {
50        Empty,
51        Err(),
52        Data(Bytes),       // Single value (GET response)
53        Datas(Vec<Bytes>), // Multiple values (GETN response)
54    }
55
56    // ========================================================================
57    // VALIDATION IMPLEMENTATIONS
58    // ========================================================================
59
60    impl NetValidate<Command> for NetCommand {
61        fn validate(self) -> Result<Command, NetError> {
62            let Some(command_action) = self.net_action else {
63                return Err(NetError::CommandEmpty);
64            };
65            Ok(Command::new(command_action.validate()?, self.request_id))
66        }
67    }
68
69    impl NetValidate<CommandAction> for NetAction {
70        fn validate(self) -> Result<CommandAction, NetError> {
71            match self {
72                NetAction::Get(get) => {
73                    CommandAction::get(get.key).map_err(|_| NetError::KeyNotAscii)
74                }
75                NetAction::Set(set) => {
76                    CommandAction::set(set.key, set.value).map_err(|_| NetError::KeyNotAscii)
77                }
78                NetAction::Del(del) => {
79                    CommandAction::del(del.key).map_err(|_| NetError::KeyNotAscii)
80                }
81                NetAction::Getn(getn) => {
82                    CommandAction::getn(getn.prefix).map_err(|_| NetError::PrefixNotAscii)
83                }
84                NetAction::Deln(deln) => {
85                    CommandAction::deln(deln.prefix).map_err(|_| NetError::PrefixNotAscii)
86                }
87            }
88        }
89    }
90
91    impl NetValidate<Response> for NetResponse {
92        fn validate(self) -> Result<Response, NetError> {
93            Ok(Response {
94                result: parse_response_result(self.net_response_result)?,
95                command_id: self.request_id,
96            })
97        }
98    }
99}
100
101// ============================================================================
102// RESPONSE PARSING
103// ============================================================================
104
105fn parse_response_result(net_res: Option<NetResponseResult>) -> Result<ResponseResult, NetError> {
106    let Some(result) = net_res else {
107        return Ok(ResponseResult::Empty);
108    };
109
110    let success_val = match result {
111        NetResponseResult::Error(err) => return Err(NetError::NetError(err.message)),
112        NetResponseResult::Success(success_val) => success_val,
113    };
114
115    let body = success_val.body.ok_or(NetError::ResponseBodyEmpty)?;
116    match body {
117        Body::SingleValue(val) => Ok(ResponseResult::Data(val)),
118        Body::MultiValue(vals) => Ok(ResponseResult::Datas(vals.values)),
119    }
120}
121
122// ============================================================================
123// TRAITS
124// ============================================================================
125
126/// Validate network messages into typed commands
127pub trait NetValidate<T>
128where
129    Self: Sized,
130{
131    fn validate(self) -> Result<T, NetError>;
132}
133
134/// Read messages from TCP stream
135pub trait FromStream
136where
137    Self: Sized,
138{
139    fn from_stream(
140        stream: &mut OwnedReadHalf<TcpStream>,
141        buffer: &mut Vec<u8>,
142    ) -> std::io::Result<Self>;
143}
144
145/// Encode messages for network transmission
146pub trait NetEncode<T: IoBufMut> {
147    fn net_encode(&self, buffer: &mut T) -> Result<(), EncodeError>;
148}
149
150impl<T> NetEncode<BytesMut> for T
151where
152    T: Message,
153{
154    fn net_encode(&self, buffer: &mut BytesMut) -> Result<(), EncodeError> {
155        let start_idx = buffer.len();
156        // Write 4-byte placeholder for message size
157        buffer.extend_from_slice(0u32.to_be_bytes().as_slice());
158        self.encode(buffer)?;
159        // Update size field with actual message length
160        let msg_len_bytes = ((buffer.len() - start_idx - size_of::<u32>()) as u32).to_be_bytes();
161        for i in 0..4 {
162            buffer[i + start_idx] = msg_len_bytes[i];
163        }
164        Ok(())
165    }
166}
167
168// ============================================================================
169// TESTS
170// ============================================================================
171
172#[cfg(test)]
173mod test {
174    use bytes::BytesMut;
175    use prost::Message;
176
177    use crate::{
178        NetEncode,
179        network::{NetCommand, net_command::NetAction},
180    };
181
182    #[test]
183    fn test_encoding_get() {
184        let command = NetCommand {
185            request_id: 0,
186            net_action: Some(NetAction::Get(crate::network::NetGetRequest {
187                key: "user:1".into(),
188            })),
189        };
190        let mut buffer = BytesMut::new();
191        command.net_encode(&mut buffer).expect("encoding error");
192        let decoded = NetCommand::decode(&buffer[4..]).expect("decoding error");
193        assert_eq!(command, decoded);
194    }
195
196    #[test]
197    fn test_encoding_getn() {
198        let command = NetCommand {
199            request_id: 1,
200            net_action: Some(NetAction::Getn(crate::network::NetGetNRequest {
201                prefix: "user".into(),
202            })),
203        };
204        let mut buffer = BytesMut::new();
205        command.net_encode(&mut buffer).expect("encoding error");
206        let decoded = NetCommand::decode(&buffer[4..]).expect("decoding error");
207        assert_eq!(command, decoded);
208    }
209
210    #[test]
211    fn test_encoding_deln() {
212        let command = NetCommand {
213            request_id: 2,
214            net_action: Some(NetAction::Deln(crate::network::NetDelNRequest {
215                prefix: "session".into(),
216            })),
217        };
218        let mut buffer = BytesMut::new();
219        command.net_encode(&mut buffer).expect("encoding error");
220        let decoded = NetCommand::decode(&buffer[4..]).expect("decoding error");
221        assert_eq!(command, decoded);
222    }
223}