vnc/client/
connection.rs

1use futures::TryStreamExt;
2use tokio_stream::wrappers::ReceiverStream;
3
4use std::{future::Future, sync::Arc, vec};
5use tokio::{
6    io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt},
7    sync::{
8        mpsc::{
9            channel,
10            error::{TryRecvError, TrySendError},
11            Receiver, Sender,
12        },
13        oneshot, Mutex,
14    },
15};
16use tokio_util::compat::*;
17use tracing::*;
18
19use crate::{codec, PixelFormat, Rect, VncEncoding, VncError, VncEvent, X11Event};
20const CHANNEL_SIZE: usize = 4096;
21
22#[cfg(not(target_arch = "wasm32"))]
23use tokio::spawn;
24#[cfg(target_arch = "wasm32")]
25use wasm_bindgen_futures::spawn_local as spawn;
26
27use super::messages::{ClientMsg, ServerMsg};
28
29struct ImageRect {
30    rect: Rect,
31    encoding: VncEncoding,
32}
33
34impl From<[u8; 12]> for ImageRect {
35    fn from(buf: [u8; 12]) -> Self {
36        Self {
37            rect: Rect {
38                x: (buf[0] as u16) << 8 | buf[1] as u16,
39                y: (buf[2] as u16) << 8 | buf[3] as u16,
40                width: (buf[4] as u16) << 8 | buf[5] as u16,
41                height: (buf[6] as u16) << 8 | buf[7] as u16,
42            },
43            encoding: ((buf[8] as u32) << 24
44                | (buf[9] as u32) << 16
45                | (buf[10] as u32) << 8
46                | (buf[11] as u32))
47                .into(),
48        }
49    }
50}
51
52impl ImageRect {
53    async fn read<S>(reader: &mut S) -> Result<Self, VncError>
54    where
55        S: AsyncRead + Unpin,
56    {
57        let mut rect_buf = [0_u8; 12];
58        reader.read_exact(&mut rect_buf).await?;
59        Ok(rect_buf.into())
60    }
61}
62
63struct VncInner {
64    name: String,
65    screen: (u16, u16),
66    input_ch: Sender<ClientMsg>,
67    output_ch: Receiver<VncEvent>,
68    decoding_stop: Option<oneshot::Sender<()>>,
69    net_conn_stop: Option<oneshot::Sender<()>>,
70    closed: bool,
71}
72
73/// The instance of a connected vnc client
74
75impl VncInner {
76    async fn new<S>(
77        mut stream: S,
78        shared: bool,
79        mut pixel_format: Option<PixelFormat>,
80        encodings: Vec<VncEncoding>,
81    ) -> Result<Self, VncError>
82    where
83        S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
84    {
85        let (conn_ch_tx, conn_ch_rx) = channel(CHANNEL_SIZE);
86        let (input_ch_tx, input_ch_rx) = channel(CHANNEL_SIZE);
87        let (output_ch_tx, output_ch_rx) = channel(CHANNEL_SIZE);
88        let (decoding_stop_tx, decoding_stop_rx) = oneshot::channel();
89        let (net_conn_stop_tx, net_conn_stop_rx) = oneshot::channel();
90
91        trace!("client init msg");
92        send_client_init(&mut stream, shared).await?;
93
94        trace!("server init msg");
95        let (name, (width, height)) =
96            read_server_init(&mut stream, &mut pixel_format, &|e| async {
97                output_ch_tx.send(e).await?;
98                Ok(())
99            })
100            .await?;
101
102        trace!("client encodings: {:?}", encodings);
103        send_client_encoding(&mut stream, encodings).await?;
104
105        trace!("Require the first frame");
106        input_ch_tx
107            .send(ClientMsg::FramebufferUpdateRequest(
108                Rect {
109                    x: 0,
110                    y: 0,
111                    width,
112                    height,
113                },
114                0,
115            ))
116            .await?;
117
118        // start the decoding thread
119        spawn(async move {
120            trace!("Decoding thread starts");
121            let mut conn_ch_rx = {
122                let conn_ch_rx = ReceiverStream::new(conn_ch_rx).into_async_read();
123                FuturesAsyncReadCompatExt::compat(conn_ch_rx)
124            };
125
126            let output_func = |e| async {
127                output_ch_tx.send(e).await?;
128                Ok(())
129            };
130
131            let pf = pixel_format.as_ref().unwrap();
132            if let Err(e) =
133                asycn_vnc_read_loop(&mut conn_ch_rx, pf, &output_func, decoding_stop_rx).await
134            {
135                if let VncError::IoError(e) = e {
136                    if let std::io::ErrorKind::UnexpectedEof = e.kind() {
137                        // this should be a normal case when the network connection disconnects
138                        // and we just send an EOF over the inner bridge between the process thread and the decode thread
139                        // do nothing here
140                    } else {
141                        error!("Error occurs during the decoding {:?}", e);
142                        let _ = output_func(VncEvent::Error(e.to_string())).await;
143                    }
144                } else {
145                    error!("Error occurs during the decoding {:?}", e);
146                    let _ = output_func(VncEvent::Error(e.to_string())).await;
147                }
148            }
149            trace!("Decoding thread stops");
150        });
151
152        // start the traffic process thread
153        spawn(async move {
154            trace!("Net Connection thread starts");
155            let _ =
156                async_connection_process_loop(stream, input_ch_rx, conn_ch_tx, net_conn_stop_rx)
157                    .await;
158            trace!("Net Connection thread stops");
159        });
160
161        info!("VNC Client {name} starts");
162        Ok(Self {
163            name,
164            screen: (width, height),
165            input_ch: input_ch_tx,
166            output_ch: output_ch_rx,
167            decoding_stop: Some(decoding_stop_tx),
168            net_conn_stop: Some(net_conn_stop_tx),
169            closed: false,
170        })
171    }
172
173    async fn input(&mut self, event: X11Event) -> Result<(), VncError> {
174        if self.closed {
175            Err(VncError::ClientNotRunning)
176        } else {
177            let msg = match event {
178                X11Event::Refresh => ClientMsg::FramebufferUpdateRequest(
179                    Rect {
180                        x: 0,
181                        y: 0,
182                        width: self.screen.0,
183                        height: self.screen.1,
184                    },
185                    1,
186                ),
187                X11Event::KeyEvent(key) => ClientMsg::KeyEvent(key.keycode, key.down),
188                X11Event::PointerEvent(mouse) => {
189                    ClientMsg::PointerEvent(mouse.position_x, mouse.position_y, mouse.bottons)
190                }
191                X11Event::CopyText(text) => ClientMsg::ClientCutText(text),
192            };
193            self.input_ch.send(msg).await?;
194            Ok(())
195        }
196    }
197
198    async fn recv_event(&mut self) -> Result<VncEvent, VncError> {
199        if self.closed {
200            Err(VncError::ClientNotRunning)
201        } else {
202            match self.output_ch.recv().await {
203                Some(e) => Ok(e),
204                None => {
205                    self.closed = true;
206                    Err(VncError::ClientNotRunning)
207                }
208            }
209        }
210    }
211
212    async fn poll_event(&mut self) -> Result<Option<VncEvent>, VncError> {
213        if self.closed {
214            Err(VncError::ClientNotRunning)
215        } else {
216            match self.output_ch.try_recv() {
217                Err(TryRecvError::Disconnected) => {
218                    self.closed = true;
219                    Err(VncError::ClientNotRunning)
220                }
221                Err(TryRecvError::Empty) => Ok(None),
222                Ok(e) => Ok(Some(e)),
223            }
224            // Ok(self.output_ch.recv().await)
225        }
226    }
227
228    /// Stop the VNC engine and release resources
229    ///
230    fn close(&mut self) -> Result<(), VncError> {
231        if self.net_conn_stop.is_some() {
232            let net_conn_stop: oneshot::Sender<()> = self.net_conn_stop.take().unwrap();
233            let _ = net_conn_stop.send(());
234        }
235        if self.decoding_stop.is_some() {
236            let decoding_stop = self.decoding_stop.take().unwrap();
237            let _ = decoding_stop.send(());
238        }
239        self.closed = true;
240        Ok(())
241    }
242}
243
244impl Drop for VncInner {
245    fn drop(&mut self) {
246        info!("VNC Client {} stops", self.name);
247        let _ = self.close();
248    }
249}
250
251pub struct VncClient {
252    inner: Arc<Mutex<VncInner>>,
253}
254
255impl VncClient {
256    pub(super) async fn new<S>(
257        stream: S,
258        shared: bool,
259        pixel_format: Option<PixelFormat>,
260        encodings: Vec<VncEncoding>,
261    ) -> Result<Self, VncError>
262    where
263        S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
264    {
265        Ok(Self {
266            inner: Arc::new(Mutex::new(
267                VncInner::new(stream, shared, pixel_format, encodings).await?,
268            )),
269        })
270    }
271
272    /// Input a `X11Event` from the frontend
273    ///
274    pub async fn input(&self, event: X11Event) -> Result<(), VncError> {
275        self.inner.lock().await.input(event).await
276    }
277
278    /// Receive a `VncEvent` from the engine
279    /// This function will block until a `VncEvent` is received
280    ///
281    pub async fn recv_event(&self) -> Result<VncEvent, VncError> {
282        self.inner.lock().await.recv_event().await
283    }
284
285    /// polling `VncEvent` from the engine and give it to the client
286    ///
287    pub async fn poll_event(&self) -> Result<Option<VncEvent>, VncError> {
288        self.inner.lock().await.poll_event().await
289    }
290
291    /// Stop the VNC engine and release resources
292    ///
293    pub async fn close(&self) -> Result<(), VncError> {
294        self.inner.lock().await.close()
295    }
296}
297
298impl Clone for VncClient {
299    fn clone(&self) -> Self {
300        Self {
301            inner: self.inner.clone(),
302        }
303    }
304}
305
306async fn send_client_init<S>(stream: &mut S, shared: bool) -> Result<(), VncError>
307where
308    S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
309{
310    trace!("Send shared flag: {}", shared);
311    stream.write_u8(shared as u8).await?;
312    Ok(())
313}
314
315async fn read_server_init<S, F, Fut>(
316    stream: &mut S,
317    pf: &mut Option<PixelFormat>,
318    output_func: &F,
319) -> Result<(String, (u16, u16)), VncError>
320where
321    S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
322    F: Fn(VncEvent) -> Fut,
323    Fut: Future<Output = Result<(), VncError>>,
324{
325    // +--------------+--------------+------------------------------+
326    // | No. of bytes | Type [Value] | Description                  |
327    // +--------------+--------------+------------------------------+
328    // | 2            | U16          | framebuffer-width in pixels  |
329    // | 2            | U16          | framebuffer-height in pixels |
330    // | 16           | PIXEL_FORMAT | server-pixel-format          |
331    // | 4            | U32          | name-length                  |
332    // | name-length  | U8 array     | name-string                  |
333    // +--------------+--------------+------------------------------+
334
335    let screen_width = stream.read_u16().await?;
336    let screen_height = stream.read_u16().await?;
337    let mut send_our_pf = false;
338
339    output_func(VncEvent::SetResolution(
340        (screen_width, screen_height).into(),
341    ))
342    .await?;
343
344    let pixel_format = PixelFormat::read(stream).await?;
345    if pf.is_none() {
346        output_func(VncEvent::SetPixelFormat(pixel_format)).await?;
347        let _ = pf.insert(pixel_format);
348    } else {
349        send_our_pf = true;
350    }
351
352    let name_len = stream.read_u32().await?;
353    let mut name_buf = vec![0_u8; name_len as usize];
354    stream.read_exact(&mut name_buf).await?;
355    let name = String::from_utf8_lossy(&name_buf).into_owned();
356
357    if send_our_pf {
358        trace!("Send customized pixel format {:#?}", pf);
359        ClientMsg::SetPixelFormat(*pf.as_ref().unwrap())
360            .write(stream)
361            .await?;
362    }
363    Ok((name, (screen_width, screen_height)))
364}
365
366async fn send_client_encoding<S>(
367    stream: &mut S,
368    encodings: Vec<VncEncoding>,
369) -> Result<(), VncError>
370where
371    S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
372{
373    ClientMsg::SetEncodings(encodings).write(stream).await?;
374    Ok(())
375}
376
377async fn asycn_vnc_read_loop<S, F, Fut>(
378    stream: &mut S,
379    pf: &PixelFormat,
380    output_func: &F,
381    mut stop_ch: oneshot::Receiver<()>,
382) -> Result<(), VncError>
383where
384    S: AsyncRead + Unpin,
385    F: Fn(VncEvent) -> Fut,
386    Fut: Future<Output = Result<(), VncError>>,
387{
388    let mut raw_decoder = codec::RawDecoder::new();
389    let mut zrle_decoder = codec::ZrleDecoder::new();
390    let mut tight_decoder = codec::TightDecoder::new();
391    let mut trle_decoder = codec::TrleDecoder::new();
392    let mut cursor = codec::CursorDecoder::new();
393
394    // main decoding loop
395    while let Err(oneshot::error::TryRecvError::Empty) = stop_ch.try_recv() {
396        let server_msg = ServerMsg::read(stream).await?;
397        trace!("Server message got: {:?}", server_msg);
398        match server_msg {
399            ServerMsg::FramebufferUpdate(rect_num) => {
400                for _ in 0..rect_num {
401                    let rect = ImageRect::read(stream).await?;
402                    // trace!("Encoding: {:?}", rect.encoding);
403
404                    match rect.encoding {
405                        VncEncoding::Raw => {
406                            raw_decoder
407                                .decode(pf, &rect.rect, stream, output_func)
408                                .await?;
409                        }
410                        VncEncoding::CopyRect => {
411                            let source_x = stream.read_u16().await?;
412                            let source_y = stream.read_u16().await?;
413                            let mut src_rect = rect.rect;
414                            src_rect.x = source_x;
415                            src_rect.y = source_y;
416                            output_func(VncEvent::Copy(rect.rect, src_rect)).await?;
417                        }
418                        VncEncoding::Tight => {
419                            tight_decoder
420                                .decode(pf, &rect.rect, stream, output_func)
421                                .await?;
422                        }
423                        VncEncoding::Trle => {
424                            trle_decoder
425                                .decode(pf, &rect.rect, stream, output_func)
426                                .await?;
427                        }
428                        VncEncoding::Zrle => {
429                            zrle_decoder
430                                .decode(pf, &rect.rect, stream, output_func)
431                                .await?;
432                        }
433                        VncEncoding::CursorPseudo => {
434                            cursor.decode(pf, &rect.rect, stream, output_func).await?;
435                        }
436                        VncEncoding::DesktopSizePseudo => {
437                            output_func(VncEvent::SetResolution(
438                                (rect.rect.width, rect.rect.height).into(),
439                            ))
440                            .await?;
441                        }
442                        VncEncoding::LastRectPseudo => {
443                            break;
444                        }
445                    }
446                }
447            }
448            // SetColorMapEntries,
449            ServerMsg::Bell => {
450                output_func(VncEvent::Bell).await?;
451            }
452            ServerMsg::ServerCutText(text) => {
453                output_func(VncEvent::Text(text)).await?;
454            }
455        }
456    }
457    Ok(())
458}
459
460async fn async_connection_process_loop<S>(
461    mut stream: S,
462    mut input_ch: Receiver<ClientMsg>,
463    conn_ch: Sender<std::io::Result<Vec<u8>>>,
464    mut stop_ch: oneshot::Receiver<()>,
465) -> Result<(), VncError>
466where
467    S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
468{
469    let mut buffer = [0; 65535];
470    let mut pending = 0;
471
472    // main traffic loop
473    loop {
474        if pending > 0 {
475            match conn_ch.try_send(Ok(buffer[0..pending].to_owned())) {
476                Err(TrySendError::Full(_message)) => (),
477                Err(TrySendError::Closed(_message)) => break,
478                Ok(()) => pending = 0,
479            }
480        }
481
482        tokio::select! {
483            _ = &mut stop_ch => break,
484            result = stream.read(&mut buffer), if pending == 0 => {
485                match result {
486                    Ok(nread) => {
487                        if nread > 0 {
488                            match conn_ch.try_send(Ok(buffer[0..nread].to_owned())) {
489                                Err(TrySendError::Full(_message)) => pending = nread,
490                                Err(TrySendError::Closed(_message)) => break,
491                                Ok(()) => ()
492                            }
493                        } else {
494                            // According to the tokio's Doc
495                            // https://docs.rs/tokio/latest/tokio/io/trait.AsyncRead.html
496                            // if nread == 0, then EOF is reached
497                            trace!("Net Connection EOF detected");
498                            break;
499                        }
500                    }
501                    Err(e) => {
502                        error!("{}", e.to_string());
503                        break;
504                    }
505                }
506            }
507            Some(msg) = input_ch.recv() => {
508                msg.write(&mut stream).await?;
509            }
510        }
511    }
512
513    // notify the decoding thread
514    let _ = conn_ch
515        .send(Err(std::io::Error::from(std::io::ErrorKind::UnexpectedEof)))
516        .await;
517
518    Ok(())
519}