Skip to main content

android_emulator/
auth.rs

1use base64::Engine;
2use p256::ecdsa::Signature;
3use p256::ecdsa::SigningKey;
4use p256::ecdsa::signature::Signer as _;
5use rand_core::OsRng;
6use serde::{Deserialize, Serialize};
7use std::collections::HashMap;
8use std::path::{Path, PathBuf};
9use std::sync::{Arc, Mutex};
10use std::time::{Duration, SystemTime, UNIX_EPOCH};
11use tonic::metadata::MetadataValue;
12use tonic::{Request, Status};
13
14/// Authentication error type
15#[derive(thiserror::Error, Debug)]
16pub enum AuthError {
17    #[error("Unsupported operation: {0}")]
18    Unsupported(String),
19
20    #[error("Failed to generate key")]
21    FailedToGenerateKey,
22
23    #[error("Failed to load key")]
24    FailedToLoadKey,
25
26    #[error("Failed to create JWKS directory")]
27    FailedToCreateJwksDir(#[source] std::io::Error),
28
29    #[error("Failed to serialize JWKS")]
30    FailedToSerializeJwks,
31
32    #[error("Failed to write JWKS file")]
33    FailedToWriteJwksFile(#[source] std::io::Error),
34
35    #[error("Timeout waiting for JWK registration")]
36    JwksRegistrationTimeout,
37
38    #[error("Unspecified signing error")]
39    Unspecified,
40}
41
42/// Authentication scheme for gRPC connections
43#[derive(Debug)]
44pub enum AuthScheme {
45    /// No authentication
46    None,
47    /// A static bearer token is attached to each request
48    Bearer,
49    /// JWT tokens are attached to each request (these have an `iss` field that
50    /// matches the issuer, an `aud` field that matches the method being
51    /// requested, and are signed with an ES256 key that has been registered
52    /// with the emulator)
53    Jwt { issuer: String, jwks_dir: PathBuf },
54}
55
56/// Bearer token with expiration
57pub struct BearerToken {
58    pub token: String,
59    pub expires_at: SystemTime,
60}
61
62/// Trait for providing JWT tokens
63pub(crate) trait TokenProvider: Send + Sync + 'static {
64    fn auth_scheme(&self) -> &AuthScheme;
65    fn token_for_aud(&self, aud: &str) -> Result<String, AuthError>;
66    fn export_token(&self, _auds: &[&str], _ttl: Duration) -> Result<BearerToken, AuthError> {
67        Err(AuthError::Unsupported("export_token".into()))
68    }
69}
70
71/// No-op token provider for unauthenticated connections
72#[derive(Clone)]
73pub struct NoOpTokenProvider;
74
75impl TokenProvider for NoOpTokenProvider {
76    fn auth_scheme(&self) -> &AuthScheme {
77        &AuthScheme::None
78    }
79    fn token_for_aud(&self, _aud: &str) -> Result<String, AuthError> {
80        Ok(String::new())
81    }
82}
83
84/// Static bearer token provider for emulators with grpc.token, or for clients
85/// that have been given a static JWT token to use.
86#[derive(Clone)]
87pub struct BearerTokenProvider {
88    token: String,
89}
90
91impl BearerTokenProvider {
92    pub fn new(token: String) -> Self {
93        Self { token }
94    }
95}
96
97impl TokenProvider for BearerTokenProvider {
98    fn auth_scheme(&self) -> &AuthScheme {
99        &AuthScheme::Bearer
100    }
101    fn token_for_aud(&self, _aud: &str) -> Result<String, AuthError> {
102        Ok(self.token.clone())
103    }
104}
105
106/// Allowlist entry for JWT authentication
107#[derive(Debug, Clone, Serialize, Deserialize)]
108pub struct AllowlistEntry {
109    pub iss: String,
110    pub allowed: Vec<String>,
111    pub protected: Vec<String>,
112}
113
114/// gRPC allowlist configuration
115#[derive(Debug, Clone, Serialize, Deserialize)]
116pub struct GrpcAllowlist {
117    pub unprotected: Vec<String>,
118    pub allowlist: Vec<AllowlistEntry>,
119}
120
121impl GrpcAllowlist {
122    /// Create a default allowlist for the given issuer
123    pub fn default_for_issuer(issuer: &str) -> Self {
124        Self {
125            unprotected: vec![],
126            allowlist: vec![
127                // Always include android-studio to not break embedded emulator
128                AllowlistEntry {
129                    iss: "android-studio".to_string(),
130                    allowed: vec![
131                        "/android.emulation.control.EmulatorController/.*".to_string(),
132                        "/android.emulation.control.UiController/.*".to_string(),
133                        "/android.emulation.control.SnapshotService/.*".to_string(),
134                        "/android.emulation.control.incubating.*".to_string(),
135                    ],
136                    protected: vec![],
137                },
138                // Add custom issuer if it's not android-studio
139                AllowlistEntry {
140                    iss: issuer.to_string(),
141                    allowed: vec![
142                        "/android.emulation.control.EmulatorController/.*".to_string(),
143                        "/android.emulation.control.UiController/.*".to_string(),
144                        "/android.emulation.control.SnapshotService/.*".to_string(),
145                        "/android.emulation.control.incubating.*".to_string(),
146                    ],
147                    protected: vec![],
148                },
149            ],
150        }
151    }
152}
153
154/// JWT header for ES256 (ECDSA P-256)
155#[derive(Serialize)]
156struct JwtHeader {
157    alg: String,
158    kid: String,
159}
160
161/// JWT claims
162#[derive(Serialize)]
163struct JwtClaims {
164    iss: String,
165    #[serde(skip_serializing_if = "Vec::is_empty")]
166    aud: Vec<String>,
167    exp: u64,
168    iat: u64,
169}
170
171/// JWK format for ECDSA P-256 public key
172#[derive(Serialize, Deserialize)]
173struct Jwk {
174    kty: String,
175    crv: String,
176    x: String,
177    y: String,
178    kid: String,
179    #[serde(rename = "use")]
180    use_: String,
181    alg: String,
182    key_ops: Vec<String>,
183}
184
185/// JWKS file format
186#[derive(Serialize, Deserialize)]
187struct Jwks {
188    keys: Vec<Jwk>,
189}
190
191/// Cached token with expiration
192struct CachedToken {
193    token: String,
194    expires_at: SystemTime,
195}
196
197/// JWT ES256 token provider implementation
198pub struct JwtTokenProvider {
199    auth_scheme: AuthScheme,
200    signing_key: SigningKey,
201    key_id: String,
202    issuer: String,
203    token_cache: Mutex<HashMap<String, CachedToken>>,
204}
205
206impl JwtTokenProvider {
207    /// Create a new JWT ES256 token provider and register the key with the emulator
208    pub fn new_and_register(
209        jwks_dir: impl Into<PathBuf>,
210        issuer: impl Into<String>,
211    ) -> Result<Arc<Self>, AuthError> {
212        let jwks_dir = jwks_dir.into();
213        let issuer = issuer.into();
214
215        // Generate P-256 signing key (private key)
216        let signing_key = SigningKey::random(&mut OsRng);
217
218        // Public key and x/y extraction (uncompressed SEC1 point)
219        let verify_key = signing_key.verifying_key();
220        let encoded = verify_key.to_encoded_point(false);
221        let public_key_bytes = encoded.as_bytes();
222
223        // Generate a random key ID
224        let key_id = uuid::Uuid::new_v4().to_string();
225
226        // For ECDSA P-256, the public key is 65 bytes: 0x04 + 32 bytes X + 32 bytes Y
227        if public_key_bytes.len() != 65 || public_key_bytes[0] != 0x04 {
228            return Err(AuthError::FailedToGenerateKey);
229        }
230
231        let x_bytes = &public_key_bytes[1..33];
232        let y_bytes = &public_key_bytes[33..65];
233
234        // Base64 URL encode without padding
235        let x = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(x_bytes);
236        let y = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(y_bytes);
237
238        let jwk = Jwk {
239            kid: key_id.clone(),
240            alg: "ES256".to_string(),
241            kty: "EC".to_string(),
242            crv: "P-256".to_string(),
243            x,
244            y,
245            use_: "sig".to_string(),
246            key_ops: vec!["verify".to_string()],
247        };
248
249        // Create JWKS with single key
250        let jwks = Jwks { keys: vec![jwk] };
251
252        // Write JWKS file
253        std::fs::create_dir_all(&jwks_dir).map_err(AuthError::FailedToCreateJwksDir)?;
254        let jwks_path = jwks_dir.join(format!("{}.jwk", key_id));
255        let jwks_json =
256            serde_json::to_string_pretty(&jwks).map_err(|_e| AuthError::FailedToSerializeJwks)?;
257        tracing::debug!("Writing JWK to {}", jwks_path.display());
258        std::fs::write(&jwks_path, jwks_json).map_err(AuthError::FailedToWriteJwksFile)?;
259
260        Ok(Arc::new(Self {
261            auth_scheme: AuthScheme::Jwt {
262                issuer: issuer.to_string(),
263                jwks_dir: jwks_dir.to_path_buf(),
264            },
265            signing_key,
266            key_id,
267            issuer: issuer.to_string(),
268            token_cache: Mutex::new(HashMap::new()),
269        }))
270    }
271
272    /// Wait for the emulator to activate this key
273    pub fn wait_for_activation(
274        self: &Arc<Self>,
275        jwks_dir: &Path,
276        timeout: Duration,
277    ) -> Result<(), AuthError> {
278        let active_jwk_path = jwks_dir.join("active.jwk");
279        let start = std::time::Instant::now();
280
281        loop {
282            if start.elapsed() > timeout {
283                return Err(AuthError::JwksRegistrationTimeout);
284            }
285
286            // Read active.jwk if it exists
287            if let Ok(contents) = std::fs::read_to_string(&active_jwk_path)
288                && let Ok(jwks) = serde_json::from_str::<Jwks>(&contents)
289            {
290                // Check if our key ID is in the active keys
291                if jwks.keys.iter().any(|k| k.kid == self.key_id) {
292                    return Ok(());
293                }
294            }
295
296            std::thread::sleep(Duration::from_millis(100));
297        }
298    }
299
300    fn create_token(&self, auds: &[&str], ttl: Duration) -> Result<String, AuthError> {
301        let now = SystemTime::now()
302            .duration_since(UNIX_EPOCH)
303            .map_err(|_| AuthError::Unspecified)?
304            .as_secs();
305        let exp = now + ttl.as_secs();
306
307        let header = JwtHeader {
308            alg: "ES256".to_string(),
309            kid: self.key_id.clone(),
310        };
311
312        let claims = JwtClaims {
313            iss: self.issuer.clone(),
314            aud: auds.iter().map(|&s| s.to_string()).collect(),
315            exp,
316            iat: now,
317        };
318
319        // Encode header and claims
320        let header_json = serde_json::to_string(&header).map_err(|_| AuthError::Unspecified)?;
321        let claims_json = serde_json::to_string(&claims).map_err(|_| AuthError::Unspecified)?;
322
323        let header_b64 = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(header_json);
324        let claims_b64 = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(claims_json);
325
326        // Create signing input
327        let signing_input = format!("{}.{}", header_b64, claims_b64);
328
329        // Sign with ECDSA P-256
330        let sig: Signature = self.signing_key.sign(signing_input.as_bytes());
331        let sig_bytes = sig.to_bytes();
332        let signature_b64 = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(sig_bytes);
333
334        // Create final JWT
335        let token = format!("{}.{}", signing_input, signature_b64);
336        Ok(token)
337    }
338}
339
340impl TokenProvider for JwtTokenProvider {
341    fn auth_scheme(&self) -> &AuthScheme {
342        &self.auth_scheme
343    }
344    fn token_for_aud(&self, aud: &str) -> Result<String, AuthError> {
345        let mut cache = self.token_cache.lock().unwrap();
346        let now = SystemTime::now();
347
348        // Check if we have a valid cached token
349        if let Some(cached) = cache.get(aud)
350            && cached.expires_at > now
351        {
352            return Ok(cached.token.clone());
353        }
354
355        // Create new token with 10 minute TTL
356        let ttl = Duration::from_secs(10 * 60);
357        let token = self.create_token(&[aud], ttl)?;
358
359        // Cache it (expires 30 seconds before actual expiration for safety)
360        let expires_at = now + ttl - Duration::from_secs(30);
361        cache.insert(
362            aud.to_string(),
363            CachedToken {
364                token: token.clone(),
365                expires_at,
366            },
367        );
368
369        Ok(token)
370    }
371    fn export_token(&self, auds: &[&str], ttl: Duration) -> Result<BearerToken, AuthError> {
372        let token = self.create_token(auds, ttl)?;
373
374        let expires_at = SystemTime::now() + ttl;
375        Ok(BearerToken { token, expires_at })
376    }
377}
378
379/// An authentication interceptor for gRPC requests, responsible for adding
380/// an appropriate authorization header to each request.
381#[derive(Clone)]
382pub struct AuthProvider {
383    provider: Arc<dyn TokenProvider>,
384}
385
386impl AuthProvider {
387    pub(crate) fn new_with_token_provider(provider: Arc<dyn TokenProvider>) -> Self {
388        Self { provider }
389    }
390
391    pub fn new_no_auth() -> Self {
392        Self {
393            provider: Arc::new(NoOpTokenProvider),
394        }
395    }
396    pub fn new_bearer(token: impl Into<String>) -> Self {
397        Self {
398            provider: Arc::new(BearerTokenProvider::new(token.into())),
399        }
400    }
401    pub async fn new_jwt(
402        jwks_dir: impl Into<PathBuf>,
403        issuer: impl Into<String>,
404    ) -> Result<Arc<Self>, AuthError> {
405        let jwks_dir = jwks_dir.into();
406        let issuer = issuer.into();
407        let jwt_provider = tokio::task::spawn_blocking(move || {
408            JwtTokenProvider::new_and_register(jwks_dir, issuer)
409        })
410        .await
411        .map_err(|_e| AuthError::Unspecified)??;
412        Ok(Arc::new(Self {
413            provider: jwt_provider,
414        }))
415    }
416
417    pub fn auth_scheme(&self) -> &AuthScheme {
418        self.provider.auth_scheme()
419    }
420    pub fn export_token(&self, auds: &[&str], ttl: Duration) -> Result<BearerToken, AuthError> {
421        self.provider.export_token(auds, ttl)
422    }
423}
424
425impl tonic::service::Interceptor for AuthProvider {
426    fn call(&mut self, mut req: Request<()>) -> Result<Request<()>, Status> {
427        // The audience for JWT is the gRPC method path
428        let method = req
429            .extensions()
430            .get::<tonic::GrpcMethod>()
431            .expect("GrpcMethod missing");
432        let aud = format!("/{}/{}", method.service(), method.method());
433        let token = self
434            .provider
435            .token_for_aud(&aud)
436            .map_err(|e| Status::unauthenticated(format!("Token error: {e}")))?;
437
438        // Only add authorization header if we have a token (skip for no-op provider)
439        if !token.is_empty() {
440            // Add Bearer token to authorization header
441            let value = MetadataValue::try_from(format!("Bearer {token}"))
442                .map_err(|_| Status::internal("Invalid auth metadata"))?;
443
444            req.metadata_mut().insert("authorization", value);
445        }
446        Ok(req)
447    }
448}