1use std::str::FromStr;
2
3use crate::types::OAuthClientMetadata;
4use crate::{
5 keyset::Keyset,
6 scopes::{Scope, Scopes},
7};
8use jacquard_common::deps::fluent_uri::Uri;
9use jacquard_common::{BosStr, IntoStatic};
10use serde::{Deserialize, Serialize};
11use smol_str::{SmolStr, ToSmolStr};
12use thiserror::Error;
13
14#[derive(Error, Debug)]
16#[non_exhaustive]
17pub enum Error {
18 #[error("`client_id` must be a valid URL")]
20 InvalidClientId,
21 #[error("`grant_types` must include `authorization_code`")]
23 InvalidGrantTypes,
24 #[error("`scope` must not include `atproto`")]
26 InvalidScope,
27 #[error("`redirect_uris` must not be empty")]
29 EmptyRedirectUris,
30 #[error("`private_key_jwt` auth method requires `jwks` keys")]
32 EmptyJwks,
33 #[error(
36 "`private_key_jwt` auth method requires `token_endpoint_auth_signing_alg`, otherwise must not be provided"
37 )]
38 AuthSigningAlg,
39 #[error(transparent)]
41 SerdeHtmlForm(#[from] serde_html_form::ser::Error),
42 #[error(transparent)]
44 LocalhostClient(#[from] LocalhostClientError),
45}
46
47#[derive(Error, Debug)]
52#[non_exhaustive]
53pub enum LocalhostClientError {
54 #[error("invalid redirect_uri: {0}")]
56 Invalid(#[from] jacquard_common::deps::fluent_uri::ParseError),
57 #[error("loopback client_id must use `http:` redirect_uri")]
59 NotHttpScheme,
60 #[error("loopback client_id must not use `localhost` as redirect_uri hostname")]
62 Localhost,
63 #[error("loopback client_id must not use loopback addresses as redirect_uri")]
65 NotLoopbackHost,
66}
67
68pub type Result<T> = core::result::Result<T, Error>;
70
71#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
76#[serde(rename_all = "snake_case")]
77pub enum AuthMethod {
78 None,
80 PrivateKeyJwt,
83}
84
85impl From<AuthMethod> for SmolStr {
86 fn from(value: AuthMethod) -> Self {
87 match value {
88 AuthMethod::None => SmolStr::new_static("none"),
89 AuthMethod::PrivateKeyJwt => SmolStr::new_static("private_key_jwt"),
90 }
91 }
92}
93
94impl From<AuthMethod> for &'static str {
95 fn from(value: AuthMethod) -> Self {
96 match value {
97 AuthMethod::None => "none",
98 AuthMethod::PrivateKeyJwt => "private_key_jwt",
99 }
100 }
101}
102
103#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
105#[serde(rename_all = "snake_case")]
106pub enum GrantType {
107 AuthorizationCode,
109 RefreshToken,
111}
112
113impl From<GrantType> for SmolStr {
114 fn from(value: GrantType) -> Self {
115 match value {
116 GrantType::AuthorizationCode => SmolStr::new_static("authorization_code"),
117 GrantType::RefreshToken => SmolStr::new_static("refresh_token"),
118 }
119 }
120}
121
122impl From<GrantType> for &'static str {
123 fn from(value: GrantType) -> Self {
124 match value {
125 GrantType::AuthorizationCode => "authorization_code",
126 GrantType::RefreshToken => "refresh_token",
127 }
128 }
129}
130
131#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
138pub struct AtprotoClientMetadata<S: BosStr + FromStr + Ord>
139where
140 <S as FromStr>::Err: core::fmt::Debug,
141{
142 pub client_id: Uri<String>,
144 pub client_uri: Option<Uri<String>>,
146 pub redirect_uris: Vec<Uri<String>>,
148 pub grant_types: Vec<GrantType>,
150 pub scopes: Scopes<S>,
152 pub jwks_uri: Option<Uri<String>>,
154 pub client_name: Option<S>,
156 pub logo_uri: Option<Uri<String>>,
158 pub tos_uri: Option<Uri<String>>,
160 pub privacy_policy_uri: Option<Uri<String>>,
162}
163
164impl<S> IntoStatic for AtprotoClientMetadata<S>
165where
166 S: BosStr + IntoStatic + Ord + FromStr + AsRef<str>,
167 <S as FromStr>::Err: core::fmt::Debug,
168 S::Output: BosStr + FromStr + Ord + AsRef<str>,
169 <S::Output as FromStr>::Err: core::fmt::Debug,
170{
171 type Output = AtprotoClientMetadata<S::Output>;
172 fn into_static(self) -> AtprotoClientMetadata<S::Output> {
173 AtprotoClientMetadata {
174 client_id: self.client_id,
175 client_uri: self.client_uri,
176 redirect_uris: self.redirect_uris,
177 grant_types: self.grant_types,
178 scopes: self.scopes.into_static(),
179 jwks_uri: self.jwks_uri,
180 client_name: self.client_name.into_static(),
181 logo_uri: self.logo_uri,
182 tos_uri: self.tos_uri,
183 privacy_policy_uri: None,
184 }
185 }
186}
187
188impl<S> AtprotoClientMetadata<S>
189where
190 S: BosStr + IntoStatic + Ord + FromStr,
191 <S as FromStr>::Err: core::fmt::Debug,
192 S::Output: BosStr + FromStr + Ord,
193 <S::Output as FromStr>::Err: core::fmt::Debug,
194{
195 pub fn with_prod_info(
200 mut self,
201 client_name: S,
202 logo_uri: Option<Uri<String>>,
203 tos_uri: Option<Uri<String>>,
204 privacy_policy_uri: Option<Uri<String>>,
205 ) -> Self {
206 self.client_name = Some(client_name);
207 self.logo_uri = logo_uri;
208 self.tos_uri = tos_uri;
209 self.privacy_policy_uri = privacy_policy_uri;
210 self
211 }
212
213 pub fn with_scopes(mut self, scopes: Scopes<S>) -> Self {
215 self.scopes = scopes;
216 self
217 }
218
219 pub fn with_jwks_uri(mut self, jwks_uri: Uri<String>) -> Self {
221 self.jwks_uri = Some(jwks_uri);
222 self
223 }
224
225 pub fn with_client_name(mut self, client_name: S) -> Self {
227 self.client_name = Some(client_name);
228 self
229 }
230
231 pub fn default_localhost() -> Self
237 where
238 S: From<SmolStr> + AsRef<str>,
239 {
240 let scopes = Scopes::new(SmolStr::new_static("atproto transition:generic"))
241 .expect("valid scopes")
242 .convert();
243 Self::new_localhost(None, Some(scopes))
244 }
245
246 pub fn new(
251 redirect_uris: Vec<Uri<String>>,
252 client_id: Uri<String>,
253 scopes: Option<Scopes<S>>,
254 ) -> AtprotoClientMetadata<S>
255 where
256 S: From<SmolStr> + AsRef<str>,
257 {
258 let default_scopes: Scopes<S> = Scopes::new(SmolStr::new_static("atproto"))
259 .expect("valid scopes")
260 .convert();
261 AtprotoClientMetadata {
262 client_id: client_id.clone(),
263 client_uri: Some(client_id),
264 redirect_uris: redirect_uris,
265 grant_types: vec![GrantType::AuthorizationCode, GrantType::RefreshToken],
266 scopes: scopes.unwrap_or(default_scopes),
267 jwks_uri: None,
268 client_name: None,
269 logo_uri: None,
270 tos_uri: None,
271 privacy_policy_uri: None,
272 }
273 }
274
275 pub fn new_localhost(
282 redirect_uris: Option<Vec<Uri<String>>>,
283 scopes: Option<Scopes<S>>,
284 ) -> AtprotoClientMetadata<S>
285 where
286 S: From<SmolStr> + AsRef<str>,
287 {
288 #[derive(serde::Serialize)]
290 struct Parameters {
291 #[serde(skip_serializing_if = "Option::is_none")]
292 redirect_uri: Option<Vec<SmolStr>>,
293 #[serde(skip_serializing_if = "Option::is_none")]
294 scope: Option<SmolStr>,
295 }
296 let redir_str = redirect_uris.as_ref().map(|uris| {
297 uris.iter()
298 .map(|u| u.as_str().trim_end_matches("/").to_smolstr())
299 .collect()
300 });
301 let query = serde_html_form::to_string(Parameters {
302 redirect_uri: redir_str,
303 scope: scopes.as_ref().map(|s| s.to_normalized_string()),
304 })
305 .ok();
306 let mut client_id = String::from("http://localhost/");
307 if let Some(query) = query
308 && !query.is_empty()
309 {
310 client_id.push_str(&format!("?{query}"));
311 }
312 let default_scopes: Scopes<S> = Scopes::new(SmolStr::new_static("atproto"))
313 .expect("valid scopes")
314 .convert();
315 AtprotoClientMetadata {
316 client_id: Uri::parse(client_id).unwrap(),
317 client_uri: None,
318 redirect_uris: redirect_uris.unwrap_or(vec![
319 Uri::parse("http://127.0.0.1".to_string()).unwrap(),
320 Uri::parse("http://[::1]".to_string()).unwrap(),
321 ]),
322 grant_types: vec![GrantType::AuthorizationCode, GrantType::RefreshToken],
323 scopes: scopes.unwrap_or(default_scopes),
324 jwks_uri: None,
325 client_name: None,
326 logo_uri: None,
327 tos_uri: None,
328 privacy_policy_uri: None,
329 }
330 }
331}
332
333pub fn atproto_client_metadata<S>(
340 metadata: &AtprotoClientMetadata<S>,
341 keyset: &Option<Keyset>,
342) -> Result<OAuthClientMetadata<S>>
343where
344 S: BosStr + Ord + FromStr + Clone,
345 <S as FromStr>::Err: core::fmt::Debug,
346{
347 let is_loopback = metadata.client_id.scheme().as_str() == "http"
348 && metadata.client_id.authority().map(|a| a.host()) == Some("localhost");
349 let application_type = if is_loopback {
350 Some(S::from_static("native"))
351 } else {
352 Some(S::from_static("web"))
353 };
354 if metadata.redirect_uris.is_empty() {
355 return Err(Error::EmptyRedirectUris);
356 }
357 if !metadata.grant_types.contains(&GrantType::AuthorizationCode) {
358 return Err(Error::InvalidGrantTypes);
359 }
360 if !metadata.scopes.grants(&Scope::<S>::Atproto) {
361 return Err(Error::InvalidScope);
362 }
363 let (auth_method, jwks_uri, jwks) = if let Some(keyset) = keyset {
364 let jwks = if metadata.jwks_uri.is_none() {
365 Some(keyset.public_jwks())
366 } else {
367 None
368 };
369 (AuthMethod::PrivateKeyJwt, metadata.jwks_uri.as_ref(), jwks)
370 } else {
371 (AuthMethod::None, None, None)
372 };
373 let client_id = metadata.client_id.as_str();
374 let client_uri = metadata
375 .client_uri
376 .as_ref()
377 .and_then(|u| S::from_str(u.as_str()).ok());
378 let redirect_uris = metadata
379 .redirect_uris
380 .iter()
381 .filter_map(|u| S::from_str(u.as_str()).ok())
382 .collect();
383 let jwks_uri = jwks_uri.as_ref().and_then(|u| S::from_str(u.as_str()).ok());
384 Ok(OAuthClientMetadata {
385 client_id: S::from_str(client_id).unwrap(),
386 client_uri,
387 redirect_uris,
388 application_type: application_type,
389 token_endpoint_auth_method: Some(S::from_static(auth_method.into())),
390 grant_types: Some(
391 metadata
392 .grant_types
393 .iter()
394 .map(|v| S::from_static(v.clone().into()))
395 .collect(),
396 ),
397 response_types: vec![S::from_static("code")],
398 scope: Some(S::from_str(metadata.scopes.to_normalized_string().as_str()).unwrap()),
399 dpop_bound_access_tokens: Some(true),
400 jwks_uri,
401 jwks,
402 token_endpoint_auth_signing_alg: if keyset.is_some() {
403 Some(S::from_static("ES256"))
404 } else {
405 None
406 },
407 client_name: metadata.client_name.as_ref().map(|c| c.clone()),
408 logo_uri: metadata
409 .logo_uri
410 .as_ref()
411 .and_then(|u| S::from_str(u.as_str()).ok()),
412 tos_uri: metadata
413 .tos_uri
414 .as_ref()
415 .and_then(|u| S::from_str(u.as_str()).ok()),
416 privacy_policy_uri: metadata
417 .privacy_policy_uri
418 .as_ref()
419 .and_then(|u| S::from_str(u.as_str()).ok()),
420 })
421}
422
423#[cfg(test)]
424mod tests {
425 use super::*;
426 use elliptic_curve::SecretKey;
427 use jose_jwk::{Jwk, Key, Parameters};
428 use p256::pkcs8::DecodePrivateKey;
429
430 const PRIVATE_KEY: &str = r#"-----BEGIN PRIVATE KEY-----
431MIGHAgEAMBMGByqGSM49AgEGCCqGSM49AwEHBG0wawIBAQQgED1AAgC7Fc9kPh5T
4324i4Tn+z+tc47W1zYgzXtyjJtD92hRANCAAT80DqC+Z/JpTO7/pkPBmWqIV1IGh1P
433gbGGr0pN+oSing7cZ0169JaRHTNh+0LNQXrFobInX6cj95FzEdRyT4T3
434-----END PRIVATE KEY-----"#;
435
436 #[test]
437 fn test_localhost_client_metadata_default() {
438 assert_eq!(
439 atproto_client_metadata(&AtprotoClientMetadata::new_localhost(None, None), &None)
440 .unwrap(),
441 OAuthClientMetadata {
442 client_id: SmolStr::new_static("http://localhost/"),
443 client_uri: None,
444 redirect_uris: vec![
445 SmolStr::new_static("http://127.0.0.1"),
446 SmolStr::new_static("http://[::1]"),
447 ],
448 application_type: Some(SmolStr::new_static("native")),
449 scope: Some(SmolStr::new_static("atproto")),
450 grant_types: Some(vec![
451 SmolStr::new_static("authorization_code"),
452 SmolStr::new_static("refresh_token")
453 ]),
454 response_types: vec![SmolStr::new_static("code")],
455 token_endpoint_auth_method: Some(AuthMethod::None.into()),
456 dpop_bound_access_tokens: Some(true),
457 jwks_uri: None,
458 jwks: None,
459 token_endpoint_auth_signing_alg: None,
460 tos_uri: None,
461 privacy_policy_uri: None,
462 client_name: None,
463 logo_uri: None,
464 }
465 );
466 }
467
468 #[test]
469 fn test_localhost_client_metadata_custom() {
470 assert_eq!(
471 atproto_client_metadata(
472 &AtprotoClientMetadata::new_localhost(
473 Some(vec![
474 Uri::parse("http://127.0.0.1/callback".to_string()).unwrap(),
475 Uri::parse("http://[::1]/callback".to_string()).unwrap(),
476 ]),
477 Some(
478 Scopes::new(SmolStr::from("account:email atproto transition:generic"))
479 .unwrap()
480 )
481 ),
482 &None
483 )
484 .expect("failed to convert metadata"),
485 OAuthClientMetadata {
486 client_id: SmolStr::new_static(
487 "http://localhost/?redirect_uri=http%3A%2F%2F127.0.0.1%2Fcallback&redirect_uri=http%3A%2F%2F%5B%3A%3A1%5D%2Fcallback&scope=account%3Aemail+atproto+transition%3Ageneric"
488 ),
489 client_uri: None,
490 redirect_uris: vec![
491 SmolStr::new_static("http://127.0.0.1/callback"),
492 SmolStr::new_static("http://[::1]/callback"),
493 ],
494 scope: Some(SmolStr::new_static(
495 "account:email atproto transition:generic"
496 )),
497 application_type: Some(SmolStr::new_static("native")),
498 grant_types: Some(vec![
499 SmolStr::new_static("authorization_code"),
500 SmolStr::new_static("refresh_token")
501 ]),
502 response_types: vec![SmolStr::new_static("code")],
503 token_endpoint_auth_method: Some(AuthMethod::None.into()),
504 dpop_bound_access_tokens: Some(true),
505 jwks_uri: None,
506 jwks: None,
507 token_endpoint_auth_signing_alg: None,
508 tos_uri: None,
509 privacy_policy_uri: None,
510 client_name: None,
511 logo_uri: None,
512 }
513 );
514 }
515
516 #[test]
517 fn test_localhost_client_metadata_invalid() {
518 {
520 let out = atproto_client_metadata(
521 &AtprotoClientMetadata::new_localhost(
522 Some(vec![Uri::parse("https://127.0.0.1".to_string()).unwrap()]),
523 None,
524 ),
525 &None,
526 )
527 .expect("should coerce to 127.0.0.1");
528 assert_eq!(
529 out,
530 OAuthClientMetadata {
531 client_id: SmolStr::new_static(
532 "http://localhost/?redirect_uri=https%3A%2F%2F127.0.0.1"
533 ),
534 application_type: Some(SmolStr::new_static("native")),
535 client_uri: None,
536 redirect_uris: vec![SmolStr::new_static("https://127.0.0.1")],
537 scope: Some(SmolStr::new_static("atproto")),
538 grant_types: Some(vec![
539 SmolStr::new_static("authorization_code"),
540 SmolStr::new_static("refresh_token")
541 ]),
542 response_types: vec![SmolStr::new_static("code")],
543 token_endpoint_auth_method: Some(AuthMethod::None.into()),
544 dpop_bound_access_tokens: Some(true),
545 jwks_uri: None,
546 jwks: None,
547 token_endpoint_auth_signing_alg: None,
548 tos_uri: None,
549 privacy_policy_uri: None,
550 client_name: None,
551 logo_uri: None,
552 }
553 );
554 }
555 {
556 let out = atproto_client_metadata(
557 &AtprotoClientMetadata::new_localhost(
558 Some(vec![
559 Uri::parse("http://localhost:8000".to_string()).unwrap(),
560 ]),
561 None,
562 ),
563 &None,
564 )
565 .expect("should coerce to 127.0.0.1");
566 assert_eq!(
567 out,
568 OAuthClientMetadata {
569 client_id: SmolStr::new_static(
570 "http://localhost/?redirect_uri=http%3A%2F%2Flocalhost%3A8000"
571 ),
572 client_uri: None,
573 redirect_uris: vec![SmolStr::new_static("http://localhost:8000")],
574 scope: Some(SmolStr::new_static("atproto")),
575 grant_types: Some(vec![
576 SmolStr::new_static("authorization_code"),
577 SmolStr::new_static("refresh_token")
578 ]),
579 application_type: Some(SmolStr::new_static("native")),
580 response_types: vec![SmolStr::new_static("code")],
581 token_endpoint_auth_method: Some(AuthMethod::None.into()),
582 dpop_bound_access_tokens: Some(true),
583 jwks_uri: None,
584 jwks: None,
585 token_endpoint_auth_signing_alg: None,
586 tos_uri: None,
587 privacy_policy_uri: None,
588 client_name: None,
589 logo_uri: None,
590 }
591 );
592 }
593 {
594 let out = atproto_client_metadata(
595 &AtprotoClientMetadata::new_localhost(
596 Some(vec![Uri::parse("http://192.168.0.0/".to_string()).unwrap()]),
597 None,
598 ),
599 &None,
600 )
601 .expect("should coerce to 127.0.0.1");
602 assert_eq!(
603 out,
604 OAuthClientMetadata {
605 client_id: SmolStr::new_static(
606 "http://localhost/?redirect_uri=http%3A%2F%2F192.168.0.0"
607 ),
608 client_uri: None,
609 redirect_uris: vec![SmolStr::new_static("http://192.168.0.0/")],
610 scope: Some(SmolStr::new_static("atproto")),
611 grant_types: Some(vec![
612 SmolStr::new_static("authorization_code"),
613 SmolStr::new_static("refresh_token")
614 ]),
615 application_type: Some(SmolStr::new_static("native")),
616 response_types: vec![SmolStr::new_static("code")],
617 token_endpoint_auth_method: Some(AuthMethod::None.into()),
618 dpop_bound_access_tokens: Some(true),
619 jwks_uri: None,
620 jwks: None,
621 token_endpoint_auth_signing_alg: None,
622 tos_uri: None,
623 privacy_policy_uri: None,
624 client_name: None,
625 logo_uri: None,
626 }
627 );
628 }
629 }
630
631 #[test]
632 fn test_client_metadata() {
633 let metadata = AtprotoClientMetadata {
634 client_id: Uri::parse("https://example.com/client_metadata.json".to_string()).unwrap(),
635 client_uri: Some(Uri::parse("https://example.com".to_string()).unwrap()),
636 redirect_uris: vec![Uri::parse("https://example.com/callback".to_string()).unwrap()],
637 grant_types: vec![GrantType::AuthorizationCode],
638 scopes: Scopes::new(SmolStr::new_static("atproto")).unwrap(),
639 jwks_uri: None,
640 client_name: None,
641 logo_uri: None,
642 tos_uri: None,
643 privacy_policy_uri: None,
644 };
645 {
646 let metadata = metadata.clone();
648 let err = atproto_client_metadata(&metadata, &None);
649 assert!(err.is_ok());
650 }
651 {
652 let metadata = metadata.clone();
653 let secret_key = SecretKey::<p256::NistP256>::from_pkcs8_pem(PRIVATE_KEY)
654 .expect("failed to parse private key");
655 let keys = vec![Jwk {
656 key: Key::from(&secret_key.into()),
657 prm: Parameters {
658 kid: Some(String::from("kid00")),
659 ..Default::default()
660 },
661 }];
662 let keyset = Keyset::try_from(keys.clone()).expect("failed to create keyset");
663 assert_eq!(
664 atproto_client_metadata(&metadata, &Some(keyset.clone()))
665 .expect("failed to convert metadata"),
666 OAuthClientMetadata {
667 client_id: SmolStr::new_static("https://example.com/client_metadata.json"),
668 client_uri: Some(SmolStr::new_static("https://example.com")),
669 redirect_uris: vec![SmolStr::new_static("https://example.com/callback")],
670 application_type: Some(SmolStr::new_static("web")),
671 scope: Some(SmolStr::new_static("atproto")),
672 grant_types: Some(vec![SmolStr::new_static("authorization_code")]),
673 token_endpoint_auth_method: Some(AuthMethod::PrivateKeyJwt.into()),
674 dpop_bound_access_tokens: Some(true),
675 response_types: vec![SmolStr::new_static("code")],
676 jwks_uri: None,
677 jwks: Some(keyset.public_jwks()),
678 token_endpoint_auth_signing_alg: Some(SmolStr::new_static("ES256")),
679 client_name: None,
680 logo_uri: None,
681 tos_uri: None,
682 privacy_policy_uri: None,
683 }
684 );
685 }
686 }
687}