Skip to main content

cli_engine/auth/
dispatcher.rs

1use std::sync::{Arc, RwLock};
2
3use async_trait::async_trait;
4
5use super::{AuthProvider, Credential, CredentialRequest};
6use crate::middleware::CommandMeta;
7use crate::{CliCoreError, Result};
8
9/// Routes auth operations to registered providers by name.
10///
11/// Clones share the same provider registry, so provider facades and transport
12/// injectors see later registration or replacement.
13#[derive(Clone, Debug, Default)]
14pub struct Dispatcher {
15    inner: Arc<RwLock<DispatcherInner>>,
16}
17
18#[derive(Clone, Debug, Default)]
19struct DispatcherInner {
20    providers: Vec<(String, Arc<dyn AuthProvider>)>,
21}
22
23/// Status row produced while querying all providers.
24#[derive(Clone, Debug)]
25pub struct StatusEntry {
26    /// Provider name.
27    pub provider: String,
28    /// Environment name.
29    pub env: String,
30    /// Cached credential when status succeeded.
31    pub credential: Option<Credential>,
32    /// Status error text when status failed.
33    pub error: Option<String>,
34}
35
36impl Dispatcher {
37    /// Creates an empty dispatcher.
38    #[must_use]
39    pub fn new() -> Self {
40        Self {
41            inner: Arc::new(RwLock::new(DispatcherInner::default())),
42        }
43    }
44
45    /// Registers or replaces a provider under its [`AuthProvider::name`].
46    pub fn register(&mut self, provider: Arc<dyn AuthProvider>) {
47        let name = provider.name().to_owned();
48        let mut inner = self.write_inner();
49        if let Some((_, existing)) = inner
50            .providers
51            .iter_mut()
52            .find(|(existing_name, _)| existing_name == &name)
53        {
54            *existing = provider;
55            return;
56        }
57        inner.providers.push((name, provider));
58    }
59
60    /// Returns provider names in registration order.
61    #[must_use]
62    pub fn registered_names(&self) -> Vec<String> {
63        self.read_inner()
64            .providers
65            .iter()
66            .map(|(name, _)| name.clone())
67            .collect()
68    }
69
70    /// Gets a credential from a named provider.
71    pub async fn get_credential(
72        &self,
73        name: &str,
74        env: &str,
75        command: &str,
76        tier: &str,
77    ) -> Result<Credential> {
78        self.get(name)?.get_credential(env, command, tier).await
79    }
80
81    /// Gets a credential from a named provider, passing the command's full
82    /// [`CredentialRequest`] so metadata-aware providers (e.g. OAuth scope
83    /// step-up) can act on it.
84    pub async fn get_credential_for(
85        &self,
86        name: &str,
87        req: &CredentialRequest<'_>,
88    ) -> Result<Credential> {
89        self.get(name)?.get_credential_for(req).await
90    }
91
92    /// Clears any cached credential, ignoring logout failures, then authenticates.
93    pub async fn login(&self, name: &str, env: &str) -> Result<Credential> {
94        self.login_with_scopes(name, env, &[]).await
95    }
96
97    /// Like [`login`](Dispatcher::login), but requests `additional_scopes` on top
98    /// of the provider's defaults.
99    ///
100    /// The scopes are carried as [`CommandMeta::scopes`] on a synthesized
101    /// request; providers without scope support ignore them.
102    pub async fn login_with_scopes(
103        &self,
104        name: &str,
105        env: &str,
106        additional_scopes: &[String],
107    ) -> Result<Credential> {
108        let provider = self.get(name)?;
109        if let Err(err) = provider.logout(env).await {
110            tracing::debug!(provider = name, error = %err, "ignoring logout error before login");
111        }
112        let mut meta = CommandMeta::default();
113        meta.set_scopes(additional_scopes.to_vec());
114        let req = CredentialRequest::new(env, "", "", &meta);
115        provider.get_credential_for(&req).await
116    }
117
118    /// Gets cached credential status from a named provider.
119    pub async fn status(&self, name: &str, env: &str) -> Result<Credential> {
120        self.get(name)?.status(env).await
121    }
122
123    /// Clears cached credentials for a named provider and environment.
124    pub async fn logout(&self, name: &str, env: &str) -> Result<()> {
125        self.get(name)?.logout(env).await
126    }
127
128    /// Queries every provider for every cached environment it reports.
129    pub async fn all_statuses(&self) -> Vec<StatusEntry> {
130        let mut entries = Vec::new();
131        let providers = self.read_inner().providers.clone();
132        for (name, provider) in providers {
133            let Ok(envs) = provider.list_environments().await else {
134                continue;
135            };
136            for env in envs {
137                match provider.status(&env).await {
138                    Ok(credential) => entries.push(StatusEntry {
139                        provider: name.clone(),
140                        env,
141                        credential: Some(credential),
142                        error: None,
143                    }),
144                    Err(err) => entries.push(StatusEntry {
145                        provider: name.clone(),
146                        env,
147                        credential: None,
148                        error: Some(err.to_string()),
149                    }),
150                }
151            }
152        }
153        entries
154    }
155
156    /// Returns an auth-provider facade backed by this dispatcher.
157    #[must_use]
158    pub fn for_provider(&self, name: impl Into<String>) -> SingleProvider {
159        SingleProvider {
160            dispatcher: self.clone(),
161            name: name.into(),
162        }
163    }
164
165    fn get(&self, name: &str) -> Result<Arc<dyn AuthProvider>> {
166        self.read_inner()
167            .providers
168            .iter()
169            .find(|(existing_name, _)| existing_name == name)
170            .map(|(_, provider)| Arc::clone(provider))
171            .ok_or_else(|| CliCoreError::MissingAuthProvider(name.to_owned()))
172    }
173
174    fn read_inner(&self) -> std::sync::RwLockReadGuard<'_, DispatcherInner> {
175        match self.inner.read() {
176            Ok(guard) => guard,
177            Err(poisoned) => poisoned.into_inner(),
178        }
179    }
180
181    fn write_inner(&self) -> std::sync::RwLockWriteGuard<'_, DispatcherInner> {
182        match self.inner.write() {
183            Ok(guard) => guard,
184            Err(poisoned) => poisoned.into_inner(),
185        }
186    }
187}
188
189/// Single-provider facade over a shared [`Dispatcher`].
190#[derive(Clone, Debug)]
191pub struct SingleProvider {
192    dispatcher: Dispatcher,
193    name: String,
194}
195
196#[async_trait]
197impl AuthProvider for SingleProvider {
198    fn name(&self) -> &str {
199        &self.name
200    }
201
202    async fn get_credential(&self, env: &str, command: &str, tier: &str) -> Result<Credential> {
203        self.dispatcher
204            .get_credential(&self.name, env, command, tier)
205            .await
206    }
207
208    async fn get_credential_for(&self, req: &CredentialRequest<'_>) -> Result<Credential> {
209        self.dispatcher.get_credential_for(&self.name, req).await
210    }
211
212    async fn status(&self, env: &str) -> Result<Credential> {
213        self.dispatcher.status(&self.name, env).await
214    }
215
216    async fn logout(&self, env: &str) -> Result<()> {
217        self.dispatcher.logout(&self.name, env).await
218    }
219
220    async fn list_environments(&self) -> Result<Vec<String>> {
221        self.dispatcher.get(&self.name)?.list_environments().await
222    }
223}