1use std::io::{BufRead, BufReader, Write};
2use std::net::ToSocketAddrs;
3use std::os::unix::net::UnixStream;
4use std::path::{Path, PathBuf};
5use std::sync::Arc;
6use std::sync::Once;
7
8use rustls::pki_types::{CertificateDer, PrivateKeyDer, ServerName};
9use rustls::{ClientConfig, RootCertStore};
10use rustls_platform_verifier::BuilderVerifierExt;
11use serde_json::Value;
12use tokio::io::{AsyncReadExt, AsyncWriteExt};
13
14#[derive(Debug)]
16pub enum CliError {
17 NotRunning(PathBuf),
19 Io(std::io::Error),
21 Protocol(String),
23 Rpc { code: i64, message: String },
25 NotImplemented(String),
27}
28
29impl std::fmt::Display for CliError {
30 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
31 match self {
32 CliError::NotRunning(p) => write!(
33 f,
34 "could not connect to engine at {}. Is the engine running?",
35 p.display()
36 ),
37 CliError::Io(e) => write!(f, "IO error: {e}"),
38 CliError::Protocol(msg) => write!(f, "protocol error: {msg}"),
39 CliError::Rpc { code, message } => write!(f, "engine error ({code}): {message}"),
40 CliError::NotImplemented(msg) => write!(f, "{msg}"),
41 }
42 }
43}
44
45impl std::error::Error for CliError {}
46
47impl From<std::io::Error> for CliError {
48 fn from(e: std::io::Error) -> Self {
49 CliError::Io(e)
50 }
51}
52
53#[derive(Debug)]
58pub struct ControlPlaneClient {
59 endpoint: Endpoint,
60 next_id: u64,
61}
62
63#[derive(Debug, Clone)]
64pub enum Endpoint {
65 Unix(PathBuf),
66 Tcp {
67 host: String,
68 port: u16,
69 tls: TlsClientConfig,
70 },
71 Quic {
72 host: String,
73 port: u16,
74 tls: TlsClientConfig,
75 },
76}
77
78#[derive(Debug, Clone)]
79pub struct TlsClientConfig {
80 pub cert: PathBuf,
81 pub key: PathBuf,
82 pub ca: Option<PathBuf>,
83 pub sni: Option<String>,
84}
85
86impl ControlPlaneClient {
87 pub fn connect(socket_path: &Path) -> Result<Self, CliError> {
91 if !socket_path.exists() {
92 return Err(CliError::NotRunning(socket_path.to_owned()));
93 }
94 Ok(Self {
95 endpoint: Endpoint::Unix(socket_path.to_owned()),
96 next_id: 0,
97 })
98 }
99
100 pub fn connect_endpoint(endpoint: &Endpoint) -> Result<Self, CliError> {
101 match endpoint {
102 Endpoint::Unix(path) => Self::connect(path),
103 Endpoint::Tcp { .. } | Endpoint::Quic { .. } => Ok(Self {
104 endpoint: endpoint.clone(),
105 next_id: 0,
106 }),
107 }
108 }
109
110 pub fn call(&mut self, method: &str, params: Value) -> Result<Value, CliError> {
114 self.next_id += 1;
115 let req = serde_json::json!({
116 "id": self.next_id,
117 "method": method,
118 "params": params,
119 });
120
121 match &self.endpoint {
122 Endpoint::Unix(path) => call_unix(path, &req),
123 Endpoint::Tcp { host, port, tls } => call_tcp(host, *port, tls, &req),
124 Endpoint::Quic { host, port, tls } => call_quic(host, *port, tls, &req),
125 }
126 }
127}
128
129fn call_unix(path: &Path, req: &Value) -> Result<Value, CliError> {
130 let mut stream =
131 UnixStream::connect(path).map_err(|_| CliError::NotRunning(path.to_owned()))?;
132 let mut body = req.to_string();
133 body.push('\n');
134 stream.write_all(body.as_bytes())?;
135 let mut reader = BufReader::new(stream);
136 let mut line = String::new();
137 reader.read_line(&mut line)?;
138 parse_response(&line)
139}
140
141fn call_tcp(host: &str, port: u16, tls: &TlsClientConfig, req: &Value) -> Result<Value, CliError> {
142 let addr = format!("{host}:{port}");
143 let stream = std::net::TcpStream::connect(addr)?;
144 stream.set_nodelay(true).ok();
145 let server_name = tls.sni.clone().unwrap_or_else(|| host.to_string());
146 let config = build_client_config(tls, &server_name)?;
147 let name = ServerName::try_from(server_name)
148 .map_err(|_| CliError::Protocol("invalid TLS server name".to_string()))?;
149 let conn = rustls::ClientConnection::new(Arc::new(config), name)
150 .map_err(|e| CliError::Protocol(e.to_string()))?;
151 let mut tls_stream = rustls::StreamOwned::new(conn, stream);
152
153 let mut body = req.to_string();
154 body.push('\n');
155 tls_stream.write_all(body.as_bytes())?;
156 let mut reader = BufReader::new(tls_stream);
157 let mut line = String::new();
158 reader.read_line(&mut line)?;
159 parse_response(&line)
160}
161
162fn call_quic(host: &str, port: u16, tls: &TlsClientConfig, req: &Value) -> Result<Value, CliError> {
163 let host = host.to_string();
164 let server_name = tls.sni.clone().unwrap_or_else(|| host.clone());
165 let config = build_client_config(tls, &server_name)?;
166 let mut client_crypto = config;
167 client_crypto.alpn_protocols = vec![b"pd-control".to_vec()];
168
169 let client_tls: s2n_quic::provider::tls::rustls::Client = client_crypto.into();
170
171 let rt = tokio::runtime::Builder::new_current_thread()
172 .enable_all()
173 .build()
174 .map_err(CliError::Io)?;
175
176 rt.block_on(async move {
177 let addr = (host.as_str(), port)
178 .to_socket_addrs()
179 .map_err(|e| CliError::Protocol(e.to_string()))?
180 .next()
181 .ok_or_else(|| CliError::Protocol("invalid host:port".to_string()))?;
182
183 let client = s2n_quic::Client::builder()
184 .with_tls(client_tls)
185 .map_err(|e| CliError::Protocol(e.to_string()))?
186 .with_io("0.0.0.0:0")
187 .map_err(|e| CliError::Protocol(e.to_string()))?
188 .start()
189 .map_err(|e| CliError::Protocol(e.to_string()))?;
190
191 let connect = s2n_quic::client::Connect::new(addr).with_server_name(server_name);
192 let mut conn = client
193 .connect(connect)
194 .await
195 .map_err(|e| CliError::Protocol(e.to_string()))?;
196
197 let mut stream = conn
198 .open_bidirectional_stream()
199 .await
200 .map_err(|e| CliError::Protocol(e.to_string()))?;
201
202 let mut body = req.to_string();
203 body.push('\n');
204 stream
205 .write_all(body.as_bytes())
206 .await
207 .map_err(|e| CliError::Protocol(e.to_string()))?;
208 stream
209 .close()
210 .await
211 .map_err(|e| CliError::Protocol(e.to_string()))?;
212
213 let mut buf = Vec::new();
214 stream
215 .read_to_end(&mut buf)
216 .await
217 .map_err(|e| CliError::Protocol(e.to_string()))?;
218
219 let line = String::from_utf8_lossy(&buf);
220 parse_response(&line)
221 })
222}
223
224fn parse_response(line: &str) -> Result<Value, CliError> {
225 let resp: Value =
226 serde_json::from_str(line.trim()).map_err(|e| CliError::Protocol(e.to_string()))?;
227 if let Some(err) = resp.get("error") {
228 let code = err["code"].as_i64().unwrap_or(-1);
229 let message = err["message"]
230 .as_str()
231 .unwrap_or("unknown error")
232 .to_string();
233 return Err(CliError::Rpc { code, message });
234 }
235 Ok(resp["result"].clone())
236}
237
238fn build_client_config(
239 tls: &TlsClientConfig,
240 _server_name: &str,
241) -> Result<ClientConfig, CliError> {
242 ensure_rustls_provider();
243 let cert_chain = read_certs(&tls.cert)?;
244 let key = read_key(&tls.key)?;
245
246 let builder = if let Some(ca_path) = &tls.ca {
247 let mut roots = RootCertStore::empty();
248 for cert in read_certs(ca_path)? {
249 roots
250 .add(cert)
251 .map_err(|_| CliError::Protocol("invalid CA cert".to_string()))?;
252 }
253 ClientConfig::builder().with_root_certificates(roots)
254 } else {
255 ClientConfig::builder()
256 .with_platform_verifier()
257 .map_err(|e| CliError::Protocol(e.to_string()))?
258 };
259
260 builder
261 .with_client_auth_cert(cert_chain, key)
262 .map_err(|e| CliError::Protocol(e.to_string()))
263}
264
265fn ensure_rustls_provider() {
266 static INIT: Once = Once::new();
267 INIT.call_once(|| {
268 #[cfg(feature = "aws-lc-rs")]
269 let _ = rustls::crypto::aws_lc_rs::default_provider().install_default();
270 #[cfg(all(not(feature = "aws-lc-rs"), feature = "ring"))]
271 let _ = rustls::crypto::ring::default_provider().install_default();
272 });
273}
274
275fn read_certs(path: &Path) -> Result<Vec<CertificateDer<'static>>, CliError> {
276 let file = std::fs::File::open(path)?;
277 let mut reader = BufReader::new(file);
278 let certs = rustls_pemfile::certs(&mut reader)
279 .collect::<Result<Vec<_>, std::io::Error>>()
280 .map_err(CliError::Io)?;
281 if certs.is_empty() {
282 return Err(CliError::Protocol("no certificates found".to_string()));
283 }
284 Ok(certs.into_iter().map(|c| c.into_owned()).collect())
285}
286
287fn read_key(path: &Path) -> Result<PrivateKeyDer<'static>, CliError> {
288 let file = std::fs::File::open(path)?;
289 let mut reader = BufReader::new(file);
290 let key = rustls_pemfile::private_key(&mut reader)
291 .map_err(|_| CliError::Protocol("invalid private key".to_string()))?;
292 match key {
293 Some(k) => Ok(k),
294 None => Err(CliError::Protocol("no private key found".to_string())),
295 }
296}