1use std::time::{SystemTime, UNIX_EPOCH};
2
3use super::{into_approve_pairing_error, into_verify_pairing_data_error, Error};
4use crate::{
5 credentials::{NodeIdProvider, RuneProvider, TlsConfigProvider},
6 pb::{
7 self,
8 scheduler::{
9 pairing_client::PairingClient, ApprovePairingRequest, GetPairingDataRequest,
10 GetPairingDataResponse,
11 },
12 },
13};
14use bytes::BufMut as _;
15use picky::{pem::Pem, x509::Csr};
16use picky_asn1_x509::{PublicKey, SubjectPublicKeyInfo};
17use ring::{
18 rand,
19 signature::{self, EcdsaKeyPair, KeyPair},
20};
21use rustls_pemfile as pemfile;
22use tonic::transport::Channel;
23
24type Result<T, E = super::Error> = core::result::Result<T, E>;
25
26pub struct Connected(PairingClient<Channel>);
27pub struct Unconnected();
28
29pub struct Client<T, C: TlsConfigProvider + RuneProvider + NodeIdProvider> {
30 inner: T,
31 uri: String,
32 creds: C,
33}
34
35impl<C: TlsConfigProvider + RuneProvider + NodeIdProvider> Client<Unconnected, C> {
36 pub fn new(creds: C) -> Result<Client<Unconnected, C>> {
37 Ok(Self {
38 inner: Unconnected(),
39 uri: crate::utils::scheduler_uri(),
40 creds,
41 })
42 }
43
44 pub fn with_uri(mut self, uri: String) -> Client<Unconnected, C> {
45 self.uri = uri;
46 self
47 }
48
49 pub async fn connect(self) -> Result<Client<Connected, C>> {
50 let tls = self.creds.tls_config();
51 let channel = tonic::transport::Endpoint::from_shared(self.uri.clone())?
52 .tls_config(tls.inner)?
53 .tcp_keepalive(Some(crate::TCP_KEEPALIVE))
54 .http2_keep_alive_interval(crate::TCP_KEEPALIVE)
55 .keep_alive_timeout(crate::TCP_KEEPALIVE_TIMEOUT)
56 .keep_alive_while_idle(true)
57 .connect_lazy();
58
59 let inner = PairingClient::new(channel);
60
61 Ok(Client {
62 inner: Connected(inner),
63 uri: self.uri,
64 creds: self.creds,
65 })
66 }
67}
68
69impl<C: TlsConfigProvider + RuneProvider + NodeIdProvider> Client<Connected, C> {
70 pub async fn get_pairing_data(&self, device_id: &str) -> Result<GetPairingDataResponse> {
71 use tokio::time::{sleep, Duration, Instant};
72
73 let deadline = Instant::now() + Duration::from_secs(10);
77
78 loop {
79 let result = self
80 .inner
81 .0
82 .clone()
83 .get_pairing_data(GetPairingDataRequest {
84 device_id: device_id.to_string(),
85 })
86 .await;
87
88 match result {
89 Ok(response) => return Ok(response.into_inner()),
90 Err(_) if Instant::now() < deadline => {
91 sleep(Duration::from_millis(100)).await;
92 continue;
93 }
94 Err(e) => return Err(e.into()),
95 }
96 }
97 }
98
99 pub async fn approve_pairing(
100 &self,
101 device_id: &str,
102 device_name: &str,
103 restrs: &str,
104 ) -> Result<pb::greenlight::Empty> {
105 let timestamp = SystemTime::now()
106 .duration_since(UNIX_EPOCH)
107 .map_err(into_approve_pairing_error)?
108 .as_secs();
109
110 let node_id = self.creds.node_id()?;
111
112 let mut buf = vec![];
114 buf.put(device_id.as_bytes());
115 buf.put_u64(timestamp);
116 buf.put(&node_id[..]);
117 buf.put(device_name.as_bytes());
118 buf.put(restrs.as_bytes());
119
120 let tls = self.creds.tls_config();
121 let tls_key = tls
122 .clone()
123 .private_key
124 .ok_or(Error::BuildClientError("empty tls private key".to_string()))?;
125
126 let key = {
128 let mut key = std::io::Cursor::new(&tls_key);
129 pemfile::pkcs8_private_keys(&mut key)
130 .map_err(into_approve_pairing_error)?
131 .remove(0)
132 };
133 let kp =
134 EcdsaKeyPair::from_pkcs8(&signature::ECDSA_P256_SHA256_FIXED_SIGNING, key.as_ref())
135 .map_err(into_approve_pairing_error)?;
136 let rng = rand::SystemRandom::new();
137 let sig = kp
138 .sign(&rng, &buf)
139 .map_err(into_approve_pairing_error)?
140 .as_ref()
141 .to_vec();
142
143 Ok(self
145 .inner
146 .0
147 .clone()
148 .approve_pairing(ApprovePairingRequest {
149 device_id: device_id.to_string(),
150 timestamp,
151 device_name: device_name.to_string(),
152 restrictions: restrs.to_string(),
153 sig: sig,
154 rune: self.creds.rune(),
155 pubkey: kp.public_key().as_ref().to_vec(),
156 })
157 .await?
158 .into_inner())
159 }
160
161 pub fn verify_pairing_data(data: GetPairingDataResponse) -> Result<()> {
162 let mut crs = std::io::Cursor::new(&data.csr);
163 let pem = Pem::read_from(&mut crs).map_err(into_verify_pairing_data_error)?;
164 let csr = Csr::from_pem(&pem).map_err(into_verify_pairing_data_error)?;
165 let sub_pk_der = csr
166 .public_key()
167 .to_der()
168 .map_err(into_verify_pairing_data_error)?;
169 let sub_pk_info: SubjectPublicKeyInfo =
170 picky_asn1_der::from_bytes(&sub_pk_der).map_err(into_verify_pairing_data_error)?;
171
172 if let PublicKey::Ec(bs) = sub_pk_info.subject_public_key {
173 let pk = hex::encode(bs.0.payload_view());
174
175 if pk == data.device_id
176 && Self::restriction_contains_pubkey_exactly_once(
177 &data.restrictions,
178 &data.device_id,
179 )
180 {
181 Ok(())
182 } else {
183 Err(Error::VerifyPairingDataError(format!(
184 "device id {} does not match public key {}",
185 data.device_id, pk
186 )))
187 }
188 } else {
189 Err(Error::VerifyPairingDataError(format!(
190 "public key is not ecdsa"
191 )))
192 }
193 }
194
195 fn restriction_contains_pubkey_exactly_once(s: &str, pubkey: &str) -> bool {
199 let search_field = format!("pubkey={}", pubkey);
200 match s.find(&search_field) {
201 Some(index) => {
202 if index > 0 && s.chars().nth(index - 1) == Some('|') {
204 return false;
205 }
206
207 let end_index = index + search_field.len();
209 if end_index < s.len() && s.chars().nth(end_index) == Some('|') {
210 return false;
211 }
212
213 s.matches(&search_field).count() == 1
215 }
216 None => false,
217 }
218 }
219}
220
221#[cfg(test)]
222pub mod tests {
223 use super::*;
224 use crate::{credentials, tls};
225
226 #[test]
227 fn test_verify_pairing_data() {
228 let kp = tls::generate_ecdsa_key_pair();
229 let device_cert = tls::generate_self_signed_device_cert(
230 &hex::encode("00"),
231 "my-device",
232 vec!["localhost".into()],
233 Some(kp),
234 );
235 let csr = device_cert.serialize_request_pem().unwrap();
236 let pk = hex::encode(device_cert.get_key_pair().public_key_raw());
237
238 let pd = GetPairingDataResponse {
240 device_id: pk.clone(),
241 csr: csr.clone().into_bytes(),
242 device_name: "my-device".to_string(),
243 description: "".to_string(),
244 restrictions: format!("pubkey={}", pk.clone()),
245 };
246 assert!(Client::<Connected, credentials::Device>::verify_pairing_data(pd).is_ok());
247
248 let pd = GetPairingDataResponse {
250 device_id: pk.clone(),
251 csr: csr.clone().into_bytes(),
252 device_name: "my-device".to_string(),
253 description: "".to_string(),
254 restrictions: format!("pubkey={}", "02000000"),
255 };
256 assert!(Client::<Connected, credentials::Device>::verify_pairing_data(pd).is_err());
257
258 let pd = GetPairingDataResponse {
260 device_id: pk.clone(),
261 csr: csr.clone().into_bytes(),
262 device_name: "my-device".to_string(),
263 description: "".to_string(),
264 restrictions: format!("pubkey={}|pubkey=02000000", pk),
265 };
266 assert!(Client::<Connected, credentials::Device>::verify_pairing_data(pd).is_err());
267
268 let pd = GetPairingDataResponse {
270 device_id: "00".to_string(),
271 csr: csr.into_bytes(),
272 device_name: "my-device".to_string(),
273 description: "".to_string(),
274 restrictions: format!("pubkey={}", pk.clone()),
275 };
276 assert!(Client::<Connected, credentials::Device>::verify_pairing_data(pd).is_err());
277 }
278}