atomic_lti/
client_credentials.rs

1use crate::errors::{ClientCredentialsError, SecureError};
2use crate::jwt::encode;
3use crate::secure::generate_secure_string;
4use cached::proc_macro::cached;
5use cached::TimedSizedCache;
6use chrono::{Duration, Utc};
7use openssl::rsa::Rsa;
8use reqwest::header;
9use serde::{Deserialize, Serialize};
10use tokio::time::sleep;
11
12const AUTHORIZATION_TRIES: u8 = 3;
13
14// https://www.imsglobal.org/spec/lti/v1p3/#token-endpoint-claim-and-services
15// When requesting an access token, the client assertion JWT iss and sub must both be the
16// OAuth 2 client_id of the tool as issued by the learning platform during registration.
17// Additional information:
18// https://www.imsglobal.org/spec/security/v1p0/#using-json-web-tokens-with-oauth-2-0-client-credentials-grant
19//
20// Example usage:
21// ```
22// let credentials = ClientCredentials::new(client_id, platform_token_url);
23// let token = credentials.build_signed_token(rsa_key_pair)?;
24// ```
25#[derive(Debug, Deserialize, Serialize)]
26pub struct ClientCredentials {
27  iss: String,      // A unique identifier for the entity that issued the JWT
28  sub: String,      // "client_id" of the OAuth Client
29  aud: Vec<String>, // Authorization server identifier. Usually the token endpoint URL of the authorization server
30  iat: i64,         // Timestamp for when the JWT was created
31  exp: i64, // Timestamp for when the JWT should be treated as having expired (after allowing a margin for clock skew)
32  jti: String, // A unique (potentially reusable) identifier for the token
33}
34
35impl ClientCredentials {
36  // client_id -The LTI tool's client_id as provided by the platform
37  // aud - Authorization server identifier. Usually the token endpoint URL of the authorization server
38  pub fn new(client_id: &str, aud: &str) -> Self {
39    Self {
40      iss: client_id.to_string(),
41      sub: client_id.to_string(),
42      aud: vec![aud.to_string()],
43      iat: Utc::now().timestamp(),
44      exp: (Utc::now() + Duration::minutes(300)).timestamp(),
45      jti: generate_secure_string(10),
46    }
47  }
48
49  // Generate a signed JWT token using the provided RSA key pair
50  pub fn build_signed_token(
51    &self,
52    kid: &str,
53    rsa_key_pair: Rsa<openssl::pkey::Private>,
54  ) -> Result<String, SecureError> {
55    encode(self, kid, rsa_key_pair)
56  }
57}
58
59#[derive(Debug, Deserialize, Serialize, Clone)]
60pub struct ClientAuthorizationRequest {
61  grant_type: &'static str,
62  client_assertion_type: &'static str,
63  scope: String,
64  client_assertion: String,
65}
66
67impl ClientAuthorizationRequest {
68  const GRANT_TYPE: &'static str = "client_credentials";
69  const CLIENT_ASSERTION_TYPE: &'static str =
70    "urn:ietf:params:oauth:client-assertion-type:jwt-bearer";
71
72  // Build a new ClientAuthorizationRequest.
73  // scopes - a list of scopes to request access to each separated by a space. For example:
74  //   "https://purl.imsglobal.org/spec/lti-ags/scope/lineitem https://purl.imsglobal.org/spec/lti-ags/scope/result/read"
75  // client_assertion - a signed JWT token generated by the ClientCredentials struct
76  fn new(scopes: &str, client_assertion: &str) -> Self {
77    Self {
78      grant_type: Self::GRANT_TYPE,
79      client_assertion_type: Self::CLIENT_ASSERTION_TYPE,
80      scope: scopes.to_string(),
81      client_assertion: client_assertion.to_string(),
82    }
83  }
84}
85
86#[derive(Debug, Deserialize, Serialize, Clone)]
87pub struct ClientAuthorizationResponse {
88  pub access_token: String,
89  pub token_type: String,
90  pub expires_in: i64,
91  pub scope: String,
92}
93
94// Request a service token capable of making LTI 1.3 service calls to the LMS.
95// This function will cache the resulting token for 15 minutes.
96//
97// Arguments:
98// client_id - The LTI tool's client_id as provided by the platform
99// platform_token_url - The platform's token endpoint URL
100// scopes - a list of scopes to request access to each separated by a space. For example:
101//   "https://purl.imsglobal.org/spec/lti-ags/scope/lineitem https://purl.imsglobal.org/spec/lti-ags/scope/result/read"
102// rsa_key_pair - The RSA key pair used to sign the JWT token
103#[cached(
104  result = true, // Only "Ok" results are cached
105  sync_writes = "default", // When called concurrently, duplicate argument-calls will be synchronized so as to only run once
106  create = "{ TimedSizedCache::with_size_and_lifespan(100, std::time::Duration::from_secs(900)) }", // 15 mins
107  ty = "TimedSizedCache<String, ClientAuthorizationResponse>",
108  convert = r#"{ format!("{}{:?}", client_id, scopes) }"#,
109)]
110pub async fn request_service_token_cached(
111  client_id: &str,
112  platform_token_url: &str,
113  scopes: &str,
114  kid: &str,
115  rsa_key_pair: Rsa<openssl::pkey::Private>,
116) -> Result<ClientAuthorizationResponse, ClientCredentialsError> {
117  let mut count = 0;
118  let credentials = ClientCredentials::new(client_id, platform_token_url);
119  let token = credentials
120    .build_signed_token(kid, rsa_key_pair)
121    .map_err(|e| ClientCredentialsError::RequestFailed(e.to_string()))?;
122  let mut last_error = String::new();
123  while count < AUTHORIZATION_TRIES {
124    match request_service_token(platform_token_url, &token, scopes).await {
125      Ok(response) => return Ok(response),
126      Err(ClientCredentialsError::RateLimited(e)) => {
127        last_error = e.to_string();
128        // Wait 1 second before trying again
129        sleep(std::time::Duration::from_secs(1)).await;
130      }
131      Err(e) => return Err(e),
132    }
133    count += 1;
134  }
135  Err(ClientCredentialsError::RequestLimitReached(last_error))
136}
137
138pub async fn request_service_token(
139  platform_token_url: &str,
140  token: &str,
141  scopes: &str,
142) -> Result<ClientAuthorizationResponse, ClientCredentialsError> {
143  let params = ClientAuthorizationRequest::new(scopes, token);
144  let client = reqwest::Client::new();
145  let response = client
146    .post(platform_token_url)
147    .header(header::CONTENT_TYPE, "application/x-www-form-urlencoded")
148    .form(&params)
149    .send()
150    .await
151    .map_err(|e| ClientCredentialsError::RequestFailed(e.to_string()))?;
152  let status = response.status();
153  let body = response
154    .text()
155    .await
156    .map_err(|e| ClientCredentialsError::RequestFailed(e.to_string()))?;
157  if !status.is_success() {
158    if body.contains("rate limit") {
159      return Err(ClientCredentialsError::RateLimited(body));
160    }
161
162    return Err(ClientCredentialsError::RequestFailed(body));
163  }
164
165  let access_token_response: ClientAuthorizationResponse = serde_json::from_str(&body)
166    .map_err(|e| ClientCredentialsError::RequestFailed(e.to_string()))?;
167
168  Ok(access_token_response)
169}
170
171#[cfg(test)]
172mod tests {
173  use super::*;
174  use openssl::rsa::Rsa;
175
176  #[tokio::test]
177  async fn test_request_service_token_success() {
178    let mut server = mockito::Server::new_async().await;
179    let server_url = server.url();
180    let response = ClientAuthorizationResponse {
181      access_token: "fake".to_string(),
182      token_type: "Bearer".to_string(),
183      expires_in: 3600,
184      scope: "https://purl.imsglobal.org/spec/lti-ags/scope/lineitem https://purl.imsglobal.org/spec/lti-ags/scope/result/read".to_string(),
185    };
186    let response_body = serde_json::to_string(&response).expect("failed to serialize response");
187    let mock = server
188      .mock("POST", "/token")
189      .with_status(200)
190      .with_body(response_body)
191      .create();
192    let client_id = "test_client_id";
193    let platform_token_url = format!("{}/token", server_url);
194    let rsa = Rsa::generate(2048).expect("failed to generate rsa key pair");
195    let kid = "test_kid";
196    let scopes = "https://purl.imsglobal.org/spec/lti-ags/scope/lineitem https://purl.imsglobal.org/spec/lti-ags/scope/result/read";
197    let result =
198      request_service_token_cached(client_id, &platform_token_url, scopes, kid, rsa).await;
199    mock.assert();
200    assert!(result.is_ok());
201  }
202
203  #[tokio::test]
204  async fn test_request_service_token_rate_limited() {
205    let mut server = mockito::Server::new_async().await;
206    let server_url = server.url();
207    let mock = server
208      .mock("POST", "/token")
209      .with_status(429)
210      .with_body("rate limit exceeded")
211      .expect(3) // Mock should be called 3 times
212      .create();
213    let client_id = "test_fail_client_id";
214    let platform_token_url = format!("{}/token", server_url);
215    let rsa = Rsa::generate(2048).expect("failed to generate rsa key pair");
216    let scopes = "https://purl.imsglobal.org/spec/lti-ags/scope/lineitem https://purl.imsglobal.org/spec/lti-ags/scope/result/read";
217    let kid = "test_kid";
218    let result =
219      request_service_token_cached(client_id, &platform_token_url, scopes, kid, rsa).await;
220
221    mock.assert();
222    assert!(matches!(
223      result,
224      Err(ClientCredentialsError::RequestLimitReached(_))
225    ));
226  }
227}