systemprompt-api 0.1.18

HTTP API server and gateway for systemprompt.io OS
Documentation
use async_trait::async_trait;
use axum::body::Body;
use axum::extract::Request;
use axum::http::HeaderMap;
use systemprompt_identifiers::{AgentName, ContextId, SessionId, TaskId, TraceId, UserId};
use systemprompt_models::execution::{ContextExtractionError, RequestContext};

use super::traits::ContextExtractor;
use crate::services::middleware::context::sources::{
    ContextIdSource, HeaderSource, PayloadSource, TASK_BASED_CONTEXT_MARKER,
};

#[derive(Debug, Clone, Copy)]
pub struct A2aContextExtractor;

impl A2aContextExtractor {
    pub const fn new() -> Self {
        Self
    }

    fn try_from_headers(headers: &HeaderMap) -> Result<RequestContext, ContextExtractionError> {
        let session_id = HeaderSource::extract_required(headers, "x-session-id")?;
        let trace_id = HeaderSource::extract_required(headers, "x-trace-id")?;
        let user_id = HeaderSource::extract_required(headers, "x-user-id")?;
        let context_id = HeaderSource::extract_required(headers, "x-context-id")?;
        let agent_name = HeaderSource::extract_required(headers, "x-agent-name")?;

        let mut context = RequestContext::new(
            SessionId::new(session_id),
            TraceId::new(trace_id),
            ContextId::new(context_id),
            AgentName::new(agent_name),
        )
        .with_user_id(UserId::new(user_id));

        if let Some(task_id_str) = HeaderSource::extract_optional(headers, "x-task-id") {
            context = context.with_task_id(TaskId::new(task_id_str));
        }

        if let Some(auth_token) = Self::extract_bearer_token(headers) {
            context = context.with_auth_token(auth_token);
        }

        Ok(context)
    }

    fn extract_bearer_token(headers: &HeaderMap) -> Option<String> {
        headers
            .get("authorization")
            .and_then(|v| v.to_str().ok())
            .and_then(|s| s.strip_prefix("Bearer "))
            .map(ToString::to_string)
    }

    fn try_from_payload(
        body_bytes: &[u8],
        headers: &HeaderMap,
    ) -> Result<RequestContext, ContextExtractionError> {
        let context_source = PayloadSource::extract_context_source(body_bytes)?;

        let session_id = HeaderSource::extract_required(headers, "x-session-id")?;
        let trace_id = HeaderSource::extract_required(headers, "x-trace-id")?;
        let user_id = HeaderSource::extract_required(headers, "x-user-id")?;
        let agent_name = HeaderSource::extract_required(headers, "x-agent-name")?;

        let (context_id, task_id) = match context_source {
            ContextIdSource::Direct(id) => {
                (id, HeaderSource::extract_optional(headers, "x-task-id"))
            },
            ContextIdSource::FromTask { task_id } => {
                (TASK_BASED_CONTEXT_MARKER.to_string(), Some(task_id))
            },
        };

        let mut context = RequestContext::new(
            SessionId::new(session_id),
            TraceId::new(trace_id),
            ContextId::new(context_id),
            AgentName::new(agent_name),
        )
        .with_user_id(UserId::new(user_id));

        if let Some(task_id_str) = task_id {
            context = context.with_task_id(TaskId::new(task_id_str));
        }

        if let Some(auth_token) = Self::extract_bearer_token(headers) {
            context = context.with_auth_token(auth_token);
        }

        Ok(context)
    }
}

impl Default for A2aContextExtractor {
    fn default() -> Self {
        Self::new()
    }
}

#[async_trait]
impl ContextExtractor for A2aContextExtractor {
    async fn extract_from_headers(
        &self,
        headers: &HeaderMap,
    ) -> Result<RequestContext, ContextExtractionError> {
        Self::try_from_headers(headers)
    }

    async fn extract_from_request(
        &self,
        request: Request<Body>,
    ) -> Result<(RequestContext, Request<Body>), ContextExtractionError> {
        let headers = request.headers().clone();

        if let Ok(context) = Self::try_from_headers(&headers) {
            return Ok((context, request));
        }

        let (body_bytes, reconstructed_request) =
            PayloadSource::read_and_reconstruct(request).await?;

        let context = Self::try_from_payload(&body_bytes, &headers)?;

        Ok((context, reconstructed_request))
    }
}