1use crate::errors::{ClientCredentialsError, SecureError};
2use crate::jwt::encode;
3use crate::secure::generate_secure_string;
4use cached::proc_macro::cached;
5use cached::TimedSizedCache;
6use chrono::{Duration, Utc};
7use openssl::rsa::Rsa;
8use reqwest::header;
9use serde::{Deserialize, Serialize};
10use tokio::time::sleep;
11
12const AUTHORIZATION_TRIES: u8 = 3;
13
14#[derive(Debug, Deserialize, Serialize)]
26pub struct ClientCredentials {
27 iss: String, sub: String, aud: Vec<String>, iat: i64, exp: i64, jti: String, }
34
35impl ClientCredentials {
36 pub fn new(client_id: &str, aud: &str) -> Self {
39 Self {
40 iss: client_id.to_string(),
41 sub: client_id.to_string(),
42 aud: vec![aud.to_string()],
43 iat: Utc::now().timestamp(),
44 exp: (Utc::now() + Duration::minutes(300)).timestamp(),
45 jti: generate_secure_string(10),
46 }
47 }
48
49 pub fn build_signed_token(
51 &self,
52 kid: &str,
53 rsa_key_pair: Rsa<openssl::pkey::Private>,
54 ) -> Result<String, SecureError> {
55 encode(self, kid, rsa_key_pair)
56 }
57}
58
59#[derive(Debug, Deserialize, Serialize, Clone)]
60pub struct ClientAuthorizationRequest {
61 grant_type: &'static str,
62 client_assertion_type: &'static str,
63 scope: String,
64 client_assertion: String,
65}
66
67impl ClientAuthorizationRequest {
68 const GRANT_TYPE: &'static str = "client_credentials";
69 const CLIENT_ASSERTION_TYPE: &'static str =
70 "urn:ietf:params:oauth:client-assertion-type:jwt-bearer";
71
72 fn new(scopes: &str, client_assertion: &str) -> Self {
77 Self {
78 grant_type: Self::GRANT_TYPE,
79 client_assertion_type: Self::CLIENT_ASSERTION_TYPE,
80 scope: scopes.to_string(),
81 client_assertion: client_assertion.to_string(),
82 }
83 }
84}
85
86#[derive(Debug, Deserialize, Serialize, Clone)]
87pub struct ClientAuthorizationResponse {
88 pub access_token: String,
89 pub token_type: String,
90 pub expires_in: i64,
91 pub scope: String,
92}
93
94#[cached(
104 result = true, sync_writes = "default", create = "{ TimedSizedCache::with_size_and_lifespan(100, std::time::Duration::from_secs(900)) }", ty = "TimedSizedCache<String, ClientAuthorizationResponse>",
108 convert = r#"{ format!("{}{:?}", client_id, scopes) }"#,
109)]
110pub async fn request_service_token_cached(
111 client_id: &str,
112 platform_token_url: &str,
113 scopes: &str,
114 kid: &str,
115 rsa_key_pair: Rsa<openssl::pkey::Private>,
116) -> Result<ClientAuthorizationResponse, ClientCredentialsError> {
117 let mut count = 0;
118 let credentials = ClientCredentials::new(client_id, platform_token_url);
119 let token = credentials
120 .build_signed_token(kid, rsa_key_pair)
121 .map_err(|e| ClientCredentialsError::RequestFailed(e.to_string()))?;
122 let mut last_error = String::new();
123 while count < AUTHORIZATION_TRIES {
124 match request_service_token(platform_token_url, &token, scopes).await {
125 Ok(response) => return Ok(response),
126 Err(ClientCredentialsError::RateLimited(e)) => {
127 last_error = e.to_string();
128 sleep(std::time::Duration::from_secs(1)).await;
130 }
131 Err(e) => return Err(e),
132 }
133 count += 1;
134 }
135 Err(ClientCredentialsError::RequestLimitReached(last_error))
136}
137
138pub async fn request_service_token(
139 platform_token_url: &str,
140 token: &str,
141 scopes: &str,
142) -> Result<ClientAuthorizationResponse, ClientCredentialsError> {
143 let params = ClientAuthorizationRequest::new(scopes, token);
144 let client = reqwest::Client::new();
145 let response = client
146 .post(platform_token_url)
147 .header(header::CONTENT_TYPE, "application/x-www-form-urlencoded")
148 .form(¶ms)
149 .send()
150 .await
151 .map_err(|e| ClientCredentialsError::RequestFailed(e.to_string()))?;
152 let status = response.status();
153 let body = response
154 .text()
155 .await
156 .map_err(|e| ClientCredentialsError::RequestFailed(e.to_string()))?;
157 if !status.is_success() {
158 if body.contains("rate limit") {
159 return Err(ClientCredentialsError::RateLimited(body));
160 }
161
162 return Err(ClientCredentialsError::RequestFailed(body));
163 }
164
165 let access_token_response: ClientAuthorizationResponse = serde_json::from_str(&body)
166 .map_err(|e| ClientCredentialsError::RequestFailed(e.to_string()))?;
167
168 Ok(access_token_response)
169}
170
171#[cfg(test)]
172mod tests {
173 use super::*;
174 use openssl::rsa::Rsa;
175
176 #[tokio::test]
177 async fn test_request_service_token_success() {
178 let mut server = mockito::Server::new_async().await;
179 let server_url = server.url();
180 let response = ClientAuthorizationResponse {
181 access_token: "fake".to_string(),
182 token_type: "Bearer".to_string(),
183 expires_in: 3600,
184 scope: "https://purl.imsglobal.org/spec/lti-ags/scope/lineitem https://purl.imsglobal.org/spec/lti-ags/scope/result/read".to_string(),
185 };
186 let response_body = serde_json::to_string(&response).expect("failed to serialize response");
187 let mock = server
188 .mock("POST", "/token")
189 .with_status(200)
190 .with_body(response_body)
191 .create();
192 let client_id = "test_client_id";
193 let platform_token_url = format!("{}/token", server_url);
194 let rsa = Rsa::generate(2048).expect("failed to generate rsa key pair");
195 let kid = "test_kid";
196 let scopes = "https://purl.imsglobal.org/spec/lti-ags/scope/lineitem https://purl.imsglobal.org/spec/lti-ags/scope/result/read";
197 let result =
198 request_service_token_cached(client_id, &platform_token_url, scopes, kid, rsa).await;
199 mock.assert();
200 assert!(result.is_ok());
201 }
202
203 #[tokio::test]
204 async fn test_request_service_token_rate_limited() {
205 let mut server = mockito::Server::new_async().await;
206 let server_url = server.url();
207 let mock = server
208 .mock("POST", "/token")
209 .with_status(429)
210 .with_body("rate limit exceeded")
211 .expect(3) .create();
213 let client_id = "test_fail_client_id";
214 let platform_token_url = format!("{}/token", server_url);
215 let rsa = Rsa::generate(2048).expect("failed to generate rsa key pair");
216 let scopes = "https://purl.imsglobal.org/spec/lti-ags/scope/lineitem https://purl.imsglobal.org/spec/lti-ags/scope/result/read";
217 let kid = "test_kid";
218 let result =
219 request_service_token_cached(client_id, &platform_token_url, scopes, kid, rsa).await;
220
221 mock.assert();
222 assert!(matches!(
223 result,
224 Err(ClientCredentialsError::RequestLimitReached(_))
225 ));
226 }
227}