Skip to main content

systemprompt_security/extraction/
header.rs

1use axum::http::{HeaderMap, HeaderValue};
2use std::error::Error;
3use std::fmt;
4use systemprompt_identifiers::{headers, AgentName, ContextId, SessionId, TaskId, TraceId, UserId};
5use systemprompt_models::execution::context::RequestContext;
6
7#[derive(Debug, Clone, Copy, PartialEq, Eq)]
8pub struct HeaderInjectionError;
9
10impl fmt::Display for HeaderInjectionError {
11    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
12        write!(f, "Header value contains invalid characters")
13    }
14}
15
16impl Error for HeaderInjectionError {}
17
18#[derive(Debug, Clone, Copy)]
19pub struct HeaderExtractor;
20
21impl HeaderExtractor {
22    pub fn extract_trace_id(headers: &HeaderMap) -> TraceId {
23        Self::extract_header(headers, headers::TRACE_ID)
24            .map_or_else(TraceId::generate, TraceId::new)
25    }
26
27    pub fn extract_context_id(headers: &HeaderMap) -> ContextId {
28        Self::extract_header(headers, headers::CONTEXT_ID)
29            .filter(|s| !s.is_empty())
30            .map_or_else(ContextId::empty, ContextId::new)
31    }
32
33    pub fn extract_task_id(headers: &HeaderMap) -> Option<TaskId> {
34        Self::extract_header(headers, headers::TASK_ID).map(TaskId::new)
35    }
36
37    pub fn extract_agent_name(headers: &HeaderMap) -> AgentName {
38        Self::extract_header(headers, headers::AGENT_NAME)
39            .map_or_else(AgentName::system, AgentName::new)
40    }
41
42    fn extract_header(headers: &HeaderMap, name: &str) -> Option<String> {
43        headers
44            .get(name)
45            .and_then(|v| {
46                v.to_str()
47                    .map_err(|e| {
48                        tracing::debug!(error = %e, header = %name, "Header contains non-ASCII characters");
49                        e
50                    })
51                    .ok()
52            })
53            .map(ToString::to_string)
54    }
55}
56
57#[derive(Debug, Clone, Copy)]
58pub struct HeaderInjector;
59
60impl HeaderInjector {
61    pub fn inject_session_id(
62        headers: &mut HeaderMap,
63        session_id: &SessionId,
64    ) -> Result<(), HeaderInjectionError> {
65        Self::inject_header(headers, headers::SESSION_ID, session_id.as_str())
66    }
67
68    pub fn inject_user_id(
69        headers: &mut HeaderMap,
70        user_id: &UserId,
71    ) -> Result<(), HeaderInjectionError> {
72        Self::inject_header(headers, headers::USER_ID, user_id.as_str())
73    }
74
75    pub fn inject_trace_id(
76        headers: &mut HeaderMap,
77        trace_id: &TraceId,
78    ) -> Result<(), HeaderInjectionError> {
79        Self::inject_header(headers, headers::TRACE_ID, trace_id.as_str())
80    }
81
82    pub fn inject_context_id(
83        headers: &mut HeaderMap,
84        context_id: &ContextId,
85    ) -> Result<(), HeaderInjectionError> {
86        if context_id.as_str().is_empty() {
87            return Ok(());
88        }
89        Self::inject_header(headers, headers::CONTEXT_ID, context_id.as_str())
90    }
91
92    pub fn inject_task_id(
93        headers: &mut HeaderMap,
94        task_id: &TaskId,
95    ) -> Result<(), HeaderInjectionError> {
96        Self::inject_header(headers, headers::TASK_ID, task_id.as_str())
97    }
98
99    pub fn inject_agent_name(
100        headers: &mut HeaderMap,
101        agent_name: &str,
102    ) -> Result<(), HeaderInjectionError> {
103        Self::inject_header(headers, headers::AGENT_NAME, agent_name)
104    }
105
106    pub fn inject_from_request_context(
107        headers: &mut HeaderMap,
108        ctx: &RequestContext,
109    ) -> Result<(), HeaderInjectionError> {
110        Self::inject_session_id(headers, &ctx.request.session_id)?;
111        Self::inject_user_id(headers, &ctx.auth.user_id)?;
112        Self::inject_trace_id(headers, &ctx.execution.trace_id)?;
113        Self::inject_context_id(headers, &ctx.execution.context_id)?;
114        Self::inject_agent_name(headers, ctx.execution.agent_name.as_str())?;
115        Ok(())
116    }
117
118    fn inject_header(
119        headers: &mut HeaderMap,
120        name: &'static str,
121        value: &str,
122    ) -> Result<(), HeaderInjectionError> {
123        HeaderValue::from_str(value).map_or(Err(HeaderInjectionError), |header_value| {
124            headers.insert(name, header_value);
125            Ok(())
126        })
127    }
128}