Skip to main content

burn_p2p_auth_github/
lib.rs

1//! GitHub-backed identity connector implementations for burn_p2p.
2#![forbid(unsafe_code)]
3
4use std::collections::BTreeMap;
5
6use burn_p2p_auth_external::ProviderMappedIdentityConnector;
7use burn_p2p_core::{AuthProvider, PrincipalId};
8use burn_p2p_security::{
9    AuthError, CallbackPayload, IdentityConnector, LoginRequest, LoginStart, PrincipalClaims,
10    PrincipalSession, StaticPrincipalRecord,
11};
12use chrono::Duration;
13
14#[derive(Debug)]
15/// Represents a git hub identity connector.
16pub struct GitHubIdentityConnector(ProviderMappedIdentityConnector);
17
18impl GitHubIdentityConnector {
19    /// Creates a new value.
20    pub fn new(
21        session_ttl: Duration,
22        principals: BTreeMap<PrincipalId, StaticPrincipalRecord>,
23        authorize_base_url: Option<String>,
24    ) -> Self {
25        Self(
26            ProviderMappedIdentityConnector::new(
27            AuthProvider::GitHub,
28            session_ttl,
29            principals,
30            authorize_base_url.or_else(|| Some("https://github.com/login/oauth/authorize".into())),
31        )
32        .with_token_url(Some("https://github.com/login/oauth/access_token".into()))
33        .with_userinfo_url(Some("https://api.github.com/user".into()))
34        .with_github_orgs_url(Some("https://api.github.com/user/orgs?per_page=100".into()))
35        .with_github_teams_url(Some("https://api.github.com/user/teams?per_page=100".into()))
36        .with_github_repo_access_url(Some(
37            "https://api.github.com/user/repos?per_page=100&affiliation=owner,collaborator,organization_member"
38                .into(),
39        )),
40        )
41    }
42
43    /// Returns a copy configured with the GitHub API base URL used for live
44    /// user, org, team, and repository-access enrichment.
45    pub fn with_api_base_url(mut self, api_base_url: Option<String>) -> Self {
46        if let Some(api_base_url) = api_base_url {
47            let api_base_url = api_base_url.trim_end_matches('/').to_owned();
48            self.0 = self
49                .0
50                .with_userinfo_url(Some(format!("{api_base_url}/user")))
51                .with_github_orgs_url(Some(format!("{api_base_url}/user/orgs?per_page=100")))
52                .with_github_teams_url(Some(format!("{api_base_url}/user/teams?per_page=100")))
53                .with_github_repo_access_url(Some(format!(
54                    "{api_base_url}/user/repos?per_page=100&affiliation=owner,collaborator,organization_member"
55                )));
56        }
57        self
58    }
59
60    /// Returns a copy configured with the exchange URL.
61    pub fn with_exchange_url(mut self, exchange_url: Option<String>) -> Self {
62        self.0 = self.0.with_exchange_url(exchange_url);
63        self
64    }
65
66    /// Returns a copy configured with the token URL.
67    pub fn with_token_url(mut self, token_url: Option<String>) -> Self {
68        self.0 = self.0.with_token_url(token_url);
69        self
70    }
71
72    /// Returns a copy configured with the client credentials.
73    pub fn with_client_credentials(
74        mut self,
75        client_id: Option<String>,
76        client_secret: Option<String>,
77    ) -> Self {
78        self.0 = self.0.with_client_credentials(client_id, client_secret);
79        self
80    }
81
82    /// Returns a copy configured with an explicit redirect URI.
83    pub fn with_redirect_uri(mut self, redirect_uri: Option<String>) -> Self {
84        self.0 = self.0.with_redirect_uri(redirect_uri);
85        self
86    }
87
88    /// Returns a copy configured with the userinfo URL.
89    pub fn with_userinfo_url(mut self, userinfo_url: Option<String>) -> Self {
90        self.0 = self.0.with_userinfo_url(userinfo_url);
91        self
92    }
93
94    /// Returns a copy configured with the refresh URL.
95    pub fn with_refresh_url(mut self, refresh_url: Option<String>) -> Self {
96        self.0 = self.0.with_refresh_url(refresh_url);
97        self
98    }
99
100    /// Returns a copy configured with the revoke URL.
101    pub fn with_revoke_url(mut self, revoke_url: Option<String>) -> Self {
102        self.0 = self.0.with_revoke_url(revoke_url);
103        self
104    }
105
106    /// Returns a copy configured with a JWKS endpoint for oidc-style token
107    /// validation when a GitHub-compatible provider exposes one.
108    pub fn with_jwks_url(mut self, jwks_url: Option<String>) -> Self {
109        self.0 = self.0.with_jwks_url(jwks_url);
110        self
111    }
112
113    /// Returns a copy configured to persist upstream provider bearer/session
114    /// material in durable connector state.
115    pub fn with_persist_remote_tokens(mut self, persist_remote_tokens: bool) -> Self {
116        self.0 = self.0.with_persist_remote_tokens(persist_remote_tokens);
117        self
118    }
119}
120
121impl IdentityConnector for GitHubIdentityConnector {
122    fn begin_login(&self, req: LoginRequest) -> Result<LoginStart, AuthError> {
123        self.0.begin_login(req)
124    }
125
126    fn complete_login(&self, callback: CallbackPayload) -> Result<PrincipalSession, AuthError> {
127        self.0.complete_login(callback)
128    }
129
130    fn refresh(&self, session: &PrincipalSession) -> Result<PrincipalSession, AuthError> {
131        self.0.refresh(session)
132    }
133
134    fn fetch_claims(&self, session: &PrincipalSession) -> Result<PrincipalClaims, AuthError> {
135        self.0.fetch_claims(session)
136    }
137
138    fn revoke(&self, session: &PrincipalSession) -> Result<(), AuthError> {
139        self.0.revoke(session)
140    }
141}
142
143#[cfg(test)]
144mod tests {
145    use std::collections::{BTreeMap, BTreeSet};
146
147    use burn_p2p_core::{
148        AuthProvider, ExperimentScope, NetworkId, PeerRole, PeerRoleSet, PrincipalId,
149    };
150    use burn_p2p_security::{
151        CallbackPayload, IdentityConnector, LoginRequest, PrincipalClaims, StaticPrincipalRecord,
152    };
153    use chrono::{Duration, Utc};
154
155    use crate::GitHubIdentityConnector;
156
157    #[test]
158    fn github_connector_issues_github_sessions() {
159        let now = Utc::now();
160        let connector = GitHubIdentityConnector::new(
161            Duration::minutes(10),
162            BTreeMap::from([(
163                PrincipalId::new("alice"),
164                StaticPrincipalRecord {
165                    claims: PrincipalClaims {
166                        principal_id: PrincipalId::new("alice"),
167                        provider: AuthProvider::GitHub,
168                        display_name: "Alice".into(),
169                        org_memberships: BTreeSet::new(),
170                        group_memberships: BTreeSet::from(["contributors".into()]),
171                        granted_roles: PeerRoleSet::new([PeerRole::TrainerGpu]),
172                        granted_scopes: BTreeSet::from([ExperimentScope::Connect]),
173                        custom_claims: BTreeMap::new(),
174                        issued_at: now,
175                        expires_at: now + Duration::hours(1),
176                    },
177                    allowed_networks: BTreeSet::from([NetworkId::new("network-a")]),
178                },
179            )]),
180            Some("https://github.example/authorize".into()),
181        );
182
183        let login = connector
184            .begin_login(LoginRequest {
185                network_id: NetworkId::new("network-a"),
186                principal_hint: Some("alice".into()),
187                requested_scopes: BTreeSet::from([ExperimentScope::Connect]),
188            })
189            .expect("login");
190        assert_eq!(login.provider, AuthProvider::GitHub);
191
192        let session = connector
193            .complete_login(CallbackPayload {
194                login_id: login.login_id,
195                state: login.state,
196                principal_id: Some(PrincipalId::new("alice")),
197                provider_code: None,
198            })
199            .expect("session");
200        assert_eq!(session.claims.provider, AuthProvider::GitHub);
201    }
202}