burn_p2p_auth_github/
lib.rs1#![forbid(unsafe_code)]
3
4use std::collections::BTreeMap;
5
6use burn_p2p_auth_external::ProviderMappedIdentityConnector;
7use burn_p2p_core::{AuthProvider, PrincipalId};
8use burn_p2p_security::{
9 AuthError, CallbackPayload, IdentityConnector, LoginRequest, LoginStart, PrincipalClaims,
10 PrincipalSession, StaticPrincipalRecord,
11};
12use chrono::Duration;
13
14#[derive(Debug)]
15pub struct GitHubIdentityConnector(ProviderMappedIdentityConnector);
17
18impl GitHubIdentityConnector {
19 pub fn new(
21 session_ttl: Duration,
22 principals: BTreeMap<PrincipalId, StaticPrincipalRecord>,
23 authorize_base_url: Option<String>,
24 ) -> Self {
25 Self(
26 ProviderMappedIdentityConnector::new(
27 AuthProvider::GitHub,
28 session_ttl,
29 principals,
30 authorize_base_url.or_else(|| Some("https://github.com/login/oauth/authorize".into())),
31 )
32 .with_token_url(Some("https://github.com/login/oauth/access_token".into()))
33 .with_userinfo_url(Some("https://api.github.com/user".into()))
34 .with_github_orgs_url(Some("https://api.github.com/user/orgs?per_page=100".into()))
35 .with_github_teams_url(Some("https://api.github.com/user/teams?per_page=100".into()))
36 .with_github_repo_access_url(Some(
37 "https://api.github.com/user/repos?per_page=100&affiliation=owner,collaborator,organization_member"
38 .into(),
39 )),
40 )
41 }
42
43 pub fn with_api_base_url(mut self, api_base_url: Option<String>) -> Self {
46 if let Some(api_base_url) = api_base_url {
47 let api_base_url = api_base_url.trim_end_matches('/').to_owned();
48 self.0 = self
49 .0
50 .with_userinfo_url(Some(format!("{api_base_url}/user")))
51 .with_github_orgs_url(Some(format!("{api_base_url}/user/orgs?per_page=100")))
52 .with_github_teams_url(Some(format!("{api_base_url}/user/teams?per_page=100")))
53 .with_github_repo_access_url(Some(format!(
54 "{api_base_url}/user/repos?per_page=100&affiliation=owner,collaborator,organization_member"
55 )));
56 }
57 self
58 }
59
60 pub fn with_exchange_url(mut self, exchange_url: Option<String>) -> Self {
62 self.0 = self.0.with_exchange_url(exchange_url);
63 self
64 }
65
66 pub fn with_token_url(mut self, token_url: Option<String>) -> Self {
68 self.0 = self.0.with_token_url(token_url);
69 self
70 }
71
72 pub fn with_client_credentials(
74 mut self,
75 client_id: Option<String>,
76 client_secret: Option<String>,
77 ) -> Self {
78 self.0 = self.0.with_client_credentials(client_id, client_secret);
79 self
80 }
81
82 pub fn with_redirect_uri(mut self, redirect_uri: Option<String>) -> Self {
84 self.0 = self.0.with_redirect_uri(redirect_uri);
85 self
86 }
87
88 pub fn with_userinfo_url(mut self, userinfo_url: Option<String>) -> Self {
90 self.0 = self.0.with_userinfo_url(userinfo_url);
91 self
92 }
93
94 pub fn with_refresh_url(mut self, refresh_url: Option<String>) -> Self {
96 self.0 = self.0.with_refresh_url(refresh_url);
97 self
98 }
99
100 pub fn with_revoke_url(mut self, revoke_url: Option<String>) -> Self {
102 self.0 = self.0.with_revoke_url(revoke_url);
103 self
104 }
105
106 pub fn with_jwks_url(mut self, jwks_url: Option<String>) -> Self {
109 self.0 = self.0.with_jwks_url(jwks_url);
110 self
111 }
112
113 pub fn with_persist_remote_tokens(mut self, persist_remote_tokens: bool) -> Self {
116 self.0 = self.0.with_persist_remote_tokens(persist_remote_tokens);
117 self
118 }
119}
120
121impl IdentityConnector for GitHubIdentityConnector {
122 fn begin_login(&self, req: LoginRequest) -> Result<LoginStart, AuthError> {
123 self.0.begin_login(req)
124 }
125
126 fn complete_login(&self, callback: CallbackPayload) -> Result<PrincipalSession, AuthError> {
127 self.0.complete_login(callback)
128 }
129
130 fn refresh(&self, session: &PrincipalSession) -> Result<PrincipalSession, AuthError> {
131 self.0.refresh(session)
132 }
133
134 fn fetch_claims(&self, session: &PrincipalSession) -> Result<PrincipalClaims, AuthError> {
135 self.0.fetch_claims(session)
136 }
137
138 fn revoke(&self, session: &PrincipalSession) -> Result<(), AuthError> {
139 self.0.revoke(session)
140 }
141}
142
143#[cfg(test)]
144mod tests {
145 use std::collections::{BTreeMap, BTreeSet};
146
147 use burn_p2p_core::{
148 AuthProvider, ExperimentScope, NetworkId, PeerRole, PeerRoleSet, PrincipalId,
149 };
150 use burn_p2p_security::{
151 CallbackPayload, IdentityConnector, LoginRequest, PrincipalClaims, StaticPrincipalRecord,
152 };
153 use chrono::{Duration, Utc};
154
155 use crate::GitHubIdentityConnector;
156
157 #[test]
158 fn github_connector_issues_github_sessions() {
159 let now = Utc::now();
160 let connector = GitHubIdentityConnector::new(
161 Duration::minutes(10),
162 BTreeMap::from([(
163 PrincipalId::new("alice"),
164 StaticPrincipalRecord {
165 claims: PrincipalClaims {
166 principal_id: PrincipalId::new("alice"),
167 provider: AuthProvider::GitHub,
168 display_name: "Alice".into(),
169 org_memberships: BTreeSet::new(),
170 group_memberships: BTreeSet::from(["contributors".into()]),
171 granted_roles: PeerRoleSet::new([PeerRole::TrainerGpu]),
172 granted_scopes: BTreeSet::from([ExperimentScope::Connect]),
173 custom_claims: BTreeMap::new(),
174 issued_at: now,
175 expires_at: now + Duration::hours(1),
176 },
177 allowed_networks: BTreeSet::from([NetworkId::new("network-a")]),
178 },
179 )]),
180 Some("https://github.example/authorize".into()),
181 );
182
183 let login = connector
184 .begin_login(LoginRequest {
185 network_id: NetworkId::new("network-a"),
186 principal_hint: Some("alice".into()),
187 requested_scopes: BTreeSet::from([ExperimentScope::Connect]),
188 })
189 .expect("login");
190 assert_eq!(login.provider, AuthProvider::GitHub);
191
192 let session = connector
193 .complete_login(CallbackPayload {
194 login_id: login.login_id,
195 state: login.state,
196 principal_id: Some(PrincipalId::new("alice")),
197 provider_code: None,
198 })
199 .expect("session");
200 assert_eq!(session.claims.provider, AuthProvider::GitHub);
201 }
202}