covert_system/expiration_manager/
mod.rs

1pub mod clock;
2mod lease;
3
4use std::collections::HashMap;
5use std::pin::Pin;
6use std::sync::Arc;
7
8use chrono::Duration;
9use covert_types::auth::AuthPolicy;
10use covert_types::error::ApiError;
11use covert_types::methods::psql::RenewLeaseResponse;
12use covert_types::methods::RenewLeaseParams;
13use covert_types::request::{Operation, Request};
14use covert_types::state::StorageState;
15use covert_types::ttl::calculate_ttl;
16use futures::stream::FuturesOrdered;
17use futures::{Future, StreamExt};
18use hyper::http;
19use tokio::sync::RwLock;
20use tokio::sync::{mpsc, Notify};
21use tokio::time::timeout;
22use tracing::{debug, error, info};
23use uuid::Uuid;
24
25use crate::error::{Error, ErrorType};
26use crate::repos::Repos;
27
28use self::clock::Clock;
29pub use self::lease::LeaseEntry;
30
31use super::router::Router;
32
33/// The expiration manager is resposible for revoking and renewing leases.
34pub struct ExpirationManager {
35    /// Used to notify the revocation worker of when new leases are registered
36    background_task: Notify,
37    /// Router to send revoke / renew requests to the backends
38    router: Arc<Router>,
39    /// Storage
40    repos: Repos,
41    /// Shutdown listener
42    shutdown_rx: Arc<RwLock<mpsc::Receiver<()>>>,
43    /// Shutdown transmitter
44    shutdown_tx: mpsc::Sender<()>,
45    /// Time before retrying a failed revocation
46    revocation_retry_timeout: Duration,
47    /// Max number of revoke requests before the lease is deleted
48    revocation_max_retries: u32,
49    /// Timeout for the revoke endpoint
50    revocation_timeout: std::time::Duration,
51    /// Number of leases the revocation worker should try to revoke at the same time
52    revocation_worker_concurrency: usize,
53    /// Provides time information. Gives us deterministic time in tests.
54    clock: Arc<dyn Clock>,
55}
56
57impl ExpirationManager {
58    /// Create a new expiration manager.
59    pub fn new(router: Arc<Router>, repos: Repos, clock: impl Clock) -> Self {
60        let (tx, rx) = mpsc::channel(1);
61
62        ExpirationManager {
63            background_task: Notify::new(),
64            router,
65            repos,
66            shutdown_rx: Arc::new(RwLock::new(rx)),
67            shutdown_tx: tx,
68            revocation_retry_timeout: Duration::seconds(5),
69            revocation_max_retries: 10,
70            revocation_timeout: std::time::Duration::from_secs(10),
71            revocation_worker_concurrency: 100,
72            clock: Arc::new(clock),
73        }
74    }
75
76    /// Register a new [`LeaseEntry`].
77    ///
78    /// This is the only way to register new leases, leases should *not* be inserted
79    /// directly to the [`LeaseStore`] without going throught the expiration manager.
80    pub async fn register(&self, le: LeaseEntry) -> Result<(), Error> {
81        self.repos.lease.create(&le).await?;
82        // Let the revocation worker know about the lease.
83        self.background_task.notify_one();
84        Ok(())
85    }
86
87    /// Revoke all leases issued by mounts under a given path prefix.
88    pub async fn revoke_leases_by_mount_prefix(
89        &self,
90        prefix: &str,
91        namespace_id: &str,
92    ) -> Result<Vec<LeaseEntry>, Error> {
93        let leases = self
94            .repos
95            .lease
96            .list_by_mount_prefix(prefix, namespace_id)
97            .await?;
98
99        let mut revoke_futures = FuturesOrdered::new();
100
101        for lease in leases {
102            revoke_futures
103                .push_back(async move { self.revoke_lease_entry(&lease).await.map(|_| lease) });
104        }
105
106        let revoked_leases = revoke_futures
107            .collect::<Vec<_>>()
108            .await
109            .into_iter()
110            .filter_map(|res| match res {
111                Ok(le) => Some(le),
112                Err(error) => {
113                    tracing::error!(?error, "Failed to revoke lease");
114                    None
115                }
116            })
117            .collect::<Vec<_>>();
118
119        Ok(revoked_leases)
120    }
121
122    /// List all leases issued by mounts under a given path prefix.
123    pub async fn list_by_mount_prefix(
124        &self,
125        prefix: &str,
126        namespace_id: &str,
127    ) -> Result<Vec<LeaseEntry>, Error> {
128        self.repos
129            .lease
130            .list_by_mount_prefix(prefix, namespace_id)
131            .await
132    }
133
134    /// Lookup a lease by its id.
135    pub async fn lookup(
136        &self,
137        lease_id: &str,
138        namespace_id: &str,
139    ) -> Result<Option<LeaseEntry>, Error> {
140        self.repos.lease.lookup(lease_id, namespace_id).await
141    }
142
143    /// Revoke a lease by its id.
144    pub async fn revoke_lease_entry_by_id(
145        &self,
146        lease_id: &str,
147        namespace_id: &str,
148    ) -> Result<LeaseEntry, Error> {
149        let le = self
150            .lookup(lease_id, namespace_id)
151            .await?
152            .ok_or_else(|| ErrorType::NotFound(format!("Lease `{lease_id}` not found")))?;
153
154        self.revoke_lease_entry(&le)
155            .await
156            .map(|_| le)
157            .map_err(|error| {
158                tracing::error!(?error, lease_id, "Unable to revoke lease.");
159                ErrorType::RevokeLease {
160                    source: Box::new(error),
161                    lease_id: lease_id.to_string(),
162                }
163                .into()
164            })
165    }
166
167    /// Send a revoke request to the backend that is resposible for revoking the
168    /// leased data.
169    #[tracing::instrument(skip_all, fields(lease_id = le.id, issued_mount_path = le.issued_mount_path))]
170    async fn send_lease_revoke_request(&self, le: &LeaseEntry) -> Result<(), ApiError> {
171        let ns = self
172            .repos
173            .namespace
174            .lookup(&le.namespace_id)
175            .await?
176            .ok_or_else(ApiError::internal_error)?;
177        let ns_path = self.repos.namespace.get_full_path(&le.namespace_id).await?;
178
179        // Perform revocation
180        let mut extensions = http::Extensions::new();
181        extensions.insert(AuthPolicy::Authenticated);
182        extensions.insert(StorageState::Unsealed);
183        extensions.insert(ns);
184
185        let revoke_path = le.revoke_path.as_ref().map_or_else(
186            || "sys/token/revoke".into(),
187            |revoke_path| format!("{}{revoke_path}", le.issued_mount_path),
188        );
189
190        let req = Request {
191            id: Uuid::default(),
192            namespace: ns_path.split('/').map(From::from).collect(),
193            operation: Operation::Revoke,
194            path: revoke_path,
195            data: le.revoke_data.clone().into(),
196            extensions,
197            token: None,
198            params: Vec::default(),
199            query_string: String::default(),
200            headers: HashMap::default(),
201        };
202
203        match timeout(self.revocation_timeout, self.router.route(req)).await {
204            Ok(backend_resp) => backend_resp.map(|_| ()).map_err(|error| {
205                tracing::error!(?error, "Backend failed to revoke lease");
206                error
207            }),
208            Err(_) => Err(ApiError::timeout()),
209        }
210    }
211
212    /// Renew a lease by its id.
213    #[allow(clippy::too_many_lines)]
214    pub async fn renew_lease_entry(
215        &self,
216        lease_id: &str,
217        namespace_id: &str,
218        ttl: Option<std::time::Duration>,
219    ) -> Result<LeaseEntry, Error> {
220        let mut le = self
221            .repos
222            .lease
223            .lookup(lease_id, namespace_id)
224            .await?
225            .ok_or_else(|| ErrorType::NotFound(format!("Lease `{lease_id}` not found")))?;
226        let mount_config = self
227            .repos
228            .mount
229            .get_by_path(&le.issued_mount_path, &le.namespace_id)
230            .await?
231            .ok_or_else(|| ErrorType::MountNotFound {
232                path: le.issued_mount_path.clone(),
233            })?
234            .config;
235
236        let ttl =
237            calculate_ttl(self.clock.now(), le.issued_at, &mount_config, ttl).map_err(|_| {
238                ErrorType::InternalError(anyhow::Error::msg(
239                    "Failed to calculate TTL when renewing lease",
240                ))
241            })?;
242
243        let ns = self
244            .repos
245            .namespace
246            .lookup(&le.namespace_id)
247            .await?
248            .ok_or_else(|| {
249                ErrorType::InternalError(anyhow::Error::msg("Unable to find namespace for lease"))
250            })?;
251        let ns_path = self.repos.namespace.get_full_path(&le.namespace_id).await?;
252
253        // Perform renewal
254        let mut extensions = http::Extensions::new();
255        extensions.insert(AuthPolicy::Authenticated);
256        extensions.insert(StorageState::Unsealed);
257        extensions.insert(ns);
258
259        let renew_path = le.renew_path.as_ref().map_or_else(
260            || "sys/token/renew".into(),
261            |renew_path| format!("{}{renew_path}", le.issued_mount_path),
262        );
263
264        let data = RenewLeaseParams {
265            ttl: ttl
266                .to_std()
267                .map_err(|_| ErrorType::BadRequest("Bad renew TTL".into()))?,
268            data: le.renew_data.clone(),
269        };
270        let data = serde_json::to_vec(&data)
271            .map_err(|_| ErrorType::BadRequest("Bad renew payload".into()))?;
272
273        let req = Request {
274            id: Uuid::default(),
275            operation: Operation::Renew,
276            namespace: ns_path.split('/').map(From::from).collect(),
277            path: renew_path,
278            data: data.into(),
279            extensions,
280            token: None,
281            params: Vec::default(),
282            query_string: String::default(),
283            headers: HashMap::default(),
284        };
285
286        let router = Arc::clone(&self.router);
287
288        let renew_path = req.path.clone();
289        match router.route(req).await {
290            Ok(resp) => {
291                let resp = resp.response.data::<RenewLeaseResponse>().map_err(|_| {
292                    ErrorType::InternalError(anyhow::Error::msg("Unexpected renew response"))
293                })?;
294
295                let ttl = calculate_ttl(
296                    self.clock.now(),
297                    le.issued_at,
298                    &mount_config,
299                    Some(resp.ttl),
300                )
301                .map_err(|_| {
302                    ErrorType::InternalError(anyhow::Error::msg(
303                        "Failed to calculate TTL when renewing lease",
304                    ))
305                })?;
306
307                let now = self.clock.now();
308
309                le.expires_at = now + ttl;
310                le.last_renewal_time = now;
311                self.repos
312                    .lease
313                    .renew(
314                        lease_id,
315                        &le.namespace_id,
316                        le.expires_at,
317                        le.last_renewal_time,
318                    )
319                    .await?;
320
321                Ok(le)
322            }
323            Err(error) => {
324                tracing::error!(?error, lease_id, renew_path, "Unable to renew lease.");
325                Err(ErrorType::RenewLease {
326                    source: Box::new(error),
327                    lease_id: lease_id.to_string(),
328                }
329                .into())
330            }
331        }
332    }
333
334    /// Start the revocation worker.
335    #[tracing::instrument(skip(self), name = "start_expiration_manager")]
336    pub async fn start(&self) -> Result<(), Error> {
337        let mut shutdown_rx = self.shutdown_rx.write().await;
338
339        loop {
340            let now = self.clock.now();
341            #[allow(clippy::cast_possible_truncation)]
342            let leases = match self
343                .repos
344                .lease
345                .pull(self.revocation_worker_concurrency as u32, now)
346                .await
347            {
348                Ok(leases) => leases,
349                Err(error) => {
350                    error!(?error, "Failed to pull leases for revocation");
351                    tokio::time::sleep(std::time::Duration::from_millis(500)).await;
352                    continue;
353                }
354            };
355
356            let number_of_leases = leases.len();
357            debug!("Fetched {} leases ready for revocation", number_of_leases);
358            if number_of_leases == 0 {
359                // TODO: this might need more care to ensure no leases are lost
360                let next_lease_fut = self
361                    .repos
362                    .lease
363                    .peek()
364                    .await?
365                    .map(|le| le.expires_at - self.clock.now())
366                    .and_then(|duration| duration.to_std().ok())
367                    .map_or_else::<Pin<Box<dyn Future<Output = ()> + Send + Sync + 'static>>, _, _>(
368                        || Box::pin(std::future::pending()),
369                        |duration| self.clock.sleep(duration),
370                    );
371
372                tokio::select! {
373                        // If new lease is registered
374                        _ = self.background_task.notified() => {
375                            continue;
376                        }
377                        // Future that resolves when the next lease is ready
378                        // to be revoked
379                        _ = next_lease_fut => {
380                            continue;
381                        }
382                        // Break loop on shutdown signal
383                        _ = shutdown_rx.recv() => {
384                            break;
385                        }
386                }
387            }
388
389            futures::stream::iter(leases)
390                .for_each_concurrent(self.revocation_worker_concurrency, |le| async move {
391                    // Errors are handled by this function, no more logging
392                    // or error handling is required at this point.
393                    let _ = self.revoke_lease_entry(&le).await;
394                })
395                .await;
396        }
397
398        info!("Expiration manager shutting down");
399        Ok(())
400    }
401
402    /// Perform revocation of the [`LeaseEntry`].
403    #[tracing::instrument(skip_all, fields(lease_id = le.id, mount_path = le.issued_mount_path))]
404    async fn revoke_lease_entry(&self, le: &LeaseEntry) -> Result<(), Error> {
405        let res = self.send_lease_revoke_request(le).await;
406        match res {
407            Ok(_) => {
408                self.repos
409                    .lease
410                    .delete(&le.id, &le.namespace_id)
411                    .await
412                    .map_err(|error| {
413                        // **NOTE**: This means that revoke endpoints should be idempotent as
414                        // this will trigger a new revoke request to be sent even though
415                        // the lease was just revoked from the backend
416                        error!(?error, "Failed to delete lease from the lease store");
417                        error
418                    })
419                    .map(|_| ())
420            }
421            Err(error) => {
422                error!(?error, "failed to revoke lease entry from backend");
423                // TODO: why is +1 needed here
424                if le.failed_revocation_attempts + 1 >= self.revocation_max_retries {
425                    // Delete from store
426                    if let Err(error) = self.repos.lease.delete(&le.id, &le.namespace_id).await {
427                        error!(?error, "failed to delete lease from store that has passed max number of revocation retries");
428                    };
429                } else {
430                    // Increase failed count
431                    if let Err(error) = self
432                        .repos
433                        .lease
434                        .increment_failed_revocation_attempts(
435                            &le.id,
436                            &le.namespace_id,
437                            // TODO: exp backoff and configure revocation_retry_timeout
438                            le.expires_at + self.revocation_retry_timeout,
439                        )
440                        .await
441                    {
442                        error!(?error, "failed to delete lease from store that has passed max number of revocation retries");
443                    }
444                }
445                Err(ErrorType::InternalError(error.into()).into())
446            }
447        }
448    }
449
450    /// Shutdown the expiration manager.
451    #[tracing::instrument(skip(self), name = "stop_expiration_manager")]
452    pub async fn stop(&self) {
453        // TODO: wait for expiration manager to shutdown fully.
454        let _ = self.shutdown_tx.send(()).await;
455    }
456}
457
458#[cfg(test)]
459mod tests {
460    use chrono::{DateTime, Utc};
461    use covert_framework::{Backend, SyncService};
462    use covert_types::{
463        backend::{BackendCategory, BackendType},
464        mount::{MountConfig, MountEntry},
465        response::Response,
466    };
467    use sqlx::SqlitePool;
468
469    use crate::{
470        expiration_manager::clock::test::TestClock,
471        repos::{mount::tests::pool, namespace::Namespace},
472        system::SYSTEM_MOUNT_PATH,
473    };
474
475    use super::*;
476
477    async fn secret_engine_handle(
478        req: Request,
479        recorder: Arc<RequestRecorder>,
480        renew_ttl: Option<std::time::Duration>,
481        clock: TestClock,
482    ) -> Result<Response, ApiError> {
483        let mut requests = recorder.0.write().await;
484        requests.push(RequestInfo {
485            path: req.path.clone(),
486            operation: req.operation,
487            reveived_at: Some(clock.now()),
488        });
489        drop(requests);
490
491        if req.path == "creds" {
492            match req.operation {
493                Operation::Revoke => Ok(Response::ok()),
494                Operation::Renew => {
495                    let data = RenewLeaseResponse {
496                        ttl: renew_ttl.unwrap(),
497                    };
498                    Ok(Response::Raw(serde_json::to_value(data).unwrap()))
499                }
500                _ => Err(ApiError::not_found()),
501            }
502        } else if req.path == "creds-slow" {
503            tokio::time::sleep(std::time::Duration::from_millis(200)).await;
504            match req.operation {
505                Operation::Revoke => Ok(Response::ok()),
506                Operation::Renew => {
507                    let data = RenewLeaseResponse {
508                        ttl: renew_ttl.unwrap(),
509                    };
510                    Ok(Response::Raw(serde_json::to_value(data).unwrap()))
511                }
512                _ => Err(ApiError::not_found()),
513            }
514        } else {
515            Err(ApiError::not_found())
516        }
517    }
518
519    async fn system_handle(
520        req: Request,
521        recorder: Arc<RequestRecorder>,
522        renew_ttl: Option<std::time::Duration>,
523        clock: TestClock,
524    ) -> Result<Response, ApiError> {
525        let mut requests = recorder.0.write().await;
526        requests.push(RequestInfo {
527            path: req.path.clone(),
528            operation: req.operation,
529            reveived_at: Some(clock.now()),
530        });
531        drop(requests);
532
533        if req.path == "token/revoke" {
534            match req.operation {
535                Operation::Revoke => Ok(Response::ok()),
536                _ => Err(ApiError::not_found()),
537            }
538        } else if req.path == "token/renew" {
539            let data = RenewLeaseResponse {
540                ttl: renew_ttl.unwrap(),
541            };
542            match req.operation {
543                Operation::Renew => Ok(Response::Raw(serde_json::to_value(data).unwrap())),
544                _ => Err(ApiError::not_found()),
545            }
546        } else {
547            Err(ApiError::not_found())
548        }
549    }
550
551    #[derive(Debug, Clone)]
552    pub struct RequestInfo {
553        pub path: String,
554        pub operation: Operation,
555        pub reveived_at: Option<DateTime<Utc>>,
556    }
557
558    impl PartialEq for RequestInfo {
559        fn eq(&self, other: &Self) -> bool {
560            if self.path != other.path {
561                return false;
562            }
563
564            if self.operation != other.operation {
565                return false;
566            }
567
568            let Some(received_at) = self.reveived_at else {
569                return true;
570            };
571            let Some(other_received_at) = other.reveived_at else {
572                return true;
573            };
574
575            received_at == other_received_at
576        }
577    }
578
579    pub struct RequestRecorder(RwLock<Vec<RequestInfo>>);
580
581    async fn advance(clock: &TestClock, duration: Duration) {
582        clock.advance(duration.num_milliseconds());
583        // Yield and give some time for expiration manager to wake up and revoke
584        tokio::time::sleep(std::time::Duration::from_millis(10)).await;
585    }
586
587    async fn advance_to(clock: &TestClock, duration: DateTime<Utc>) {
588        clock.set(duration.timestamp_millis());
589        // Yield and give some time for expiration manager to wake up and revoke
590        tokio::time::sleep(std::time::Duration::from_millis(10)).await;
591    }
592
593    #[tokio::test]
594    async fn revoke_secret_after_ttl_expires() {
595        let recorder = Arc::new(RequestRecorder(RwLock::new(Vec::new())));
596        let clock = TestClock::new();
597
598        let pool = Arc::new(pool().await);
599        let u_pool = SqlitePool::connect(":memory:").await.unwrap();
600        let repos = Repos::new(pool, u_pool);
601
602        let ns = Namespace {
603            id: Uuid::new_v4().to_string(),
604            name: "root".to_string(),
605            parent_namespace_id: None,
606        };
607        repos.namespace.create(&ns).await.unwrap();
608
609        let router = Arc::new(Router::new(repos.mount.clone()));
610        let exp_m = Arc::new(ExpirationManager::new(
611            Arc::clone(&router),
612            repos.clone(),
613            clock.clone(),
614        ));
615
616        let expiration_manager = Arc::clone(&exp_m);
617        tokio::spawn(async move {
618            expiration_manager.start().await.unwrap();
619        });
620        tokio::task::yield_now().await;
621
622        // Setup mount
623        let me = MountEntry {
624            id: Uuid::new_v4(),
625            backend_type: BackendType::Postgres,
626            config: MountConfig::default(),
627            path: "foo/".to_string(),
628            namespace_id: ns.id.clone(),
629        };
630        repos.mount.create(&me).await.unwrap();
631
632        let recorder_moved = Arc::clone(&recorder);
633        let clock_moved = clock.clone();
634        let handler = SyncService::new(tower::service_fn(move |req| {
635            let recorder = Arc::clone(&recorder_moved);
636            let clock = clock_moved.clone();
637            async move { secret_engine_handle(req, recorder, None, clock).await }
638        }));
639        let backend = Arc::new(Backend {
640            category: BackendCategory::Logical,
641            migrations: vec![],
642            variant: me.backend_type,
643            handler,
644        });
645
646        router.mount(me.id, Arc::clone(&backend));
647
648        let ttl = Duration::hours(4);
649        let le = LeaseEntry::new(
650            me.path.clone(),
651            Some("creds".into()),
652            &(),
653            Some("creds".into()),
654            &(),
655            clock.now(),
656            ttl,
657            ns.id.clone(),
658        )
659        .unwrap();
660
661        assert!(exp_m.register(le.clone()).await.is_ok());
662        let leases = repos.lease.list().await.unwrap();
663        assert_eq!(leases, vec![le.clone()]);
664        let next_lease = repos.lease.peek().await.unwrap();
665        assert_eq!(next_lease, Some(le.clone()));
666
667        // Wait ttl - 1 hours and it should still be there
668        advance_to(&clock, le.expires_at - Duration::hours(1)).await;
669        let leases = repos.lease.list().await.unwrap();
670        assert_eq!(leases, vec![le.clone()]);
671
672        // Go to expire time
673        advance_to(&clock, le.expires_at).await;
674
675        let requests = recorder.0.read().await;
676        assert_eq!(
677            *requests,
678            vec![RequestInfo {
679                path: "creds".into(),
680                operation: Operation::Revoke,
681                reveived_at: Some(clock.now())
682            }]
683        );
684
685        let leases = repos.lease.list().await.unwrap();
686        assert_eq!(leases, vec![]);
687    }
688
689    #[tokio::test]
690    async fn revoke_token_after_ttl_expires() {
691        let clock = TestClock::new();
692        let recorder = Arc::new(RequestRecorder(RwLock::new(Vec::new())));
693
694        let pool = Arc::new(pool().await);
695        let u_pool = SqlitePool::connect(":memory:").await.unwrap();
696        let repos = Repos::new(pool, u_pool);
697
698        let ns = Namespace {
699            id: Uuid::new_v4().to_string(),
700            name: "root".to_string(),
701            parent_namespace_id: None,
702        };
703        repos.namespace.create(&ns).await.unwrap();
704
705        let router = Arc::new(Router::new(repos.mount.clone()));
706        let exp_m = Arc::new(ExpirationManager::new(
707            Arc::clone(&router),
708            repos.clone(),
709            clock.clone(),
710        ));
711
712        let expiration_manager = Arc::clone(&exp_m);
713        tokio::spawn(async move {
714            expiration_manager.start().await.unwrap();
715        });
716        tokio::task::yield_now().await;
717
718        // Setup system mount
719        let me = MountEntry {
720            id: Uuid::new_v4(),
721            backend_type: BackendType::System,
722            config: MountConfig::default(),
723            path: SYSTEM_MOUNT_PATH.to_string(),
724            namespace_id: ns.id.clone(),
725        };
726        repos.mount.create(&me).await.unwrap();
727
728        let recorder_moved = Arc::clone(&recorder);
729        let clock_moved = clock.clone();
730        let handler = SyncService::new(tower::service_fn(move |req| {
731            let recorder = Arc::clone(&recorder_moved);
732            let clock = clock_moved.clone();
733            async move { system_handle(req, recorder, None, clock).await }
734        }));
735        let backend = Arc::new(Backend {
736            category: BackendCategory::Logical,
737            migrations: vec![],
738            variant: me.backend_type,
739            handler,
740        });
741        router.mount_system(backend);
742
743        let ttl = Duration::hours(4);
744        let le = LeaseEntry::new(
745            me.path.clone(),
746            None,
747            &(),
748            None,
749            &(),
750            clock.now(),
751            ttl,
752            ns.id.clone(),
753        )
754        .unwrap();
755
756        assert!(exp_m.register(le.clone()).await.is_ok());
757        let leases = repos.lease.list().await.unwrap();
758        assert_eq!(leases, vec![le.clone()]);
759
760        // Wait ttl - 1 hours and it should still be there
761        advance_to(&clock, le.expires_at - Duration::hours(1)).await;
762        let leases = repos.lease.list().await.unwrap();
763        assert_eq!(leases, vec![le.clone()]);
764
765        // Go to revocation time
766        advance_to(&clock, le.expires_at).await;
767
768        let requests = recorder.0.read().await;
769        assert_eq!(
770            *requests,
771            vec![RequestInfo {
772                path: "token/revoke".into(),
773                operation: Operation::Revoke,
774                reveived_at: Some(le.expires_at)
775            }]
776        );
777
778        let leases = repos.lease.list().await.unwrap();
779        assert_eq!(leases, vec![]);
780    }
781
782    #[tokio::test]
783    async fn revoke_before_ttl_expires() {
784        let clock = TestClock::new();
785        let recorder = Arc::new(RequestRecorder(RwLock::new(Vec::new())));
786
787        let pool = Arc::new(pool().await);
788        let u_pool = SqlitePool::connect(":memory:").await.unwrap();
789        let repos = Repos::new(pool, u_pool);
790
791        let ns = Namespace {
792            id: Uuid::new_v4().to_string(),
793            name: "root".to_string(),
794            parent_namespace_id: None,
795        };
796        repos.namespace.create(&ns).await.unwrap();
797
798        let router = Arc::new(Router::new(repos.mount.clone()));
799        let exp_m = Arc::new(ExpirationManager::new(
800            Arc::clone(&router),
801            repos.clone(),
802            clock.clone(),
803        ));
804
805        let expiration_manager = Arc::clone(&exp_m);
806        tokio::spawn(async move {
807            expiration_manager.start().await.unwrap();
808        });
809        tokio::task::yield_now().await;
810
811        // Setup system mount
812        let me = MountEntry {
813            id: Uuid::new_v4(),
814            backend_type: BackendType::System,
815            config: MountConfig::default(),
816            path: SYSTEM_MOUNT_PATH.to_string(),
817            namespace_id: ns.id.clone(),
818        };
819        repos.mount.create(&me).await.unwrap();
820
821        let recorder_moved = Arc::clone(&recorder);
822        let clock_moved = clock.clone();
823        let handler = SyncService::new(tower::service_fn(move |req| {
824            let recorder = Arc::clone(&recorder_moved);
825            let clock = clock_moved.clone();
826            async move { system_handle(req, recorder, None, clock).await }
827        }));
828        let backend = Arc::new(Backend {
829            category: BackendCategory::Logical,
830            migrations: vec![],
831            variant: me.backend_type,
832            handler,
833        });
834        router.mount_system(backend);
835
836        let ttl = Duration::hours(4);
837        let le = LeaseEntry::new(
838            me.path.clone(),
839            None,
840            &(),
841            None,
842            &(),
843            clock.now(),
844            ttl,
845            ns.id.clone(),
846        )
847        .unwrap();
848
849        assert!(exp_m.register(le.clone()).await.is_ok());
850        let leases = repos.lease.list().await.unwrap();
851        assert_eq!(leases, vec![le.clone()]);
852        assert!(exp_m
853            .revoke_lease_entry_by_id(le.id(), &le.namespace_id)
854            .await
855            .is_ok());
856        let leases = repos.lease.list().await.unwrap();
857        assert_eq!(leases, vec![]);
858
859        advance_to(&clock, le.expires_at).await;
860
861        let requests = recorder.0.read().await;
862        assert_eq!(
863            *requests,
864            vec![RequestInfo {
865                path: "token/revoke".into(),
866                operation: Operation::Revoke,
867                reveived_at: None
868            }]
869        );
870
871        // Sanity test that leases is still empty
872        let leases = repos.lease.list().await.unwrap();
873        assert_eq!(leases, vec![]);
874    }
875
876    #[tokio::test]
877    #[allow(clippy::too_many_lines)]
878    async fn renew() {
879        let clock = TestClock::new();
880        let recorder = Arc::new(RequestRecorder(RwLock::new(Vec::new())));
881
882        let pool = Arc::new(pool().await);
883        let u_pool = SqlitePool::connect(":memory:").await.unwrap();
884        let repos = Repos::new(pool, u_pool);
885
886        let ns = Namespace {
887            id: Uuid::new_v4().to_string(),
888            name: "root".to_string(),
889            parent_namespace_id: None,
890        };
891        repos.namespace.create(&ns).await.unwrap();
892
893        let router = Arc::new(Router::new(repos.mount.clone()));
894        let exp_m = Arc::new(ExpirationManager::new(
895            Arc::clone(&router),
896            repos.clone(),
897            clock.clone(),
898        ));
899
900        let expiration_manager = Arc::clone(&exp_m);
901        tokio::spawn(async move {
902            expiration_manager.start().await.unwrap();
903        });
904        tokio::task::yield_now().await;
905
906        let mount_config = MountConfig {
907            max_lease_ttl: std::time::Duration::from_secs(3600 * 24),
908            ..Default::default()
909        };
910        let me = MountEntry {
911            id: Uuid::new_v4(),
912            backend_type: BackendType::Postgres,
913            config: mount_config,
914            path: "psql/".into(),
915            namespace_id: ns.id.clone(),
916        };
917        repos.mount.create(&me).await.unwrap();
918
919        let renew_ttl = Duration::hours(2);
920
921        let recorder_moved = Arc::clone(&recorder);
922        let clock_moved = clock.clone();
923        let handler = SyncService::new(tower::service_fn(move |req| {
924            let recorder = Arc::clone(&recorder_moved);
925            let clock = clock_moved.clone();
926            async move {
927                secret_engine_handle(req, recorder, Some(renew_ttl.to_std().unwrap()), clock).await
928            }
929        }));
930        let backend = Arc::new(Backend {
931            category: BackendCategory::Logical,
932            migrations: vec![],
933            variant: me.backend_type,
934            handler,
935        });
936
937        router.mount(me.id, Arc::clone(&backend));
938
939        let ttl = Duration::hours(4);
940        let le = LeaseEntry::new(
941            me.path.clone(),
942            Some("creds".into()),
943            &(),
944            Some("creds".into()),
945            &(),
946            clock.now(),
947            ttl,
948            ns.id.clone(),
949        )
950        .unwrap();
951
952        assert!(exp_m.register(le.clone()).await.is_ok());
953        let leases = repos.lease.list().await.unwrap();
954        assert_eq!(leases, vec![le.clone()]);
955
956        // 1 hour before it expires. Lets renew!
957        advance_to(&clock, le.expires_at - Duration::hours(1)).await;
958
959        // Renew
960        let new_le = exp_m
961            .renew_lease_entry(le.id(), &le.namespace_id, None)
962            .await
963            .unwrap();
964        let leases = repos.lease.list().await.unwrap();
965        assert_eq!(leases, vec![le.clone()]);
966        let new_expire_time = new_le.expires_at;
967
968        // Advance 1 hours until the original revocation time.
969        advance_to(&clock, le.expires_at).await;
970
971        // Still not revoked
972        let requests = recorder.0.read().await;
973        assert_eq!(
974            *requests,
975            vec![RequestInfo {
976                path: "creds".into(),
977                operation: Operation::Renew,
978                reveived_at: None
979            }]
980        );
981        drop(requests);
982
983        // Advance until the new revocation time.
984        advance_to(&clock, new_expire_time).await;
985
986        // Now it should be revoked
987        let requests = recorder.0.read().await;
988        assert_eq!(
989            *requests,
990            vec![
991                RequestInfo {
992                    path: "creds".into(),
993                    operation: Operation::Renew,
994                    reveived_at: None
995                },
996                RequestInfo {
997                    path: "creds".into(),
998                    operation: Operation::Revoke,
999                    reveived_at: Some(new_expire_time)
1000                }
1001            ]
1002        );
1003        drop(requests);
1004
1005        // Sanity test that leases is still empty
1006        let leases = repos.lease.list().await.unwrap();
1007        assert_eq!(leases, vec![]);
1008    }
1009
1010    #[tokio::test]
1011    async fn retry_failed_revocation() {
1012        let clock = TestClock::new();
1013        let recorder = Arc::new(RequestRecorder(RwLock::new(Vec::new())));
1014
1015        let pool = Arc::new(pool().await);
1016        let u_pool = SqlitePool::connect(":memory:").await.unwrap();
1017        let repos = Repos::new(pool, u_pool);
1018
1019        let ns = Namespace {
1020            id: Uuid::new_v4().to_string(),
1021            name: "root".to_string(),
1022            parent_namespace_id: None,
1023        };
1024        repos.namespace.create(&ns).await.unwrap();
1025
1026        let router = Arc::new(Router::new(repos.mount.clone()));
1027        let mut exp_m = ExpirationManager::new(Arc::clone(&router), repos.clone(), clock.clone());
1028        exp_m.revocation_retry_timeout = Duration::milliseconds(10);
1029        exp_m.revocation_max_retries = 5;
1030        let exp_m = Arc::new(exp_m);
1031
1032        let expiration_manager = Arc::clone(&exp_m);
1033        tokio::spawn(async move {
1034            expiration_manager.start().await.unwrap();
1035        });
1036        tokio::task::yield_now().await;
1037
1038        let me = MountEntry {
1039            id: Uuid::new_v4(),
1040            backend_type: BackendType::Postgres,
1041            config: MountConfig::default(),
1042            path: "psql/".into(),
1043            namespace_id: ns.id.clone(),
1044        };
1045        repos.mount.create(&me).await.unwrap();
1046
1047        let recorder_moved = Arc::clone(&recorder);
1048        let clock_moved = clock.clone();
1049        let handler = SyncService::new(tower::service_fn(move |req| {
1050            let recorder = Arc::clone(&recorder_moved);
1051            let clock = clock_moved.clone();
1052            async move { secret_engine_handle(req, recorder, None, clock).await }
1053        }));
1054        let backend = Arc::new(Backend {
1055            category: BackendCategory::Logical,
1056            migrations: vec![],
1057            variant: me.backend_type,
1058            handler,
1059        });
1060
1061        router.mount(me.id, Arc::clone(&backend));
1062
1063        let ttl = Duration::hours(4);
1064        let le = LeaseEntry::new(
1065            me.path.clone(),
1066            // This will ensure revocation returns 404 error
1067            Some("invalid-revoke-path".into()),
1068            &(),
1069            Some("creds".into()),
1070            &(),
1071            clock.now(),
1072            ttl,
1073            ns.id.clone(),
1074        )
1075        .unwrap();
1076
1077        assert!(exp_m.register(le.clone()).await.is_ok());
1078
1079        // Wait until revocation time
1080        advance_to(&clock, le.expires_at).await;
1081
1082        for _ in 0..exp_m.revocation_max_retries * 2 {
1083            advance(&clock, exp_m.revocation_retry_timeout).await;
1084        }
1085
1086        let requests = recorder.0.read().await;
1087        assert_eq!(requests.len(), exp_m.revocation_max_retries as usize);
1088        drop(requests);
1089
1090        // Lease should be deleted
1091        let leases = repos.lease.list().await.unwrap();
1092        assert_eq!(leases, vec![]);
1093    }
1094
1095    #[tokio::test]
1096    async fn revoke_for_mount() {
1097        let clock = TestClock::new();
1098        let recorder = Arc::new(RequestRecorder(RwLock::new(Vec::new())));
1099
1100        let pool = Arc::new(pool().await);
1101        let u_pool = SqlitePool::connect(":memory:").await.unwrap();
1102        let repos = Repos::new(pool, u_pool);
1103
1104        let ns = Namespace {
1105            id: Uuid::new_v4().to_string(),
1106            name: "root".to_string(),
1107            parent_namespace_id: None,
1108        };
1109        repos.namespace.create(&ns).await.unwrap();
1110
1111        let router = Arc::new(Router::new(repos.mount.clone()));
1112        let exp_m = Arc::new(ExpirationManager::new(
1113            Arc::clone(&router),
1114            repos.clone(),
1115            clock.clone(),
1116        ));
1117
1118        let expiration_manager = Arc::clone(&exp_m);
1119        tokio::spawn(async move {
1120            expiration_manager.start().await.unwrap();
1121        });
1122        tokio::task::yield_now().await;
1123
1124        let me = MountEntry {
1125            id: Uuid::new_v4(),
1126            backend_type: BackendType::Postgres,
1127            config: MountConfig::default(),
1128            path: "psql/".into(),
1129            namespace_id: ns.id.clone(),
1130        };
1131        repos.mount.create(&me).await.unwrap();
1132
1133        let recorder_moved = Arc::clone(&recorder);
1134        let clock_moved = clock.clone();
1135        let handler = SyncService::new(tower::service_fn(move |req| {
1136            let recorder = Arc::clone(&recorder_moved);
1137            let clock = clock_moved.clone();
1138            async move { secret_engine_handle(req, recorder, None, clock).await }
1139        }));
1140        let backend = Arc::new(Backend {
1141            category: BackendCategory::Logical,
1142            migrations: vec![],
1143            variant: me.backend_type,
1144            handler,
1145        });
1146
1147        router.mount(me.id, Arc::clone(&backend));
1148
1149        let ttl = Duration::hours(4);
1150        let lease_count = 50;
1151
1152        for _ in 0..lease_count {
1153            let le = LeaseEntry::new(
1154                me.path.clone(),
1155                Some("creds".into()),
1156                &(),
1157                Some("creds".into()),
1158                &(),
1159                clock.now(),
1160                ttl,
1161                ns.id.clone(),
1162            )
1163            .unwrap();
1164            assert!(exp_m.register(le.clone()).await.is_ok());
1165        }
1166
1167        assert_eq!(
1168            exp_m
1169                .revoke_leases_by_mount_prefix(&me.path, &ns.id)
1170                .await
1171                .unwrap()
1172                .len(),
1173            lease_count
1174        );
1175
1176        // Leases should be empty
1177        let leases = repos.lease.list().await.unwrap();
1178        assert_eq!(leases, vec![]);
1179
1180        // Number of leases revoked == number of requests
1181        let requests = recorder.0.read().await;
1182        assert_eq!(requests.len(), lease_count);
1183        drop(requests);
1184
1185        // Go to revocation time
1186        advance(&clock, ttl).await;
1187
1188        // No new requests has been sent
1189        let requests = recorder.0.read().await;
1190        assert_eq!(requests.len(), lease_count);
1191        drop(requests);
1192    }
1193
1194    #[tokio::test]
1195    async fn slow_revoke_endpoint_does_not_halt_other_revocations() {
1196        let clock = TestClock::new();
1197        let recorder = Arc::new(RequestRecorder(RwLock::new(Vec::new())));
1198
1199        let pool = Arc::new(pool().await);
1200        let u_pool = SqlitePool::connect(":memory:").await.unwrap();
1201        let repos = Repos::new(pool, u_pool);
1202
1203        let ns = Namespace {
1204            id: Uuid::new_v4().to_string(),
1205            name: "root".to_string(),
1206            parent_namespace_id: None,
1207        };
1208        repos.namespace.create(&ns).await.unwrap();
1209
1210        let router = Arc::new(Router::new(repos.mount.clone()));
1211        let mut exp_m = ExpirationManager::new(Arc::clone(&router), repos.clone(), clock.clone());
1212        exp_m.revocation_max_retries = 3;
1213        exp_m.revocation_timeout = std::time::Duration::from_millis(10);
1214        let exp_m = Arc::new(exp_m);
1215
1216        let expiration_manager = Arc::clone(&exp_m);
1217        tokio::spawn(async move {
1218            expiration_manager.start().await.unwrap();
1219        });
1220        tokio::task::yield_now().await;
1221
1222        let me = MountEntry {
1223            id: Uuid::new_v4(),
1224            backend_type: BackendType::Postgres,
1225            config: MountConfig::default(),
1226            path: "psql/".into(),
1227            namespace_id: ns.id.clone(),
1228        };
1229        repos.mount.create(&me).await.unwrap();
1230
1231        let recorder_moved = Arc::clone(&recorder);
1232        let clock_moved = clock.clone();
1233        let handler = SyncService::new(tower::service_fn(move |req| {
1234            let recorder = Arc::clone(&recorder_moved);
1235            let clock = clock_moved.clone();
1236            async move { secret_engine_handle(req, recorder, None, clock).await }
1237        }));
1238        let backend = Arc::new(Backend {
1239            category: BackendCategory::Logical,
1240            migrations: vec![],
1241            variant: me.backend_type,
1242            handler,
1243        });
1244
1245        router.mount(me.id, Arc::clone(&backend));
1246
1247        let ttl = Duration::hours(4);
1248        let fast_lease_revocation_time = clock.now() + ttl;
1249        let lease_count = 5;
1250
1251        for i in 0..lease_count {
1252            let le = LeaseEntry::new(
1253                me.path.clone(),
1254                Some("creds".into()),
1255                &i,
1256                Some("creds".into()),
1257                &i,
1258                clock.now(),
1259                ttl,
1260                ns.id.clone(),
1261            )
1262            .unwrap();
1263            assert!(exp_m.register(le.clone()).await.is_ok());
1264        }
1265
1266        // Register lease that will be slow to revoke
1267        let slow_lease = LeaseEntry::new(
1268            me.path.clone(),
1269            Some("creds-slow".into()),
1270            &(),
1271            Some("creds".into()),
1272            &(),
1273            clock.now(),
1274            ttl - Duration::milliseconds(2),
1275            ns.id.clone(),
1276        )
1277        .unwrap();
1278        assert!(exp_m.register(slow_lease.clone()).await.is_ok());
1279
1280        advance_to(&clock, slow_lease.expires_at).await;
1281
1282        // All leases still stored
1283        let leases = repos.lease.list().await.unwrap();
1284        assert_eq!(leases.len(), lease_count + 1);
1285
1286        // Advance to time where fast leases will be revoked
1287        advance_to(&clock, fast_lease_revocation_time).await;
1288
1289        // All fast lease revocations are gone now
1290        let leases = repos.lease.list().await.unwrap();
1291        assert_eq!(leases.len(), 1);
1292    }
1293}