Skip to main content

systemprompt_agent/services/
message.rs

1use anyhow::{anyhow, Result};
2use serde_json::json;
3use uuid::Uuid;
4
5use crate::models::a2a::{Message, Part, TextPart};
6use crate::repository::task::TaskRepository;
7use systemprompt_database::{DatabaseProvider, DatabaseTransaction, DbPool};
8use systemprompt_identifiers::{ContextId, TaskId};
9use systemprompt_models::RequestContext;
10use systemprompt_traits::Repository;
11
12pub struct MessageService {
13    task_repo: TaskRepository,
14}
15
16impl std::fmt::Debug for MessageService {
17    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
18        f.debug_struct("MessageService").finish_non_exhaustive()
19    }
20}
21
22impl MessageService {
23    pub fn new(db_pool: DbPool) -> Self {
24        Self {
25            task_repo: TaskRepository::new(db_pool),
26        }
27    }
28
29    pub async fn persist_message_in_tx(
30        &self,
31        tx: &mut dyn DatabaseTransaction,
32        message: &Message,
33        task_id: &TaskId,
34        context_id: &ContextId,
35        user_id: Option<&systemprompt_identifiers::UserId>,
36        session_id: &systemprompt_identifiers::SessionId,
37        trace_id: &systemprompt_identifiers::TraceId,
38    ) -> Result<i32> {
39        let sequence_number = self
40            .task_repo
41            .get_next_sequence_number_in_tx(tx, task_id)
42            .await?;
43
44        self.task_repo
45            .persist_message_with_tx(
46                tx,
47                message,
48                task_id,
49                context_id,
50                sequence_number,
51                user_id,
52                session_id,
53                trace_id,
54            )
55            .await
56            .map_err(|e| anyhow!("Failed to persist message: {}", e))?;
57
58        tracing::info!(
59            message_id = %message.id,
60            task_id = %task_id,
61            sequence_number = sequence_number,
62            "Message persisted"
63        );
64
65        Ok(sequence_number)
66    }
67
68    pub async fn persist_messages(
69        &self,
70        task_id: &TaskId,
71        context_id: &ContextId,
72        messages: Vec<Message>,
73        user_id: Option<&systemprompt_identifiers::UserId>,
74        session_id: &systemprompt_identifiers::SessionId,
75        trace_id: &systemprompt_identifiers::TraceId,
76    ) -> Result<Vec<i32>> {
77        if messages.is_empty() {
78            return Ok(Vec::new());
79        }
80
81        let mut tx = self.task_repo.pool().as_ref().begin_transaction().await?;
82        let mut sequence_numbers = Vec::new();
83
84        tracing::info!(
85            task_id = %task_id,
86            message_count = messages.len(),
87            "Persisting multiple messages"
88        );
89
90        for message in messages {
91            let seq = self
92                .persist_message_in_tx(
93                    &mut *tx, &message, task_id, context_id, user_id, session_id, trace_id,
94                )
95                .await?;
96            sequence_numbers.push(seq);
97        }
98
99        tx.commit().await?;
100
101        tracing::info!(
102            task_id = %task_id,
103            sequence_numbers = ?sequence_numbers,
104            "Messages persisted successfully"
105        );
106
107        Ok(sequence_numbers)
108    }
109
110    pub async fn create_tool_execution_message(
111        &self,
112        task_id: &TaskId,
113        context_id: &ContextId,
114        tool_name: &str,
115        tool_args: &serde_json::Value,
116        request_context: &RequestContext,
117    ) -> Result<(String, i32)> {
118        let message_id = Uuid::new_v4().to_string();
119
120        let tool_args_display =
121            serde_json::to_string_pretty(tool_args).unwrap_or_else(|_| tool_args.to_string());
122
123        let timestamp = chrono::Utc::now().to_rfc3339();
124
125        let message = Message {
126            role: "user".to_string(),
127            id: message_id.clone().into(),
128            task_id: Some(task_id.clone()),
129            context_id: context_id.clone(),
130            kind: "message".to_string(),
131            parts: vec![Part::Text(TextPart {
132                text: format!(
133                    "Executed MCP tool: {} with arguments:\n{}\n\nExecution ID: {} at {}",
134                    tool_name,
135                    tool_args_display,
136                    task_id.as_str(),
137                    timestamp
138                ),
139            })],
140            metadata: Some(json!({
141                "source": "mcp_direct_call",
142                "tool_name": tool_name,
143                "is_synthetic": true,
144                "tool_args": tool_args,
145                "execution_timestamp": timestamp,
146            })),
147            extensions: None,
148            reference_task_ids: None,
149        };
150
151        let mut tx = self.task_repo.pool().as_ref().begin_transaction().await?;
152
153        let sequence_number = self
154            .persist_message_in_tx(
155                &mut *tx,
156                &message,
157                task_id,
158                context_id,
159                Some(request_context.user_id()),
160                request_context.session_id(),
161                request_context.trace_id(),
162            )
163            .await?;
164
165        tx.commit().await?;
166
167        tracing::info!(
168            message_id = %message_id,
169            task_id = %task_id,
170            tool_name = %tool_name,
171            sequence_number = sequence_number,
172            "Created synthetic tool execution message"
173        );
174
175        Ok((message_id, sequence_number))
176    }
177}