use crate::Error;
use serde::Deserialize;
use std::time::{Duration, Instant};
#[derive(Clone, Debug)]
pub struct DeviceFlowConfig {
pub name: String,
pub device_authorization_endpoint: String,
pub token_endpoint: String,
pub client_id: String,
pub scope: Option<String>,
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct TokenResult {
pub access_token: String,
pub refresh_token: Option<String>,
pub expires_in: Option<u64>,
}
#[derive(Deserialize)]
struct DeviceCodeResponse {
device_code: String,
user_code: String,
verification_uri: String,
#[serde(default)]
verification_uri_complete: Option<String>,
#[serde(default)]
expires_in: Option<u64>,
#[serde(default)]
interval: Option<u64>,
}
#[derive(Deserialize)]
struct TokenSuccess {
access_token: String,
#[serde(default)]
refresh_token: Option<String>,
#[serde(default)]
expires_in: Option<u64>,
}
#[derive(Deserialize)]
struct TokenError {
error: String,
#[serde(default)]
error_description: Option<String>,
}
#[derive(Debug, PartialEq)]
enum PollDecision {
Done(TokenResult),
Continue,
SlowDown,
Abort(String),
}
fn classify_token_response(status: u16, body: &str) -> PollDecision {
if (200..300).contains(&status) {
return match serde_json::from_str::<TokenSuccess>(body) {
Ok(t) => PollDecision::Done(TokenResult {
access_token: t.access_token,
refresh_token: t.refresh_token,
expires_in: t.expires_in,
}),
Err(e) => PollDecision::Abort(format!("parse token success: {e}")),
};
}
let err: TokenError = serde_json::from_str(body).unwrap_or(TokenError {
error: "unknown".into(),
error_description: Some(body.to_string()),
});
match err.error.as_str() {
"authorization_pending" => PollDecision::Continue,
"slow_down" => PollDecision::SlowDown,
"access_denied" => PollDecision::Abort("authorization denied by user".into()),
"expired_token" => PollDecision::Abort("device code expired before authorization".into()),
other => PollDecision::Abort(format!(
"oauth error: {other}{}",
err.error_description
.map(|d| format!(" ({d})"))
.unwrap_or_default()
)),
}
}
pub async fn run_device_flow(config: &DeviceFlowConfig) -> Result<TokenResult, Error> {
let client = reqwest::Client::new();
let mut params: Vec<(&str, &str)> = vec![("client_id", config.client_id.as_str())];
if let Some(scope) = &config.scope {
params.push(("scope", scope.as_str()));
}
let device_resp = client
.post(&config.device_authorization_endpoint)
.header("Accept", "application/json")
.form(¶ms)
.send()
.await
.map_err(|e| {
Error::Runtime(format!("{}: device authorization request: {e}", config.name))
})?;
if !device_resp.status().is_success() {
let status = device_resp.status();
let body = device_resp.text().await.unwrap_or_default();
return Err(Error::Runtime(format!(
"{}: device authorization {status}: {body}",
config.name
)));
}
let device: DeviceCodeResponse = device_resp.json().await.map_err(|e| {
Error::Runtime(format!("{}: parse device response: {e}", config.name))
})?;
eprintln!();
eprintln!("Visit {} in your browser", device.verification_uri);
eprintln!("Enter code: {}", device.user_code);
if let Some(complete) = &device.verification_uri_complete {
eprintln!("(or open: {complete})");
}
let _ = open_browser_best_effort(
device
.verification_uri_complete
.as_deref()
.unwrap_or(&device.verification_uri),
);
let initial_interval = device.interval.unwrap_or(5);
let lifetime = device.expires_in.unwrap_or(15 * 60);
eprintln!(
"\nwaiting for authorization (polling every {initial_interval}s, expires in {}m)…",
lifetime / 60
);
let mut interval = Duration::from_secs(initial_interval);
let deadline = Instant::now() + Duration::from_secs(lifetime);
loop {
if Instant::now() > deadline {
return Err(Error::Runtime(format!(
"{}: device authorization expired before user completed sign-in",
config.name
)));
}
tokio::time::sleep(interval).await;
let resp = client
.post(&config.token_endpoint)
.header("Accept", "application/json")
.form(&[
(
"grant_type",
"urn:ietf:params:oauth:grant-type:device_code",
),
("device_code", device.device_code.as_str()),
("client_id", config.client_id.as_str()),
])
.send()
.await
.map_err(|e| Error::Runtime(format!("{}: token poll request: {e}", config.name)))?;
let status = resp.status().as_u16();
let body = resp.text().await.unwrap_or_default();
match classify_token_response(status, &body) {
PollDecision::Done(result) => return Ok(result),
PollDecision::Continue => continue,
PollDecision::SlowDown => {
interval += Duration::from_secs(5);
eprintln!("(server requested slow down; polling every {}s)", interval.as_secs());
}
PollDecision::Abort(msg) => {
return Err(Error::Runtime(format!("{}: {msg}", config.name)))
}
}
}
}
fn open_browser_best_effort(url: &str) -> std::io::Result<()> {
#[cfg(target_os = "linux")]
let cmd = "xdg-open";
#[cfg(target_os = "macos")]
let cmd = "open";
#[cfg(target_os = "windows")]
let cmd = "cmd";
#[cfg(target_os = "windows")]
let args: &[&str] = &["/C", "start", url];
#[cfg(not(target_os = "windows"))]
let args: &[&str] = &[url];
std::process::Command::new(cmd).args(args).spawn().map(|_| ())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn classifies_success_with_full_fields() {
let body = r#"{"access_token":"sk-x","refresh_token":"r","expires_in":3600}"#;
match classify_token_response(200, body) {
PollDecision::Done(t) => {
assert_eq!(t.access_token, "sk-x");
assert_eq!(t.refresh_token.as_deref(), Some("r"));
assert_eq!(t.expires_in, Some(3600));
}
other => panic!("expected Done, got {other:?}"),
}
}
#[test]
fn classifies_success_with_minimal_fields() {
let body = r#"{"access_token":"sk-x"}"#;
match classify_token_response(200, body) {
PollDecision::Done(t) => {
assert_eq!(t.access_token, "sk-x");
assert!(t.refresh_token.is_none());
assert!(t.expires_in.is_none());
}
other => panic!("expected Done, got {other:?}"),
}
}
#[test]
fn classifies_authorization_pending_as_continue() {
let body = r#"{"error":"authorization_pending"}"#;
assert_eq!(
classify_token_response(400, body),
PollDecision::Continue
);
}
#[test]
fn classifies_slow_down() {
let body = r#"{"error":"slow_down"}"#;
assert_eq!(
classify_token_response(400, body),
PollDecision::SlowDown
);
}
#[test]
fn classifies_access_denied_as_abort() {
let body = r#"{"error":"access_denied"}"#;
match classify_token_response(400, body) {
PollDecision::Abort(msg) => assert!(msg.contains("denied")),
other => panic!("expected Abort, got {other:?}"),
}
}
#[test]
fn classifies_expired_token_as_abort() {
let body = r#"{"error":"expired_token"}"#;
match classify_token_response(400, body) {
PollDecision::Abort(msg) => assert!(msg.contains("expired")),
other => panic!("expected Abort, got {other:?}"),
}
}
#[test]
fn classifies_unknown_error_with_description() {
let body = r#"{"error":"weird","error_description":"specific message"}"#;
match classify_token_response(400, body) {
PollDecision::Abort(msg) => {
assert!(msg.contains("weird"));
assert!(msg.contains("specific message"));
}
other => panic!("expected Abort, got {other:?}"),
}
}
#[test]
fn classifies_garbage_body_as_abort() {
match classify_token_response(400, "<html>500 internal</html>") {
PollDecision::Abort(msg) => assert!(msg.contains("unknown")),
other => panic!("expected Abort, got {other:?}"),
}
}
#[test]
fn classifies_garbage_success_body_as_abort() {
match classify_token_response(200, "{not json}") {
PollDecision::Abort(msg) => assert!(msg.contains("parse")),
other => panic!("expected Abort, got {other:?}"),
}
}
}