systemprompt_security/extraction/
header.rs1use axum::http::{HeaderMap, HeaderValue};
2use std::error::Error;
3use std::fmt;
4use systemprompt_identifiers::{
5 AgentName, ContextId, GatewayConversationId, ProviderRequestId, SessionId, TaskId, TraceId,
6 UserId, headers,
7};
8use systemprompt_models::execution::context::RequestContext;
9
10#[derive(Debug, Clone, Copy, PartialEq, Eq)]
11pub struct HeaderInjectionError;
12
13impl fmt::Display for HeaderInjectionError {
14 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
15 write!(f, "Header value contains invalid characters")
16 }
17}
18
19impl Error for HeaderInjectionError {}
20
21#[derive(Debug, Clone, Copy)]
22pub struct HeaderExtractor;
23
24impl HeaderExtractor {
25 pub fn extract_trace_id(headers: &HeaderMap) -> TraceId {
26 Self::extract_header(headers, headers::TRACE_ID)
27 .map_or_else(TraceId::generate, TraceId::new)
28 }
29
30 pub fn extract_context_id(headers: &HeaderMap) -> Option<ContextId> {
31 Self::extract_header(headers, headers::CONTEXT_ID)
32 .filter(|s| !s.is_empty())
33 .and_then(|s| {
34 ContextId::try_new(s)
35 .map_err(|e| {
36 tracing::warn!(error = %e, "Invalid context_id header value, ignoring");
37 e
38 })
39 .ok()
40 })
41 }
42
43 pub fn extract_gateway_conversation_id(headers: &HeaderMap) -> Option<GatewayConversationId> {
44 Self::extract_header(headers, headers::GATEWAY_CONVERSATION_ID)
45 .filter(|s| !s.is_empty())
46 .and_then(|s| GatewayConversationId::try_new(s).ok())
47 }
48
49 pub fn extract_provider_request_id(headers: &HeaderMap) -> Option<ProviderRequestId> {
50 Self::extract_header(headers, headers::PROVIDER_REQUEST_ID)
51 .filter(|s| !s.is_empty())
52 .and_then(|s| ProviderRequestId::try_new(s).ok())
53 }
54
55 pub fn extract_task_id(headers: &HeaderMap) -> Option<TaskId> {
56 Self::extract_header(headers, headers::TASK_ID).map(TaskId::new)
57 }
58
59 pub fn extract_agent_name(headers: &HeaderMap) -> AgentName {
60 Self::extract_header(headers, headers::AGENT_NAME)
61 .map_or_else(AgentName::system, AgentName::new)
62 }
63
64 fn extract_header(headers: &HeaderMap, name: &str) -> Option<String> {
65 headers
66 .get(name)
67 .and_then(|v| {
68 v.to_str()
69 .map_err(|e| {
70 tracing::debug!(error = %e, header = %name, "Header contains non-ASCII characters");
71 e
72 })
73 .ok()
74 })
75 .map(ToString::to_string)
76 }
77}
78
79#[derive(Debug, Clone, Copy)]
80pub struct HeaderInjector;
81
82impl HeaderInjector {
83 pub fn inject_session_id(
84 headers: &mut HeaderMap,
85 session_id: &SessionId,
86 ) -> Result<(), HeaderInjectionError> {
87 Self::inject_header(headers, headers::SESSION_ID, session_id.as_str())
88 }
89
90 pub fn inject_user_id(
91 headers: &mut HeaderMap,
92 user_id: &UserId,
93 ) -> Result<(), HeaderInjectionError> {
94 Self::inject_header(headers, headers::USER_ID, user_id.as_str())
95 }
96
97 pub fn inject_trace_id(
98 headers: &mut HeaderMap,
99 trace_id: &TraceId,
100 ) -> Result<(), HeaderInjectionError> {
101 Self::inject_header(headers, headers::TRACE_ID, trace_id.as_str())
102 }
103
104 pub fn inject_context_id(
105 headers: &mut HeaderMap,
106 context_id: &ContextId,
107 ) -> Result<(), HeaderInjectionError> {
108 Self::inject_header(headers, headers::CONTEXT_ID, context_id.as_str())
109 }
110
111 pub fn inject_gateway_conversation_id(
112 headers: &mut HeaderMap,
113 id: &GatewayConversationId,
114 ) -> Result<(), HeaderInjectionError> {
115 Self::inject_header(headers, headers::GATEWAY_CONVERSATION_ID, id.as_str())
116 }
117
118 pub fn inject_provider_request_id(
119 headers: &mut HeaderMap,
120 id: &ProviderRequestId,
121 ) -> Result<(), HeaderInjectionError> {
122 Self::inject_header(headers, headers::PROVIDER_REQUEST_ID, id.as_str())
123 }
124
125 pub fn inject_task_id(
126 headers: &mut HeaderMap,
127 task_id: &TaskId,
128 ) -> Result<(), HeaderInjectionError> {
129 Self::inject_header(headers, headers::TASK_ID, task_id.as_str())
130 }
131
132 pub fn inject_agent_name(
133 headers: &mut HeaderMap,
134 agent_name: &str,
135 ) -> Result<(), HeaderInjectionError> {
136 Self::inject_header(headers, headers::AGENT_NAME, agent_name)
137 }
138
139 pub fn inject_from_request_context(
140 headers: &mut HeaderMap,
141 ctx: &RequestContext,
142 ) -> Result<(), HeaderInjectionError> {
143 Self::inject_session_id(headers, &ctx.request.session_id)?;
144 Self::inject_user_id(headers, &ctx.auth.user_id)?;
145 Self::inject_trace_id(headers, &ctx.execution.trace_id)?;
146 Self::inject_context_id(headers, &ctx.execution.context_id)?;
147 Self::inject_agent_name(headers, ctx.execution.agent_name.as_str())?;
148 Ok(())
149 }
150
151 fn inject_header(
152 headers: &mut HeaderMap,
153 name: &'static str,
154 value: &str,
155 ) -> Result<(), HeaderInjectionError> {
156 HeaderValue::from_str(value).map_or(Err(HeaderInjectionError), |header_value| {
157 headers.insert(name, header_value);
158 Ok(())
159 })
160 }
161}