thal 0.0.1

Reactive semantic runtime — molecules, reactions, and effect actors for building LLM-backed applications as dataflow programs.
Documentation
//! OAuth 2.0 Device Authorization Grant (RFC 8628).
//!
//! Generic implementation: any provider that publishes a device-flow endpoint
//! pair plus a public client ID is supported via a configured
//! `DeviceFlowConfig`. `setup.rs`'s `github-copilot` and `oauth-custom` arms
//! drive this module.

use crate::Error;
use serde::Deserialize;
use std::time::{Duration, Instant};

#[derive(Clone, Debug)]
pub struct DeviceFlowConfig {
    pub name: String,
    pub device_authorization_endpoint: String,
    pub token_endpoint: String,
    pub client_id: String,
    pub scope: Option<String>,
}

#[derive(Clone, Debug, PartialEq, Eq)]
pub struct TokenResult {
    pub access_token: String,
    pub refresh_token: Option<String>,
    pub expires_in: Option<u64>,
}

#[derive(Deserialize)]
struct DeviceCodeResponse {
    device_code: String,
    user_code: String,
    verification_uri: String,
    #[serde(default)]
    verification_uri_complete: Option<String>,
    #[serde(default)]
    expires_in: Option<u64>,
    #[serde(default)]
    interval: Option<u64>,
}

#[derive(Deserialize)]
struct TokenSuccess {
    access_token: String,
    #[serde(default)]
    refresh_token: Option<String>,
    #[serde(default)]
    expires_in: Option<u64>,
}

#[derive(Deserialize)]
struct TokenError {
    error: String,
    #[serde(default)]
    error_description: Option<String>,
}

/// Decision returned by the polling state machine after parsing one token
/// endpoint response. Pulled out into its own type so the state-machine
/// logic is testable without a network.
#[derive(Debug, PartialEq)]
enum PollDecision {
    Done(TokenResult),
    Continue,
    SlowDown,
    Abort(String),
}

fn classify_token_response(status: u16, body: &str) -> PollDecision {
    if (200..300).contains(&status) {
        return match serde_json::from_str::<TokenSuccess>(body) {
            Ok(t) => PollDecision::Done(TokenResult {
                access_token: t.access_token,
                refresh_token: t.refresh_token,
                expires_in: t.expires_in,
            }),
            Err(e) => PollDecision::Abort(format!("parse token success: {e}")),
        };
    }
    let err: TokenError = serde_json::from_str(body).unwrap_or(TokenError {
        error: "unknown".into(),
        error_description: Some(body.to_string()),
    });
    match err.error.as_str() {
        "authorization_pending" => PollDecision::Continue,
        "slow_down" => PollDecision::SlowDown,
        "access_denied" => PollDecision::Abort("authorization denied by user".into()),
        "expired_token" => PollDecision::Abort("device code expired before authorization".into()),
        other => PollDecision::Abort(format!(
            "oauth error: {other}{}",
            err.error_description
                .map(|d| format!(" ({d})"))
                .unwrap_or_default()
        )),
    }
}

pub async fn run_device_flow(config: &DeviceFlowConfig) -> Result<TokenResult, Error> {
    let client = reqwest::Client::new();

    // Step 1: device authorization
    let mut params: Vec<(&str, &str)> = vec![("client_id", config.client_id.as_str())];
    if let Some(scope) = &config.scope {
        params.push(("scope", scope.as_str()));
    }
    let device_resp = client
        .post(&config.device_authorization_endpoint)
        .header("Accept", "application/json")
        .form(&params)
        .send()
        .await
        .map_err(|e| {
            Error::Runtime(format!("{}: device authorization request: {e}", config.name))
        })?;
    if !device_resp.status().is_success() {
        let status = device_resp.status();
        let body = device_resp.text().await.unwrap_or_default();
        return Err(Error::Runtime(format!(
            "{}: device authorization {status}: {body}",
            config.name
        )));
    }
    let device: DeviceCodeResponse = device_resp.json().await.map_err(|e| {
        Error::Runtime(format!("{}: parse device response: {e}", config.name))
    })?;

    // Step 2: tell the user
    eprintln!();
    eprintln!("Visit {} in your browser", device.verification_uri);
    eprintln!("Enter code: {}", device.user_code);
    if let Some(complete) = &device.verification_uri_complete {
        eprintln!("(or open: {complete})");
    }
    let _ = open_browser_best_effort(
        device
            .verification_uri_complete
            .as_deref()
            .unwrap_or(&device.verification_uri),
    );
    let initial_interval = device.interval.unwrap_or(5);
    let lifetime = device.expires_in.unwrap_or(15 * 60);
    eprintln!(
        "\nwaiting for authorization (polling every {initial_interval}s, expires in {}m)…",
        lifetime / 60
    );

    // Step 3: poll
    let mut interval = Duration::from_secs(initial_interval);
    let deadline = Instant::now() + Duration::from_secs(lifetime);

    loop {
        if Instant::now() > deadline {
            return Err(Error::Runtime(format!(
                "{}: device authorization expired before user completed sign-in",
                config.name
            )));
        }
        tokio::time::sleep(interval).await;

        let resp = client
            .post(&config.token_endpoint)
            .header("Accept", "application/json")
            .form(&[
                (
                    "grant_type",
                    "urn:ietf:params:oauth:grant-type:device_code",
                ),
                ("device_code", device.device_code.as_str()),
                ("client_id", config.client_id.as_str()),
            ])
            .send()
            .await
            .map_err(|e| Error::Runtime(format!("{}: token poll request: {e}", config.name)))?;

        let status = resp.status().as_u16();
        let body = resp.text().await.unwrap_or_default();

        match classify_token_response(status, &body) {
            PollDecision::Done(result) => return Ok(result),
            PollDecision::Continue => continue,
            PollDecision::SlowDown => {
                interval += Duration::from_secs(5);
                eprintln!("(server requested slow down; polling every {}s)", interval.as_secs());
            }
            PollDecision::Abort(msg) => {
                return Err(Error::Runtime(format!("{}: {msg}", config.name)))
            }
        }
    }
}

