use bytes::Bytes;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
pub const PROTOCOL_VERSION: u32 = 2;
pub const MAX_MESSAGE_SIZE: usize = 10 * 1024 * 1024;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum EventType {
Configure,
RequestHeaders,
RequestBodyChunk,
ResponseHeaders,
ResponseBodyChunk,
RequestComplete,
WebSocketFrame,
GuardrailInspect,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default)]
#[serde(rename_all = "snake_case")]
pub enum Decision {
#[default]
Allow,
Block {
status: u16,
body: Option<String>,
headers: Option<HashMap<String, String>>,
},
Redirect {
url: String,
status: u16,
},
Challenge {
challenge_type: String,
params: HashMap<String, String>,
},
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum HeaderOp {
Set { name: String, value: String },
Add { name: String, value: String },
Remove { name: String },
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct BodyMutation {
pub data: Option<String>,
#[serde(default)]
pub chunk_index: u32,
}
impl BodyMutation {
pub fn pass_through(chunk_index: u32) -> Self {
Self {
data: None,
chunk_index,
}
}
pub fn drop_chunk(chunk_index: u32) -> Self {
Self {
data: Some(String::new()),
chunk_index,
}
}
pub fn replace(chunk_index: u32, data: String) -> Self {
Self {
data: Some(data),
chunk_index,
}
}
pub fn is_pass_through(&self) -> bool {
self.data.is_none()
}
pub fn is_drop(&self) -> bool {
matches!(&self.data, Some(d) if d.is_empty())
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RequestMetadata {
pub correlation_id: String,
pub request_id: String,
pub client_ip: String,
pub client_port: u16,
pub server_name: Option<String>,
pub protocol: String,
pub tls_version: Option<String>,
pub tls_cipher: Option<String>,
pub route_id: Option<String>,
pub upstream_id: Option<String>,
pub timestamp: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub traceparent: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RequestHeadersEvent {
pub metadata: RequestMetadata,
pub method: String,
pub uri: String,
pub headers: HashMap<String, Vec<String>>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RequestBodyChunkEvent {
pub correlation_id: String,
pub data: String,
pub is_last: bool,
pub total_size: Option<usize>,
#[serde(default)]
pub chunk_index: u32,
#[serde(default)]
pub bytes_received: usize,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ResponseHeadersEvent {
pub correlation_id: String,
pub status: u16,
pub headers: HashMap<String, Vec<String>>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ResponseBodyChunkEvent {
pub correlation_id: String,
pub data: String,
pub is_last: bool,
pub total_size: Option<usize>,
#[serde(default)]
pub chunk_index: u32,
#[serde(default)]
pub bytes_sent: usize,
}
#[derive(Debug, Clone)]
pub struct BinaryRequestBodyChunkEvent {
pub correlation_id: String,
pub data: Bytes,
pub is_last: bool,
pub total_size: Option<usize>,
pub chunk_index: u32,
pub bytes_received: usize,
}
#[derive(Debug, Clone)]
pub struct BinaryResponseBodyChunkEvent {
pub correlation_id: String,
pub data: Bytes,
pub is_last: bool,
pub total_size: Option<usize>,
pub chunk_index: u32,
pub bytes_sent: usize,
}
impl BinaryRequestBodyChunkEvent {
pub fn new(
correlation_id: impl Into<String>,
data: impl Into<Bytes>,
chunk_index: u32,
is_last: bool,
) -> Self {
let data = data.into();
Self {
correlation_id: correlation_id.into(),
bytes_received: data.len(),
data,
is_last,
total_size: None,
chunk_index,
}
}
pub fn with_total_size(mut self, size: usize) -> Self {
self.total_size = Some(size);
self
}
pub fn with_bytes_received(mut self, bytes: usize) -> Self {
self.bytes_received = bytes;
self
}
}
impl BinaryResponseBodyChunkEvent {
pub fn new(
correlation_id: impl Into<String>,
data: impl Into<Bytes>,
chunk_index: u32,
is_last: bool,
) -> Self {
let data = data.into();
Self {
correlation_id: correlation_id.into(),
bytes_sent: data.len(),
data,
is_last,
total_size: None,
chunk_index,
}
}
pub fn with_total_size(mut self, size: usize) -> Self {
self.total_size = Some(size);
self
}
pub fn with_bytes_sent(mut self, bytes: usize) -> Self {
self.bytes_sent = bytes;
self
}
}
impl From<BinaryRequestBodyChunkEvent> for RequestBodyChunkEvent {
fn from(event: BinaryRequestBodyChunkEvent) -> Self {
use base64::{engine::general_purpose::STANDARD, Engine as _};
Self {
correlation_id: event.correlation_id,
data: STANDARD.encode(&event.data),
is_last: event.is_last,
total_size: event.total_size,
chunk_index: event.chunk_index,
bytes_received: event.bytes_received,
}
}
}
impl From<&RequestBodyChunkEvent> for BinaryRequestBodyChunkEvent {
fn from(event: &RequestBodyChunkEvent) -> Self {
use base64::{engine::general_purpose::STANDARD, Engine as _};
let data = STANDARD
.decode(&event.data)
.map(Bytes::from)
.unwrap_or_else(|_| Bytes::copy_from_slice(event.data.as_bytes()));
Self {
correlation_id: event.correlation_id.clone(),
data,
is_last: event.is_last,
total_size: event.total_size,
chunk_index: event.chunk_index,
bytes_received: event.bytes_received,
}
}
}
impl From<BinaryResponseBodyChunkEvent> for ResponseBodyChunkEvent {
fn from(event: BinaryResponseBodyChunkEvent) -> Self {
use base64::{engine::general_purpose::STANDARD, Engine as _};
Self {
correlation_id: event.correlation_id,
data: STANDARD.encode(&event.data),
is_last: event.is_last,
total_size: event.total_size,
chunk_index: event.chunk_index,
bytes_sent: event.bytes_sent,
}
}
}
impl From<&ResponseBodyChunkEvent> for BinaryResponseBodyChunkEvent {
fn from(event: &ResponseBodyChunkEvent) -> Self {
use base64::{engine::general_purpose::STANDARD, Engine as _};
let data = STANDARD
.decode(&event.data)
.map(Bytes::from)
.unwrap_or_else(|_| Bytes::copy_from_slice(event.data.as_bytes()));
Self {
correlation_id: event.correlation_id.clone(),
data,
is_last: event.is_last,
total_size: event.total_size,
chunk_index: event.chunk_index,
bytes_sent: event.bytes_sent,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RequestCompleteEvent {
pub correlation_id: String,
pub status: u16,
pub duration_ms: u64,
pub request_body_size: usize,
pub response_body_size: usize,
pub upstream_attempts: u32,
pub error: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct WebSocketFrameEvent {
pub correlation_id: String,
pub opcode: String,
pub data: String,
pub client_to_server: bool,
pub frame_index: u64,
pub fin: bool,
pub route_id: Option<String>,
pub client_ip: String,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum WebSocketOpcode {
Continuation,
Text,
Binary,
Close,
Ping,
Pong,
}
impl WebSocketOpcode {
pub fn as_str(&self) -> &'static str {
match self {
Self::Continuation => "continuation",
Self::Text => "text",
Self::Binary => "binary",
Self::Close => "close",
Self::Ping => "ping",
Self::Pong => "pong",
}
}
pub fn from_u8(value: u8) -> Option<Self> {
match value {
0x0 => Some(Self::Continuation),
0x1 => Some(Self::Text),
0x2 => Some(Self::Binary),
0x8 => Some(Self::Close),
0x9 => Some(Self::Ping),
0xA => Some(Self::Pong),
_ => None,
}
}
pub fn as_u8(&self) -> u8 {
match self {
Self::Continuation => 0x0,
Self::Text => 0x1,
Self::Binary => 0x2,
Self::Close => 0x8,
Self::Ping => 0x9,
Self::Pong => 0xA,
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default)]
#[serde(rename_all = "snake_case")]
pub enum WebSocketDecision {
#[default]
Allow,
Drop,
Close {
code: u16,
reason: String,
},
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AgentResponse {
pub version: u32,
pub decision: Decision,
#[serde(default)]
pub request_headers: Vec<HeaderOp>,
#[serde(default)]
pub response_headers: Vec<HeaderOp>,
#[serde(default)]
pub routing_metadata: HashMap<String, String>,
#[serde(default)]
pub audit: AuditMetadata,
#[serde(default)]
pub needs_more: bool,
#[serde(default)]
pub request_body_mutation: Option<BodyMutation>,
#[serde(default)]
pub response_body_mutation: Option<BodyMutation>,
#[serde(default)]
pub websocket_decision: Option<WebSocketDecision>,
}
impl AgentResponse {
pub fn default_allow() -> Self {
Self {
version: PROTOCOL_VERSION,
decision: Decision::Allow,
request_headers: vec![],
response_headers: vec![],
routing_metadata: HashMap::new(),
audit: AuditMetadata::default(),
needs_more: false,
request_body_mutation: None,
response_body_mutation: None,
websocket_decision: None,
}
}
pub fn block(status: u16, body: Option<String>) -> Self {
Self {
version: PROTOCOL_VERSION,
decision: Decision::Block {
status,
body,
headers: None,
},
request_headers: vec![],
response_headers: vec![],
routing_metadata: HashMap::new(),
audit: AuditMetadata::default(),
needs_more: false,
request_body_mutation: None,
response_body_mutation: None,
websocket_decision: None,
}
}
pub fn redirect(url: String, status: u16) -> Self {
Self {
version: PROTOCOL_VERSION,
decision: Decision::Redirect { url, status },
request_headers: vec![],
response_headers: vec![],
routing_metadata: HashMap::new(),
audit: AuditMetadata::default(),
needs_more: false,
request_body_mutation: None,
response_body_mutation: None,
websocket_decision: None,
}
}
pub fn needs_more_data() -> Self {
Self {
version: PROTOCOL_VERSION,
decision: Decision::Allow,
request_headers: vec![],
response_headers: vec![],
routing_metadata: HashMap::new(),
audit: AuditMetadata::default(),
needs_more: true,
request_body_mutation: None,
response_body_mutation: None,
websocket_decision: None,
}
}
pub fn websocket_allow() -> Self {
Self {
websocket_decision: Some(WebSocketDecision::Allow),
..Self::default_allow()
}
}
pub fn websocket_drop() -> Self {
Self {
websocket_decision: Some(WebSocketDecision::Drop),
..Self::default_allow()
}
}
pub fn websocket_close(code: u16, reason: String) -> Self {
Self {
websocket_decision: Some(WebSocketDecision::Close { code, reason }),
..Self::default_allow()
}
}
pub fn with_websocket_decision(mut self, decision: WebSocketDecision) -> Self {
self.websocket_decision = Some(decision);
self
}
pub fn with_request_body_mutation(mut self, mutation: BodyMutation) -> Self {
self.request_body_mutation = Some(mutation);
self
}
pub fn with_response_body_mutation(mut self, mutation: BodyMutation) -> Self {
self.response_body_mutation = Some(mutation);
self
}
pub fn set_needs_more(mut self, needs_more: bool) -> Self {
self.needs_more = needs_more;
self
}
pub fn add_request_header(mut self, op: HeaderOp) -> Self {
self.request_headers.push(op);
self
}
pub fn add_response_header(mut self, op: HeaderOp) -> Self {
self.response_headers.push(op);
self
}
pub fn with_audit(mut self, audit: AuditMetadata) -> Self {
self.audit = audit;
self
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct AuditMetadata {
#[serde(default)]
pub tags: Vec<String>,
#[serde(default)]
pub rule_ids: Vec<String>,
pub confidence: Option<f32>,
#[serde(default)]
pub reason_codes: Vec<String>,
#[serde(default)]
pub custom: HashMap<String, serde_json::Value>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum GuardrailInspectionType {
PromptInjection,
PiiDetection,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GuardrailInspectEvent {
pub correlation_id: String,
pub inspection_type: GuardrailInspectionType,
pub content: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub model: Option<String>,
#[serde(default)]
pub categories: Vec<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub route_id: Option<String>,
#[serde(default)]
pub metadata: HashMap<String, String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GuardrailResponse {
pub detected: bool,
#[serde(default)]
pub confidence: f64,
#[serde(default)]
pub detections: Vec<GuardrailDetection>,
#[serde(skip_serializing_if = "Option::is_none")]
pub redacted_content: Option<String>,
}
impl Default for GuardrailResponse {
fn default() -> Self {
Self {
detected: false,
confidence: 0.0,
detections: Vec::new(),
redacted_content: None,
}
}
}
impl GuardrailResponse {
pub fn clean() -> Self {
Self::default()
}
pub fn with_detection(detection: GuardrailDetection) -> Self {
Self {
detected: true,
confidence: detection.confidence.unwrap_or(1.0),
detections: vec![detection],
redacted_content: None,
}
}
pub fn add_detection(&mut self, detection: GuardrailDetection) {
self.detected = true;
if let Some(conf) = detection.confidence {
self.confidence = self.confidence.max(conf);
}
self.detections.push(detection);
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GuardrailDetection {
pub category: String,
pub description: String,
#[serde(default)]
pub severity: DetectionSeverity,
#[serde(skip_serializing_if = "Option::is_none")]
pub confidence: Option<f64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub span: Option<TextSpan>,
}
impl GuardrailDetection {
pub fn new(category: impl Into<String>, description: impl Into<String>) -> Self {
Self {
category: category.into(),
description: description.into(),
severity: DetectionSeverity::Medium,
confidence: None,
span: None,
}
}
pub fn with_severity(mut self, severity: DetectionSeverity) -> Self {
self.severity = severity;
self
}
pub fn with_confidence(mut self, confidence: f64) -> Self {
self.confidence = Some(confidence);
self
}
pub fn with_span(mut self, start: usize, end: usize) -> Self {
self.span = Some(TextSpan { start, end });
self
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TextSpan {
pub start: usize,
pub end: usize,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
#[serde(rename_all = "lowercase")]
pub enum DetectionSeverity {
Low,
#[default]
Medium,
High,
Critical,
}