use std::convert::Infallible;
use std::sync::Arc;
use axum::Extension;
use axum::extract::{Path, Query, State};
use axum::http::{HeaderMap, StatusCode};
use axum::response::{IntoResponse, Response};
use axum::routing::{get, post};
use axum::{Json, Router};
use bytes::Bytes;
use futures::StreamExt;
use serde_json::{Value, json};
use awaken_server_contract::ScopeContext;
use awaken_server_contract::contract::content::{ContentBlock, extract_text};
use awaken_server_contract::contract::event_sink::EventSink;
use awaken_server_contract::contract::message::{Message, MessageRecord, Role, ToolCall};
use awaken_server_contract::contract::storage::{MessageOrder, MessageQuery};
use awaken_server_contract::registry_spec::AgentSpec;
use crate::app::ProtocolRoutesState as S;
use crate::http_run::{format_relay_error, wire_sse_relay};
use crate::http_sse::{sse_body_stream, sse_response};
use crate::routes::{ApiError, map_mailbox_error};
use crate::transport::channel_sink::BoundedChannelEventSink;
use crate::transport::replay_buffer::EventReplayBuffer;
use awaken_runtime::RunActivation;
use awaken_runtime::registry::resolve::RegistrySetResolver;
use awaken_runtime::registry::{AgentSpecRegistry, RegistryHandle, RegistrySet};
use awaken_runtime::{AgentResolver, AgentRuntime};
use super::encoder::AiSdkEncoder;
use super::request::{AiSdkChatRequest, ProcessedRequest};
const AI_SDK_STREAM_HEADER: &str = "x-vercel-ai-ui-message-stream";
const AI_SDK_STREAM_VERSION: &str = "v1";
pub(super) fn ai_sdk_sse_response<S>(stream: S) -> Response
where
S: futures::Stream<Item = Result<Bytes, Infallible>> + Send + 'static,
{
let mut response = sse_response(stream);
response.headers_mut().insert(
axum::http::HeaderName::from_static(AI_SDK_STREAM_HEADER),
axum::http::HeaderValue::from_static(AI_SDK_STREAM_VERSION),
);
response
}
pub fn ai_sdk_routes() -> Router<S> {
Router::new()
.route("/v1/ai-sdk/chat", post(ai_sdk_chat))
.route(
"/v1/ai-sdk/threads/:thread_id/runs",
post(ai_sdk_chat_threaded),
)
.route(
"/v1/ai-sdk/agents/:agent_id/runs",
post(ai_sdk_chat_agent_scoped),
)
.route(
"/v1/ai-sdk/agent-previews/runs",
post(ai_sdk_chat_preview_agent),
)
.route("/v1/ai-sdk/chat/:thread_id/stream", get(resume_stream))
.route("/v1/ai-sdk/threads/:thread_id/stream", get(resume_stream))
.route(
"/v1/ai-sdk/threads/:thread_id/messages",
get(thread_messages),
)
.route("/v1/ai-sdk/threads/:thread_id/cancel", post(cancel_thread))
.route(
"/v1/ai-sdk/threads/:thread_id/interrupt",
post(interrupt_thread),
)
.merge(super::replay::ai_sdk_replay_routes())
}
pub(crate) fn ai_sdk_admin_routes() -> Router<S> {
Router::new().route("/v1/admin/assistant/runs", post(ai_sdk_admin_assistant))
}
async fn ai_sdk_chat(
State(st): State<S>,
Extension(scope): Extension<ScopeContext>,
Json(payload): Json<AiSdkChatRequest>,
) -> Result<Response, ApiError> {
ai_sdk_chat_inner(st.scoped(&scope), payload).await
}
async fn ai_sdk_chat_threaded(
State(st): State<S>,
Extension(scope): Extension<ScopeContext>,
Path(thread_id): Path<String>,
Json(mut payload): Json<AiSdkChatRequest>,
) -> Result<Response, ApiError> {
payload.thread_id = Some(thread_id);
ai_sdk_chat_inner(st.scoped(&scope), payload).await
}
async fn ai_sdk_chat_agent_scoped(
State(st): State<S>,
Extension(scope): Extension<ScopeContext>,
Path(agent_id): Path<String>,
Json(mut payload): Json<AiSdkChatRequest>,
) -> Result<Response, ApiError> {
payload.agent_id = Some(agent_id);
ai_sdk_chat_inner(st.scoped(&scope), payload).await
}
async fn ai_sdk_admin_assistant(
State(st): State<S>,
Extension(scope): Extension<ScopeContext>,
Json(payload): Json<AiSdkChatRequest>,
) -> Result<Response, ApiError> {
let st = st.scoped(&scope);
let config_state = st.config.as_ref().cloned().ok_or_else(|| {
ApiError::ServiceUnavailable("admin assistant requires config routes".into())
})?;
let config_state = crate::app::ConfigRoutesState {
admin: st.admin.clone(),
config: config_state,
run: st.run.clone(),
scope_provider: st.scope_provider.clone(),
};
let current =
st.run.runtime.registry_set().ok_or_else(|| {
ApiError::Internal("runtime does not expose a registry snapshot".into())
})?;
let assistant_config = crate::admin_assistant::load_config(&config_state.config)
.await
.map_err(|error| ApiError::Internal(error.to_string()))?;
let model_id = crate::admin_assistant::resolve_admin_assistant_model_id(
current.models.as_ref(),
current.providers.as_ref(),
assistant_config.model_id.as_deref(),
)
.ok_or_else(|| {
ApiError::Conflict(
"configure and publish the first model before using the admin assistant".into(),
)
})?;
let agent = crate::admin_assistant::admin_assistant_agent(
model_id,
Some(assistant_config.policy_prompt),
);
let processed = super::request::process_preview_chat_request(
payload.messages,
payload.thread_id,
payload.state,
)
.map_err(ApiError::BadRequest)?;
let candidate = build_admin_assistant_registry_set(¤t, &agent, config_state)?;
let preview_runtime = Arc::new(
AgentRuntime::new_with_execution_resolver(Arc::new(RegistrySetResolver::new(
candidate.clone(),
)))
.with_registry_handle(RegistryHandle::new(candidate)),
);
let mut request = RunActivation::new(
processed.thread_id,
crate::request::inject_frontend_context(processed.messages, processed.state),
)
.with_agent_id(agent.id)
.with_adapter(awaken_server_contract::contract::tool_intercept::AdapterKind::AiSdk);
if !processed.decisions.is_empty() {
request = request.with_decisions(processed.decisions);
}
let sse_rx = spawn_ephemeral_runtime_stream(
preview_runtime,
request,
st.sse_buffer_size,
"admin assistant run failed",
);
Ok(ai_sdk_sse_response(sse_body_stream(sse_rx)))
}
async fn ai_sdk_chat_preview_agent(
State(st): State<S>,
Extension(scope): Extension<ScopeContext>,
headers: HeaderMap,
Json(payload): Json<super::request::PreviewAgentChatRequest>,
) -> Result<Response, ApiError> {
let st = st.scoped(&scope);
if let Err(err) = crate::config_routes::ensure_admin_auth(&st.admin, &headers) {
return Ok(err.into_response());
}
let super::request::PreviewAgentChatRequest {
messages,
thread_id,
state,
mut agent,
} = payload;
if agent.model_id.trim().is_empty() {
return Err(ApiError::BadRequest(
"preview agent model_id cannot be empty".to_string(),
));
}
if agent.id.trim().is_empty() {
agent.id = "draft-preview".to_string();
}
if agent.uses_remote_backend() || agent.registry.is_some() {
return Err(ApiError::BadRequest(
"preview agent payload must not carry remote backend, `endpoint`, or `registry` fields"
.to_string(),
));
}
let durable_resolution_id = match st.config.as_ref() {
Some(config) if config.runtime_manager.has_versioned_registry_store() => Some(
config
.runtime_manager
.publish_ephemeral_with_extra_agent(&agent)
.await
.map_err(|error| {
ApiError::Internal(format!("failed to publish draft preview registry: {error}"))
})?
.to_string(),
),
_ => None,
};
if let Some(resolution_id) = durable_resolution_id {
let chat_request = AiSdkChatRequest {
messages,
thread_id,
agent_id: Some(agent.id),
state,
};
return ai_sdk_chat_inner_with_resolution_id_hint(st, chat_request, Some(resolution_id))
.await;
}
let processed = super::request::process_preview_chat_request(messages, thread_id, state)
.map_err(ApiError::BadRequest)?;
let candidate = build_preview_registry_set(&st, &agent)?;
let preview_runtime = Arc::new(
AgentRuntime::new_with_execution_resolver(Arc::new(RegistrySetResolver::new(
candidate.clone(),
)))
.with_registry_handle(RegistryHandle::new(candidate)),
);
let mut request = RunActivation::new(
processed.thread_id,
crate::request::inject_frontend_context(processed.messages, processed.state),
)
.with_agent_id(agent.id)
.with_adapter(awaken_server_contract::contract::tool_intercept::AdapterKind::AiSdk);
if !processed.decisions.is_empty() {
request = request.with_decisions(processed.decisions);
}
let sse_rx = spawn_ephemeral_runtime_stream(
preview_runtime,
request,
st.sse_buffer_size,
"agent preview run failed",
);
Ok(ai_sdk_sse_response(sse_body_stream(sse_rx)))
}
fn spawn_ephemeral_runtime_stream(
runtime: Arc<AgentRuntime>,
request: RunActivation,
sse_buffer_size: usize,
failure_message: &'static str,
) -> tokio::sync::mpsc::Receiver<Bytes> {
let (event_tx, event_rx) = tokio::sync::mpsc::channel(sse_buffer_size.max(32));
let sink: Arc<dyn EventSink> = Arc::new(BoundedChannelEventSink::new(event_tx));
let run_task = tokio::spawn(async move { runtime.run(request, sink).await });
let encoder = AiSdkEncoder::new();
let mut relay_rx = wire_sse_relay(event_rx, encoder, sse_buffer_size, None);
let (final_tx, final_rx) = tokio::sync::mpsc::channel::<Bytes>(sse_buffer_size.max(1));
tokio::spawn(async move {
let mut run_task = std::pin::pin!(run_task);
let mut relay_done = false;
let mut run_done = false;
loop {
tokio::select! {
frame = relay_rx.recv(), if !relay_done => {
match frame {
Some(frame) => {
if final_tx.send(frame).await.is_err() {
break;
}
}
None => relay_done = true,
}
}
result = &mut run_task, if !run_done => {
run_done = true;
match result {
Ok(Ok(_)) => {}
Ok(Err(error)) => {
tracing::error!(error = %error, message = failure_message, "ephemeral runtime run failed");
let _ = final_tx
.send(format_relay_error(&format!("{failure_message}: {error}")))
.await;
}
Err(error) => {
tracing::error!(error = %error, message = failure_message, "ephemeral runtime task failed");
let _ = final_tx
.send(format_relay_error(&format!("{failure_message}: {error}")))
.await;
}
}
}
}
if relay_done && run_done {
break;
}
}
});
final_rx
}
#[derive(Clone)]
struct PreviewAgentRegistry {
preview: AgentSpec,
fallback: Arc<dyn AgentSpecRegistry>,
}
impl PreviewAgentRegistry {
fn new(preview: AgentSpec, fallback: Arc<dyn AgentSpecRegistry>) -> Self {
Self { preview, fallback }
}
}
impl AgentSpecRegistry for PreviewAgentRegistry {
fn get_agent(&self, id: &str) -> Option<AgentSpec> {
if id == self.preview.id {
Some(self.preview.clone())
} else {
self.fallback.get_agent(id)
}
}
fn agent_ids(&self) -> Vec<String> {
let mut ids = self.fallback.agent_ids();
if !ids.iter().any(|id| id == &self.preview.id) {
ids.push(self.preview.id.clone());
}
ids
}
}
fn build_preview_registry_set(st: &S, agent: &AgentSpec) -> Result<RegistrySet, ApiError> {
let current =
st.run.runtime.registry_set().ok_or_else(|| {
ApiError::Internal("runtime does not expose a registry snapshot".into())
})?;
let candidate = RegistrySet {
agents: Arc::new(PreviewAgentRegistry::new(
agent.clone(),
current.agents.clone(),
)),
tools: current.tools.clone(),
models: current.models.clone(),
providers: current.providers.clone(),
plugins: current.plugins.clone(),
backends: current.backends.clone(),
};
if !agent.uses_remote_backend() {
RegistrySetResolver::new(candidate.clone())
.resolve(&agent.id)
.map_err(|error| {
ApiError::BadRequest(format!("invalid preview agent '{}': {error}", agent.id))
})?;
}
Ok(candidate)
}
fn build_admin_assistant_registry_set(
current: &RegistrySet,
agent: &AgentSpec,
config_state: crate::app::ConfigRoutesState,
) -> Result<RegistrySet, ApiError> {
let candidate = RegistrySet {
agents: Arc::new(PreviewAgentRegistry::new(
agent.clone(),
current.agents.clone(),
)),
tools: crate::admin_assistant::admin_tool_registry(config_state),
models: current.models.clone(),
providers: current.providers.clone(),
plugins: current.plugins.clone(),
backends: current.backends.clone(),
};
RegistrySetResolver::new(candidate.clone())
.resolve(&agent.id)
.map_err(|error| {
ApiError::Internal(format!(
"invalid admin assistant agent '{}': {error}",
agent.id
))
})?;
Ok(candidate)
}
async fn ai_sdk_chat_inner(st: S, payload: AiSdkChatRequest) -> Result<Response, ApiError> {
ai_sdk_chat_inner_with_resolution_id_hint(st, payload, None).await
}
async fn ai_sdk_chat_inner_with_resolution_id_hint(
st: S,
payload: AiSdkChatRequest,
resolution_id_hint: Option<String>,
) -> Result<Response, ApiError> {
let processed = super::request::process_chat_request(st.run.store().as_ref(), payload)
.await
.map_err(ApiError::BadRequest)?;
let resume_only = processed.is_resume_only();
let ProcessedRequest {
thread_id,
messages,
decisions,
has_interaction_responses: _,
state,
agent_id,
} = processed;
if !decisions.is_empty() {
let (new_event_tx, new_event_rx) = tokio::sync::mpsc::channel(256);
let mailbox = st.run.mailbox();
let reconnected = mailbox
.reconnect_sink(&st.run.scoped_id(&thread_id), new_event_tx)
.await;
if reconnected {
let mut any_delivered = false;
for (tool_call_id, resume) in &decisions {
if mailbox.send_decision(
&st.run.scoped_id(&thread_id),
tool_call_id.clone(),
resume.clone(),
) {
any_delivered = true;
}
}
if any_delivered {
let replay_buffer = Arc::new(EventReplayBuffer::new(st.replay_buffer_capacity));
let buffer_key = st.run.scoped_id(&thread_id);
st.insert_replay_buffer(buffer_key.clone(), Arc::clone(&replay_buffer));
let encoder = AiSdkEncoder::new();
let sse_rx = wire_sse_relay(
new_event_rx,
encoder,
st.sse_buffer_size,
Some(Arc::clone(&replay_buffer)),
);
let st_cleanup = st.clone();
let replay_buf = Arc::clone(&replay_buffer);
let tid = buffer_key;
let mut rx = sse_rx;
let (final_tx, final_rx) = tokio::sync::mpsc::channel::<Bytes>(st.sse_buffer_size);
tokio::spawn(async move {
while let Some(frame) = rx.recv().await {
if final_tx.send(frame).await.is_err() {
break;
}
}
replay_buf.close_subscribers();
st_cleanup.remove_replay_buffer(&tid);
});
return Ok(ai_sdk_sse_response(sse_body_stream(final_rx)));
}
}
}
if resume_only {
let (_, rx) = tokio::sync::mpsc::channel(1);
let encoder = AiSdkEncoder::new();
let sse_rx = crate::http_run::wire_sse_relay(rx, encoder, st.sse_buffer_size, None);
return Ok(ai_sdk_sse_response(sse_body_stream(sse_rx)));
}
let messages = crate::request::inject_frontend_context(messages, state);
let mut request = RunActivation::new(thread_id.clone(), messages)
.with_adapter(awaken_server_contract::contract::tool_intercept::AdapterKind::AiSdk);
if let Some(id) = agent_id {
request = request.with_agent_id(id);
}
if !decisions.is_empty() {
request = request.with_decisions(decisions);
}
if let Some(resolution_id_hint) = resolution_id_hint {
request = request.with_resolution_id_hint(resolution_id_hint);
}
let (_result, event_rx) = st
.run
.mailbox()
.submit(st.run.scope_activation(request))
.await
.map_err(map_mailbox_error)?;
let replay_buffer = Arc::new(EventReplayBuffer::new(st.replay_buffer_capacity));
let buffer_key = st.run.scoped_id(&thread_id);
st.insert_replay_buffer(buffer_key.clone(), Arc::clone(&replay_buffer));
let encoder = AiSdkEncoder::new();
let sse_rx = wire_sse_relay(event_rx, encoder, st.sse_buffer_size, Some(replay_buffer));
let st_cleanup = st.clone();
let replay_buf_for_cleanup = st
.get_replay_buffer(&thread_id)
.or_else(|| st.get_replay_buffer(&buffer_key))
.ok_or_else(|| ApiError::Internal("replay buffer disappeared after insert".into()))?;
let cleanup_thread_id = buffer_key;
let mut sse_rx_forwarded = sse_rx;
let (final_tx, final_rx) = tokio::sync::mpsc::channel::<Bytes>(st.sse_buffer_size);
tokio::spawn(async move {
let mut client_tx = Some(final_tx);
let mut waiting_for_client_finish = false;
while let Some(frame) = sse_rx_forwarded.recv().await {
if is_waiting_state_snapshot_frame(&frame) {
waiting_for_client_finish = true;
}
let should_close_client = is_suspended_finish_frame(&frame);
let is_finish_step = is_finish_step_frame(&frame);
if let Some(tx) = client_tx.as_ref() {
if tx.send(frame).await.is_err() {
client_tx = None;
} else if should_close_client {
client_tx = None;
} else if waiting_for_client_finish && is_finish_step {
let (_, finish_frame) = replay_buf_for_cleanup
.push_json(r#"{"type":"finish","finishReason":"tool-calls"}"#);
let _ = tx.send(finish_frame).await;
client_tx = None;
waiting_for_client_finish = false;
}
}
}
replay_buf_for_cleanup.close_subscribers();
st_cleanup.remove_replay_buffer(&cleanup_thread_id);
});
Ok(ai_sdk_sse_response(sse_body_stream(final_rx)))
}
async fn resume_stream(
State(st): State<S>,
Extension(scope): Extension<ScopeContext>,
Path(thread_id): Path<String>,
headers: HeaderMap,
) -> Response {
let st = st.scoped(&scope);
if super::replay::resume_header_has_protocol_cursor(&headers) {
match super::replay::resume_from_replay_log(&st, &thread_id, &headers).await {
Ok(Some(response)) => return response,
Ok(None) => return axum::http::StatusCode::NO_CONTENT.into_response(),
Err(error) => return error.into_response(),
}
}
let buffer = st.get_replay_buffer(&st.run.scoped_id(&thread_id));
let Some(buffer) = buffer else {
match super::replay::resume_from_replay_log(&st, &thread_id, &headers).await {
Ok(Some(response)) => return response,
Ok(None) => {}
Err(error) => return error.into_response(),
}
return axum::http::StatusCode::NO_CONTENT.into_response();
};
let (replayed, live_rx) = buffer.subscribe_after(0);
let replay_stream = futures::stream::iter(replayed.into_iter().map(Ok::<Bytes, Infallible>));
let live_stream =
tokio_stream::wrappers::UnboundedReceiverStream::new(live_rx).map(Ok::<Bytes, Infallible>);
let combined = replay_stream.chain(live_stream);
ai_sdk_sse_response(combined)
}
async fn thread_messages(
State(st): State<S>,
Extension(scope): Extension<ScopeContext>,
Path(id): Path<String>,
Query(params): Query<crate::query::MessageQueryParams>,
) -> Result<Json<Value>, ApiError> {
let st = st.scoped(&scope);
let storage_query = params.storage_query().map_err(ApiError::BadRequest)?;
let records = st
.run
.store()
.load_message_records(&id)
.await
.map_err(|e| ApiError::Internal(e.to_string()))?
.unwrap_or_default();
let page =
encode_history_page(records, &storage_query, ¶ms).map_err(ApiError::BadRequest)?;
Ok(Json(serde_json::json!({
"messages": page.items,
"total": page.total,
"has_more": page.has_more,
"next_cursor": page.next_cursor,
})))
}
fn encode_history_page(
records: Vec<MessageRecord>,
storage_query: &MessageQuery,
params: &crate::query::MessageQueryParams,
) -> Result<crate::query::CursorPage<Value>, String> {
let mut records: Vec<MessageRecord> = records
.into_iter()
.filter(|record| storage_query.matches_record(record))
.collect();
records.sort_by_key(|record| record.seq);
let messages = records.into_iter().map(|record| record.message).collect();
let mut encoded_messages = encode_history_messages(messages);
if matches!(storage_query.order, MessageOrder::Desc) {
encoded_messages.reverse();
}
params.paginate(encoded_messages)
}
fn encode_history_messages(messages: Vec<Message>) -> Vec<Value> {
let mut encoded: Vec<Value> = Vec::new();
let mut pending_tool_parts: std::collections::HashMap<String, (usize, usize)> =
std::collections::HashMap::new();
for message in messages {
match message.role {
Role::User | Role::System => {
let parts = content_blocks_to_ui_parts(&message.content);
if parts.is_empty() {
continue;
}
encoded.push(json!({
"id": message.id,
"role": match message.role {
Role::User => "user",
Role::System => "system",
_ => unreachable!(),
},
"parts": parts,
}));
}
Role::Assistant => {
let mut parts = content_blocks_to_ui_parts(&message.content);
let message_index = encoded.len();
if let Some(tool_calls) = &message.tool_calls {
for call in tool_calls {
let part_index = parts.len();
parts.push(tool_call_part(call));
pending_tool_parts.insert(call.id.clone(), (message_index, part_index));
}
}
if parts.is_empty() {
continue;
}
encoded.push(json!({
"id": message.id,
"role": "assistant",
"parts": parts,
}));
}
Role::Tool => {
let Some(call_id) = message.tool_call_id.as_ref() else {
encoded.push(json!({
"id": message.id,
"role": "tool",
"parts": content_blocks_to_ui_parts(&message.content),
}));
continue;
};
let Some((message_index, part_index)) = pending_tool_parts.remove(call_id) else {
encoded.push(json!({
"id": message.id,
"role": "tool",
"parts": content_blocks_to_ui_parts(&message.content),
}));
continue;
};
let Some(message_object) = encoded
.get_mut(message_index)
.and_then(Value::as_object_mut)
else {
continue;
};
let Some(parts) = message_object
.get_mut("parts")
.and_then(Value::as_array_mut)
else {
continue;
};
let Some(part) = parts.get_mut(part_index).and_then(Value::as_object_mut) else {
continue;
};
let output_text = parse_tool_message_output(&message);
let is_suspended = output_text
.as_str()
.is_some_and(|s| s.contains("suspended"));
if is_suspended {
} else {
part.insert(
"state".to_string(),
Value::String("output-available".into()),
);
part.insert("output".to_string(), output_text);
part.insert("providerExecuted".to_string(), Value::Bool(true));
}
}
}
}
encoded
}
fn content_blocks_to_ui_parts(content: &[ContentBlock]) -> Vec<Value> {
content
.iter()
.filter_map(|block| match block {
ContentBlock::Text { text } => Some(json!({"type": "text", "text": text})),
_ => None,
})
.collect()
}
fn tool_call_part(call: &ToolCall) -> Value {
json!({
"type": format!("tool-{}", call.name),
"toolName": call.name,
"toolCallId": call.id,
"state": "input-available",
"input": call.arguments,
"providerExecuted": true,
})
}
fn parse_tool_message_output(message: &Message) -> Value {
let text = extract_text(&message.content);
serde_json::from_str(&text).unwrap_or(Value::String(text))
}
async fn cancel_thread(
State(st): State<S>,
Extension(scope): Extension<ScopeContext>,
Path(thread_id): Path<String>,
) -> Result<Response, ApiError> {
let st = st.scoped(&scope);
let cancelled = st
.run
.mailbox()
.cancel(&st.run.scoped_id(&thread_id))
.await
.map_err(|e| ApiError::Internal(e.to_string()))?;
if cancelled {
return Ok((
StatusCode::ACCEPTED,
Json(json!({
"status": "cancel_requested",
"thread_id": thread_id,
})),
)
.into_response());
}
Err(ApiError::ThreadNotFound(thread_id))
}
async fn interrupt_thread(
State(st): State<S>,
Extension(scope): Extension<ScopeContext>,
Path(thread_id): Path<String>,
) -> Result<Response, ApiError> {
let st = st.scoped(&scope);
let interrupted = st
.run
.mailbox()
.interrupt(&st.run.scoped_id(&thread_id))
.await
.map_err(|e| ApiError::Internal(e.to_string()))?;
if interrupted.active_dispatch.is_some() || interrupted.superseded_count > 0 {
return Ok((
StatusCode::ACCEPTED,
Json(json!({
"status": "interrupt_requested",
"thread_id": thread_id,
"superseded_dispatches": interrupted.superseded_count,
})),
)
.into_response());
}
Err(ApiError::ThreadNotFound(thread_id))
}
fn is_suspended_finish_frame(frame: &Bytes) -> bool {
let Ok(text) = std::str::from_utf8(frame) else {
return false;
};
let Some(data_line) = text.lines().find_map(|line| line.strip_prefix("data: ")) else {
return false;
};
let Ok(value) = serde_json::from_str::<Value>(data_line) else {
return false;
};
value.get("type").and_then(Value::as_str) == Some("finish")
&& value.get("finishReason").and_then(Value::as_str) == Some("tool-calls")
}
fn parse_frame_json(frame: &Bytes) -> Option<Value> {
let text = std::str::from_utf8(frame).ok()?;
let data_line = text.lines().find_map(|line| line.strip_prefix("data: "))?;
serde_json::from_str::<Value>(data_line).ok()
}
fn is_waiting_state_snapshot_frame(frame: &Bytes) -> bool {
let Some(value) = parse_frame_json(frame) else {
return false;
};
value.get("type").and_then(Value::as_str) == Some("data-state-snapshot")
&& value
.get("data")
.and_then(|data| data.get("extensions"))
.and_then(|ext| ext.get("__runtime.run_lifecycle"))
.and_then(|lifecycle| lifecycle.get("status"))
.and_then(Value::as_str)
== Some("waiting")
}
fn is_finish_step_frame(frame: &Bytes) -> bool {
parse_frame_json(frame)
.and_then(|value| value.get("type").and_then(Value::as_str).map(str::to_owned))
.as_deref()
== Some("finish-step")
}
#[cfg(test)]
mod tests {
use super::*;
use awaken_runtime::RuntimeError;
use awaken_runtime::registry::ResolvedAgent;
use futures::stream;
use serde_json::json;
struct FailingResolver;
impl AgentResolver for FailingResolver {
fn resolve(&self, _agent_id: &str) -> Result<ResolvedAgent, RuntimeError> {
Err(RuntimeError::ResolveFailed {
message: "forced test failure".into(),
})
}
}
#[test]
fn ai_sdk_sse_response_sets_transport_header() {
let response = ai_sdk_sse_response(stream::empty());
assert_eq!(
response
.headers()
.get(AI_SDK_STREAM_HEADER)
.and_then(|value| value.to_str().ok()),
Some(AI_SDK_STREAM_VERSION)
);
}
#[tokio::test]
async fn ephemeral_runtime_stream_emits_error_frame_when_run_fails() {
let runtime = Arc::new(AgentRuntime::new(Arc::new(FailingResolver)));
let request = RunActivation::new("thread", vec![Message::user("hello")])
.with_agent_id("missing-agent");
let mut rx =
spawn_ephemeral_runtime_stream(runtime, request, 8, "test ephemeral run failed");
let mut frames = Vec::new();
while let Some(frame) = rx.recv().await {
frames.push(String::from_utf8(frame.to_vec()).expect("utf8 sse frame"));
if frames
.iter()
.any(|frame| frame.contains(r#""type":"error""#))
{
break;
}
}
let body = frames.join("");
assert!(
body.contains(r#""type":"error""#),
"run failure should be visible to the SSE client: {body}"
);
assert!(
body.contains("test ephemeral run failed"),
"error frame should name the failed runtime path: {body}"
);
}
#[test]
fn detects_suspended_finish_frame() {
let frame =
Bytes::from("id: 7\ndata: {\"type\":\"finish\",\"finishReason\":\"tool-calls\"}\n\n");
assert!(is_suspended_finish_frame(&frame));
let natural =
Bytes::from("id: 8\ndata: {\"type\":\"finish\",\"finishReason\":\"stop\"}\n\n");
assert!(!is_suspended_finish_frame(&natural));
}
#[test]
fn encodes_tool_history_as_assistant_tool_parts() {
let messages = vec![
Message::user("show me a dashboard").with_id("u1".into()),
Message::assistant_with_tool_calls(
"Generating the dashboard now.",
vec![ToolCall::new(
"call-1",
"render_json_ui",
json!({"prompt": "Quarterly dashboard"}),
)],
)
.with_id("a1".into()),
Message::tool("call-1", r#"{"content":{"root":"page"},"steps":1}"#)
.with_id("t1".into()),
Message::assistant("Done.").with_id("a2".into()),
];
let encoded = encode_history_messages(messages);
assert_eq!(encoded.len(), 3);
let assistant_parts = encoded[1]["parts"].as_array().expect("assistant parts");
assert_eq!(assistant_parts[0]["type"].as_str(), Some("text"));
assert_eq!(
assistant_parts[1]["type"].as_str(),
Some("tool-render_json_ui")
);
assert_eq!(assistant_parts[1]["toolCallId"].as_str(), Some("call-1"));
assert_eq!(
assistant_parts[1]["state"].as_str(),
Some("output-available")
);
assert_eq!(assistant_parts[1]["providerExecuted"].as_bool(), Some(true));
assert_eq!(
assistant_parts[1]["output"]["content"]["root"].as_str(),
Some("page")
);
}
#[test]
fn encoded_history_total_matches_encoded_messages() {
let messages = vec![
Message::user("show me a dashboard").with_id("u1".into()),
Message::assistant_with_tool_calls(
"Generating the dashboard now.",
vec![ToolCall::new(
"call-1",
"render_json_ui",
json!({"prompt": "Quarterly dashboard"}),
)],
)
.with_id("a1".into()),
Message::tool("call-1", r#"{"content":{"root":"page"},"steps":1}"#)
.with_id("t1".into()),
];
let encoded = encode_history_messages(messages);
assert_eq!(encoded.len(), 2);
assert_eq!(
encoded.len(),
2,
"history pagination must use encoded message count"
);
}
#[test]
fn encode_history_page_desc_preserves_tool_call_output_merge() {
let params: crate::query::MessageQueryParams =
serde_json::from_str(r#"{"order":"desc"}"#).unwrap();
let storage_query = params.storage_query().unwrap();
let records = vec![
MessageRecord::from_message("t", 1, Message::user("show me a dashboard")),
MessageRecord::from_message(
"t",
2,
Message::assistant_with_tool_calls(
"Generating the dashboard now.",
vec![ToolCall::new(
"call-1",
"render_json_ui",
json!({"prompt": "Quarterly dashboard"}),
)],
),
),
MessageRecord::from_message(
"t",
3,
Message::tool("call-1", r#"{"content":{"root":"page"},"steps":1}"#),
),
MessageRecord::from_message("t", 4, Message::assistant("Done.")),
];
let page = encode_history_page(records, &storage_query, ¶ms).unwrap();
assert_eq!(page.total, 3);
assert_eq!(page.items.len(), 3);
assert_eq!(page.items[0]["role"].as_str(), Some("assistant"));
assert_eq!(page.items[0]["parts"][0]["text"].as_str(), Some("Done."));
assert_eq!(page.items[1]["role"].as_str(), Some("assistant"));
assert_eq!(
page.items[1]["parts"][1]["type"].as_str(),
Some("tool-render_json_ui")
);
assert_eq!(
page.items[1]["parts"][1]["state"].as_str(),
Some("output-available")
);
assert_eq!(
page.items[1]["parts"][1]["output"]["content"]["root"].as_str(),
Some("page")
);
assert!(
page.items
.iter()
.all(|message| message["role"].as_str() != Some("tool"))
);
}
#[test]
fn content_blocks_to_ui_parts_text() {
let blocks = vec![ContentBlock::text("hello")];
let parts = content_blocks_to_ui_parts(&blocks);
assert_eq!(parts.len(), 1);
assert_eq!(parts[0]["type"].as_str(), Some("text"));
assert_eq!(parts[0]["text"].as_str(), Some("hello"));
}
#[test]
fn content_blocks_to_ui_parts_empty() {
let parts = content_blocks_to_ui_parts(&[]);
assert!(parts.is_empty());
}
#[test]
fn content_blocks_to_ui_parts_non_text_skipped() {
let blocks = vec![
ContentBlock::text("keep"),
ContentBlock::image_url("https://example.com/img.png"),
];
let parts = content_blocks_to_ui_parts(&blocks);
assert_eq!(parts.len(), 1);
assert_eq!(parts[0]["text"].as_str(), Some("keep"));
}
#[test]
fn tool_call_part_structure() {
let call = ToolCall::new("c1", "search", json!({"q": "rust"}));
let part = tool_call_part(&call);
assert_eq!(part["type"].as_str(), Some("tool-search"));
assert_eq!(part["toolName"].as_str(), Some("search"));
assert_eq!(part["toolCallId"].as_str(), Some("c1"));
assert_eq!(part["state"].as_str(), Some("input-available"));
assert_eq!(part["providerExecuted"].as_bool(), Some(true));
assert_eq!(part["input"]["q"].as_str(), Some("rust"));
}
#[test]
fn parse_tool_message_output_json() {
let msg = Message::tool("c1", r#"{"key": "value"}"#);
let output = parse_tool_message_output(&msg);
assert_eq!(output["key"].as_str(), Some("value"));
}
#[test]
fn parse_tool_message_output_plain_text() {
let msg = Message::tool("c1", "not json at all");
let output = parse_tool_message_output(&msg);
assert_eq!(output.as_str(), Some("not json at all"));
}
#[test]
fn parse_frame_json_valid() {
let frame = Bytes::from("id: 1\ndata: {\"type\":\"text\"}\n\n");
let val = parse_frame_json(&frame).unwrap();
assert_eq!(val["type"].as_str(), Some("text"));
}
#[test]
fn parse_frame_json_no_data_line() {
let frame = Bytes::from("id: 1\nevent: ping\n\n");
assert!(parse_frame_json(&frame).is_none());
}
#[test]
fn parse_frame_json_invalid_json() {
let frame = Bytes::from("data: {not valid json}\n\n");
assert!(parse_frame_json(&frame).is_none());
}
#[test]
fn is_finish_step_frame_true() {
let frame = Bytes::from("id: 5\ndata: {\"type\":\"finish-step\"}\n\n");
assert!(is_finish_step_frame(&frame));
}
#[test]
fn is_finish_step_frame_false() {
let frame = Bytes::from("id: 5\ndata: {\"type\":\"text\"}\n\n");
assert!(!is_finish_step_frame(&frame));
}
}