use std::sync::Arc;
use pocopine_core::{ServerError, ServerResult};
use serde::{Deserialize, Deserializer, Serialize, Serializer};
use crate::role::{Permission, Role};
use crate::user::AuthUser;
#[derive(Clone, Debug, Default, Eq, PartialEq)]
pub struct Principal {
user: Option<Arc<AuthUser>>,
}
impl Serialize for Principal {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
use serde::ser::SerializeStruct;
let mut state = serializer.serialize_struct("Principal", 1)?;
state.serialize_field("user", &self.user.as_deref())?;
state.end()
}
}
impl<'de> Deserialize<'de> for Principal {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
#[derive(Deserialize)]
struct Wire {
user: Option<AuthUser>,
}
let wire = Wire::deserialize(deserializer)?;
Ok(Principal {
user: wire.user.map(Arc::new),
})
}
}
impl Principal {
pub fn anonymous() -> Self {
Self { user: None }
}
pub fn from_user(user: AuthUser) -> Self {
Self {
user: Some(Arc::new(user)),
}
}
pub fn from_arc(user: Arc<AuthUser>) -> Self {
Self { user: Some(user) }
}
pub fn is_authenticated(&self) -> bool {
self.user.is_some()
}
pub fn user(&self) -> Option<&AuthUser> {
self.user.as_deref()
}
pub fn user_arc(&self) -> Option<Arc<AuthUser>> {
self.user.clone()
}
pub fn require_user(&self) -> ServerResult<&AuthUser> {
self.user
.as_deref()
.ok_or_else(|| ServerError::unauthorized("login required"))
}
pub fn has_role(&self, role: &Role) -> bool {
self.user.as_deref().is_some_and(|user| user.has_role(role))
}
pub fn has_permission(&self, permission: &Permission) -> bool {
self.user
.as_deref()
.is_some_and(|user| user.has_permission(permission))
}
}
impl From<AuthUser> for Principal {
fn from(user: AuthUser) -> Self {
Self::from_user(user)
}
}
impl From<Arc<AuthUser>> for Principal {
fn from(user: Arc<AuthUser>) -> Self {
Self::from_arc(user)
}
}
#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
pub struct Session {
pub id: String,
pub user: AuthUser,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub expires_at_ms: Option<u64>,
}
impl Session {
pub fn new(id: impl Into<String>, user: AuthUser) -> Self {
Self {
id: id.into(),
user,
expires_at_ms: None,
}
}
pub fn with_expires_at_ms(mut self, expires_at_ms: u64) -> Self {
self.expires_at_ms = Some(expires_at_ms);
self
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn principal_user_clones_are_cheap() {
let user = AuthUser::new("uid-1");
let p1 = Principal::from_user(user);
let p2 = p1.clone();
let arc1 = p1.user_arc().expect("p1 has user");
let arc2 = p2.user_arc().expect("p2 has user");
assert!(Arc::ptr_eq(&arc1, &arc2));
}
#[test]
fn principal_serialization_emits_option_discriminant() {
let p = Principal::from_user(AuthUser::new("uid-1"));
let json = serde_json::to_string(&p).unwrap();
assert!(json.contains("\"user\":{"), "got: {json}");
assert!(json.contains("\"id\":\"uid-1\""), "got: {json}");
let round_tripped: Principal = serde_json::from_str(&json).unwrap();
assert_eq!(p, round_tripped);
assert_eq!(round_tripped.user().unwrap().id, "uid-1");
let anon = Principal::anonymous();
let json = serde_json::to_string(&anon).unwrap();
assert_eq!(json, r#"{"user":null}"#);
let round_tripped: Principal = serde_json::from_str(&json).unwrap();
assert!(!round_tripped.is_authenticated());
}
}