mod helpers;
use std::sync::Arc;
use uuid::Uuid;
use self::helpers::{
BroadcastAguiLifecycleParams, broadcast_agui_lifecycle, collect_stream_response,
};
use crate::models::a2a::{Message, MessageRole, Part, Task, TaskState, TaskStatus, TextPart};
use crate::services::a2a_server::processing::message::persistence::{
broadcast_completion, persist_completed_task,
};
use crate::services::a2a_server::processing::message::stream_processor::StreamProcessor;
use crate::services::a2a_server::processing::message::{
MessageProcessor, ProcessMessageStreamParams,
};
use crate::services::a2a_server::processing::task_builder::build_completed_task;
use crate::services::a2a_server::streaming::broadcast::{
BroadcastTaskCreatedParams, broadcast_task_created,
};
use crate::services::shared::{AgentServiceError, Result};
use systemprompt_identifiers::{MessageId, SessionId, TaskId, TraceId, UserId};
use systemprompt_models::{RequestContext, TaskMetadata};
impl MessageProcessor {
pub(in crate::services::a2a_server) async fn handle_message(
&self,
message: Message,
agent_name: &str,
context: &RequestContext,
) -> Result<Task> {
tracing::info!(agent_name = %agent_name, "Handling non-streaming message");
let agent_runtime = self.load_agent_runtime(agent_name).await?;
let context_id = &message.context_id;
self.context_repo
.get_context(context_id, context.user_id())
.await
.map_err(|e| {
AgentServiceError::Internal(format!(
"Context validation failed - context_id: {}, user_id: {}, error: {}",
context_id,
context.user_id(),
e
))
})?;
tracing::info!(
context_id = %context_id,
user_id = %context.user_id(),
"Context validated"
);
let task_id = message.task_id.clone().map_or_else(
|| {
let new_task_id = TaskId::new(Uuid::new_v4().to_string());
tracing::info!(task_id = %new_task_id, "Starting NEW task with generated ID");
new_task_id
},
|existing_task_id| {
tracing::info!(task_id = %existing_task_id, "Continuing existing task");
existing_task_id
},
);
let metadata = TaskMetadata::new_agent_message(agent_name.to_owned());
let task = Task {
id: task_id.clone(),
context_id: context_id.clone(),
status: TaskStatus {
state: TaskState::Submitted,
message: None,
timestamp: Some(chrono::Utc::now()),
},
history: None,
artifacts: None,
metadata: Some(metadata),
created_at: Some(chrono::Utc::now()),
last_modified: Some(chrono::Utc::now()),
};
if let Err(e) = self
.task_repo
.create_task(crate::repository::task::RepoCreateTaskParams {
task: &task,
user_id: &UserId::new(context.user_id().as_str()),
session_id: &SessionId::new(context.session_id().as_str()),
trace_id: &TraceId::new(context.trace_id().as_str()),
agent_name,
})
.await
{
return Err(AgentServiceError::Internal(format!(
"Failed to persist task at start: {e}"
)));
}
tracing::info!(task_id = %task_id, "Task persisted to database");
broadcast_task_created(BroadcastTaskCreatedParams {
task_id: &task_id,
context_id,
user_id: context.user_id().as_str(),
user_message: &message,
agent_name,
token: context.auth_token().as_str(),
})
.await;
let working_timestamp = chrono::Utc::now();
if let Err(e) = self
.task_repo
.update_task_state(&task_id, TaskState::Working, &working_timestamp)
.await
{
tracing::error!(task_id = %task_id, error = %e, "Failed to mark task as working");
}
let stream_processor = StreamProcessor {
ai_service: Arc::clone(&self.ai_service),
context_service: self.context_service.clone(),
skill_service: Arc::clone(&self.skill_service),
execution_step_repo: Arc::clone(&self.execution_step_repo),
};
let chunk_rx = stream_processor
.process_message_stream(ProcessMessageStreamParams {
a2a_message: &message,
agent_runtime: &agent_runtime,
agent_name,
context,
task_id: task_id.clone(),
})
.await?;
let (response_text, tool_artifacts) = collect_stream_response(chunk_rx, context).await?;
let task = build_completed_task(
task_id,
context_id.clone(),
response_text.clone(),
message.clone(),
tool_artifacts,
);
let agent_message = task.status.message.clone().unwrap_or_else(|| {
let client_message_id = message
.metadata
.as_ref()
.and_then(|m| m.get("clientMessageId"))
.cloned();
let metadata = client_message_id.map(|id| serde_json::json!({"clientMessageId": id}));
Message {
role: MessageRole::Agent,
parts: vec![Part::Text(TextPart {
text: response_text.clone(),
})],
message_id: MessageId::generate(),
task_id: Some(task.id.clone()),
context_id: task.context_id.clone(),
metadata,
extensions: None,
reference_task_ids: None,
}
});
if context.user_type() == systemprompt_models::auth::UserType::Anon {
tracing::warn!(
context_id = %context_id,
session_id = %context.session_id(),
"Saving messages for anonymous user"
);
}
self.persist_or_mark_failed(&task, &message, &agent_message, context)
.await?;
broadcast_completion(&task, context).await;
broadcast_agui_lifecycle(BroadcastAguiLifecycleParams {
context,
context_id,
task: &task,
agent_message: &agent_message,
response_text: &response_text,
})
.await;
Ok(task)
}
async fn persist_or_mark_failed(
&self,
task: &Task,
user_message: &Message,
agent_message: &Message,
context: &RequestContext,
) -> Result<()> {
let Err(e) = persist_completed_task(
crate::services::a2a_server::processing::message::persistence::PersistCompletedTaskParams {
task,
user_message,
agent_message,
context,
task_repo: &self.task_repo,
db_pool: &self.db_pool,
artifacts_already_published: false,
},
)
.await
else {
return Ok(());
};
let error_msg = format!("Failed to persist completed task: {}", e);
tracing::error!(task_id = %task.id, error = %e, "Failed to persist completed task");
let failed_timestamp = chrono::Utc::now();
if let Err(update_err) = self
.task_repo
.update_task_failed_with_error(&task.id, &error_msg, &failed_timestamp)
.await
{
tracing::error!(task_id = %task.id, error = %update_err, "Failed to update task to failed state");
}
Err(e)
}
}