1use futures::TryStreamExt;
2use tokio_stream::wrappers::ReceiverStream;
3
4use std::{future::Future, sync::Arc, vec};
5use tokio::{
6 io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt},
7 sync::{
8 mpsc::{
9 channel,
10 error::{TryRecvError, TrySendError},
11 Receiver, Sender,
12 },
13 oneshot, Mutex,
14 },
15};
16use tokio_util::compat::*;
17use tracing::*;
18
19use crate::{codec, PixelFormat, Rect, VncEncoding, VncError, VncEvent, X11Event};
20const CHANNEL_SIZE: usize = 4096;
21
22#[cfg(not(target_arch = "wasm32"))]
23use tokio::spawn;
24#[cfg(target_arch = "wasm32")]
25use wasm_bindgen_futures::spawn_local as spawn;
26
27use super::messages::{ClientMsg, ServerMsg};
28
29struct ImageRect {
30 rect: Rect,
31 encoding: VncEncoding,
32}
33
34impl From<[u8; 12]> for ImageRect {
35 fn from(buf: [u8; 12]) -> Self {
36 Self {
37 rect: Rect {
38 x: (buf[0] as u16) << 8 | buf[1] as u16,
39 y: (buf[2] as u16) << 8 | buf[3] as u16,
40 width: (buf[4] as u16) << 8 | buf[5] as u16,
41 height: (buf[6] as u16) << 8 | buf[7] as u16,
42 },
43 encoding: ((buf[8] as u32) << 24
44 | (buf[9] as u32) << 16
45 | (buf[10] as u32) << 8
46 | (buf[11] as u32))
47 .into(),
48 }
49 }
50}
51
52impl ImageRect {
53 async fn read<S>(reader: &mut S) -> Result<Self, VncError>
54 where
55 S: AsyncRead + Unpin,
56 {
57 let mut rect_buf = [0_u8; 12];
58 reader.read_exact(&mut rect_buf).await?;
59 Ok(rect_buf.into())
60 }
61}
62
63struct VncInner {
64 name: String,
65 screen: (u16, u16),
66 input_ch: Sender<ClientMsg>,
67 output_ch: Receiver<VncEvent>,
68 decoding_stop: Option<oneshot::Sender<()>>,
69 net_conn_stop: Option<oneshot::Sender<()>>,
70 closed: bool,
71}
72
73impl VncInner {
76 async fn new<S>(
77 mut stream: S,
78 shared: bool,
79 mut pixel_format: Option<PixelFormat>,
80 encodings: Vec<VncEncoding>,
81 ) -> Result<Self, VncError>
82 where
83 S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
84 {
85 let (conn_ch_tx, conn_ch_rx) = channel(CHANNEL_SIZE);
86 let (input_ch_tx, input_ch_rx) = channel(CHANNEL_SIZE);
87 let (output_ch_tx, output_ch_rx) = channel(CHANNEL_SIZE);
88 let (decoding_stop_tx, decoding_stop_rx) = oneshot::channel();
89 let (net_conn_stop_tx, net_conn_stop_rx) = oneshot::channel();
90
91 trace!("client init msg");
92 send_client_init(&mut stream, shared).await?;
93
94 trace!("server init msg");
95 let (name, (width, height)) =
96 read_server_init(&mut stream, &mut pixel_format, &|e| async {
97 output_ch_tx.send(e).await?;
98 Ok(())
99 })
100 .await?;
101
102 trace!("client encodings: {:?}", encodings);
103 send_client_encoding(&mut stream, encodings).await?;
104
105 trace!("Require the first frame");
106 input_ch_tx
107 .send(ClientMsg::FramebufferUpdateRequest(
108 Rect {
109 x: 0,
110 y: 0,
111 width,
112 height,
113 },
114 0,
115 ))
116 .await?;
117
118 spawn(async move {
120 trace!("Decoding thread starts");
121 let mut conn_ch_rx = {
122 let conn_ch_rx = ReceiverStream::new(conn_ch_rx).into_async_read();
123 FuturesAsyncReadCompatExt::compat(conn_ch_rx)
124 };
125
126 let output_func = |e| async {
127 output_ch_tx.send(e).await?;
128 Ok(())
129 };
130
131 let pf = pixel_format.as_ref().unwrap();
132 if let Err(e) =
133 asycn_vnc_read_loop(&mut conn_ch_rx, pf, &output_func, decoding_stop_rx).await
134 {
135 if let VncError::IoError(e) = e {
136 if let std::io::ErrorKind::UnexpectedEof = e.kind() {
137 } else {
141 error!("Error occurs during the decoding {:?}", e);
142 let _ = output_func(VncEvent::Error(e.to_string())).await;
143 }
144 } else {
145 error!("Error occurs during the decoding {:?}", e);
146 let _ = output_func(VncEvent::Error(e.to_string())).await;
147 }
148 }
149 trace!("Decoding thread stops");
150 });
151
152 spawn(async move {
154 trace!("Net Connection thread starts");
155 let _ =
156 async_connection_process_loop(stream, input_ch_rx, conn_ch_tx, net_conn_stop_rx)
157 .await;
158 trace!("Net Connection thread stops");
159 });
160
161 info!("VNC Client {name} starts");
162 Ok(Self {
163 name,
164 screen: (width, height),
165 input_ch: input_ch_tx,
166 output_ch: output_ch_rx,
167 decoding_stop: Some(decoding_stop_tx),
168 net_conn_stop: Some(net_conn_stop_tx),
169 closed: false,
170 })
171 }
172
173 async fn input(&mut self, event: X11Event) -> Result<(), VncError> {
174 if self.closed {
175 Err(VncError::ClientNotRunning)
176 } else {
177 let msg = match event {
178 X11Event::Refresh => ClientMsg::FramebufferUpdateRequest(
179 Rect {
180 x: 0,
181 y: 0,
182 width: self.screen.0,
183 height: self.screen.1,
184 },
185 1,
186 ),
187 X11Event::KeyEvent(key) => ClientMsg::KeyEvent(key.keycode, key.down),
188 X11Event::PointerEvent(mouse) => {
189 ClientMsg::PointerEvent(mouse.position_x, mouse.position_y, mouse.bottons)
190 }
191 X11Event::CopyText(text) => ClientMsg::ClientCutText(text),
192 };
193 self.input_ch.send(msg).await?;
194 Ok(())
195 }
196 }
197
198 async fn recv_event(&mut self) -> Result<VncEvent, VncError> {
199 if self.closed {
200 Err(VncError::ClientNotRunning)
201 } else {
202 match self.output_ch.recv().await {
203 Some(e) => Ok(e),
204 None => {
205 self.closed = true;
206 Err(VncError::ClientNotRunning)
207 }
208 }
209 }
210 }
211
212 async fn poll_event(&mut self) -> Result<Option<VncEvent>, VncError> {
213 if self.closed {
214 Err(VncError::ClientNotRunning)
215 } else {
216 match self.output_ch.try_recv() {
217 Err(TryRecvError::Disconnected) => {
218 self.closed = true;
219 Err(VncError::ClientNotRunning)
220 }
221 Err(TryRecvError::Empty) => Ok(None),
222 Ok(e) => Ok(Some(e)),
223 }
224 }
226 }
227
228 fn close(&mut self) -> Result<(), VncError> {
231 if self.net_conn_stop.is_some() {
232 let net_conn_stop: oneshot::Sender<()> = self.net_conn_stop.take().unwrap();
233 let _ = net_conn_stop.send(());
234 }
235 if self.decoding_stop.is_some() {
236 let decoding_stop = self.decoding_stop.take().unwrap();
237 let _ = decoding_stop.send(());
238 }
239 self.closed = true;
240 Ok(())
241 }
242}
243
244impl Drop for VncInner {
245 fn drop(&mut self) {
246 info!("VNC Client {} stops", self.name);
247 let _ = self.close();
248 }
249}
250
251pub struct VncClient {
252 inner: Arc<Mutex<VncInner>>,
253}
254
255impl VncClient {
256 pub(super) async fn new<S>(
257 stream: S,
258 shared: bool,
259 pixel_format: Option<PixelFormat>,
260 encodings: Vec<VncEncoding>,
261 ) -> Result<Self, VncError>
262 where
263 S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
264 {
265 Ok(Self {
266 inner: Arc::new(Mutex::new(
267 VncInner::new(stream, shared, pixel_format, encodings).await?,
268 )),
269 })
270 }
271
272 pub async fn input(&self, event: X11Event) -> Result<(), VncError> {
275 self.inner.lock().await.input(event).await
276 }
277
278 pub async fn recv_event(&self) -> Result<VncEvent, VncError> {
282 self.inner.lock().await.recv_event().await
283 }
284
285 pub async fn poll_event(&self) -> Result<Option<VncEvent>, VncError> {
288 self.inner.lock().await.poll_event().await
289 }
290
291 pub async fn close(&self) -> Result<(), VncError> {
294 self.inner.lock().await.close()
295 }
296}
297
298impl Clone for VncClient {
299 fn clone(&self) -> Self {
300 Self {
301 inner: self.inner.clone(),
302 }
303 }
304}
305
306async fn send_client_init<S>(stream: &mut S, shared: bool) -> Result<(), VncError>
307where
308 S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
309{
310 trace!("Send shared flag: {}", shared);
311 stream.write_u8(shared as u8).await?;
312 Ok(())
313}
314
315async fn read_server_init<S, F, Fut>(
316 stream: &mut S,
317 pf: &mut Option<PixelFormat>,
318 output_func: &F,
319) -> Result<(String, (u16, u16)), VncError>
320where
321 S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
322 F: Fn(VncEvent) -> Fut,
323 Fut: Future<Output = Result<(), VncError>>,
324{
325 let screen_width = stream.read_u16().await?;
336 let screen_height = stream.read_u16().await?;
337 let mut send_our_pf = false;
338
339 output_func(VncEvent::SetResolution(
340 (screen_width, screen_height).into(),
341 ))
342 .await?;
343
344 let pixel_format = PixelFormat::read(stream).await?;
345 if pf.is_none() {
346 output_func(VncEvent::SetPixelFormat(pixel_format)).await?;
347 let _ = pf.insert(pixel_format);
348 } else {
349 send_our_pf = true;
350 }
351
352 let name_len = stream.read_u32().await?;
353 let mut name_buf = vec![0_u8; name_len as usize];
354 stream.read_exact(&mut name_buf).await?;
355 let name = String::from_utf8_lossy(&name_buf).into_owned();
356
357 if send_our_pf {
358 trace!("Send customized pixel format {:#?}", pf);
359 ClientMsg::SetPixelFormat(*pf.as_ref().unwrap())
360 .write(stream)
361 .await?;
362 }
363 Ok((name, (screen_width, screen_height)))
364}
365
366async fn send_client_encoding<S>(
367 stream: &mut S,
368 encodings: Vec<VncEncoding>,
369) -> Result<(), VncError>
370where
371 S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
372{
373 ClientMsg::SetEncodings(encodings).write(stream).await?;
374 Ok(())
375}
376
377async fn asycn_vnc_read_loop<S, F, Fut>(
378 stream: &mut S,
379 pf: &PixelFormat,
380 output_func: &F,
381 mut stop_ch: oneshot::Receiver<()>,
382) -> Result<(), VncError>
383where
384 S: AsyncRead + Unpin,
385 F: Fn(VncEvent) -> Fut,
386 Fut: Future<Output = Result<(), VncError>>,
387{
388 let mut raw_decoder = codec::RawDecoder::new();
389 let mut zrle_decoder = codec::ZrleDecoder::new();
390 let mut tight_decoder = codec::TightDecoder::new();
391 let mut trle_decoder = codec::TrleDecoder::new();
392 let mut cursor = codec::CursorDecoder::new();
393
394 while let Err(oneshot::error::TryRecvError::Empty) = stop_ch.try_recv() {
396 let server_msg = ServerMsg::read(stream).await?;
397 trace!("Server message got: {:?}", server_msg);
398 match server_msg {
399 ServerMsg::FramebufferUpdate(rect_num) => {
400 for _ in 0..rect_num {
401 let rect = ImageRect::read(stream).await?;
402 match rect.encoding {
405 VncEncoding::Raw => {
406 raw_decoder
407 .decode(pf, &rect.rect, stream, output_func)
408 .await?;
409 }
410 VncEncoding::CopyRect => {
411 let source_x = stream.read_u16().await?;
412 let source_y = stream.read_u16().await?;
413 let mut src_rect = rect.rect;
414 src_rect.x = source_x;
415 src_rect.y = source_y;
416 output_func(VncEvent::Copy(rect.rect, src_rect)).await?;
417 }
418 VncEncoding::Tight => {
419 tight_decoder
420 .decode(pf, &rect.rect, stream, output_func)
421 .await?;
422 }
423 VncEncoding::Trle => {
424 trle_decoder
425 .decode(pf, &rect.rect, stream, output_func)
426 .await?;
427 }
428 VncEncoding::Zrle => {
429 zrle_decoder
430 .decode(pf, &rect.rect, stream, output_func)
431 .await?;
432 }
433 VncEncoding::CursorPseudo => {
434 cursor.decode(pf, &rect.rect, stream, output_func).await?;
435 }
436 VncEncoding::DesktopSizePseudo => {
437 output_func(VncEvent::SetResolution(
438 (rect.rect.width, rect.rect.height).into(),
439 ))
440 .await?;
441 }
442 VncEncoding::LastRectPseudo => {
443 break;
444 }
445 }
446 }
447 }
448 ServerMsg::Bell => {
450 output_func(VncEvent::Bell).await?;
451 }
452 ServerMsg::ServerCutText(text) => {
453 output_func(VncEvent::Text(text)).await?;
454 }
455 }
456 }
457 Ok(())
458}
459
460async fn async_connection_process_loop<S>(
461 mut stream: S,
462 mut input_ch: Receiver<ClientMsg>,
463 conn_ch: Sender<std::io::Result<Vec<u8>>>,
464 mut stop_ch: oneshot::Receiver<()>,
465) -> Result<(), VncError>
466where
467 S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
468{
469 let mut buffer = [0; 65535];
470 let mut pending = 0;
471
472 loop {
474 if pending > 0 {
475 match conn_ch.try_send(Ok(buffer[0..pending].to_owned())) {
476 Err(TrySendError::Full(_message)) => (),
477 Err(TrySendError::Closed(_message)) => break,
478 Ok(()) => pending = 0,
479 }
480 }
481
482 tokio::select! {
483 _ = &mut stop_ch => break,
484 result = stream.read(&mut buffer), if pending == 0 => {
485 match result {
486 Ok(nread) => {
487 if nread > 0 {
488 match conn_ch.try_send(Ok(buffer[0..nread].to_owned())) {
489 Err(TrySendError::Full(_message)) => pending = nread,
490 Err(TrySendError::Closed(_message)) => break,
491 Ok(()) => ()
492 }
493 } else {
494 trace!("Net Connection EOF detected");
498 break;
499 }
500 }
501 Err(e) => {
502 error!("{}", e.to_string());
503 break;
504 }
505 }
506 }
507 Some(msg) = input_ch.recv() => {
508 msg.write(&mut stream).await?;
509 }
510 }
511 }
512
513 let _ = conn_ch
515 .send(Err(std::io::Error::from(std::io::ErrorKind::UnexpectedEof)))
516 .await;
517
518 Ok(())
519}