Skip to main content

agent_runtime/llm/
types.rs

1use serde::{Deserialize, Serialize};
2use serde_json::Value as JsonValue;
3
4#[cfg(test)]
5#[path = "types_test.rs"]
6mod types_test;
7
8/// Role of a message in a conversation
9#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
10#[serde(rename_all = "lowercase")]
11pub enum Role {
12    System,
13    User,
14    Assistant,
15    Tool,
16}
17
18/// A single message in a chat conversation
19#[derive(Debug, Clone, Serialize, Deserialize)]
20pub struct ChatMessage {
21    pub role: Role,
22    pub content: String,
23
24    #[serde(skip_serializing_if = "Option::is_none")]
25    pub tool_calls: Option<Vec<ToolCall>>,
26
27    #[serde(skip_serializing_if = "Option::is_none")]
28    pub tool_call_id: Option<String>,
29}
30
31impl ChatMessage {
32    pub fn system(content: impl Into<String>) -> Self {
33        Self {
34            role: Role::System,
35            content: content.into(),
36            tool_calls: None,
37            tool_call_id: None,
38        }
39    }
40
41    pub fn user(content: impl Into<String>) -> Self {
42        Self {
43            role: Role::User,
44            content: content.into(),
45            tool_calls: None,
46            tool_call_id: None,
47        }
48    }
49
50    pub fn assistant(content: impl Into<String>) -> Self {
51        Self {
52            role: Role::Assistant,
53            content: content.into(),
54            tool_calls: None,
55            tool_call_id: None,
56        }
57    }
58
59    pub fn assistant_with_tool_calls(
60        content: impl Into<String>,
61        tool_calls: Vec<ToolCall>,
62    ) -> Self {
63        Self {
64            role: Role::Assistant,
65            content: content.into(),
66            tool_calls: Some(tool_calls),
67            tool_call_id: None,
68        }
69    }
70
71    pub fn tool_result(tool_call_id: impl Into<String>, content: impl Into<String>) -> Self {
72        Self {
73            role: Role::Tool,
74            content: content.into(),
75            tool_calls: None,
76            tool_call_id: Some(tool_call_id.into()),
77        }
78    }
79}
80
81/// Request for chat completion
82#[derive(Debug, Clone, Serialize, Deserialize)]
83pub struct ChatRequest {
84    pub messages: Vec<ChatMessage>,
85
86    #[serde(skip_serializing_if = "Option::is_none")]
87    pub temperature: Option<f32>,
88
89    #[serde(skip_serializing_if = "Option::is_none")]
90    pub max_tokens: Option<u32>,
91
92    #[serde(skip_serializing_if = "Option::is_none")]
93    pub top_p: Option<f32>,
94
95    #[serde(skip_serializing_if = "Option::is_none")]
96    pub tools: Option<Vec<JsonValue>>,
97}
98
99impl ChatRequest {
100    pub fn new(messages: Vec<ChatMessage>) -> Self {
101        Self {
102            messages,
103            temperature: None,
104            max_tokens: None,
105            top_p: None,
106            tools: None,
107        }
108    }
109
110    pub fn with_temperature(mut self, temperature: f32) -> Self {
111        self.temperature = Some(temperature);
112        self
113    }
114
115    pub fn with_max_tokens(mut self, max_tokens: u32) -> Self {
116        self.max_tokens = Some(max_tokens);
117        self
118    }
119
120    pub fn with_top_p(mut self, top_p: f32) -> Self {
121        self.top_p = Some(top_p);
122        self
123    }
124
125    pub fn with_tools(mut self, tools: Vec<JsonValue>) -> Self {
126        self.tools = Some(tools);
127        self
128    }
129}
130
131/// Response from chat completion
132#[derive(Debug, Clone, Serialize, Deserialize)]
133pub struct ChatResponse {
134    pub content: String,
135    pub model: String,
136
137    #[serde(skip_serializing_if = "Option::is_none")]
138    pub usage: Option<Usage>,
139
140    #[serde(skip_serializing_if = "Option::is_none")]
141    pub finish_reason: Option<String>,
142
143    #[serde(skip_serializing_if = "Option::is_none")]
144    pub tool_calls: Option<Vec<ToolCall>>,
145}
146
147/// A tool call request from the LLM
148#[derive(Debug, Clone, Serialize, Deserialize)]
149pub struct ToolCall {
150    pub id: String,
151    pub r#type: String, // Usually "function"
152    pub function: FunctionCall,
153}
154
155/// Function call details
156#[derive(Debug, Clone, Serialize, Deserialize)]
157pub struct FunctionCall {
158    pub name: String,
159    pub arguments: String, // JSON string
160}
161
162/// Token usage statistics
163#[derive(Debug, Clone, Serialize, Deserialize)]
164pub struct Usage {
165    pub prompt_tokens: u32,
166    pub completion_tokens: u32,
167    pub total_tokens: u32,
168}