use base64::prelude::*;
use rand::distributions::Alphanumeric;
use rand::Rng;
use serde::{de::DeserializeOwned, Deserialize, Serialize};
use std::fmt;
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct StateParam {
value: String,
is_structured: bool,
}
#[derive(Serialize, Deserialize)]
struct StructuredState<T> {
nonce: String,
data: T,
}
#[derive(Deserialize)]
struct NonceOnly {
nonce: String,
}
const _: fn() = || {
const fn assert_send_sync<T: Send + Sync>() {}
assert_send_sync::<StateParam>();
};
impl StateParam {
const NONCE_LENGTH: usize = 15;
#[must_use]
pub fn new() -> Self {
let nonce: String = rand::thread_rng()
.sample_iter(&Alphanumeric)
.take(Self::NONCE_LENGTH)
.map(char::from)
.collect();
Self {
value: nonce,
is_structured: false,
}
}
#[must_use]
pub fn with_data<T: Serialize>(data: &T) -> Self {
let nonce: String = rand::thread_rng()
.sample_iter(&Alphanumeric)
.take(Self::NONCE_LENGTH)
.map(char::from)
.collect();
let structured = StructuredState { nonce, data };
let json = serde_json::to_string(&structured).unwrap_or_default();
let encoded = BASE64_STANDARD.encode(json.as_bytes());
Self {
value: encoded,
is_structured: true,
}
}
#[must_use]
pub fn from_raw(raw: impl Into<String>) -> Self {
Self {
value: raw.into(),
is_structured: false,
}
}
#[must_use]
pub fn nonce(&self) -> &str {
&self.value
}
#[must_use]
pub fn extract_data<T: DeserializeOwned>(&self) -> Option<T> {
let decoded = BASE64_STANDARD.decode(self.value.as_bytes()).ok()?;
let json_str = String::from_utf8(decoded).ok()?;
let structured: StructuredState<T> = serde_json::from_str(&json_str).ok()?;
Some(structured.data)
}
#[must_use]
pub fn extract_nonce(&self) -> String {
if !self.is_structured {
return self.value.clone();
}
if let Ok(decoded) = BASE64_STANDARD.decode(self.value.as_bytes()) {
if let Ok(json_str) = String::from_utf8(decoded) {
if let Ok(nonce_only) = serde_json::from_str::<NonceOnly>(&json_str) {
return nonce_only.nonce;
}
}
}
self.value.clone()
}
}
impl Default for StateParam {
fn default() -> Self {
Self::new()
}
}
impl fmt::Display for StateParam {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.value)
}
}
impl AsRef<str> for StateParam {
fn as_ref(&self) -> &str {
&self.value
}
}
#[cfg(test)]
mod tests {
use super::*;
use serde::{Deserialize, Serialize};
#[test]
fn test_new_generates_15_char_alphanumeric_nonce() {
let state = StateParam::new();
let nonce = state.nonce();
assert_eq!(nonce.len(), 15);
assert!(nonce.chars().all(|c| c.is_ascii_alphanumeric()));
}
#[test]
fn test_new_generates_unique_nonces() {
let state1 = StateParam::new();
let state2 = StateParam::new();
assert_ne!(state1.nonce(), state2.nonce());
}
#[test]
fn test_with_data_embeds_json_in_base64() {
#[derive(Serialize, Deserialize, Debug, PartialEq)]
struct TestData {
return_url: String,
}
let data = TestData {
return_url: "/dashboard".to_string(),
};
let state = StateParam::with_data(&data);
let decoded = BASE64_STANDARD.decode(state.as_ref().as_bytes());
assert!(decoded.is_ok());
let json_str = String::from_utf8(decoded.unwrap()).unwrap();
assert!(json_str.contains("nonce"));
assert!(json_str.contains("data"));
assert!(json_str.contains("/dashboard"));
}
#[test]
fn test_from_raw_wraps_string_correctly() {
let state = StateParam::from_raw("custom-state-123");
assert_eq!(state.as_ref(), "custom-state-123");
assert_eq!(state.nonce(), "custom-state-123");
}
#[test]
fn test_nonce_returns_value_for_simple_state() {
let state = StateParam::new();
assert_eq!(state.nonce().len(), 15);
}
#[test]
fn test_nonce_returns_full_value_for_structured_state() {
let state = StateParam::with_data(&"test");
assert!(state.nonce().len() > 15);
assert_eq!(state.extract_nonce().len(), 15);
}
#[test]
fn test_extract_data_returns_embedded_data() {
#[derive(Serialize, Deserialize, Debug, PartialEq)]
struct FlowData {
user_id: u64,
redirect_to: String,
}
let original = FlowData {
user_id: 12345,
redirect_to: "/admin/orders".to_string(),
};
let state = StateParam::with_data(&original);
let extracted: Option<FlowData> = state.extract_data();
assert_eq!(extracted, Some(original));
}
#[test]
fn test_extract_data_returns_none_for_simple_state() {
#[derive(Deserialize)]
struct SomeData {
#[allow(dead_code)]
field: String,
}
let state = StateParam::new();
let extracted: Option<SomeData> = state.extract_data();
assert!(extracted.is_none());
}
#[test]
fn test_extract_data_returns_none_for_type_mismatch() {
#[derive(Serialize)]
struct DataA {
field_a: String,
}
#[derive(Deserialize)]
struct DataB {
#[allow(dead_code)]
field_b: i32,
}
let data = DataA {
field_a: "test".to_string(),
};
let state = StateParam::with_data(&data);
let extracted: Option<DataB> = state.extract_data();
assert!(extracted.is_none());
}
#[test]
fn test_display_returns_full_state_string() {
let state = StateParam::from_raw("display-test");
assert_eq!(format!("{}", state), "display-test");
let state = StateParam::new();
assert_eq!(format!("{}", state), state.as_ref());
}
#[test]
fn test_as_ref_provides_string_slice() {
let state = StateParam::from_raw("ref-test");
let s: &str = state.as_ref();
assert_eq!(s, "ref-test");
}
#[test]
fn test_with_data_handles_various_types() {
let state = StateParam::with_data(&"simple string");
let extracted: Option<String> = state.extract_data();
assert_eq!(extracted, Some("simple string".to_string()));
let state = StateParam::with_data(&42i32);
let extracted: Option<i32> = state.extract_data();
assert_eq!(extracted, Some(42));
let state = StateParam::with_data(&vec![1, 2, 3]);
let extracted: Option<Vec<i32>> = state.extract_data();
assert_eq!(extracted, Some(vec![1, 2, 3]));
}
#[test]
fn test_extract_nonce_from_structured_state() {
#[derive(Serialize)]
struct Data {
value: i32,
}
let state = StateParam::with_data(&Data { value: 42 });
let nonce = state.extract_nonce();
assert_eq!(nonce.len(), 15);
assert!(nonce.chars().all(|c| c.is_ascii_alphanumeric()));
}
#[test]
fn test_state_param_is_send_sync() {
fn assert_send_sync<T: Send + Sync>() {}
assert_send_sync::<StateParam>();
}
#[test]
fn test_state_param_clone() {
let state = StateParam::new();
let cloned = state.clone();
assert_eq!(state, cloned);
}
#[test]
fn test_state_param_eq() {
let state1 = StateParam::from_raw("same");
let state2 = StateParam::from_raw("same");
let state3 = StateParam::from_raw("different");
assert_eq!(state1, state2);
assert_ne!(state1, state3);
}
#[test]
fn test_state_param_default() {
let state = StateParam::default();
assert_eq!(state.nonce().len(), 15);
}
}