1use base64::Engine;
2use p256::ecdsa::Signature;
3use p256::ecdsa::SigningKey;
4use p256::ecdsa::signature::Signer as _;
5use rand_core::OsRng;
6use serde::{Deserialize, Serialize};
7use std::collections::HashMap;
8use std::path::{Path, PathBuf};
9use std::sync::{Arc, Mutex};
10use std::time::{Duration, SystemTime, UNIX_EPOCH};
11use tonic::metadata::MetadataValue;
12use tonic::{Request, Status};
13
14#[derive(thiserror::Error, Debug)]
16pub enum AuthError {
17 #[error("Unsupported operation: {0}")]
18 Unsupported(String),
19
20 #[error("Failed to generate key")]
21 FailedToGenerateKey,
22
23 #[error("Failed to load key")]
24 FailedToLoadKey,
25
26 #[error("Failed to create JWKS directory")]
27 FailedToCreateJwksDir(#[source] std::io::Error),
28
29 #[error("Failed to serialize JWKS")]
30 FailedToSerializeJwks,
31
32 #[error("Failed to write JWKS file")]
33 FailedToWriteJwksFile(#[source] std::io::Error),
34
35 #[error("Timeout waiting for JWK registration")]
36 JwksRegistrationTimeout,
37
38 #[error("Unspecified signing error")]
39 Unspecified,
40}
41
42#[derive(Debug)]
44pub enum AuthScheme {
45 None,
47 Bearer,
49 Jwt { issuer: String, jwks_dir: PathBuf },
54}
55
56pub struct BearerToken {
58 pub token: String,
59 pub expires_at: SystemTime,
60}
61
62pub(crate) trait TokenProvider: Send + Sync + 'static {
64 fn auth_scheme(&self) -> &AuthScheme;
65 fn token_for_aud(&self, aud: &str) -> Result<String, AuthError>;
66 fn export_token(&self, _auds: &[&str], _ttl: Duration) -> Result<BearerToken, AuthError> {
67 Err(AuthError::Unsupported("export_token".into()))
68 }
69}
70
71#[derive(Clone)]
73pub struct NoOpTokenProvider;
74
75impl TokenProvider for NoOpTokenProvider {
76 fn auth_scheme(&self) -> &AuthScheme {
77 &AuthScheme::None
78 }
79 fn token_for_aud(&self, _aud: &str) -> Result<String, AuthError> {
80 Ok(String::new())
81 }
82}
83
84#[derive(Clone)]
87pub struct BearerTokenProvider {
88 token: String,
89}
90
91impl BearerTokenProvider {
92 pub fn new(token: String) -> Self {
93 Self { token }
94 }
95}
96
97impl TokenProvider for BearerTokenProvider {
98 fn auth_scheme(&self) -> &AuthScheme {
99 &AuthScheme::Bearer
100 }
101 fn token_for_aud(&self, _aud: &str) -> Result<String, AuthError> {
102 Ok(self.token.clone())
103 }
104}
105
106#[derive(Debug, Clone, Serialize, Deserialize)]
108pub struct AllowlistEntry {
109 pub iss: String,
110 pub allowed: Vec<String>,
111 pub protected: Vec<String>,
112}
113
114#[derive(Debug, Clone, Serialize, Deserialize)]
116pub struct GrpcAllowlist {
117 pub unprotected: Vec<String>,
118 pub allowlist: Vec<AllowlistEntry>,
119}
120
121impl GrpcAllowlist {
122 pub fn default_for_issuer(issuer: &str) -> Self {
124 Self {
125 unprotected: vec![],
126 allowlist: vec![
127 AllowlistEntry {
129 iss: "android-studio".to_string(),
130 allowed: vec![
131 "/android.emulation.control.EmulatorController/.*".to_string(),
132 "/android.emulation.control.UiController/.*".to_string(),
133 "/android.emulation.control.SnapshotService/.*".to_string(),
134 "/android.emulation.control.incubating.*".to_string(),
135 ],
136 protected: vec![],
137 },
138 AllowlistEntry {
140 iss: issuer.to_string(),
141 allowed: vec![
142 "/android.emulation.control.EmulatorController/.*".to_string(),
143 "/android.emulation.control.UiController/.*".to_string(),
144 "/android.emulation.control.SnapshotService/.*".to_string(),
145 "/android.emulation.control.incubating.*".to_string(),
146 ],
147 protected: vec![],
148 },
149 ],
150 }
151 }
152}
153
154#[derive(Serialize)]
156struct JwtHeader {
157 alg: String,
158 kid: String,
159}
160
161#[derive(Serialize)]
163struct JwtClaims {
164 iss: String,
165 #[serde(skip_serializing_if = "Vec::is_empty")]
166 aud: Vec<String>,
167 exp: u64,
168 iat: u64,
169}
170
171#[derive(Serialize, Deserialize)]
173struct Jwk {
174 kty: String,
175 crv: String,
176 x: String,
177 y: String,
178 kid: String,
179 #[serde(rename = "use")]
180 use_: String,
181 alg: String,
182 key_ops: Vec<String>,
183}
184
185#[derive(Serialize, Deserialize)]
187struct Jwks {
188 keys: Vec<Jwk>,
189}
190
191struct CachedToken {
193 token: String,
194 expires_at: SystemTime,
195}
196
197pub struct JwtTokenProvider {
199 auth_scheme: AuthScheme,
200 signing_key: SigningKey,
201 key_id: String,
202 issuer: String,
203 token_cache: Mutex<HashMap<String, CachedToken>>,
204}
205
206impl JwtTokenProvider {
207 pub fn new_and_register(
209 jwks_dir: impl Into<PathBuf>,
210 issuer: impl Into<String>,
211 ) -> Result<Arc<Self>, AuthError> {
212 let jwks_dir = jwks_dir.into();
213 let issuer = issuer.into();
214
215 let signing_key = SigningKey::random(&mut OsRng);
217
218 let verify_key = signing_key.verifying_key();
220 let encoded = verify_key.to_encoded_point(false);
221 let public_key_bytes = encoded.as_bytes();
222
223 let key_id = uuid::Uuid::new_v4().to_string();
225
226 if public_key_bytes.len() != 65 || public_key_bytes[0] != 0x04 {
228 return Err(AuthError::FailedToGenerateKey);
229 }
230
231 let x_bytes = &public_key_bytes[1..33];
232 let y_bytes = &public_key_bytes[33..65];
233
234 let x = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(x_bytes);
236 let y = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(y_bytes);
237
238 let jwk = Jwk {
239 kid: key_id.clone(),
240 alg: "ES256".to_string(),
241 kty: "EC".to_string(),
242 crv: "P-256".to_string(),
243 x,
244 y,
245 use_: "sig".to_string(),
246 key_ops: vec!["verify".to_string()],
247 };
248
249 let jwks = Jwks { keys: vec![jwk] };
251
252 std::fs::create_dir_all(&jwks_dir).map_err(AuthError::FailedToCreateJwksDir)?;
254 let jwks_path = jwks_dir.join(format!("{}.jwk", key_id));
255 let jwks_json =
256 serde_json::to_string_pretty(&jwks).map_err(|_e| AuthError::FailedToSerializeJwks)?;
257 tracing::debug!("Writing JWK to {}", jwks_path.display());
258 std::fs::write(&jwks_path, jwks_json).map_err(AuthError::FailedToWriteJwksFile)?;
259
260 Ok(Arc::new(Self {
261 auth_scheme: AuthScheme::Jwt {
262 issuer: issuer.to_string(),
263 jwks_dir: jwks_dir.to_path_buf(),
264 },
265 signing_key,
266 key_id,
267 issuer: issuer.to_string(),
268 token_cache: Mutex::new(HashMap::new()),
269 }))
270 }
271
272 pub fn wait_for_activation(
274 self: &Arc<Self>,
275 jwks_dir: &Path,
276 timeout: Duration,
277 ) -> Result<(), AuthError> {
278 let active_jwk_path = jwks_dir.join("active.jwk");
279 let start = std::time::Instant::now();
280
281 loop {
282 if start.elapsed() > timeout {
283 return Err(AuthError::JwksRegistrationTimeout);
284 }
285
286 if let Ok(contents) = std::fs::read_to_string(&active_jwk_path)
288 && let Ok(jwks) = serde_json::from_str::<Jwks>(&contents)
289 {
290 if jwks.keys.iter().any(|k| k.kid == self.key_id) {
292 return Ok(());
293 }
294 }
295
296 std::thread::sleep(Duration::from_millis(100));
297 }
298 }
299
300 fn create_token(&self, auds: &[&str], ttl: Duration) -> Result<String, AuthError> {
301 let now = SystemTime::now()
302 .duration_since(UNIX_EPOCH)
303 .map_err(|_| AuthError::Unspecified)?
304 .as_secs();
305 let exp = now + ttl.as_secs();
306
307 let header = JwtHeader {
308 alg: "ES256".to_string(),
309 kid: self.key_id.clone(),
310 };
311
312 let claims = JwtClaims {
313 iss: self.issuer.clone(),
314 aud: auds.iter().map(|&s| s.to_string()).collect(),
315 exp,
316 iat: now,
317 };
318
319 let header_json = serde_json::to_string(&header).map_err(|_| AuthError::Unspecified)?;
321 let claims_json = serde_json::to_string(&claims).map_err(|_| AuthError::Unspecified)?;
322
323 let header_b64 = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(header_json);
324 let claims_b64 = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(claims_json);
325
326 let signing_input = format!("{}.{}", header_b64, claims_b64);
328
329 let sig: Signature = self.signing_key.sign(signing_input.as_bytes());
331 let sig_bytes = sig.to_bytes();
332 let signature_b64 = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(sig_bytes);
333
334 let token = format!("{}.{}", signing_input, signature_b64);
336 Ok(token)
337 }
338}
339
340impl TokenProvider for JwtTokenProvider {
341 fn auth_scheme(&self) -> &AuthScheme {
342 &self.auth_scheme
343 }
344 fn token_for_aud(&self, aud: &str) -> Result<String, AuthError> {
345 let mut cache = self.token_cache.lock().unwrap();
346 let now = SystemTime::now();
347
348 if let Some(cached) = cache.get(aud)
350 && cached.expires_at > now
351 {
352 return Ok(cached.token.clone());
353 }
354
355 let ttl = Duration::from_secs(10 * 60);
357 let token = self.create_token(&[aud], ttl)?;
358
359 let expires_at = now + ttl - Duration::from_secs(30);
361 cache.insert(
362 aud.to_string(),
363 CachedToken {
364 token: token.clone(),
365 expires_at,
366 },
367 );
368
369 Ok(token)
370 }
371 fn export_token(&self, auds: &[&str], ttl: Duration) -> Result<BearerToken, AuthError> {
372 let token = self.create_token(auds, ttl)?;
373
374 let expires_at = SystemTime::now() + ttl;
375 Ok(BearerToken { token, expires_at })
376 }
377}
378
379#[derive(Clone)]
382pub struct AuthProvider {
383 provider: Arc<dyn TokenProvider>,
384}
385
386impl AuthProvider {
387 pub(crate) fn new_with_token_provider(provider: Arc<dyn TokenProvider>) -> Self {
388 Self { provider }
389 }
390
391 pub fn new_no_auth() -> Self {
392 Self {
393 provider: Arc::new(NoOpTokenProvider),
394 }
395 }
396 pub fn new_bearer(token: impl Into<String>) -> Self {
397 Self {
398 provider: Arc::new(BearerTokenProvider::new(token.into())),
399 }
400 }
401 pub async fn new_jwt(
402 jwks_dir: impl Into<PathBuf>,
403 issuer: impl Into<String>,
404 ) -> Result<Arc<Self>, AuthError> {
405 let jwks_dir = jwks_dir.into();
406 let issuer = issuer.into();
407 let jwt_provider = tokio::task::spawn_blocking(move || {
408 JwtTokenProvider::new_and_register(jwks_dir, issuer)
409 })
410 .await
411 .map_err(|_e| AuthError::Unspecified)??;
412 Ok(Arc::new(Self {
413 provider: jwt_provider,
414 }))
415 }
416
417 pub fn auth_scheme(&self) -> &AuthScheme {
418 self.provider.auth_scheme()
419 }
420 pub fn export_token(&self, auds: &[&str], ttl: Duration) -> Result<BearerToken, AuthError> {
421 self.provider.export_token(auds, ttl)
422 }
423}
424
425impl tonic::service::Interceptor for AuthProvider {
426 fn call(&mut self, mut req: Request<()>) -> Result<Request<()>, Status> {
427 let method = req
429 .extensions()
430 .get::<tonic::GrpcMethod>()
431 .expect("GrpcMethod missing");
432 let aud = format!("/{}/{}", method.service(), method.method());
433 let token = self
434 .provider
435 .token_for_aud(&aud)
436 .map_err(|e| Status::unauthenticated(format!("Token error: {e}")))?;
437
438 if !token.is_empty() {
440 let value = MetadataValue::try_from(format!("Bearer {token}"))
442 .map_err(|_| Status::internal("Invalid auth metadata"))?;
443
444 req.metadata_mut().insert("authorization", value);
445 }
446 Ok(req)
447 }
448}