use crate::provider::{GenericProviderConfig, Provider};
use secrecy::Secret;
use std::time::Duration;
use url::Url;
#[derive(Debug, Clone)]
pub struct DeviceFlowConfig {
pub client_id: String,
pub client_secret: Option<Secret<String>>,
pub scopes: Vec<String>,
pub redirect_uri: Option<Url>,
pub poll_interval: Duration,
pub max_attempts: u32,
pub backoff_multiplier: f64,
pub max_poll_interval: Duration,
pub user_agent: Option<String>,
pub additional_headers: std::collections::HashMap<String, String>,
pub generic_provider_config: Option<GenericProviderConfig>,
pub request_timeout: Duration,
pub use_pkce: Option<bool>,
}
impl DeviceFlowConfig {
pub fn new() -> Self {
Self {
client_id: String::new(),
client_secret: None,
scopes: Vec::new(),
redirect_uri: None,
poll_interval: Duration::from_secs(5),
max_attempts: 60, backoff_multiplier: 1.1,
max_poll_interval: Duration::from_secs(30),
user_agent: Some(format!("oauth-device-flows/{}", env!("CARGO_PKG_VERSION"))),
additional_headers: std::collections::HashMap::new(),
generic_provider_config: None,
request_timeout: Duration::from_secs(30),
use_pkce: None, }
}
pub fn client_id(mut self, client_id: impl Into<String>) -> Self {
self.client_id = client_id.into();
self
}
pub fn client_secret(mut self, client_secret: impl Into<String>) -> Self {
self.client_secret = Some(Secret::new(client_secret.into()));
self
}
pub fn scopes<I, S>(mut self, scopes: I) -> Self
where
I: IntoIterator<Item = S>,
S: Into<String>,
{
self.scopes = scopes.into_iter().map(|s| s.into()).collect();
self
}
pub fn scope(mut self, scope: impl Into<String>) -> Self {
self.scopes.push(scope.into());
self
}
pub fn redirect_uri(mut self, uri: impl Into<String>) -> Result<Self, url::ParseError> {
self.redirect_uri = Some(Url::parse(&uri.into())?);
Ok(self)
}
pub fn poll_interval(mut self, interval: Duration) -> Self {
self.poll_interval = interval;
self
}
pub fn max_attempts(mut self, attempts: u32) -> Self {
self.max_attempts = attempts;
self
}
pub fn backoff_multiplier(mut self, multiplier: f64) -> Self {
self.backoff_multiplier = multiplier;
self
}
pub fn max_poll_interval(mut self, interval: Duration) -> Self {
self.max_poll_interval = interval;
self
}
pub fn user_agent(mut self, user_agent: impl Into<String>) -> Self {
self.user_agent = Some(user_agent.into());
self
}
pub fn header(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
self.additional_headers.insert(key.into(), value.into());
self
}
pub fn generic_provider(mut self, config: GenericProviderConfig) -> Self {
self.generic_provider_config = Some(config);
self
}
pub fn request_timeout(mut self, timeout: Duration) -> Self {
self.request_timeout = timeout;
self
}
pub fn use_pkce(mut self, use_pkce: bool) -> Self {
self.use_pkce = Some(use_pkce);
self
}
pub fn effective_scopes(&self, provider: Provider) -> Vec<String> {
if !self.scopes.is_empty() {
self.scopes.clone()
} else if let Some(ref config) = self.generic_provider_config {
config.default_scopes.clone()
} else {
provider
.default_scopes()
.into_iter()
.map(|s| s.to_string())
.collect()
}
}
pub fn effective_poll_interval(&self, provider: Provider) -> Duration {
if self.poll_interval != Duration::from_secs(5) {
self.poll_interval
} else if let Some(ref config) = self.generic_provider_config {
config.default_poll_interval
} else {
provider.default_poll_interval()
}
}
pub fn should_use_pkce(&self, provider: Provider) -> bool {
self.use_pkce.unwrap_or_else(|| {
if let Some(config) = &self.generic_provider_config {
config.supports_pkce
} else {
provider.supports_pkce()
}
})
}
pub fn requires_client_secret(&self, provider: Provider) -> bool {
if let Some(config) = &self.generic_provider_config {
config.requires_client_secret
} else {
provider.requires_client_secret()
}
}
pub fn validate(&self, provider: Provider) -> Result<(), crate::error::DeviceFlowError> {
use crate::error::DeviceFlowError;
if self.client_id.is_empty() {
return Err(DeviceFlowError::invalid_client("Client ID is required"));
}
if self.requires_client_secret(provider) && self.client_secret.is_none() {
return Err(DeviceFlowError::invalid_client(format!(
"Client secret is required for {provider}"
)));
}
if provider == Provider::Generic && self.generic_provider_config.is_none() {
return Err(DeviceFlowError::invalid_client(
"Generic provider configuration is required for generic provider",
));
}
if self.max_attempts == 0 {
return Err(DeviceFlowError::invalid_client(
"Max attempts must be greater than 0",
));
}
if self.backoff_multiplier < 1.0 {
return Err(DeviceFlowError::invalid_client(
"Backoff multiplier must be >= 1.0",
));
}
Ok(())
}
}
impl Default for DeviceFlowConfig {
fn default() -> Self {
Self::new()
}
}