systemprompt-api 0.1.18

HTTP API server and gateway for systemprompt.io OS
Documentation
use async_trait::async_trait;
use axum::http::HeaderMap;
use systemprompt_identifiers::{AgentName, ContextId, SessionId, TaskId, TraceId, UserId};
use systemprompt_models::auth::UserType;
use systemprompt_models::execution::{ContextExtractionError, RequestContext};

use super::traits::ContextExtractor;

#[derive(Debug, Clone, Copy)]
pub struct HeaderContextExtractor;

impl HeaderContextExtractor {
    pub const fn new() -> Self {
        Self
    }

    fn extract_required_header(
        headers: &HeaderMap,
        name: &str,
    ) -> Result<String, ContextExtractionError> {
        headers
            .get(name)
            .ok_or_else(|| ContextExtractionError::MissingHeader(name.to_string()))?
            .to_str()
            .map(ToString::to_string)
            .map_err(|e| ContextExtractionError::InvalidHeaderValue {
                header: name.to_string(),
                reason: e.to_string(),
            })
    }

    fn extract_optional_header(headers: &HeaderMap, name: &str) -> Option<String> {
        headers
            .get(name)
            .and_then(|v| v.to_str().ok())
            .map(ToString::to_string)
    }
}

impl Default for HeaderContextExtractor {
    fn default() -> Self {
        Self::new()
    }
}

#[async_trait]
impl ContextExtractor for HeaderContextExtractor {
    async fn extract_from_headers(
        &self,
        headers: &HeaderMap,
    ) -> Result<RequestContext, ContextExtractionError> {
        let session_id_str = Self::extract_required_header(headers, "x-session-id")?;
        let trace_id_str = Self::extract_required_header(headers, "x-trace-id")?;
        let user_id_str = Self::extract_required_header(headers, "x-user-id")?;
        let context_id_str = Self::extract_required_header(headers, "x-context-id")?;
        let agent_name_str = Self::extract_required_header(headers, "x-agent-name")?;

        let mut context = RequestContext::new(
            SessionId::new(session_id_str),
            TraceId::new(trace_id_str),
            ContextId::new(context_id_str),
            AgentName::new(agent_name_str),
        )
        .with_user_id(UserId::new(user_id_str));

        if let Some(task_id_str) = Self::extract_optional_header(headers, "x-task-id") {
            context = context.with_task_id(TaskId::new(task_id_str));
        }

        Ok(context)
    }

    async fn extract_user_only(
        &self,
        headers: &HeaderMap,
    ) -> Result<RequestContext, ContextExtractionError> {
        let session_id_str = Self::extract_required_header(headers, "x-session-id")?;
        let trace_id_str = Self::extract_required_header(headers, "x-trace-id")?;
        let user_id_str = Self::extract_required_header(headers, "x-user-id")?;
        let agent_name_str = Self::extract_required_header(headers, "x-agent-name")?;

        let context = RequestContext::new(
            SessionId::new(session_id_str),
            TraceId::new(trace_id_str),
            ContextId::new(String::new()),
            AgentName::new(agent_name_str),
        )
        .with_user_id(UserId::new(user_id_str))
        .with_user_type(UserType::User);

        Ok(context)
    }
}