systemprompt_agent/services/
message.rs1use anyhow::{Result, anyhow};
2use serde_json::json;
3use uuid::Uuid;
4
5use crate::models::a2a::{Message, MessageRole, Part, TextPart};
6use crate::repository::context::message::PersistMessageWithTxParams;
7use crate::repository::task::TaskRepository;
8use systemprompt_database::{DatabaseProvider, DatabaseTransaction, DbPool};
9use systemprompt_identifiers::{ContextId, TaskId};
10use systemprompt_models::RequestContext;
11
12pub struct PersistMessageInTxParams<'a> {
13 pub tx: &'a mut dyn DatabaseTransaction,
14 pub message: &'a Message,
15 pub task_id: &'a TaskId,
16 pub context_id: &'a ContextId,
17 pub user_id: Option<&'a systemprompt_identifiers::UserId>,
18 pub session_id: &'a systemprompt_identifiers::SessionId,
19 pub trace_id: &'a systemprompt_identifiers::TraceId,
20}
21
22impl std::fmt::Debug for PersistMessageInTxParams<'_> {
23 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
24 f.debug_struct("PersistMessageInTxParams")
25 .field("message", &self.message)
26 .field("task_id", &self.task_id)
27 .field("context_id", &self.context_id)
28 .field("user_id", &self.user_id)
29 .field("session_id", &self.session_id)
30 .field("trace_id", &self.trace_id)
31 .finish_non_exhaustive()
32 }
33}
34
35#[derive(Debug)]
36pub struct PersistMessagesParams<'a> {
37 pub task_id: &'a TaskId,
38 pub context_id: &'a ContextId,
39 pub messages: Vec<Message>,
40 pub user_id: Option<&'a systemprompt_identifiers::UserId>,
41 pub session_id: &'a systemprompt_identifiers::SessionId,
42 pub trace_id: &'a systemprompt_identifiers::TraceId,
43}
44
45#[derive(Debug)]
46pub struct CreateToolExecutionMessageParams<'a> {
47 pub task_id: &'a TaskId,
48 pub context_id: &'a ContextId,
49 pub tool_name: &'a str,
50 pub tool_args: &'a serde_json::Value,
51 pub request_context: &'a RequestContext,
52}
53
54pub struct MessageService {
55 task_repo: TaskRepository,
56}
57
58impl std::fmt::Debug for MessageService {
59 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
60 f.debug_struct("MessageService").finish_non_exhaustive()
61 }
62}
63
64impl MessageService {
65 pub fn new(db_pool: &DbPool) -> Result<Self> {
66 Ok(Self {
67 task_repo: TaskRepository::new(db_pool)?,
68 })
69 }
70
71 pub async fn persist_message_in_tx(&self, params: PersistMessageInTxParams<'_>) -> Result<i32> {
72 let PersistMessageInTxParams {
73 tx,
74 message,
75 task_id,
76 context_id,
77 user_id,
78 session_id,
79 trace_id,
80 } = params;
81 let sequence_number = self
82 .task_repo
83 .get_next_sequence_number_in_tx(tx, task_id)
84 .await?;
85
86 self.task_repo
87 .persist_message_with_tx(PersistMessageWithTxParams {
88 tx,
89 message,
90 task_id,
91 context_id,
92 sequence_number,
93 user_id,
94 session_id,
95 trace_id,
96 })
97 .await
98 .map_err(|e| anyhow!("Failed to persist message: {}", e))?;
99
100 tracing::info!(
101 message_id = %message.message_id,
102 task_id = %task_id,
103 sequence_number = sequence_number,
104 "Message persisted"
105 );
106
107 Ok(sequence_number)
108 }
109
110 pub async fn persist_messages(&self, params: PersistMessagesParams<'_>) -> Result<Vec<i32>> {
111 let PersistMessagesParams {
112 task_id,
113 context_id,
114 messages,
115 user_id,
116 session_id,
117 trace_id,
118 } = params;
119
120 if messages.is_empty() {
121 return Ok(Vec::new());
122 }
123
124 let mut tx = self
125 .task_repo
126 .db_pool()
127 .as_ref()
128 .begin_transaction()
129 .await?;
130 let mut sequence_numbers = Vec::new();
131
132 tracing::info!(
133 task_id = %task_id,
134 message_count = messages.len(),
135 "Persisting multiple messages"
136 );
137
138 for message in messages {
139 let seq = self
140 .persist_message_in_tx(PersistMessageInTxParams {
141 tx: &mut *tx,
142 message: &message,
143 task_id,
144 context_id,
145 user_id,
146 session_id,
147 trace_id,
148 })
149 .await?;
150 sequence_numbers.push(seq);
151 }
152
153 tx.commit().await?;
154
155 tracing::info!(
156 task_id = %task_id,
157 sequence_numbers = ?sequence_numbers,
158 "Messages persisted successfully"
159 );
160
161 Ok(sequence_numbers)
162 }
163
164 pub async fn create_tool_execution_message(
165 &self,
166 params: CreateToolExecutionMessageParams<'_>,
167 ) -> Result<(String, i32)> {
168 let CreateToolExecutionMessageParams {
169 task_id,
170 context_id,
171 tool_name,
172 tool_args,
173 request_context,
174 } = params;
175 let message_id = Uuid::new_v4().to_string();
176
177 let tool_args_display =
178 serde_json::to_string_pretty(tool_args).unwrap_or_else(|_| tool_args.to_string());
179
180 let timestamp = chrono::Utc::now().to_rfc3339();
181
182 let message = Message {
183 role: MessageRole::User,
184 message_id: message_id.clone().into(),
185 task_id: Some(task_id.clone()),
186 context_id: context_id.clone(),
187 parts: vec![Part::Text(TextPart {
188 text: format!(
189 "Executed MCP tool: {} with arguments:\n{}\n\nExecution ID: {} at {}",
190 tool_name,
191 tool_args_display,
192 task_id.as_str(),
193 timestamp
194 ),
195 })],
196 metadata: Some(json!({
197 "source": "mcp_direct_call",
198 "tool_name": tool_name,
199 "is_synthetic": true,
200 "tool_args": tool_args,
201 "execution_timestamp": timestamp,
202 })),
203 extensions: None,
204 reference_task_ids: None,
205 };
206
207 let mut tx = self
208 .task_repo
209 .db_pool()
210 .as_ref()
211 .begin_transaction()
212 .await?;
213
214 let sequence_number = self
215 .persist_message_in_tx(PersistMessageInTxParams {
216 tx: &mut *tx,
217 message: &message,
218 task_id,
219 context_id,
220 user_id: Some(request_context.user_id()),
221 session_id: request_context.session_id(),
222 trace_id: request_context.trace_id(),
223 })
224 .await?;
225
226 tx.commit().await?;
227
228 tracing::info!(
229 message_id = %message_id,
230 task_id = %task_id,
231 tool_name = %tool_name,
232 sequence_number = sequence_number,
233 "Created synthetic tool execution message"
234 );
235
236 Ok((message_id, sequence_number))
237 }
238}