use std::collections::HashMap;
use std::sync::Arc;
use std::time::{Duration, Instant};
use parking_lot::RwLock;
use reqwest::Client as HttpClient;
use serde::{Deserialize, Serialize};
use thiserror::Error;
use url::Url;
use solid_pod_rs::security::ssrf::is_safe_url;
pub const CLIENT_CACHE_TTL: Duration = Duration::from_secs(5 * 60);
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct RegistrationRequest {
#[serde(default, skip_serializing_if = "Option::is_none")]
pub client_id: Option<String>,
#[serde(default)]
pub redirect_uris: Vec<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub client_name: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub client_uri: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub logo_uri: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub policy_uri: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub tos_uri: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub scope: Option<String>,
#[serde(default)]
pub grant_types: Vec<String>,
#[serde(default)]
pub response_types: Vec<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub token_endpoint_auth_method: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub application_type: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ClientDocument {
pub client_id: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub client_secret: Option<String>,
pub client_id_issued_at: u64,
pub redirect_uris: Vec<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub client_name: Option<String>,
pub grant_types: Vec<String>,
pub response_types: Vec<String>,
pub token_endpoint_auth_method: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub application_type: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub scope: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub client_id_document_url: Option<String>,
}
impl ClientDocument {
fn now_secs() -> u64 {
use std::time::{SystemTime, UNIX_EPOCH};
SystemTime::now()
.duration_since(UNIX_EPOCH)
.map(|d| d.as_secs())
.unwrap_or(0)
}
}
#[derive(Debug, Error)]
pub enum RegError {
#[error("invalid registration: {0}")]
InvalidRequest(String),
#[error("SSRF-blocked: {0}")]
Ssrf(String),
#[error("fetch failed: {0}")]
Fetch(String),
#[error("invalid client document: {0}")]
InvalidDocument(String),
}
#[derive(Clone)]
pub struct ClientStore {
inner: Arc<RwLock<ClientStoreInner>>,
http: Option<HttpClient>,
allow_unsafe_urls: bool,
}
impl Default for ClientStore {
fn default() -> Self {
Self::new()
}
}
#[derive(Default)]
struct ClientStoreInner {
registered: HashMap<String, ClientDocument>,
cache: HashMap<String, (ClientDocument, Instant)>,
}
impl ClientStore {
pub fn new() -> Self {
Self {
inner: Arc::new(RwLock::new(ClientStoreInner::default())),
http: HttpClient::builder()
.timeout(Duration::from_secs(10))
.redirect(reqwest::redirect::Policy::limited(3))
.build()
.ok(),
allow_unsafe_urls: false,
}
}
pub fn with_http(mut self, client: HttpClient) -> Self {
self.http = Some(client);
self
}
#[doc(hidden)]
pub fn allow_unsafe_urls_for_testing(mut self) -> Self {
self.allow_unsafe_urls = true;
self
}
pub fn insert(&self, client: ClientDocument) {
let mut inner = self.inner.write();
inner.registered.insert(client.client_id.clone(), client);
}
pub async fn find(&self, client_id: &str) -> Result<Option<ClientDocument>, RegError> {
if let Some(doc) = self.inner.read().registered.get(client_id).cloned() {
return Ok(Some(doc));
}
{
let inner = self.inner.read();
if let Some((doc, ts)) = inner.cache.get(client_id) {
if ts.elapsed() < CLIENT_CACHE_TTL {
return Ok(Some(doc.clone()));
}
}
}
if client_id.starts_with("http://") || client_id.starts_with("https://") {
let doc = self.fetch_client_document(client_id).await?;
let mut inner = self.inner.write();
inner
.cache
.insert(client_id.to_string(), (doc.clone(), Instant::now()));
return Ok(Some(doc));
}
Ok(None)
}
async fn fetch_client_document(&self, url: &str) -> Result<ClientDocument, RegError> {
if !self.allow_unsafe_urls {
is_safe_url(url).map_err(|e| RegError::Ssrf(e.to_string()))?;
}
let parsed = Url::parse(url)
.map_err(|e| RegError::InvalidDocument(format!("URL parse: {e}")))?;
if !matches!(parsed.scheme(), "http" | "https") {
return Err(RegError::InvalidDocument(format!(
"unsupported scheme: {}",
parsed.scheme()
)));
}
let http = self
.http
.as_ref()
.ok_or_else(|| RegError::Fetch("no HTTP client configured".into()))?;
let resp = http
.get(url)
.header("Accept", "application/ld+json, application/json")
.send()
.await
.map_err(|e| RegError::Fetch(e.to_string()))?;
if !resp.status().is_success() {
return Err(RegError::Fetch(format!(
"HTTP {} from {url}",
resp.status()
)));
}
let body: serde_json::Value = resp
.json()
.await
.map_err(|e| RegError::InvalidDocument(format!("JSON parse: {e}")))?;
if let Some(declared) = body.get("client_id").and_then(|v| v.as_str()) {
if declared != url {
return Err(RegError::InvalidDocument(format!(
"client_id mismatch: document says {declared}, URL is {url}"
)));
}
}
let redirect_uris: Vec<String> = body
.get("redirect_uris")
.and_then(|v| v.as_array())
.map(|arr| {
arr.iter()
.filter_map(|v| v.as_str().map(str::to_string))
.collect()
})
.unwrap_or_default();
if redirect_uris.is_empty() {
return Err(RegError::InvalidDocument(
"Client Identifier Document is missing redirect_uris".into(),
));
}
let client_name = body
.get("client_name")
.and_then(|v| v.as_str())
.or_else(|| body.get("name").and_then(|v| v.as_str()))
.map(str::to_string);
let scope = body
.get("scope")
.and_then(|v| v.as_str())
.map(str::to_string)
.or_else(|| Some("openid webid".into()));
Ok(ClientDocument {
client_id: url.to_string(),
client_secret: None,
client_id_issued_at: ClientDocument::now_secs(),
redirect_uris,
client_name,
grant_types: vec!["authorization_code".into(), "refresh_token".into()],
response_types: vec!["code".into()],
token_endpoint_auth_method: "none".into(),
application_type: Some("web".into()),
scope,
client_id_document_url: Some(url.to_string()),
})
}
}
pub async fn register_client(
store: &ClientStore,
req: RegistrationRequest,
) -> Result<ClientDocument, RegError> {
if let Some(id) = req.client_id.as_deref() {
if id.starts_with("http://") || id.starts_with("https://") {
if let Some(doc) = store.find(id).await? {
return Ok(doc);
}
return Err(RegError::InvalidDocument(
"Client Identifier Document fetch returned no document".into(),
));
}
}
if req.redirect_uris.is_empty() {
return Err(RegError::InvalidRequest(
"redirect_uris is required for authorization-code flow".into(),
));
}
let id_ts = u128::from(ClientDocument::now_secs()).max(1);
let ts36 = to_base36(id_ts);
let rand_tail: String = rand_base36(8);
let client_id = format!("client_{ts36}_{rand_tail}");
let auth_method = req
.token_endpoint_auth_method
.clone()
.unwrap_or_else(|| "none".into());
let client_secret = if auth_method == "none" {
None
} else {
Some(format!("secret-{}", uuid::Uuid::new_v4()))
};
let grant_types = if req.grant_types.is_empty() {
vec!["authorization_code".into(), "refresh_token".into()]
} else {
req.grant_types.clone()
};
let response_types = if req.response_types.is_empty() {
vec!["code".into()]
} else {
req.response_types.clone()
};
let doc = ClientDocument {
client_id,
client_secret,
client_id_issued_at: ClientDocument::now_secs(),
redirect_uris: req.redirect_uris,
client_name: req.client_name,
grant_types,
response_types,
token_endpoint_auth_method: auth_method,
application_type: req.application_type.or_else(|| Some("web".into())),
scope: req.scope.or_else(|| Some("openid webid".into())),
client_id_document_url: None,
};
store.insert(doc.clone());
Ok(doc)
}
fn to_base36(mut n: u128) -> String {
if n == 0 {
return "0".into();
}
const ALPHA: &[u8] = b"0123456789abcdefghijklmnopqrstuvwxyz";
let mut out = Vec::new();
while n > 0 {
out.push(ALPHA[(n % 36) as usize]);
n /= 36;
}
out.reverse();
String::from_utf8(out).unwrap_or_default()
}
fn rand_base36(len: usize) -> String {
use rand::Rng;
const ALPHA: &[u8] = b"0123456789abcdefghijklmnopqrstuvwxyz";
let mut rng = rand::thread_rng();
(0..len)
.map(|_| ALPHA[rng.gen_range(0..36)] as char)
.collect()
}
#[cfg(test)]
mod tests {
use super::*;
use wiremock::matchers::{method, path};
use wiremock::{Mock, MockServer, ResponseTemplate};
#[tokio::test]
async fn opaque_registration_assigns_prefixed_client_id() {
let store = ClientStore::new();
let req = RegistrationRequest {
redirect_uris: vec!["https://app.example/cb".into()],
client_name: Some("App".into()),
..Default::default()
};
let doc = register_client(&store, req).await.unwrap();
assert!(doc.client_id.starts_with("client_"));
assert!(doc.client_secret.is_none());
let again = store.find(&doc.client_id).await.unwrap().unwrap();
assert_eq!(again.client_id, doc.client_id);
}
#[tokio::test]
async fn registration_without_redirect_uris_is_rejected() {
let store = ClientStore::new();
let err = register_client(
&store,
RegistrationRequest {
..Default::default()
},
)
.await
.unwrap_err();
assert!(matches!(err, RegError::InvalidRequest(_)));
}
#[tokio::test]
async fn client_identifier_document_is_fetched_and_cached() {
let server = MockServer::start().await;
let cid_url = format!("{}/client#id", server.uri());
let body = serde_json::json!({
"@context": "https://www.w3.org/ns/solid/oidc-context.jsonld",
"client_id": cid_url,
"client_name": "Federated App",
"redirect_uris": ["https://app.example/cb"],
"grant_types": ["authorization_code", "refresh_token"],
"scope": "openid webid profile"
});
Mock::given(method("GET"))
.and(path("/client"))
.respond_with(ResponseTemplate::new(200).set_body_json(body.clone()))
.expect(1) .mount(&server)
.await;
let store = ClientStore::new().allow_unsafe_urls_for_testing();
let doc = store.find(&cid_url).await.unwrap().unwrap();
assert_eq!(doc.client_id, cid_url);
assert_eq!(doc.redirect_uris, vec!["https://app.example/cb".to_string()]);
assert_eq!(doc.client_name.as_deref(), Some("Federated App"));
assert_eq!(doc.client_id_document_url.as_deref(), Some(cid_url.as_str()));
let _ = store.find(&cid_url).await.unwrap().unwrap();
}
#[tokio::test]
async fn client_identifier_document_rejects_id_mismatch() {
let server = MockServer::start().await;
let cid_url = format!("{}/client", server.uri());
Mock::given(method("GET"))
.and(path("/client"))
.respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
"client_id": "https://malicious.example/evil",
"redirect_uris": ["https://malicious.example/cb"]
})))
.mount(&server)
.await;
let store = ClientStore::new().allow_unsafe_urls_for_testing();
let err = store.find(&cid_url).await.unwrap_err();
assert!(matches!(err, RegError::InvalidDocument(_)));
}
#[tokio::test]
async fn client_identifier_document_rejects_private_ip() {
let store = ClientStore::new();
let err = store.find("http://127.0.0.1/client").await.unwrap_err();
assert!(matches!(err, RegError::Ssrf(_)));
}
#[tokio::test]
async fn client_identifier_document_requires_redirect_uris() {
let server = MockServer::start().await;
let cid_url = format!("{}/client", server.uri());
Mock::given(method("GET"))
.and(path("/client"))
.respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
"client_id": cid_url,
"client_name": "Incomplete"
})))
.mount(&server)
.await;
let store = ClientStore::new().allow_unsafe_urls_for_testing();
let err = store.find(&cid_url).await.unwrap_err();
assert!(matches!(err, RegError::InvalidDocument(_)));
}
#[test]
fn base36_encode_sanity() {
assert_eq!(to_base36(0), "0");
assert_eq!(to_base36(35), "z");
assert_eq!(to_base36(36), "10");
}
}