Skip to main content

kontext_dev_sdk/
verify.rs

1use std::collections::HashSet;
2
3use jsonwebtoken::Algorithm;
4use jsonwebtoken::DecodingKey;
5use jsonwebtoken::Validation;
6use serde::{Deserialize, Serialize};
7
8#[derive(Clone, Debug)]
9pub struct KontextTokenVerifierConfig {
10    pub jwks_url: String,
11    pub issuer: String,
12    pub audience: String,
13    pub required_scopes: Vec<String>,
14}
15
16#[derive(Clone, Debug, Deserialize, Serialize)]
17#[serde(rename_all = "camelCase")]
18pub struct VerifiedTokenClaims {
19    pub subject: Option<String>,
20    pub issuer: Option<String>,
21    pub audience: Vec<String>,
22    pub scope: Option<String>,
23    pub client_id: Option<String>,
24    pub expires_at: Option<u64>,
25    pub raw: serde_json::Value,
26}
27
28#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq)]
29#[serde(rename_all = "snake_case")]
30pub enum TokenVerificationErrorCode {
31    MissingKid,
32    JwksFetchFailed,
33    SigningKeyNotFound,
34    UnsupportedKey,
35    InvalidToken,
36    MissingScope,
37}
38
39#[derive(Clone, Debug, Deserialize, Serialize)]
40#[serde(rename_all = "camelCase")]
41pub struct TokenVerificationError {
42    pub code: TokenVerificationErrorCode,
43    pub message: String,
44}
45
46#[derive(Clone, Debug, Deserialize, Serialize)]
47#[serde(rename_all = "camelCase")]
48pub struct VerifyResult {
49    pub success: bool,
50    pub claims: Option<VerifiedTokenClaims>,
51    pub error: Option<TokenVerificationError>,
52}
53
54#[derive(Clone, Debug)]
55pub struct JwksClient {
56    jwks_url: String,
57    http: reqwest::Client,
58}
59
60impl JwksClient {
61    pub fn new(jwks_url: impl Into<String>) -> Self {
62        Self {
63            jwks_url: jwks_url.into(),
64            http: reqwest::Client::new(),
65        }
66    }
67
68    async fn fetch(&self) -> Result<JwkSet, TokenVerificationError> {
69        let response = self
70            .http
71            .get(self.jwks_url.as_str())
72            .send()
73            .await
74            .map_err(|err| TokenVerificationError {
75                code: TokenVerificationErrorCode::JwksFetchFailed,
76                message: err.to_string(),
77            })?;
78
79        if !response.status().is_success() {
80            let status = response.status();
81            let body = response.text().await.unwrap_or_default();
82            return Err(TokenVerificationError {
83                code: TokenVerificationErrorCode::JwksFetchFailed,
84                message: format!("{status}: {body}"),
85            });
86        }
87
88        response
89            .json::<JwkSet>()
90            .await
91            .map_err(|err| TokenVerificationError {
92                code: TokenVerificationErrorCode::JwksFetchFailed,
93                message: err.to_string(),
94            })
95    }
96}
97
98#[derive(Clone, Debug)]
99pub struct KontextTokenVerifier {
100    config: KontextTokenVerifierConfig,
101    jwks_client: JwksClient,
102}
103
104impl KontextTokenVerifier {
105    pub fn new(config: KontextTokenVerifierConfig) -> Self {
106        let jwks_client = JwksClient::new(config.jwks_url.clone());
107        Self {
108            config,
109            jwks_client,
110        }
111    }
112
113    pub async fn verify(&self, token: &str) -> VerifyResult {
114        match self.verify_inner(token).await {
115            Ok(claims) => VerifyResult {
116                success: true,
117                claims: Some(claims),
118                error: None,
119            },
120            Err(error) => VerifyResult {
121                success: false,
122                claims: None,
123                error: Some(error),
124            },
125        }
126    }
127
128    async fn verify_inner(
129        &self,
130        token: &str,
131    ) -> Result<VerifiedTokenClaims, TokenVerificationError> {
132        let header = jsonwebtoken::decode_header(token).map_err(|err| TokenVerificationError {
133            code: TokenVerificationErrorCode::InvalidToken,
134            message: err.to_string(),
135        })?;
136
137        let kid = header.kid.ok_or(TokenVerificationError {
138            code: TokenVerificationErrorCode::MissingKid,
139            message: "token header is missing kid".to_string(),
140        })?;
141
142        let jwk_set = self.jwks_client.fetch().await?;
143        let jwk = jwk_set
144            .keys
145            .iter()
146            .find(|k| k.kid.as_deref() == Some(kid.as_str()))
147            .ok_or(TokenVerificationError {
148                code: TokenVerificationErrorCode::SigningKeyNotFound,
149                message: format!("no signing key found for kid={kid}"),
150            })?;
151
152        let decoding_key = jwk.decoding_key()?;
153
154        let mut validation = Validation::new(Algorithm::RS256);
155        validation.validate_exp = true;
156        validation.set_audience(&[self.config.audience.clone()]);
157        validation.set_issuer(&[self.config.issuer.clone()]);
158
159        let payload = jsonwebtoken::decode::<serde_json::Value>(token, &decoding_key, &validation)
160            .map_err(|err| TokenVerificationError {
161                code: TokenVerificationErrorCode::InvalidToken,
162                message: err.to_string(),
163            })?;
164
165        let claims = payload.claims;
166        let token_scopes = claims
167            .get("scope")
168            .and_then(|value| value.as_str())
169            .unwrap_or_default();
170
171        let available_scopes: HashSet<&str> = token_scopes
172            .split_whitespace()
173            .filter(|scope| !scope.is_empty())
174            .collect();
175
176        for required_scope in &self.config.required_scopes {
177            if !available_scopes.contains(required_scope.as_str()) {
178                return Err(TokenVerificationError {
179                    code: TokenVerificationErrorCode::MissingScope,
180                    message: format!("missing required scope `{required_scope}`"),
181                });
182            }
183        }
184
185        let audience = match claims.get("aud") {
186            Some(serde_json::Value::Array(values)) => values
187                .iter()
188                .filter_map(|value| value.as_str().map(ToString::to_string))
189                .collect(),
190            Some(serde_json::Value::String(value)) => vec![value.to_string()],
191            _ => Vec::new(),
192        };
193
194        Ok(VerifiedTokenClaims {
195            subject: claims
196                .get("sub")
197                .and_then(|value| value.as_str())
198                .map(ToString::to_string),
199            issuer: claims
200                .get("iss")
201                .and_then(|value| value.as_str())
202                .map(ToString::to_string),
203            audience,
204            scope: claims
205                .get("scope")
206                .and_then(|value| value.as_str())
207                .map(ToString::to_string),
208            client_id: claims
209                .get("client_id")
210                .and_then(|value| value.as_str())
211                .map(ToString::to_string),
212            expires_at: claims.get("exp").and_then(|value| value.as_u64()),
213            raw: claims,
214        })
215    }
216}
217
218#[derive(Clone, Debug, Deserialize)]
219struct JwkSet {
220    keys: Vec<Jwk>,
221}
222
223#[derive(Clone, Debug, Deserialize)]
224struct Jwk {
225    kty: String,
226    #[allow(dead_code)]
227    alg: Option<String>,
228    kid: Option<String>,
229    n: Option<String>,
230    e: Option<String>,
231}
232
233impl Jwk {
234    fn decoding_key(&self) -> Result<DecodingKey, TokenVerificationError> {
235        if self.kty != "RSA" {
236            return Err(TokenVerificationError {
237                code: TokenVerificationErrorCode::UnsupportedKey,
238                message: format!("unsupported key type `{}`", self.kty),
239            });
240        }
241
242        let n = self.n.as_deref().ok_or(TokenVerificationError {
243            code: TokenVerificationErrorCode::UnsupportedKey,
244            message: "RSA key missing modulus".to_string(),
245        })?;
246
247        let e = self.e.as_deref().ok_or(TokenVerificationError {
248            code: TokenVerificationErrorCode::UnsupportedKey,
249            message: "RSA key missing exponent".to_string(),
250        })?;
251
252        DecodingKey::from_rsa_components(n, e).map_err(|err| TokenVerificationError {
253            code: TokenVerificationErrorCode::UnsupportedKey,
254            message: err.to_string(),
255        })
256    }
257}