awaken-server 0.6.0

Multi-protocol HTTP server with SSE, mailbox, and protocol adapters for Awaken
Documentation
use awaken_protocol_a2a::{
    Artifact, StreamResponse, TaskArtifactUpdateEvent, TaskStatus, TaskStatusUpdateEvent,
};

use super::types::TaskSnapshot;

#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub(super) enum InitialStreamEvent {
    TaskSnapshot,
    StatusUpdate,
}

#[derive(Debug)]
pub(super) struct TaskStreamProjector {
    initial: InitialStreamEvent,
    delivered_initial: bool,
    last_status: Option<TaskStatus>,
    last_artifacts: Vec<Artifact>,
}

impl TaskStreamProjector {
    pub(super) fn new(initial: InitialStreamEvent) -> Self {
        Self {
            initial,
            delivered_initial: false,
            last_status: None,
            last_artifacts: Vec::new(),
        }
    }

    pub(super) fn project(&mut self, snapshot: &TaskSnapshot) -> Vec<StreamResponse> {
        let mut responses = Vec::new();
        if !self.delivered_initial {
            responses.push(match self.initial {
                InitialStreamEvent::TaskSnapshot => StreamResponse {
                    task: Some(snapshot.task.clone()),
                    ..Default::default()
                },
                InitialStreamEvent::StatusUpdate => status_update_response(snapshot),
            });
            self.delivered_initial = true;
            self.last_status = Some(snapshot.task.status.clone());
            self.last_artifacts = snapshot.task.artifacts.clone();
            return responses;
        }

        if self.last_status.as_ref() != Some(&snapshot.task.status) {
            responses.push(status_update_response(snapshot));
            self.last_status = Some(snapshot.task.status.clone());
        }

        if snapshot.task.artifacts != self.last_artifacts {
            let total = snapshot.task.artifacts.len();
            responses.extend(snapshot.task.artifacts.iter().cloned().enumerate().map(
                |(index, artifact)| StreamResponse {
                    artifact_update: Some(TaskArtifactUpdateEvent {
                        task_id: snapshot.task.id.clone(),
                        context_id: snapshot.task.context_id.clone(),
                        artifact,
                        append: Some(false),
                        last_chunk: Some(index + 1 == total),
                        metadata: None,
                    }),
                    ..Default::default()
                },
            ));
            self.last_artifacts = snapshot.task.artifacts.clone();
        }

        responses
    }
}

fn status_update_response(snapshot: &TaskSnapshot) -> StreamResponse {
    StreamResponse {
        status_update: Some(TaskStatusUpdateEvent {
            task_id: snapshot.task.id.clone(),
            context_id: snapshot.task.context_id.clone(),
            status: snapshot.task.status.clone(),
            metadata: None,
        }),
        ..Default::default()
    }
}

#[cfg(test)]
mod tests {
    use awaken_protocol_a2a::{Artifact, Part, Task, TaskState, TaskStatus};

    use super::*;

    fn snapshot(state: TaskState, artifacts: Vec<Artifact>) -> TaskSnapshot {
        TaskSnapshot {
            task: Task {
                id: "task_1".into(),
                context_id: "thread_1".into(),
                status: TaskStatus {
                    state,
                    message: None,
                    timestamp: None,
                },
                artifacts,
                history: Vec::new(),
                metadata: None,
            },
            updated_at_ms: 0,
            current_agent_id: None,
        }
    }

    fn artifact(id: &str, text: &str) -> Artifact {
        Artifact {
            artifact_id: id.into(),
            name: None,
            description: None,
            parts: vec![Part::text(text)],
            metadata: None,
        }
    }

    #[test]
    fn emits_initial_task_snapshot_for_stream() {
        let mut projector = TaskStreamProjector::new(InitialStreamEvent::TaskSnapshot);

        let responses = projector.project(&snapshot(TaskState::Submitted, Vec::new()));

        assert_eq!(responses.len(), 1);
        assert_eq!(
            responses[0].task.as_ref().map(|task| task.id.as_str()),
            Some("task_1")
        );
        assert!(responses[0].status_update.is_none());
    }

    #[test]
    fn emits_initial_status_update_for_push() {
        let mut projector = TaskStreamProjector::new(InitialStreamEvent::StatusUpdate);

        let responses = projector.project(&snapshot(TaskState::Submitted, Vec::new()));

        assert_eq!(responses.len(), 1);
        assert_eq!(
            responses[0]
                .status_update
                .as_ref()
                .map(|event| event.status.state),
            Some(TaskState::Submitted)
        );
        assert!(responses[0].task.is_none());
    }

    #[test]
    fn emits_only_changed_status_after_initial_event() {
        let mut projector = TaskStreamProjector::new(InitialStreamEvent::TaskSnapshot);

        projector.project(&snapshot(TaskState::Submitted, Vec::new()));
        assert!(
            projector
                .project(&snapshot(TaskState::Submitted, Vec::new()))
                .is_empty()
        );

        let responses = projector.project(&snapshot(TaskState::Working, Vec::new()));

        assert_eq!(responses.len(), 1);
        assert_eq!(
            responses[0]
                .status_update
                .as_ref()
                .map(|event| event.status.state),
            Some(TaskState::Working)
        );
    }

    #[test]
    fn emits_full_artifact_replacement_when_artifacts_change() {
        let mut projector = TaskStreamProjector::new(InitialStreamEvent::TaskSnapshot);
        projector.project(&snapshot(TaskState::Working, Vec::new()));

        let responses = projector.project(&snapshot(
            TaskState::Working,
            vec![artifact("a1", "first"), artifact("a2", "second")],
        ));

        assert_eq!(responses.len(), 2);
        assert_eq!(
            responses[0]
                .artifact_update
                .as_ref()
                .map(|event| (event.artifact.artifact_id.as_str(), event.last_chunk)),
            Some(("a1", Some(false)))
        );
        assert_eq!(
            responses[1]
                .artifact_update
                .as_ref()
                .map(|event| (event.artifact.artifact_id.as_str(), event.last_chunk)),
            Some(("a2", Some(true)))
        );
    }
}