Skip to main content

palladium_cli/
client.rs

1use std::io::{BufRead, BufReader, Write};
2use std::net::ToSocketAddrs;
3use std::os::unix::net::UnixStream;
4use std::path::{Path, PathBuf};
5use std::sync::Arc;
6use std::sync::Once;
7
8use rustls::pki_types::{CertificateDer, PrivateKeyDer, ServerName};
9use rustls::{ClientConfig, RootCertStore};
10use rustls_platform_verifier::BuilderVerifierExt;
11use serde_json::Value;
12use tokio::io::{AsyncReadExt, AsyncWriteExt};
13
14/// Error type for CLI operations.
15#[derive(Debug)]
16pub enum CliError {
17    /// Engine is not running at the given socket path.
18    NotRunning(PathBuf),
19    /// IO error communicating with the engine.
20    Io(std::io::Error),
21    /// Unexpected protocol response.
22    Protocol(String),
23    /// Engine returned a JSON-RPC error.
24    Rpc { code: i64, message: String },
25    /// Feature not implemented.
26    NotImplemented(String),
27}
28
29impl std::fmt::Display for CliError {
30    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
31        match self {
32            CliError::NotRunning(p) => write!(
33                f,
34                "could not connect to engine at {}. Is the engine running?",
35                p.display()
36            ),
37            CliError::Io(e) => write!(f, "IO error: {e}"),
38            CliError::Protocol(msg) => write!(f, "protocol error: {msg}"),
39            CliError::Rpc { code, message } => write!(f, "engine error ({code}): {message}"),
40            CliError::NotImplemented(msg) => write!(f, "{msg}"),
41        }
42    }
43}
44
45impl std::error::Error for CliError {}
46
47impl From<std::io::Error> for CliError {
48    fn from(e: std::io::Error) -> Self {
49        CliError::Io(e)
50    }
51}
52
53/// Synchronous JSON-RPC 2.0 client over a Unix domain socket.
54///
55/// Each `call` opens a fresh connection, sends one newline-terminated request,
56/// and reads one newline-terminated response.
57#[derive(Debug)]
58pub struct ControlPlaneClient {
59    endpoint: Endpoint,
60    next_id: u64,
61}
62
63#[derive(Debug, Clone)]
64pub enum Endpoint {
65    Unix(PathBuf),
66    Tcp {
67        host: String,
68        port: u16,
69        tls: TlsClientConfig,
70    },
71    Quic {
72        host: String,
73        port: u16,
74        tls: TlsClientConfig,
75    },
76}
77
78#[derive(Debug, Clone)]
79pub struct TlsClientConfig {
80    pub cert: PathBuf,
81    pub key: PathBuf,
82    pub ca: Option<PathBuf>,
83    pub sni: Option<String>,
84}
85
86impl ControlPlaneClient {
87    /// Create a client for the given socket path.
88    ///
89    /// Returns `CliError::NotRunning` if the socket file does not exist.
90    pub fn connect(socket_path: &Path) -> Result<Self, CliError> {
91        if !socket_path.exists() {
92            return Err(CliError::NotRunning(socket_path.to_owned()));
93        }
94        Ok(Self {
95            endpoint: Endpoint::Unix(socket_path.to_owned()),
96            next_id: 0,
97        })
98    }
99
100    pub fn connect_endpoint(endpoint: &Endpoint) -> Result<Self, CliError> {
101        match endpoint {
102            Endpoint::Unix(path) => Self::connect(path),
103            Endpoint::Tcp { .. } | Endpoint::Quic { .. } => Ok(Self {
104                endpoint: endpoint.clone(),
105                next_id: 0,
106            }),
107        }
108    }
109
110    /// Send a JSON-RPC request and return the `result` field on success.
111    ///
112    /// Returns `CliError::Rpc` if the server sends an `error` object.
113    pub fn call(&mut self, method: &str, params: Value) -> Result<Value, CliError> {
114        self.next_id += 1;
115        let req = serde_json::json!({
116            "id": self.next_id,
117            "method": method,
118            "params": params,
119        });
120
121        match &self.endpoint {
122            Endpoint::Unix(path) => call_unix(path, &req),
123            Endpoint::Tcp { host, port, tls } => call_tcp(host, *port, tls, &req),
124            Endpoint::Quic { host, port, tls } => call_quic(host, *port, tls, &req),
125        }
126    }
127}
128
129fn call_unix(path: &Path, req: &Value) -> Result<Value, CliError> {
130    let mut stream =
131        UnixStream::connect(path).map_err(|_| CliError::NotRunning(path.to_owned()))?;
132    let mut body = req.to_string();
133    body.push('\n');
134    stream.write_all(body.as_bytes())?;
135    let mut reader = BufReader::new(stream);
136    let mut line = String::new();
137    reader.read_line(&mut line)?;
138    parse_response(&line)
139}
140
141fn call_tcp(host: &str, port: u16, tls: &TlsClientConfig, req: &Value) -> Result<Value, CliError> {
142    let addr = format!("{host}:{port}");
143    let stream = std::net::TcpStream::connect(addr)?;
144    stream.set_nodelay(true).ok();
145    let server_name = tls.sni.clone().unwrap_or_else(|| host.to_string());
146    let config = build_client_config(tls, &server_name)?;
147    let name = ServerName::try_from(server_name)
148        .map_err(|_| CliError::Protocol("invalid TLS server name".to_string()))?;
149    let conn = rustls::ClientConnection::new(Arc::new(config), name)
150        .map_err(|e| CliError::Protocol(e.to_string()))?;
151    let mut tls_stream = rustls::StreamOwned::new(conn, stream);
152
153    let mut body = req.to_string();
154    body.push('\n');
155    tls_stream.write_all(body.as_bytes())?;
156    let mut reader = BufReader::new(tls_stream);
157    let mut line = String::new();
158    reader.read_line(&mut line)?;
159    parse_response(&line)
160}
161
162fn call_quic(host: &str, port: u16, tls: &TlsClientConfig, req: &Value) -> Result<Value, CliError> {
163    let host = host.to_string();
164    let server_name = tls.sni.clone().unwrap_or_else(|| host.clone());
165    let config = build_client_config(tls, &server_name)?;
166    let mut client_crypto = config;
167    client_crypto.alpn_protocols = vec![b"pd-control".to_vec()];
168
169    let client_tls: s2n_quic::provider::tls::rustls::Client = client_crypto.into();
170
171    let rt = tokio::runtime::Builder::new_current_thread()
172        .enable_all()
173        .build()
174        .map_err(CliError::Io)?;
175
176    rt.block_on(async move {
177        let addr = (host.as_str(), port)
178            .to_socket_addrs()
179            .map_err(|e| CliError::Protocol(e.to_string()))?
180            .next()
181            .ok_or_else(|| CliError::Protocol("invalid host:port".to_string()))?;
182
183        let client = s2n_quic::Client::builder()
184            .with_tls(client_tls)
185            .map_err(|e| CliError::Protocol(e.to_string()))?
186            .with_io("0.0.0.0:0")
187            .map_err(|e| CliError::Protocol(e.to_string()))?
188            .start()
189            .map_err(|e| CliError::Protocol(e.to_string()))?;
190
191        let connect = s2n_quic::client::Connect::new(addr).with_server_name(server_name);
192        let mut conn = client
193            .connect(connect)
194            .await
195            .map_err(|e| CliError::Protocol(e.to_string()))?;
196
197        let mut stream = conn
198            .open_bidirectional_stream()
199            .await
200            .map_err(|e| CliError::Protocol(e.to_string()))?;
201
202        let mut body = req.to_string();
203        body.push('\n');
204        stream
205            .write_all(body.as_bytes())
206            .await
207            .map_err(|e| CliError::Protocol(e.to_string()))?;
208        stream
209            .close()
210            .await
211            .map_err(|e| CliError::Protocol(e.to_string()))?;
212
213        let mut buf = Vec::new();
214        stream
215            .read_to_end(&mut buf)
216            .await
217            .map_err(|e| CliError::Protocol(e.to_string()))?;
218
219        let line = String::from_utf8_lossy(&buf);
220        parse_response(&line)
221    })
222}
223
224fn parse_response(line: &str) -> Result<Value, CliError> {
225    let resp: Value =
226        serde_json::from_str(line.trim()).map_err(|e| CliError::Protocol(e.to_string()))?;
227    if let Some(err) = resp.get("error") {
228        let code = err["code"].as_i64().unwrap_or(-1);
229        let message = err["message"]
230            .as_str()
231            .unwrap_or("unknown error")
232            .to_string();
233        return Err(CliError::Rpc { code, message });
234    }
235    Ok(resp["result"].clone())
236}
237
238fn build_client_config(
239    tls: &TlsClientConfig,
240    _server_name: &str,
241) -> Result<ClientConfig, CliError> {
242    ensure_rustls_provider();
243    let cert_chain = read_certs(&tls.cert)?;
244    let key = read_key(&tls.key)?;
245
246    let builder = if let Some(ca_path) = &tls.ca {
247        let mut roots = RootCertStore::empty();
248        for cert in read_certs(ca_path)? {
249            roots
250                .add(cert)
251                .map_err(|_| CliError::Protocol("invalid CA cert".to_string()))?;
252        }
253        ClientConfig::builder().with_root_certificates(roots)
254    } else {
255        ClientConfig::builder()
256            .with_platform_verifier()
257            .map_err(|e| CliError::Protocol(e.to_string()))?
258    };
259
260    builder
261        .with_client_auth_cert(cert_chain, key)
262        .map_err(|e| CliError::Protocol(e.to_string()))
263}
264
265fn ensure_rustls_provider() {
266    static INIT: Once = Once::new();
267    INIT.call_once(|| {
268        #[cfg(feature = "aws-lc-rs")]
269        let _ = rustls::crypto::aws_lc_rs::default_provider().install_default();
270        #[cfg(all(not(feature = "aws-lc-rs"), feature = "ring"))]
271        let _ = rustls::crypto::ring::default_provider().install_default();
272    });
273}
274
275fn read_certs(path: &Path) -> Result<Vec<CertificateDer<'static>>, CliError> {
276    let file = std::fs::File::open(path)?;
277    let mut reader = BufReader::new(file);
278    let certs = rustls_pemfile::certs(&mut reader)
279        .collect::<Result<Vec<_>, std::io::Error>>()
280        .map_err(CliError::Io)?;
281    if certs.is_empty() {
282        return Err(CliError::Protocol("no certificates found".to_string()));
283    }
284    Ok(certs.into_iter().map(|c| c.into_owned()).collect())
285}
286
287fn read_key(path: &Path) -> Result<PrivateKeyDer<'static>, CliError> {
288    let file = std::fs::File::open(path)?;
289    let mut reader = BufReader::new(file);
290    let key = rustls_pemfile::private_key(&mut reader)
291        .map_err(|_| CliError::Protocol("invalid private key".to_string()))?;
292    match key {
293        Some(k) => Ok(k),
294        None => Err(CliError::Protocol("no private key found".to_string())),
295    }
296}