motorcortex_rust/
request.rs

1use crate::get_parameter_value::GetParameterValue;
2use crate::motorcortex_msg::{
3    DataType, LoginMsg, LogoutMsg, SetParameterMsg, StatusCode, StatusMsg,
4};
5use crate::parameter_tree::ParameterTree;
6use crate::set_parameter_value::SetParameterValue;
7use crate::{
8    get_hash, get_hash_size, GetParameterMsg, GetParameterTreeMsg, Hash, ParameterMsg,
9    ParameterTreeMsg,
10};
11use nng_c_sys::nng_close;
12use prost::{DecodeError, Message};
13use std::ffi::{CStr, CString};
14use std::ptr;
15
16/// Represents the possible states of the connection.
17#[derive(Debug)]
18pub enum ConnectionState {
19    Connecting,
20    ConnectionOk,
21    ConnectionLost,
22    ConnectionFailed,
23    Disconnecting,
24    Disconnected,
25}
26
27/// Options used to configure the connection settings.
28pub struct ConnectionOptions {
29    /// Path to the TLS certificate used for secure communication.
30    pub certificate: String,
31    /// Connection timeout in milliseconds.
32    pub conn_timeout_ms: u32,
33    /// I/O timeout in milliseconds.
34    pub io_timeout_ms: u32,
35}
36
37impl ConnectionOptions {
38    /// Creates a new `ConnectionOptions` instance.
39    ///
40    /// # Arguments:
41    /// * `certificate` - Path to the TLS certificate.
42    /// * `conn_timeout_ms` - Timeout value for connection establishment, in milliseconds.
43    /// * `io_timeout_ms` - Timeout value for I/O operations, in milliseconds.
44    pub fn new(certificate: String, conn_timeout_ms: u32, io_timeout_ms: u32) -> Self {
45        Self {
46            certificate,
47            conn_timeout_ms,
48            io_timeout_ms,
49        }
50    }
51}
52
53/// Represents a request client that manages socket-based connections
54/// and interacts with the server.
55pub struct Request {
56    /// NNG socket instance for communications.
57    sock: Option<nng_c_sys::nng_socket>,
58    /// Optional TLS configuration for secure communication.
59    tls_cfg: Option<*mut nng_c_sys::nng_tls_config>,
60    /// Optional NNG dialer for managing the connection lifecycle.
61    dialer: Option<nng_c_sys::nng_dialer>,
62    /// Client's local representation of the parameter hierarchy.
63    parameter_tree: ParameterTree,
64}
65
66impl Request {
67    /// Creates a new `Request` instance with uninitialized members.
68    pub fn new() -> Self {
69        Self {
70            sock: None,
71            tls_cfg: None,
72            dialer: None,
73            parameter_tree: ParameterTree::new(),
74        }
75    }
76
77    /// Establishes a connection to a specified server URL using the given configuration options.
78    ///
79    /// # Arguments:
80    /// * `url` - The server's address (e.g., "tcp://127.0.0.1:5555").
81    /// * `connection_options` - The connection settings, including TLS certificate and timeouts.
82    ///
83    /// # Returns:
84    /// * `Ok(())` if the connection is successful.
85    /// * `Err(String)` with the error message if the operation fails.
86    pub fn connect(
87        &mut self,
88        url: &str,
89        connection_options: ConnectionOptions,
90    ) -> Result<(), String> {
91        self.sock = unsafe {
92            // Create socket
93            let mut sock: nng_c_sys::nng_socket = std::mem::zeroed();
94            let rv = nng_c_sys::nng_req0_open(&mut sock);
95            if rv != 0 {
96                return Err(format!(
97                    "Failed to open req0 socket: {} ({})",
98                    self.nng_error_to_string(nng_c_sys::nng_strerror(rv)),
99                    rv
100                ));
101            }
102
103            Some(sock)
104        };
105
106        if !connection_options.certificate.is_empty() {
107            let ca_file = CString::new(connection_options.certificate).unwrap();
108            self.tls_cfg = unsafe {
109                // Create tls config
110                let mut tls_cfg: *mut nng_c_sys::nng_tls_config = ptr::null_mut();
111                let mut rv = nng_c_sys::nng_tls_config_alloc(
112                    &mut tls_cfg,
113                    nng_c_sys::nng_tls_mode::NNG_TLS_MODE_CLIENT,
114                );
115                if rv != 0 {
116                    return Err(format!(
117                        "Failed to allocate certificate: {} ({})",
118                        self.nng_error_to_string(nng_c_sys::nng_strerror(rv)),
119                        rv
120                    ));
121                }
122
123                rv = nng_c_sys::nng_tls_config_ca_file(tls_cfg, ca_file.as_ptr());
124                if rv != 0 {
125                    return Err(format!(
126                        "Failed to load certificate from path {}, error: {} ({})",
127                        ca_file.into_string().unwrap(),
128                        self.nng_error_to_string(nng_c_sys::nng_strerror(rv)),
129                        rv
130                    ));
131                }
132
133                rv = nng_c_sys::nng_socket_set_ptr(
134                    self.sock.unwrap(),
135                    nng_c_sys::NNG_OPT_TLS_CONFIG.as_ptr() as *const i8,
136                    tls_cfg as *mut _,
137                );
138
139                if rv != 0 {
140                    return Err(format!(
141                        "Failed to apply certificate to the socket: {} ({})",
142                        self.nng_error_to_string(nng_c_sys::nng_strerror(rv)),
143                        rv
144                    ));
145                }
146                Some(tls_cfg)
147            };
148        }
149
150        self.dialer = unsafe {
151            let mut dialer: nng_c_sys::nng_dialer = std::mem::zeroed();
152            let c_url = CString::new(url).unwrap();
153            let rv = nng_c_sys::nng_dialer_create(&mut dialer, self.sock.unwrap(), c_url.as_ptr());
154            if rv != 0 {
155                return Err(format!(
156                    "Failed to create dealer: {} ({})",
157                    self.nng_error_to_string(nng_c_sys::nng_strerror(rv)),
158                    rv
159                ));
160            }
161
162            let rv = nng_c_sys::nng_dialer_start(dialer, 0);
163            if rv != 0 {
164                return Err(format!(
165                    "Failed to start dealer: {} ({})",
166                    self.nng_error_to_string(nng_c_sys::nng_strerror(rv)),
167                    rv
168                ));
169            }
170
171            Some(dialer)
172        };
173
174        Ok(())
175    }
176
177    /// Disconnects the current connection and frees the associated resources.
178    ///
179    /// # Returns:
180    /// * `Ok(())` if the disconnection is successful.
181    /// * `Err(String)` with the error message if the operation fails.
182    pub fn disconnect(&mut self) -> Result<(), String> {
183        unsafe {
184            // Close the socket if it exists
185            if let Some(sock) = self.sock.take() {
186                let rv = nng_close(sock);
187                if rv != 0 {
188                    return Err(format!(
189                        "Failed to close socket: {} ({})",
190                        self.nng_error_to_string(nng_c_sys::nng_strerror(rv)),
191                        rv
192                    ));
193                }
194            }
195
196            // Free the TLS configuration if it exists
197            if let Some(tls_cfg) = self.tls_cfg.take() {
198                nng_c_sys::nng_tls_config_free(tls_cfg);
199            }
200        }
201        Ok(())
202    }
203
204    /// Sends a login message to the server with the specified username and password.
205    ///
206    /// # Arguments:
207    /// * `username` - The username to authenticate with.
208    /// * `password` - The password for authentication.
209    ///
210    /// # Returns:
211    /// * `Ok(StatusCode)` - The status code representing the login response.
212    /// * `Err(String)` - An error message if the login fails.
213    pub fn login(&self, username: String, password: String) -> Result<StatusCode, String> {
214        let login_msg = LoginMsg {
215            header: None,
216            login: username,
217            password,
218        };
219
220        let buffer = Self::encode_with_hash(&login_msg)
221            .map_err(|e| format!("Failed to encode LoginMsg: {:?}", e))?;
222        self.send_message(&buffer)?;
223
224        let slice = self
225            .receive_message()
226            .map_err(|_| "Failed to receive status message".to_string())?;
227        let msg = Self::decode_status_msg(slice)
228            .map_err(|e| format!("Failed to decode status message: {:?}", e))?;
229
230        Ok(StatusCode::try_from(msg.status).unwrap())
231    }
232
233    /// Sends a logout message to the server.
234    ///
235    /// # Returns:
236    /// * `Ok(StatusCode)` - The status code representing the logout response.
237    /// * `Err(String)` - An error message if the logout fails.
238    pub fn logout(&self) -> Result<StatusCode, String> {
239        let logout_msg = LogoutMsg { header: None };
240
241        let buffer = Self::encode_with_hash(&logout_msg)
242            .map_err(|e| format!("Failed to encode LogoutMsg: {:?}", e))?;
243        self.send_message(&buffer)?;
244
245        let slice = self
246            .receive_message()
247            .map_err(|_| "Failed to receive status message".to_string())?;
248        let msg = Self::decode_status_msg(slice)
249            .map_err(|e| format!("Failed to decode status message: {:?}", e))?;
250
251        Ok(StatusCode::try_from(msg.status).unwrap())
252    }
253
254    /// Requests and updates the client's parameter tree from the server.
255    ///
256    /// # Returns:
257    /// * `Ok(StatusCode)` - The status code indicating the result of the request.
258    /// * `Err(String)` - An error message if the request fails.
259    pub fn request_parameter_tree(&mut self) -> Result<StatusCode, String> {
260        match self.get_parameter_tree() {
261            Ok((status_code, parameter_tree)) => {
262                self.parameter_tree = parameter_tree;
263                Ok(status_code)
264            }
265            Err(e) => Err(e),
266        }
267    }
268
269    /// Updates a specific parameter on the server with the provided value.
270    ///
271    /// # Arguments:
272    /// * `path` - The hierarchical path to the parameter being updated.
273    /// * `value` - The new value for the parameter. Must implement `SetParameterValue`.
274    ///
275    /// # Returns:
276    /// * `Ok(StatusCode)` - The status code after setting the parameter.
277    /// * `Err(String)` - An error message if the update fails.
278    pub fn set_parameter<V>(&self, path: &str, value: V) -> Result<StatusCode, String>
279    where
280        V: SetParameterValue + Default + PartialEq,
281    {
282        let data_type = self
283            .parameter_tree
284            .get_parameter_data_type(&path)
285            .ok_or((
286                StatusCode::WrongParameterPath,
287                format!("Parameter data type not found for path: {}", path),
288            ))
289            .unwrap();
290
291        let mut msg = SetParameterMsg {
292            header: None,
293            offset: None,
294            path: path.to_string(),
295            value: vec![],
296        };
297
298        match DataType::try_from(data_type as i32).unwrap() {
299            DataType::Bool => {
300                msg.value = value.to_bytes_as_bool();
301            }
302            DataType::Int8 => {
303                msg.value = value.to_bytes_as_i8();
304            }
305            DataType::Uint8 => {
306                msg.value = value.to_bytes_as_u8();
307            }
308            DataType::Int16 => {
309                msg.value = value.to_bytes_as_i16();
310            }
311            DataType::Uint16 => {
312                msg.value = value.to_bytes_as_u16();
313            }
314            DataType::Int32 => {
315                msg.value = value.to_bytes_as_i32();
316            }
317            DataType::Uint32 => {
318                msg.value = value.to_bytes_as_u32();
319            }
320            DataType::Int64 => {
321                msg.value.extend(value.to_bytes_as_i64());
322            }
323            DataType::Uint64 => {
324                msg.value = value.to_bytes_as_u64();
325            }
326            DataType::Float => {
327                msg.value.extend(value.to_bytes_as_f32());
328            }
329            DataType::Double => {
330                msg.value.extend(value.to_bytes_as_f64());
331            }
332            DataType::String => {
333                msg.value.extend(value.to_bytes_as_string());
334            }
335            _ => println!("Not implemented"),
336        }
337
338        let buffer = Self::encode_with_hash(&msg)
339            .map_err(|e| format!("Failed to encode SetParameter: {:?}", e))?;
340        self.send_message(&buffer)?;
341
342        let slice = self
343            .receive_message()
344            .map_err(|_| "Failed to receive status message".to_string())?;
345        let msg = Self::decode_status_msg(slice)
346            .map_err(|e| format!("Failed to decode status message: {:?}", e))?;
347
348        Ok(StatusCode::try_from(msg.status).unwrap())
349    }
350
351    /// Retrieves the value of a parameter from the server for the given path.
352    ///
353    /// # Arguments:
354    /// * `path` - The hierarchical path of the parameter to retrieve.
355    ///
356    /// # Type Parameters:
357    /// * `V` - The expected value type of the parameter. This type must implement the `GetParameterValue` trait
358    ///         to properly decode the parameter's value from the server response.
359    ///
360    /// # Returns:
361    /// * `Ok(V)` - The parameter value successfully retrieved and decoded into the type `V`.
362    /// * `Err(String)` - An error message if the retrieval or decoding fails.
363    ///
364    /// # Errors:
365    /// This function will return an error if:
366    /// * The parameter's data type cannot be identified.
367    /// * The path to the parameter is invalid or non-existent.
368    /// * There is an issue encoding the request or decoding the response from the server.
369    ///
370    /// # Example:
371    /// ```rust
372    /// # use your_crate::Request;
373    /// # // Assume `Request::get_parameter` is implemented as shown.
374    /// let request = Request::new();
375    /// let connection_options = ConnectionOptions::new(
376    ///     "cert.pem".to_string(),
377    ///     1000,
378    ///     1000,
379    /// );
380    /// request.connect("tcp://127.0.0.1:5555", connection_options).unwrap();
381    ///
382    /// // Retrieve a numeric parameter
383    /// let param_value: i32 = request.get_parameter("parameter.path").unwrap();
384    /// println!("Parameter value: {}", param_value);
385    /// ```
386    pub fn get_parameter<V>(&self, path: &str) -> Result<V, String>
387    where
388        V: GetParameterValue,
389    {
390        let data_type = self
391            .parameter_tree
392            .get_parameter_data_type(&path)
393            .ok_or((
394                StatusCode::WrongParameterPath,
395                format!("Parameter data type not found for path: {}", path),
396            ))
397            .unwrap();
398
399        let msg = GetParameterMsg {
400            header: None,
401            path: path.to_string(),
402        };
403
404        let buffer = Self::encode_with_hash(&msg)
405            .map_err(|e| format!("Failed to encode GetParameter: {:?}", e))?;
406        self.send_message(&buffer)?;
407
408        let slice = self
409            .receive_message()
410            .map_err(|_| "Failed to receive parameter message".to_string())?;
411        let msg = Self::decode_parameter_msg(slice)
412            .map_err(|e| format!("Failed to decode parameter message: {:?}", e))?;
413
414        println!("{:?}", msg);
415
416        match DataType::try_from(data_type as i32).unwrap() {
417            DataType::Bool => {
418                return V::from_bytes_as_bool(&msg.value);
419            }
420            DataType::Int8 => {
421                return V::from_bytes_as_i8(&msg.value);
422            }
423            DataType::Uint8 => {
424                return V::from_bytes_as_u8(&msg.value);
425            }
426            DataType::Int16 => {
427                return V::from_bytes_as_i16(&msg.value);
428            }
429            DataType::Uint16 => {
430                return V::from_bytes_as_u16(&msg.value);
431            }
432            DataType::Int32 => {
433                return V::from_bytes_as_i32(&msg.value);
434            }
435            DataType::Uint32 => {
436                return V::from_bytes_as_u32(&msg.value);
437            }
438            DataType::Int64 => {
439                return V::from_bytes_as_i64(&msg.value);
440            }
441            DataType::Uint64 => {
442                return V::from_bytes_as_u64(&msg.value);
443            }
444            DataType::Float => {
445                return V::from_bytes_as_f32(&msg.value);
446            }
447            DataType::Double => {
448                return V::from_bytes_as_f64(&msg.value);
449            }
450            DataType::String => {
451                return V::from_bytes_as_string(&msg.value);
452            }
453            _ => println!("Not implemented"),
454        }
455
456        Err("Not implemented".to_string())
457        // V::decode_from_bytes(&msg.value)
458    }
459
460    /// Retrieves the server's parameter tree.
461    ///
462    /// This function sends a `GetParameterTreeMsg` request to the server and decodes the response
463    /// into a `ParameterTree` object along with a status code indicating the result.
464    ///
465    /// # Returns:
466    /// * `Ok((StatusCode, ParameterTree))` - A tuple containing the status code and the retrieved `ParameterTree`.
467    /// * `Err(String)` - An error message if retrieving or decoding the parameter tree fails.
468    pub fn get_parameter_tree(&self) -> Result<(StatusCode, ParameterTree), String> {
469        let get_parameter_tree = GetParameterTreeMsg { header: None };
470        let buffer = Self::encode_with_hash(&get_parameter_tree)
471            .map_err(|e| format!("Failed to encode GetParameterTreeMsg: {:?}", e))?;
472        self.send_message(&buffer)?;
473
474        let slice = self
475            .receive_message()
476            .map_err(|_| "Failed to receive parameter tree message".to_string())?;
477        let msg = Self::decode_parameter_tree_msg(slice)
478            .map_err(|e| format!("Failed to decode parameter tree message: {:?}", e))?;
479
480        match ParameterTree::from_message(msg) {
481            Some(parameter_tree) => Ok((StatusCode::Ok, parameter_tree)),
482            None => Err("Failed to create ParameterTree: Invalid status code.".to_string()),
483        }
484    }
485
486    /// Utility function to convert NNG error messages to human-readable Rust strings.
487    ///
488    /// # Arguments:
489    /// * `err_msg` - A pointer to the raw NNG error message (C string).
490    ///
491    /// # Returns:
492    /// * A `String` representation of the error.
493    fn nng_error_to_string(&self, err_msg: *const core::ffi::c_char) -> String {
494        unsafe {
495            if err_msg.is_null() {
496                // If the pointer is null, return a default error message
497                return "Unknown error".to_string();
498            }
499
500            // Convert the `*const c_char` into a &CStr
501            let c_str = CStr::from_ptr(err_msg);
502
503            // Safely convert the C string into a Rust String
504            c_str.to_string_lossy().into_owned()
505        }
506    }
507
508    /// Encodes a message with its associated hash into a byte buffer for transport.
509    ///
510    /// # Arguments:
511    /// * `message` - The message to encode. Must implement `prost::Message` and `Hash`.
512    ///
513    /// # Returns:
514    /// * `Ok(Vec<u8>)` - The encoded message as a byte buffer.
515    /// * `Err(String)` - An error message if encoding fails.
516    fn encode_with_hash<M: Message + Hash>(message: &M) -> Result<Vec<u8>, String> {
517        let mut buffer: Vec<u8> = Vec::new();
518        buffer.extend(get_hash::<M>().to_le_bytes());
519        message
520            .encode(&mut buffer)
521            .map_err(|e| format!("Failed to encode message: {:?}", e))?;
522
523        Ok(buffer)
524    }
525
526    /// Decodes a byte slice into a specific Protobuf message type.
527    ///
528    /// # Arguments:
529    /// * `reply_slice` - The received byte slice containing the message data.
530    ///
531    /// # Returns:
532    /// * `Ok(T)` - The decoded message of type `T`.
533    /// * `Err(DecodeError)` - An error if decoding fails.
534    pub fn decode_message<T: Message + Default + Hash>(
535        reply_slice: &[u8],
536    ) -> Result<T, DecodeError> {
537        let hash_size = get_hash_size();
538
539        // Verify the provided slice is large enough to contain the hash
540        if hash_size > reply_slice.len() {
541            return Err(DecodeError::new("Invalid message length, hash missing"));
542        }
543
544        // Extract the hash from the slice
545        let provided_hash = u32::from_le_bytes(
546            reply_slice[..hash_size]
547                .try_into()
548                .map_err(|_| DecodeError::new("Failed to extract hash"))?,
549        );
550
551        // Validate the provided hash against the expected one for the generic type T
552        if provided_hash != get_hash::<T>() {
553            return Err(DecodeError::new("Invalid message hash"));
554        }
555
556        // Decode the rest of the slice into a Protobuf message
557        let decode_slice = &reply_slice[hash_size..];
558        T::decode(decode_slice)
559    }
560
561    /// Decodes a `ParameterTreeMsg` from a byte slice.
562    ///
563    /// This is a helper function for decoding parameter tree messages sent from the server.
564    ///
565    /// # Arguments:
566    /// * `reply_slice` - The received buffer containing the encoded `ParameterTreeMsg`.
567    ///
568    /// # Returns:
569    /// * `Ok(ParameterTreeMsg)` - The decoded parameter tree message.
570    /// * `Err(DecodeError)` - An error if the decoding fails.
571    fn decode_parameter_tree_msg(reply_slice: &[u8]) -> Result<ParameterTreeMsg, DecodeError> {
572        Self::decode_message::<ParameterTreeMsg>(reply_slice)
573    }
574
575    /// Decodes a `StatusMsg` from a received response slice.
576    ///
577    /// This is a helper function for decoding `StatusMsg` responses from the server.
578    ///
579    /// # Arguments:
580    /// * `reply_slice` - The received buffer containing the encoded `StatusMsg`.
581    ///
582    /// # Returns:
583    /// * `Ok(StatusMsg)` - The decoded status message.
584    /// * `Err(DecodeError)` - An error if decoding the `StatusMsg` fails.
585    fn decode_status_msg(reply_slice: &[u8]) -> Result<StatusMsg, DecodeError> {
586        Self::decode_message::<StatusMsg>(reply_slice)
587    }
588
589    fn decode_parameter_msg(reply_slice: &[u8]) -> Result<ParameterMsg, DecodeError> {
590        Self::decode_message::<ParameterMsg>(reply_slice)
591    }
592
593    /// Receives a raw message from the server using the NNG socket.
594    ///
595    /// This function allocates a buffer to receive the message
596    /// and returns the message data as a byte slice.
597    ///
598    /// # Returns:
599    /// * `Ok(&[u8])` - A byte slice containing the received message data.
600    /// * `Err(String)` - An error message if reception fails.
601    fn receive_message(&self) -> Result<&[u8], String> {
602        unsafe {
603            // Allocate memory to receive the response
604            let mut reply_ptr = ptr::null_mut();
605            let mut reply_size = 0usize;
606
607            // Use nng_recv to wait for and receive the response
608            let rv = nng_c_sys::nng_recv(
609                self.sock.unwrap(),
610                &mut reply_ptr as *mut _ as *mut std::ffi::c_void,
611                &mut reply_size,
612                nng_c_sys::NNG_FLAG_ALLOC,
613            );
614
615            // Check for errors in the reception operation
616            if rv != 0 {
617                return Err(format!(
618                    "Failed to receive response via NNG. Error code: {}",
619                    rv
620                ));
621            }
622            Ok(std::slice::from_raw_parts(
623                reply_ptr as *const u8,
624                reply_size,
625            ))
626        }
627    }
628
629    /// Sends the provided data buffer to the server using the NNG socket.
630    ///
631    /// This function sends a message buffer over the active NNG socket connection.
632    ///
633    /// # Arguments:
634    /// * `buffer` - A byte slice containing the message to send.
635    ///
636    /// # Returns:
637    /// * `Ok(())` - If the message was successfully sent.
638    /// * `Err(String)` - If sending the message fails.
639    fn send_message(&self, buffer: &[u8]) -> Result<(), String> {
640        unsafe {
641            // Create raw pointer for the buffer
642            let data_ptr = buffer.as_ptr() as *mut std::ffi::c_void;
643            let data_len = buffer.len();
644
645            // Use nng_send API
646            let sock = self.sock.ok_or("Socket is not available. Connect first.")?;
647            let rv = nng_c_sys::nng_send(sock, data_ptr, data_len, 0);
648
649            // Check if the send operation was successful
650            if rv != 0 {
651                return Err(format!(
652                    "Failed to send message via NNG. Error code: {}",
653                    rv
654                ));
655            }
656        }
657
658        Ok(())
659    }
660}