fn open_browser_best_effort(url: &str) -> std::io::Result<()> {
    #[cfg(target_os = "linux")]
    let cmd = "xdg-open";
    #[cfg(target_os = "macos")]
    let cmd = "open";
    #[cfg(target_os = "windows")]
    let cmd = "cmd";
    #[cfg(target_os = "windows")]
    let args: &[&str] = &["/C", "start", url];
    #[cfg(not(target_os = "windows"))]
    let args: &[&str] = &[url];

    std::process::Command::new(cmd).args(args).spawn().map(|_| ())
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn classifies_success_with_full_fields() {
        let body = r#"{"access_token":"sk-x","refresh_token":"r","expires_in":3600}"#;
        match classify_token_response(200, body) {
            PollDecision::Done(t) => {
                assert_eq!(t.access_token, "sk-x");
                assert_eq!(t.refresh_token.as_deref(), Some("r"));
                assert_eq!(t.expires_in, Some(3600));
            }
            other => panic!("expected Done, got {other:?}"),
        }
    }

    #[test]
    fn classifies_success_with_minimal_fields() {
        let body = r#"{"access_token":"sk-x"}"#;
        match classify_token_response(200, body) {
            PollDecision::Done(t) => {
                assert_eq!(t.access_token, "sk-x");
                assert!(t.refresh_token.is_none());
                assert!(t.expires_in.is_none());
            }
            other => panic!("expected Done, got {other:?}"),
        }
    }

    #[test]
    fn classifies_authorization_pending_as_continue() {
        let body = r#"{"error":"authorization_pending"}"#;
        assert_eq!(
            classify_token_response(400, body),
            PollDecision::Continue
        );
    }

    #[test]
    fn classifies_slow_down() {
        let body = r#"{"error":"slow_down"}"#;
        assert_eq!(
            classify_token_response(400, body),
            PollDecision::SlowDown
        );
    }

    #[test]
    fn classifies_access_denied_as_abort() {
        let body = r#"{"error":"access_denied"}"#;
        match classify_token_response(400, body) {
            PollDecision::Abort(msg) => assert!(msg.contains("denied")),
            other => panic!("expected Abort, got {other:?}"),
        }
    }

    #[test]
    fn classifies_expired_token_as_abort() {
        let body = r#"{"error":"expired_token"}"#;
        match classify_token_response(400, body) {
            PollDecision::Abort(msg) => assert!(msg.contains("expired")),
            other => panic!("expected Abort, got {other:?}"),
        }
    }

    #[test]
    fn classifies_unknown_error_with_description() {
        let body = r#"{"error":"weird","error_description":"specific message"}"#;
        match classify_token_response(400, body) {
            PollDecision::Abort(msg) => {
                assert!(msg.contains("weird"));
                assert!(msg.contains("specific message"));
            }
            other => panic!("expected Abort, got {other:?}"),
        }
    }

    #[test]
    fn classifies_garbage_body_as_abort() {
        match classify_token_response(400, "<html>500 internal</html>") {
            PollDecision::Abort(msg) => assert!(msg.contains("unknown")),
            other => panic!("expected Abort, got {other:?}"),
        }
    }

    #[test]
    fn classifies_garbage_success_body_as_abort() {
        match classify_token_response(200, "{not json}") {
            PollDecision::Abort(msg) => assert!(msg.contains("parse")),
            other => panic!("expected Abort, got {other:?}"),
        }
    }
}