blueprint_auth/
models.rs

1use axum::http::uri;
2use base64::Engine;
3use blueprint_core::debug;
4use prost::Message;
5use std::collections::BTreeMap;
6use tracing::instrument;
7
8use crate::{
9    Error,
10    api_tokens::{CUSTOM_ENGINE, GeneratedApiToken},
11    db::{RocksDb, cf},
12    types::{KeyType, ServiceId},
13};
14
15#[derive(prost::Message, Clone)]
16pub struct ApiTokenModel {
17    /// The token ID.
18    #[prost(uint64)]
19    id: u64,
20    /// The token hash.
21    #[prost(string)]
22    token: String,
23    /// The service ID this token is associated with.
24    #[prost(uint64)]
25    service_id: u64,
26    /// The sub-service ID this token is associated with (zero means no sub-service).
27    #[prost(uint64)]
28    sub_service_id: u64,
29    /// The token's expiration time in seconds since the epoch.
30    ///
31    /// Zero means no expiration.
32    #[prost(uint64)]
33    pub expires_at: u64,
34    /// Whether the token is enabled.
35    #[prost(bool)]
36    pub is_enabled: bool,
37    /// Additional headers to be forwarded to the upstream service.
38    #[prost(bytes)]
39    pub additional_headers: Vec<u8>,
40}
41
42/// TLS profile configuration for a service
43#[derive(prost::Message, Clone, serde::Serialize, serde::Deserialize)]
44pub struct TlsProfile {
45    /// Whether TLS is enabled for this service
46    #[prost(bool)]
47    pub tls_enabled: bool,
48    /// Whether client mTLS is required
49    #[prost(bool)]
50    pub require_client_mtls: bool,
51    /// Encrypted server certificate PEM
52    #[prost(bytes)]
53    pub encrypted_server_cert: Vec<u8>,
54    /// Encrypted server private key PEM
55    #[prost(bytes)]
56    pub encrypted_server_key: Vec<u8>,
57    /// Encrypted client CA bundle PEM
58    #[prost(bytes)]
59    pub encrypted_client_ca_bundle: Vec<u8>,
60    /// Encrypted upstream CA bundle PEM
61    #[prost(bytes)]
62    pub encrypted_upstream_ca_bundle: Vec<u8>,
63    /// Encrypted upstream client certificate PEM
64    #[prost(bytes)]
65    pub encrypted_upstream_client_cert: Vec<u8>,
66    /// Encrypted upstream client private key PEM
67    #[prost(bytes)]
68    pub encrypted_upstream_client_key: Vec<u8>,
69    /// Maximum client certificate TTL in hours
70    #[prost(uint32)]
71    pub client_cert_ttl_hours: u32,
72    /// Optional SNI hostname for this service
73    #[prost(string, optional)]
74    pub sni: Option<String>,
75    /// Template to derive subjectAltNames for issued certificates
76    #[prost(string, optional)]
77    pub subject_alt_name_template: Option<String>,
78    /// Allowed DNS names for issued certificates
79    #[prost(string, repeated)]
80    pub allowed_dns_names: Vec<String>,
81}
82impl TlsProfile {
83    /// Convenience constructor for tests and setup code.
84    pub fn new(tls_enabled: bool, require_client_mtls: bool, client_cert_ttl_hours: u32) -> Self {
85        Self {
86            tls_enabled,
87            require_client_mtls,
88            encrypted_server_cert: Vec::new(),
89            encrypted_server_key: Vec::new(),
90            encrypted_client_ca_bundle: Vec::new(),
91            encrypted_upstream_ca_bundle: Vec::new(),
92            encrypted_upstream_client_cert: Vec::new(),
93            encrypted_upstream_client_key: Vec::new(),
94            client_cert_ttl_hours,
95            sni: None,
96            subject_alt_name_template: None,
97            allowed_dns_names: Vec::new(),
98        }
99    }
100}
101
102/// Represents a service model stored in the database.
103#[derive(prost::Message, Clone)]
104pub struct ServiceModel {
105    /// The service API Key prefix.
106    #[prost(string)]
107    pub api_key_prefix: String,
108    /// A List of service owners.
109    #[prost(message, repeated)]
110    pub owners: Vec<ServiceOwnerModel>,
111    /// The service upstream URL.
112    ///
113    /// This what the proxy will use to forward requests to the service.
114    #[prost(string)]
115    pub upstream_url: String,
116    /// TLS profile configuration for this service
117    #[prost(message, optional)]
118    pub tls_profile: Option<TlsProfile>,
119}
120
121/// A service owner model stored in the database.
122#[derive(prost::Message, Clone, PartialEq, Eq)]
123pub struct ServiceOwnerModel {
124    /// The Public key type.
125    ///
126    /// See [`KeyType`] for more details.
127    #[prost(enumeration = "KeyType")]
128    pub key_type: i32,
129    /// The public key bytes.
130    #[prost(bytes)]
131    pub key_bytes: Vec<u8>,
132}
133
134/// TLS certificate metadata for issued certificates
135#[derive(prost::Message, Clone, PartialEq, Eq)]
136pub struct TlsCertMetadata {
137    /// The service ID this certificate belongs to
138    #[prost(uint64)]
139    pub service_id: u64,
140    /// The certificate ID (unique within service)
141    #[prost(string)]
142    pub cert_id: String,
143    /// Certificate PEM encoded
144    #[prost(string)]
145    pub certificate_pem: String,
146    /// Certificate serial number
147    #[prost(string)]
148    pub serial: String,
149    /// Certificate expiration timestamp (seconds since epoch)
150    #[prost(uint64)]
151    pub expires_at: u64,
152    /// Whether the certificate is revoked
153    #[prost(bool)]
154    pub is_revoked: bool,
155    /// Certificate usage (client, server, both)
156    #[prost(string)]
157    pub usage: String,
158    /// Subject common name
159    #[prost(string)]
160    pub common_name: String,
161    /// Subject alternative names
162    #[prost(string, repeated)]
163    pub subject_alt_names: Vec<String>,
164    /// Issuance timestamp (seconds since epoch)
165    #[prost(uint64)]
166    pub issued_at: u64,
167    /// Issuing API key ID (for auditing)
168    #[prost(uint64)]
169    pub issued_by_api_key_id: u64,
170    /// Tenant ID associated with this certificate
171    #[prost(string, optional)]
172    pub tenant_id: Option<String>,
173}
174
175impl ApiTokenModel {
176    /// Find a token by its ID in the database.
177    #[instrument(skip(db), err)]
178    pub fn find_token_id(id: u64, db: &RocksDb) -> Result<Option<Self>, crate::Error> {
179        let cf = db
180            .cf_handle(cf::TOKENS_OPTS_CF)
181            .ok_or(crate::Error::UnknownColumnFamily(cf::TOKENS_OPTS_CF))?;
182        let token_opts_bytes = db.get_pinned_cf(&cf, id.to_be_bytes())?;
183
184        token_opts_bytes
185            .map(|bytes| ApiTokenModel::decode(bytes.as_ref()))
186            .transpose()
187            .map_err(Into::into)
188    }
189
190    /// Checks if the given plaintext matches the stored token hash.
191    #[instrument(skip(self), ret)]
192    pub fn is(&self, plaintext: &str) -> bool {
193        use tiny_keccak::Hasher;
194
195        let mut hasher = tiny_keccak::Keccak::v256();
196        hasher.update(plaintext.as_bytes());
197        let mut output = [0u8; 32];
198        hasher.finalize(&mut output);
199
200        let token_hash = CUSTOM_ENGINE.encode(output);
201
202        debug!(
203            %plaintext,
204            %self.token,
205            %token_hash,
206            token_match = self.token == token_hash,
207            "Checking token match",
208        );
209
210        self.token == token_hash
211    }
212
213    /// Saves the token to the database and returns the ID.
214    pub fn save(&mut self, db: &RocksDb) -> Result<u64, crate::Error> {
215        let tokens_cf = db
216            .cf_handle(cf::TOKENS_OPTS_CF)
217            .ok_or(crate::Error::UnknownColumnFamily(cf::TOKENS_OPTS_CF))?;
218        if self.id != 0 {
219            // update the existing token
220            let token_bytes = self.encode_to_vec();
221            db.put_cf(&tokens_cf, self.id.to_be_bytes(), token_bytes)?;
222            Ok(self.id)
223        } else {
224            self.create(db)
225        }
226    }
227
228    fn create(&mut self, db: &RocksDb) -> Result<u64, crate::Error> {
229        let tokens_cf = db
230            .cf_handle(cf::TOKENS_OPTS_CF)
231            .ok_or(crate::Error::UnknownColumnFamily(cf::TOKENS_OPTS_CF))?;
232
233        let seq_cf = db
234            .cf_handle(cf::SEQ_CF)
235            .ok_or(crate::Error::UnknownColumnFamily(cf::SEQ_CF))?;
236
237        let txn = db.transaction();
238        // Increment the sequence number
239        // internally, the adder merge operator will increment the sequence number
240        let mut retry_count = 0;
241        let max_retries = 10;
242        loop {
243            let result = txn.merge_cf(&seq_cf, b"tokens", 1u64.to_be_bytes());
244            match result {
245                Ok(()) => break,
246                Err(e) if e.kind() == rocksdb::ErrorKind::Busy => {
247                    retry_count += 1;
248                    if retry_count >= max_retries {
249                        return Err(crate::Error::RocksDB(e));
250                    }
251                }
252                Err(e) if e.kind() == rocksdb::ErrorKind::TryAgain => {
253                    retry_count += 1;
254                    if retry_count >= max_retries {
255                        return Err(crate::Error::RocksDB(e));
256                    }
257                }
258                Err(e) => return Err(crate::Error::RocksDB(e)),
259            }
260        }
261
262        let next_id = txn
263            .get_cf(&seq_cf, b"tokens")?
264            .map(|v| {
265                let mut id = [0u8; 8];
266                id.copy_from_slice(&v);
267                u64::from_be_bytes(id)
268            })
269            .unwrap_or(0u64);
270        self.id = next_id;
271        let tokens_bytes = self.encode_to_vec();
272        txn.put_cf(&tokens_cf, next_id.to_be_bytes(), tokens_bytes)?;
273        // commit the transaction
274        txn.commit()?;
275
276        Ok(next_id)
277    }
278
279    /// Returns the token expiration time in milliseconds since the epoch.
280    /// Zero means no expiration.
281    pub fn expires_at(&self) -> Option<u64> {
282        if self.expires_at == 0 {
283            None
284        } else {
285            Some(self.expires_at)
286        }
287    }
288
289    /// Checks if the token is expired.
290    #[instrument(skip(self), ret)]
291    pub fn is_expired(&self) -> bool {
292        if self.expires_at == 0 {
293            return false;
294        }
295        let now = std::time::SystemTime::now();
296        let since_epoch = now
297            .duration_since(std::time::UNIX_EPOCH)
298            .unwrap_or_default();
299
300        let now = since_epoch.as_secs();
301        self.expires_at < now
302    }
303
304    /// Return the service ID associated with this token.
305    pub fn service_id(&self) -> ServiceId {
306        ServiceId::new(self.service_id).with_subservice(self.sub_service_id)
307    }
308
309    /// Get the additional headers as a BTreeMap
310    pub fn get_additional_headers(&self) -> BTreeMap<String, String> {
311        if self.additional_headers.is_empty() {
312            BTreeMap::new()
313        } else {
314            serde_json::from_slice(&self.additional_headers).unwrap_or_default()
315        }
316    }
317
318    /// Set the additional headers from a BTreeMap
319    pub fn set_additional_headers(&mut self, headers: &BTreeMap<String, String>) {
320        self.additional_headers = serde_json::to_vec(headers).unwrap_or_default();
321    }
322}
323
324impl From<&GeneratedApiToken> for ApiTokenModel {
325    fn from(token: &GeneratedApiToken) -> Self {
326        let mut model = Self {
327            id: 0,
328            token: token.token.clone(),
329            service_id: token.service_id.0,
330            sub_service_id: token.service_id.1,
331            expires_at: token.expires_at().unwrap_or(0),
332            is_enabled: true,
333            additional_headers: Vec::new(),
334        };
335        model.set_additional_headers(token.additional_headers());
336        model
337    }
338}
339
340impl ServiceModel {
341    /// Convenience constructor to reduce boilerplate in tests.
342    pub fn new(api_key_prefix: impl Into<String>, upstream_url: impl Into<String>) -> Self {
343        Self {
344            api_key_prefix: api_key_prefix.into(),
345            owners: Vec::new(),
346            upstream_url: upstream_url.into(),
347            tls_profile: None,
348        }
349    }
350
351    /// Find a service by its ID in the database.
352    pub fn find_by_id(id: ServiceId, db: &RocksDb) -> Result<Option<Self>, crate::Error> {
353        let cf = db
354            .cf_handle(cf::SERVICES_USER_KEYS_CF)
355            .ok_or(crate::Error::UnknownColumnFamily(cf::SERVICES_USER_KEYS_CF))?;
356        let service_bytes = db.get_pinned_cf(&cf, id.to_be_bytes())?;
357
358        service_bytes
359            .map(|bytes| ServiceModel::decode(bytes.as_ref()))
360            .transpose()
361            .map_err(Into::into)
362    }
363
364    /// Saves the service to the database at the given ID.
365    pub fn save(&self, id: ServiceId, db: &RocksDb) -> Result<(), crate::Error> {
366        let cf = db
367            .cf_handle(cf::SERVICES_USER_KEYS_CF)
368            .ok_or(crate::Error::UnknownColumnFamily(cf::SERVICES_USER_KEYS_CF))?;
369        let service_bytes = self.encode_to_vec();
370        db.put_cf(&cf, id.to_be_bytes(), service_bytes)?;
371        Ok(())
372    }
373
374    /// Deletes the service from the database.
375    pub fn delete(id: ServiceId, db: &RocksDb) -> Result<(), crate::Error> {
376        let cf = db
377            .cf_handle(cf::SERVICES_USER_KEYS_CF)
378            .ok_or(crate::Error::UnknownColumnFamily(cf::SERVICES_USER_KEYS_CF))?;
379        db.delete_cf(&cf, id.to_be_bytes())?;
380        Ok(())
381    }
382
383    pub fn api_key_prefix(&self) -> &str {
384        &self.api_key_prefix
385    }
386
387    /// Checks if the service has a specific owner.
388    pub fn is_owner(&self, key_type: KeyType, key_bytes: &[u8]) -> bool {
389        self.owners
390            .iter()
391            .any(|owner| owner.key_type == key_type as i32 && owner.key_bytes == key_bytes)
392    }
393
394    /// Adds a new owner to the service.
395    pub fn add_owner(&mut self, key_type: KeyType, key_bytes: Vec<u8>) {
396        let owner = ServiceOwnerModel {
397            key_type: key_type as i32,
398            key_bytes,
399        };
400        self.owners.push(owner);
401    }
402
403    /// Removes an owner from the service.
404    pub fn remove_owner(&mut self, key_type: KeyType, key_bytes: &[u8]) {
405        self.owners
406            .retain(|owner| !(owner.key_type == key_type as i32 && owner.key_bytes == key_bytes));
407    }
408
409    /// Returns the list of owners.
410    pub fn owners(&self) -> &[ServiceOwnerModel] {
411        &self.owners
412    }
413
414    /// Returns the upstream URL.
415    pub fn upstream_url(&self) -> Result<uri::Uri, Error> {
416        self.upstream_url.parse::<uri::Uri>().map_err(Into::into)
417    }
418
419    /// Get the TLS profile for this service
420    pub fn tls_profile(&self) -> Option<&TlsProfile> {
421        self.tls_profile.as_ref()
422    }
423
424    /// Set the TLS profile for this service
425    pub fn set_tls_profile(&mut self, tls_profile: TlsProfile) {
426        self.tls_profile = Some(tls_profile);
427    }
428
429    /// Check if TLS is enabled for this service
430    pub fn is_tls_enabled(&self) -> bool {
431        self.tls_profile
432            .as_ref()
433            .map(|p| p.tls_enabled)
434            .unwrap_or(false)
435    }
436
437    /// Check if client mTLS is required for this service
438    pub fn requires_client_mtls(&self) -> bool {
439        self.tls_profile
440            .as_ref()
441            .map(|p| p.require_client_mtls)
442            .unwrap_or(false)
443    }
444}
445
446/// Configuration for creating a new TlsCertMetadata
447pub struct TlsCertMetadataConfig {
448    pub service_id: u64,
449    pub cert_id: String,
450    pub certificate_pem: String,
451    pub serial: String,
452    pub expires_at: u64,
453    pub usage: String,
454    pub common_name: String,
455    pub issued_by_api_key_id: u64,
456}
457
458impl TlsCertMetadata {
459    /// Create a new certificate metadata entry
460    pub fn new(config: TlsCertMetadataConfig) -> Self {
461        Self {
462            service_id: config.service_id,
463            cert_id: config.cert_id,
464            certificate_pem: config.certificate_pem,
465            serial: config.serial,
466            expires_at: config.expires_at,
467            is_revoked: false,
468            usage: config.usage,
469            common_name: config.common_name,
470            subject_alt_names: Vec::new(),
471            issued_at: std::time::SystemTime::now()
472                .duration_since(std::time::UNIX_EPOCH)
473                .unwrap_or_default()
474                .as_secs(),
475            issued_by_api_key_id: config.issued_by_api_key_id,
476            tenant_id: None,
477        }
478    }
479
480    /// Generate the database key for this certificate metadata
481    pub fn db_key(&self) -> Vec<u8> {
482        let mut key = Vec::with_capacity(16 + self.cert_id.len());
483        key.extend_from_slice(&self.service_id.to_be_bytes());
484        key.extend_from_slice(self.cert_id.as_bytes());
485        key
486    }
487
488    /// Save the certificate metadata to the database
489    pub fn save(&self, db: &RocksDb) -> Result<(), crate::Error> {
490        let cf = db.cf_handle(crate::db::cf::TLS_CERT_METADATA_CF).ok_or(
491            crate::Error::UnknownColumnFamily(crate::db::cf::TLS_CERT_METADATA_CF),
492        )?;
493
494        let key = self.db_key();
495        let metadata_bytes = self.encode_to_vec();
496        db.put_cf(&cf, key, metadata_bytes)?;
497        Ok(())
498    }
499
500    /// Find certificate metadata by service ID and certificate ID
501    pub fn find_by_service_and_cert_id(
502        service_id: u64,
503        cert_id: &str,
504        db: &RocksDb,
505    ) -> Result<Option<Self>, crate::Error> {
506        let cf = db.cf_handle(crate::db::cf::TLS_CERT_METADATA_CF).ok_or(
507            crate::Error::UnknownColumnFamily(crate::db::cf::TLS_CERT_METADATA_CF),
508        )?;
509
510        let mut key = Vec::with_capacity(16 + cert_id.len());
511        key.extend_from_slice(&service_id.to_be_bytes());
512        key.extend_from_slice(cert_id.as_bytes());
513
514        let metadata_bytes = db.get_pinned_cf(&cf, key)?;
515
516        metadata_bytes
517            .map(|bytes| TlsCertMetadata::decode(bytes.as_ref()))
518            .transpose()
519            .map_err(Into::into)
520    }
521
522    /// Check if the certificate is expired
523    pub fn is_expired(&self) -> bool {
524        let now = std::time::SystemTime::now()
525            .duration_since(std::time::UNIX_EPOCH)
526            .unwrap_or_default()
527            .as_secs();
528        self.expires_at < now
529    }
530
531    /// Revoke the certificate
532    pub fn revoke(&mut self) {
533        self.is_revoked = true;
534    }
535
536    /// Add a subject alternative name
537    pub fn add_subject_alt_name(&mut self, san: String) {
538        self.subject_alt_names.push(san);
539    }
540
541    /// Set the tenant ID
542    pub fn set_tenant_id(&mut self, tenant_id: String) {
543        self.tenant_id = Some(tenant_id);
544    }
545}
546
547/// Helper functions for TLS asset management
548pub mod tls_assets {
549    use super::*;
550    use crate::db::cf;
551
552    /// Store encrypted TLS asset in the database
553    pub fn store_tls_asset(
554        db: &RocksDb,
555        service_id: u64,
556        asset_type: &str,
557        encrypted_data: &[u8],
558    ) -> Result<(), crate::Error> {
559        let cf = db
560            .cf_handle(cf::TLS_ASSETS_CF)
561            .ok_or(crate::Error::UnknownColumnFamily(cf::TLS_ASSETS_CF))?;
562
563        let key = format!("{service_id}:{asset_type}");
564        db.put_cf(&cf, key.as_bytes(), encrypted_data)?;
565        Ok(())
566    }
567
568    /// Retrieve encrypted TLS asset from the database
569    pub fn get_tls_asset(
570        db: &RocksDb,
571        service_id: u64,
572        asset_type: &str,
573    ) -> Result<Option<Vec<u8>>, crate::Error> {
574        let cf = db
575            .cf_handle(cf::TLS_ASSETS_CF)
576            .ok_or(crate::Error::UnknownColumnFamily(cf::TLS_ASSETS_CF))?;
577
578        let key = format!("{service_id}:{asset_type}");
579        let asset_bytes = db.get_pinned_cf(&cf, key.as_bytes())?;
580        Ok(asset_bytes.map(|bytes| bytes.to_vec()))
581    }
582
583    /// Delete TLS asset from the database
584    pub fn delete_tls_asset(
585        db: &RocksDb,
586        service_id: u64,
587        asset_type: &str,
588    ) -> Result<(), crate::Error> {
589        let cf = db
590            .cf_handle(cf::TLS_ASSETS_CF)
591            .ok_or(crate::Error::UnknownColumnFamily(cf::TLS_ASSETS_CF))?;
592
593        let key = format!("{service_id}:{asset_type}");
594        db.delete_cf(&cf, key.as_bytes())?;
595        Ok(())
596    }
597
598    /// Log certificate issuance for auditing
599    pub fn log_certificate_issuance(
600        db: &RocksDb,
601        service_id: u64,
602        cert_id: &str,
603        api_key_id: u64,
604        tenant_id: Option<&str>,
605    ) -> Result<(), crate::Error> {
606        let cf = db
607            .cf_handle(cf::TLS_ISSUANCE_LOG_CF)
608            .ok_or(crate::Error::UnknownColumnFamily(cf::TLS_ISSUANCE_LOG_CF))?;
609
610        let timestamp = std::time::SystemTime::now()
611            .duration_since(std::time::UNIX_EPOCH)
612            .unwrap_or_default()
613            .as_secs();
614
615        let mut log_entry = Vec::new();
616        log_entry.extend_from_slice(&timestamp.to_be_bytes());
617        log_entry.extend_from_slice(&service_id.to_be_bytes());
618        log_entry.extend_from_slice(cert_id.as_bytes());
619        log_entry.push(0u8); // separator
620        log_entry.extend_from_slice(&api_key_id.to_be_bytes());
621        if let Some(tenant_id) = tenant_id {
622            log_entry.extend_from_slice(tenant_id.as_bytes());
623        }
624
625        let log_key = format!("{timestamp}:{cert_id}");
626        db.put_cf(&cf, log_key.as_bytes(), log_entry)?;
627        Ok(())
628    }
629
630    /// Get certificate issuance log for a service
631    pub fn get_certificate_issuance_log(
632        db: &RocksDb,
633        service_id: u64,
634    ) -> Result<Vec<TlsCertMetadata>, crate::Error> {
635        let cf = db
636            .cf_handle(cf::TLS_CERT_METADATA_CF)
637            .ok_or(crate::Error::UnknownColumnFamily(cf::TLS_CERT_METADATA_CF))?;
638
639        let mut certificates = Vec::new();
640        let prefix = service_id.to_be_bytes();
641
642        // Iterate through all certificates for this service
643        let iter = db.prefix_iterator_cf(&cf, prefix);
644        for item in iter {
645            let (_key, value) = item?;
646            let metadata = TlsCertMetadata::decode(&*value)?;
647            if metadata.service_id == service_id {
648                certificates.push(metadata);
649            }
650        }
651
652        Ok(certificates)
653    }
654}
655
656#[cfg(test)]
657mod tests {
658    use crate::{api_tokens::ApiTokenGenerator, types::ServiceId};
659
660    use super::*;
661
662    #[test]
663    fn token_generator() {
664        let mut rng = blueprint_std::BlueprintRng::new();
665        let tmp_dir = tempfile::tempdir().unwrap();
666        let db = RocksDb::open(tmp_dir.path(), &Default::default()).unwrap();
667        let service_id = ServiceId::new(1);
668        let generator = ApiTokenGenerator::new();
669        let token = generator.generate_token(service_id, &mut rng);
670        let mut token = ApiTokenModel::from(&token);
671
672        // Save the token to the database
673        let id = token.save(&db).unwrap();
674        assert_eq!(id, 1);
675
676        // Find the token by ID
677        let found_token = ApiTokenModel::find_token_id(id, &db).unwrap();
678        assert!(found_token.is_some());
679        let found_token = found_token.unwrap();
680        assert_eq!(found_token.id, id);
681        assert_eq!(found_token.token, token.token);
682        assert_eq!(found_token.expires_at, token.expires_at);
683        assert_eq!(found_token.is_enabled, token.is_enabled);
684    }
685
686    #[test]
687    fn token_with_headers() {
688        use std::collections::BTreeMap;
689
690        let mut rng = blueprint_std::BlueprintRng::new();
691        let tmp_dir = tempfile::tempdir().unwrap();
692        let db = RocksDb::open(tmp_dir.path(), &Default::default()).unwrap();
693        let service_id = ServiceId::new(1);
694        let generator = ApiTokenGenerator::new();
695
696        // Create headers
697        let mut headers = BTreeMap::new();
698        headers.insert("X-Tenant-Id".to_string(), "tenant123".to_string());
699        headers.insert("X-User-Type".to_string(), "premium".to_string());
700
701        // Generate token with headers
702        let token = generator.generate_token_with_expiration_and_headers(
703            service_id,
704            0,
705            headers.clone(),
706            &mut rng,
707        );
708        let mut token_model = ApiTokenModel::from(&token);
709
710        // Save the token to the database
711        let id = token_model.save(&db).unwrap();
712
713        // Find the token by ID
714        let found_token = ApiTokenModel::find_token_id(id, &db).unwrap().unwrap();
715
716        // Verify headers are preserved
717        let found_headers = found_token.get_additional_headers();
718        assert_eq!(found_headers, headers);
719    }
720
721    #[test]
722    fn test_additional_headers_methods() {
723        use std::collections::BTreeMap;
724
725        let mut token_model = ApiTokenModel {
726            id: 0,
727            token: "test".to_string(),
728            service_id: 1,
729            sub_service_id: 0,
730            expires_at: 0,
731            is_enabled: true,
732            additional_headers: Vec::new(),
733        };
734
735        // Test empty headers
736        assert!(token_model.get_additional_headers().is_empty());
737
738        // Set headers
739        let mut headers = BTreeMap::new();
740        headers.insert("X-Test".to_string(), "value".to_string());
741        token_model.set_additional_headers(&headers);
742
743        // Get headers back
744        let retrieved = token_model.get_additional_headers();
745        assert_eq!(retrieved, headers);
746    }
747}