use std::path::PathBuf;
use std::time::Duration;
use chrono::{DateTime, Utc};
use crate::browser::{Browser, DefaultBrowser};
use crate::config::{self, App};
use crate::deviceflow::{DeviceCodeUI, DeviceFlowClient, SimpleDeviceCodeUI};
use crate::error::Error;
use crate::github::GitHubClient;
use crate::keyring::{AccessToken, DEFAULT_SERVICE_KEY, Keyring};
use crate::log::Logger;
pub struct InputGet {
pub keyring_service: String,
pub app_name: String,
pub config_file_path: String,
pub app_owner: String,
pub min_expiration: Duration,
}
impl Default for InputGet {
fn default() -> Self {
Self {
keyring_service: String::new(),
app_name: String::new(),
config_file_path: String::new(),
app_owner: String::new(),
min_expiration: Duration::ZERO,
}
}
}
pub struct Client {
logger: Logger,
device_code_ui: Box<dyn DeviceCodeUI>,
browser: Box<dyn Browser>,
keyring: Keyring,
github_base_url: String,
api_base_url: String,
}
impl Client {
pub fn new() -> Self {
Self {
logger: Logger::new(),
device_code_ui: Box::new(SimpleDeviceCodeUI),
browser: Box::new(DefaultBrowser),
keyring: Keyring::new(),
github_base_url: "https://github.com".to_string(),
api_base_url: "https://api.github.com".to_string(),
}
}
pub fn set_logger(&mut self, logger: Logger) {
self.logger = logger;
}
pub fn set_device_code_ui(&mut self, ui: Box<dyn DeviceCodeUI>) {
self.device_code_ui = ui;
}
pub fn set_browser(&mut self, browser: Box<dyn Browser>) {
self.browser = browser;
}
pub fn set_keyring(&mut self, keyring: Keyring) {
self.keyring = keyring;
}
pub fn set_github_base_url(&mut self, url: String) {
self.github_base_url = url.trim_end_matches('/').to_string();
}
pub fn set_api_base_url(&mut self, url: String) {
self.api_base_url = url.trim_end_matches('/').to_string();
}
pub fn token_source(self, input: InputGet) -> TokenSource {
TokenSource::new(self, input)
}
pub async fn get(&self, input: &InputGet) -> crate::Result<(AccessToken, App)> {
let config_path = if input.config_file_path.is_empty() {
config::get_path(|k| std::env::var(k).ok(), std::env::consts::OS)?
} else {
PathBuf::from(&input.config_file_path)
};
let cfg = config::read(&config_path)?
.ok_or_else(|| Error::Config("configuration file is empty".into()))?;
cfg.validate()?;
let app_name = if input.app_name.is_empty() {
std::env::var("GHTKN_APP").unwrap_or_default()
} else {
input.app_name.clone()
};
let app = config::select_app(&cfg, &app_name, &input.app_owner)
.ok_or_else(|| Error::Config("no matching app found".into()))?
.clone();
let service = if input.keyring_service.is_empty() {
DEFAULT_SERVICE_KEY.to_string()
} else {
input.keyring_service.clone()
};
match self
.get_or_create_token(&service, &app, input.min_expiration)
.await
{
Ok(token) => Ok((token, app)),
Err(Error::StoreToken { message, token, .. }) => Err(Error::StoreToken {
message,
token,
app: Box::new(app),
}),
Err(e) => Err(e),
}
}
async fn get_or_create_token(
&self,
service: &str,
app: &App,
min_expiration: Duration,
) -> crate::Result<AccessToken> {
match self.keyring.get(service, &app.client_id) {
Ok(Some(token)) => {
if check_expired(token.expiration_date, min_expiration) {
if let Some(cb) = &self.logger.expire {
cb(token.expiration_date);
}
} else {
return Ok(token);
}
}
Ok(None) => {
if let Some(cb) = &self.logger.access_token_is_not_found_in_keyring {
cb();
}
}
Err(e) => {
if let Some(cb) = &self.logger.failed_to_get_access_token_from_keyring {
cb(&e.to_string());
}
}
}
self.create_token(service, app).await
}
async fn create_token(&self, service: &str, app: &App) -> crate::Result<AccessToken> {
let http_client = reqwest::Client::new();
let df_client = DeviceFlowClient::with_base_url(
http_client,
self.browser.as_ref(),
&self.logger,
self.device_code_ui.as_ref(),
self.github_base_url.clone(),
);
let df_token = df_client.create(&app.client_id).await?;
let gh_client =
GitHubClient::with_base_url(&df_token.access_token, self.api_base_url.clone());
let user = gh_client.get_user().await?;
let kr_token = AccessToken {
access_token: df_token.access_token,
expiration_date: df_token.expiration_date,
login: user.login,
};
if let Err(e) = self.keyring.set(service, &app.client_id, &kr_token) {
return Err(Error::StoreToken {
message: e.to_string(),
token: Box::new(kr_token),
app: Box::new(app.clone()),
});
}
Ok(kr_token)
}
}
impl Default for Client {
fn default() -> Self {
Self::new()
}
}
pub struct TokenSource {
client: Client,
input: InputGet,
cached: tokio::sync::Mutex<Option<String>>,
}
impl TokenSource {
pub fn new(client: Client, input: InputGet) -> Self {
Self {
client,
input,
cached: tokio::sync::Mutex::new(None),
}
}
pub async fn token(&self) -> crate::Result<String> {
let mut cached = self.cached.lock().await;
if let Some(token) = cached.as_ref() {
return Ok(token.clone());
}
let access_token = match self.client.get(&self.input).await {
Ok((token, _)) => token.access_token,
Err(Error::StoreToken { token, message, .. }) => {
tracing::warn!(error = message, "keyring write failed, using token anyway");
token.access_token
}
Err(e) => return Err(e),
};
*cached = Some(access_token.clone());
Ok(access_token)
}
pub async fn token_or_none(&self) -> Option<String> {
match self.token().await {
Ok(token) => Some(token),
Err(e) => {
tracing::warn!(error = %e, "ghtkn token unavailable");
None
}
}
}
}
fn check_expired(expiration_date: DateTime<Utc>, min_expiration: Duration) -> bool {
let min_exp = chrono::Duration::from_std(min_expiration).unwrap_or(chrono::Duration::zero());
Utc::now() + min_exp > expiration_date
}
#[cfg(test)]
mod tests {
use std::time::Duration;
use chrono::{TimeZone, Utc};
use super::*;
#[test]
fn test_check_expired_not_expired() {
let expiration = Utc::now() + chrono::Duration::hours(8);
assert!(!check_expired(expiration, Duration::ZERO));
}
#[test]
fn test_check_expired_is_expired() {
let expiration = Utc::now() - chrono::Duration::hours(1);
assert!(check_expired(expiration, Duration::ZERO));
}
#[test]
fn test_check_expired_with_min_expiration() {
let expiration = Utc::now() + chrono::Duration::minutes(5);
let min_exp = Duration::from_secs(10 * 60); assert!(check_expired(expiration, min_exp));
}
#[test]
fn test_check_expired_with_min_expiration_sufficient() {
let expiration = Utc::now() + chrono::Duration::minutes(20);
let min_exp = Duration::from_secs(10 * 60); assert!(!check_expired(expiration, min_exp));
}
#[test]
fn test_check_expired_exactly_at_boundary() {
let expiration = Utc.with_ymd_and_hms(1970, 1, 1, 0, 0, 0).unwrap();
assert!(check_expired(expiration, Duration::ZERO));
}
#[test]
fn test_input_get_default() {
let input = InputGet::default();
assert!(input.keyring_service.is_empty());
assert!(input.app_name.is_empty());
assert!(input.config_file_path.is_empty());
assert!(input.app_owner.is_empty());
assert_eq!(input.min_expiration, Duration::ZERO);
}
#[test]
fn test_client_new() {
let _client = Client::new();
}
#[test]
fn test_client_default() {
let _client = Client::default();
}
#[test]
fn test_client_set_logger() {
let mut client = Client::new();
let logger = Logger::new();
client.set_logger(logger);
}
#[test]
fn test_client_set_device_code_ui() {
let mut client = Client::new();
client.set_device_code_ui(Box::new(SimpleDeviceCodeUI));
}
#[test]
fn test_client_set_browser() {
let mut client = Client::new();
client.set_browser(Box::new(DefaultBrowser));
}
#[test]
fn test_client_set_keyring() {
let mut client = Client::new();
client.set_keyring(Keyring::new());
}
#[test]
fn test_client_set_github_base_url_trims_trailing_slash() {
let mut client = Client::new();
client.set_github_base_url("https://ghe.example.com/".to_string());
assert_eq!(client.github_base_url, "https://ghe.example.com");
}
#[test]
fn test_client_set_api_base_url_trims_trailing_slash() {
let mut client = Client::new();
client.set_api_base_url("https://ghe.example.com/api/v3/".to_string());
assert_eq!(client.api_base_url, "https://ghe.example.com/api/v3");
}
#[test]
fn test_token_source_new() {
let _ts = TokenSource::new(Client::new(), InputGet::default());
}
}