Skip to main content

systemprompt_agent/services/
message.rs

1use anyhow::{anyhow, Result};
2use serde_json::json;
3use uuid::Uuid;
4
5use crate::models::a2a::{Message, Part, TextPart};
6use crate::repository::task::TaskRepository;
7use systemprompt_database::{DatabaseProvider, DatabaseTransaction, DbPool};
8use systemprompt_identifiers::{ContextId, TaskId};
9use systemprompt_models::RequestContext;
10
11pub struct MessageService {
12    task_repo: TaskRepository,
13}
14
15impl std::fmt::Debug for MessageService {
16    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
17        f.debug_struct("MessageService").finish_non_exhaustive()
18    }
19}
20
21impl MessageService {
22    pub fn new(db_pool: &DbPool) -> Result<Self> {
23        Ok(Self {
24            task_repo: TaskRepository::new(db_pool)?,
25        })
26    }
27
28    pub async fn persist_message_in_tx(
29        &self,
30        tx: &mut dyn DatabaseTransaction,
31        message: &Message,
32        task_id: &TaskId,
33        context_id: &ContextId,
34        user_id: Option<&systemprompt_identifiers::UserId>,
35        session_id: &systemprompt_identifiers::SessionId,
36        trace_id: &systemprompt_identifiers::TraceId,
37    ) -> Result<i32> {
38        let sequence_number = self
39            .task_repo
40            .get_next_sequence_number_in_tx(tx, task_id)
41            .await?;
42
43        self.task_repo
44            .persist_message_with_tx(
45                tx,
46                message,
47                task_id,
48                context_id,
49                sequence_number,
50                user_id,
51                session_id,
52                trace_id,
53            )
54            .await
55            .map_err(|e| anyhow!("Failed to persist message: {}", e))?;
56
57        tracing::info!(
58            message_id = %message.id,
59            task_id = %task_id,
60            sequence_number = sequence_number,
61            "Message persisted"
62        );
63
64        Ok(sequence_number)
65    }
66
67    pub async fn persist_messages(
68        &self,
69        task_id: &TaskId,
70        context_id: &ContextId,
71        messages: Vec<Message>,
72        user_id: Option<&systemprompt_identifiers::UserId>,
73        session_id: &systemprompt_identifiers::SessionId,
74        trace_id: &systemprompt_identifiers::TraceId,
75    ) -> Result<Vec<i32>> {
76        if messages.is_empty() {
77            return Ok(Vec::new());
78        }
79
80        let mut tx = self
81            .task_repo
82            .db_pool()
83            .as_ref()
84            .begin_transaction()
85            .await?;
86        let mut sequence_numbers = Vec::new();
87
88        tracing::info!(
89            task_id = %task_id,
90            message_count = messages.len(),
91            "Persisting multiple messages"
92        );
93
94        for message in messages {
95            let seq = self
96                .persist_message_in_tx(
97                    &mut *tx, &message, task_id, context_id, user_id, session_id, trace_id,
98                )
99                .await?;
100            sequence_numbers.push(seq);
101        }
102
103        tx.commit().await?;
104
105        tracing::info!(
106            task_id = %task_id,
107            sequence_numbers = ?sequence_numbers,
108            "Messages persisted successfully"
109        );
110
111        Ok(sequence_numbers)
112    }
113
114    pub async fn create_tool_execution_message(
115        &self,
116        task_id: &TaskId,
117        context_id: &ContextId,
118        tool_name: &str,
119        tool_args: &serde_json::Value,
120        request_context: &RequestContext,
121    ) -> Result<(String, i32)> {
122        let message_id = Uuid::new_v4().to_string();
123
124        let tool_args_display =
125            serde_json::to_string_pretty(tool_args).unwrap_or_else(|_| tool_args.to_string());
126
127        let timestamp = chrono::Utc::now().to_rfc3339();
128
129        let message = Message {
130            role: "user".to_string(),
131            id: message_id.clone().into(),
132            task_id: Some(task_id.clone()),
133            context_id: context_id.clone(),
134            kind: "message".to_string(),
135            parts: vec![Part::Text(TextPart {
136                text: format!(
137                    "Executed MCP tool: {} with arguments:\n{}\n\nExecution ID: {} at {}",
138                    tool_name,
139                    tool_args_display,
140                    task_id.as_str(),
141                    timestamp
142                ),
143            })],
144            metadata: Some(json!({
145                "source": "mcp_direct_call",
146                "tool_name": tool_name,
147                "is_synthetic": true,
148                "tool_args": tool_args,
149                "execution_timestamp": timestamp,
150            })),
151            extensions: None,
152            reference_task_ids: None,
153        };
154
155        let mut tx = self
156            .task_repo
157            .db_pool()
158            .as_ref()
159            .begin_transaction()
160            .await?;
161
162        let sequence_number = self
163            .persist_message_in_tx(
164                &mut *tx,
165                &message,
166                task_id,
167                context_id,
168                Some(request_context.user_id()),
169                request_context.session_id(),
170                request_context.trace_id(),
171            )
172            .await?;
173
174        tx.commit().await?;
175
176        tracing::info!(
177            message_id = %message_id,
178            task_id = %task_id,
179            tool_name = %tool_name,
180            sequence_number = sequence_number,
181            "Created synthetic tool execution message"
182        );
183
184        Ok((message_id, sequence_number))
185    }
186}