1use std::time::{Duration, Instant};
12
13use serde::Deserialize;
14use thiserror::Error;
15
16use crate::Error;
17
18component! {
19 TokenProvider : TokenProviderInner {
20 tokens: Vec<Token> = vec![],
21 }
22}
23
24#[derive(Debug, Error)]
25pub enum TokenError {
26 #[error("no tokens available")]
27 Empty,
28}
29
30impl From<TokenError> for Error {
31 fn from(err: TokenError) -> Self {
32 Error::unavailable(err)
33 }
34}
35
36#[derive(Clone, Debug)]
37pub struct Token {
38 pub access_token: String,
39 pub expires_in: Duration,
40 pub token_type: String,
41 pub scopes: Vec<String>,
42 pub timestamp: Instant,
43}
44
45#[derive(Deserialize)]
46#[serde(rename_all = "camelCase")]
47struct TokenData {
48 access_token: String,
49 expires_in: u64,
50 token_type: String,
51 scope: Vec<String>,
52}
53
54impl TokenProvider {
55 fn find_token(&self, scopes: Vec<&str>) -> Option<usize> {
56 self.lock(|inner| {
57 (0..inner.tokens.len()).find(|&i| inner.tokens[i].in_scopes(scopes.clone()))
58 })
59 }
60
61 pub async fn get_token(&self, scopes: &str) -> Result<Token, Error> {
66 let client_id = self.session().client_id();
67 self.get_token_with_client_id(scopes, &client_id).await
68 }
69
70 pub async fn get_token_with_client_id(
71 &self,
72 scopes: &str,
73 client_id: &str,
74 ) -> Result<Token, Error> {
75 if client_id.is_empty() {
76 return Err(Error::invalid_argument("Client ID cannot be empty"));
77 }
78
79 if let Some(index) = self.find_token(scopes.split(',').collect()) {
80 let cached_token = self.lock(|inner| inner.tokens[index].clone());
81 if cached_token.is_expired() {
82 self.lock(|inner| inner.tokens.remove(index));
83 } else {
84 return Ok(cached_token);
85 }
86 }
87
88 trace!(
89 "Requested token in scopes {:?} unavailable or expired, requesting new token.",
90 scopes
91 );
92
93 let query_uri = format!(
94 "hm://keymaster/token/authenticated?scope={}&client_id={}&device_id={}",
95 scopes,
96 client_id,
97 self.session().device_id(),
98 );
99 let request = self.session().mercury().get(query_uri)?;
100 let response = request.await?;
101 let data = response.payload.first().ok_or(TokenError::Empty)?.to_vec();
102 let token = Token::from_json(String::from_utf8(data)?)?;
103 trace!("Got token: {:#?}", token);
104 self.lock(|inner| inner.tokens.push(token.clone()));
105 Ok(token)
106 }
107}
108
109impl Token {
110 const EXPIRY_THRESHOLD: Duration = Duration::from_secs(10);
111
112 pub fn from_json(body: String) -> Result<Self, Error> {
113 let data: TokenData = serde_json::from_slice(body.as_ref())?;
114 Ok(Self {
115 access_token: data.access_token,
116 expires_in: Duration::from_secs(data.expires_in),
117 token_type: data.token_type,
118 scopes: data.scope,
119 timestamp: Instant::now(),
120 })
121 }
122
123 pub fn is_expired(&self) -> bool {
124 self.timestamp + (self.expires_in.saturating_sub(Self::EXPIRY_THRESHOLD)) < Instant::now()
125 }
126
127 pub fn in_scope(&self, scope: &str) -> bool {
128 for s in &self.scopes {
129 if *s == scope {
130 return true;
131 }
132 }
133 false
134 }
135
136 pub fn in_scopes(&self, scopes: Vec<&str>) -> bool {
137 for s in scopes {
138 if !self.in_scope(s) {
139 return false;
140 }
141 }
142 true
143 }
144}