openiap_client/
ws.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
use tracing::{error, debug, trace};
use futures_util::{StreamExt};
use openiap_proto::{errors::OpenIAPError, protos::Envelope};
use prost::Message as _;
use tokio_tungstenite::{connect_async, tungstenite::protocol::Message};
use std::sync::Arc;
use tokio::sync::{Mutex};
use futures::SinkExt;
use bytes::{BytesMut, BufMut}; // Correct import for BufMut

use crate::Client;

impl Client {
    /// Setup a websocket connection to the server
    // pub async fn setup_ws(&self, strurl: &str) -> Result<(), Box<dyn std::error::Error>> {
    pub async fn setup_ws(&self, strurl: &str) -> Result<(), OpenIAPError> {
        let ws_stream = match connect_async(strurl).await {
            Ok((ws_stream, _)) => ws_stream,
            Err(e) => {
                error!("Failed to connect to websocket: {:?}", e);
                self.set_connected(false, Some(&e.to_string()));
                return Err(OpenIAPError::ClientError(e.to_string()));
            }            
        };
        trace!("WebSocket handshake has been successfully completed");
        let (mut write, mut read) = ws_stream.split();

        self.set_msgcount(-1); // Reset message count

        let envelope_receiver = self.out_envelope_receiver.clone();
        let me = self.clone();
        
        // Spawn sending task
        let sender = tokio::spawn(async move {
            while let Ok(envelope) = envelope_receiver.recv().await {
                if me.is_connected() == false {
                    error!("Failed to send message to websocket: not connected");
                    return;
                }
                let mut envelope = envelope;
                let command = envelope.command.clone();
                
                envelope.seq = me.inc_msgcount();
                if envelope.id.is_empty() {
                    envelope.id = envelope.seq.to_string();
                }

                if envelope.rid.is_empty() {
                    debug!("Send #{} #{} {} message", envelope.seq, envelope.id, command);
                } else {
                    debug!("Send #{} #{} (reply to #{}) {} message", envelope.seq, envelope.id, envelope.rid, command);
                }

                // Encode envelope and prepend length in little-endian
                let mut message = BytesMut::with_capacity(4 + envelope.encoded_len());
                message.put_u32_le(envelope.encoded_len() as u32);
                match envelope.encode(&mut message) {
                    Ok(_) => {},
                    Err(e) => {
                        error!("Failed to encode protobuf message: {:?}", e);
                        me.set_connected(false, Some(&e.to_string()));
                        return;
                    }                    
                };

                // Send the message
                if let Err(e) = write.send(Message::Binary(message.to_vec())).await {
                    error!("Failed to send {} message to websocket: {:?}", command, e);
                    me.set_connected(false, Some(&e.to_string()));
                    return;
                }
            }
        });

        let buffer = Arc::new(Mutex::new(BytesMut::with_capacity(4096))); // Pre-allocate buffer size
        let me = self.clone();

        // Reading task with backpressure handling
        let reader = tokio::spawn({
            let buffer = Arc::clone(&buffer);
            async move {
                while let Some(message) = read.next().await {
                    if me.is_connected() == false {
                        error!("Failed to send message to websocket: not connected");
                        return;
                    }
                    let data = match message {
                        Ok(msg) => msg.into_data(),
                        Err(e) => {
                            error!("Failed to receive message from websocket: {:?}", e);
                            me.set_connected(false, Some(&e.to_string()));
                            return;
                        }
                    };

                    let mut buffer = buffer.lock().await;
                    buffer.extend_from_slice(&data);

                    while buffer.len() >= 4 {
                        let size = u32::from_le_bytes([buffer[0], buffer[1], buffer[2], buffer[3]]) as usize;

                        if buffer.len() < 4 + size {
                            break; // Wait for more data
                        }

                        let payload = buffer.split_to(4 + size);
                        let payload = &payload[4..]; // Skip the size bytes

                        match Envelope::decode(payload) {
                            Ok(received) => {
                                me.parse_incomming_envelope(received).await;
                            },
                            Err(e) => {
                                error!("Failed to decode protobuf message: {:?}", e);
                            }
                        }
                    }
                }
            }
        });
        let on_disconnect_receiver = self.on_disconnect_receiver.clone();
        tokio::spawn(async move {
            match on_disconnect_receiver.recv().await {
                Ok(_) => {},
                Err(e) => {
                    error!("Failed to receive on_disconnect signal: {:?}", e);
                }
            };
            trace!("Killing the sender and reader for websocket");
            sender.abort();
            reader.abort();
            trace!("Killed the sender and reader for websocket");
        });
        self.set_connected(true, None);
        Ok(())
    }
}