1use std::collections::HashSet;
2
3use jsonwebtoken::Algorithm;
4use jsonwebtoken::DecodingKey;
5use jsonwebtoken::Validation;
6use serde::{Deserialize, Serialize};
7
8#[derive(Clone, Debug)]
9pub struct KontextTokenVerifierConfig {
10 pub jwks_url: String,
11 pub issuer: String,
12 pub audience: String,
13 pub required_scopes: Vec<String>,
14}
15
16#[derive(Clone, Debug, Deserialize, Serialize)]
17#[serde(rename_all = "camelCase")]
18pub struct VerifiedTokenClaims {
19 pub subject: Option<String>,
20 pub issuer: Option<String>,
21 pub audience: Vec<String>,
22 pub scope: Option<String>,
23 pub client_id: Option<String>,
24 pub expires_at: Option<u64>,
25 pub raw: serde_json::Value,
26}
27
28#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq)]
29#[serde(rename_all = "snake_case")]
30pub enum TokenVerificationErrorCode {
31 MissingKid,
32 JwksFetchFailed,
33 SigningKeyNotFound,
34 UnsupportedKey,
35 InvalidToken,
36 MissingScope,
37}
38
39#[derive(Clone, Debug, Deserialize, Serialize)]
40#[serde(rename_all = "camelCase")]
41pub struct TokenVerificationError {
42 pub code: TokenVerificationErrorCode,
43 pub message: String,
44}
45
46#[derive(Clone, Debug, Deserialize, Serialize)]
47#[serde(rename_all = "camelCase")]
48pub struct VerifyResult {
49 pub success: bool,
50 pub claims: Option<VerifiedTokenClaims>,
51 pub error: Option<TokenVerificationError>,
52}
53
54#[derive(Clone, Debug)]
55pub struct JwksClient {
56 jwks_url: String,
57 http: reqwest::Client,
58}
59
60impl JwksClient {
61 pub fn new(jwks_url: impl Into<String>) -> Self {
62 Self {
63 jwks_url: jwks_url.into(),
64 http: reqwest::Client::new(),
65 }
66 }
67
68 async fn fetch(&self) -> Result<JwkSet, TokenVerificationError> {
69 let response = self
70 .http
71 .get(self.jwks_url.as_str())
72 .send()
73 .await
74 .map_err(|err| TokenVerificationError {
75 code: TokenVerificationErrorCode::JwksFetchFailed,
76 message: err.to_string(),
77 })?;
78
79 if !response.status().is_success() {
80 let status = response.status();
81 let body = response.text().await.unwrap_or_default();
82 return Err(TokenVerificationError {
83 code: TokenVerificationErrorCode::JwksFetchFailed,
84 message: format!("{status}: {body}"),
85 });
86 }
87
88 response
89 .json::<JwkSet>()
90 .await
91 .map_err(|err| TokenVerificationError {
92 code: TokenVerificationErrorCode::JwksFetchFailed,
93 message: err.to_string(),
94 })
95 }
96}
97
98#[derive(Clone, Debug)]
99pub struct KontextTokenVerifier {
100 config: KontextTokenVerifierConfig,
101 jwks_client: JwksClient,
102}
103
104impl KontextTokenVerifier {
105 pub fn new(config: KontextTokenVerifierConfig) -> Self {
106 let jwks_client = JwksClient::new(config.jwks_url.clone());
107 Self {
108 config,
109 jwks_client,
110 }
111 }
112
113 pub async fn verify(&self, token: &str) -> VerifyResult {
114 match self.verify_inner(token).await {
115 Ok(claims) => VerifyResult {
116 success: true,
117 claims: Some(claims),
118 error: None,
119 },
120 Err(error) => VerifyResult {
121 success: false,
122 claims: None,
123 error: Some(error),
124 },
125 }
126 }
127
128 async fn verify_inner(
129 &self,
130 token: &str,
131 ) -> Result<VerifiedTokenClaims, TokenVerificationError> {
132 let header = jsonwebtoken::decode_header(token).map_err(|err| TokenVerificationError {
133 code: TokenVerificationErrorCode::InvalidToken,
134 message: err.to_string(),
135 })?;
136
137 let kid = header.kid.ok_or(TokenVerificationError {
138 code: TokenVerificationErrorCode::MissingKid,
139 message: "token header is missing kid".to_string(),
140 })?;
141
142 let jwk_set = self.jwks_client.fetch().await?;
143 let jwk = jwk_set
144 .keys
145 .iter()
146 .find(|k| k.kid.as_deref() == Some(kid.as_str()))
147 .ok_or(TokenVerificationError {
148 code: TokenVerificationErrorCode::SigningKeyNotFound,
149 message: format!("no signing key found for kid={kid}"),
150 })?;
151
152 let decoding_key = jwk.decoding_key()?;
153
154 let mut validation = Validation::new(Algorithm::RS256);
155 validation.validate_exp = true;
156 validation.set_audience(&[self.config.audience.clone()]);
157 validation.set_issuer(&[self.config.issuer.clone()]);
158
159 let payload = jsonwebtoken::decode::<serde_json::Value>(token, &decoding_key, &validation)
160 .map_err(|err| TokenVerificationError {
161 code: TokenVerificationErrorCode::InvalidToken,
162 message: err.to_string(),
163 })?;
164
165 let claims = payload.claims;
166 let token_scopes = claims
167 .get("scope")
168 .and_then(|value| value.as_str())
169 .unwrap_or_default();
170
171 let available_scopes: HashSet<&str> = token_scopes
172 .split_whitespace()
173 .filter(|scope| !scope.is_empty())
174 .collect();
175
176 for required_scope in &self.config.required_scopes {
177 if !available_scopes.contains(required_scope.as_str()) {
178 return Err(TokenVerificationError {
179 code: TokenVerificationErrorCode::MissingScope,
180 message: format!("missing required scope `{required_scope}`"),
181 });
182 }
183 }
184
185 let audience = match claims.get("aud") {
186 Some(serde_json::Value::Array(values)) => values
187 .iter()
188 .filter_map(|value| value.as_str().map(ToString::to_string))
189 .collect(),
190 Some(serde_json::Value::String(value)) => vec![value.to_string()],
191 _ => Vec::new(),
192 };
193
194 Ok(VerifiedTokenClaims {
195 subject: claims
196 .get("sub")
197 .and_then(|value| value.as_str())
198 .map(ToString::to_string),
199 issuer: claims
200 .get("iss")
201 .and_then(|value| value.as_str())
202 .map(ToString::to_string),
203 audience,
204 scope: claims
205 .get("scope")
206 .and_then(|value| value.as_str())
207 .map(ToString::to_string),
208 client_id: claims
209 .get("client_id")
210 .and_then(|value| value.as_str())
211 .map(ToString::to_string),
212 expires_at: claims.get("exp").and_then(|value| value.as_u64()),
213 raw: claims,
214 })
215 }
216}
217
218#[derive(Clone, Debug, Deserialize)]
219struct JwkSet {
220 keys: Vec<Jwk>,
221}
222
223#[derive(Clone, Debug, Deserialize)]
224struct Jwk {
225 kty: String,
226 #[allow(dead_code)]
227 alg: Option<String>,
228 kid: Option<String>,
229 n: Option<String>,
230 e: Option<String>,
231}
232
233impl Jwk {
234 fn decoding_key(&self) -> Result<DecodingKey, TokenVerificationError> {
235 if self.kty != "RSA" {
236 return Err(TokenVerificationError {
237 code: TokenVerificationErrorCode::UnsupportedKey,
238 message: format!("unsupported key type `{}`", self.kty),
239 });
240 }
241
242 let n = self.n.as_deref().ok_or(TokenVerificationError {
243 code: TokenVerificationErrorCode::UnsupportedKey,
244 message: "RSA key missing modulus".to_string(),
245 })?;
246
247 let e = self.e.as_deref().ok_or(TokenVerificationError {
248 code: TokenVerificationErrorCode::UnsupportedKey,
249 message: "RSA key missing exponent".to_string(),
250 })?;
251
252 DecodingKey::from_rsa_components(n, e).map_err(|err| TokenVerificationError {
253 code: TokenVerificationErrorCode::UnsupportedKey,
254 message: err.to_string(),
255 })
256 }
257}