use alloc::boxed::Box;
use alloc::collections::BTreeMap;
use alloc::string::String;
use alloc::sync::Arc;
use alloc::vec::Vec;
use core::error::Error as StdError;
use core::fmt;
use core::future::Future;
use core::hash::Hash;
#[cfg(feature = "std")]
use miette::Diagnostic;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use smol_str::SmolStr;
use crate::bos::{BosStr, DefaultStr};
use crate::types::{did::Did, handle::Handle};
#[cfg(feature = "std")]
use std::path::{Path, PathBuf};
#[cfg(not(feature = "std"))]
use maitake_sync::RwLock;
#[cfg(feature = "std")]
use tokio::sync::RwLock;
#[derive(Debug, thiserror::Error)]
#[cfg_attr(feature = "std", derive(Diagnostic))]
#[non_exhaustive]
pub enum SessionStoreError {
#[cfg(feature = "std")]
#[error("I/O error: {0}")]
#[cfg_attr(feature = "std", diagnostic(code(jacquard::session_store::io)))]
Io(#[from] std::io::Error),
#[error("serialization error: {0}")]
#[cfg_attr(feature = "std", diagnostic(code(jacquard::session_store::serde)))]
Serde(#[from] serde_json::Error),
#[error(transparent)]
#[cfg_attr(feature = "std", diagnostic(code(jacquard::session_store::other)))]
Other(#[from] Box<dyn StdError + Send + Sync>),
}
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
pub struct SessionKey {
pub did: Did,
pub session_id: SmolStr,
}
impl SessionKey {
pub fn new(did: Did, session_id: impl Into<SmolStr>) -> Self {
Self {
did,
session_id: session_id.into(),
}
}
pub fn did(&self) -> Did<&str> {
self.did.borrow()
}
pub fn session_id(&self) -> &str {
self.session_id.as_str()
}
}
impl fmt::Display for SessionKey {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}/{}", self.did, self.session_id)
}
}
impl From<(Did, SmolStr)> for SessionKey {
fn from((did, session_id): (Did, SmolStr)) -> Self {
Self { did, session_id }
}
}
impl From<SessionKey> for (Did, SmolStr) {
fn from(key: SessionKey) -> Self {
(key.did, key.session_id)
}
}
impl SessionHint<DefaultStr> {
pub fn any() -> Self {
SessionHint::Any
}
pub fn key(key: SessionKey) -> Self {
SessionHint::Key(key)
}
pub fn identifier(identifier: DefaultStr) -> Self {
SessionHint::Identifier(identifier)
}
pub fn did(did: Did<DefaultStr>) -> Self {
SessionHint::Did(did)
}
pub fn handle(handle: Handle<DefaultStr>) -> Self {
SessionHint::Handle(handle)
}
}
impl<'a> SessionHint<&'a str> {
pub fn from_input(input: &'a str) -> Self {
if let Ok(did) = Did::new(input) {
SessionHint::Did(did)
} else if let Ok(handle) = Handle::new(input) {
SessionHint::Handle(handle)
} else {
SessionHint::Identifier(input)
}
}
pub fn from_optional_input(input: Option<&'a str>) -> Self {
match input {
Some(input) => Self::from_input(input),
None => SessionHint::Any,
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum SessionHint<S: BosStr = DefaultStr> {
Any,
Did(Did<S>),
Handle(Handle<S>),
Key(SessionKey),
Identifier(S),
}
pub fn match_session_key<I, S>(hint: &SessionHint<S>, keys: I) -> Option<SessionKey>
where
I: IntoIterator<Item = SessionKey>,
S: BosStr,
{
match hint {
SessionHint::Any => keys.into_iter().next(),
SessionHint::Did(did) => keys
.into_iter()
.find(|key| key.did.as_str() == did.as_ref()),
SessionHint::Handle(_) | SessionHint::Identifier(_) => None,
SessionHint::Key(target) => keys.into_iter().find(|key| key == target),
}
}
#[cfg_attr(not(target_arch = "wasm32"), trait_variant::make(Send))]
pub trait SessionSelector<M>: Send + Sync {
type Error;
fn select_session<S: BosStr + Send + Sync>(
&self,
hint: &SessionHint<S>,
) -> impl Future<Output = Result<Option<M>, Self::Error>>;
}
#[cfg_attr(not(target_arch = "wasm32"), trait_variant::make(Send))]
pub trait SessionStore<K, T>: Send + Sync
where
K: Eq + Hash,
T: Clone,
{
fn get(&self, key: &K) -> impl Future<Output = Option<T>>;
fn set(&self, key: K, session: T) -> impl Future<Output = Result<(), SessionStoreError>>;
fn del(&self, key: &K) -> impl Future<Output = Result<(), SessionStoreError>>;
fn list_keys(&self) -> impl Future<Output = Result<Vec<K>, SessionStoreError>>
where
K: Clone,
{
async { Ok(Vec::new()) }
}
}
#[derive(Clone)]
pub struct MemorySessionStore<K, T>(Arc<RwLock<BTreeMap<K, T>>>);
impl<K, T> Default for MemorySessionStore<K, T> {
fn default() -> Self {
Self(Arc::new(RwLock::new(BTreeMap::new())))
}
}
impl<K, T> SessionStore<K, T> for MemorySessionStore<K, T>
where
K: Eq + Hash + Send + Sync + Ord,
T: Clone + Send + Sync,
{
async fn get(&self, key: &K) -> Option<T> {
self.0.read().await.get(key).cloned()
}
async fn set(&self, key: K, session: T) -> Result<(), SessionStoreError> {
self.0.write().await.insert(key, session);
Ok(())
}
async fn del(&self, key: &K) -> Result<(), SessionStoreError> {
self.0.write().await.remove(key);
Ok(())
}
async fn list_keys(&self) -> Result<Vec<K>, SessionStoreError>
where
K: Clone,
{
Ok(self.0.read().await.keys().cloned().collect())
}
}
impl<T> SessionSelector<SessionKey> for MemorySessionStore<SessionKey, T>
where
T: Clone + Send + Sync,
{
type Error = SessionStoreError;
async fn select_session<S: BosStr + Send + Sync>(
&self,
hint: &SessionHint<S>,
) -> Result<Option<SessionKey>, Self::Error> {
Ok(match_session_key(hint, self.list_keys().await?))
}
}
#[cfg(feature = "std")]
#[derive(Clone, Debug)]
pub struct FileTokenStore {
pub path: PathBuf,
}
#[cfg(feature = "std")]
impl FileTokenStore {
pub fn try_new(path: impl AsRef<Path>) -> Result<Self, SessionStoreError> {
let path = path.as_ref();
if let Some(parent) = path.parent() {
if !parent.as_os_str().is_empty() && !parent.exists() {
std::fs::create_dir_all(parent)?;
}
}
if !path.exists() {
std::fs::write(path, b"{}")?;
}
Ok(Self {
path: path.to_path_buf(),
})
}
pub fn new(path: impl AsRef<Path>) -> Self {
Self::try_new(path).expect("failed to initialize FileTokenStore")
}
}
#[cfg(feature = "std")]
impl FileTokenStore {
pub fn get_value(&self, key: &str) -> Result<Option<Value>, SessionStoreError> {
let file = std::fs::read_to_string(&self.path)?;
let store: Value = serde_json::from_str(&file)?;
Ok(store.get(key).cloned())
}
pub fn set_value(&self, key: impl Into<String>, value: Value) -> Result<(), SessionStoreError> {
let file = std::fs::read_to_string(&self.path)?;
let mut store: Value = serde_json::from_str(&file)?;
if let Some(store) = store.as_object_mut() {
store.insert(key.into(), value);
std::fs::write(&self.path, serde_json::to_string_pretty(&store)?)?;
Ok(())
} else {
Err(SessionStoreError::Other("invalid store".into()))
}
}
pub fn remove_value(&self, key: &str) -> Result<(), SessionStoreError> {
let file = std::fs::read_to_string(&self.path)?;
let mut store: Value = serde_json::from_str(&file)?;
if let Some(store) = store.as_object_mut() {
store.remove(key);
std::fs::write(&self.path, serde_json::to_string_pretty(&store)?)?;
Ok(())
} else {
Err(SessionStoreError::Other("invalid store".into()))
}
}
pub fn entries(&self) -> Result<Vec<(String, Value)>, SessionStoreError> {
let file = std::fs::read_to_string(&self.path)?;
let store: Value = serde_json::from_str(&file)?;
if let Some(store) = store.as_object() {
Ok(store
.iter()
.map(|(key, value)| (key.clone(), value.clone()))
.collect())
} else {
Err(SessionStoreError::Other("invalid store".into()))
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use alloc::string::ToString;
#[test]
fn session_key_display_uses_slash_separator() {
let did = Did::new_static("did:plc:alice").unwrap();
let key = SessionKey::new(did, "session_1");
assert_eq!(key.to_string(), "did:plc:alice/session_1");
}
#[tokio::test]
async fn memory_store_lists_keys() {
let store = MemorySessionStore::<SessionKey, String>::default();
let key = SessionKey::new(Did::new_static("did:plc:alice").unwrap(), "session");
store.set(key.clone(), "value".to_string()).await.unwrap();
assert_eq!(store.list_keys().await.unwrap(), vec![key]);
}
struct EmptyStore;
impl SessionStore<SessionKey, String> for EmptyStore {
async fn get(&self, _key: &SessionKey) -> Option<String> {
None
}
async fn set(&self, _key: SessionKey, _session: String) -> Result<(), SessionStoreError> {
Ok(())
}
async fn del(&self, _key: &SessionKey) -> Result<(), SessionStoreError> {
Ok(())
}
}
#[tokio::test]
async fn default_list_keys_is_empty() {
assert!(EmptyStore.list_keys().await.unwrap().is_empty());
}
#[test]
fn match_session_key_is_resolver_free() {
let alice = SessionKey::new(Did::new_static("did:plc:alice").unwrap(), "a");
let bob = SessionKey::new(Did::new_static("did:plc:bob").unwrap(), "b");
let keys = vec![alice.clone(), bob.clone()];
assert_eq!(
match_session_key(&SessionHint::any(), keys.clone()),
Some(alice.clone())
);
assert_eq!(
match_session_key(&SessionHint::Did(bob.did.clone()), keys.clone()),
Some(bob.clone())
);
assert_eq!(
match_session_key(&SessionHint::key(bob.clone()), keys.clone()),
Some(bob.clone())
);
assert_eq!(
match_session_key(
&SessionHint::key(SessionKey::new(
Did::new_static("did:plc:carol").unwrap(),
"c",
)),
keys.clone(),
),
None
);
assert_eq!(match_session_key(&SessionHint::any(), Vec::new()), None);
assert_eq!(
match_session_key(
&SessionHint::<DefaultStr>::Handle(
Handle::new_static("alice.example.com").unwrap()
),
keys.clone(),
),
None
);
assert_eq!(
match_session_key(
&SessionHint::Identifier(SmolStr::new("alice@example.com")),
keys
),
None
);
}
}