use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::collections::{HashMap, HashSet};
use tirea_contract::io::decision_translation::suspension_response_to_decision;
use tirea_contract::io::ResumeDecisionAction;
use tirea_contract::runtime::suspended_calls_from_state;
use tirea_contract::{gen_message_id, RunOrigin, RunRequest, Visibility};
use tirea_contract::{SuspensionResponse, ToolCallDecision};
use tracing::warn;
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Default)]
#[serde(rename_all = "lowercase")]
pub enum Role {
Developer,
System,
#[default]
Assistant,
User,
Tool,
Activity,
Reasoning,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct Message {
pub role: Role,
pub content: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub id: Option<String>,
#[serde(rename = "toolCallId", skip_serializing_if = "Option::is_none")]
pub tool_call_id: Option<String>,
}
impl Message {
pub fn user(content: impl Into<String>) -> Self {
Self {
role: Role::User,
content: content.into(),
id: None,
tool_call_id: None,
}
}
pub fn assistant(content: impl Into<String>) -> Self {
Self {
role: Role::Assistant,
content: content.into(),
id: None,
tool_call_id: None,
}
}
pub fn system(content: impl Into<String>) -> Self {
Self {
role: Role::System,
content: content.into(),
id: None,
tool_call_id: None,
}
}
pub fn tool(content: impl Into<String>, tool_call_id: impl Into<String>) -> Self {
Self {
role: Role::Tool,
content: content.into(),
id: None,
tool_call_id: Some(tool_call_id.into()),
}
}
pub fn activity(content: impl Into<String>) -> Self {
Self {
role: Role::Activity,
content: content.into(),
id: None,
tool_call_id: None,
}
}
pub fn reasoning(content: impl Into<String>) -> Self {
Self {
role: Role::Reasoning,
content: content.into(),
id: None,
tool_call_id: None,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct Context {
pub description: String,
pub value: Value,
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Default)]
#[serde(rename_all = "lowercase")]
pub enum ToolExecutionLocation {
Backend,
#[default]
Frontend,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct Tool {
pub name: String,
pub description: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub parameters: Option<Value>,
#[serde(default, skip_serializing_if = "is_default_frontend")]
pub execute: ToolExecutionLocation,
}
fn is_default_frontend(loc: &ToolExecutionLocation) -> bool {
*loc == ToolExecutionLocation::Frontend
}
impl Tool {
pub fn backend(name: impl Into<String>, description: impl Into<String>) -> Self {
Self {
name: name.into(),
description: description.into(),
parameters: None,
execute: ToolExecutionLocation::Backend,
}
}
pub fn frontend(name: impl Into<String>, description: impl Into<String>) -> Self {
Self {
name: name.into(),
description: description.into(),
parameters: None,
execute: ToolExecutionLocation::Frontend,
}
}
pub fn with_parameters(mut self, parameters: Value) -> Self {
self.parameters = Some(parameters);
self
}
pub fn is_frontend(&self) -> bool {
self.execute == ToolExecutionLocation::Frontend
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RunAgentInput {
#[serde(rename = "threadId")]
pub thread_id: String,
#[serde(rename = "runId")]
pub run_id: String,
pub messages: Vec<Message>,
#[serde(default)]
pub tools: Vec<Tool>,
#[serde(default)]
pub context: Vec<Context>,
#[serde(skip_serializing_if = "Option::is_none")]
pub state: Option<Value>,
#[serde(rename = "parentRunId", skip_serializing_if = "Option::is_none")]
pub parent_run_id: Option<String>,
#[serde(
rename = "parentThreadId",
alias = "parent_thread_id",
skip_serializing_if = "Option::is_none"
)]
pub parent_thread_id: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub model: Option<String>,
#[serde(rename = "systemPrompt", skip_serializing_if = "Option::is_none")]
pub system_prompt: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub config: Option<Value>,
#[serde(
rename = "forwardedProps",
alias = "forwarded_props",
skip_serializing_if = "Option::is_none"
)]
pub forwarded_props: Option<Value>,
}
impl RunAgentInput {
pub fn new(thread_id: impl Into<String>, run_id: impl Into<String>) -> Self {
Self {
thread_id: thread_id.into(),
run_id: run_id.into(),
messages: Vec::new(),
tools: Vec::new(),
context: Vec::new(),
state: None,
parent_run_id: None,
parent_thread_id: None,
model: None,
system_prompt: None,
config: None,
forwarded_props: None,
}
}
pub fn with_message(mut self, message: Message) -> Self {
self.messages.push(message);
self
}
pub fn with_messages(mut self, messages: Vec<Message>) -> Self {
self.messages.extend(messages);
self
}
pub fn with_state(mut self, state: Value) -> Self {
self.state = Some(state);
self
}
pub fn with_parent_thread_id(mut self, parent_thread_id: impl Into<String>) -> Self {
self.parent_thread_id = Some(parent_thread_id.into());
self
}
pub fn with_model(mut self, model: impl Into<String>) -> Self {
self.model = Some(model.into());
self
}
pub fn with_system_prompt(mut self, prompt: impl Into<String>) -> Self {
self.system_prompt = Some(prompt.into());
self
}
pub fn with_forwarded_props(mut self, forwarded_props: Value) -> Self {
self.forwarded_props = Some(forwarded_props);
self
}
pub fn validate(&self) -> Result<(), RequestError> {
if self.thread_id.is_empty() {
return Err(RequestError::invalid_field("threadId cannot be empty"));
}
if self.run_id.is_empty() {
return Err(RequestError::invalid_field("runId cannot be empty"));
}
Ok(())
}
pub fn frontend_tools(&self) -> Vec<&Tool> {
self.tools.iter().filter(|t| t.is_frontend()).collect()
}
pub fn has_any_interaction_responses(&self) -> bool {
!self.interaction_responses().is_empty()
}
pub fn has_any_suspension_decisions(&self) -> bool {
!self.suspension_decisions().is_empty()
}
pub fn has_user_input(&self) -> bool {
self.messages
.iter()
.any(|message| message.role == Role::User && !message.content.trim().is_empty())
}
pub fn into_runtime_run_request(self, agent_id: String) -> RunRequest {
let initial_decisions = self.suspension_decisions();
RunRequest {
agent_id,
thread_id: Some(self.thread_id),
run_id: Some(self.run_id),
parent_run_id: self.parent_run_id,
parent_thread_id: self.parent_thread_id,
resource_id: None,
origin: RunOrigin::AgUi,
state: self.state,
messages: convert_agui_messages(&self.messages),
initial_decisions,
source_mailbox_entry_id: None,
}
}
pub fn interaction_responses(&self) -> Vec<SuspensionResponse> {
let expected_ids = self.suspended_call_response_ids();
let mut latest_by_id: HashMap<String, (usize, Value)> = HashMap::new();
self.messages
.iter()
.enumerate()
.filter(|(_, m)| m.role == Role::Tool)
.filter_map(|(idx, m)| {
m.tool_call_id.as_ref().and_then(|id| {
if !expected_ids.is_empty() && !expected_ids.contains(id) {
return None;
}
let result = parse_interaction_result_value(&m.content);
Some((idx, id.clone(), result))
})
})
.for_each(|(idx, id, result)| {
latest_by_id.insert(id, (idx, result));
});
let mut responses: Vec<(usize, SuspensionResponse)> = latest_by_id
.into_iter()
.map(|(id, (idx, result))| (idx, SuspensionResponse::new(id, result)))
.collect();
responses.sort_by_key(|(idx, _)| *idx);
responses
.into_iter()
.map(|(_, response)| response)
.collect()
}
pub fn suspension_decisions(&self) -> Vec<ToolCallDecision> {
self.interaction_responses()
.into_iter()
.map(suspension_response_to_decision)
.collect()
}
pub fn approved_target_ids(&self) -> Vec<String> {
self.suspension_decisions()
.into_iter()
.filter(|d| matches!(d.resume.action, ResumeDecisionAction::Resume))
.map(|d| d.target_id)
.collect()
}
pub fn denied_target_ids(&self) -> Vec<String> {
self.suspension_decisions()
.into_iter()
.filter(|d| matches!(d.resume.action, ResumeDecisionAction::Cancel))
.map(|d| d.target_id)
.collect()
}
fn suspended_call_response_ids(&self) -> HashSet<String> {
let mut ids = HashSet::new();
let Some(state) = self.state.as_ref() else {
return ids;
};
let calls = suspended_calls_from_state(state);
for call in calls.values() {
ids.insert(call.ticket.pending.id.clone());
ids.insert(call.call_id.clone());
ids.insert(call.ticket.suspension.id.clone());
}
ids
}
}
fn parse_interaction_result_value(content: &str) -> Value {
serde_json::from_str(content).unwrap_or_else(|_| Value::String(content.to_string()))
}
pub fn core_message_from_ag_ui(msg: &Message) -> tirea_contract::Message {
let role = match msg.role {
Role::System => tirea_contract::Role::System,
Role::Developer => tirea_contract::Role::System,
Role::User => tirea_contract::Role::User,
Role::Assistant => tirea_contract::Role::Assistant,
Role::Tool => tirea_contract::Role::Tool,
Role::Activity => tirea_contract::Role::Assistant,
Role::Reasoning => tirea_contract::Role::Assistant,
};
tirea_contract::Message {
id: Some(msg.id.clone().unwrap_or_else(gen_message_id)),
role,
content: msg.content.clone(),
tool_calls: None,
tool_call_id: msg.tool_call_id.clone(),
visibility: Visibility::default(),
metadata: None,
}
}
pub fn convert_agui_messages(messages: &[Message]) -> Vec<tirea_contract::Message> {
messages
.iter()
.filter(|m| {
m.role != Role::Assistant && m.role != Role::Activity && m.role != Role::Reasoning
})
.map(core_message_from_ag_ui)
.collect()
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RequestError {
pub code: String,
pub message: String,
}
impl RequestError {
pub fn invalid_field(message: impl Into<String>) -> Self {
Self {
code: "INVALID_FIELD".into(),
message: message.into(),
}
}
pub fn validation(message: impl Into<String>) -> Self {
Self {
code: "VALIDATION_ERROR".into(),
message: message.into(),
}
}
pub fn internal(message: impl Into<String>) -> Self {
Self {
code: "INTERNAL_ERROR".into(),
message: message.into(),
}
}
}
impl std::fmt::Display for RequestError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "[{}] {}", self.code, self.message)
}
}
impl std::error::Error for RequestError {}
impl From<String> for RequestError {
fn from(message: String) -> Self {
Self::validation(message)
}
}
pub fn build_context_addendum(request: &RunAgentInput) -> Option<String> {
if request.context.is_empty() {
return None;
}
let mut parts = Vec::new();
for entry in &request.context {
let value_str = match &entry.value {
Value::String(s) => s.clone(),
other => match serde_json::to_string(other) {
Ok(value) => value,
Err(err) => {
warn!(
error = %err,
description = %entry.description,
"failed to stringify AG-UI context value"
);
"<unserializable-context-value>".to_string()
}
},
};
parts.push(format!("[{}]: {}", entry.description, value_str));
}
Some(format!(
"\n\nThe following context is available from the frontend:\n{}",
parts.join("\n")
))
}