1mod claims;
2mod error;
3
4pub use claims::LicenseClaims;
5pub use error::LicenseError;
6
7use jsonwebtoken::{decode, decode_header, jwk::JwkSet, Algorithm, DecodingKey, Validation};
8use reqwest::Client;
9use serde::{Deserialize, Serialize};
10use std::{collections::HashSet, env};
11
12#[cfg(not(test))]
15const OFFLINE_JWKS: &str = include_str!("keys/offline_jwks.json");
16#[cfg(test)]
17const OFFLINE_JWKS: &str = include_str!("keys/test_offline_jwks.json");
18
19const ALLOWED_ALGORITHMS: &[Algorithm] = &[
20 Algorithm::RS256,
21 Algorithm::RS384,
22 Algorithm::RS512,
23 Algorithm::ES256,
24 Algorithm::ES384,
25 Algorithm::PS256,
26 Algorithm::PS384,
27 Algorithm::PS512,
28];
29
30const EXPECTED_ISSUER: &str = "https://auth.bunny.com";
31const EXPECTED_AUDIENCE: &str = "bunny-license-key";
32
33#[derive(Serialize)]
34struct ValidateRequest<'a> {
35 #[serde(skip_serializing_if = "Option::is_none")]
36 instance_fingerprint: Option<&'a str>,
37}
38
39#[derive(Deserialize)]
40struct ValidateResponse {
41 token: String,
42}
43
44pub async fn validate_license(instance_fingerprint: Option<&str>) -> Result<LicenseClaims, LicenseError> {
50 if let Ok(key) = env::var("BUNNY_LICENSE_KEY") {
51 validate_online(&key, instance_fingerprint).await
52 } else if let Ok(token) = env::var("BUNNY_OFFLINE_LICENSE_KEY") {
53 validate_offline(&token)
54 } else {
55 Err(LicenseError::NoLicenseKeySet)
56 }
57}
58
59fn build_client() -> Result<Client, LicenseError> {
60 let accept_invalid = env::var("BUNNY_DANGER_ACCEPT_INVALID_CERTS")
61 .map(|v| v == "true" || v == "1")
62 .unwrap_or(false);
63 Client::builder()
64 .danger_accept_invalid_certs(accept_invalid)
65 .build()
66 .map_err(LicenseError::Http)
67}
68
69async fn validate_online(license_key: &str, instance_fingerprint: Option<&str>) -> Result<LicenseClaims, LicenseError> {
70 let host = env::var("BUNNY_HOST").map_err(|_| LicenseError::NoHostSet)?;
71 let client = build_client()?;
72
73 let response = client
74 .post(format!("{}/api/license/validate", host))
75 .bearer_auth(license_key)
76 .json(&ValidateRequest { instance_fingerprint })
77 .send()
78 .await?;
79
80 let status = response.status();
81 if !status.is_success() {
82 return Err(LicenseError::ValidationFailed {
83 status: status.as_u16(),
84 });
85 }
86
87 let body: ValidateResponse = response
88 .json()
89 .await
90 .map_err(|_| LicenseError::MissingToken)?;
91
92 let jwks_response = client
93 .get(format!("{}/api/.well-known/jwks.json", host))
94 .send()
95 .await?;
96
97 let jwks: JwkSet = jwks_response
98 .json()
99 .await
100 .map_err(|e| LicenseError::JwksParse(e.to_string()))?;
101
102 verify_jwt(&body.token, &jwks)
103}
104
105fn validate_offline(token: &str) -> Result<LicenseClaims, LicenseError> {
106 let jwks: JwkSet =
107 serde_json::from_str(OFFLINE_JWKS).map_err(|e| LicenseError::JwksParse(e.to_string()))?;
108
109 verify_jwt(token, &jwks)
110}
111
112fn verify_jwt(token: &str, jwks: &JwkSet) -> Result<LicenseClaims, LicenseError> {
113 let header = decode_header(token)?;
114
115 if !ALLOWED_ALGORITHMS.contains(&header.alg) {
116 return Err(LicenseError::UnsupportedAlgorithm(format!(
117 "{:?}",
118 header.alg
119 )));
120 }
121
122 let kid = header.kid.ok_or(LicenseError::MissingKeyId)?;
123 let jwk = jwks
124 .find(&kid)
125 .ok_or_else(|| LicenseError::KeyNotFound(kid.clone()))?;
126
127 let decoding_key = DecodingKey::from_jwk(jwk)?;
128
129 let mut validation = Validation::new(header.alg);
130 validation.validate_exp = true;
131 validation.required_spec_claims = HashSet::from(["exp".to_string()]);
132 validation.set_issuer(&[EXPECTED_ISSUER]);
133 validation.set_audience(&[EXPECTED_AUDIENCE]);
134
135 decode::<LicenseClaims>(token, &decoding_key, &validation)
136 .map(|data| data.claims)
137 .map_err(|e| match e.kind() {
138 jsonwebtoken::errors::ErrorKind::ExpiredSignature => LicenseError::TokenExpired,
139 _ => LicenseError::Jwt(e),
140 })
141}
142
143#[cfg(test)]
144mod tests {
145 use super::*;
146 use jsonwebtoken::{encode, Algorithm, EncodingKey, Header};
147 use serde_json::{json, Value};
148 use serial_test::serial;
149 use std::{
150 env,
151 time::{SystemTime, UNIX_EPOCH},
152 };
153 use wiremock::{
154 matchers::{header, method, path},
155 Mock, MockServer, ResponseTemplate,
156 };
157
158 const TEST_KID: &str = "test-key-1";
159 const TEST_RSA_N: &str = "pzxnaqHr0kjYR5W3aYq_2QduN5oymUoH3YH6U4dk9W6lra2XmQaTidQs1xHVn79WZOs4CNgU_RvScalyEPaMt0SHta3rwMSdTk2ShfNA831jDwDFpjQGcyAWM3d4IowHFDC6cXkttNKNXGqoDQMq_qXfoHjYAkAVv2O9jyg7mJo8pZeeOxzj8vAtQlUFcDM2x3cHuGJl48DASC1cG2WI606ppz319c7gmwVKnay7vOQScek4ErJ3EMh9AypFHSju3fRVhjobFALO7xwta09CWTk25DHAd3mcYvLviGikOAPnI0bxEPZYS42IXmLG7GyNMa7sYhcvYNsK7m3YtWGbpQ";
160 const TEST_PRIVATE_KEY: &str = include_str!("keys/test_private_key.pem");
161
162 fn unix_now() -> u64 {
163 SystemTime::now()
164 .duration_since(UNIX_EPOCH)
165 .unwrap()
166 .as_secs()
167 }
168
169 fn valid_claims_json() -> Value {
171 let now = unix_now();
172 json!({
173 "sub": "cust_test",
174 "iat": now,
175 "exp": now + 3600,
176 "iss": EXPECTED_ISSUER,
177 "aud": EXPECTED_AUDIENCE,
178 "subscription": { "plan": "pro", "seats": 10 }
179 })
180 }
181
182 fn make_jwt(claims: &Value, kid: Option<&str>) -> String {
183 let mut header = Header::new(Algorithm::RS256);
184 header.kid = kid.map(str::to_string);
185 encode(
186 &header,
187 claims,
188 &EncodingKey::from_rsa_pem(TEST_PRIVATE_KEY.as_bytes()).unwrap(),
189 )
190 .unwrap()
191 }
192
193 fn make_hs256_jwt(claims: &Value) -> String {
194 let mut header = Header::new(Algorithm::HS256);
195 header.kid = Some(TEST_KID.to_string());
196 encode(&header, claims, &EncodingKey::from_secret(b"secret")).unwrap()
197 }
198
199 fn test_jwks_json() -> Value {
200 json!({
201 "keys": [{
202 "kty": "RSA",
203 "use": "sig",
204 "alg": "RS256",
205 "kid": TEST_KID,
206 "n": TEST_RSA_N,
207 "e": "AQAB"
208 }]
209 })
210 }
211
212 fn test_jwks() -> JwkSet {
213 serde_json::from_value(test_jwks_json()).unwrap()
214 }
215
216 #[test]
219 fn verify_valid_jwt_returns_claims() {
220 let token = make_jwt(&valid_claims_json(), Some(TEST_KID));
221 let result = verify_jwt(&token, &test_jwks()).unwrap();
222 assert_eq!(result.subscription, json!({ "plan": "pro", "seats": 10 }));
223 assert_eq!(result.sub.as_deref(), Some("cust_test"));
224 }
225
226 #[test]
227 fn verify_expired_jwt_returns_token_expired() {
228 let now = unix_now();
229 let claims = json!({
230 "sub": "cust_test",
231 "iat": now - 7200,
232 "exp": now - 3600,
233 "iss": EXPECTED_ISSUER,
234 "aud": EXPECTED_AUDIENCE,
235 "subscription": {}
236 });
237 let token = make_jwt(&claims, Some(TEST_KID));
238 let err = verify_jwt(&token, &test_jwks()).unwrap_err();
239 assert!(matches!(err, LicenseError::TokenExpired));
240 }
241
242 #[test]
243 fn verify_hs256_jwt_returns_unsupported_algorithm() {
244 let token = make_hs256_jwt(&valid_claims_json());
245 let err = verify_jwt(&token, &test_jwks()).unwrap_err();
246 assert!(matches!(err, LicenseError::UnsupportedAlgorithm(_)));
247 }
248
249 #[test]
250 fn verify_unknown_kid_returns_key_not_found() {
251 let token = make_jwt(&valid_claims_json(), Some("no-such-key"));
252 let err = verify_jwt(&token, &test_jwks()).unwrap_err();
253 assert!(matches!(err, LicenseError::KeyNotFound(_)));
254 }
255
256 #[test]
257 fn verify_missing_kid_returns_missing_key_id() {
258 let token = make_jwt(&valid_claims_json(), None);
259 let err = verify_jwt(&token, &test_jwks()).unwrap_err();
260 assert!(matches!(err, LicenseError::MissingKeyId));
261 }
262
263 #[test]
264 fn verify_wrong_issuer_returns_jwt_error() {
265 let mut claims = valid_claims_json();
266 claims["iss"] = json!("https://evil.example.com");
267 let token = make_jwt(&claims, Some(TEST_KID));
268 let err = verify_jwt(&token, &test_jwks()).unwrap_err();
269 assert!(matches!(err, LicenseError::Jwt(_)));
270 }
271
272 #[test]
273 fn verify_wrong_audience_returns_jwt_error() {
274 let mut claims = valid_claims_json();
275 claims["aud"] = json!("wrong-audience");
276 let token = make_jwt(&claims, Some(TEST_KID));
277 let err = verify_jwt(&token, &test_jwks()).unwrap_err();
278 assert!(matches!(err, LicenseError::Jwt(_)));
279 }
280
281 #[tokio::test]
284 #[serial]
285 async fn no_env_vars_returns_no_license_key_set() {
286 env::remove_var("BUNNY_LICENSE_KEY");
287 env::remove_var("BUNNY_OFFLINE_LICENSE_KEY");
288 let err = validate_license(None).await.unwrap_err();
289 assert!(matches!(err, LicenseError::NoLicenseKeySet));
290 }
291
292 #[tokio::test]
293 #[serial]
294 async fn online_mode_without_host_returns_no_host_set() {
295 env::set_var("BUNNY_LICENSE_KEY", "any-key");
296 env::remove_var("BUNNY_HOST");
297 let err = validate_license(None).await.unwrap_err();
298 env::remove_var("BUNNY_LICENSE_KEY");
299 assert!(matches!(err, LicenseError::NoHostSet));
300 }
301
302 #[tokio::test]
305 #[serial]
306 async fn offline_valid_jwt_returns_claims() {
307 let token = make_jwt(&valid_claims_json(), Some(TEST_KID));
308 env::remove_var("BUNNY_LICENSE_KEY");
309 env::set_var("BUNNY_OFFLINE_LICENSE_KEY", &token);
310 let result = validate_license(None).await;
311 env::remove_var("BUNNY_OFFLINE_LICENSE_KEY");
312
313 let claims = result.unwrap();
314 assert_eq!(claims.subscription, json!({ "plan": "pro", "seats": 10 }));
315 }
316
317 #[tokio::test]
318 #[serial]
319 async fn offline_expired_jwt_returns_token_expired() {
320 let now = unix_now();
321 let claims = json!({
322 "sub": "cust_test",
323 "iat": now - 7200,
324 "exp": now - 3600,
325 "iss": EXPECTED_ISSUER,
326 "aud": EXPECTED_AUDIENCE,
327 "subscription": {}
328 });
329 let token = make_jwt(&claims, Some(TEST_KID));
330 env::remove_var("BUNNY_LICENSE_KEY");
331 env::set_var("BUNNY_OFFLINE_LICENSE_KEY", &token);
332 let err = validate_license(None).await.unwrap_err();
333 env::remove_var("BUNNY_OFFLINE_LICENSE_KEY");
334 assert!(matches!(err, LicenseError::TokenExpired));
335 }
336
337 #[tokio::test]
338 #[serial]
339 async fn offline_unknown_kid_returns_key_not_found() {
340 let token = make_jwt(&valid_claims_json(), Some("rotated-key-not-in-bundle"));
341 env::remove_var("BUNNY_LICENSE_KEY");
342 env::set_var("BUNNY_OFFLINE_LICENSE_KEY", &token);
343 let err = validate_license(None).await.unwrap_err();
344 env::remove_var("BUNNY_OFFLINE_LICENSE_KEY");
345 assert!(matches!(err, LicenseError::KeyNotFound(_)));
346 }
347
348 #[tokio::test]
351 #[serial]
352 async fn online_server_returns_401_yields_validation_failed() {
353 let server = MockServer::start().await;
354 Mock::given(method("POST"))
355 .and(path("/api/license/validate"))
356 .respond_with(ResponseTemplate::new(401))
357 .mount(&server)
358 .await;
359
360 env::set_var("BUNNY_LICENSE_KEY", "bad-key");
361 env::set_var("BUNNY_HOST", server.uri());
362 let err = validate_license(None).await.unwrap_err();
363 env::remove_var("BUNNY_LICENSE_KEY");
364 env::remove_var("BUNNY_HOST");
365
366 assert!(matches!(err, LicenseError::ValidationFailed { status: 401 }));
367 }
368
369 #[tokio::test]
370 #[serial]
371 async fn online_valid_flow_returns_claims() {
372 let server = MockServer::start().await;
373
374 let token = make_jwt(&valid_claims_json(), Some(TEST_KID));
375
376 Mock::given(method("POST"))
377 .and(path("/api/license/validate"))
378 .and(header("authorization", "Bearer valid-license-key"))
379 .respond_with(ResponseTemplate::new(200).set_body_json(json!({ "token": token })))
380 .mount(&server)
381 .await;
382
383 Mock::given(method("GET"))
384 .and(path("/api/.well-known/jwks.json"))
385 .respond_with(ResponseTemplate::new(200).set_body_json(test_jwks_json()))
386 .mount(&server)
387 .await;
388
389 env::set_var("BUNNY_LICENSE_KEY", "valid-license-key");
390 env::set_var("BUNNY_HOST", server.uri());
391 let result = validate_license(None).await;
392 env::remove_var("BUNNY_LICENSE_KEY");
393 env::remove_var("BUNNY_HOST");
394
395 let claims = result.unwrap();
396 assert_eq!(claims.subscription, json!({ "plan": "pro", "seats": 10 }));
397 assert_eq!(claims.sub.as_deref(), Some("cust_test"));
398 }
399
400 #[tokio::test]
401 #[serial]
402 async fn online_sends_instance_fingerprint_in_body() {
403 use wiremock::matchers::body_json;
404
405 let server = MockServer::start().await;
406 let token = make_jwt(&valid_claims_json(), Some(TEST_KID));
407
408 Mock::given(method("POST"))
409 .and(path("/api/license/validate"))
410 .and(body_json(json!({ "instance_fingerprint": "device-abc-123" })))
411 .respond_with(ResponseTemplate::new(200).set_body_json(json!({ "token": token })))
412 .mount(&server)
413 .await;
414
415 Mock::given(method("GET"))
416 .and(path("/api/.well-known/jwks.json"))
417 .respond_with(ResponseTemplate::new(200).set_body_json(test_jwks_json()))
418 .mount(&server)
419 .await;
420
421 env::set_var("BUNNY_LICENSE_KEY", "valid-license-key");
422 env::set_var("BUNNY_HOST", server.uri());
423 let result = validate_license(Some("device-abc-123")).await;
424 env::remove_var("BUNNY_LICENSE_KEY");
425 env::remove_var("BUNNY_HOST");
426
427 result.unwrap();
428 }
429
430 #[tokio::test]
431 #[serial]
432 async fn online_expired_jwt_from_server_returns_token_expired() {
433 let server = MockServer::start().await;
434
435 let now = unix_now();
436 let claims = json!({
437 "sub": "cust_test",
438 "iat": now - 7200,
439 "exp": now - 3600,
440 "iss": EXPECTED_ISSUER,
441 "aud": EXPECTED_AUDIENCE,
442 "subscription": {}
443 });
444 let token = make_jwt(&claims, Some(TEST_KID));
445
446 Mock::given(method("POST"))
447 .and(path("/api/license/validate"))
448 .respond_with(ResponseTemplate::new(200).set_body_json(json!({ "token": token })))
449 .mount(&server)
450 .await;
451
452 Mock::given(method("GET"))
453 .and(path("/api/.well-known/jwks.json"))
454 .respond_with(ResponseTemplate::new(200).set_body_json(test_jwks_json()))
455 .mount(&server)
456 .await;
457
458 env::set_var("BUNNY_LICENSE_KEY", "some-key");
459 env::set_var("BUNNY_HOST", server.uri());
460 let err = validate_license(None).await.unwrap_err();
461 env::remove_var("BUNNY_LICENSE_KEY");
462 env::remove_var("BUNNY_HOST");
463
464 assert!(matches!(err, LicenseError::TokenExpired));
465 }
466}