Skip to main content

agent_feed_auth_github/
lib.rs

1use serde::{Deserialize, Serialize};
2use std::collections::BTreeMap;
3use std::fs;
4use std::io::Read;
5use std::net::SocketAddr;
6use std::path::{Path, PathBuf};
7use time::{Duration, OffsetDateTime};
8
9#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
10pub struct GithubProfile {
11    pub id: String,
12    pub login: String,
13    pub name: Option<String>,
14    pub avatar_url: Option<String>,
15}
16
17#[derive(Debug, thiserror::Error)]
18pub enum GithubAuthError {
19    #[error("github auth callback path was not /callback/github")]
20    InvalidCallbackPath,
21    #[error("github auth callback state mismatch")]
22    StateMismatch,
23    #[error("github auth callback was missing {0}")]
24    MissingCallbackField(&'static str),
25    #[error("github auth callback contained invalid github_user_id: {0}")]
26    InvalidGithubUserId(String),
27    #[error("github auth callback contained invalid expires_at: {0}")]
28    InvalidExpiresAt(String),
29    #[error("io failed: {0}")]
30    Io(#[from] std::io::Error),
31    #[error("json failed: {0}")]
32    Json(#[from] serde_json::Error),
33}
34
35#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
36pub struct GithubCliAuthConfig {
37    pub edge_base_url: String,
38    pub callback_bind: SocketAddr,
39    pub callback_path: String,
40    pub requested_scopes: Vec<String>,
41    pub github_org: Option<String>,
42}
43
44impl Default for GithubCliAuthConfig {
45    fn default() -> Self {
46        Self {
47            edge_base_url: "https://api.feed.aberration.technology".to_string(),
48            callback_bind: SocketAddr::from(([127, 0, 0, 1], 0)),
49            callback_path: "/callback/github".to_string(),
50            requested_scopes: vec!["read:user".to_string()],
51            github_org: None,
52        }
53    }
54}
55
56#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
57pub struct GithubCliLoginStart {
58    pub authorize_url: String,
59    pub callback_url: String,
60    pub state: String,
61    pub bind: SocketAddr,
62}
63
64#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
65pub struct GithubCliCallback {
66    pub state: String,
67    pub github_user_id: u64,
68    pub login: String,
69    pub name: Option<String>,
70    pub avatar_url: Option<String>,
71    pub session_token: Option<String>,
72    pub scopes: Vec<String>,
73    pub github_orgs: Vec<String>,
74    pub expires_at: OffsetDateTime,
75}
76
77#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
78pub struct GithubAuthSession {
79    pub provider: String,
80    pub github_user_id: u64,
81    pub login: String,
82    pub name: Option<String>,
83    pub avatar_url: Option<String>,
84    pub session_token: Option<String>,
85    #[serde(default)]
86    pub scopes: Vec<String>,
87    #[serde(default)]
88    pub github_orgs: Vec<String>,
89    #[serde(with = "time::serde::rfc3339")]
90    pub issued_at: OffsetDateTime,
91    #[serde(with = "time::serde::rfc3339")]
92    pub expires_at: OffsetDateTime,
93    pub edge_base_url: String,
94}
95
96impl GithubAuthSession {
97    #[must_use]
98    pub fn profile(&self) -> GithubProfile {
99        GithubProfile {
100            id: self.github_user_id.to_string(),
101            login: self.login.clone(),
102            name: self.name.clone(),
103            avatar_url: self.avatar_url.clone(),
104        }
105    }
106
107    #[must_use]
108    pub fn is_expired_at(&self, now: OffsetDateTime) -> bool {
109        self.expires_at <= now
110    }
111}
112
113#[derive(Clone, Debug, PartialEq, Eq)]
114pub struct GithubSessionStore {
115    pub path: PathBuf,
116}
117
118impl GithubSessionStore {
119    #[must_use]
120    pub fn new(path: impl Into<PathBuf>) -> Self {
121        Self { path: path.into() }
122    }
123
124    pub fn load(&self) -> Result<Option<GithubAuthSession>, GithubAuthError> {
125        if !self.path.exists() {
126            return Ok(None);
127        }
128        let input = fs::read_to_string(&self.path)?;
129        Ok(Some(serde_json::from_str(&input)?))
130    }
131
132    pub fn save(&self, session: &GithubAuthSession) -> Result<(), GithubAuthError> {
133        if let Some(parent) = self.path.parent() {
134            fs::create_dir_all(parent)?;
135        }
136        let tmp = self.path.with_extension("json.tmp");
137        fs::write(&tmp, serde_json::to_vec_pretty(session)?)?;
138        set_owner_only_permissions(&tmp)?;
139        fs::rename(tmp, &self.path)?;
140        Ok(())
141    }
142
143    pub fn delete(&self) -> Result<bool, GithubAuthError> {
144        if self.path.exists() {
145            fs::remove_file(&self.path)?;
146            Ok(true)
147        } else {
148            Ok(false)
149        }
150    }
151}
152
153pub fn begin_cli_login(
154    config: &GithubCliAuthConfig,
155    bind: SocketAddr,
156) -> Result<GithubCliLoginStart, GithubAuthError> {
157    begin_cli_login_with_state(config, bind, &generate_state()?)
158}
159
160pub fn begin_cli_login_with_state(
161    config: &GithubCliAuthConfig,
162    bind: SocketAddr,
163    state: &str,
164) -> Result<GithubCliLoginStart, GithubAuthError> {
165    let callback_url = format!(
166        "http://{}{}",
167        bind,
168        normalize_callback_path(&config.callback_path)
169    );
170    let mut query = vec![
171        ("client", "feed-cli".to_string()),
172        ("state", state.to_string()),
173        ("redirect_uri", callback_url.clone()),
174    ];
175    if !config.requested_scopes.is_empty() {
176        query.push(("scope", config.requested_scopes.join(" ")));
177    }
178    if let Some(org) = config
179        .github_org
180        .as_deref()
181        .filter(|org| !org.trim().is_empty())
182    {
183        query.push(("org", org.trim().to_string()));
184    }
185    let authorize_url = format!(
186        "{}/auth/github?{}",
187        config.edge_base_url.trim_end_matches('/'),
188        encode_query(&query)
189    );
190    Ok(GithubCliLoginStart {
191        authorize_url,
192        callback_url,
193        state: state.to_string(),
194        bind,
195    })
196}
197
198pub fn complete_cli_login(
199    start: &GithubCliLoginStart,
200    callback: GithubCliCallback,
201    edge_base_url: impl Into<String>,
202) -> Result<GithubAuthSession, GithubAuthError> {
203    if callback.state != start.state {
204        return Err(GithubAuthError::StateMismatch);
205    }
206    Ok(GithubAuthSession {
207        provider: "github".to_string(),
208        github_user_id: callback.github_user_id,
209        login: callback.login,
210        name: callback.name,
211        avatar_url: callback.avatar_url,
212        session_token: callback.session_token,
213        scopes: callback.scopes,
214        github_orgs: callback.github_orgs,
215        issued_at: OffsetDateTime::now_utc(),
216        expires_at: callback.expires_at,
217        edge_base_url: edge_base_url.into(),
218    })
219}
220
221pub fn parse_cli_callback_request(
222    request_target: &str,
223) -> Result<GithubCliCallback, GithubAuthError> {
224    let (path, query) = request_target
225        .split_once('?')
226        .unwrap_or((request_target, ""));
227    if path != "/callback/github" {
228        return Err(GithubAuthError::InvalidCallbackPath);
229    }
230    let params = parse_query(query);
231    let state = required(&params, "state")?.to_string();
232    let id = required(&params, "github_user_id")
233        .or_else(|_| required(&params, "id"))?
234        .parse::<u64>()
235        .map_err(|err| GithubAuthError::InvalidGithubUserId(err.to_string()))?;
236    let login = required(&params, "login")?.to_string();
237    let expires_at = params
238        .get("expires_at")
239        .map(|value| {
240            OffsetDateTime::parse(value, &time::format_description::well_known::Rfc3339)
241                .map_err(|err| GithubAuthError::InvalidExpiresAt(err.to_string()))
242        })
243        .transpose()?
244        .unwrap_or_else(|| OffsetDateTime::now_utc() + Duration::days(7));
245
246    Ok(GithubCliCallback {
247        state,
248        github_user_id: id,
249        login,
250        name: params
251            .get("name")
252            .cloned()
253            .filter(|value| !value.is_empty()),
254        avatar_url: params
255            .get("avatar_url")
256            .or_else(|| params.get("avatar"))
257            .cloned()
258            .filter(|value| !value.is_empty()),
259        session_token: params
260            .get("session")
261            .or_else(|| params.get("session_token"))
262            .or_else(|| params.get("grant"))
263            .cloned()
264            .filter(|value| !value.is_empty()),
265        scopes: split_claims(
266            params.get("scopes").map(String::as_str).unwrap_or_default(),
267            ' ',
268        ),
269        github_orgs: split_claims(
270            params
271                .get("github_orgs")
272                .or_else(|| params.get("orgs"))
273                .map(String::as_str)
274                .unwrap_or_default(),
275            ',',
276        ),
277        expires_at,
278    })
279}
280
281fn split_claims(value: &str, separator: char) -> Vec<String> {
282    value
283        .split(separator)
284        .map(str::trim)
285        .filter(|value| !value.is_empty())
286        .map(ToString::to_string)
287        .collect()
288}
289
290pub fn browser_sign_in_url(edge_base_url: &str, return_to: &str) -> String {
291    format!(
292        "{}/auth/github?{}",
293        edge_base_url.trim_end_matches('/'),
294        encode_query(&[
295            ("client", "feed-browser".to_string()),
296            ("return_to", return_to.to_string()),
297        ])
298    )
299}
300
301fn required<'a>(
302    params: &'a BTreeMap<String, String>,
303    key: &'static str,
304) -> Result<&'a str, GithubAuthError> {
305    params
306        .get(key)
307        .map(String::as_str)
308        .filter(|value| !value.is_empty())
309        .ok_or(GithubAuthError::MissingCallbackField(key))
310}
311
312fn normalize_callback_path(path: &str) -> String {
313    if path.starts_with('/') {
314        path.to_string()
315    } else {
316        format!("/{path}")
317    }
318}
319
320fn generate_state() -> Result<String, GithubAuthError> {
321    let mut bytes = [0u8; 24];
322    match fs::File::open("/dev/urandom").and_then(|mut file| file.read_exact(&mut bytes)) {
323        Ok(()) => Ok(hex(&bytes)),
324        Err(_) => {
325            let now = OffsetDateTime::now_utc().unix_timestamp_nanos();
326            let pid = std::process::id();
327            Ok(hex(format!("{now}:{pid}").as_bytes()))
328        }
329    }
330}
331
332fn encode_query(params: &[(&str, String)]) -> String {
333    params
334        .iter()
335        .map(|(key, value)| format!("{}={}", url_encode(key), url_encode(value)))
336        .collect::<Vec<_>>()
337        .join("&")
338}
339
340fn parse_query(query: &str) -> BTreeMap<String, String> {
341    query
342        .split('&')
343        .filter(|part| !part.is_empty())
344        .filter_map(|part| {
345            let (key, value) = part.split_once('=').unwrap_or((part, ""));
346            Some((url_decode(key)?, url_decode(value)?))
347        })
348        .collect()
349}
350
351fn url_encode(value: &str) -> String {
352    value
353        .bytes()
354        .flat_map(|byte| match byte {
355            b'A'..=b'Z' | b'a'..=b'z' | b'0'..=b'9' | b'-' | b'.' | b'_' | b'~' => {
356                vec![byte as char]
357            }
358            b' ' => vec!['%', '2', '0'],
359            other => format!("%{other:02X}").chars().collect(),
360        })
361        .collect()
362}
363
364fn url_decode(value: &str) -> Option<String> {
365    let mut bytes = Vec::with_capacity(value.len());
366    let mut chars = value.as_bytes().iter().copied();
367    while let Some(byte) = chars.next() {
368        if byte == b'%' {
369            let hi = chars.next()?;
370            let lo = chars.next()?;
371            let hex = [hi, lo];
372            let text = std::str::from_utf8(&hex).ok()?;
373            bytes.push(u8::from_str_radix(text, 16).ok()?);
374        } else if byte == b'+' {
375            bytes.push(b' ');
376        } else {
377            bytes.push(byte);
378        }
379    }
380    String::from_utf8(bytes).ok()
381}
382
383fn hex(bytes: &[u8]) -> String {
384    bytes.iter().map(|byte| format!("{byte:02x}")).collect()
385}
386
387fn set_owner_only_permissions(path: &Path) -> Result<(), GithubAuthError> {
388    #[cfg(unix)]
389    {
390        use std::os::unix::fs::PermissionsExt;
391        fs::set_permissions(path, fs::Permissions::from_mode(0o600))?;
392    }
393    #[cfg(not(unix))]
394    {
395        let _ = path;
396    }
397    Ok(())
398}
399
400#[cfg(test)]
401mod tests {
402    use super::*;
403
404    #[test]
405    fn cli_login_start_uses_loopback_redirect_and_state() {
406        let config = GithubCliAuthConfig {
407            edge_base_url: "https://edge.example".to_string(),
408            ..GithubCliAuthConfig::default()
409        };
410        let start = begin_cli_login_with_state(
411            &config,
412            SocketAddr::from(([127, 0, 0, 1], 49152)),
413            "state one",
414        )
415        .expect("login starts");
416
417        assert!(
418            start
419                .authorize_url
420                .starts_with("https://edge.example/auth/github?")
421        );
422        assert!(start.authorize_url.contains("client=feed-cli"));
423        assert!(start.authorize_url.contains("state=state%20one"));
424        assert!(
425            start
426                .authorize_url
427                .contains("redirect_uri=http%3A%2F%2F127.0.0.1%3A49152%2Fcallback%2Fgithub")
428        );
429        assert_eq!(start.callback_url, "http://127.0.0.1:49152/callback/github");
430    }
431
432    #[test]
433    fn cli_login_start_can_request_org_authorization_scope() {
434        let config = GithubCliAuthConfig {
435            edge_base_url: "https://edge.example".to_string(),
436            requested_scopes: vec!["read:user".to_string(), "read:org".to_string()],
437            github_org: Some("aberration-technology".to_string()),
438            ..GithubCliAuthConfig::default()
439        };
440        let start = begin_cli_login_with_state(
441            &config,
442            SocketAddr::from(([127, 0, 0, 1], 49152)),
443            "state one",
444        )
445        .expect("login starts");
446
447        assert!(
448            start
449                .authorize_url
450                .contains("scope=read%3Auser%20read%3Aorg")
451        );
452        assert!(start.authorize_url.contains("org=aberration-technology"));
453    }
454
455    #[test]
456    fn callback_parses_profile_and_session() {
457        let callback = parse_cli_callback_request(
458            "/callback/github?state=s1&github_user_id=123&login=mosure&name=mosure&avatar_url=%2Favatar%2Fgithub%2F123&session=grant&scopes=read%3Auser%20read%3Aorg&github_orgs=aberration-technology",
459        )
460        .expect("callback parses");
461
462        assert_eq!(callback.state, "s1");
463        assert_eq!(callback.github_user_id, 123);
464        assert_eq!(callback.login, "mosure");
465        assert_eq!(callback.avatar_url.as_deref(), Some("/avatar/github/123"));
466        assert_eq!(callback.session_token.as_deref(), Some("grant"));
467        assert_eq!(callback.scopes, vec!["read:user", "read:org"]);
468        assert_eq!(callback.github_orgs, vec!["aberration-technology"]);
469    }
470
471    #[test]
472    fn complete_login_rejects_state_mismatch() {
473        let start = GithubCliLoginStart {
474            authorize_url: String::new(),
475            callback_url: String::new(),
476            state: "expected".to_string(),
477            bind: SocketAddr::from(([127, 0, 0, 1], 49152)),
478        };
479        let callback = GithubCliCallback {
480            state: "other".to_string(),
481            github_user_id: 123,
482            login: "mosure".to_string(),
483            name: None,
484            avatar_url: None,
485            session_token: None,
486            scopes: Vec::new(),
487            github_orgs: Vec::new(),
488            expires_at: OffsetDateTime::now_utc() + Duration::hours(1),
489        };
490
491        assert!(matches!(
492            complete_cli_login(&start, callback, "https://edge.example"),
493            Err(GithubAuthError::StateMismatch)
494        ));
495    }
496
497    #[test]
498    fn browser_sign_in_targets_edge_auth() {
499        let url = browser_sign_in_url("https://edge.example/", "https://app.example/mosure?all");
500        assert!(url.starts_with("https://edge.example/auth/github?"));
501        assert!(url.contains("client=feed-browser"));
502        assert!(url.contains("return_to=https%3A%2F%2Fapp.example%2Fmosure%3Fall"));
503    }
504}