awaken-server 0.6.0

Multi-protocol HTTP server with SSE, mailbox, and protocol adapters for Awaken
Documentation
use awaken_protocol_a2a::TaskState;
use awaken_server_contract::thread::Thread;
use axum::extract::Query;
use axum::http::{HeaderMap, Uri};
use serde::de::DeserializeOwned;

use crate::app::ProtocolRoutesState;

use super::error::{A2aError, map_a2a_storage_error};
use super::types::A2A_VERSION;

pub(super) fn parse_a2a_tail(tail: &str) -> Vec<&str> {
    tail.split('/')
        .filter(|segment| !segment.is_empty())
        .collect()
}

pub(super) fn decode_query<T: DeserializeOwned>(uri: &Uri) -> Result<T, A2aError> {
    Query::<T>::try_from_uri(uri)
        .map(|query| query.0)
        .map_err(|err| A2aError::invalid("query", err.to_string()))
}

pub(super) fn decode_json_body<T: DeserializeOwned>(
    headers: &HeaderMap,
    body: &[u8],
) -> Result<T, A2aError> {
    ensure_json_content_type(headers)?;
    serde_json::from_slice(body)
        .map_err(|err| A2aError::invalid("body", format!("invalid JSON body: {err}")))
}

fn ensure_json_content_type(headers: &HeaderMap) -> Result<(), A2aError> {
    let Some(content_type) = forwarded_header(headers, "content-type") else {
        return Err(A2aError::invalid(
            "contentType",
            "Content-Type must be application/json or application/a2a+json",
        ));
    };

    let media_type = content_type
        .split(';')
        .next()
        .unwrap_or(content_type)
        .trim();
    if media_type.eq_ignore_ascii_case("application/json")
        || media_type.eq_ignore_ascii_case("application/a2a+json")
    {
        Ok(())
    } else {
        Err(A2aError::invalid(
            "contentType",
            "Content-Type must be application/json or application/a2a+json",
        ))
    }
}

pub(super) fn parse_page_token(page_token: Option<&str>) -> Result<usize, A2aError> {
    match page_token.map(str::trim).filter(|token| !token.is_empty()) {
        Some(token) => token.parse::<usize>().map_err(|_| {
            A2aError::invalid("pageToken", "pageToken must be an unsigned integer offset")
        }),
        None => Ok(0),
    }
}

pub(super) fn parse_task_state_filter(raw: &str) -> Result<TaskState, A2aError> {
    match raw.trim().to_ascii_lowercase().as_str() {
        "task_state_submitted" | "submitted" => Ok(TaskState::Submitted),
        "task_state_working" | "working" => Ok(TaskState::Working),
        "task_state_input_required" | "input_required" | "input-required" => {
            Ok(TaskState::InputRequired)
        }
        "task_state_auth_required" | "auth_required" | "auth-required" => {
            Ok(TaskState::AuthRequired)
        }
        "task_state_completed" | "completed" => Ok(TaskState::Completed),
        "task_state_failed" | "failed" => Ok(TaskState::Failed),
        "task_state_canceled" | "canceled" | "cancelled" => Ok(TaskState::Canceled),
        "task_state_rejected" | "rejected" => Ok(TaskState::Rejected),
        _ => Err(A2aError::invalid(
            "status",
            "status must be a valid TaskState value",
        )),
    }
}

pub(super) fn parse_task_action_segment(raw: &str) -> Result<(String, &str), A2aError> {
    let Some((task_id, action)) = raw.rsplit_once(':') else {
        return Err(A2aError::NotFound(format!(
            "unsupported A2A task action path: {raw}"
        )));
    };

    if task_id.trim().is_empty() {
        return Err(A2aError::invalid(
            "taskId",
            "task action path must include a task id before the action suffix",
        ));
    }

    match action {
        "cancel" | "subscribe" => Ok((task_id.to_string(), action)),
        _ => Err(A2aError::NotFound(format!(
            "unsupported A2A task action path: {raw}"
        ))),
    }
}

pub(super) fn trim_to_option(value: Option<&str>) -> Option<String> {
    value
        .map(str::trim)
        .filter(|value| !value.is_empty())
        .map(ToOwned::to_owned)
}

pub(super) fn forwarded_header<'a>(headers: &'a HeaderMap, name: &str) -> Option<&'a str> {
    headers
        .get(name)
        .and_then(|value| value.to_str().ok())
        .map(str::trim)
        .filter(|value| !value.is_empty())
}

pub(super) fn ensure_supported_version_from_request(
    headers: &HeaderMap,
    uri: &Uri,
) -> Result<(), A2aError> {
    if let Some(version) = uri
        .query()
        .into_iter()
        .flat_map(|query| query.split('&'))
        .filter_map(|pair| pair.split_once('='))
        .find_map(|(key, value)| key.eq_ignore_ascii_case("A2A-Version").then_some(value))
        && version != A2A_VERSION
    {
        return Err(A2aError::version_not_supported(version));
    }
    ensure_supported_version(headers)
}

pub(super) fn ensure_supported_version(headers: &HeaderMap) -> Result<(), A2aError> {
    if let Some(version) = forwarded_header(headers, "a2a-version")
        && version != A2A_VERSION
    {
        return Err(A2aError::version_not_supported(version));
    }
    Ok(())
}

pub(super) fn public_agent_id(st: &ProtocolRoutesState) -> Result<String, A2aError> {
    if st.run.resolver.resolve("default").is_ok() {
        return Ok("default".to_string());
    }

    let mut ids = st.run.resolver.agent_ids();
    ids.sort();
    ids.into_iter()
        .find(|id| st.run.resolver.resolve(id).is_ok())
        .ok_or_else(|| A2aError::NotFound("no runnable local agents registered".to_string()))
}

pub(super) fn ensure_runnable_agent(
    st: &ProtocolRoutesState,
    agent_id: &str,
) -> Result<(), A2aError> {
    st.run
        .resolver
        .resolve(agent_id)
        .map(|_| ())
        .map_err(|_| A2aError::NotFound(format!("agent not found: {agent_id}")))
}

pub(super) async fn load_thread_metadata_projection(
    st: &ProtocolRoutesState,
    thread_id: &str,
) -> Result<(bool, Thread), A2aError> {
    let existing = st
        .run
        .store()
        .load_thread(thread_id)
        .await
        .map_err(|e| A2aError::Internal(e.to_string()))?;
    Ok(materialize_thread_metadata_projection(
        thread_id,
        existing,
        awaken_server_contract::now_ms(),
    ))
}

pub(super) fn materialize_thread_metadata_projection(
    thread_id: &str,
    existing: Option<Thread>,
    now: u64,
) -> (bool, Thread) {
    let exists = existing.is_some();
    let mut thread = existing.unwrap_or_else(|| Thread::with_id(thread_id));
    thread.touch(now);
    (exists, thread)
}

pub(super) async fn persist_thread_metadata(
    st: &ProtocolRoutesState,
    thread_id: &str,
    exists: bool,
    thread: Thread,
) -> Result<(), A2aError> {
    if exists {
        st.run
            .store()
            .update_thread_metadata(thread_id, thread.metadata)
            .await
            .map_err(map_a2a_storage_error)?;
    } else {
        st.run
            .store()
            .save_thread_validated(&thread)
            .await
            .map_err(map_a2a_storage_error)?;
    }
    Ok(())
}