use std::time::Duration;
use serde::{Deserialize, Serialize};
use ureq::http::Response as HttpResponse;
use super::error::AuthError;
use super::server::ServerUrl;
const REQUEST_TIMEOUT_SECS: u64 = 30;
#[derive(Debug, Clone, PartialEq, Eq, Deserialize)]
pub struct OAuthInit {
#[serde(rename = "url")]
pub authorize_url: String,
}
#[derive(Debug, Clone, Serialize)]
struct CliTokenRequest<'a> {
code: &'a str,
#[serde(rename = "repoFullName")]
repo_full_name: &'a str,
#[serde(skip_serializing_if = "Option::is_none")]
name: Option<&'a str>,
}
#[derive(Debug, Clone, Deserialize)]
pub struct CliTokenResponse {
pub arta_token: String,
pub jwt: String,
pub user: GitHubUser,
pub token_id: String,
pub repo_full_name: String,
pub last_4: String,
}
#[derive(Debug, Clone, PartialEq, Eq, Deserialize)]
pub struct GitHubUser {
pub id: u64,
pub login: String,
}
pub fn oauth_start(server: &ServerUrl) -> Result<OAuthInit, AuthError> {
let redirect_uri = format!("{}/auth/callback", server.as_str());
let url = format!(
"{}/auth/login?redirect_uri={}",
server.as_str(),
url_encode(&redirect_uri),
);
let agent = build_agent();
let result = agent.get(&url).call();
let body = consume_response::<OAuthInit>(result)?;
Ok(body)
}
fn url_encode(s: &str) -> String {
let mut out = String::with_capacity(s.len() + 8);
for byte in s.bytes() {
match byte {
b'A'..=b'Z' | b'a'..=b'z' | b'0'..=b'9' | b'-' | b'_' | b'.' | b'~' => {
out.push(byte as char);
}
other => out.push_str(&format!("%{other:02X}")),
}
}
out
}
pub fn oauth_exchange(
server: &ServerUrl,
code: &str,
repo_full_name: &str,
name: Option<&str>,
) -> Result<CliTokenResponse, AuthError> {
let url = format!("{}/auth/cli-token", server.as_str());
let req = CliTokenRequest {
code,
repo_full_name,
name,
};
let agent = build_agent();
let result = agent
.post(&url)
.header("Content-Type", "application/json")
.send_json(&req);
consume_response::<CliTokenResponse>(result)
}
fn build_agent() -> ureq::Agent {
let config = ureq::Agent::config_builder()
.timeout_global(Some(Duration::from_secs(REQUEST_TIMEOUT_SECS)))
.user_agent(format!("aristo/{}", env!("CARGO_PKG_VERSION")))
.http_status_as_error(false)
.build();
config.into()
}
fn consume_response<T>(
result: Result<HttpResponse<ureq::Body>, ureq::Error>,
) -> Result<T, AuthError>
where
T: for<'de> serde::Deserialize<'de>,
{
let response = match result {
Ok(r) => r,
Err(e) => return Err(transport_error_to_auth(e)),
};
let status = response.status().as_u16();
let body_text = read_body_capped(response, 64 * 1024);
map_response(status, &body_text)
}
pub(crate) fn map_response<T>(status: u16, body: &str) -> Result<T, AuthError>
where
T: for<'de> serde::Deserialize<'de>,
{
match status {
200..=299 => serde_json::from_str::<T>(body).map_err(|e| {
AuthError::Malformed(format!("proxy returned 2xx with unparseable body: {e}"))
}),
401 | 403 => Err(AuthError::Invalid),
400..=499 => {
let msg = extract_error_message(body)
.unwrap_or_else(|| format!("HTTP {status}: {}", truncate(body, 200)));
Err(AuthError::Malformed(msg))
}
500..=599 => Err(AuthError::Malformed(format!(
"proxy HTTP {status}: {}",
extract_error_message(body).unwrap_or_else(|| truncate(body, 200))
))),
other => Err(AuthError::Malformed(format!(
"unexpected HTTP {other} from proxy"
))),
}
}
fn extract_error_message(body: &str) -> Option<String> {
let v: serde_json::Value = serde_json::from_str(body).ok()?;
v.get("error")
.and_then(|e| e.as_str())
.map(|s| s.to_string())
.or_else(|| {
v.get("message")
.and_then(|e| e.as_str())
.map(|s| s.to_string())
})
}
fn truncate(s: &str, max: usize) -> String {
if s.len() <= max {
s.to_string()
} else {
format!("{}…", &s[..max])
}
}
fn read_body_capped(response: HttpResponse<ureq::Body>, cap: usize) -> String {
use std::io::Read;
let mut reader = response.into_body().into_reader();
let mut buf = Vec::with_capacity(8 * 1024);
let mut tmp = [0u8; 8 * 1024];
while buf.len() < cap {
match reader.read(&mut tmp) {
Ok(0) => break,
Ok(n) => {
let to_take = (cap - buf.len()).min(n);
buf.extend_from_slice(&tmp[..to_take]);
}
Err(_) => break,
}
}
String::from_utf8_lossy(&buf).into_owned()
}
fn transport_error_to_auth(e: ureq::Error) -> AuthError {
let s = e.to_string();
if s.contains("timed out") || s.contains("timeout") {
AuthError::Malformed(format!("proxy request timed out: {s}"))
} else {
AuthError::Malformed(format!("proxy transport error: {s}"))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn map_response_2xx_decodes_oauth_init() {
let body =
r#"{"url":"https://github.com/login/oauth/authorize?client_id=x&redirect_uri=y"}"#;
let r: OAuthInit = map_response(200, body).expect("decode");
assert_eq!(
r.authorize_url,
"https://github.com/login/oauth/authorize?client_id=x&redirect_uri=y"
);
}
#[test]
fn map_response_2xx_decodes_cli_token_response() {
let body = r#"{
"arta_token": "arta_xyz",
"jwt": "jwt-blob",
"user": { "id": 42, "login": "octocat" },
"token_id": "tk_001",
"repo_full_name": "owner/repo",
"last_4": "wxyz"
}"#;
let r: CliTokenResponse = map_response(200, body).expect("decode");
assert_eq!(r.arta_token, "arta_xyz");
assert_eq!(r.user.login, "octocat");
assert_eq!(r.user.id, 42);
assert_eq!(r.last_4, "wxyz");
}
#[test]
fn map_response_401_maps_to_invalid() {
let r: Result<OAuthInit, _> = map_response(401, r#"{"error":"bad token"}"#);
assert_eq!(r.unwrap_err(), AuthError::Invalid);
}
#[test]
fn map_response_403_maps_to_invalid() {
let r: Result<OAuthInit, _> = map_response(403, r#"{"error":"forbidden"}"#);
assert_eq!(r.unwrap_err(), AuthError::Invalid);
}
#[test]
fn map_response_400_extracts_error_field() {
let r: Result<CliTokenResponse, _> = map_response(400, r#"{"error":"Missing code"}"#);
match r.unwrap_err() {
AuthError::Malformed(m) => assert!(m.contains("Missing code"), "got: {m}"),
other => panic!("expected Malformed, got {other:?}"),
}
}
#[test]
fn map_response_400_falls_back_to_message_field() {
let r: Result<CliTokenResponse, _> = map_response(400, r#"{"message":"bad shape"}"#);
match r.unwrap_err() {
AuthError::Malformed(m) => assert!(m.contains("bad shape"), "got: {m}"),
other => panic!("expected Malformed, got {other:?}"),
}
}
#[test]
fn map_response_400_with_non_json_body_uses_truncated_body() {
let r: Result<CliTokenResponse, _> = map_response(400, "plain text error");
match r.unwrap_err() {
AuthError::Malformed(m) => {
assert!(m.contains("400"), "got: {m}");
assert!(m.contains("plain text error"), "got: {m}");
}
other => panic!("expected Malformed, got {other:?}"),
}
}
#[test]
fn map_response_500_maps_to_malformed_with_proxy_label() {
let r: Result<CliTokenResponse, _> = map_response(500, r#"{"error":"oauth failed"}"#);
match r.unwrap_err() {
AuthError::Malformed(m) => {
assert!(m.contains("500"), "got: {m}");
assert!(m.contains("oauth failed"), "got: {m}");
}
other => panic!("expected Malformed, got {other:?}"),
}
}
#[test]
fn map_response_2xx_garbage_body_surfaces_malformed() {
let r: Result<OAuthInit, _> = map_response(200, "not json");
match r.unwrap_err() {
AuthError::Malformed(m) => assert!(m.contains("unparseable"), "got: {m}"),
other => panic!("expected Malformed, got {other:?}"),
}
}
#[test]
fn truncate_short_string_passes_through() {
assert_eq!(truncate("short", 100), "short");
}
#[test]
fn truncate_long_string_clips_with_ellipsis() {
let s = "a".repeat(500);
let t = truncate(&s, 10);
assert_eq!(t.chars().count(), 11); assert!(t.ends_with('…'));
}
}