plane_common/typed_socket/
client.rs1use super::{ChannelMessage, Handshake, SocketAction, TypedSocket};
2use crate::controller_address::AuthorizedAddress;
3use crate::exponential_backoff::ExponentialBackoff;
4use crate::names::NodeName;
5use crate::version::plane_version_info;
6use crate::PlaneClientError;
7use futures_util::{SinkExt, StreamExt};
8use std::marker::PhantomData;
9use tokio::net::TcpStream;
10use tokio_tungstenite::{MaybeTlsStream, WebSocketStream};
11use tungstenite::handshake::client::generate_key;
12use tungstenite::http::{
13 header::{HeaderValue, AUTHORIZATION},
14 Method, Request,
15};
16use tungstenite::{error::ProtocolError, Message};
17
18type Socket = WebSocketStream<MaybeTlsStream<TcpStream>>;
19
20pub struct TypedSocketConnector<T: ChannelMessage> {
21 authorized_address: AuthorizedAddress,
22 backoff: ExponentialBackoff,
23 _phantom: PhantomData<T>,
24}
25
26impl<T: ChannelMessage> TypedSocketConnector<T> {
27 pub fn new(authorized_address: AuthorizedAddress) -> Self {
28 Self {
29 authorized_address,
30 backoff: ExponentialBackoff::default(),
31 _phantom: PhantomData,
32 }
33 }
34
35 pub async fn connect_with_retry(&mut self, name: &impl NodeName) -> TypedSocket<T> {
41 loop {
42 self.backoff.wait().await;
43 match self.connect(name).await {
44 Ok(pair) => {
45 self.backoff.defer_reset();
46 return pair;
47 }
48 Err(e) => {
49 tracing::error!(%e, "Error connecting to server; retrying.");
50 }
51 }
52 }
53 }
54
55 pub async fn connect<N: NodeName>(&self, name: &N) -> Result<TypedSocket<T>, PlaneClientError> {
56 let handshake = Handshake {
57 name: name.to_string(),
58 version: plane_version_info(),
59 };
60
61 let req = auth_url_to_request(&self.authorized_address)?;
62 let (mut socket, _) = tokio_tungstenite::connect_async(req).await?;
63
64 socket
65 .send(Message::Text(serde_json::to_string(&handshake)?))
66 .await?;
67
68 let msg = socket.next().await.ok_or(PlaneClientError::ConnectFailed(
69 "Socket closed before handshake received.",
70 ))??;
71 let msg = match msg {
72 Message::Text(msg) => msg,
73 msg => {
74 tracing::error!("Unexpected handshake message: {:?}", msg);
75 return Err(PlaneClientError::ConnectFailed(
76 "Handshake message was not text.",
77 ));
78 }
79 };
80
81 let remote_handshake: Handshake = serde_json::from_str(&msg)?;
82 tracing::info!(
83 remote_version = %remote_handshake.version.version,
84 remote_hash = %remote_handshake.version.git_hash,
85 remote_name = %remote_handshake.name,
86 "Connected to server"
87 );
88
89 handshake.check_compat(&remote_handshake);
90
91 new_client(socket, remote_handshake).await
92 }
93}
94
95fn auth_url_to_request(addr: &AuthorizedAddress) -> Result<Request<()>, PlaneClientError> {
97 let mut request = Request::builder()
98 .method(Method::GET)
99 .uri(addr.url.as_str())
100 .header(
101 "Host",
102 addr.url
103 .host_str()
104 .ok_or(PlaneClientError::BadConfiguration(
105 "URL does not have a hostname.",
106 ))?
107 .to_string(),
108 )
109 .header("Connection", "Upgrade")
110 .header("Upgrade", "websocket")
111 .header("Sec-WebSocket-Version", "13")
112 .header("Sec-WebSocket-Key", generate_key());
113
114 if let Some(bearer_header) = addr.bearer_header() {
115 request = request.header(
116 AUTHORIZATION,
117 HeaderValue::from_str(&bearer_header).expect("Bearer header is valid"),
118 );
119 }
120
121 Ok(request.body(()).expect("Request is valid"))
122}
123
124async fn new_client<T: ChannelMessage>(
125 mut socket: Socket,
126 remote_handshake: Handshake,
127) -> Result<TypedSocket<T>, PlaneClientError> {
128 let (send_to_client, recv_to_client) = tokio::sync::mpsc::channel::<T::Reply>(100);
129 let (send_from_client, mut recv_from_client) =
130 tokio::sync::mpsc::channel::<SocketAction<T>>(100);
131
132 tokio::spawn(async move {
133 loop {
134 tokio::select! {
135 message = recv_from_client.recv() => {
136 match message {
137 None => {
138 let _ = socket.send(Message::Close(None)).await;
139 break;
140 }
141 Some(SocketAction::Send(message)) => {
142 let message = serde_json::to_string(&message).expect("Message is always serializable");
143 if let Err(err) = socket.send(Message::Text(message.clone())).await {
144 tracing::error!(?err, ?message, "Failed to send message on websocket.");
145 }
146 },
147 Some(SocketAction::Close) => {
148 recv_from_client.close();
149 }
150 }
151 }
152 v = socket.next() => {
153 match v {
154 Some(Ok(Message::Text(msg))) => {
155 let result = match serde_json::from_str(&msg) {
156 Ok(msg) => msg,
157 Err(err) => {
158 tracing::error!(?err, "Failed to deserialize message.");
159 continue;
160 }
161 };
162 if let Err(e) = send_to_client.try_send(result) {
163 tracing::error!(%e, "Error sending message.");
164 }
165 }
166 Some(Err(tungstenite::Error::Protocol(
167 ProtocolError::ResetWithoutClosingHandshake,
168 ))) => {
169 break;
172 }
173 Some(msg) => {
174 tracing::warn!("Received ignored message: {:?}", msg);
175 }
176 None => {
177 tracing::error!("Connection closed.");
178 break;
179 }
180 }
181 }
182 }
183 }
184 });
185
186 Ok(TypedSocket {
187 send: send_from_client,
188 recv: recv_to_client,
189 remote_handshake,
190 })
191}
192
193#[cfg(test)]
194mod test {
195 use crate::controller_address::AuthorizedAddress;
196
197 #[test]
198 fn test_url_no_token() {
199 let url = url::Url::parse("https://foo.bar.com/").unwrap();
200 let addr = AuthorizedAddress::from(url);
201 let request = super::auth_url_to_request(&addr).unwrap();
202 assert!(request.headers().get("Authorization").is_none());
203 }
204
205 #[test]
206 fn test_url_with_token() {
207 let url = url::Url::parse("https://abcdefg@foo.bar.com/").unwrap();
208 let addr = AuthorizedAddress::from(url);
209 let request = super::auth_url_to_request(&addr).unwrap();
210 assert_eq!(
211 request
212 .headers()
213 .get("Authorization")
214 .map(|d| d.to_str().unwrap()),
215 Some("Bearer abcdefg")
216 );
217 assert_eq!(
218 request.headers().get("Host").map(|d| d.to_str().unwrap()),
219 Some("foo.bar.com")
220 );
221 }
222}