burn_p2p_auth_github 0.21.0-pre.25

Optional GitHub auth connector for burn_p2p.
Documentation
//! GitHub-backed identity connector implementations for burn_p2p.
#![forbid(unsafe_code)]

use std::collections::BTreeMap;

use burn_p2p_auth_external::ProviderMappedIdentityConnector;
use burn_p2p_core::{AuthProvider, PrincipalId};
use burn_p2p_security::{
    AuthError, CallbackPayload, IdentityConnector, LoginRequest, LoginStart, PrincipalClaims,
    PrincipalSession, StaticPrincipalRecord,
};
use chrono::Duration;

#[derive(Debug)]
/// Represents a git hub identity connector.
pub struct GitHubIdentityConnector(ProviderMappedIdentityConnector);

impl GitHubIdentityConnector {
    /// Creates a new value.
    pub fn new(
        session_ttl: Duration,
        principals: BTreeMap<PrincipalId, StaticPrincipalRecord>,
        authorize_base_url: Option<String>,
    ) -> Self {
        Self(
            ProviderMappedIdentityConnector::new(
            AuthProvider::GitHub,
            session_ttl,
            principals,
            authorize_base_url.or_else(|| Some("https://github.com/login/oauth/authorize".into())),
        )
        .with_token_url(Some("https://github.com/login/oauth/access_token".into()))
        .with_userinfo_url(Some("https://api.github.com/user".into()))
        .with_github_orgs_url(Some("https://api.github.com/user/orgs?per_page=100".into()))
        .with_github_teams_url(Some("https://api.github.com/user/teams?per_page=100".into()))
        .with_github_repo_access_url(Some(
            "https://api.github.com/user/repos?per_page=100&affiliation=owner,collaborator,organization_member"
                .into(),
        )),
        )
    }

    /// Returns a copy configured with the GitHub API base URL used for live
    /// user, org, team, and repository-access enrichment.
    pub fn with_api_base_url(mut self, api_base_url: Option<String>) -> Self {
        if let Some(api_base_url) = api_base_url {
            let api_base_url = api_base_url.trim_end_matches('/').to_owned();
            self.0 = self
                .0
                .with_userinfo_url(Some(format!("{api_base_url}/user")))
                .with_github_orgs_url(Some(format!("{api_base_url}/user/orgs?per_page=100")))
                .with_github_teams_url(Some(format!("{api_base_url}/user/teams?per_page=100")))
                .with_github_repo_access_url(Some(format!(
                    "{api_base_url}/user/repos?per_page=100&affiliation=owner,collaborator,organization_member"
                )));
        }
        self
    }

    /// Returns a copy configured with the exchange URL.
    pub fn with_exchange_url(mut self, exchange_url: Option<String>) -> Self {
        self.0 = self.0.with_exchange_url(exchange_url);
        self
    }

    /// Returns a copy configured with the token URL.
    pub fn with_token_url(mut self, token_url: Option<String>) -> Self {
        self.0 = self.0.with_token_url(token_url);
        self
    }

    /// Returns a copy configured with the client credentials.
    pub fn with_client_credentials(
        mut self,
        client_id: Option<String>,
        client_secret: Option<String>,
    ) -> Self {
        self.0 = self.0.with_client_credentials(client_id, client_secret);
        self
    }

    /// Returns a copy configured with an explicit redirect URI.
    pub fn with_redirect_uri(mut self, redirect_uri: Option<String>) -> Self {
        self.0 = self.0.with_redirect_uri(redirect_uri);
        self
    }

    /// Returns a copy configured with the userinfo URL.
    pub fn with_userinfo_url(mut self, userinfo_url: Option<String>) -> Self {
        self.0 = self.0.with_userinfo_url(userinfo_url);
        self
    }

    /// Returns a copy configured with the refresh URL.
    pub fn with_refresh_url(mut self, refresh_url: Option<String>) -> Self {
        self.0 = self.0.with_refresh_url(refresh_url);
        self
    }

    /// Returns a copy configured with the revoke URL.
    pub fn with_revoke_url(mut self, revoke_url: Option<String>) -> Self {
        self.0 = self.0.with_revoke_url(revoke_url);
        self
    }

    /// Returns a copy configured with a JWKS endpoint for oidc-style token
    /// validation when a GitHub-compatible provider exposes one.
    pub fn with_jwks_url(mut self, jwks_url: Option<String>) -> Self {
        self.0 = self.0.with_jwks_url(jwks_url);
        self
    }

    /// Returns a copy configured to persist upstream provider bearer/session
    /// material in durable connector state.
    pub fn with_persist_remote_tokens(mut self, persist_remote_tokens: bool) -> Self {
        self.0 = self.0.with_persist_remote_tokens(persist_remote_tokens);
        self
    }
}

impl IdentityConnector for GitHubIdentityConnector {
    fn begin_login(&self, req: LoginRequest) -> Result<LoginStart, AuthError> {
        self.0.begin_login(req)
    }

    fn complete_login(&self, callback: CallbackPayload) -> Result<PrincipalSession, AuthError> {
        self.0.complete_login(callback)
    }

    fn refresh(&self, session: &PrincipalSession) -> Result<PrincipalSession, AuthError> {
        self.0.refresh(session)
    }

    fn fetch_claims(&self, session: &PrincipalSession) -> Result<PrincipalClaims, AuthError> {
        self.0.fetch_claims(session)
    }

    fn revoke(&self, session: &PrincipalSession) -> Result<(), AuthError> {
        self.0.revoke(session)
    }
}

#[cfg(test)]
mod tests {
    use std::collections::{BTreeMap, BTreeSet};

    use burn_p2p_core::{
        AuthProvider, ExperimentScope, NetworkId, PeerRole, PeerRoleSet, PrincipalId,
    };
    use burn_p2p_security::{
        CallbackPayload, IdentityConnector, LoginRequest, PrincipalClaims, StaticPrincipalRecord,
    };
    use chrono::{Duration, Utc};

    use crate::GitHubIdentityConnector;

    #[test]
    fn github_connector_issues_github_sessions() {
        let now = Utc::now();
        let connector = GitHubIdentityConnector::new(
            Duration::minutes(10),
            BTreeMap::from([(
                PrincipalId::new("alice"),
                StaticPrincipalRecord {
                    claims: PrincipalClaims {
                        principal_id: PrincipalId::new("alice"),
                        provider: AuthProvider::GitHub,
                        display_name: "Alice".into(),
                        org_memberships: BTreeSet::new(),
                        group_memberships: BTreeSet::from(["contributors".into()]),
                        granted_roles: PeerRoleSet::new([PeerRole::TrainerGpu]),
                        granted_scopes: BTreeSet::from([ExperimentScope::Connect]),
                        custom_claims: BTreeMap::new(),
                        issued_at: now,
                        expires_at: now + Duration::hours(1),
                    },
                    allowed_networks: BTreeSet::from([NetworkId::new("network-a")]),
                },
            )]),
            Some("https://github.example/authorize".into()),
        );

        let login = connector
            .begin_login(LoginRequest {
                network_id: NetworkId::new("network-a"),
                principal_hint: Some("alice".into()),
                requested_scopes: BTreeSet::from([ExperimentScope::Connect]),
            })
            .expect("login");
        assert_eq!(login.provider, AuthProvider::GitHub);

        let session = connector
            .complete_login(CallbackPayload {
                login_id: login.login_id,
                state: login.state,
                principal_id: Some(PrincipalId::new("alice")),
                provider_code: None,
            })
            .expect("session");
        assert_eq!(session.claims.provider, AuthProvider::GitHub);
    }
}