use std::{
env,
sync::{LazyLock, OnceLock},
};
use crate::config::O2P_ROUTE_PREFIX;
use crate::oauth2::discovery::{OidcDiscoveryDocument, OidcDiscoveryError, fetch_oidc_discovery};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct ProviderName(&'static str);
impl ProviderName {
pub(crate) const fn from_static(s: &'static str) -> Self {
Self(s)
}
pub(crate) fn from_env_leaked(raw: String) -> Self {
Self(leak_static(raw))
}
pub const fn as_str(&self) -> &'static str {
self.0
}
pub fn from_registered(s: &str) -> Option<Self> {
ProviderKind::from_provider_name(s)
.and_then(provider_for)
.map(|cfg| cfg.provider_name)
}
}
impl std::fmt::Display for ProviderName {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str(self.0)
}
}
impl AsRef<str> for ProviderName {
fn as_ref(&self) -> &str {
self.0
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub(crate) enum ProviderKind {
Google,
Custom(CustomSlot),
}
pub(crate) struct ProviderPreset {
pub(crate) display_name: &'static str,
pub(crate) provider_name: ProviderName,
pub(crate) icon_slug: &'static str,
pub(crate) button_color: &'static str,
pub(crate) button_hover_color: &'static str,
pub(crate) additional_allowed_origins: &'static [&'static str],
}
pub(crate) const AUTH0_PRESET: ProviderPreset = ProviderPreset {
display_name: "Auth0",
provider_name: ProviderName::from_static("auth0"),
icon_slug: "auth0",
button_color: "#eb5424",
button_hover_color: "#c94419",
additional_allowed_origins: &[],
};
pub(crate) const KEYCLOAK_PRESET: ProviderPreset = ProviderPreset {
display_name: "Keycloak",
provider_name: ProviderName::from_static("keycloak"),
icon_slug: "keycloak",
button_color: "#4d4d4d",
button_hover_color: "#333333",
additional_allowed_origins: &[],
};
pub(crate) const ENTRA_PRESET: ProviderPreset = ProviderPreset {
display_name: "Microsoft",
provider_name: ProviderName::from_static("entra"),
icon_slug: "entra",
button_color: "#0078D4",
button_hover_color: "#005A9E",
additional_allowed_origins: &["https://login.live.com"],
};
pub(crate) const ZITADEL_PRESET: ProviderPreset = ProviderPreset {
display_name: "Zitadel",
provider_name: ProviderName::from_static("zitadel"),
icon_slug: "zitadel",
button_color: "#333333",
button_hover_color: "#1a1a1a",
additional_allowed_origins: &[],
};
pub(crate) const OKTA_PRESET: ProviderPreset = ProviderPreset {
display_name: "Okta",
provider_name: ProviderName::from_static("okta"),
icon_slug: "okta",
button_color: "#007dc1",
button_hover_color: "#005e93",
additional_allowed_origins: &[],
};
pub(crate) const AUTHENTIK_PRESET: ProviderPreset = ProviderPreset {
display_name: "Authentik",
provider_name: ProviderName::from_static("authentik"),
icon_slug: "authentik",
button_color: "#fd4b2d",
button_hover_color: "#e03d1f",
additional_allowed_origins: &[],
};
pub(crate) const LINE_PRESET: ProviderPreset = ProviderPreset {
display_name: "LINE",
provider_name: ProviderName::from_static("line"),
icon_slug: "line",
button_color: "#06C755",
button_hover_color: "#05A647",
additional_allowed_origins: &[],
};
pub(crate) const APPLE_PRESET: ProviderPreset = ProviderPreset {
display_name: "Apple",
provider_name: ProviderName::from_static("apple"),
icon_slug: "apple",
button_color: "#000000",
button_hover_color: "#333333",
additional_allowed_origins: &[],
};
fn resolve_preset(key: &str) -> Result<&'static ProviderPreset, String> {
match key {
"auth0" => Ok(&AUTH0_PRESET),
"keycloak" => Ok(&KEYCLOAK_PRESET),
"entra" => Ok(&ENTRA_PRESET),
"zitadel" => Ok(&ZITADEL_PRESET),
"okta" => Ok(&OKTA_PRESET),
"authentik" => Ok(&AUTHENTIK_PRESET),
"line" => Ok(&LINE_PRESET),
"apple" => Ok(&APPLE_PRESET),
other => Err(format!(
"unknown PRESET '{other}' (expected one of: \
auth0, keycloak, entra, zitadel, okta, authentik, line, apple)"
)),
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub(crate) enum CustomSlot {
Slot1,
Slot2,
Slot3,
Slot4,
Slot5,
Slot6,
Slot7,
Slot8,
}
impl CustomSlot {
pub(crate) const ALL: &'static [Self] = &[
Self::Slot1,
Self::Slot2,
Self::Slot3,
Self::Slot4,
Self::Slot5,
Self::Slot6,
Self::Slot7,
Self::Slot8,
];
pub(crate) const fn label(self) -> &'static str {
match self {
Self::Slot1 => "custom1",
Self::Slot2 => "custom2",
Self::Slot3 => "custom3",
Self::Slot4 => "custom4",
Self::Slot5 => "custom5",
Self::Slot6 => "custom6",
Self::Slot7 => "custom7",
Self::Slot8 => "custom8",
}
}
pub(crate) const fn env_prefix(self) -> &'static str {
match self {
Self::Slot1 => "OAUTH2_CUSTOM1",
Self::Slot2 => "OAUTH2_CUSTOM2",
Self::Slot3 => "OAUTH2_CUSTOM3",
Self::Slot4 => "OAUTH2_CUSTOM4",
Self::Slot5 => "OAUTH2_CUSTOM5",
Self::Slot6 => "OAUTH2_CUSTOM6",
Self::Slot7 => "OAUTH2_CUSTOM7",
Self::Slot8 => "OAUTH2_CUSTOM8",
}
}
pub(crate) const fn button_class(self) -> &'static str {
match self {
Self::Slot1 => "btn-oauth2 btn-custom1",
Self::Slot2 => "btn-oauth2 btn-custom2",
Self::Slot3 => "btn-oauth2 btn-custom3",
Self::Slot4 => "btn-oauth2 btn-custom4",
Self::Slot5 => "btn-oauth2 btn-custom5",
Self::Slot6 => "btn-oauth2 btn-custom6",
Self::Slot7 => "btn-oauth2 btn-custom7",
Self::Slot8 => "btn-oauth2 btn-custom8",
}
}
}
impl ProviderKind {
pub(crate) const ALL: &'static [Self] = &[
Self::Google,
Self::Custom(CustomSlot::Slot1),
Self::Custom(CustomSlot::Slot2),
Self::Custom(CustomSlot::Slot3),
Self::Custom(CustomSlot::Slot4),
Self::Custom(CustomSlot::Slot5),
Self::Custom(CustomSlot::Slot6),
Self::Custom(CustomSlot::Slot7),
Self::Custom(CustomSlot::Slot8),
];
pub(crate) fn optional_env_contract(&self) -> Option<(&'static str, &'static [&'static str])> {
match self {
Self::Google => None,
Self::Custom(CustomSlot::Slot1) => Some((
"OAUTH2_CUSTOM1_CLIENT_ID",
&["OAUTH2_CUSTOM1_CLIENT_SECRET", "OAUTH2_CUSTOM1_ISSUER_URL"],
)),
Self::Custom(CustomSlot::Slot2) => Some((
"OAUTH2_CUSTOM2_CLIENT_ID",
&["OAUTH2_CUSTOM2_CLIENT_SECRET", "OAUTH2_CUSTOM2_ISSUER_URL"],
)),
Self::Custom(CustomSlot::Slot3) => Some((
"OAUTH2_CUSTOM3_CLIENT_ID",
&["OAUTH2_CUSTOM3_CLIENT_SECRET", "OAUTH2_CUSTOM3_ISSUER_URL"],
)),
Self::Custom(CustomSlot::Slot4) => Some((
"OAUTH2_CUSTOM4_CLIENT_ID",
&["OAUTH2_CUSTOM4_CLIENT_SECRET", "OAUTH2_CUSTOM4_ISSUER_URL"],
)),
Self::Custom(CustomSlot::Slot5) => Some((
"OAUTH2_CUSTOM5_CLIENT_ID",
&["OAUTH2_CUSTOM5_CLIENT_SECRET", "OAUTH2_CUSTOM5_ISSUER_URL"],
)),
Self::Custom(CustomSlot::Slot6) => Some((
"OAUTH2_CUSTOM6_CLIENT_ID",
&["OAUTH2_CUSTOM6_CLIENT_SECRET", "OAUTH2_CUSTOM6_ISSUER_URL"],
)),
Self::Custom(CustomSlot::Slot7) => Some((
"OAUTH2_CUSTOM7_CLIENT_ID",
&["OAUTH2_CUSTOM7_CLIENT_SECRET", "OAUTH2_CUSTOM7_ISSUER_URL"],
)),
Self::Custom(CustomSlot::Slot8) => Some((
"OAUTH2_CUSTOM8_CLIENT_ID",
&["OAUTH2_CUSTOM8_CLIENT_SECRET", "OAUTH2_CUSTOM8_ISSUER_URL"],
)),
}
}
pub(crate) const fn as_str(&self) -> &'static str {
match self {
Self::Google => "google",
Self::Custom(slot) => slot.label(),
}
}
pub(crate) fn from_provider_name(s: &str) -> Option<Self> {
match s {
"google" => Some(Self::Google),
_ => CustomSlot::ALL
.iter()
.copied()
.find(|&slot| {
provider_for(Self::Custom(slot))
.is_some_and(|cfg| cfg.provider_name.as_str() == s)
})
.map(Self::Custom),
}
}
}
#[derive(Debug, Clone)]
#[non_exhaustive]
pub struct ProviderInfo {
pub provider_name: ProviderName,
pub display_name: &'static str,
pub button_class: &'static str,
pub icon_slug: &'static str,
pub button_color: Option<&'static str>,
pub button_hover_color: Option<&'static str>,
pub css_var_suffix: Option<&'static str>,
}
impl std::fmt::Display for ProviderKind {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str(self.as_str())
}
}
pub(crate) struct ProviderConfig {
pub(crate) kind: ProviderKind,
pub(crate) client_id: String,
pub(crate) client_secret: String,
pub(crate) issuer_url: String,
pub(crate) redirect_uri: String,
pub(crate) response_mode: String,
pub(crate) query_string: String,
pub(crate) discovery: OnceLock<OidcDiscoveryDocument>,
pub(crate) additional_allowed_origins: Vec<String>,
pub(crate) provider_name: ProviderName,
pub(crate) display_name: &'static str,
pub(crate) button_class: &'static str,
pub(crate) icon_slug: &'static str,
pub(crate) button_color: Option<&'static str>,
pub(crate) button_hover_color: Option<&'static str>,
pub(crate) css_var_suffix: Option<&'static str>,
pub(crate) strict_display_claims: bool,
}
impl ProviderConfig {
async fn get_or_fetch_discovery(&self) -> Result<&OidcDiscoveryDocument, OidcDiscoveryError> {
if let Some(cached) = self.discovery.get() {
return Ok(cached);
}
tracing::debug!(
provider = %self.kind,
"Fetching OIDC discovery for issuer: {}",
self.issuer_url
);
let document = fetch_oidc_discovery(&self.issuer_url).await?;
let _ = self.discovery.set(document);
self.discovery.get().ok_or_else(|| {
OidcDiscoveryError::CacheError("Failed to cache discovery document".to_string())
})
}
pub(crate) async fn auth_url(&self) -> Result<String, OidcDiscoveryError> {
let doc = self.get_or_fetch_discovery().await?;
Ok(doc.authorization_endpoint.clone())
}
pub(crate) async fn token_url(&self) -> Result<String, OidcDiscoveryError> {
let doc = self.get_or_fetch_discovery().await?;
Ok(doc.token_endpoint.clone())
}
pub(crate) async fn jwks_url(&self) -> Result<String, OidcDiscoveryError> {
let doc = self.get_or_fetch_discovery().await?;
Ok(doc.jwks_uri.clone())
}
pub(crate) async fn userinfo_url(&self) -> Result<String, OidcDiscoveryError> {
let doc = self.get_or_fetch_discovery().await?;
Ok(doc.userinfo_endpoint.clone())
}
pub(crate) async fn expected_issuer(&self) -> Result<String, OidcDiscoveryError> {
let doc = self.get_or_fetch_discovery().await?;
Ok(doc.issuer.clone())
}
}
pub(crate) static GOOGLE_PROVIDER: LazyLock<ProviderConfig> = LazyLock::new(|| {
let client_id =
env::var("OAUTH2_GOOGLE_CLIENT_ID").expect("OAUTH2_GOOGLE_CLIENT_ID must be set");
let client_secret =
env::var("OAUTH2_GOOGLE_CLIENT_SECRET").expect("OAUTH2_GOOGLE_CLIENT_SECRET must be set");
let issuer_url =
env::var("OAUTH2_ISSUER_URL").unwrap_or_else(|_| "https://accounts.google.com".to_string());
let origin = env::var("ORIGIN").expect("Missing ORIGIN!");
let redirect_uri = format!(
"{}{}/oauth2/google/authorized",
origin,
O2P_ROUTE_PREFIX.as_str()
);
let response_mode = {
let mode = env::var("OAUTH2_RESPONSE_MODE").unwrap_or_else(|_| "form_post".to_string());
match mode.to_lowercase().as_str() {
"form_post" => "form_post".to_string(),
"query" => "query".to_string(),
_ => panic!("Invalid OAUTH2_RESPONSE_MODE '{mode}'. Must be 'form_post' or 'query'."),
}
};
let scope = env::var("OAUTH2_SCOPE").unwrap_or_else(|_| "openid+email+profile".to_string());
let response_type = env::var("OAUTH2_RESPONSE_TYPE").unwrap_or_else(|_| "code".to_string());
let prompt = parse_prompt("OAUTH2_GOOGLE_PROMPT").unwrap_or_else(|msg| panic!("{msg}"));
let prompt_segment = prompt.map(|p| format!("&prompt={p}")).unwrap_or_default();
let query_string = format!(
"&response_type={}&scope={}&response_mode={}&access_type=online{}",
response_type, scope, response_mode, prompt_segment
);
let strict_display_claims = read_strict_display_claims("OAUTH2_GOOGLE_STRICT_DISPLAY_CLAIMS");
ProviderConfig {
kind: ProviderKind::Google,
client_id,
client_secret,
issuer_url,
redirect_uri,
response_mode,
query_string,
discovery: OnceLock::new(),
additional_allowed_origins: Vec::new(),
provider_name: ProviderName::from_static("google"),
display_name: "Google",
button_class: "btn-oauth2 btn-google",
icon_slug: "google",
button_color: None,
button_hover_color: None,
css_var_suffix: None,
strict_display_claims,
}
});
const CUSTOM_DEFAULT_BUTTON_COLOR: &str = "#6b7280";
const CUSTOM_DEFAULT_BUTTON_HOVER_COLOR: &str = "#4b5563";
fn leak_static(s: String) -> &'static str {
Box::leak(s.into_boxed_str())
}
fn parse_strict_display_claims(env_var: &str) -> Result<bool, String> {
match env::var(env_var).ok().as_deref() {
None | Some("true") => Ok(true),
Some("false") => Ok(false),
Some(other) => Err(format!(
"Invalid {env_var} '{other}'. Must be 'true' or 'false'."
)),
}
}
fn read_strict_display_claims(env_var: &str) -> bool {
parse_strict_display_claims(env_var).unwrap_or_else(|msg| panic!("{msg}"))
}
pub(crate) fn validate_named_provider_strict_display_claims() -> Result<(), String> {
parse_strict_display_claims("OAUTH2_GOOGLE_STRICT_DISPLAY_CLAIMS")?;
Ok(())
}
fn parse_prompt(env_var: &str) -> Result<Option<&'static str>, String> {
match env::var(env_var).ok().as_deref() {
None => Ok(Some("consent")),
Some("") => Ok(None),
Some("none") => Ok(Some("none")),
Some("login") => Ok(Some("login")),
Some("consent") => Ok(Some("consent")),
Some("select_account") => Ok(Some("select_account")),
Some(other) => Err(format!(
"Invalid {env_var} '{other}'. \
Must be one of: none, login, consent, select_account \
(or empty to omit the parameter)."
)),
}
}
pub(crate) fn validate_named_provider_prompt() -> Result<(), String> {
parse_prompt("OAUTH2_GOOGLE_PROMPT")?;
Ok(())
}
fn build_custom_provider(slot: CustomSlot) -> Option<ProviderConfig> {
let prefix = slot.env_prefix();
let client_id = env::var(format!("{prefix}_CLIENT_ID")).ok()?;
let client_secret = env::var(format!("{prefix}_CLIENT_SECRET"))
.unwrap_or_else(|_| panic!("{prefix}_CLIENT_ID set but {prefix}_CLIENT_SECRET missing"));
let issuer_url = env::var(format!("{prefix}_ISSUER_URL"))
.unwrap_or_else(|_| panic!("{prefix}_CLIENT_ID set but {prefix}_ISSUER_URL missing"));
let preset: Option<&'static ProviderPreset> =
match env::var(format!("{prefix}_PRESET")).ok().as_deref() {
None => None,
Some(key) => {
Some(resolve_preset(key).unwrap_or_else(|msg| panic!("{prefix}_PRESET: {msg}")))
}
};
let display_name: &'static str = env::var(format!("{prefix}_DISPLAY_NAME"))
.ok()
.map(leak_static)
.or(preset.map(|p| p.display_name))
.unwrap_or_else(|| {
panic!("{prefix}_CLIENT_ID set but {prefix}_DISPLAY_NAME missing (no PRESET to supply a default)")
});
let provider_name: ProviderName = env::var(format!("{prefix}_NAME"))
.ok()
.map(ProviderName::from_env_leaked)
.or(preset.map(|p| p.provider_name))
.unwrap_or_else(|| {
panic!(
"{prefix}_CLIENT_ID set but {prefix}_NAME missing (no PRESET to supply a default)"
)
});
let origin = env::var("ORIGIN").expect("Missing ORIGIN!");
let response_mode =
env::var(format!("{prefix}_RESPONSE_MODE")).unwrap_or_else(|_| "form_post".to_string());
let scope =
env::var(format!("{prefix}_SCOPE")).unwrap_or_else(|_| "openid+email+profile".to_string());
let button_color: &'static str = env::var(format!("{prefix}_BUTTON_COLOR"))
.ok()
.map(leak_static)
.or(preset.map(|p| p.button_color))
.unwrap_or(CUSTOM_DEFAULT_BUTTON_COLOR);
let button_hover_color: &'static str = env::var(format!("{prefix}_BUTTON_HOVER_COLOR"))
.ok()
.map(leak_static)
.or(preset.map(|p| p.button_hover_color))
.unwrap_or(CUSTOM_DEFAULT_BUTTON_HOVER_COLOR);
let icon_slug: &'static str = env::var(format!("{prefix}_ICON_SLUG"))
.ok()
.map(leak_static)
.or(preset.map(|p| p.icon_slug))
.unwrap_or("openid");
let additional_allowed_origins: Vec<String> = preset
.map(|p| {
p.additional_allowed_origins
.iter()
.map(|s| s.to_string())
.collect()
})
.unwrap_or_default();
let strict_display_claims =
read_strict_display_claims(&format!("{prefix}_STRICT_DISPLAY_CLAIMS"));
let redirect_uri = format!(
"{}{}/oauth2/{}/authorized",
origin,
O2P_ROUTE_PREFIX.as_str(),
provider_name
);
let prompt = parse_prompt(&format!("{prefix}_PROMPT")).unwrap_or_else(|msg| panic!("{msg}"));
let prompt_segment = prompt.map(|p| format!("&prompt={p}")).unwrap_or_default();
let query_string = format!(
"&response_type=code&scope={}&response_mode={}{}",
scope, response_mode, prompt_segment
);
Some(ProviderConfig {
kind: ProviderKind::Custom(slot),
client_id,
client_secret,
issuer_url,
redirect_uri,
response_mode,
query_string,
discovery: OnceLock::new(),
additional_allowed_origins,
provider_name,
display_name,
button_class: slot.button_class(),
icon_slug,
button_color: Some(button_color),
button_hover_color: Some(button_hover_color),
css_var_suffix: Some(slot.label()),
strict_display_claims,
})
}
pub(crate) static CUSTOM1_PROVIDER: LazyLock<Option<ProviderConfig>> =
LazyLock::new(|| build_custom_provider(CustomSlot::Slot1));
pub(crate) static CUSTOM2_PROVIDER: LazyLock<Option<ProviderConfig>> =
LazyLock::new(|| build_custom_provider(CustomSlot::Slot2));
pub(crate) static CUSTOM3_PROVIDER: LazyLock<Option<ProviderConfig>> =
LazyLock::new(|| build_custom_provider(CustomSlot::Slot3));
pub(crate) static CUSTOM4_PROVIDER: LazyLock<Option<ProviderConfig>> =
LazyLock::new(|| build_custom_provider(CustomSlot::Slot4));
pub(crate) static CUSTOM5_PROVIDER: LazyLock<Option<ProviderConfig>> =
LazyLock::new(|| build_custom_provider(CustomSlot::Slot5));
pub(crate) static CUSTOM6_PROVIDER: LazyLock<Option<ProviderConfig>> =
LazyLock::new(|| build_custom_provider(CustomSlot::Slot6));
pub(crate) static CUSTOM7_PROVIDER: LazyLock<Option<ProviderConfig>> =
LazyLock::new(|| build_custom_provider(CustomSlot::Slot7));
pub(crate) static CUSTOM8_PROVIDER: LazyLock<Option<ProviderConfig>> =
LazyLock::new(|| build_custom_provider(CustomSlot::Slot8));
pub(crate) const RESERVED_PROVIDER_NAMES: &[&str] = &[
"google",
"authorized",
"accounts",
"fedcm",
"popup_close",
"oauth2.js",
"select",
];
fn is_valid_custom_provider_name(s: &str) -> bool {
!s.is_empty()
&& s.chars()
.all(|c| c.is_ascii_lowercase() || c.is_ascii_digit() || c == '_' || c == '-')
}
fn is_valid_css_color(s: &str) -> bool {
match s.strip_prefix('#') {
Some(hex) => {
matches!(hex.len(), 3 | 4 | 6 | 8) && hex.chars().all(|c| c.is_ascii_hexdigit())
}
None => {
let len = s.len();
(3..=30).contains(&len) && s.chars().all(|c| c.is_ascii_lowercase())
}
}
}
pub(crate) fn validate_custom_slots() -> Result<(), String> {
let mut enabled_segments: Vec<(CustomSlot, ProviderName)> = Vec::new();
for &slot in CustomSlot::ALL {
let Some(cfg) = provider_for(ProviderKind::Custom(slot)) else {
continue;
};
let seg = cfg.provider_name;
if !is_valid_custom_provider_name(seg.as_str()) {
return Err(format!(
"{}_NAME='{}' is invalid: must match [a-z0-9_-]+",
slot.env_prefix(),
seg
));
}
if RESERVED_PROVIDER_NAMES.contains(&seg.as_str()) {
return Err(format!(
"{}_NAME='{}' collides with a reserved name",
slot.env_prefix(),
seg
));
}
if let Some((other_slot, _)) = enabled_segments.iter().find(|(_, s)| *s == seg) {
return Err(format!(
"{}_NAME='{}' collides with {}_NAME",
slot.env_prefix(),
seg,
other_slot.env_prefix()
));
}
enabled_segments.push((slot, seg));
if let Some(color) = cfg.button_color
&& !is_valid_css_color(color)
{
return Err(format!(
"{}_BUTTON_COLOR='{}' is invalid: expected '#rgb[a]', '#rrggbb[aa]', or a CSS color keyword (3-30 lowercase letters)",
slot.env_prefix(),
color
));
}
if let Some(color) = cfg.button_hover_color
&& !is_valid_css_color(color)
{
return Err(format!(
"{}_BUTTON_HOVER_COLOR='{}' is invalid: expected '#rgb[a]', '#rrggbb[aa]', or a CSS color keyword (3-30 lowercase letters)",
slot.env_prefix(),
color
));
}
if !is_valid_custom_provider_name(cfg.icon_slug) {
return Err(format!(
"{}_ICON_SLUG='{}' is invalid: must match [a-z0-9_-]+",
slot.env_prefix(),
cfg.icon_slug
));
}
}
Ok(())
}
pub(crate) fn validate_custom_slot_preset_shape() -> Result<(), String> {
for &slot in CustomSlot::ALL {
let prefix = slot.env_prefix();
if env::var(format!("{prefix}_CLIENT_ID")).is_err() {
continue;
}
let preset_key = env::var(format!("{prefix}_PRESET")).ok();
let has_preset = match preset_key.as_deref() {
None => false,
Some(key) => {
resolve_preset(key).map_err(|msg| format!("{prefix}_PRESET: {msg}"))?;
true
}
};
if !has_preset {
if env::var(format!("{prefix}_DISPLAY_NAME")).is_err() {
return Err(format!(
"{prefix}_CLIENT_ID is set without {prefix}_PRESET; {prefix}_DISPLAY_NAME is required"
));
}
if env::var(format!("{prefix}_NAME")).is_err() {
return Err(format!(
"{prefix}_CLIENT_ID is set without {prefix}_PRESET; {prefix}_NAME is required"
));
}
}
parse_prompt(&format!("{prefix}_PROMPT"))?;
}
Ok(())
}
pub(crate) fn provider_for(kind: ProviderKind) -> Option<&'static ProviderConfig> {
match kind {
ProviderKind::Google => Some(&GOOGLE_PROVIDER),
ProviderKind::Custom(slot) => match slot {
CustomSlot::Slot1 => CUSTOM1_PROVIDER.as_ref(),
CustomSlot::Slot2 => CUSTOM2_PROVIDER.as_ref(),
CustomSlot::Slot3 => CUSTOM3_PROVIDER.as_ref(),
CustomSlot::Slot4 => CUSTOM4_PROVIDER.as_ref(),
CustomSlot::Slot5 => CUSTOM5_PROVIDER.as_ref(),
CustomSlot::Slot6 => CUSTOM6_PROVIDER.as_ref(),
CustomSlot::Slot7 => CUSTOM7_PROVIDER.as_ref(),
CustomSlot::Slot8 => CUSTOM8_PROVIDER.as_ref(),
},
}
}
#[cfg(test)]
mod tests;