rfb/
server.rs

1// This Source Code Form is subject to the terms of the Mozilla Public
2// License, v. 2.0. If a copy of the MPL was not distributed with this
3// file, You can obtain one at https://mozilla.org/MPL/2.0/.
4//
5// Copyright 2022 Oxide Computer Company
6
7use std::io;
8use std::marker::{Send, Sync};
9use std::net::SocketAddr;
10use std::sync::Arc;
11
12use async_trait::async_trait;
13use futures::future::Shared;
14use futures::FutureExt;
15use log::{debug, error, info, trace};
16use thiserror::Error;
17use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
18use tokio::net::{TcpListener, TcpStream};
19use tokio::select;
20use tokio::sync::{oneshot, Mutex};
21
22use crate::rfb::{
23    ClientInit, ClientMessage, FramebufferUpdate, KeyEvent, PixelFormat, ProtoVersion,
24    ProtocolError, ReadMessage, SecurityResult, SecurityType, SecurityTypes, ServerInit,
25    WriteMessage,
26};
27
28#[derive(Debug, Error)]
29pub enum HandshakeError {
30    #[error("incompatible protocol versions (client = {client:?}, server = {server:?})")]
31    IncompatibleVersions {
32        client: ProtoVersion,
33        server: ProtoVersion,
34    },
35
36    #[error(
37        "incompatible security types (client choice = {choice:?}, server offered = {offer:?})"
38    )]
39    IncompatibleSecurityTypes {
40        choice: SecurityType,
41        offer: SecurityTypes,
42    },
43
44    #[error(transparent)]
45    Protocol(#[from] ProtocolError),
46}
47
48/// Immutable state
49pub struct VncServerConfig {
50    pub addr: SocketAddr,
51    pub version: ProtoVersion,
52    pub sec_types: SecurityTypes,
53    pub name: String,
54}
55
56/// Mutable state
57pub struct VncServerData {
58    pub width: u16,
59    pub height: u16,
60
61    /// The pixel format of the framebuffer data passed in to the server via
62    /// get_framebuffer_update.
63    pub input_pixel_format: PixelFormat,
64}
65
66pub struct VncServer<S: Server> {
67    /// VNC startup server configuration
68    config: VncServerConfig,
69
70    /// VNC runtime mutable state
71    data: Mutex<VncServerData>,
72
73    /// The underlying [`Server`] implementation
74    pub server: S,
75
76    /// One-shot channel used to signal that the server should shut down.
77    stop_ch: Mutex<Option<oneshot::Sender<()>>>,
78}
79
80#[async_trait]
81pub trait Server: Sync + Send + 'static {
82    async fn get_framebuffer_update(&self) -> FramebufferUpdate;
83    async fn key_event(&self, _ke: KeyEvent) {}
84    async fn stop(&self) {}
85}
86
87impl<S: Server> VncServer<S> {
88    pub fn new(server: S, config: VncServerConfig, data: VncServerData) -> Arc<Self> {
89        assert!(
90            config.sec_types.0.len() > 0,
91            "at least one security type must be defined"
92        );
93        Arc::new(Self {
94            config: config,
95            data: Mutex::new(data),
96            server: server,
97            stop_ch: Mutex::new(None),
98        })
99    }
100
101    pub async fn set_pixel_format(&self, pixel_format: PixelFormat) {
102        let mut locked = self.data.lock().await;
103        locked.input_pixel_format = pixel_format;
104    }
105
106    pub async fn set_resolution(&self, width: u16, height: u16) {
107        let mut locked = self.data.lock().await;
108        locked.width = width;
109        locked.height = height;
110    }
111
112    async fn rfb_handshake(
113        &self,
114        s: &mut (impl AsyncRead + AsyncWrite + Unpin + Send + Sync),
115        addr: SocketAddr,
116    ) -> Result<(), HandshakeError> {
117        // ProtocolVersion handshake
118        info!("Tx [{:?}]: ProtoVersion={:?}", addr, self.config.version);
119        self.config.version.write_to(s).await?;
120        let client_version = ProtoVersion::read_from(s).await?;
121        info!("Rx [{:?}]: ClientVersion={:?}", addr, client_version);
122
123        if client_version < self.config.version {
124            let err_str = format!(
125                "[{:?}] unsupported client version={:?} (server version: {:?})",
126                addr, client_version, self.config.version
127            );
128            error!("{}", err_str);
129            return Err(HandshakeError::IncompatibleVersions {
130                client: client_version,
131                server: self.config.version,
132            });
133        }
134
135        // Security Handshake
136        let supported_types = self.config.sec_types.clone();
137        info!("Tx [{:?}]: SecurityTypes={:?}", addr, supported_types);
138        supported_types.write_to(s).await?;
139        let client_choice = SecurityType::read_from(s).await?;
140        info!("Rx [{:?}]: SecurityType Choice={:?}", addr, client_choice);
141        if !self.config.sec_types.0.contains(&client_choice) {
142            info!("Tx [{:?}]: SecurityResult=Failure", addr);
143            let failure = SecurityResult::Failure("unsupported security type".to_string());
144            failure.write_to(s).await?;
145            let err_str = format!("invalid security choice={:?}", client_choice);
146            error!("{}", err_str);
147            return Err(HandshakeError::IncompatibleSecurityTypes {
148                choice: client_choice,
149                offer: self.config.sec_types.clone(),
150            });
151        }
152
153        let res = SecurityResult::Success;
154        info!("Tx: SecurityResult=Success");
155        res.write_to(s).await?;
156
157        Ok(())
158    }
159
160    async fn rfb_initialization(
161        &self,
162        s: &mut (impl AsyncRead + AsyncWrite + Unpin + Send + Sync),
163        addr: SocketAddr,
164    ) -> Result<(), ProtocolError> {
165        let client_init = ClientInit::read_from(s).await?;
166        info!("Rx [{:?}]: ClientInit={:?}", addr, client_init);
167        // TODO: decide what to do in exclusive case
168        match client_init.shared {
169            true => {}
170            false => {}
171        }
172
173        let data = self.data.lock().await;
174        let server_init = ServerInit::new(
175            data.width,
176            data.height,
177            self.config.name.clone(),
178            data.input_pixel_format.clone(),
179        );
180        info!("Tx [{:?}]: ServerInit={:#?}", addr, server_init);
181        server_init.write_to(s).await?;
182
183        Ok(())
184    }
185
186    pub async fn handle_conn(
187        &self,
188        s: &mut (impl AsyncRead + AsyncWrite + Unpin + Send + Sync),
189        addr: SocketAddr,
190        mut close_ch: Shared<oneshot::Receiver<()>>,
191    ) {
192        info!("[{:?}] new connection", addr);
193
194        if let Err(e) = self.rfb_handshake(s, addr).await {
195            error!("[{:?}] could not complete handshake: {:?}", addr, e);
196            return;
197        }
198
199        if let Err(e) = self.rfb_initialization(s, addr).await {
200            error!("[{:?}] could not complete handshake: {:?}", addr, e);
201            return;
202        }
203
204        let data = self.data.lock().await;
205        let mut output_pixel_format = data.input_pixel_format.clone();
206        drop(data);
207
208        loop {
209            let req = select! {
210                // Poll in the order written so we check for close first
211                biased;
212
213                _ = &mut close_ch => {
214                    info!("[{:?}] server stopping, closing connection with peer", addr);
215                    let _ = s.shutdown().await;
216                    return;
217                }
218
219                req = ClientMessage::read_from(s) => req,
220            };
221
222            match req {
223                Ok(client_msg) => match client_msg {
224                    ClientMessage::SetPixelFormat(pf) => {
225                        debug!("Rx [{:?}]: SetPixelFormat={:#?}", addr, pf);
226
227                        // TODO: invalid pixel formats?
228                        output_pixel_format = pf;
229                    }
230                    ClientMessage::SetEncodings(e) => {
231                        debug!("Rx [{:?}]: SetEncodings={:?}", addr, e);
232                    }
233                    ClientMessage::FramebufferUpdateRequest(f) => {
234                        debug!("Rx [{:?}]: FramebufferUpdateRequest={:?}", addr, f);
235
236                        let mut fbu = self.server.get_framebuffer_update().await;
237
238                        let data = self.data.lock().await;
239
240                        // We only need to change pixel formats if the client requested a different
241                        // one than what's specified in the input.
242                        //
243                        // For now, we only support transformations between 4-byte RGB formats, so
244                        // if the requested format isn't one of those, we'll just leave the pixels
245                        // as is.
246                        if data.input_pixel_format != output_pixel_format
247                            && data.input_pixel_format.is_rgb_888()
248                            && output_pixel_format.is_rgb_888()
249                        {
250                            debug!(
251                                "transforming: input={:#?}, output={:#?}",
252                                data.input_pixel_format, output_pixel_format
253                            );
254                            fbu = fbu.transform(&data.input_pixel_format, &output_pixel_format);
255                        } else if !(data.input_pixel_format.is_rgb_888()
256                            && output_pixel_format.is_rgb_888())
257                        {
258                            debug!("cannot transform between pixel formats (not rgb888): input.is_rgb_888()={}, output.is_rgb_888()={}", data.input_pixel_format.is_rgb_888(), output_pixel_format.is_rgb_888());
259                        } else {
260                            debug!("no input transformation needed");
261                        }
262
263                        if let Err(e) = fbu.write_to(s).await {
264                            error!(
265                                "[{:?}] could not write FramebufferUpdateRequest: {:?}",
266                                addr, e
267                            );
268                            return;
269                        }
270                        debug!("Tx [{:?}]: FramebufferUpdate", addr);
271                    }
272                    ClientMessage::KeyEvent(ke) => {
273                        trace!("Rx [{:?}]: KeyEvent={:?}", addr, ke);
274                        self.server.key_event(ke).await;
275                    }
276                    ClientMessage::PointerEvent(pe) => {
277                        trace!("Rx [{:?}: PointerEvent={:?}", addr, pe);
278                    }
279                    ClientMessage::ClientCutText(t) => {
280                        trace!("Rx [{:?}: ClientCutText={:?}", addr, t);
281                    }
282                },
283                Err(e) => {
284                    error!("[{:?}] error reading client message: {}", addr, e);
285                    return;
286                }
287            }
288        }
289    }
290
291    /// Start listening for incoming connections.
292    pub async fn start(self: &Arc<Self>) -> io::Result<()> {
293        let listener = TcpListener::bind(self.config.addr).await?;
294
295        // Create a channel to signal the server to stop.
296        let (close_tx, close_rx) = oneshot::channel();
297        assert!(
298            self.stop_ch.lock().await.replace(close_tx).is_none(),
299            "server already started"
300        );
301        let mut close_rx = close_rx.shared();
302
303        loop {
304            let (mut client_sock, client_addr) = select! {
305                // Poll in the order written so we check for close first
306                biased;
307
308                _ = &mut close_rx => {
309                    info!("server stopping");
310                    self.server.stop().await;
311                    return Ok(());
312                }
313
314                conn = listener.accept() => conn?,
315            };
316
317            let close_rx = close_rx.clone();
318            let server = self.clone();
319            tokio::spawn(async move {
320                server
321                    .handle_conn(&mut client_sock, client_addr, close_rx)
322                    .await;
323            });
324        }
325    }
326
327    /// Stop the server (and disconnect any client) if it's running.
328    pub async fn stop(self: &Arc<Self>) {
329        if let Some(close_tx) = self.stop_ch.lock().await.take() {
330            let _ = close_tx.send(());
331        }
332    }
333}