lib_client_openai/
types.rs

1//! Data types for the OpenAI API.
2
3use serde::{Deserialize, Serialize};
4
5/// Message role.
6#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
7#[serde(rename_all = "lowercase")]
8pub enum Role {
9    System,
10    User,
11    Assistant,
12    Tool,
13}
14
15/// A message in the conversation.
16#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct Message {
18    /// Message role.
19    pub role: Role,
20    /// Message content.
21    #[serde(skip_serializing_if = "Option::is_none")]
22    pub content: Option<String>,
23    /// Tool calls made by the assistant.
24    #[serde(skip_serializing_if = "Option::is_none")]
25    pub tool_calls: Option<Vec<ToolCall>>,
26    /// Tool call ID (for tool role messages).
27    #[serde(skip_serializing_if = "Option::is_none")]
28    pub tool_call_id: Option<String>,
29}
30
31impl Message {
32    /// Create a system message.
33    pub fn system(content: impl Into<String>) -> Self {
34        Self {
35            role: Role::System,
36            content: Some(content.into()),
37            tool_calls: None,
38            tool_call_id: None,
39        }
40    }
41
42    /// Create a user message.
43    pub fn user(content: impl Into<String>) -> Self {
44        Self {
45            role: Role::User,
46            content: Some(content.into()),
47            tool_calls: None,
48            tool_call_id: None,
49        }
50    }
51
52    /// Create an assistant message.
53    pub fn assistant(content: impl Into<String>) -> Self {
54        Self {
55            role: Role::Assistant,
56            content: Some(content.into()),
57            tool_calls: None,
58            tool_call_id: None,
59        }
60    }
61
62    /// Create an assistant message with tool calls.
63    pub fn assistant_with_tool_calls(tool_calls: Vec<ToolCall>) -> Self {
64        Self {
65            role: Role::Assistant,
66            content: None,
67            tool_calls: Some(tool_calls),
68            tool_call_id: None,
69        }
70    }
71
72    /// Create a tool result message.
73    pub fn tool(tool_call_id: impl Into<String>, content: impl Into<String>) -> Self {
74        Self {
75            role: Role::Tool,
76            content: Some(content.into()),
77            tool_calls: None,
78            tool_call_id: Some(tool_call_id.into()),
79        }
80    }
81}
82
83/// Tool call made by the assistant.
84#[derive(Debug, Clone, Serialize, Deserialize)]
85pub struct ToolCall {
86    /// Tool call ID.
87    pub id: String,
88    /// Tool type (always "function").
89    #[serde(rename = "type")]
90    pub tool_type: String,
91    /// Function call details.
92    pub function: FunctionCall,
93}
94
95impl ToolCall {
96    /// Create a new tool call.
97    pub fn new(
98        id: impl Into<String>,
99        name: impl Into<String>,
100        arguments: impl Into<String>,
101    ) -> Self {
102        Self {
103            id: id.into(),
104            tool_type: "function".to_string(),
105            function: FunctionCall {
106                name: name.into(),
107                arguments: arguments.into(),
108            },
109        }
110    }
111}
112
113/// Function call details.
114#[derive(Debug, Clone, Serialize, Deserialize)]
115pub struct FunctionCall {
116    /// Function name.
117    pub name: String,
118    /// JSON-encoded arguments.
119    pub arguments: String,
120}
121
122/// Tool definition.
123#[derive(Debug, Clone, Serialize, Deserialize)]
124pub struct Tool {
125    /// Tool type (always "function").
126    #[serde(rename = "type")]
127    pub tool_type: String,
128    /// Function definition.
129    pub function: FunctionDefinition,
130}
131
132impl Tool {
133    /// Create a new function tool.
134    pub fn function(
135        name: impl Into<String>,
136        description: impl Into<String>,
137        parameters: serde_json::Value,
138    ) -> Self {
139        Self {
140            tool_type: "function".to_string(),
141            function: FunctionDefinition {
142                name: name.into(),
143                description: description.into(),
144                parameters,
145            },
146        }
147    }
148}
149
150/// Function definition.
151#[derive(Debug, Clone, Serialize, Deserialize)]
152pub struct FunctionDefinition {
153    /// Function name.
154    pub name: String,
155    /// Function description.
156    pub description: String,
157    /// JSON schema for parameters.
158    pub parameters: serde_json::Value,
159}
160
161/// Request to create a chat completion.
162#[derive(Debug, Clone, Serialize)]
163pub struct CreateChatCompletionRequest {
164    /// Model to use.
165    pub model: String,
166    /// Messages in the conversation.
167    pub messages: Vec<Message>,
168    /// Maximum tokens to generate.
169    #[serde(skip_serializing_if = "Option::is_none")]
170    pub max_tokens: Option<usize>,
171    /// Maximum completion tokens (for o1/o3 models).
172    #[serde(skip_serializing_if = "Option::is_none")]
173    pub max_completion_tokens: Option<usize>,
174    /// Temperature for sampling.
175    #[serde(skip_serializing_if = "Option::is_none")]
176    pub temperature: Option<f32>,
177    /// Top-p sampling.
178    #[serde(skip_serializing_if = "Option::is_none")]
179    pub top_p: Option<f32>,
180    /// Stop sequences.
181    #[serde(skip_serializing_if = "Option::is_none")]
182    pub stop: Option<Vec<String>>,
183    /// Available tools.
184    #[serde(skip_serializing_if = "Option::is_none")]
185    pub tools: Option<Vec<Tool>>,
186    /// Whether to stream the response.
187    #[serde(skip_serializing_if = "Option::is_none")]
188    pub stream: Option<bool>,
189    /// Number of completions to generate.
190    #[serde(skip_serializing_if = "Option::is_none")]
191    pub n: Option<usize>,
192    /// Presence penalty.
193    #[serde(skip_serializing_if = "Option::is_none")]
194    pub presence_penalty: Option<f32>,
195    /// Frequency penalty.
196    #[serde(skip_serializing_if = "Option::is_none")]
197    pub frequency_penalty: Option<f32>,
198}
199
200impl CreateChatCompletionRequest {
201    /// Create a new chat completion request.
202    pub fn new(model: impl Into<String>, messages: Vec<Message>) -> Self {
203        Self {
204            model: model.into(),
205            messages,
206            max_tokens: None,
207            max_completion_tokens: None,
208            temperature: None,
209            top_p: None,
210            stop: None,
211            tools: None,
212            stream: None,
213            n: None,
214            presence_penalty: None,
215            frequency_penalty: None,
216        }
217    }
218
219    /// Set max tokens.
220    pub fn with_max_tokens(mut self, max_tokens: usize) -> Self {
221        self.max_tokens = Some(max_tokens);
222        self
223    }
224
225    /// Set max completion tokens (for o1/o3 models).
226    pub fn with_max_completion_tokens(mut self, max_tokens: usize) -> Self {
227        self.max_completion_tokens = Some(max_tokens);
228        self
229    }
230
231    /// Set temperature.
232    pub fn with_temperature(mut self, temperature: f32) -> Self {
233        self.temperature = Some(temperature);
234        self
235    }
236
237    /// Set top-p sampling.
238    pub fn with_top_p(mut self, top_p: f32) -> Self {
239        self.top_p = Some(top_p);
240        self
241    }
242
243    /// Set stop sequences.
244    pub fn with_stop(mut self, stop: Vec<String>) -> Self {
245        self.stop = Some(stop);
246        self
247    }
248
249    /// Set available tools.
250    pub fn with_tools(mut self, tools: Vec<Tool>) -> Self {
251        self.tools = Some(tools);
252        self
253    }
254}
255
256/// Token usage statistics.
257#[derive(Debug, Clone, Deserialize)]
258pub struct Usage {
259    /// Prompt tokens.
260    pub prompt_tokens: usize,
261    /// Completion tokens.
262    pub completion_tokens: usize,
263    /// Total tokens.
264    pub total_tokens: usize,
265}
266
267/// A completion choice.
268#[derive(Debug, Clone, Deserialize)]
269pub struct Choice {
270    /// Choice index.
271    pub index: usize,
272    /// Generated message.
273    pub message: Message,
274    /// Finish reason.
275    pub finish_reason: Option<String>,
276}
277
278/// Response from creating a chat completion.
279#[derive(Debug, Clone, Deserialize)]
280pub struct CreateChatCompletionResponse {
281    /// Response ID.
282    pub id: String,
283    /// Object type.
284    pub object: String,
285    /// Creation timestamp.
286    pub created: u64,
287    /// Model used.
288    pub model: String,
289    /// Completion choices.
290    pub choices: Vec<Choice>,
291    /// Token usage.
292    pub usage: Option<Usage>,
293}
294
295impl CreateChatCompletionResponse {
296    /// Get the first choice's message content.
297    pub fn content(&self) -> Option<&str> {
298        self.choices
299            .first()
300            .and_then(|c| c.message.content.as_deref())
301    }
302
303    /// Get the first choice's tool calls.
304    pub fn tool_calls(&self) -> Option<&Vec<ToolCall>> {
305        self.choices
306            .first()
307            .and_then(|c| c.message.tool_calls.as_ref())
308    }
309
310    /// Check if the response contains tool calls.
311    pub fn has_tool_calls(&self) -> bool {
312        self.choices
313            .first()
314            .map(|c| c.message.tool_calls.is_some())
315            .unwrap_or(false)
316    }
317}
318
319/// Model information.
320#[derive(Debug, Clone, Deserialize)]
321pub struct Model {
322    /// Model ID.
323    pub id: String,
324    /// Object type.
325    pub object: String,
326    /// Creation timestamp.
327    pub created: u64,
328    /// Owner organization.
329    pub owned_by: String,
330}
331
332/// List of models.
333#[derive(Debug, Clone, Deserialize)]
334pub struct ModelList {
335    /// Object type.
336    pub object: String,
337    /// Models.
338    pub data: Vec<Model>,
339}
340
341/// Error response from the API.
342#[derive(Debug, Clone, Deserialize)]
343pub struct ErrorResponse {
344    /// Error details.
345    pub error: ErrorDetail,
346}
347
348/// Error detail.
349#[derive(Debug, Clone, Deserialize)]
350pub struct ErrorDetail {
351    /// Error message.
352    pub message: String,
353    /// Error type.
354    #[serde(rename = "type")]
355    pub error_type: String,
356    /// Parameter that caused the error.
357    pub param: Option<String>,
358    /// Error code.
359    pub code: Option<String>,
360}