1use crate::keyset::Keyset;
2use crate::types::{OAuthClientMetadata, TryIntoOAuthClientMetadata};
3use atrium_xrpc::http::uri::{InvalidUri, Scheme, Uri};
4use serde::{Deserialize, Serialize};
5use thiserror::Error;
6
7#[derive(Error, Debug)]
8pub enum Error {
9 #[error("`client_id` must be a valid URL")]
10 InvalidClientId,
11 #[error("`grant_types` must include `authorization_code`")]
12 InvalidGrantTypes,
13 #[error("`scope` must not include `atproto`")]
14 InvalidScope,
15 #[error("`redirect_uris` must not be empty")]
16 EmptyRedirectUris,
17 #[error("`private_key_jwt` auth method requires `jwks` keys")]
18 EmptyJwks,
19 #[error("`private_key_jwt` auth method requires `token_endpoint_auth_signing_alg`, otherwise must not be provided")]
20 AuthSigningAlg,
21 #[error(transparent)]
22 SerdeHtmlForm(#[from] serde_html_form::ser::Error),
23 #[error(transparent)]
24 LocalhostClient(#[from] LocalhostClientError),
25}
26
27#[derive(Error, Debug)]
28pub enum LocalhostClientError {
29 #[error("invalid redirect_uri: {0}")]
30 Invalid(#[from] InvalidUri),
31 #[error("loopback client_id must use `http:` redirect_uri")]
32 NotHttpScheme,
33 #[error("loopback client_id must not use `localhost` as redirect_uri hostname")]
34 Localhost,
35 #[error("loopback client_id must not use loopback addresses as redirect_uri")]
36 NotLoopbackHost,
37}
38
39pub type Result<T> = core::result::Result<T, Error>;
40
41#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
42#[serde(rename_all = "snake_case")]
43pub enum AuthMethod {
44 None,
45 PrivateKeyJwt,
47}
48
49impl From<AuthMethod> for String {
50 fn from(value: AuthMethod) -> Self {
51 match value {
52 AuthMethod::None => String::from("none"),
53 AuthMethod::PrivateKeyJwt => String::from("private_key_jwt"),
54 }
55 }
56}
57
58#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
59#[serde(rename_all = "snake_case")]
60pub enum GrantType {
61 AuthorizationCode,
62 RefreshToken,
63}
64
65impl From<GrantType> for String {
66 fn from(value: GrantType) -> Self {
67 match value {
68 GrantType::AuthorizationCode => String::from("authorization_code"),
69 GrantType::RefreshToken => String::from("refresh_token"),
70 }
71 }
72}
73
74#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
75#[serde(untagged)]
76pub enum Scope {
77 Known(KnownScope),
78 Unknown(String),
79}
80
81#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
82pub enum KnownScope {
83 #[serde(rename = "atproto")]
84 Atproto,
85 #[serde(rename = "transition:generic")]
86 TransitionGeneric,
87 #[serde(rename = "transition:chat.bsky")]
88 TransitionChatBsky,
89}
90
91impl AsRef<str> for Scope {
92 fn as_ref(&self) -> &str {
93 match self {
94 Self::Known(KnownScope::Atproto) => "atproto",
95 Self::Known(KnownScope::TransitionGeneric) => "transition:generic",
96 Self::Known(KnownScope::TransitionChatBsky) => "transition:chat.bsky",
97 Self::Unknown(value) => value,
98 }
99 }
100}
101
102#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize, Default)]
103pub struct AtprotoLocalhostClientMetadata {
104 pub redirect_uris: Option<Vec<String>>,
105 pub scopes: Option<Vec<Scope>>,
106}
107
108#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
109pub struct AtprotoClientMetadata {
110 pub client_id: String,
111 pub client_uri: Option<String>,
112 pub redirect_uris: Vec<String>,
113 pub token_endpoint_auth_method: AuthMethod,
114 pub grant_types: Vec<GrantType>,
115 pub scopes: Vec<Scope>,
116 pub jwks_uri: Option<String>,
117 pub token_endpoint_auth_signing_alg: Option<String>,
118}
119
120impl TryIntoOAuthClientMetadata for AtprotoLocalhostClientMetadata {
121 type Error = Error;
122
123 fn try_into_client_metadata(self, _: &Option<Keyset>) -> Result<OAuthClientMetadata> {
124 if let Some(redirect_uris) = &self.redirect_uris {
126 for redirect_uri in redirect_uris {
127 let uri = redirect_uri.parse::<Uri>().map_err(LocalhostClientError::Invalid)?;
128 if uri.scheme() != Some(&Scheme::HTTP) {
129 return Err(Error::LocalhostClient(LocalhostClientError::NotHttpScheme));
130 }
131 if uri.host() == Some("localhost") {
132 return Err(Error::LocalhostClient(LocalhostClientError::Localhost));
133 }
134 if uri.host().map_or(true, |host| host != "127.0.0.1" && host != "[::1]") {
135 return Err(Error::LocalhostClient(LocalhostClientError::NotLoopbackHost));
136 }
137 }
138 }
139 #[derive(serde::Serialize)]
141 struct Parameters {
142 #[serde(skip_serializing_if = "Option::is_none")]
143 redirect_uri: Option<Vec<String>>,
144 #[serde(skip_serializing_if = "Option::is_none")]
145 scope: Option<String>,
146 }
147 let query = serde_html_form::to_string(Parameters {
148 redirect_uri: self.redirect_uris.clone(),
149 scope: self
150 .scopes
151 .map(|scopes| scopes.iter().map(AsRef::as_ref).collect::<Vec<_>>().join(" ")),
152 })?;
153 let mut client_id = String::from("http://localhost");
154 if !query.is_empty() {
155 client_id.push_str(&format!("?{query}"));
156 }
157 Ok(OAuthClientMetadata {
158 client_id,
159 client_uri: None,
160 redirect_uris: self
161 .redirect_uris
162 .unwrap_or(vec![String::from("http://127.0.0.1/"), String::from("http://[::1]/")]),
163 scope: None,
164 grant_types: None, token_endpoint_auth_method: Some(String::from("none")),
166 dpop_bound_access_tokens: None, jwks_uri: None,
168 jwks: None,
169 token_endpoint_auth_signing_alg: None,
170 })
171 }
172}
173
174impl TryIntoOAuthClientMetadata for AtprotoClientMetadata {
175 type Error = Error;
176
177 fn try_into_client_metadata(self, keyset: &Option<Keyset>) -> Result<OAuthClientMetadata> {
178 if self.client_id.parse::<Uri>().is_err() {
179 return Err(Error::InvalidClientId);
180 }
181 if self.redirect_uris.is_empty() {
182 return Err(Error::EmptyRedirectUris);
183 }
184 if !self.grant_types.contains(&GrantType::AuthorizationCode) {
185 return Err(Error::InvalidGrantTypes);
186 }
187 if !self.scopes.contains(&Scope::Known(KnownScope::Atproto)) {
188 return Err(Error::InvalidScope);
189 }
190 let (jwks_uri, mut jwks) = (self.jwks_uri, None);
191 match self.token_endpoint_auth_method {
192 AuthMethod::None => {
193 if self.token_endpoint_auth_signing_alg.is_some() {
194 return Err(Error::AuthSigningAlg);
195 }
196 }
197 AuthMethod::PrivateKeyJwt => {
198 if let Some(keyset) = keyset {
199 if self.token_endpoint_auth_signing_alg.is_none() {
200 return Err(Error::AuthSigningAlg);
201 }
202 if jwks_uri.is_none() {
203 jwks = Some(keyset.public_jwks());
204 }
205 } else {
206 return Err(Error::EmptyJwks);
207 }
208 }
209 }
210 Ok(OAuthClientMetadata {
211 client_id: self.client_id,
212 client_uri: self.client_uri,
213 redirect_uris: self.redirect_uris,
214 token_endpoint_auth_method: Some(self.token_endpoint_auth_method.into()),
215 grant_types: Some(self.grant_types.into_iter().map(|v| v.into()).collect()),
216 scope: Some(self.scopes.iter().map(AsRef::as_ref).collect::<Vec<_>>().join(" ")),
217 dpop_bound_access_tokens: Some(true),
218 jwks_uri,
219 jwks,
220 token_endpoint_auth_signing_alg: self.token_endpoint_auth_signing_alg,
221 })
222 }
223}
224
225#[cfg(test)]
226mod tests {
227 use super::*;
228 use elliptic_curve::SecretKey;
229 use jose_jwk::{Jwk, Key, Parameters};
230 use p256::pkcs8::DecodePrivateKey;
231
232 const PRIVATE_KEY: &str = r#"-----BEGIN PRIVATE KEY-----
233MIGHAgEAMBMGByqGSM49AgEGCCqGSM49AwEHBG0wawIBAQQgED1AAgC7Fc9kPh5T
2344i4Tn+z+tc47W1zYgzXtyjJtD92hRANCAAT80DqC+Z/JpTO7/pkPBmWqIV1IGh1P
235gbGGr0pN+oSing7cZ0169JaRHTNh+0LNQXrFobInX6cj95FzEdRyT4T3
236-----END PRIVATE KEY-----"#;
237
238 #[test]
239 fn test_localhost_client_metadata_default() {
240 let metadata = AtprotoLocalhostClientMetadata::default();
241 assert_eq!(
242 metadata.try_into_client_metadata(&None).expect("failed to convert metadata"),
243 OAuthClientMetadata {
244 client_id: String::from("http://localhost"),
245 client_uri: None,
246 redirect_uris: vec![
247 String::from("http://127.0.0.1/"),
248 String::from("http://[::1]/"),
249 ],
250 scope: None,
251 grant_types: None,
252 token_endpoint_auth_method: Some(AuthMethod::None.into()),
253 dpop_bound_access_tokens: None,
254 jwks_uri: None,
255 jwks: None,
256 token_endpoint_auth_signing_alg: None,
257 }
258 );
259 }
260
261 #[test]
262 fn test_localhost_client_metadata_custom() {
263 let metadata = AtprotoLocalhostClientMetadata {
264 redirect_uris: Some(vec![
265 String::from("http://127.0.0.1/callback"),
266 String::from("http://[::1]/callback"),
267 ]),
268 scopes: Some(vec![
269 Scope::Known(KnownScope::Atproto),
270 Scope::Known(KnownScope::TransitionGeneric),
271 Scope::Unknown(String::from("unknown")),
272 ]),
273 };
274 assert_eq!(
275 metadata.try_into_client_metadata(&None).expect("failed to convert metadata"),
276 OAuthClientMetadata {
277 client_id: String::from("http://localhost?redirect_uri=http%3A%2F%2F127.0.0.1%2Fcallback&redirect_uri=http%3A%2F%2F%5B%3A%3A1%5D%2Fcallback&scope=atproto+transition%3Ageneric+unknown"),
278 client_uri: None,
279 redirect_uris: vec![
280 String::from("http://127.0.0.1/callback"),
281 String::from("http://[::1]/callback"),
282 ],
283 scope: None,
284 grant_types: None,
285 token_endpoint_auth_method: Some(AuthMethod::None.into()),
286 dpop_bound_access_tokens: None,
287 jwks_uri: None,
288 jwks: None,
289 token_endpoint_auth_signing_alg: None,
290 }
291 );
292 }
293
294 #[test]
295 fn test_localhost_client_metadata_invalid() {
296 {
297 let metadata = AtprotoLocalhostClientMetadata {
298 redirect_uris: Some(vec![String::from("http://")]),
299 ..Default::default()
300 };
301 let err = metadata.try_into_client_metadata(&None).expect_err("expected to fail");
302 assert!(matches!(err, Error::LocalhostClient(LocalhostClientError::Invalid(_))));
303 }
304 {
305 let metadata = AtprotoLocalhostClientMetadata {
306 redirect_uris: Some(vec![String::from("https://127.0.0.1/")]),
307 ..Default::default()
308 };
309 let err = metadata.try_into_client_metadata(&None).expect_err("expected to fail");
310 assert!(matches!(err, Error::LocalhostClient(LocalhostClientError::NotHttpScheme)));
311 }
312 {
313 let metadata = AtprotoLocalhostClientMetadata {
314 redirect_uris: Some(vec![String::from("http://localhost:8000/")]),
315 ..Default::default()
316 };
317 let err = metadata.try_into_client_metadata(&None).expect_err("expected to fail");
318 assert!(matches!(err, Error::LocalhostClient(LocalhostClientError::Localhost)));
319 }
320 {
321 let metadata = AtprotoLocalhostClientMetadata {
322 redirect_uris: Some(vec![String::from("http://192.168.0.0/")]),
323 ..Default::default()
324 };
325 let err = metadata.try_into_client_metadata(&None).expect_err("expected to fail");
326 assert!(matches!(err, Error::LocalhostClient(LocalhostClientError::NotLoopbackHost)));
327 }
328 }
329
330 #[test]
331 fn test_client_metadata() {
332 let metadata = AtprotoClientMetadata {
333 client_id: String::from("https://example.com/client_metadata.json"),
334 client_uri: Some(String::from("https://example.com")),
335 redirect_uris: vec![String::from("https://example.com/callback")],
336 token_endpoint_auth_method: AuthMethod::PrivateKeyJwt,
337 grant_types: vec![GrantType::AuthorizationCode],
338 scopes: vec![Scope::Known(KnownScope::Atproto)],
339 jwks_uri: None,
340 token_endpoint_auth_signing_alg: Some(String::from("ES256")),
341 };
342 {
343 let metadata = metadata.clone();
344 let err = metadata.try_into_client_metadata(&None).expect_err("expected to fail");
345 assert!(matches!(err, Error::EmptyJwks));
346 }
347 {
348 let metadata = metadata.clone();
349 let secret_key = SecretKey::<p256::NistP256>::from_pkcs8_pem(PRIVATE_KEY)
350 .expect("failed to parse private key");
351 let keys = vec![Jwk {
352 key: Key::from(&secret_key.into()),
353 prm: Parameters { kid: Some(String::from("kid00")), ..Default::default() },
354 }];
355 let keyset = Keyset::try_from(keys.clone()).expect("failed to create keyset");
356 assert_eq!(
357 metadata
358 .try_into_client_metadata(&Some(keyset.clone()))
359 .expect("failed to convert metadata"),
360 OAuthClientMetadata {
361 client_id: String::from("https://example.com/client_metadata.json"),
362 client_uri: Some(String::from("https://example.com")),
363 redirect_uris: vec![String::from("https://example.com/callback"),],
364 scope: Some(String::from("atproto")),
365 grant_types: Some(vec![String::from("authorization_code")]),
366 token_endpoint_auth_method: Some(AuthMethod::PrivateKeyJwt.into()),
367 dpop_bound_access_tokens: Some(true),
368 jwks_uri: None,
369 jwks: Some(keyset.public_jwks()),
370 token_endpoint_auth_signing_alg: Some(String::from("ES256")),
371 }
372 );
373 }
374 }
375
376 #[test]
377 fn test_scope_serde() {
378 #[derive(Debug, PartialEq, Eq, Serialize, Deserialize)]
379 struct Scopes {
380 scopes: Vec<Scope>,
381 }
382
383 let scopes = Scopes {
384 scopes: vec![
385 Scope::Known(KnownScope::Atproto),
386 Scope::Known(KnownScope::TransitionGeneric),
387 Scope::Unknown(String::from("unknown")),
388 ],
389 };
390 let json = serde_json::to_string(&scopes).expect("failed to serialize scopes");
391 assert_eq!(json, r#"{"scopes":["atproto","transition:generic","unknown"]}"#);
392 let deserialized =
393 serde_json::from_str::<Scopes>(&json).expect("failed to deserialize scopes");
394 assert_eq!(deserialized, scopes);
395 }
396}