atrium_oauth/
atproto.rs

1use crate::keyset::Keyset;
2use crate::types::{OAuthClientMetadata, TryIntoOAuthClientMetadata};
3use atrium_xrpc::http::uri::{InvalidUri, Scheme, Uri};
4use serde::{Deserialize, Serialize};
5use thiserror::Error;
6
7#[derive(Error, Debug)]
8pub enum Error {
9    #[error("`client_id` must be a valid URL")]
10    InvalidClientId,
11    #[error("`grant_types` must include `authorization_code`")]
12    InvalidGrantTypes,
13    #[error("`scope` must not include `atproto`")]
14    InvalidScope,
15    #[error("`redirect_uris` must not be empty")]
16    EmptyRedirectUris,
17    #[error("`private_key_jwt` auth method requires `jwks` keys")]
18    EmptyJwks,
19    #[error("`private_key_jwt` auth method requires `token_endpoint_auth_signing_alg`, otherwise must not be provided")]
20    AuthSigningAlg,
21    #[error(transparent)]
22    SerdeHtmlForm(#[from] serde_html_form::ser::Error),
23    #[error(transparent)]
24    LocalhostClient(#[from] LocalhostClientError),
25}
26
27#[derive(Error, Debug)]
28pub enum LocalhostClientError {
29    #[error("invalid redirect_uri: {0}")]
30    Invalid(#[from] InvalidUri),
31    #[error("loopback client_id must use `http:` redirect_uri")]
32    NotHttpScheme,
33    #[error("loopback client_id must not use `localhost` as redirect_uri hostname")]
34    Localhost,
35    #[error("loopback client_id must not use loopback addresses as redirect_uri")]
36    NotLoopbackHost,
37}
38
39pub type Result<T> = core::result::Result<T, Error>;
40
41#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
42#[serde(rename_all = "snake_case")]
43pub enum AuthMethod {
44    None,
45    // https://openid.net/specs/openid-connect-core-1_0.html#ClientAuthentication
46    PrivateKeyJwt,
47}
48
49impl From<AuthMethod> for String {
50    fn from(value: AuthMethod) -> Self {
51        match value {
52            AuthMethod::None => String::from("none"),
53            AuthMethod::PrivateKeyJwt => String::from("private_key_jwt"),
54        }
55    }
56}
57
58#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
59#[serde(rename_all = "snake_case")]
60pub enum GrantType {
61    AuthorizationCode,
62    RefreshToken,
63}
64
65impl From<GrantType> for String {
66    fn from(value: GrantType) -> Self {
67        match value {
68            GrantType::AuthorizationCode => String::from("authorization_code"),
69            GrantType::RefreshToken => String::from("refresh_token"),
70        }
71    }
72}
73
74#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
75#[serde(untagged)]
76pub enum Scope {
77    Known(KnownScope),
78    Unknown(String),
79}
80
81#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
82pub enum KnownScope {
83    #[serde(rename = "atproto")]
84    Atproto,
85    #[serde(rename = "transition:generic")]
86    TransitionGeneric,
87    #[serde(rename = "transition:chat.bsky")]
88    TransitionChatBsky,
89}
90
91impl AsRef<str> for Scope {
92    fn as_ref(&self) -> &str {
93        match self {
94            Self::Known(KnownScope::Atproto) => "atproto",
95            Self::Known(KnownScope::TransitionGeneric) => "transition:generic",
96            Self::Known(KnownScope::TransitionChatBsky) => "transition:chat.bsky",
97            Self::Unknown(value) => value,
98        }
99    }
100}
101
102#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize, Default)]
103pub struct AtprotoLocalhostClientMetadata {
104    pub redirect_uris: Option<Vec<String>>,
105    pub scopes: Option<Vec<Scope>>,
106}
107
108#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
109pub struct AtprotoClientMetadata {
110    pub client_id: String,
111    pub client_uri: Option<String>,
112    pub redirect_uris: Vec<String>,
113    pub token_endpoint_auth_method: AuthMethod,
114    pub grant_types: Vec<GrantType>,
115    pub scopes: Vec<Scope>,
116    pub jwks_uri: Option<String>,
117    pub token_endpoint_auth_signing_alg: Option<String>,
118}
119
120impl TryIntoOAuthClientMetadata for AtprotoLocalhostClientMetadata {
121    type Error = Error;
122
123    fn try_into_client_metadata(self, _: &Option<Keyset>) -> Result<OAuthClientMetadata> {
124        // validate redirect_uris
125        if let Some(redirect_uris) = &self.redirect_uris {
126            for redirect_uri in redirect_uris {
127                let uri = redirect_uri.parse::<Uri>().map_err(LocalhostClientError::Invalid)?;
128                if uri.scheme() != Some(&Scheme::HTTP) {
129                    return Err(Error::LocalhostClient(LocalhostClientError::NotHttpScheme));
130                }
131                if uri.host() == Some("localhost") {
132                    return Err(Error::LocalhostClient(LocalhostClientError::Localhost));
133                }
134                if uri.host().map_or(true, |host| host != "127.0.0.1" && host != "[::1]") {
135                    return Err(Error::LocalhostClient(LocalhostClientError::NotLoopbackHost));
136                }
137            }
138        }
139        // determine client_id
140        #[derive(serde::Serialize)]
141        struct Parameters {
142            #[serde(skip_serializing_if = "Option::is_none")]
143            redirect_uri: Option<Vec<String>>,
144            #[serde(skip_serializing_if = "Option::is_none")]
145            scope: Option<String>,
146        }
147        let query = serde_html_form::to_string(Parameters {
148            redirect_uri: self.redirect_uris.clone(),
149            scope: self
150                .scopes
151                .map(|scopes| scopes.iter().map(AsRef::as_ref).collect::<Vec<_>>().join(" ")),
152        })?;
153        let mut client_id = String::from("http://localhost");
154        if !query.is_empty() {
155            client_id.push_str(&format!("?{query}"));
156        }
157        Ok(OAuthClientMetadata {
158            client_id,
159            client_uri: None,
160            redirect_uris: self
161                .redirect_uris
162                .unwrap_or(vec![String::from("http://127.0.0.1/"), String::from("http://[::1]/")]),
163            scope: None,
164            grant_types: None, // will be set to `authorization_code` and `refresh_token`
165            token_endpoint_auth_method: Some(String::from("none")),
166            dpop_bound_access_tokens: None, // will be set to `true`
167            jwks_uri: None,
168            jwks: None,
169            token_endpoint_auth_signing_alg: None,
170        })
171    }
172}
173
174impl TryIntoOAuthClientMetadata for AtprotoClientMetadata {
175    type Error = Error;
176
177    fn try_into_client_metadata(self, keyset: &Option<Keyset>) -> Result<OAuthClientMetadata> {
178        if self.client_id.parse::<Uri>().is_err() {
179            return Err(Error::InvalidClientId);
180        }
181        if self.redirect_uris.is_empty() {
182            return Err(Error::EmptyRedirectUris);
183        }
184        if !self.grant_types.contains(&GrantType::AuthorizationCode) {
185            return Err(Error::InvalidGrantTypes);
186        }
187        if !self.scopes.contains(&Scope::Known(KnownScope::Atproto)) {
188            return Err(Error::InvalidScope);
189        }
190        let (jwks_uri, mut jwks) = (self.jwks_uri, None);
191        match self.token_endpoint_auth_method {
192            AuthMethod::None => {
193                if self.token_endpoint_auth_signing_alg.is_some() {
194                    return Err(Error::AuthSigningAlg);
195                }
196            }
197            AuthMethod::PrivateKeyJwt => {
198                if let Some(keyset) = keyset {
199                    if self.token_endpoint_auth_signing_alg.is_none() {
200                        return Err(Error::AuthSigningAlg);
201                    }
202                    if jwks_uri.is_none() {
203                        jwks = Some(keyset.public_jwks());
204                    }
205                } else {
206                    return Err(Error::EmptyJwks);
207                }
208            }
209        }
210        Ok(OAuthClientMetadata {
211            client_id: self.client_id,
212            client_uri: self.client_uri,
213            redirect_uris: self.redirect_uris,
214            token_endpoint_auth_method: Some(self.token_endpoint_auth_method.into()),
215            grant_types: Some(self.grant_types.into_iter().map(|v| v.into()).collect()),
216            scope: Some(self.scopes.iter().map(AsRef::as_ref).collect::<Vec<_>>().join(" ")),
217            dpop_bound_access_tokens: Some(true),
218            jwks_uri,
219            jwks,
220            token_endpoint_auth_signing_alg: self.token_endpoint_auth_signing_alg,
221        })
222    }
223}
224
225#[cfg(test)]
226mod tests {
227    use super::*;
228    use elliptic_curve::SecretKey;
229    use jose_jwk::{Jwk, Key, Parameters};
230    use p256::pkcs8::DecodePrivateKey;
231
232    const PRIVATE_KEY: &str = r#"-----BEGIN PRIVATE KEY-----
233MIGHAgEAMBMGByqGSM49AgEGCCqGSM49AwEHBG0wawIBAQQgED1AAgC7Fc9kPh5T
2344i4Tn+z+tc47W1zYgzXtyjJtD92hRANCAAT80DqC+Z/JpTO7/pkPBmWqIV1IGh1P
235gbGGr0pN+oSing7cZ0169JaRHTNh+0LNQXrFobInX6cj95FzEdRyT4T3
236-----END PRIVATE KEY-----"#;
237
238    #[test]
239    fn test_localhost_client_metadata_default() {
240        let metadata = AtprotoLocalhostClientMetadata::default();
241        assert_eq!(
242            metadata.try_into_client_metadata(&None).expect("failed to convert metadata"),
243            OAuthClientMetadata {
244                client_id: String::from("http://localhost"),
245                client_uri: None,
246                redirect_uris: vec![
247                    String::from("http://127.0.0.1/"),
248                    String::from("http://[::1]/"),
249                ],
250                scope: None,
251                grant_types: None,
252                token_endpoint_auth_method: Some(AuthMethod::None.into()),
253                dpop_bound_access_tokens: None,
254                jwks_uri: None,
255                jwks: None,
256                token_endpoint_auth_signing_alg: None,
257            }
258        );
259    }
260
261    #[test]
262    fn test_localhost_client_metadata_custom() {
263        let metadata = AtprotoLocalhostClientMetadata {
264            redirect_uris: Some(vec![
265                String::from("http://127.0.0.1/callback"),
266                String::from("http://[::1]/callback"),
267            ]),
268            scopes: Some(vec![
269                Scope::Known(KnownScope::Atproto),
270                Scope::Known(KnownScope::TransitionGeneric),
271                Scope::Unknown(String::from("unknown")),
272            ]),
273        };
274        assert_eq!(
275            metadata.try_into_client_metadata(&None).expect("failed to convert metadata"),
276            OAuthClientMetadata {
277                client_id: String::from("http://localhost?redirect_uri=http%3A%2F%2F127.0.0.1%2Fcallback&redirect_uri=http%3A%2F%2F%5B%3A%3A1%5D%2Fcallback&scope=atproto+transition%3Ageneric+unknown"),
278                client_uri: None,
279                redirect_uris: vec![
280                    String::from("http://127.0.0.1/callback"),
281                    String::from("http://[::1]/callback"),
282                    ],
283                scope: None,
284                grant_types: None,
285                token_endpoint_auth_method: Some(AuthMethod::None.into()),
286                dpop_bound_access_tokens: None,
287                jwks_uri: None,
288                jwks: None,
289                token_endpoint_auth_signing_alg: None,
290            }
291        );
292    }
293
294    #[test]
295    fn test_localhost_client_metadata_invalid() {
296        {
297            let metadata = AtprotoLocalhostClientMetadata {
298                redirect_uris: Some(vec![String::from("http://")]),
299                ..Default::default()
300            };
301            let err = metadata.try_into_client_metadata(&None).expect_err("expected to fail");
302            assert!(matches!(err, Error::LocalhostClient(LocalhostClientError::Invalid(_))));
303        }
304        {
305            let metadata = AtprotoLocalhostClientMetadata {
306                redirect_uris: Some(vec![String::from("https://127.0.0.1/")]),
307                ..Default::default()
308            };
309            let err = metadata.try_into_client_metadata(&None).expect_err("expected to fail");
310            assert!(matches!(err, Error::LocalhostClient(LocalhostClientError::NotHttpScheme)));
311        }
312        {
313            let metadata = AtprotoLocalhostClientMetadata {
314                redirect_uris: Some(vec![String::from("http://localhost:8000/")]),
315                ..Default::default()
316            };
317            let err = metadata.try_into_client_metadata(&None).expect_err("expected to fail");
318            assert!(matches!(err, Error::LocalhostClient(LocalhostClientError::Localhost)));
319        }
320        {
321            let metadata = AtprotoLocalhostClientMetadata {
322                redirect_uris: Some(vec![String::from("http://192.168.0.0/")]),
323                ..Default::default()
324            };
325            let err = metadata.try_into_client_metadata(&None).expect_err("expected to fail");
326            assert!(matches!(err, Error::LocalhostClient(LocalhostClientError::NotLoopbackHost)));
327        }
328    }
329
330    #[test]
331    fn test_client_metadata() {
332        let metadata = AtprotoClientMetadata {
333            client_id: String::from("https://example.com/client_metadata.json"),
334            client_uri: Some(String::from("https://example.com")),
335            redirect_uris: vec![String::from("https://example.com/callback")],
336            token_endpoint_auth_method: AuthMethod::PrivateKeyJwt,
337            grant_types: vec![GrantType::AuthorizationCode],
338            scopes: vec![Scope::Known(KnownScope::Atproto)],
339            jwks_uri: None,
340            token_endpoint_auth_signing_alg: Some(String::from("ES256")),
341        };
342        {
343            let metadata = metadata.clone();
344            let err = metadata.try_into_client_metadata(&None).expect_err("expected to fail");
345            assert!(matches!(err, Error::EmptyJwks));
346        }
347        {
348            let metadata = metadata.clone();
349            let secret_key = SecretKey::<p256::NistP256>::from_pkcs8_pem(PRIVATE_KEY)
350                .expect("failed to parse private key");
351            let keys = vec![Jwk {
352                key: Key::from(&secret_key.into()),
353                prm: Parameters { kid: Some(String::from("kid00")), ..Default::default() },
354            }];
355            let keyset = Keyset::try_from(keys.clone()).expect("failed to create keyset");
356            assert_eq!(
357                metadata
358                    .try_into_client_metadata(&Some(keyset.clone()))
359                    .expect("failed to convert metadata"),
360                OAuthClientMetadata {
361                    client_id: String::from("https://example.com/client_metadata.json"),
362                    client_uri: Some(String::from("https://example.com")),
363                    redirect_uris: vec![String::from("https://example.com/callback"),],
364                    scope: Some(String::from("atproto")),
365                    grant_types: Some(vec![String::from("authorization_code")]),
366                    token_endpoint_auth_method: Some(AuthMethod::PrivateKeyJwt.into()),
367                    dpop_bound_access_tokens: Some(true),
368                    jwks_uri: None,
369                    jwks: Some(keyset.public_jwks()),
370                    token_endpoint_auth_signing_alg: Some(String::from("ES256")),
371                }
372            );
373        }
374    }
375
376    #[test]
377    fn test_scope_serde() {
378        #[derive(Debug, PartialEq, Eq, Serialize, Deserialize)]
379        struct Scopes {
380            scopes: Vec<Scope>,
381        }
382
383        let scopes = Scopes {
384            scopes: vec![
385                Scope::Known(KnownScope::Atproto),
386                Scope::Known(KnownScope::TransitionGeneric),
387                Scope::Unknown(String::from("unknown")),
388            ],
389        };
390        let json = serde_json::to_string(&scopes).expect("failed to serialize scopes");
391        assert_eq!(json, r#"{"scopes":["atproto","transition:generic","unknown"]}"#);
392        let deserialized =
393            serde_json::from_str::<Scopes>(&json).expect("failed to deserialize scopes");
394        assert_eq!(deserialized, scopes);
395    }
396}