rfb/
server.rs

1// This Source Code Form is subject to the terms of the Mozilla Public
2// License, v. 2.0. If a copy of the MPL was not distributed with this
3// file, You can obtain one at https://mozilla.org/MPL/2.0/.
4//
5// Copyright 2022 Oxide Computer Company
6
7use std::fmt::Debug;
8use std::io;
9use std::marker::{Send, Sync};
10use std::net::SocketAddr;
11use std::sync::Arc;
12
13use async_trait::async_trait;
14use futures::future::Shared;
15use futures::FutureExt;
16use log::{debug, error, info, trace};
17use thiserror::Error;
18use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
19use tokio::net::{TcpListener, TcpStream};
20use tokio::select;
21use tokio::sync::{oneshot, Mutex};
22
23use crate::rfb::{
24    ClientInit, ClientMessage, FramebufferUpdate, KeyEvent, PixelFormat, ProtoVersion,
25    ProtocolError, ReadMessage, SecurityResult, SecurityType, SecurityTypes, ServerInit,
26    WriteMessage,
27};
28
29#[derive(Debug, Error)]
30pub enum HandshakeError {
31    #[error("incompatible protocol versions (client = {client:?}, server = {server:?})")]
32    IncompatibleVersions {
33        client: ProtoVersion,
34        server: ProtoVersion,
35    },
36
37    #[error(
38        "incompatible security types (client choice = {choice:?}, server offered = {offer:?})"
39    )]
40    IncompatibleSecurityTypes {
41        choice: SecurityType,
42        offer: SecurityTypes,
43    },
44
45    #[error(transparent)]
46    Protocol(#[from] ProtocolError),
47}
48
49/// Immutable state
50pub struct VncServerConfig {
51    pub addr: SocketAddr,
52    pub version: ProtoVersion,
53    pub sec_types: SecurityTypes,
54    pub name: String,
55}
56
57/// Mutable state
58pub struct VncServerData {
59    pub width: u16,
60    pub height: u16,
61
62    /// The pixel format of the framebuffer data passed in to the server via
63    /// get_framebuffer_update.
64    pub input_pixel_format: PixelFormat,
65}
66
67pub struct VncServer<S: Server<K>, K = SocketAddr> {
68    /// VNC startup server configuration
69    config: VncServerConfig,
70
71    /// VNC runtime mutable state
72    data: Mutex<VncServerData>,
73
74    /// The underlying [`Server`] implementation
75    pub server: Arc<dyn Fn(&K) -> Arc<S> + Send + Sync>,
76
77    /// One-shot channel used to signal that the server should shut down.
78    stop_ch: Mutex<Option<oneshot::Sender<()>>>,
79}
80
81#[async_trait]
82pub trait Server<K = SocketAddr>: Sync + Send + 'static {
83    async fn get_framebuffer_update(&self) -> FramebufferUpdate;
84    async fn key_event(&self, _ke: KeyEvent) {}
85    async fn stop(&self) {}
86}
87
88impl<S: Server<K>, K: Debug + Send + Sync + 'static> VncServer<S, K> {
89    pub fn new(server: S, config: VncServerConfig, data: VncServerData) -> Arc<Self> {
90        let server = Arc::new(server);
91        Self::new_base(Arc::new(move |_| server.clone()), config, data)
92    }
93    pub fn new_base(
94        server: Arc<dyn Fn(&K) -> Arc<S> + Send + Sync>,
95        config: VncServerConfig,
96        data: VncServerData,
97    ) -> Arc<Self> {
98        assert!(
99            config.sec_types.0.len() > 0,
100            "at least one security type must be defined"
101        );
102        Arc::new(Self {
103            config: config,
104            data: Mutex::new(data),
105            server: server,
106            stop_ch: Mutex::new(None),
107        })
108    }
109
110    pub async fn set_pixel_format(&self, pixel_format: PixelFormat) {
111        let mut locked = self.data.lock().await;
112        locked.input_pixel_format = pixel_format;
113    }
114
115    pub async fn set_resolution(&self, width: u16, height: u16) {
116        let mut locked = self.data.lock().await;
117        locked.width = width;
118        locked.height = height;
119    }
120
121    async fn rfb_handshake(
122        &self,
123        s: &mut (impl AsyncRead + AsyncWrite + Unpin + Send + Sync),
124        addr: &K,
125    ) -> Result<(), HandshakeError> {
126        // ProtocolVersion handshake
127        info!("Tx [{:?}]: ProtoVersion={:?}", addr, self.config.version);
128        self.config.version.write_to(s).await?;
129        let client_version = ProtoVersion::read_from(s).await?;
130        info!("Rx [{:?}]: ClientVersion={:?}", addr, client_version);
131
132        if client_version < self.config.version {
133            let err_str = format!(
134                "[{:?}] unsupported client version={:?} (server version: {:?})",
135                addr, client_version, self.config.version
136            );
137            error!("{}", err_str);
138            return Err(HandshakeError::IncompatibleVersions {
139                client: client_version,
140                server: self.config.version,
141            });
142        }
143
144        // Security Handshake
145        let supported_types = self.config.sec_types.clone();
146        info!("Tx [{:?}]: SecurityTypes={:?}", addr, supported_types);
147        supported_types.write_to(s).await?;
148        let client_choice = SecurityType::read_from(s).await?;
149        info!("Rx [{:?}]: SecurityType Choice={:?}", addr, client_choice);
150        if !self.config.sec_types.0.contains(&client_choice) {
151            info!("Tx [{:?}]: SecurityResult=Failure", addr);
152            let failure = SecurityResult::Failure("unsupported security type".to_string());
153            failure.write_to(s).await?;
154            let err_str = format!("invalid security choice={:?}", client_choice);
155            error!("{}", err_str);
156            return Err(HandshakeError::IncompatibleSecurityTypes {
157                choice: client_choice,
158                offer: self.config.sec_types.clone(),
159            });
160        }
161
162        let res = SecurityResult::Success;
163        info!("Tx: SecurityResult=Success");
164        res.write_to(s).await?;
165
166        Ok(())
167    }
168
169    async fn rfb_initialization(
170        &self,
171        s: &mut (impl AsyncRead + AsyncWrite + Unpin + Send + Sync),
172        addr: &K,
173    ) -> Result<(), ProtocolError> {
174        let client_init = ClientInit::read_from(s).await?;
175        info!("Rx [{:?}]: ClientInit={:?}", addr, client_init);
176        // TODO: decide what to do in exclusive case
177        match client_init.shared {
178            true => {}
179            false => {}
180        }
181
182        let data = self.data.lock().await;
183        let server_init = ServerInit::new(
184            data.width,
185            data.height,
186            self.config.name.clone(),
187            data.input_pixel_format.clone(),
188        );
189        info!("Tx [{:?}]: ServerInit={:#?}", addr, server_init);
190        server_init.write_to(s).await?;
191
192        Ok(())
193    }
194
195    pub async fn handle_conn(
196        &self,
197        s: &mut (impl AsyncRead + AsyncWrite + Unpin + Send + Sync),
198        addr: K,
199        mut close_ch: Shared<oneshot::Receiver<()>>,
200    ) {
201        info!("[{:?}] new connection", addr);
202
203        if let Err(e) = self.rfb_handshake(s, &addr).await {
204            error!("[{:?}] could not complete handshake: {:?}", addr, e);
205            return;
206        }
207
208        if let Err(e) = self.rfb_initialization(s, &addr).await {
209            error!("[{:?}] could not complete handshake: {:?}", addr, e);
210            return;
211        }
212
213        let data = self.data.lock().await;
214        let mut output_pixel_format = data.input_pixel_format.clone();
215        drop(data);
216
217        let server = (self.server)(&addr);
218
219        loop {
220            let req = select! {
221                // Poll in the order written so we check for close first
222                biased;
223
224                _ = &mut close_ch => {
225                    info!("[{:?}] server stopping, closing connection with peer", addr);
226                    let _ = s.shutdown().await;
227                    return;
228                }
229
230                req = ClientMessage::read_from(s) => req,
231            };
232
233            match req {
234                Ok(client_msg) => match client_msg {
235                    ClientMessage::SetPixelFormat(pf) => {
236                        debug!("Rx [{:?}]: SetPixelFormat={:#?}", addr, pf);
237
238                        // TODO: invalid pixel formats?
239                        output_pixel_format = pf;
240                    }
241                    ClientMessage::SetEncodings(e) => {
242                        debug!("Rx [{:?}]: SetEncodings={:?}", addr, e);
243                    }
244                    ClientMessage::FramebufferUpdateRequest(f) => {
245                        debug!("Rx [{:?}]: FramebufferUpdateRequest={:?}", addr, f);
246
247                        let mut fbu = server.get_framebuffer_update().await;
248
249                        let data = self.data.lock().await;
250
251                        // We only need to change pixel formats if the client requested a different
252                        // one than what's specified in the input.
253                        //
254                        // For now, we only support transformations between 4-byte RGB formats, so
255                        // if the requested format isn't one of those, we'll just leave the pixels
256                        // as is.
257                        if data.input_pixel_format != output_pixel_format
258                            && data.input_pixel_format.is_rgb_888()
259                            && output_pixel_format.is_rgb_888()
260                        {
261                            debug!(
262                                "transforming: input={:#?}, output={:#?}",
263                                data.input_pixel_format, output_pixel_format
264                            );
265                            fbu = fbu.transform(&data.input_pixel_format, &output_pixel_format);
266                        } else if !(data.input_pixel_format.is_rgb_888()
267                            && output_pixel_format.is_rgb_888())
268                        {
269                            debug!("cannot transform between pixel formats (not rgb888): input.is_rgb_888()={}, output.is_rgb_888()={}", data.input_pixel_format.is_rgb_888(), output_pixel_format.is_rgb_888());
270                        } else {
271                            debug!("no input transformation needed");
272                        }
273
274                        if let Err(e) = fbu.write_to(s).await {
275                            error!(
276                                "[{:?}] could not write FramebufferUpdateRequest: {:?}",
277                                addr, e
278                            );
279                            return;
280                        }
281                        debug!("Tx [{:?}]: FramebufferUpdate", addr);
282                    }
283                    ClientMessage::KeyEvent(ke) => {
284                        trace!("Rx [{:?}]: KeyEvent={:?}", addr, ke);
285                        server.key_event(ke).await;
286                    }
287                    ClientMessage::PointerEvent(pe) => {
288                        trace!("Rx [{:?}: PointerEvent={:?}", addr, pe);
289                    }
290                    ClientMessage::ClientCutText(t) => {
291                        trace!("Rx [{:?}: ClientCutText={:?}", addr, t);
292                    }
293                },
294                Err(e) => {
295                    error!("[{:?}] error reading client message: {}", addr, e);
296                    return;
297                }
298            }
299        }
300    }
301
302    /// Start listening for incoming connections.
303    pub async fn start(self: &Arc<Self>) -> io::Result<()>  where K: From<SocketAddr>{
304        let listener = TcpListener::bind(self.config.addr).await?;
305
306        // Create a channel to signal the server to stop.
307        let (close_tx, close_rx) = oneshot::channel();
308        assert!(
309            self.stop_ch.lock().await.replace(close_tx).is_none(),
310            "server already started"
311        );
312        let mut close_rx = close_rx.shared();
313
314        loop {
315            let (mut client_sock, client_addr) = select! {
316                // Poll in the order written so we check for close first
317                biased;
318
319                _ = &mut close_rx => {
320                    info!("server stopping");
321                    // self.server.stop().await;
322                    return Ok(());
323                }
324
325                conn = listener.accept() => conn?,
326            };
327
328            let close_rx = close_rx.clone();
329            let server = self.clone();
330            tokio::spawn(async move {
331                server
332                    .handle_conn(&mut client_sock, client_addr.into(), close_rx)
333                    .await;
334            });
335        }
336    }
337
338    /// Stop the server (and disconnect any client) if it's running.
339    pub async fn stop(self: &Arc<Self>) {
340        if let Some(close_tx) = self.stop_ch.lock().await.take() {
341            let _ = close_tx.send(());
342        }
343    }
344}