Skip to main content

bunnyapp_license/
lib.rs

1mod claims;
2mod error;
3
4pub use claims::LicenseClaims;
5pub use error::LicenseError;
6
7use jsonwebtoken::{decode, decode_header, jwk::JwkSet, Algorithm, DecodingKey, Validation};
8use reqwest::Client;
9use serde::{Deserialize, Serialize};
10use std::{collections::HashSet, env};
11
12// In tests, swap in a JWKS that contains the test RSA key so the offline path
13// can be exercised end-to-end without touching the real bundled keys.
14#[cfg(not(test))]
15const OFFLINE_JWKS: &str = include_str!("keys/offline_jwks.json");
16#[cfg(test)]
17const OFFLINE_JWKS: &str = include_str!("keys/test_offline_jwks.json");
18
19const ALLOWED_ALGORITHMS: &[Algorithm] = &[
20    Algorithm::RS256,
21    Algorithm::RS384,
22    Algorithm::RS512,
23    Algorithm::ES256,
24    Algorithm::ES384,
25    Algorithm::PS256,
26    Algorithm::PS384,
27    Algorithm::PS512,
28];
29
30const EXPECTED_ISSUER: &str = "https://auth.bunny.com";
31const EXPECTED_AUDIENCE: &str = "bunny-license-key";
32
33#[derive(Serialize)]
34struct ValidateRequest<'a> {
35    #[serde(skip_serializing_if = "Option::is_none")]
36    instance_fingerprint: Option<&'a str>,
37}
38
39#[derive(Deserialize)]
40struct ValidateResponse {
41    token: String,
42}
43
44/// Validate the Bunny license and return the verified JWT claims.
45///
46/// Checks `BUNNY_LICENSE_KEY` first (online mode), then falls back to
47/// `BUNNY_OFFLINE_LICENSE_KEY` (offline / air-gapped mode).
48/// Online mode also requires `BUNNY_HOST` to be set.
49pub async fn validate_license(instance_fingerprint: Option<&str>) -> Result<LicenseClaims, LicenseError> {
50    if let Ok(key) = env::var("BUNNY_LICENSE_KEY") {
51        validate_online(&key, instance_fingerprint).await
52    } else if let Ok(token) = env::var("BUNNY_OFFLINE_LICENSE_KEY") {
53        validate_offline(&token)
54    } else {
55        Err(LicenseError::NoLicenseKeySet)
56    }
57}
58
59fn build_client() -> Result<Client, LicenseError> {
60    let accept_invalid = env::var("BUNNY_DANGER_ACCEPT_INVALID_CERTS")
61        .map(|v| v == "true" || v == "1")
62        .unwrap_or(false);
63    Client::builder()
64        .danger_accept_invalid_certs(accept_invalid)
65        .build()
66        .map_err(LicenseError::Http)
67}
68
69async fn validate_online(license_key: &str, instance_fingerprint: Option<&str>) -> Result<LicenseClaims, LicenseError> {
70    let host = env::var("BUNNY_HOST").map_err(|_| LicenseError::NoHostSet)?;
71    let client = build_client()?;
72
73    let response = client
74        .post(format!("{}/api/license/validate", host))
75        .bearer_auth(license_key)
76        .json(&ValidateRequest { instance_fingerprint })
77        .send()
78        .await?;
79
80    let status = response.status();
81    if !status.is_success() {
82        return Err(LicenseError::ValidationFailed {
83            status: status.as_u16(),
84        });
85    }
86
87    let body: ValidateResponse = response
88        .json()
89        .await
90        .map_err(|_| LicenseError::MissingToken)?;
91
92    let jwks_response = client
93        .get(format!("{}/api/.well-known/jwks.json", host))
94        .send()
95        .await?;
96
97    let jwks: JwkSet = jwks_response
98        .json()
99        .await
100        .map_err(|e| LicenseError::JwksParse(e.to_string()))?;
101
102    verify_jwt(&body.token, &jwks)
103}
104
105fn validate_offline(token: &str) -> Result<LicenseClaims, LicenseError> {
106    let jwks: JwkSet =
107        serde_json::from_str(OFFLINE_JWKS).map_err(|e| LicenseError::JwksParse(e.to_string()))?;
108
109    verify_jwt(token, &jwks)
110}
111
112fn verify_jwt(token: &str, jwks: &JwkSet) -> Result<LicenseClaims, LicenseError> {
113    let header = decode_header(token)?;
114
115    if !ALLOWED_ALGORITHMS.contains(&header.alg) {
116        return Err(LicenseError::UnsupportedAlgorithm(format!(
117            "{:?}",
118            header.alg
119        )));
120    }
121
122    let kid = header.kid.ok_or(LicenseError::MissingKeyId)?;
123    let jwk = jwks
124        .find(&kid)
125        .ok_or_else(|| LicenseError::KeyNotFound(kid.clone()))?;
126
127    let decoding_key = DecodingKey::from_jwk(jwk)?;
128
129    let mut validation = Validation::new(header.alg);
130    validation.validate_exp = true;
131    validation.required_spec_claims = HashSet::from(["exp".to_string()]);
132    validation.set_issuer(&[EXPECTED_ISSUER]);
133    validation.set_audience(&[EXPECTED_AUDIENCE]);
134
135    decode::<LicenseClaims>(token, &decoding_key, &validation)
136        .map(|data| data.claims)
137        .map_err(|e| match e.kind() {
138            jsonwebtoken::errors::ErrorKind::ExpiredSignature => LicenseError::TokenExpired,
139            _ => LicenseError::Jwt(e),
140        })
141}
142
143#[cfg(test)]
144mod tests {
145    use super::*;
146    use jsonwebtoken::{encode, Algorithm, EncodingKey, Header};
147    use serde_json::{json, Value};
148    use serial_test::serial;
149    use std::{
150        env,
151        time::{SystemTime, UNIX_EPOCH},
152    };
153    use wiremock::{
154        matchers::{header, method, path},
155        Mock, MockServer, ResponseTemplate,
156    };
157
158    const TEST_KID: &str = "test-key-1";
159    const TEST_RSA_N: &str = "pzxnaqHr0kjYR5W3aYq_2QduN5oymUoH3YH6U4dk9W6lra2XmQaTidQs1xHVn79WZOs4CNgU_RvScalyEPaMt0SHta3rwMSdTk2ShfNA831jDwDFpjQGcyAWM3d4IowHFDC6cXkttNKNXGqoDQMq_qXfoHjYAkAVv2O9jyg7mJo8pZeeOxzj8vAtQlUFcDM2x3cHuGJl48DASC1cG2WI606ppz319c7gmwVKnay7vOQScek4ErJ3EMh9AypFHSju3fRVhjobFALO7xwta09CWTk25DHAd3mcYvLviGikOAPnI0bxEPZYS42IXmLG7GyNMa7sYhcvYNsK7m3YtWGbpQ";
160    const TEST_PRIVATE_KEY: &str = include_str!("keys/test_private_key.pem");
161
162    fn unix_now() -> u64 {
163        SystemTime::now()
164            .duration_since(UNIX_EPOCH)
165            .unwrap()
166            .as_secs()
167    }
168
169    /// Builds a full valid claims payload including the required iss and aud.
170    fn valid_claims_json() -> Value {
171        let now = unix_now();
172        json!({
173            "sub": "cust_test",
174            "iat": now,
175            "exp": now + 3600,
176            "iss": EXPECTED_ISSUER,
177            "aud": EXPECTED_AUDIENCE,
178            "subscription": { "plan": "pro", "seats": 10 }
179        })
180    }
181
182    fn make_jwt(claims: &Value, kid: Option<&str>) -> String {
183        let mut header = Header::new(Algorithm::RS256);
184        header.kid = kid.map(str::to_string);
185        encode(
186            &header,
187            claims,
188            &EncodingKey::from_rsa_pem(TEST_PRIVATE_KEY.as_bytes()).unwrap(),
189        )
190        .unwrap()
191    }
192
193    fn make_hs256_jwt(claims: &Value) -> String {
194        let mut header = Header::new(Algorithm::HS256);
195        header.kid = Some(TEST_KID.to_string());
196        encode(&header, claims, &EncodingKey::from_secret(b"secret")).unwrap()
197    }
198
199    fn test_jwks_json() -> Value {
200        json!({
201            "keys": [{
202                "kty": "RSA",
203                "use": "sig",
204                "alg": "RS256",
205                "kid": TEST_KID,
206                "n": TEST_RSA_N,
207                "e": "AQAB"
208            }]
209        })
210    }
211
212    fn test_jwks() -> JwkSet {
213        serde_json::from_value(test_jwks_json()).unwrap()
214    }
215
216    // ── verify_jwt (direct, no env vars, sync) ──────────────────────────────
217
218    #[test]
219    fn verify_valid_jwt_returns_claims() {
220        let token = make_jwt(&valid_claims_json(), Some(TEST_KID));
221        let result = verify_jwt(&token, &test_jwks()).unwrap();
222        assert_eq!(result.subscription, json!({ "plan": "pro", "seats": 10 }));
223        assert_eq!(result.sub.as_deref(), Some("cust_test"));
224    }
225
226    #[test]
227    fn verify_expired_jwt_returns_token_expired() {
228        let now = unix_now();
229        let claims = json!({
230            "sub": "cust_test",
231            "iat": now - 7200,
232            "exp": now - 3600,
233            "iss": EXPECTED_ISSUER,
234            "aud": EXPECTED_AUDIENCE,
235            "subscription": {}
236        });
237        let token = make_jwt(&claims, Some(TEST_KID));
238        let err = verify_jwt(&token, &test_jwks()).unwrap_err();
239        assert!(matches!(err, LicenseError::TokenExpired));
240    }
241
242    #[test]
243    fn verify_hs256_jwt_returns_unsupported_algorithm() {
244        let token = make_hs256_jwt(&valid_claims_json());
245        let err = verify_jwt(&token, &test_jwks()).unwrap_err();
246        assert!(matches!(err, LicenseError::UnsupportedAlgorithm(_)));
247    }
248
249    #[test]
250    fn verify_unknown_kid_returns_key_not_found() {
251        let token = make_jwt(&valid_claims_json(), Some("no-such-key"));
252        let err = verify_jwt(&token, &test_jwks()).unwrap_err();
253        assert!(matches!(err, LicenseError::KeyNotFound(_)));
254    }
255
256    #[test]
257    fn verify_missing_kid_returns_missing_key_id() {
258        let token = make_jwt(&valid_claims_json(), None);
259        let err = verify_jwt(&token, &test_jwks()).unwrap_err();
260        assert!(matches!(err, LicenseError::MissingKeyId));
261    }
262
263    #[test]
264    fn verify_wrong_issuer_returns_jwt_error() {
265        let mut claims = valid_claims_json();
266        claims["iss"] = json!("https://evil.example.com");
267        let token = make_jwt(&claims, Some(TEST_KID));
268        let err = verify_jwt(&token, &test_jwks()).unwrap_err();
269        assert!(matches!(err, LicenseError::Jwt(_)));
270    }
271
272    #[test]
273    fn verify_wrong_audience_returns_jwt_error() {
274        let mut claims = valid_claims_json();
275        claims["aud"] = json!("wrong-audience");
276        let token = make_jwt(&claims, Some(TEST_KID));
277        let err = verify_jwt(&token, &test_jwks()).unwrap_err();
278        assert!(matches!(err, LicenseError::Jwt(_)));
279    }
280
281    // ── validate_license routing ─────────────────────────────────────────────
282
283    #[tokio::test]
284    #[serial]
285    async fn no_env_vars_returns_no_license_key_set() {
286        env::remove_var("BUNNY_LICENSE_KEY");
287        env::remove_var("BUNNY_OFFLINE_LICENSE_KEY");
288        let err = validate_license(None).await.unwrap_err();
289        assert!(matches!(err, LicenseError::NoLicenseKeySet));
290    }
291
292    #[tokio::test]
293    #[serial]
294    async fn online_mode_without_host_returns_no_host_set() {
295        env::set_var("BUNNY_LICENSE_KEY", "any-key");
296        env::remove_var("BUNNY_HOST");
297        let err = validate_license(None).await.unwrap_err();
298        env::remove_var("BUNNY_LICENSE_KEY");
299        assert!(matches!(err, LicenseError::NoHostSet));
300    }
301
302    // ── offline mode ─────────────────────────────────────────────────────────
303
304    #[tokio::test]
305    #[serial]
306    async fn offline_valid_jwt_returns_claims() {
307        let token = make_jwt(&valid_claims_json(), Some(TEST_KID));
308        env::remove_var("BUNNY_LICENSE_KEY");
309        env::set_var("BUNNY_OFFLINE_LICENSE_KEY", &token);
310        let result = validate_license(None).await;
311        env::remove_var("BUNNY_OFFLINE_LICENSE_KEY");
312
313        let claims = result.unwrap();
314        assert_eq!(claims.subscription, json!({ "plan": "pro", "seats": 10 }));
315    }
316
317    #[tokio::test]
318    #[serial]
319    async fn offline_expired_jwt_returns_token_expired() {
320        let now = unix_now();
321        let claims = json!({
322            "sub": "cust_test",
323            "iat": now - 7200,
324            "exp": now - 3600,
325            "iss": EXPECTED_ISSUER,
326            "aud": EXPECTED_AUDIENCE,
327            "subscription": {}
328        });
329        let token = make_jwt(&claims, Some(TEST_KID));
330        env::remove_var("BUNNY_LICENSE_KEY");
331        env::set_var("BUNNY_OFFLINE_LICENSE_KEY", &token);
332        let err = validate_license(None).await.unwrap_err();
333        env::remove_var("BUNNY_OFFLINE_LICENSE_KEY");
334        assert!(matches!(err, LicenseError::TokenExpired));
335    }
336
337    #[tokio::test]
338    #[serial]
339    async fn offline_unknown_kid_returns_key_not_found() {
340        let token = make_jwt(&valid_claims_json(), Some("rotated-key-not-in-bundle"));
341        env::remove_var("BUNNY_LICENSE_KEY");
342        env::set_var("BUNNY_OFFLINE_LICENSE_KEY", &token);
343        let err = validate_license(None).await.unwrap_err();
344        env::remove_var("BUNNY_OFFLINE_LICENSE_KEY");
345        assert!(matches!(err, LicenseError::KeyNotFound(_)));
346    }
347
348    // ── online mode (mock server) ─────────────────────────────────────────────
349
350    #[tokio::test]
351    #[serial]
352    async fn online_server_returns_401_yields_validation_failed() {
353        let server = MockServer::start().await;
354        Mock::given(method("POST"))
355            .and(path("/api/license/validate"))
356            .respond_with(ResponseTemplate::new(401))
357            .mount(&server)
358            .await;
359
360        env::set_var("BUNNY_LICENSE_KEY", "bad-key");
361        env::set_var("BUNNY_HOST", server.uri());
362        let err = validate_license(None).await.unwrap_err();
363        env::remove_var("BUNNY_LICENSE_KEY");
364        env::remove_var("BUNNY_HOST");
365
366        assert!(matches!(err, LicenseError::ValidationFailed { status: 401 }));
367    }
368
369    #[tokio::test]
370    #[serial]
371    async fn online_valid_flow_returns_claims() {
372        let server = MockServer::start().await;
373
374        let token = make_jwt(&valid_claims_json(), Some(TEST_KID));
375
376        Mock::given(method("POST"))
377            .and(path("/api/license/validate"))
378            .and(header("authorization", "Bearer valid-license-key"))
379            .respond_with(ResponseTemplate::new(200).set_body_json(json!({ "token": token })))
380            .mount(&server)
381            .await;
382
383        Mock::given(method("GET"))
384            .and(path("/api/.well-known/jwks.json"))
385            .respond_with(ResponseTemplate::new(200).set_body_json(test_jwks_json()))
386            .mount(&server)
387            .await;
388
389        env::set_var("BUNNY_LICENSE_KEY", "valid-license-key");
390        env::set_var("BUNNY_HOST", server.uri());
391        let result = validate_license(None).await;
392        env::remove_var("BUNNY_LICENSE_KEY");
393        env::remove_var("BUNNY_HOST");
394
395        let claims = result.unwrap();
396        assert_eq!(claims.subscription, json!({ "plan": "pro", "seats": 10 }));
397        assert_eq!(claims.sub.as_deref(), Some("cust_test"));
398    }
399
400    #[tokio::test]
401    #[serial]
402    async fn online_sends_instance_fingerprint_in_body() {
403        use wiremock::matchers::body_json;
404
405        let server = MockServer::start().await;
406        let token = make_jwt(&valid_claims_json(), Some(TEST_KID));
407
408        Mock::given(method("POST"))
409            .and(path("/api/license/validate"))
410            .and(body_json(json!({ "instance_fingerprint": "device-abc-123" })))
411            .respond_with(ResponseTemplate::new(200).set_body_json(json!({ "token": token })))
412            .mount(&server)
413            .await;
414
415        Mock::given(method("GET"))
416            .and(path("/api/.well-known/jwks.json"))
417            .respond_with(ResponseTemplate::new(200).set_body_json(test_jwks_json()))
418            .mount(&server)
419            .await;
420
421        env::set_var("BUNNY_LICENSE_KEY", "valid-license-key");
422        env::set_var("BUNNY_HOST", server.uri());
423        let result = validate_license(Some("device-abc-123")).await;
424        env::remove_var("BUNNY_LICENSE_KEY");
425        env::remove_var("BUNNY_HOST");
426
427        result.unwrap();
428    }
429
430    #[tokio::test]
431    #[serial]
432    async fn online_expired_jwt_from_server_returns_token_expired() {
433        let server = MockServer::start().await;
434
435        let now = unix_now();
436        let claims = json!({
437            "sub": "cust_test",
438            "iat": now - 7200,
439            "exp": now - 3600,
440            "iss": EXPECTED_ISSUER,
441            "aud": EXPECTED_AUDIENCE,
442            "subscription": {}
443        });
444        let token = make_jwt(&claims, Some(TEST_KID));
445
446        Mock::given(method("POST"))
447            .and(path("/api/license/validate"))
448            .respond_with(ResponseTemplate::new(200).set_body_json(json!({ "token": token })))
449            .mount(&server)
450            .await;
451
452        Mock::given(method("GET"))
453            .and(path("/api/.well-known/jwks.json"))
454            .respond_with(ResponseTemplate::new(200).set_body_json(test_jwks_json()))
455            .mount(&server)
456            .await;
457
458        env::set_var("BUNNY_LICENSE_KEY", "some-key");
459        env::set_var("BUNNY_HOST", server.uri());
460        let err = validate_license(None).await.unwrap_err();
461        env::remove_var("BUNNY_LICENSE_KEY");
462        env::remove_var("BUNNY_HOST");
463
464        assert!(matches!(err, LicenseError::TokenExpired));
465    }
466}