Skip to main content

limit_llm/
types.rs

1use serde::{Deserialize, Serialize};
2
3#[derive(Debug, Clone, Serialize, Deserialize)]
4pub struct Message {
5    pub role: Role,
6    #[serde(skip_serializing_if = "Option::is_none")]
7    pub content: Option<String>,
8    #[serde(skip_serializing_if = "Option::is_none")]
9    pub tool_calls: Option<Vec<ToolCall>>,
10    #[serde(skip_serializing_if = "Option::is_none")]
11    pub tool_call_id: Option<String>,
12}
13
14#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
15#[serde(rename_all = "lowercase")]
16pub enum Role {
17    User,
18    Assistant,
19    System,
20    Tool,
21}
22
23#[derive(Debug, Clone, Serialize, Deserialize)]
24pub struct ToolCall {
25    pub id: String,
26    #[serde(rename = "type")]
27    pub tool_type: String,
28    pub function: FunctionCall,
29}
30
31#[derive(Debug, Clone, Serialize, Deserialize)]
32pub struct FunctionCall {
33    pub name: String,
34    /// JSON string representation of the function arguments
35    pub arguments: String,
36}
37
38#[derive(Debug, Clone, Serialize, Deserialize)]
39pub struct Tool {
40    #[serde(rename = "type")]
41    pub tool_type: String,
42    pub function: ToolFunction,
43}
44
45#[derive(Debug, Clone, Serialize, Deserialize)]
46pub struct ToolFunction {
47    pub name: String,
48    pub description: String,
49    pub parameters: serde_json::Value,
50}
51
52#[derive(Debug, Clone, Serialize, Deserialize)]
53pub struct Response {
54    pub content: String,
55    #[serde(skip_serializing_if = "Option::is_none")]
56    pub tool_calls: Option<Vec<ToolCall>>,
57    pub usage: Usage,
58}
59
60#[derive(Debug, Clone, Serialize, Deserialize)]
61pub struct Usage {
62    pub input_tokens: u64,
63    pub output_tokens: u64,
64}
65
66#[cfg(test)]
67mod tests {
68    use super::*;
69
70    #[test]
71    fn test_message_serialization() {
72        let msg = Message {
73            role: Role::User,
74            content: Some("Hello".to_string()),
75            tool_calls: None,
76            tool_call_id: None,
77        };
78        let json = serde_json::to_string(&msg).unwrap();
79        let deserialized: Message = serde_json::from_str(&json).unwrap();
80        assert_eq!(msg.content, deserialized.content);
81    }
82
83    #[test]
84    fn test_message_with_tool_calls() {
85        let msg = Message {
86            role: Role::Assistant,
87            content: Some("".to_string()),
88            tool_calls: Some(vec![ToolCall {
89                id: "call_123".to_string(),
90                tool_type: "function".to_string(),
91                function: FunctionCall {
92                    name: "test_tool".to_string(),
93                    arguments: serde_json::json!({"arg": "value"}).to_string(),
94                },
95            }]),
96            tool_call_id: None,
97        };
98        let json = serde_json::to_string(&msg).unwrap();
99        let deserialized: Message = serde_json::from_str(&json).unwrap();
100        assert!(deserialized.tool_calls.is_some());
101    }
102
103    #[test]
104    fn test_tool_result_message() {
105        let msg = Message {
106            role: Role::Tool,
107            content: Some("result output".to_string()),
108            tool_calls: None,
109            tool_call_id: Some("call_123".to_string()),
110        };
111        let json = serde_json::to_string(&msg).unwrap();
112        println!("Tool result message JSON: {}", json);
113        assert!(json.contains("tool_call_id"));
114        let deserialized: Message = serde_json::from_str(&json).unwrap();
115        assert_eq!(deserialized.tool_call_id, Some("call_123".to_string()));
116    }
117
118    #[test]
119    fn test_assistant_with_tool_calls_serialization() {
120        let msg = Message {
121            role: Role::Assistant,
122            content: None, // Empty content
123            tool_calls: Some(vec![ToolCall {
124                id: "call_123".to_string(),
125                tool_type: "function".to_string(),
126                function: FunctionCall {
127                    name: "test_tool".to_string(),
128                    arguments: serde_json::json!({}).to_string(),
129                },
130            }]),
131            tool_call_id: None,
132        };
133        let json = serde_json::to_string(&msg).unwrap();
134        println!("Assistant with tool_calls JSON: {}", json);
135        // Content should be omitted when None
136        assert!(!json.contains("\"content\":null"));
137        assert!(json.contains("tool_calls"));
138    }
139
140    #[test]
141    fn test_role_serialization() {
142        let role = Role::User;
143        let json = serde_json::to_string(&role).unwrap();
144        assert_eq!(json, "\"user\"");
145    }
146
147    #[test]
148    fn test_tool_serialization() {
149        let tool = Tool {
150            tool_type: "function".to_string(),
151            function: ToolFunction {
152                name: "test_tool".to_string(),
153                description: "A test tool".to_string(),
154                parameters: serde_json::json!({"type": "object"}),
155            },
156        };
157        let json = serde_json::to_string(&tool).unwrap();
158        let deserialized: Tool = serde_json::from_str(&json).unwrap();
159        assert_eq!(tool.function.name, deserialized.function.name);
160    }
161
162    #[test]
163    fn test_response_serialization() {
164        let response = Response {
165            content: "Hello, world!".to_string(),
166            tool_calls: None,
167            usage: Usage {
168                input_tokens: 10,
169                output_tokens: 5,
170            },
171        };
172        let json = serde_json::to_string(&response).unwrap();
173        let deserialized: Response = serde_json::from_str(&json).unwrap();
174        assert_eq!(response.content, deserialized.content);
175        assert_eq!(response.usage.input_tokens, deserialized.usage.input_tokens);
176    }
177
178    #[test]
179    fn test_usage_serialization() {
180        let usage = Usage {
181            input_tokens: 100,
182            output_tokens: 50,
183        };
184        let json = serde_json::to_string(&usage).unwrap();
185        let deserialized: Usage = serde_json::from_str(&json).unwrap();
186        assert_eq!(usage.input_tokens, deserialized.input_tokens);
187        assert_eq!(usage.output_tokens, deserialized.output_tokens);
188    }
189}