pub mod detect;
pub mod event;
pub mod extract;
pub mod mcp;
use std::time::{SystemTime, UNIX_EPOCH};
use tokio::sync::broadcast;
use aa_proto::assembly::audit::v1::{
audit_event, AuditEvent, LlmCallDetail, NetworkCallDetail, PolicyViolation, ToolCallDetail,
};
use aa_proto::assembly::common::v1::ActionType;
use aa_runtime::pipeline::event::{EnrichedEvent, EventSource};
use aa_runtime::pipeline::PipelineEvent;
use aa_security::{CredentialFinding, CredentialScanner};
use bytes::Bytes;
use crate::config::CredentialAction;
use crate::error::ProxyError;
use crate::intercept::detect::LlmApiPattern;
use crate::intercept::extract::{extract_anthropic, extract_cohere, extract_openai, ExtractionError, LlmFields};
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum VerdictDecision {
Forward,
ForwardRedacted,
Block,
AlertAndForward,
}
#[derive(Debug, Clone)]
pub struct InterceptVerdict {
pub decision: VerdictDecision,
pub findings: Vec<CredentialFinding>,
pub redacted_body: Option<Bytes>,
}
pub struct Interceptor {
event_tx: broadcast::Sender<PipelineEvent>,
scanner: Option<CredentialScanner>,
}
impl Interceptor {
pub fn intercept_request(&self, body: &[u8], action: CredentialAction) -> InterceptVerdict {
let Some(scanner) = self.scanner.as_ref() else {
return InterceptVerdict {
decision: VerdictDecision::Forward,
findings: Vec::new(),
redacted_body: None,
};
};
let text = String::from_utf8_lossy(body);
let scan = scanner.scan(&text);
if scan.is_clean() {
return InterceptVerdict {
decision: VerdictDecision::Forward,
findings: Vec::new(),
redacted_body: None,
};
}
match action {
CredentialAction::Block => InterceptVerdict {
decision: VerdictDecision::Block,
findings: scan.findings,
redacted_body: None,
},
CredentialAction::RedactOnly => {
let redacted = scan.redact(&text);
InterceptVerdict {
decision: VerdictDecision::ForwardRedacted,
findings: scan.findings,
redacted_body: Some(Bytes::from(redacted)),
}
}
CredentialAction::AlertOnly => InterceptVerdict {
decision: VerdictDecision::AlertAndForward,
findings: scan.findings,
redacted_body: None,
},
}
}
pub fn new(event_tx: broadcast::Sender<PipelineEvent>) -> Self {
Self {
event_tx,
scanner: Some(CredentialScanner::new()),
}
}
pub fn with_scanner(event_tx: broadcast::Sender<PipelineEvent>, scanner: Option<CredentialScanner>) -> Self {
Self { event_tx, scanner }
}
pub async fn emit_policy_decision(&self, host: &str, denied: bool) {
let now_ms = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_millis() as i64;
let (action_type, detail) = if denied {
let violation = PolicyViolation {
blocked_action: format!("CONNECT {host}"),
reason: "host is on the deny list".into(),
..Default::default()
};
(ActionType::NetworkCall, audit_event::Detail::Violation(violation))
} else {
let network = NetworkCallDetail {
host: host.to_string(),
protocol: "https".into(),
succeeded: true,
..Default::default()
};
(ActionType::NetworkCall, audit_event::Detail::Network(network))
};
let audit = AuditEvent {
action_type: action_type.into(),
detail: Some(detail),
..Default::default()
};
let enriched = EnrichedEvent {
inner: audit,
received_at_ms: now_ms,
source: EventSource::Proxy,
agent_id: String::new(),
connection_id: 0,
sequence_number: 0,
};
let _ = self.event_tx.send(PipelineEvent::Audit(Box::new(enriched)));
}
pub fn redact_response_body(&self, body: &[u8]) -> Option<Vec<u8>> {
let scanner = self.scanner.as_ref()?;
let text = String::from_utf8_lossy(body);
let scan = scanner.scan(&text);
if scan.is_clean() {
return None;
}
Some(scan.redact(&text).into_bytes())
}
pub async fn emit_mcp_decision(&self, tool_name: &str, args_json: &[u8], denied: bool, reason: &str) {
let now_ms = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_millis() as i64;
let detail = if denied {
let violation = PolicyViolation {
blocked_action: format!("tools/call {tool_name}"),
reason: reason.to_string(),
..Default::default()
};
audit_event::Detail::Violation(violation)
} else {
let safe_args = self
.scanner
.as_ref()
.and_then(|s| {
let text = String::from_utf8_lossy(args_json);
let scan = s.scan(&text);
if scan.is_clean() {
None
} else {
Some(scan.redact(&text).into_bytes())
}
})
.unwrap_or_else(|| args_json.to_vec());
let tool_call = ToolCallDetail {
tool_name: tool_name.to_string(),
tool_source: "mcp".into(),
succeeded: true,
args_json: safe_args,
..Default::default()
};
audit_event::Detail::ToolCall(tool_call)
};
let audit = AuditEvent {
action_type: ActionType::ToolCall.into(),
detail: Some(detail),
..Default::default()
};
let enriched = EnrichedEvent {
inner: audit,
received_at_ms: now_ms,
source: EventSource::Proxy,
agent_id: String::new(),
connection_id: 0,
sequence_number: 0,
};
let _ = self.event_tx.send(PipelineEvent::Audit(Box::new(enriched)));
}
pub async fn intercept(&self, event: &event::ProxyEvent) -> Result<Option<LlmFields>, ProxyError> {
if event.pattern == LlmApiPattern::Unknown {
tracing::debug!(method = %event.method, path = %event.path, "non-LLM traffic, skipping");
return Ok(None);
}
let raw_body = event.response_body.as_ref().or(event.request_body.as_ref());
let body: Option<bytes::Bytes> = raw_body.map(|b| {
if let Some(scanner) = &self.scanner {
let text = String::from_utf8_lossy(b);
let result = scanner.scan(&text);
if result.is_clean() {
b.clone()
} else {
tracing::warn!(
findings = result.findings.len(),
agent_id = event.agent_id.as_deref().unwrap_or("<unknown>"),
"credentials detected in LLM body, redacting before audit"
);
bytes::Bytes::from(result.redact(&text))
}
} else {
b.clone()
}
});
let fields = match body.as_ref() {
Some(bytes) => match Self::extract_for_pattern(&event.pattern, bytes) {
Ok(f) => Some(f),
Err(e) => {
tracing::warn!(
pattern = ?event.pattern,
error = %e,
"failed to extract LLM fields from body"
);
None
}
},
None => None,
};
tracing::info!(
agent_id = event.agent_id.as_deref().unwrap_or("<unknown>"),
pattern = ?event.pattern,
method = %event.method,
path = %event.path,
model = fields.as_ref().map(|f| f.model.as_str()).unwrap_or("<unknown>"),
messages = fields.as_ref().map(|f| f.messages_count).unwrap_or(0),
"intercepted LLM API call"
);
let pipeline_event = Self::build_pipeline_event(event, fields.as_ref());
let _ = self.event_tx.send(pipeline_event);
Ok(fields)
}
fn build_pipeline_event(event: &event::ProxyEvent, fields: Option<&LlmFields>) -> PipelineEvent {
let now_ms = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_millis() as i64;
let llm_detail = LlmCallDetail {
model: fields.map(|f| f.model.clone()).unwrap_or_default(),
prompt_tokens: fields.and_then(|f| f.prompt_tokens).unwrap_or(0) as i32,
completion_tokens: fields.and_then(|f| f.completion_tokens).unwrap_or(0) as i32,
provider: Self::provider_name(&event.pattern).into(),
..Default::default()
};
let audit = AuditEvent {
action_type: ActionType::LlmCall.into(),
detail: Some(audit_event::Detail::LlmCall(llm_detail)),
..Default::default()
};
let enriched = EnrichedEvent {
inner: audit,
received_at_ms: now_ms,
source: EventSource::Proxy,
agent_id: event.agent_id.clone().unwrap_or_default(),
connection_id: 0,
sequence_number: 0,
};
PipelineEvent::Audit(Box::new(enriched))
}
fn provider_name(pattern: &LlmApiPattern) -> &'static str {
match pattern {
LlmApiPattern::OpenAi => "openai",
LlmApiPattern::Anthropic => "anthropic",
LlmApiPattern::Cohere => "cohere",
LlmApiPattern::Unknown => "unknown",
}
}
fn extract_for_pattern(pattern: &LlmApiPattern, body: &[u8]) -> Result<LlmFields, ExtractionError> {
match pattern {
LlmApiPattern::OpenAi => extract_openai(body),
LlmApiPattern::Anthropic => extract_anthropic(body),
LlmApiPattern::Cohere => extract_cohere(body),
LlmApiPattern::Unknown => Err(ExtractionError::UnrecognizedFormat {
reason: "unknown provider".into(),
}),
}
}
}
#[cfg(test)]
mod tests {
use std::time::SystemTime;
use bytes::Bytes;
use super::*;
use crate::intercept::detect::LlmApiPattern;
use crate::intercept::event::ProxyEvent;
fn make_interceptor() -> Interceptor {
let (tx, _rx) = broadcast::channel(16);
Interceptor::new(tx)
}
fn make_event(pattern: LlmApiPattern) -> ProxyEvent {
ProxyEvent {
agent_id: Some("test-agent".into()),
pattern,
method: "POST".into(),
path: "/v1/chat/completions".into(),
request_body: None,
response_body: None,
timestamp: SystemTime::now(),
}
}
#[tokio::test]
async fn intercept_openai_event_succeeds() {
let interceptor = make_interceptor();
let result = interceptor.intercept(&make_event(LlmApiPattern::OpenAi)).await;
assert!(result.is_ok());
}
#[tokio::test]
async fn intercept_anthropic_event_succeeds() {
let interceptor = make_interceptor();
let result = interceptor.intercept(&make_event(LlmApiPattern::Anthropic)).await;
assert!(result.is_ok());
}
#[tokio::test]
async fn intercept_unknown_returns_none() {
let interceptor = make_interceptor();
let result = interceptor
.intercept(&make_event(LlmApiPattern::Unknown))
.await
.unwrap();
assert!(result.is_none(), "unknown pattern should skip extraction");
}
#[tokio::test]
async fn intercept_with_no_agent_id_succeeds() {
let interceptor = make_interceptor();
let mut event = make_event(LlmApiPattern::OpenAi);
event.agent_id = None;
assert!(interceptor.intercept(&event).await.is_ok());
}
#[tokio::test]
async fn intercept_openai_with_body_extracts_fields() {
let interceptor = make_interceptor();
let mut event = make_event(LlmApiPattern::OpenAi);
event.response_body = Some(Bytes::from(
r#"{"model":"gpt-4","usage":{"prompt_tokens":10,"completion_tokens":20}}"#,
));
let fields = interceptor.intercept(&event).await.unwrap().unwrap();
assert_eq!(fields.model, "gpt-4");
assert_eq!(fields.prompt_tokens, Some(10));
assert_eq!(fields.completion_tokens, Some(20));
}
#[tokio::test]
async fn intercept_anthropic_with_body_extracts_fields() {
let interceptor = make_interceptor();
let mut event = make_event(LlmApiPattern::Anthropic);
event.response_body = Some(Bytes::from(
r#"{"model":"claude-3-opus-20240229","usage":{"input_tokens":15,"output_tokens":30}}"#,
));
let fields = interceptor.intercept(&event).await.unwrap().unwrap();
assert_eq!(fields.model, "claude-3-opus-20240229");
assert_eq!(fields.prompt_tokens, Some(15));
assert_eq!(fields.completion_tokens, Some(30));
}
#[tokio::test]
async fn intercept_cohere_with_body_extracts_fields() {
let interceptor = make_interceptor();
let mut event = make_event(LlmApiPattern::Cohere);
event.response_body = Some(Bytes::from(
r#"{"model":"command-r-plus","message":"hello","meta":{"tokens":{"input_tokens":5,"output_tokens":12}}}"#,
));
let fields = interceptor.intercept(&event).await.unwrap().unwrap();
assert_eq!(fields.model, "command-r-plus");
assert_eq!(fields.prompt_tokens, Some(5));
assert_eq!(fields.completion_tokens, Some(12));
assert_eq!(fields.messages_count, 1);
}
#[tokio::test]
async fn intercept_prefers_response_body_over_request() {
let interceptor = make_interceptor();
let mut event = make_event(LlmApiPattern::OpenAi);
event.request_body = Some(Bytes::from(
r#"{"model":"gpt-4","messages":[{"role":"user","content":"hi"}]}"#,
));
event.response_body = Some(Bytes::from(
r#"{"model":"gpt-4","usage":{"prompt_tokens":10,"completion_tokens":20}}"#,
));
let fields = interceptor.intercept(&event).await.unwrap().unwrap();
assert_eq!(fields.prompt_tokens, Some(10));
assert_eq!(fields.completion_tokens, Some(20));
assert_eq!(fields.messages_count, 0);
}
#[tokio::test]
async fn intercept_falls_back_to_request_body() {
let interceptor = make_interceptor();
let mut event = make_event(LlmApiPattern::OpenAi);
event.request_body = Some(Bytes::from(
r#"{"model":"gpt-4","messages":[{"role":"user","content":"hi"}]}"#,
));
event.response_body = None;
let fields = interceptor.intercept(&event).await.unwrap().unwrap();
assert_eq!(fields.model, "gpt-4");
assert_eq!(fields.messages_count, 1);
assert_eq!(fields.prompt_tokens, None);
}
#[tokio::test]
async fn intercept_with_none_body_returns_none() {
let interceptor = make_interceptor();
let event = make_event(LlmApiPattern::OpenAi);
let result = interceptor.intercept(&event).await.unwrap();
assert!(result.is_none());
}
#[tokio::test]
async fn intercept_with_malformed_body_returns_none() {
let interceptor = make_interceptor();
let mut event = make_event(LlmApiPattern::OpenAi);
event.response_body = Some(Bytes::from("not json"));
let result = interceptor.intercept(&event).await.unwrap();
assert!(result.is_none());
}
#[tokio::test]
async fn non_llm_traffic_emits_no_pipeline_event() {
let (tx, mut rx) = broadcast::channel(16);
let interceptor = Interceptor::new(tx);
let event = make_event(LlmApiPattern::Unknown);
interceptor.intercept(&event).await.unwrap();
assert!(rx.try_recv().is_err(), "non-LLM traffic must not emit a pipeline event");
}
#[tokio::test]
async fn llm_traffic_emits_pipeline_event_with_correct_fields() {
let (tx, mut rx) = broadcast::channel(16);
let interceptor = Interceptor::new(tx);
let mut event = make_event(LlmApiPattern::OpenAi);
event.response_body = Some(Bytes::from(
r#"{"model":"gpt-4","usage":{"prompt_tokens":10,"completion_tokens":20}}"#,
));
interceptor.intercept(&event).await.unwrap();
let pipeline_event = rx.try_recv().expect("should have received a pipeline event");
match pipeline_event {
PipelineEvent::Audit(enriched) => {
assert_eq!(enriched.source, EventSource::Proxy);
assert_eq!(enriched.agent_id, "test-agent");
let detail = enriched.inner.detail.expect("detail must be set");
match detail {
aa_proto::assembly::audit::v1::audit_event::Detail::LlmCall(llm) => {
assert_eq!(llm.model, "gpt-4");
assert_eq!(llm.prompt_tokens, 10);
assert_eq!(llm.completion_tokens, 20);
assert_eq!(llm.provider, "openai");
}
other => panic!("expected LlmCall detail, got {other:?}"),
}
}
other => panic!("expected Audit event, got {other:?}"),
}
}
#[tokio::test]
async fn intercept_redacts_credentials_from_body() {
let interceptor = make_interceptor();
let mut event = make_event(LlmApiPattern::OpenAi);
event.request_body = Some(Bytes::from(
r#"{"model":"gpt-4","messages":[{"role":"user","content":"my key is sk-proj-aBcDeFgHiJkLmNoPqRsT1234567890abcdef1234567890ab"}]}"#,
));
event.response_body = None;
let fields = interceptor.intercept(&event).await.unwrap().unwrap();
assert_eq!(fields.model, "gpt-4");
assert_eq!(fields.messages_count, 1);
}
#[tokio::test]
async fn intercept_credential_body_emits_redacted_event() {
let (tx, mut rx) = broadcast::channel(16);
let interceptor = Interceptor::new(tx);
let mut event = make_event(LlmApiPattern::OpenAi);
event.response_body = Some(Bytes::from(
r#"{"model":"gpt-4","usage":{"prompt_tokens":5,"completion_tokens":8},"debug":"sk-proj-aBcDeFgHiJkLmNoPqRsT1234567890abcdef1234567890ab"}"#,
));
let fields = interceptor.intercept(&event).await.unwrap().unwrap();
assert_eq!(fields.model, "gpt-4");
assert_eq!(fields.prompt_tokens, Some(5));
let pipeline_event = rx.try_recv().expect("should receive pipeline event");
let event_str = format!("{pipeline_event:?}");
assert!(
!event_str.contains("sk-proj-"),
"pipeline event must not contain raw credential"
);
}
#[tokio::test]
async fn intercept_with_scanner_disabled_skips_redaction() {
let (tx, _rx) = broadcast::channel(16);
let interceptor = Interceptor::with_scanner(tx, None);
let mut event = make_event(LlmApiPattern::OpenAi);
event.response_body = Some(Bytes::from(
r#"{"model":"gpt-4","usage":{"prompt_tokens":5,"completion_tokens":8},"debug":"sk-proj-aBcDeFgHiJkLmNoPqRsT1234567890abcdef1234567890ab"}"#,
));
let fields = interceptor.intercept(&event).await.unwrap().unwrap();
assert_eq!(fields.model, "gpt-4");
assert_eq!(fields.prompt_tokens, Some(5));
}
}