1use self::controller_address::AuthorizedAddress;
2use crate::{
3 names::{BackendName, DroneName},
4 protocol::{MessageFromDns, MessageFromDrone, MessageFromProxy},
5 typed_socket::client::TypedSocketConnector,
6 types::{
7 backend_state::BackendStatusStreamEntry, ClusterName, ClusterState, ConnectRequest,
8 ConnectResponse, DrainResult, DronePoolName, RevokeRequest,
9 },
10};
11use protocol::{ApiError, StatusResponse};
12use reqwest::{Response, StatusCode};
13use serde::de::DeserializeOwned;
14use url::{form_urlencoded, Url};
15
16pub mod controller_address;
17pub mod exponential_backoff;
18pub mod log_types;
19pub mod names;
20pub mod protocol;
21pub mod serialization;
22pub mod sse;
23pub mod typed_socket;
24pub mod types;
25pub mod util;
26pub mod version;
27
28#[derive(thiserror::Error, Debug)]
29pub enum PlaneClientError {
30 #[error("HTTP error: {0}")]
31 Http(#[from] reqwest::Error),
32
33 #[error("URL error: {0}")]
34 Url(#[from] url::ParseError),
35
36 #[error("JSON error: {0}")]
37 Json(#[from] serde_json::Error),
38
39 #[error("Unexpected status code: {0}")]
40 UnexpectedStatus(StatusCode),
41
42 #[error("API error: {0} ({1})")]
43 PlaneError(ApiError, StatusCode),
44
45 #[error("Failed to connect.")]
46 ConnectFailed(&'static str),
47
48 #[error("Bad configuration.")]
49 BadConfiguration(&'static str),
50
51 #[error("WebSocket error: {0}")]
52 Tungstenite(#[from] tokio_tungstenite::tungstenite::Error),
53
54 #[error("Send error")]
55 SendFailed,
56}
57
58#[derive(Clone)]
59pub struct PlaneClient {
60 client: reqwest::Client,
61 controller_address: AuthorizedAddress,
62}
63
64impl PlaneClient {
65 pub fn new(base_url: Url) -> Self {
66 let client = reqwest::Client::new();
67 let controller_address = AuthorizedAddress::from(base_url);
68
69 Self {
70 client,
71 controller_address,
72 }
73 }
74
75 pub async fn status(&self) -> Result<StatusResponse, PlaneClientError> {
76 let addr = self.controller_address.join("/ctrl/status");
77 authed_get(&self.client, &addr).await
78 }
79
80 pub fn drone_connection(
81 &self,
82 cluster: &ClusterName,
83 pool: &DronePoolName,
84 ) -> TypedSocketConnector<MessageFromDrone> {
85 let base_path = format!("/ctrl/c/{}/drone-socket", cluster);
86 let addr = if pool.is_default() {
87 self.controller_address.join(&base_path)
88 } else {
89 let encoded_pool: String =
90 form_urlencoded::byte_serialize(pool.as_str().as_bytes()).collect();
91 self.controller_address
92 .join(&format!("{}?pool={}", base_path, encoded_pool))
93 }
94 .to_websocket_address();
95 TypedSocketConnector::new(addr)
96 }
97
98 pub fn proxy_connection(
99 &self,
100 cluster: &ClusterName,
101 ) -> TypedSocketConnector<MessageFromProxy> {
102 let addr = self
103 .controller_address
104 .join(&format!("/ctrl/c/{}/proxy-socket", cluster))
105 .to_websocket_address();
106 TypedSocketConnector::new(addr)
107 }
108
109 pub fn dns_connection(&self) -> TypedSocketConnector<MessageFromDns> {
110 let url = self
111 .controller_address
112 .join("/ctrl/dns-socket")
113 .to_websocket_address();
114 TypedSocketConnector::new(url)
115 }
116
117 pub async fn connect(
118 &self,
119 connect_request: &ConnectRequest,
120 ) -> Result<ConnectResponse, PlaneClientError> {
121 let addr = self.controller_address.join("/ctrl/connect");
122
123 let response = authed_post(&self.client, &addr, connect_request).await?;
124 Ok(response)
125 }
126
127 pub async fn drain(
128 &self,
129 cluster: &ClusterName,
130 drone: &DroneName,
131 ) -> Result<DrainResult, PlaneClientError> {
132 let addr = self
133 .controller_address
134 .join(&format!("/ctrl/c/{}/d/{}/drain", cluster, drone));
135
136 let result: DrainResult = authed_post(&self.client, &addr, &()).await?;
137 Ok(result)
138 }
139 pub async fn soft_terminate(&self, backend_id: &BackendName) -> Result<(), PlaneClientError> {
140 let addr = self
141 .controller_address
142 .join(&format!("/ctrl/b/{}/soft-terminate", backend_id));
143
144 let _: () = authed_post(&self.client, &addr, &()).await?;
145 Ok(())
146 }
147
148 pub async fn hard_terminate(&self, backend_id: &BackendName) -> Result<(), PlaneClientError> {
149 let addr = self
150 .controller_address
151 .join(&format!("/ctrl/b/{}/hard-terminate", backend_id));
152
153 let _: () = authed_post(&self.client, &addr, &()).await?;
154 Ok(())
155 }
156
157 pub async fn revoke(&self, request: &RevokeRequest) -> Result<(), PlaneClientError> {
158 let addr = self.controller_address.join("/ctrl/revoke");
159
160 let _: () = authed_post(&self.client, &addr, &request).await?;
161 Ok(())
162 }
163
164 pub fn backend_status_url(&self, backend_id: &BackendName) -> Url {
165 self.controller_address
166 .join(&format!("/pub/b/{}/status", backend_id))
167 .url
168 }
169
170 pub async fn backend_status(
171 &self,
172 backend_id: &BackendName,
173 ) -> Result<BackendStatusStreamEntry, PlaneClientError> {
174 let url = self.backend_status_url(backend_id);
175
176 let response = self.client.get(url).send().await?;
177 let status: BackendStatusStreamEntry = get_response(response).await?;
178 Ok(status)
179 }
180
181 pub fn backend_status_stream_url(&self, backend_id: &BackendName) -> Url {
182 self.controller_address
183 .join(&format!("/pub/b/{}/status-stream", backend_id))
184 .url
185 }
186
187 pub async fn backend_status_stream(
188 &self,
189 backend_id: &BackendName,
190 ) -> Result<sse::SseStream<BackendStatusStreamEntry>, PlaneClientError> {
191 let url = self.backend_status_stream_url(backend_id);
192
193 let stream = sse::sse_request(url, self.client.clone()).await?;
194 Ok(stream)
195 }
196
197 pub async fn cluster_state(
198 &self,
199 cluster: &ClusterName,
200 ) -> Result<ClusterState, PlaneClientError> {
201 let url = self
202 .controller_address
203 .join(&format!("/ctrl/c/{}/state", cluster));
204 let cluster_state: ClusterState = authed_get(&self.client, &url).await?;
205 Ok(cluster_state)
206 }
207
208 pub async fn health_check(&self) -> Result<(), PlaneClientError> {
209 let url = self.controller_address.join("/pub/health");
210 self.client.get(url.url).send().await?;
211 Ok(())
212 }
213}
214
215async fn get_response<T: DeserializeOwned>(response: Response) -> Result<T, PlaneClientError> {
216 if response.status().is_success() {
217 Ok(response.json::<T>().await?)
218 } else {
219 let url = response.url().to_string();
220 let status = response.status();
221 if status.is_server_error() {
222 tracing::error!(?url, ?status, "Got 5xx response from Plane API server.");
223 } else {
224 tracing::warn!(
225 ?url,
226 ?status,
227 "Got unsuccessful response from Plane API server."
228 );
229 }
230 if let Ok(api_error) = response.json::<ApiError>().await {
231 Err(PlaneClientError::PlaneError(api_error, status))
232 } else {
233 Err(PlaneClientError::UnexpectedStatus(status))
234 }
235 }
236}
237
238async fn authed_get<T: DeserializeOwned>(
239 client: &reqwest::Client,
240 addr: &AuthorizedAddress,
241) -> Result<T, PlaneClientError> {
242 let mut req = client.get(addr.url.clone());
243 if let Some(header) = addr.bearer_header() {
244 req = req.header("Authorization", header);
245 }
246
247 let response = req.send().await?;
248 get_response(response).await
249}
250
251async fn authed_post<T: DeserializeOwned>(
252 client: &reqwest::Client,
253 addr: &AuthorizedAddress,
254 body: &impl serde::Serialize,
255) -> Result<T, PlaneClientError> {
256 let mut req = client.post(addr.url.clone());
257 if let Some(header) = addr.bearer_header() {
258 req = req.header("Authorization", header);
259 }
260
261 let response = req.json(body).send().await?;
262 get_response(response).await
263}