plane_common/typed_socket/
client.rs

1use super::{ChannelMessage, Handshake, SocketAction, TypedSocket};
2use crate::controller_address::AuthorizedAddress;
3use crate::exponential_backoff::ExponentialBackoff;
4use crate::names::NodeName;
5use crate::version::plane_version_info;
6use crate::PlaneClientError;
7use futures_util::{SinkExt, StreamExt};
8use std::marker::PhantomData;
9use tokio::net::TcpStream;
10use tokio_tungstenite::{MaybeTlsStream, WebSocketStream};
11use tungstenite::handshake::client::generate_key;
12use tungstenite::http::{
13    header::{HeaderValue, AUTHORIZATION},
14    Method, Request,
15};
16use tungstenite::{error::ProtocolError, Message};
17
18type Socket = WebSocketStream<MaybeTlsStream<TcpStream>>;
19
20pub struct TypedSocketConnector<T: ChannelMessage> {
21    authorized_address: AuthorizedAddress,
22    backoff: ExponentialBackoff,
23    _phantom: PhantomData<T>,
24}
25
26impl<T: ChannelMessage> TypedSocketConnector<T> {
27    pub fn new(authorized_address: AuthorizedAddress) -> Self {
28        Self {
29            authorized_address,
30            backoff: ExponentialBackoff::default(),
31            _phantom: PhantomData,
32        }
33    }
34
35    /// Continually retry a connection, with exponential backoff and unlimited
36    /// retries.
37    ///
38    /// This is useful in a connection loop in places that are expected to
39    /// always be connected (e.g. the drone).
40    pub async fn connect_with_retry(&mut self, name: &impl NodeName) -> TypedSocket<T> {
41        loop {
42            self.backoff.wait().await;
43            match self.connect(name).await {
44                Ok(pair) => {
45                    self.backoff.defer_reset();
46                    return pair;
47                }
48                Err(e) => {
49                    tracing::error!(%e, "Error connecting to server; retrying.");
50                }
51            }
52        }
53    }
54
55    pub async fn connect<N: NodeName>(&self, name: &N) -> Result<TypedSocket<T>, PlaneClientError> {
56        let handshake = Handshake {
57            name: name.to_string(),
58            version: plane_version_info(),
59        };
60
61        let req = auth_url_to_request(&self.authorized_address)?;
62        let (mut socket, _) = tokio_tungstenite::connect_async(req).await?;
63
64        socket
65            .send(Message::Text(serde_json::to_string(&handshake)?))
66            .await?;
67
68        let msg = socket.next().await.ok_or(PlaneClientError::ConnectFailed(
69            "Socket closed before handshake received.",
70        ))??;
71        let msg = match msg {
72            Message::Text(msg) => msg,
73            msg => {
74                tracing::error!("Unexpected handshake message: {:?}", msg);
75                return Err(PlaneClientError::ConnectFailed(
76                    "Handshake message was not text.",
77                ));
78            }
79        };
80
81        let remote_handshake: Handshake = serde_json::from_str(&msg)?;
82        tracing::info!(
83            remote_version = %remote_handshake.version.version,
84            remote_hash = %remote_handshake.version.git_hash,
85            remote_name = %remote_handshake.name,
86            "Connected to server"
87        );
88
89        handshake.check_compat(&remote_handshake);
90
91        new_client(socket, remote_handshake).await
92    }
93}
94
95/// Creates a WebSocket request from an AuthorizedAddress.
96fn auth_url_to_request(addr: &AuthorizedAddress) -> Result<Request<()>, PlaneClientError> {
97    let mut request = Request::builder()
98        .method(Method::GET)
99        .uri(addr.url.as_str())
100        .header(
101            "Host",
102            addr.url
103                .host_str()
104                .ok_or(PlaneClientError::BadConfiguration(
105                    "URL does not have a hostname.",
106                ))?
107                .to_string(),
108        )
109        .header("Connection", "Upgrade")
110        .header("Upgrade", "websocket")
111        .header("Sec-WebSocket-Version", "13")
112        .header("Sec-WebSocket-Key", generate_key());
113
114    if let Some(bearer_header) = addr.bearer_header() {
115        request = request.header(
116            AUTHORIZATION,
117            HeaderValue::from_str(&bearer_header).expect("Bearer header is valid"),
118        );
119    }
120
121    Ok(request.body(()).expect("Request is valid"))
122}
123
124async fn new_client<T: ChannelMessage>(
125    mut socket: Socket,
126    remote_handshake: Handshake,
127) -> Result<TypedSocket<T>, PlaneClientError> {
128    let (send_to_client, recv_to_client) = tokio::sync::mpsc::channel::<T::Reply>(100);
129    let (send_from_client, mut recv_from_client) =
130        tokio::sync::mpsc::channel::<SocketAction<T>>(100);
131
132    tokio::spawn(async move {
133        loop {
134            tokio::select! {
135                message = recv_from_client.recv() => {
136                    match message {
137                        None => {
138                            let _ = socket.send(Message::Close(None)).await;
139                            break;
140                        }
141                        Some(SocketAction::Send(message)) => {
142                            let message = serde_json::to_string(&message).expect("Message is always serializable");
143                            if let Err(err) = socket.send(Message::Text(message.clone())).await {
144                                tracing::error!(?err, ?message, "Failed to send message on websocket.");
145                            }
146                        },
147                        Some(SocketAction::Close) => {
148                            recv_from_client.close();
149                        }
150                    }
151                }
152                v = socket.next() => {
153                    match v {
154                        Some(Ok(Message::Text(msg))) => {
155                            let result = match serde_json::from_str(&msg) {
156                                Ok(msg) => msg,
157                                Err(err) => {
158                                    tracing::error!(?err, "Failed to deserialize message.");
159                                    continue;
160                                }
161                            };
162                            if let Err(e) = send_to_client.try_send(result) {
163                                tracing::error!(%e, "Error sending message.");
164                            }
165                        }
166                        Some(Err(tungstenite::Error::Protocol(
167                            ProtocolError::ResetWithoutClosingHandshake,
168                        ))) => {
169                            // This is too common to report (it just means the connection was
170                            // lost instead of gracefully closed).
171                            break;
172                        }
173                        Some(msg) => {
174                            tracing::warn!("Received ignored message: {:?}", msg);
175                        }
176                        None => {
177                            tracing::error!("Connection closed.");
178                            break;
179                        }
180                    }
181                }
182            }
183        }
184    });
185
186    Ok(TypedSocket {
187        send: send_from_client,
188        recv: recv_to_client,
189        remote_handshake,
190    })
191}
192
193#[cfg(test)]
194mod test {
195    use crate::controller_address::AuthorizedAddress;
196
197    #[test]
198    fn test_url_no_token() {
199        let url = url::Url::parse("https://foo.bar.com/").unwrap();
200        let addr = AuthorizedAddress::from(url);
201        let request = super::auth_url_to_request(&addr).unwrap();
202        assert!(request.headers().get("Authorization").is_none());
203    }
204
205    #[test]
206    fn test_url_with_token() {
207        let url = url::Url::parse("https://abcdefg@foo.bar.com/").unwrap();
208        let addr = AuthorizedAddress::from(url);
209        let request = super::auth_url_to_request(&addr).unwrap();
210        assert_eq!(
211            request
212                .headers()
213                .get("Authorization")
214                .map(|d| d.to_str().unwrap()),
215            Some("Bearer abcdefg")
216        );
217        assert_eq!(
218            request.headers().get("Host").map(|d| d.to_str().unwrap()),
219            Some("foo.bar.com")
220        );
221    }
222}