oauth_device_flows/
token.rs

1//! Token management and refresh functionality
2
3use crate::{
4    config::DeviceFlowConfig,
5    error::{DeviceFlowError, Result},
6    provider::Provider,
7    types::{RefreshTokenRequest, TokenResponse},
8};
9use reqwest::Client;
10use secrecy::ExposeSecret;
11use std::time::Duration;
12use time::OffsetDateTime;
13use url::Url;
14
15/// Manages OAuth tokens including refresh functionality
16#[derive(Debug, Clone)]
17pub struct TokenManager {
18    /// The current token response
19    token: TokenResponse,
20
21    /// The OAuth provider
22    provider: Provider,
23
24    /// Configuration
25    config: DeviceFlowConfig,
26
27    /// HTTP client
28    client: Client,
29}
30
31impl TokenManager {
32    /// Create a new token manager
33    pub fn new(token: TokenResponse, provider: Provider, config: DeviceFlowConfig) -> Result<Self> {
34        let client = Self::build_client(&config)?;
35
36        Ok(Self {
37            token,
38            provider,
39            config,
40            client,
41        })
42    }
43
44    /// Create a token manager from an existing token (for deserialization)
45    pub fn from_token(token: TokenResponse) -> Self {
46        let config = DeviceFlowConfig::new();
47        let client = Self::build_client(&config).unwrap_or_default();
48
49        Self {
50            token,
51            provider: Provider::Microsoft, // Default, should be set properly
52            config,
53            client,
54        }
55    }
56
57    /// Get the current access token
58    pub fn access_token(&self) -> &str {
59        self.token.access_token()
60    }
61
62    /// Get the current token response
63    pub fn token(&self) -> &TokenResponse {
64        &self.token
65    }
66
67    /// Check if the token is expired
68    pub fn is_expired(&self) -> bool {
69        self.token.is_expired()
70    }
71
72    /// Check if the token will expire within the given duration
73    pub fn expires_within(&self, duration: Duration) -> bool {
74        self.token.expires_within(duration)
75    }
76
77    /// Get the remaining lifetime of the token
78    pub fn remaining_lifetime(&self) -> Option<Duration> {
79        self.token.remaining_lifetime()
80    }
81
82    /// Refresh the token if a refresh token is available
83    pub async fn refresh(&mut self) -> Result<()> {
84        let refresh_token = self
85            .token
86            .refresh_token()
87            .ok_or_else(|| DeviceFlowError::other("No refresh token available"))?;
88
89        let new_token = self.refresh_token(refresh_token).await?;
90        self.token = new_token;
91
92        Ok(())
93    }
94
95    /// Refresh the token and return the new token without updating the manager
96    pub async fn refresh_token(&self, refresh_token: &str) -> Result<TokenResponse> {
97        let token_endpoint = if let Some(ref config) = self.config.generic_provider_config {
98            config.token_endpoint.clone()
99        } else {
100            Url::parse(self.provider.token_endpoint())
101                .map_err(|e| DeviceFlowError::other(format!("Invalid token endpoint: {e}")))?
102        };
103
104        let request = RefreshTokenRequest {
105            grant_type: "refresh_token".to_string(),
106            refresh_token: refresh_token.to_string(),
107            client_id: self.config.client_id.clone(),
108            scope: None, // Keep original scope
109        };
110
111        let mut req_builder = self.client.post(token_endpoint).form(&request);
112
113        // Add client secret if required
114        if let Some(ref client_secret) = self.config.client_secret {
115            req_builder = req_builder.form(&[("client_secret", client_secret.expose_secret())]);
116        }
117
118        // Add provider-specific headers
119        for (key, value) in self.provider.headers() {
120            req_builder = req_builder.header(key, value);
121        }
122
123        // Add additional headers
124        for (key, value) in &self.config.additional_headers {
125            req_builder = req_builder.header(key, value);
126        }
127
128        let response = req_builder.send().await?;
129
130        if !response.status().is_success() {
131            let error_text = response.text().await?;
132            return Err(DeviceFlowError::other(format!(
133                "Token refresh failed: {error_text}"
134            )));
135        }
136
137        let mut token_response: TokenResponse = response.json().await?;
138
139        // Update the issued_at timestamp
140        token_response.issued_at = OffsetDateTime::now_utc();
141
142        // If no new refresh token was provided, keep the old one
143        if token_response.refresh_token.is_none() {
144            token_response.refresh_token = self.token.refresh_token.clone();
145        }
146
147        Ok(token_response)
148    }
149
150    /// Get a valid access token, refreshing if necessary
151    pub async fn get_valid_token(&mut self) -> Result<&str> {
152        // Check if token is expired or will expire soon (within 5 minutes)
153        if self.expires_within(Duration::from_secs(300)) {
154            self.refresh().await?;
155        }
156
157        Ok(self.access_token())
158    }
159
160    /// Create an authorization header value
161    pub fn authorization_header(&self) -> String {
162        format!("{} {}", self.token.token_type, self.access_token())
163    }
164
165    /// Update the provider (useful when deserializing)
166    pub fn with_provider(mut self, provider: Provider) -> Self {
167        self.provider = provider;
168        self
169    }
170
171    /// Update the configuration (useful when deserializing)
172    pub fn with_config(mut self, config: DeviceFlowConfig) -> Result<Self> {
173        self.client = Self::build_client(&config)?;
174        self.config = config;
175        Ok(self)
176    }
177
178    /// Revoke the token (if supported by the provider)
179    pub async fn revoke(&self) -> Result<()> {
180        let revoke_endpoint = match self.provider {
181            Provider::Microsoft => "https://login.microsoftonline.com/common/oauth2/v2.0/logout",
182            Provider::Google => "https://oauth2.googleapis.com/revoke",
183            Provider::GitHub => {
184                return Err(DeviceFlowError::other(
185                    "GitHub does not support token revocation",
186                ))
187            }
188            Provider::GitLab => "https://gitlab.com/oauth/revoke",
189            Provider::Generic => {
190                return Err(DeviceFlowError::other(
191                    "Revocation not supported for generic provider",
192                ));
193            }
194        };
195
196        let revoke_url = Url::parse(revoke_endpoint)
197            .map_err(|e| DeviceFlowError::other(format!("Invalid revoke endpoint: {e}")))?;
198
199        let form_data = match self.provider {
200            Provider::Google => vec![("token", self.access_token())],
201            Provider::Microsoft => vec![("token", self.access_token())],
202            Provider::GitLab => vec![
203                ("token", self.access_token()),
204                ("client_id", &self.config.client_id),
205            ],
206            _ => {
207                return Err(DeviceFlowError::other(
208                    "Revocation not implemented for this provider",
209                ))
210            }
211        };
212
213        let response = self.client.post(revoke_url).form(&form_data).send().await?;
214
215        if !response.status().is_success() {
216            let error_text = response.text().await?;
217            return Err(DeviceFlowError::other(format!(
218                "Token revocation failed: {error_text}"
219            )));
220        }
221
222        Ok(())
223    }
224
225    /// Build HTTP client with configuration
226    fn build_client(config: &DeviceFlowConfig) -> Result<Client> {
227        let mut client_builder = Client::builder().timeout(config.request_timeout);
228
229        if let Some(ref user_agent) = config.user_agent {
230            client_builder = client_builder.user_agent(user_agent);
231        }
232
233        client_builder.build().map_err(DeviceFlowError::from)
234    }
235}