use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use std::time::Duration;
use anyhow::{bail, Context, Result};
use serde::Deserialize;
use crate::config::manager::ConfigManager;
use crate::model::types::AuthConfig;
#[derive(Debug, Clone, Deserialize)]
pub struct DeviceCodeResponse {
pub device_code: String,
pub user_code: String,
pub verification_uri_complete: String,
pub expires_in: u64,
#[serde(default = "default_interval")]
pub interval: u64,
}
fn default_interval() -> u64 {
5
}
const MIN_POLL_INTERVAL_SECS: u64 = 5;
const HTTP_REQUEST_TIMEOUT: Duration = Duration::from_secs(60);
const HTTP_CONNECT_TIMEOUT: Duration = Duration::from_secs(30);
fn http_client() -> Result<reqwest::Client> {
reqwest::Client::builder()
.user_agent(format!("zilliz-cli/{}", env!("CARGO_PKG_VERSION")))
.timeout(HTTP_REQUEST_TIMEOUT)
.connect_timeout(HTTP_CONNECT_TIMEOUT)
.build()
.context("Failed to build HTTP client")
}
#[derive(Debug, Clone, Deserialize)]
pub struct TokenResponse {
pub access_token: String,
}
#[derive(Debug, Deserialize)]
struct TokenErrorResponse {
error: String,
#[serde(default)]
#[allow(dead_code)]
error_description: String,
}
#[derive(Debug, Clone, Default)]
pub struct CancellationToken {
flag: Arc<AtomicBool>,
}
impl CancellationToken {
pub fn new() -> Self {
Self::default()
}
pub fn cancel(&self) {
self.flag.store(true, Ordering::SeqCst);
}
pub fn is_cancelled(&self) -> bool {
self.flag.load(Ordering::SeqCst)
}
}
pub async fn request_device_code(auth_config: &AuthConfig) -> Result<DeviceCodeResponse> {
let client = http_client()?;
let resp: DeviceCodeResponse = client
.post(format!("{}/oauth/device/code", auth_config.auth0_domain))
.form(&[
("client_id", auth_config.client_id.as_str()),
("scope", "openid email profile"),
])
.send()
.await
.context("Failed to request device code")?
.json()
.await
.context("Invalid device code response")?;
Ok(resp)
}
pub async fn poll_for_token(
auth_config: &AuthConfig,
device_code: &str,
interval_secs: u64,
timeout_secs: u64,
cancel: CancellationToken,
) -> Result<TokenResponse> {
let client = http_client()?;
let url = format!("{}/oauth/token", auth_config.auth0_domain);
let start = std::time::Instant::now();
let timeout = Duration::from_secs(timeout_secs);
let mut interval = interval_secs.max(MIN_POLL_INTERVAL_SECS);
loop {
if cancel.is_cancelled() {
bail!("Cancelled by user");
}
if start.elapsed() > timeout {
bail!("Authentication timed out. Please try again.");
}
let slice = Duration::from_millis(250);
let mut slept = Duration::ZERO;
let target = Duration::from_secs(interval);
while slept < target {
if cancel.is_cancelled() {
bail!("Cancelled by user");
}
tokio::time::sleep(slice).await;
slept += slice;
}
let resp = client
.post(&url)
.form(&[
("grant_type", "urn:ietf:params:oauth:grant-type:device_code"),
("device_code", device_code),
("client_id", auth_config.client_id.as_str()),
])
.send()
.await
.context("Token poll request failed")?;
if resp.status().is_success() {
return resp
.json::<TokenResponse>()
.await
.context("Invalid token response");
}
let error_resp: TokenErrorResponse = resp.json().await.unwrap_or(TokenErrorResponse {
error: "unknown".to_string(),
error_description: "Unknown error".to_string(),
});
match error_resp.error.as_str() {
"authorization_pending" => continue,
"slow_down" => {
interval = (interval + 5).min(60);
continue;
}
"expired_token" => bail!("Authentication timed out. Please try again."),
"access_denied" => bail!("Authentication was denied."),
other => bail!("Authentication failed: {}", other),
}
}
}
pub fn save_api_key(config_mgr: &ConfigManager, api_key: &str) -> Result<()> {
config_mgr.save_api_key_only(api_key)
}
#[derive(Clone)]
pub struct LoginPayload {
pub user_id: String,
pub email: String,
pub name: String,
pub orgs: Vec<serde_json::Value>,
}
impl std::fmt::Debug for LoginPayload {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("LoginPayload")
.field("user_id", &self.user_id)
.field("email", &self.email)
.field("name", &self.name)
.field(
"orgs",
&format_args!("<{} org(s) redacted>", self.orgs.len()),
)
.finish()
}
}
pub async fn exchange_token(auth_config: &AuthConfig, access_token: &str) -> Result<LoginPayload> {
let client = http_client()?;
let url = format!("{}/account/v1/cli/login", auth_config.login_api);
let resp = client
.post(&url)
.header("Authorization", format!("Bearer {}", access_token))
.send()
.await
.context("Failed to exchange token")?;
let status = resp.status();
let body_text = resp.text().await.context("Failed to read login response")?;
let body: serde_json::Value = serde_json::from_str(&body_text)
.with_context(|| format!("Invalid login response (HTTP {})", status.as_u16()))?;
let code = body
.get("code")
.or_else(|| body.get("Code"))
.and_then(|v| v.as_i64())
.unwrap_or(-1);
if !status.is_success() || (code != 0 && code != 200) {
let msg = body
.get("msg")
.or_else(|| body.get("Message"))
.and_then(|v| v.as_str())
.unwrap_or("Login failed");
bail!("Login failed ({}): {}", status.as_u16(), msg);
}
let data = body
.get("data")
.or_else(|| body.get("Data"))
.cloned()
.unwrap_or_default();
let user = data.get("user").cloned().unwrap_or_default();
let orgs = data
.get("orgs")
.and_then(|v| v.as_array())
.cloned()
.unwrap_or_default();
Ok(LoginPayload {
user_id: user
.get("userId")
.and_then(|v| v.as_str())
.unwrap_or("")
.to_string(),
email: user
.get("email")
.and_then(|v| v.as_str())
.unwrap_or("")
.to_string(),
name: user
.get("name")
.and_then(|v| v.as_str())
.unwrap_or("")
.to_string(),
orgs,
})
}
pub fn open_browser(url: &str) -> std::io::Result<()> {
#[cfg(target_os = "macos")]
{
std::process::Command::new("open").arg(url).spawn()?;
}
#[cfg(target_os = "linux")]
{
std::process::Command::new("xdg-open").arg(url).spawn()?;
}
#[cfg(target_os = "windows")]
{
std::process::Command::new("explorer.exe")
.arg(url)
.spawn()?;
}
Ok(())
}