use crate::Response;
use std::any::Any;
use std::net::SocketAddr;
use std::sync::Arc;
type AuthMetadata = std::collections::HashMap<String, Arc<dyn Any + Send + Sync>>;
#[derive(Default, Clone)]
pub struct ConnectionContext {
pub remote_addr: Option<SocketAddr>,
pub metadata: AuthMetadata,
}
impl ConnectionContext {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn with_addr(remote_addr: SocketAddr) -> Self {
Self {
remote_addr: Some(remote_addr),
metadata: std::collections::HashMap::new(),
}
}
pub fn insert<T: Any + Send + Sync>(&mut self, key: String, value: T) {
self.metadata.insert(key, Arc::new(value));
}
#[must_use]
pub fn get<T: Any + Send + Sync>(&self, key: &str) -> Option<&T> {
self.metadata.get(key).and_then(|v| v.downcast_ref::<T>())
}
}
#[async_trait::async_trait]
pub trait ContextExtractor: Send + Sync {
async fn extract(
&self,
remote_addr: Option<SocketAddr>,
metadata: Option<Arc<dyn Any + Send + Sync>>,
) -> ConnectionContext;
}
pub struct DefaultContextExtractor;
#[async_trait::async_trait]
impl ContextExtractor for DefaultContextExtractor {
async fn extract(
&self,
remote_addr: Option<SocketAddr>,
_metadata: Option<Arc<dyn Any + Send + Sync>>,
) -> ConnectionContext {
ConnectionContext {
remote_addr,
metadata: std::collections::HashMap::new(),
}
}
}
pub trait AuthPolicy: Send + Sync {
fn can_access(
&self,
method: &str,
params: Option<&serde_json::Value>,
ctx: &ConnectionContext,
) -> bool;
fn unauthorized_error(&self, method: &str) -> Response {
let _ = method;
crate::ResponseBuilder::new()
.error(
crate::ErrorBuilder::new(crate::error_codes::INTERNAL_ERROR, "Unauthorized")
.build(),
)
.id(None)
.build()
}
}
pub struct AllowAll;
impl AuthPolicy for AllowAll {
fn can_access(
&self,
_method: &str,
_params: Option<&serde_json::Value>,
_ctx: &ConnectionContext,
) -> bool {
true
}
}
pub struct DenyAll;
impl AuthPolicy for DenyAll {
fn can_access(
&self,
_method: &str,
_params: Option<&serde_json::Value>,
_ctx: &ConnectionContext,
) -> bool {
false
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_allow_all() {
let policy = AllowAll;
let ctx = ConnectionContext::new();
assert!(policy.can_access("any_method", None, &ctx));
assert!(policy.can_access(
"another_method",
Some(&serde_json::json!({"key": "value"})),
&ctx
));
}
#[test]
fn test_deny_all() {
let policy = DenyAll;
let ctx = ConnectionContext::new();
assert!(!policy.can_access("any_method", None, &ctx));
assert!(!policy.can_access(
"another_method",
Some(&serde_json::json!({"key": "value"})),
&ctx
));
}
#[test]
fn test_connection_context() {
let mut ctx = ConnectionContext::new();
ctx.insert("user_id".to_string(), 42u64);
assert_eq!(ctx.get::<u64>("user_id"), Some(&42));
assert_eq!(ctx.get::<String>("user_id"), None);
assert_eq!(ctx.get::<u64>("other"), None);
}
#[test]
fn test_connection_context_with_addr() {
use std::net::{IpAddr, Ipv4Addr};
let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 8080);
let ctx = ConnectionContext::with_addr(addr);
assert_eq!(ctx.remote_addr, Some(addr));
assert_eq!(ctx.metadata.len(), 0);
}
#[test]
fn test_connection_context_default() {
let ctx = ConnectionContext::default();
assert!(ctx.remote_addr.is_none());
assert_eq!(ctx.metadata.len(), 0);
}
#[test]
fn test_connection_context_multiple_metadata() {
let mut ctx = ConnectionContext::new();
ctx.insert("user_id".to_string(), 123u64);
ctx.insert("username".to_string(), String::from("alice"));
ctx.insert("is_admin".to_string(), true);
assert_eq!(ctx.get::<u64>("user_id"), Some(&123));
assert_eq!(ctx.get::<String>("username"), Some(&String::from("alice")));
assert_eq!(ctx.get::<bool>("is_admin"), Some(&true));
}
#[test]
fn test_allow_all_unauthorized_error() {
let policy = AllowAll;
let response = policy.unauthorized_error("test_method");
assert!(response.error.is_some());
let error = response.error.unwrap();
assert_eq!(error.code, crate::error_codes::INTERNAL_ERROR);
assert_eq!(error.message, "Unauthorized");
}
#[test]
fn test_deny_all_unauthorized_error() {
let policy = DenyAll;
let response = policy.unauthorized_error("blocked_method");
assert!(response.error.is_some());
}
#[tokio::test]
async fn test_default_context_extractor() {
use std::net::{IpAddr, Ipv4Addr};
let extractor = DefaultContextExtractor;
let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1)), 9000);
let ctx = extractor.extract(Some(addr), None).await;
assert_eq!(ctx.remote_addr, Some(addr));
assert_eq!(ctx.metadata.len(), 0);
}
#[tokio::test]
async fn test_default_context_extractor_no_addr() {
let extractor = DefaultContextExtractor;
let ctx = extractor.extract(None, None).await;
assert!(ctx.remote_addr.is_none());
}
#[tokio::test]
async fn test_default_context_extractor_with_metadata() {
use std::net::{IpAddr, Ipv4Addr};
let extractor = DefaultContextExtractor;
let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1)), 3000);
let metadata: Arc<dyn Any + Send + Sync> = Arc::new(String::from("test"));
let ctx = extractor.extract(Some(addr), Some(metadata)).await;
assert_eq!(ctx.remote_addr, Some(addr));
}
#[test]
fn test_connection_context_clone() {
let mut ctx1 = ConnectionContext::new();
ctx1.insert("key".to_string(), 100u32);
let ctx2 = ctx1.clone();
assert_eq!(ctx2.get::<u32>("key"), Some(&100));
}
}