Skip to main content

forge_guardrails/core/
message.rs

1use indexmap::IndexMap;
2use serde::Serialize;
3use serde_json::Value;
4use std::fmt;
5
6/// Message role: system, user, assistant, or tool.
7#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize)]
8#[serde(rename_all = "lowercase")]
9pub enum MessageRole {
10    /// System instruction role.
11    System,
12    /// User input role.
13    User,
14    /// Assistant/model response role.
15    Assistant,
16    /// Tool result response role.
17    Tool,
18}
19
20impl MessageRole {
21    /// Returns the string representation of the message role.
22    pub fn as_str(&self) -> &'static str {
23        match self {
24            Self::System => "system",
25            Self::User => "user",
26            Self::Assistant => "assistant",
27            Self::Tool => "tool",
28        }
29    }
30}
31
32impl fmt::Display for MessageRole {
33    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
34        write!(f, "{}", self.as_str())
35    }
36}
37
38/// Message type classification for metadata tagging.
39#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize)]
40#[serde(rename_all = "snake_case")]
41pub enum MessageType {
42    /// Crate-level or system instruction prompt.
43    SystemPrompt,
44    /// Direct user request or query.
45    UserInput,
46    /// Model-generated tool invocation request.
47    ToolCall,
48    /// Result payload returned from tool execution.
49    ToolResult,
50    /// Internal reasoning or thinking thoughts from the model.
51    Reasoning,
52    /// Final text response intended for the user.
53    TextResponse,
54    /// Nudge generated when a required step is missing.
55    StepNudge,
56    /// Nudge generated when a prerequisite step is missing.
57    PrerequisiteNudge,
58    /// Nudge or retry instructions following an error.
59    RetryNudge,
60    /// Warning nudge triggered by budget/hardware limits.
61    ContextWarning,
62    /// Summarized representation of older turns.
63    Summary,
64}
65
66impl MessageType {
67    /// Returns the string representation of the message type.
68    pub fn as_str(&self) -> &'static str {
69        match self {
70            Self::SystemPrompt => "system_prompt",
71            Self::UserInput => "user_input",
72            Self::ToolCall => "tool_call",
73            Self::ToolResult => "tool_result",
74            Self::Reasoning => "reasoning",
75            Self::TextResponse => "text_response",
76            Self::StepNudge => "step_nudge",
77            Self::PrerequisiteNudge => "prerequisite_nudge",
78            Self::RetryNudge => "retry_nudge",
79            Self::ContextWarning => "context_warning",
80            Self::Summary => "summary",
81        }
82    }
83}
84
85impl fmt::Display for MessageType {
86    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
87        write!(f, "{}", self.as_str())
88    }
89}
90
91/// Immutable metadata attached to a message.
92#[derive(Debug, Clone, PartialEq)]
93pub struct MessageMeta {
94    /// The classified message type.
95    pub msg_type: MessageType,
96    /// The iteration or step index of the workflow loop.
97    pub step_index: Option<i64>,
98    /// The original message type if this message was transformed.
99    pub original_type: Option<MessageType>,
100    /// Estimated token count for this message.
101    pub token_estimate: Option<i64>,
102}
103
104impl MessageMeta {
105    /// Creates a new `MessageMeta` with the given type.
106    pub fn new(msg_type: MessageType) -> Self {
107        Self {
108            msg_type,
109            step_index: None,
110            original_type: None,
111            token_estimate: None,
112        }
113    }
114
115    /// Sets the step index.
116    pub fn with_step_index(mut self, idx: i64) -> Self {
117        self.step_index = Some(idx);
118        self
119    }
120
121    /// Sets the original message type.
122    pub fn with_original_type(mut self, t: MessageType) -> Self {
123        self.original_type = Some(t);
124        self
125    }
126
127    /// Sets the token estimate.
128    pub fn with_token_estimate(mut self, est: i64) -> Self {
129        self.token_estimate = Some(est);
130        self
131    }
132}
133
134/// Immutable representation of a single tool call within a message.
135#[derive(Debug, Clone, PartialEq)]
136pub struct ToolCallInfo {
137    /// Name of the tool being called.
138    pub name: String,
139    /// Arguments passed to the tool, if any.
140    pub args: Option<IndexMap<String, Value>>,
141    /// Uniquely generated identifier for this tool call.
142    pub call_id: String,
143}
144
145impl ToolCallInfo {
146    /// Creates a new `ToolCallInfo`.
147    pub fn new(
148        name: impl Into<String>,
149        args: Option<IndexMap<String, Value>>,
150        call_id: impl Into<String>,
151    ) -> Self {
152        Self {
153            name: name.into(),
154            args,
155            call_id: call_id.into(),
156        }
157    }
158
159    fn effective_args(&self) -> &IndexMap<String, Value> {
160        static EMPTY: std::sync::OnceLock<IndexMap<String, Value>> = std::sync::OnceLock::new();
161        self.args
162            .as_ref()
163            .unwrap_or_else(|| EMPTY.get_or_init(IndexMap::new))
164    }
165}
166
167fn json_dumps_default_value(value: &Value) -> String {
168    match value {
169        Value::Null => "null".to_string(),
170        Value::Bool(v) => v.to_string(),
171        Value::Number(v) => v.to_string(),
172        Value::String(v) => serde_json::to_string(v).unwrap_or_else(|_| "\"\"".to_string()),
173        Value::Array(values) => {
174            let inner = values
175                .iter()
176                .map(json_dumps_default_value)
177                .collect::<Vec<_>>()
178                .join(", ");
179            format!("[{}]", inner)
180        }
181        Value::Object(values) => {
182            let inner = values
183                .iter()
184                .map(|(key, val)| {
185                    let key = serde_json::to_string(key).unwrap_or_else(|_| "\"\"".to_string());
186                    format!("{}: {}", key, json_dumps_default_value(val))
187                })
188                .collect::<Vec<_>>()
189                .join(", ");
190            format!("{{{}}}", inner)
191        }
192    }
193}
194
195fn json_dumps_default_object(values: &IndexMap<String, Value>) -> String {
196    let inner = values
197        .iter()
198        .map(|(key, val)| {
199            let key = serde_json::to_string(key).unwrap_or_else(|_| "\"\"".to_string());
200            format!("{}: {}", key, json_dumps_default_value(val))
201        })
202        .collect::<Vec<_>>()
203        .join(", ");
204    format!("{{{}}}", inner)
205}
206
207/// A conversation message with dual serialization format support.
208#[derive(Debug, Clone)]
209pub struct Message {
210    /// Message role (system, user, assistant, tool).
211    pub role: MessageRole,
212    /// Main text content of the message.
213    pub content: String,
214    /// Metadata tagging for tracking and budget decisions.
215    pub metadata: MessageMeta,
216    /// Tool name associated with the message, if it is a tool response.
217    pub tool_name: Option<String>,
218    /// Tool call identifier paired with this message.
219    pub tool_call_id: Option<String>,
220    /// List of tool call invocations, if this is an assistant message.
221    pub tool_calls: Option<Vec<ToolCallInfo>>,
222}
223
224impl Message {
225    /// Creates a new `Message`.
226    pub fn new(role: MessageRole, content: impl Into<String>, metadata: MessageMeta) -> Self {
227        Self {
228            role,
229            content: content.into(),
230            metadata,
231            tool_name: None,
232            tool_call_id: None,
233            tool_calls: None,
234        }
235    }
236
237    /// Sets the tool name.
238    pub fn with_tool_name(mut self, name: impl Into<String>) -> Self {
239        self.tool_name = Some(name.into());
240        self
241    }
242
243    /// Sets the tool call ID.
244    pub fn with_tool_call_id(mut self, id: impl Into<String>) -> Self {
245        self.tool_call_id = Some(id.into());
246        self
247    }
248
249    /// Sets the list of tool calls.
250    pub fn with_tool_calls(mut self, calls: Vec<ToolCallInfo>) -> Self {
251        self.tool_calls = Some(calls);
252        self
253    }
254
255    /// Serialize this message for an LLM API.
256    ///
257    /// Format "ollama" (default): tool calls have no id/type, args as dict.
258    /// Format "openai": tool calls have id, type="function", args as JSON string.
259    pub fn serialize(&self, format: &str) -> Value {
260        match format {
261            "ollama" => self.serialize_ollama(),
262            "openai" => self.serialize_openai(),
263            _ => self.serialize_ollama(),
264        }
265    }
266
267    fn serialize_ollama(&self) -> Value {
268        let mut map = serde_json::Map::new();
269        map.insert(
270            "role".to_string(),
271            Value::String(self.role.as_str().to_string()),
272        );
273        map.insert("content".to_string(), Value::String(self.content.clone()));
274
275        if let Some(calls) = &self.tool_calls {
276            let tool_calls_json: Vec<Value> = calls
277                .iter()
278                .map(|tc| {
279                    let mut entry = serde_json::Map::new();
280                    let mut func = serde_json::Map::new();
281                    func.insert("name".to_string(), Value::String(tc.name.clone()));
282                    func.insert(
283                        "arguments".to_string(),
284                        serde_json::to_value(tc.effective_args())
285                            .unwrap_or(Value::Object(serde_json::Map::new())),
286                    );
287                    entry.insert("function".to_string(), Value::Object(func));
288                    Value::Object(entry)
289                })
290                .collect();
291            map.insert("tool_calls".to_string(), Value::Array(tool_calls_json));
292        }
293
294        if self.role == MessageRole::Tool {
295            if let Some(name) = &self.tool_name {
296                map.insert("tool_name".to_string(), Value::String(name.clone()));
297            }
298        }
299
300        Value::Object(map)
301    }
302
303    fn serialize_openai(&self) -> Value {
304        let mut map = serde_json::Map::new();
305        map.insert(
306            "role".to_string(),
307            Value::String(self.role.as_str().to_string()),
308        );
309        map.insert("content".to_string(), Value::String(self.content.clone()));
310
311        if let Some(calls) = &self.tool_calls {
312            let tool_calls_json: Vec<Value> = calls
313                .iter()
314                .map(|tc| {
315                    let mut entry = serde_json::Map::new();
316                    entry.insert("id".to_string(), Value::String(tc.call_id.clone()));
317                    entry.insert("type".to_string(), Value::String("function".to_string()));
318                    let mut func = serde_json::Map::new();
319                    func.insert("name".to_string(), Value::String(tc.name.clone()));
320                    let args_str = json_dumps_default_object(tc.effective_args());
321                    func.insert("arguments".to_string(), Value::String(args_str));
322                    entry.insert("function".to_string(), Value::Object(func));
323                    Value::Object(entry)
324                })
325                .collect();
326            map.insert("tool_calls".to_string(), Value::Array(tool_calls_json));
327        }
328
329        if self.role == MessageRole::Tool {
330            if let Some(name) = &self.tool_name {
331                map.insert("name".to_string(), Value::String(name.clone()));
332            }
333            if let Some(id) = &self.tool_call_id {
334                map.insert("tool_call_id".to_string(), Value::String(id.clone()));
335            }
336        }
337
338        Value::Object(map)
339    }
340}