systemprompt_agent/services/
message.rs1use anyhow::{anyhow, Result};
2use serde_json::json;
3use uuid::Uuid;
4
5use crate::models::a2a::{Message, Part, TextPart};
6use crate::repository::task::TaskRepository;
7use systemprompt_database::{DatabaseProvider, DatabaseTransaction, DbPool};
8use systemprompt_identifiers::{ContextId, TaskId};
9use systemprompt_models::RequestContext;
10use systemprompt_traits::Repository;
11
12pub struct MessageService {
13 task_repo: TaskRepository,
14}
15
16impl std::fmt::Debug for MessageService {
17 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
18 f.debug_struct("MessageService").finish_non_exhaustive()
19 }
20}
21
22impl MessageService {
23 pub fn new(db_pool: DbPool) -> Self {
24 Self {
25 task_repo: TaskRepository::new(db_pool),
26 }
27 }
28
29 pub async fn persist_message_in_tx(
30 &self,
31 tx: &mut dyn DatabaseTransaction,
32 message: &Message,
33 task_id: &TaskId,
34 context_id: &ContextId,
35 user_id: Option<&systemprompt_identifiers::UserId>,
36 session_id: &systemprompt_identifiers::SessionId,
37 trace_id: &systemprompt_identifiers::TraceId,
38 ) -> Result<i32> {
39 let sequence_number = self
40 .task_repo
41 .get_next_sequence_number_in_tx(tx, task_id)
42 .await?;
43
44 self.task_repo
45 .persist_message_with_tx(
46 tx,
47 message,
48 task_id,
49 context_id,
50 sequence_number,
51 user_id,
52 session_id,
53 trace_id,
54 )
55 .await
56 .map_err(|e| anyhow!("Failed to persist message: {}", e))?;
57
58 tracing::info!(
59 message_id = %message.id,
60 task_id = %task_id,
61 sequence_number = sequence_number,
62 "Message persisted"
63 );
64
65 Ok(sequence_number)
66 }
67
68 pub async fn persist_messages(
69 &self,
70 task_id: &TaskId,
71 context_id: &ContextId,
72 messages: Vec<Message>,
73 user_id: Option<&systemprompt_identifiers::UserId>,
74 session_id: &systemprompt_identifiers::SessionId,
75 trace_id: &systemprompt_identifiers::TraceId,
76 ) -> Result<Vec<i32>> {
77 if messages.is_empty() {
78 return Ok(Vec::new());
79 }
80
81 let mut tx = self.task_repo.pool().as_ref().begin_transaction().await?;
82 let mut sequence_numbers = Vec::new();
83
84 tracing::info!(
85 task_id = %task_id,
86 message_count = messages.len(),
87 "Persisting multiple messages"
88 );
89
90 for message in messages {
91 let seq = self
92 .persist_message_in_tx(
93 &mut *tx, &message, task_id, context_id, user_id, session_id, trace_id,
94 )
95 .await?;
96 sequence_numbers.push(seq);
97 }
98
99 tx.commit().await?;
100
101 tracing::info!(
102 task_id = %task_id,
103 sequence_numbers = ?sequence_numbers,
104 "Messages persisted successfully"
105 );
106
107 Ok(sequence_numbers)
108 }
109
110 pub async fn create_tool_execution_message(
111 &self,
112 task_id: &TaskId,
113 context_id: &ContextId,
114 tool_name: &str,
115 tool_args: &serde_json::Value,
116 request_context: &RequestContext,
117 ) -> Result<(String, i32)> {
118 let message_id = Uuid::new_v4().to_string();
119
120 let tool_args_display =
121 serde_json::to_string_pretty(tool_args).unwrap_or_else(|_| tool_args.to_string());
122
123 let timestamp = chrono::Utc::now().to_rfc3339();
124
125 let message = Message {
126 role: "user".to_string(),
127 id: message_id.clone().into(),
128 task_id: Some(task_id.clone()),
129 context_id: context_id.clone(),
130 kind: "message".to_string(),
131 parts: vec![Part::Text(TextPart {
132 text: format!(
133 "Executed MCP tool: {} with arguments:\n{}\n\nExecution ID: {} at {}",
134 tool_name,
135 tool_args_display,
136 task_id.as_str(),
137 timestamp
138 ),
139 })],
140 metadata: Some(json!({
141 "source": "mcp_direct_call",
142 "tool_name": tool_name,
143 "is_synthetic": true,
144 "tool_args": tool_args,
145 "execution_timestamp": timestamp,
146 })),
147 extensions: None,
148 reference_task_ids: None,
149 };
150
151 let mut tx = self.task_repo.pool().as_ref().begin_transaction().await?;
152
153 let sequence_number = self
154 .persist_message_in_tx(
155 &mut *tx,
156 &message,
157 task_id,
158 context_id,
159 Some(request_context.user_id()),
160 request_context.session_id(),
161 request_context.trace_id(),
162 )
163 .await?;
164
165 tx.commit().await?;
166
167 tracing::info!(
168 message_id = %message_id,
169 task_id = %task_id,
170 tool_name = %tool_name,
171 sequence_number = sequence_number,
172 "Created synthetic tool execution message"
173 );
174
175 Ok((message_id, sequence_number))
176 }
177}