use crate::auth::Principal;
use alloc::collections::BTreeMap;
use alloc::string::String;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
#[non_exhaustive]
pub enum TransportType {
#[default]
Stdio,
Http,
WebSocket,
Tcp,
Unix,
Wasm,
Channel,
Unknown,
}
impl TransportType {
#[inline]
pub fn is_network(&self) -> bool {
matches!(self, Self::Http | Self::WebSocket | Self::Tcp)
}
#[inline]
pub fn is_local(&self) -> bool {
matches!(self, Self::Stdio | Self::Unix | Self::Channel)
}
pub fn as_str(&self) -> &'static str {
match self {
Self::Stdio => "stdio",
Self::Http => "http",
Self::WebSocket => "websocket",
Self::Tcp => "tcp",
Self::Unix => "unix",
Self::Wasm => "wasm",
Self::Channel => "channel",
Self::Unknown => "unknown",
}
}
}
impl core::fmt::Display for TransportType {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
write!(f, "{}", self.as_str())
}
}
#[derive(Debug, Clone, Default)]
pub struct RequestContext {
pub request_id: String,
pub transport: TransportType,
pub metadata: BTreeMap<String, String>,
pub principal: Option<Principal>,
}
impl RequestContext {
pub fn new(request_id: impl Into<String>, transport: TransportType) -> Self {
Self {
request_id: request_id.into(),
transport,
metadata: BTreeMap::new(),
principal: None,
}
}
#[inline]
pub fn stdio() -> Self {
Self::new("", TransportType::Stdio)
}
#[inline]
pub fn http() -> Self {
Self::new("", TransportType::Http)
}
#[inline]
pub fn websocket() -> Self {
Self::new("", TransportType::WebSocket)
}
#[inline]
pub fn tcp() -> Self {
Self::new("", TransportType::Tcp)
}
#[inline]
pub fn wasm() -> Self {
Self::new("", TransportType::Wasm)
}
pub fn with_metadata(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
self.metadata.insert(key.into(), value.into());
self
}
pub fn insert_metadata(&mut self, key: impl Into<String>, value: impl Into<String>) {
self.metadata.insert(key.into(), value.into());
}
pub fn get_metadata(&self, key: &str) -> Option<&str> {
self.metadata.get(key).map(|s| s.as_str())
}
pub fn has_metadata(&self, key: &str) -> bool {
self.metadata.contains_key(key)
}
pub fn with_request_id(mut self, id: impl Into<String>) -> Self {
self.request_id = id.into();
self
}
pub fn has_request_id(&self) -> bool {
!self.request_id.is_empty()
}
pub fn with_principal(mut self, principal: Principal) -> Self {
self.principal = Some(principal);
self
}
pub fn set_principal(&mut self, principal: Principal) {
self.principal = Some(principal);
}
pub fn principal(&self) -> Option<&Principal> {
self.principal.as_ref()
}
pub fn is_authenticated(&self) -> bool {
self.principal.is_some()
}
pub fn subject(&self) -> Option<&str> {
self.principal.as_ref().map(|p| p.subject.as_str())
}
pub fn clear_principal(&mut self) {
self.principal = None;
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_transport_type_display() {
assert_eq!(TransportType::Stdio.to_string(), "stdio");
assert_eq!(TransportType::Http.to_string(), "http");
assert_eq!(TransportType::WebSocket.to_string(), "websocket");
assert_eq!(TransportType::Tcp.to_string(), "tcp");
assert_eq!(TransportType::Unix.to_string(), "unix");
assert_eq!(TransportType::Wasm.to_string(), "wasm");
assert_eq!(TransportType::Channel.to_string(), "channel");
assert_eq!(TransportType::Unknown.to_string(), "unknown");
}
#[test]
fn test_transport_type_classification() {
assert!(TransportType::Http.is_network());
assert!(TransportType::WebSocket.is_network());
assert!(TransportType::Tcp.is_network());
assert!(!TransportType::Stdio.is_network());
assert!(TransportType::Stdio.is_local());
assert!(TransportType::Unix.is_local());
assert!(TransportType::Channel.is_local());
assert!(!TransportType::Http.is_local());
}
#[test]
fn test_request_context_new() {
let ctx = RequestContext::new("test-123", TransportType::Http);
assert_eq!(ctx.request_id, "test-123");
assert_eq!(ctx.transport, TransportType::Http);
assert!(ctx.metadata.is_empty());
}
#[test]
fn test_request_context_factory_methods() {
assert_eq!(RequestContext::stdio().transport, TransportType::Stdio);
assert_eq!(RequestContext::http().transport, TransportType::Http);
assert_eq!(
RequestContext::websocket().transport,
TransportType::WebSocket
);
assert_eq!(RequestContext::tcp().transport, TransportType::Tcp);
assert_eq!(RequestContext::wasm().transport, TransportType::Wasm);
}
#[test]
fn test_request_context_metadata() {
let ctx = RequestContext::new("1", TransportType::Http)
.with_metadata("key1", "value1")
.with_metadata("key2", "value2");
assert_eq!(ctx.get_metadata("key1"), Some("value1"));
assert_eq!(ctx.get_metadata("key2"), Some("value2"));
assert_eq!(ctx.get_metadata("key3"), None);
assert!(ctx.has_metadata("key1"));
assert!(!ctx.has_metadata("key3"));
}
#[test]
fn test_request_context_mutable_metadata() {
let mut ctx = RequestContext::new("1", TransportType::Http);
ctx.insert_metadata("key", "value");
assert_eq!(ctx.get_metadata("key"), Some("value"));
}
#[test]
fn test_request_context_request_id() {
let ctx = RequestContext::new("", TransportType::Http);
assert!(!ctx.has_request_id());
let ctx = ctx.with_request_id("request-456");
assert!(ctx.has_request_id());
assert_eq!(ctx.request_id, "request-456");
}
#[test]
fn test_request_context_default() {
let ctx = RequestContext::default();
assert_eq!(ctx.request_id, "");
assert_eq!(ctx.transport, TransportType::Stdio);
assert!(ctx.metadata.is_empty());
}
#[test]
fn test_request_context_clone() {
let ctx1 = RequestContext::new("1", TransportType::Http).with_metadata("key", "value");
let ctx2 = ctx1.clone();
assert_eq!(ctx1.request_id, ctx2.request_id);
assert_eq!(ctx1.transport, ctx2.transport);
assert_eq!(ctx1.get_metadata("key"), ctx2.get_metadata("key"));
}
#[test]
fn test_request_context_principal() {
let ctx = RequestContext::new("1", TransportType::Http);
assert!(!ctx.is_authenticated());
assert!(ctx.principal().is_none());
assert!(ctx.subject().is_none());
let principal = Principal::new("user-123")
.with_email("user@example.com")
.with_role("admin");
let ctx = ctx.with_principal(principal);
assert!(ctx.is_authenticated());
assert!(ctx.principal().is_some());
assert_eq!(ctx.subject(), Some("user-123"));
assert_eq!(
ctx.principal().unwrap().email,
Some("user@example.com".to_string())
);
assert!(ctx.principal().unwrap().has_role("admin"));
}
#[test]
fn test_request_context_principal_mutable() {
let mut ctx = RequestContext::new("1", TransportType::Http);
assert!(!ctx.is_authenticated());
ctx.set_principal(Principal::new("user-456"));
assert!(ctx.is_authenticated());
assert_eq!(ctx.subject(), Some("user-456"));
ctx.clear_principal();
assert!(!ctx.is_authenticated());
}
}