systemprompt-security 0.1.21

Security module for systemprompt.io - authentication, authorization, JWT, and token extraction
Documentation
use axum::http::{HeaderMap, HeaderValue};
use std::error::Error;
use std::fmt;
use systemprompt_identifiers::{AgentName, ContextId, SessionId, TaskId, TraceId, UserId, headers};
use systemprompt_models::execution::context::RequestContext;

#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct HeaderInjectionError;

impl fmt::Display for HeaderInjectionError {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        write!(f, "Header value contains invalid characters")
    }
}

impl Error for HeaderInjectionError {}

#[derive(Debug, Clone, Copy)]
pub struct HeaderExtractor;

impl HeaderExtractor {
    pub fn extract_trace_id(headers: &HeaderMap) -> TraceId {
        Self::extract_header(headers, headers::TRACE_ID)
            .map_or_else(TraceId::generate, TraceId::new)
    }

    pub fn extract_context_id(headers: &HeaderMap) -> ContextId {
        Self::extract_header(headers, headers::CONTEXT_ID)
            .filter(|s| !s.is_empty())
            .map_or_else(ContextId::empty, ContextId::new)
    }

    pub fn extract_task_id(headers: &HeaderMap) -> Option<TaskId> {
        Self::extract_header(headers, headers::TASK_ID).map(TaskId::new)
    }

    pub fn extract_agent_name(headers: &HeaderMap) -> AgentName {
        Self::extract_header(headers, headers::AGENT_NAME)
            .map_or_else(AgentName::system, AgentName::new)
    }

    fn extract_header(headers: &HeaderMap, name: &str) -> Option<String> {
        headers
            .get(name)
            .and_then(|v| {
                v.to_str()
                    .map_err(|e| {
                        tracing::debug!(error = %e, header = %name, "Header contains non-ASCII characters");
                        e
                    })
                    .ok()
            })
            .map(ToString::to_string)
    }
}

#[derive(Debug, Clone, Copy)]
pub struct HeaderInjector;

impl HeaderInjector {
    pub fn inject_session_id(
        headers: &mut HeaderMap,
        session_id: &SessionId,
    ) -> Result<(), HeaderInjectionError> {
        Self::inject_header(headers, headers::SESSION_ID, session_id.as_str())
    }

    pub fn inject_user_id(
        headers: &mut HeaderMap,
        user_id: &UserId,
    ) -> Result<(), HeaderInjectionError> {
        Self::inject_header(headers, headers::USER_ID, user_id.as_str())
    }

    pub fn inject_trace_id(
        headers: &mut HeaderMap,
        trace_id: &TraceId,
    ) -> Result<(), HeaderInjectionError> {
        Self::inject_header(headers, headers::TRACE_ID, trace_id.as_str())
    }

    pub fn inject_context_id(
        headers: &mut HeaderMap,
        context_id: &ContextId,
    ) -> Result<(), HeaderInjectionError> {
        if context_id.as_str().is_empty() {
            return Ok(());
        }
        Self::inject_header(headers, headers::CONTEXT_ID, context_id.as_str())
    }

    pub fn inject_task_id(
        headers: &mut HeaderMap,
        task_id: &TaskId,
    ) -> Result<(), HeaderInjectionError> {
        Self::inject_header(headers, headers::TASK_ID, task_id.as_str())
    }

    pub fn inject_agent_name(
        headers: &mut HeaderMap,
        agent_name: &str,
    ) -> Result<(), HeaderInjectionError> {
        Self::inject_header(headers, headers::AGENT_NAME, agent_name)
    }

    pub fn inject_from_request_context(
        headers: &mut HeaderMap,
        ctx: &RequestContext,
    ) -> Result<(), HeaderInjectionError> {
        Self::inject_session_id(headers, &ctx.request.session_id)?;
        Self::inject_user_id(headers, &ctx.auth.user_id)?;
        Self::inject_trace_id(headers, &ctx.execution.trace_id)?;
        Self::inject_context_id(headers, &ctx.execution.context_id)?;
        Self::inject_agent_name(headers, ctx.execution.agent_name.as_str())?;
        Ok(())
    }

    fn inject_header(
        headers: &mut HeaderMap,
        name: &'static str,
        value: &str,
    ) -> Result<(), HeaderInjectionError> {
        HeaderValue::from_str(value).map_or(Err(HeaderInjectionError), |header_value| {
            headers.insert(name, header_value);
            Ok(())
        })
    }
}