motorcortex_rust/
request.rs

1use crate::get_parameter_value::GetParameterValue;
2use crate::motorcortex_msg::{
3    DataType, LoginMsg, LogoutMsg, SetParameterMsg, StatusCode, StatusMsg,
4};
5use crate::parameter_tree::ParameterTree;
6use crate::set_parameter_value::SetParameterValue;
7use crate::{
8    get_hash, get_hash_size, GetParameterMsg, GetParameterTreeMsg, Hash, ParameterMsg,
9    ParameterTreeMsg,
10};
11use nng_c_sys::nng_close;
12use prost::{DecodeError, Message};
13use std::ffi::{CStr, CString};
14use std::ptr;
15
16/// Represents the possible states of the connection.
17#[derive(Debug)]
18pub enum ConnectionState {
19    Connecting,
20    ConnectionOk,
21    ConnectionLost,
22    ConnectionFailed,
23    Disconnecting,
24    Disconnected,
25}
26
27/// Options used to configure the connection settings.
28pub struct ConnectionOptions {
29    /// Path to the TLS certificate used for secure communication.
30    pub certificate: String,
31    /// Connection timeout in milliseconds.
32    pub conn_timeout_ms: u32,
33    /// I/O timeout in milliseconds.
34    pub io_timeout_ms: u32,
35}
36
37impl ConnectionOptions {
38    /// Creates a new `ConnectionOptions` instance.
39    ///
40    /// # Arguments:
41    /// * `certificate` - Path to the TLS certificate.
42    /// * `conn_timeout_ms` - Timeout value for connection establishment, in milliseconds.
43    /// * `io_timeout_ms` - Timeout value for I/O operations, in milliseconds.
44    pub fn new(certificate: String, conn_timeout_ms: u32, io_timeout_ms: u32) -> Self {
45        Self {
46            certificate,
47            conn_timeout_ms,
48            io_timeout_ms,
49        }
50    }
51}
52
53/// Represents a request client that manages socket-based connections
54/// and interacts with the server.
55pub struct Request {
56    /// NNG socket instance for communications.
57    sock: Option<nng_c_sys::nng_socket>,
58    /// Optional TLS configuration for secure communication.
59    tls_cfg: Option<*mut nng_c_sys::nng_tls_config>,
60    /// Optional NNG dialer for managing the connection lifecycle.
61    dialer: Option<nng_c_sys::nng_dialer>,
62    /// Client's local representation of the parameter hierarchy.
63    parameter_tree: ParameterTree,
64}
65
66impl Request {
67    /// Creates a new `Request` instance with uninitialized members.
68    pub fn new() -> Self {
69        Self {
70            sock: None,
71            tls_cfg: None,
72            dialer: None,
73            parameter_tree: ParameterTree::new(),
74        }
75    }
76
77    /// Establishes a connection to a specified server URL using the given configuration options.
78    ///
79    /// # Arguments:
80    /// * `url` - The server's address (e.g., "tcp://127.0.0.1:5555").
81    /// * `connection_options` - The connection settings, including TLS certificate and timeouts.
82    ///
83    /// # Returns:
84    /// * `Ok(())` if the connection is successful.
85    /// * `Err(String)` with the error message if the operation fails.
86    pub fn connect(
87        &mut self,
88        url: &str,
89        connection_options: ConnectionOptions,
90    ) -> Result<(), String> {
91        self.sock = unsafe {
92            // Create socket
93            let mut sock: nng_c_sys::nng_socket = std::mem::zeroed();
94            let rv = nng_c_sys::nng_req0_open(&mut sock);
95            if rv != 0 {
96                return Err(format!(
97                    "Failed to open req0 socket: {} ({})",
98                    self.nng_error_to_string(nng_c_sys::nng_strerror(rv)),
99                    rv
100                ));
101            }
102
103            Some(sock)
104        };
105
106        if !connection_options.certificate.is_empty() {
107            let ca_file = CString::new(connection_options.certificate).unwrap();
108            self.tls_cfg = unsafe {
109                // Create tls config
110                let mut tls_cfg: *mut nng_c_sys::nng_tls_config = ptr::null_mut();
111                let mut rv = nng_c_sys::nng_tls_config_alloc(
112                    &mut tls_cfg,
113                    nng_c_sys::nng_tls_mode::NNG_TLS_MODE_CLIENT,
114                );
115                if rv != 0 {
116                    return Err(format!(
117                        "Failed to allocate certificate: {} ({})",
118                        self.nng_error_to_string(nng_c_sys::nng_strerror(rv)),
119                        rv
120                    ));
121                }
122
123                rv = nng_c_sys::nng_tls_config_ca_file(tls_cfg, ca_file.as_ptr());
124                if rv != 0 {
125                    return Err(format!(
126                        "Failed to load certificate from path {}, error: {} ({})",
127                        ca_file.into_string().unwrap(),
128                        self.nng_error_to_string(nng_c_sys::nng_strerror(rv)),
129                        rv
130                    ));
131                }
132
133                rv = nng_c_sys::nng_socket_set_ptr(
134                    self.sock.unwrap(),
135                    nng_c_sys::NNG_OPT_TLS_CONFIG.as_ptr() as *const i8,
136                    tls_cfg as *mut _,
137                );
138
139                if rv != 0 {
140                    return Err(format!(
141                        "Failed to apply certificate to the socket: {} ({})",
142                        self.nng_error_to_string(nng_c_sys::nng_strerror(rv)),
143                        rv
144                    ));
145                }
146                Some(tls_cfg)
147            };
148        }
149
150        self.dialer = unsafe {
151            let mut dialer: nng_c_sys::nng_dialer = std::mem::zeroed();
152            let c_url = CString::new(url).unwrap();
153            let rv = nng_c_sys::nng_dialer_create(&mut dialer, self.sock.unwrap(), c_url.as_ptr());
154            if rv != 0 {
155                return Err(format!(
156                    "Failed to create dealer: {} ({})",
157                    self.nng_error_to_string(nng_c_sys::nng_strerror(rv)),
158                    rv
159                ));
160            }
161
162            let rv = nng_c_sys::nng_dialer_start(dialer, 0);
163            if rv != 0 {
164                return Err(format!(
165                    "Failed to start dealer: {} ({})",
166                    self.nng_error_to_string(nng_c_sys::nng_strerror(rv)),
167                    rv
168                ));
169            }
170
171            Some(dialer)
172        };
173
174        Ok(())
175    }
176
177    /// Disconnects the current connection and frees the associated resources.
178    ///
179    /// # Returns:
180    /// * `Ok(())` if the disconnection is successful.
181    /// * `Err(String)` with the error message if the operation fails.
182    pub fn disconnect(&mut self) -> Result<(), String> {
183        unsafe {
184            // Close the socket if it exists
185            if let Some(sock) = self.sock.take() {
186                let rv = nng_close(sock);
187                if rv != 0 {
188                    return Err(format!(
189                        "Failed to close socket: {} ({})",
190                        self.nng_error_to_string(nng_c_sys::nng_strerror(rv)),
191                        rv
192                    ));
193                }
194            }
195
196            // Free the TLS configuration if it exists
197            if let Some(tls_cfg) = self.tls_cfg.take() {
198                nng_c_sys::nng_tls_config_free(tls_cfg);
199            }
200        }
201        Ok(())
202    }
203
204    /// Sends a login message to the server with the specified username and password.
205    ///
206    /// # Arguments:
207    /// * `username` - The username to authenticate with.
208    /// * `password` - The password for authentication.
209    ///
210    /// # Returns:
211    /// * `Ok(StatusCode)` - The status code representing the login response.
212    /// * `Err(String)` - An error message if the login fails.
213    pub fn login(&self, username: String, password: String) -> Result<StatusCode, String> {
214        let login_msg = LoginMsg {
215            header: None,
216            login: username,
217            password,
218        };
219
220        let buffer = Self::encode_with_hash(&login_msg)
221            .map_err(|e| format!("Failed to encode LoginMsg: {:?}", e))?;
222        self.send_message(&buffer)?;
223
224        let slice = self
225            .receive_message()
226            .map_err(|_| "Failed to receive status message".to_string())?;
227        let msg = Self::decode_status_msg(slice)
228            .map_err(|e| format!("Failed to decode status message: {:?}", e))?;
229
230        Ok(StatusCode::try_from(msg.status).unwrap())
231    }
232
233    /// Sends a logout message to the server.
234    ///
235    /// # Returns:
236    /// * `Ok(StatusCode)` - The status code representing the logout response.
237    /// * `Err(String)` - An error message if the logout fails.
238    pub fn logout(&self) -> Result<StatusCode, String> {
239        let logout_msg = LogoutMsg { header: None };
240
241        let buffer = Self::encode_with_hash(&logout_msg)
242            .map_err(|e| format!("Failed to encode LogoutMsg: {:?}", e))?;
243        self.send_message(&buffer)?;
244
245        let slice = self
246            .receive_message()
247            .map_err(|_| "Failed to receive status message".to_string())?;
248        let msg = Self::decode_status_msg(slice)
249            .map_err(|e| format!("Failed to decode status message: {:?}", e))?;
250
251        Ok(StatusCode::try_from(msg.status).unwrap())
252    }
253
254    /// Requests and updates the client's parameter tree from the server.
255    ///
256    /// # Returns:
257    /// * `Ok(StatusCode)` - The status code indicating the result of the request.
258    /// * `Err(String)` - An error message if the request fails.
259    pub fn request_parameter_tree(&mut self) -> Result<StatusCode, String> {
260        match self.get_parameter_tree() {
261            Ok((status_code, parameter_tree)) => {
262                self.parameter_tree = parameter_tree;
263                Ok(status_code)
264            }
265            Err(e) => Err(e),
266        }
267    }
268
269    /// Updates a specific parameter on the server with the provided value.
270    ///
271    /// # Arguments:
272    /// * `path` - The hierarchical path to the parameter being updated.
273    /// * `value` - The new value for the parameter. Must implement `SetParameterValue`.
274    ///
275    /// # Returns:
276    /// * `Ok(StatusCode)` - The status code after setting the parameter.
277    /// * `Err(String)` - An error message if the update fails.
278    pub fn set_parameter<V>(&self, path: &str, value: V) -> Result<StatusCode, String>
279    where
280        V: SetParameterValue + Default + PartialEq,
281    {
282        let data_type = self
283            .parameter_tree
284            .get_parameter_data_type(&path)
285            .ok_or((
286                StatusCode::WrongParameterPath,
287                format!("Parameter data type not found for path: {}", path),
288            ))
289            .unwrap();
290
291        let mut msg = SetParameterMsg {
292            header: None,
293            offset: None,
294            path: path.to_string(),
295            value: vec![],
296        };
297
298        match DataType::try_from(data_type as i32).unwrap() {
299            DataType::Bool => {
300                msg.value = value.to_bytes_as_bool();
301            }
302            DataType::Int8 => {
303                msg.value = value.to_bytes_as_i8();
304            }
305            DataType::Uint8 => {
306                msg.value = value.to_bytes_as_u8();
307            }
308            DataType::Int16 => {
309                msg.value = value.to_bytes_as_i16();
310            }
311            DataType::Uint16 => {
312                msg.value = value.to_bytes_as_u16();
313            }
314            DataType::Int32 => {
315                msg.value = value.to_bytes_as_i32();
316            }
317            DataType::Uint32 => {
318                msg.value = value.to_bytes_as_u32();
319            }
320            DataType::Int64 => {
321                msg.value.extend(value.to_bytes_as_i64());
322            }
323            DataType::Uint64 => {
324                msg.value = value.to_bytes_as_u64();
325            }
326            DataType::Float => {
327                msg.value.extend(value.to_bytes_as_f32());
328            }
329            DataType::Double => {
330                msg.value.extend(value.to_bytes_as_f64());
331            }
332            DataType::String => {
333                msg.value.extend(value.to_bytes_as_string());
334            }
335            _ => println!("Not implemented"),
336        }
337
338        let buffer = Self::encode_with_hash(&msg)
339            .map_err(|e| format!("Failed to encode SetParameter: {:?}", e))?;
340        self.send_message(&buffer)?;
341
342        let slice = self
343            .receive_message()
344            .map_err(|_| "Failed to receive status message".to_string())?;
345        let msg = Self::decode_status_msg(slice)
346            .map_err(|e| format!("Failed to decode status message: {:?}", e))?;
347
348        Ok(StatusCode::try_from(msg.status).unwrap())
349    }
350
351    /// Retrieves the value of a parameter from the server for the given path.
352    ///
353    /// # Arguments:
354    /// * `path` - The hierarchical path of the parameter to retrieve.
355    ///
356    /// # Type Parameters:
357    /// * `V` - The expected value type of the parameter. This type must implement the `GetParameterValue` trait
358    ///         to properly decode the parameter's value from the server response.
359    ///
360    /// # Returns:
361    /// * `Ok(V)` - The parameter value successfully retrieved and decoded into the type `V`.
362    /// * `Err(String)` - An error message if the retrieval or decoding fails.
363    ///
364    /// # Errors:
365    /// This function will return an error if:
366    /// * The parameter's data type cannot be identified.
367    /// * The path to the parameter is invalid or non-existent.
368    /// * There is an issue encoding the request or decoding the response from the server.
369    ///
370    /// # Example:
371    /// ```rust
372    /// # use your_crate::Request;
373    /// # // Assume `Request::get_parameter` is implemented as shown.
374    /// let request = Request::new();
375    /// let connection_options = ConnectionOptions::new(
376    ///     "cert.pem".to_string(),
377    ///     1000,
378    ///     1000,
379    /// );
380    /// request.connect("wss://127.0.0.1:5568", connection_options).unwrap();
381    /// request.request_parameter_tree().unwrap();
382    /// // Retrieve a numeric parameter
383    /// let param_value: i32 = request.get_parameter("parameter.path").unwrap();
384    /// println!("Parameter value: {}", param_value);
385    /// ```
386    pub fn get_parameter<V>(&self, path: &str) -> Result<V, String>
387    where
388        V: GetParameterValue,
389    {
390        let data_type = self
391            .parameter_tree
392            .get_parameter_data_type(&path)
393            .ok_or((
394                StatusCode::WrongParameterPath,
395                format!("Parameter data type not found for path: {}", path),
396            ))
397            .unwrap();
398
399        let msg = GetParameterMsg {
400            header: None,
401            path: path.to_string(),
402        };
403
404        let buffer = Self::encode_with_hash(&msg)
405            .map_err(|e| format!("Failed to encode GetParameter: {:?}", e))?;
406        self.send_message(&buffer)?;
407
408        let slice = self
409            .receive_message()
410            .map_err(|_| "Failed to receive parameter message".to_string())?;
411        let msg = Self::decode_parameter_msg(slice)
412            .map_err(|e| format!("Failed to decode parameter message: {:?}", e))?;
413        
414        match DataType::try_from(data_type as i32).unwrap() {
415            DataType::Bool => {
416                return V::from_bytes_as_bool(&msg.value);
417            }
418            DataType::Int8 => {
419                return V::from_bytes_as_i8(&msg.value);
420            }
421            DataType::Uint8 => {
422                return V::from_bytes_as_u8(&msg.value);
423            }
424            DataType::Int16 => {
425                return V::from_bytes_as_i16(&msg.value);
426            }
427            DataType::Uint16 => {
428                return V::from_bytes_as_u16(&msg.value);
429            }
430            DataType::Int32 => {
431                return V::from_bytes_as_i32(&msg.value);
432            }
433            DataType::Uint32 => {
434                return V::from_bytes_as_u32(&msg.value);
435            }
436            DataType::Int64 => {
437                return V::from_bytes_as_i64(&msg.value);
438            }
439            DataType::Uint64 => {
440                return V::from_bytes_as_u64(&msg.value);
441            }
442            DataType::Float => {
443                return V::from_bytes_as_f32(&msg.value);
444            }
445            DataType::Double => {
446                return V::from_bytes_as_f64(&msg.value);
447            }
448            DataType::String => {
449                return V::from_bytes_as_string(&msg.value);
450            }
451            _ => println!("Not implemented"),
452        }
453
454        Err("Not implemented".to_string())
455        // V::decode_from_bytes(&msg.value)
456    }
457
458    /// Retrieves the server's parameter tree.
459    ///
460    /// This function sends a `GetParameterTreeMsg` request to the server and decodes the response
461    /// into a `ParameterTree` object along with a status code indicating the result.
462    ///
463    /// # Returns:
464    /// * `Ok((StatusCode, ParameterTree))` - A tuple containing the status code and the retrieved `ParameterTree`.
465    /// * `Err(String)` - An error message if retrieving or decoding the parameter tree fails.
466    pub fn get_parameter_tree(&self) -> Result<(StatusCode, ParameterTree), String> {
467        let get_parameter_tree = GetParameterTreeMsg { header: None };
468        let buffer = Self::encode_with_hash(&get_parameter_tree)
469            .map_err(|e| format!("Failed to encode GetParameterTreeMsg: {:?}", e))?;
470        self.send_message(&buffer)?;
471
472        let slice = self
473            .receive_message()
474            .map_err(|_| "Failed to receive parameter tree message".to_string())?;
475        let msg = Self::decode_parameter_tree_msg(slice)
476            .map_err(|e| format!("Failed to decode parameter tree message: {:?}", e))?;
477
478        match ParameterTree::from_message(msg) {
479            Some(parameter_tree) => Ok((StatusCode::Ok, parameter_tree)),
480            None => Err("Failed to create ParameterTree: Invalid status code.".to_string()),
481        }
482    }
483
484    /// Utility function to convert NNG error messages to human-readable Rust strings.
485    ///
486    /// # Arguments:
487    /// * `err_msg` - A pointer to the raw NNG error message (C string).
488    ///
489    /// # Returns:
490    /// * A `String` representation of the error.
491    fn nng_error_to_string(&self, err_msg: *const core::ffi::c_char) -> String {
492        unsafe {
493            if err_msg.is_null() {
494                // If the pointer is null, return a default error message
495                return "Unknown error".to_string();
496            }
497
498            // Convert the `*const c_char` into a &CStr
499            let c_str = CStr::from_ptr(err_msg);
500
501            // Safely convert the C string into a Rust String
502            c_str.to_string_lossy().into_owned()
503        }
504    }
505
506    /// Encodes a message with its associated hash into a byte buffer for transport.
507    ///
508    /// # Arguments:
509    /// * `message` - The message to encode. Must implement `prost::Message` and `Hash`.
510    ///
511    /// # Returns:
512    /// * `Ok(Vec<u8>)` - The encoded message as a byte buffer.
513    /// * `Err(String)` - An error message if encoding fails.
514    fn encode_with_hash<M: Message + Hash>(message: &M) -> Result<Vec<u8>, String> {
515        let mut buffer: Vec<u8> = Vec::new();
516        buffer.extend(get_hash::<M>().to_le_bytes());
517        message
518            .encode(&mut buffer)
519            .map_err(|e| format!("Failed to encode message: {:?}", e))?;
520
521        Ok(buffer)
522    }
523
524    /// Decodes a byte slice into a specific Protobuf message type.
525    ///
526    /// # Arguments:
527    /// * `reply_slice` - The received byte slice containing the message data.
528    ///
529    /// # Returns:
530    /// * `Ok(T)` - The decoded message of type `T`.
531    /// * `Err(DecodeError)` - An error if decoding fails.
532    pub fn decode_message<T: Message + Default + Hash>(
533        reply_slice: &[u8],
534    ) -> Result<T, DecodeError> {
535        let hash_size = get_hash_size();
536
537        // Verify the provided slice is large enough to contain the hash
538        if hash_size > reply_slice.len() {
539            return Err(DecodeError::new("Invalid message length, hash missing"));
540        }
541
542        // Extract the hash from the slice
543        let provided_hash = u32::from_le_bytes(
544            reply_slice[..hash_size]
545                .try_into()
546                .map_err(|_| DecodeError::new("Failed to extract hash"))?,
547        );
548
549        // Validate the provided hash against the expected one for the generic type T
550        if provided_hash != get_hash::<T>() {
551            return Err(DecodeError::new("Invalid message hash"));
552        }
553
554        // Decode the rest of the slice into a Protobuf message
555        let decode_slice = &reply_slice[hash_size..];
556        T::decode(decode_slice)
557    }
558
559    /// Decodes a `ParameterTreeMsg` from a byte slice.
560    ///
561    /// This is a helper function for decoding parameter tree messages sent from the server.
562    ///
563    /// # Arguments:
564    /// * `reply_slice` - The received buffer containing the encoded `ParameterTreeMsg`.
565    ///
566    /// # Returns:
567    /// * `Ok(ParameterTreeMsg)` - The decoded parameter tree message.
568    /// * `Err(DecodeError)` - An error if the decoding fails.
569    fn decode_parameter_tree_msg(reply_slice: &[u8]) -> Result<ParameterTreeMsg, DecodeError> {
570        Self::decode_message::<ParameterTreeMsg>(reply_slice)
571    }
572
573    /// Decodes a `StatusMsg` from a received response slice.
574    ///
575    /// This is a helper function for decoding `StatusMsg` responses from the server.
576    ///
577    /// # Arguments:
578    /// * `reply_slice` - The received buffer containing the encoded `StatusMsg`.
579    ///
580    /// # Returns:
581    /// * `Ok(StatusMsg)` - The decoded status message.
582    /// * `Err(DecodeError)` - An error if decoding the `StatusMsg` fails.
583    fn decode_status_msg(reply_slice: &[u8]) -> Result<StatusMsg, DecodeError> {
584        Self::decode_message::<StatusMsg>(reply_slice)
585    }
586
587    fn decode_parameter_msg(reply_slice: &[u8]) -> Result<ParameterMsg, DecodeError> {
588        Self::decode_message::<ParameterMsg>(reply_slice)
589    }
590
591    /// Receives a raw message from the server using the NNG socket.
592    ///
593    /// This function allocates a buffer to receive the message
594    /// and returns the message data as a byte slice.
595    ///
596    /// # Returns:
597    /// * `Ok(&[u8])` - A byte slice containing the received message data.
598    /// * `Err(String)` - An error message if reception fails.
599    fn receive_message(&self) -> Result<&[u8], String> {
600        unsafe {
601            // Allocate memory to receive the response
602            let mut reply_ptr = ptr::null_mut();
603            let mut reply_size = 0usize;
604
605            // Use nng_recv to wait for and receive the response
606            let rv = nng_c_sys::nng_recv(
607                self.sock.unwrap(),
608                &mut reply_ptr as *mut _ as *mut std::ffi::c_void,
609                &mut reply_size,
610                nng_c_sys::NNG_FLAG_ALLOC,
611            );
612
613            // Check for errors in the reception operation
614            if rv != 0 {
615                return Err(format!(
616                    "Failed to receive response via NNG. Error code: {}",
617                    rv
618                ));
619            }
620            Ok(std::slice::from_raw_parts(
621                reply_ptr as *const u8,
622                reply_size,
623            ))
624        }
625    }
626
627    /// Sends the provided data buffer to the server using the NNG socket.
628    ///
629    /// This function sends a message buffer over the active NNG socket connection.
630    ///
631    /// # Arguments:
632    /// * `buffer` - A byte slice containing the message to send.
633    ///
634    /// # Returns:
635    /// * `Ok(())` - If the message was successfully sent.
636    /// * `Err(String)` - If sending the message fails.
637    fn send_message(&self, buffer: &[u8]) -> Result<(), String> {
638        unsafe {
639            // Create raw pointer for the buffer
640            let data_ptr = buffer.as_ptr() as *mut std::ffi::c_void;
641            let data_len = buffer.len();
642
643            // Use nng_send API
644            let sock = self.sock.ok_or("Socket is not available. Connect first.")?;
645            let rv = nng_c_sys::nng_send(sock, data_ptr, data_len, 0);
646
647            // Check if the send operation was successful
648            if rv != 0 {
649                return Err(format!(
650                    "Failed to send message via NNG. Error code: {}",
651                    rv
652                ));
653            }
654        }
655
656        Ok(())
657    }
658}