use std::borrow::Cow;
use std::fmt;
use std::str::FromStr;
use rand::Rng;
use schemars::{JsonSchema, Schema, SchemaGenerator, json_schema};
use secrecy::{ExposeSecret, SecretString};
use serde::{Deserialize, Deserializer, Serialize, Serializer};
pub const INVALID_OR_UNAVAILABLE_PIONEER_CODE: &str = "invalid or unavailable pioneer code";
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct PioneerCodeError;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, schemars::JsonSchema)]
#[serde(rename_all = "snake_case")]
pub enum PioneerCodeStatus {
Active,
Revoked,
}
impl PioneerCodeStatus {
pub fn as_str(self) -> &'static str {
match self {
Self::Active => "active",
Self::Revoked => "revoked",
}
}
pub fn from_storage_value(value: &str) -> Option<Self> {
match value {
"active" => Some(Self::Active),
"revoked" => Some(Self::Revoked),
_ => None,
}
}
}
impl fmt::Display for PioneerCodeStatus {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str(self.as_str())
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, schemars::JsonSchema)]
#[serde(rename_all = "snake_case")]
pub enum PioneerCodeUseKind {
HumanGithubSignIn,
AgentApi,
}
impl PioneerCodeUseKind {
pub fn as_str(self) -> &'static str {
match self {
Self::HumanGithubSignIn => "human_github_sign_in",
Self::AgentApi => "agent_api",
}
}
pub fn from_storage_value(value: &str) -> Option<Self> {
match value {
"human_github_sign_in" => Some(Self::HumanGithubSignIn),
"agent_api" => Some(Self::AgentApi),
_ => None,
}
}
}
impl fmt::Display for PioneerCodeUseKind {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str(self.as_str())
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, schemars::JsonSchema)]
#[serde(rename_all = "snake_case")]
pub enum PioneerCodeSubjectKind {
Human,
Agent,
}
impl PioneerCodeSubjectKind {
pub fn as_str(self) -> &'static str {
match self {
Self::Human => "human",
Self::Agent => "agent",
}
}
pub fn from_storage_value(value: &str) -> Option<Self> {
match value {
"human" => Some(Self::Human),
"agent" => Some(Self::Agent),
_ => None,
}
}
}
impl fmt::Display for PioneerCodeSubjectKind {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str(self.as_str())
}
}
impl fmt::Display for PioneerCodeError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str(
"pioneer_code must be 8 lowercase hex chars or <label>-<8 lowercase hex chars>; label may use lowercase letters, digits, or _ and must be at most 6 chars",
)
}
}
impl std::error::Error for PioneerCodeError {}
#[derive(Clone)]
pub struct PioneerCode(SecretString);
impl PioneerCode {
pub fn try_new(value: impl Into<String>) -> Result<Self, PioneerCodeError> {
let value = value.into();
validate_pioneer_code(&value)?;
Ok(Self(SecretString::from(value)))
}
pub fn generate(label: Option<&str>) -> Result<Self, PioneerCodeError> {
let mut bytes = [0u8; 4];
rand::rng().fill_bytes(&mut bytes);
let random_hex = hex::encode(bytes);
let code = match label {
Some(label) => {
validate_pioneer_label(label)?;
format!("{label}-{random_hex}")
}
None => random_hex,
};
Self::try_new(code)
}
pub fn expose_secret(&self) -> &str {
self.0.expose_secret()
}
pub fn label(&self) -> Option<&str> {
self.expose_secret()
.split_once('-')
.map(|(label, _random)| label)
}
}
#[derive(Clone)]
pub struct PioneerCodeInput(SecretString);
impl PioneerCodeInput {
pub fn try_new(value: impl Into<String>) -> Result<Self, PioneerCodeError> {
Ok(Self(SecretString::from(value.into())))
}
pub fn expose_secret(&self) -> &str {
self.0.expose_secret()
}
}
impl fmt::Debug for PioneerCodeInput {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str("PioneerCodeInput([redacted])")
}
}
impl Serialize for PioneerCodeInput {
fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
where
S: Serializer,
{
serializer.serialize_str(self.expose_secret())
}
}
impl<'de> Deserialize<'de> for PioneerCodeInput {
fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
where
D: Deserializer<'de>,
{
let value = String::deserialize(deserializer)?;
Self::try_new(value).map_err(serde::de::Error::custom)
}
}
impl JsonSchema for PioneerCodeInput {
fn inline_schema() -> bool {
true
}
fn schema_name() -> Cow<'static, str> {
"PioneerCodeInput".into()
}
fn json_schema(_: &mut SchemaGenerator) -> Schema {
json_schema!({ "type": "string" })
}
}
impl fmt::Debug for PioneerCode {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str("PioneerCode([redacted])")
}
}
impl fmt::Display for PioneerCode {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str("[redacted pioneer code]")
}
}
impl FromStr for PioneerCode {
type Err = PioneerCodeError;
fn from_str(value: &str) -> Result<Self, Self::Err> {
Self::try_new(value.to_string())
}
}
impl Serialize for PioneerCode {
fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
where
S: Serializer,
{
serializer.serialize_str(self.expose_secret())
}
}
impl<'de> Deserialize<'de> for PioneerCode {
fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
where
D: Deserializer<'de>,
{
let value = String::deserialize(deserializer)?;
Self::try_new(value).map_err(serde::de::Error::custom)
}
}
impl JsonSchema for PioneerCode {
fn inline_schema() -> bool {
true
}
fn schema_name() -> Cow<'static, str> {
"PioneerCode".into()
}
fn json_schema(_: &mut SchemaGenerator) -> Schema {
json_schema!({
"type": "string",
"pattern": "^([a-z0-9_]{1,6}-)?[0-9a-f]{8}$"
})
}
}
fn validate_pioneer_code(value: &str) -> Result<(), PioneerCodeError> {
if let Some((label, random_hex)) = value.split_once('-') {
if random_hex.contains('-') {
return Err(PioneerCodeError);
}
validate_pioneer_label(label)?;
validate_random_hex(random_hex)?;
} else {
validate_random_hex(value)?;
}
Ok(())
}
fn validate_pioneer_label(label: &str) -> Result<(), PioneerCodeError> {
if label.is_empty() || label.len() > 6 {
return Err(PioneerCodeError);
}
if !label
.bytes()
.all(|byte| matches!(byte, b'a'..=b'z' | b'0'..=b'9' | b'_'))
{
return Err(PioneerCodeError);
}
Ok(())
}
fn validate_random_hex(random_hex: &str) -> Result<(), PioneerCodeError> {
if random_hex.len() != 8 {
return Err(PioneerCodeError);
}
if !random_hex
.bytes()
.all(|byte| matches!(byte, b'0'..=b'9' | b'a'..=b'f'))
{
return Err(PioneerCodeError);
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::PioneerCode;
#[test]
fn accepts_plain_and_labeled_codes() {
let plain = PioneerCode::try_new("deadbeef").expect("plain code should parse");
assert_eq!(plain.expose_secret(), "deadbeef");
assert_eq!(plain.label(), None);
let labeled = PioneerCode::try_new("jack_1-deadbeef").expect("labeled code should parse");
assert_eq!(labeled.expose_secret(), "jack_1-deadbeef");
assert_eq!(labeled.label(), Some("jack_1"));
}
#[test]
fn rejects_invalid_codes() {
for value in [
"",
"DEADBEEF",
"deadbee",
"deadbeef00",
"labeltoolong-deadbeef",
"bad-label-deadbeef",
"bad!-deadbeef",
"-deadbeef",
"jack-DEADBEEF",
"jack-deadbee!",
] {
assert!(PioneerCode::try_new(value).is_err(), "{value}");
}
}
#[test]
fn generated_labeled_code_keeps_label() {
let code = PioneerCode::generate(Some("jack")).expect("generated code should be valid");
assert_eq!(code.label(), Some("jack"));
assert!(code.expose_secret().starts_with("jack-"));
}
#[test]
fn serde_uses_string_wire_shape() {
let code: PioneerCode =
serde_json::from_str("\"deadbeef\"").expect("valid code should deserialize");
assert_eq!(code.expose_secret(), "deadbeef");
assert_eq!(
serde_json::to_string(&code).expect("code should serialize"),
"\"deadbeef\""
);
}
}