1use serde::{Deserialize, Serialize};
5
6#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
8#[serde(rename_all = "lowercase")]
9pub enum Role {
10 System,
11 User,
12 Assistant,
13 Tool,
14 Developer,
15}
16
17#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
19#[serde(rename_all = "snake_case")]
20pub enum FinishReason {
21 Stop,
22 Length,
23 ToolCalls,
24 ContentFilter,
25 FunctionCall,
26 #[serde(other)]
28 Unknown,
29}
30
31#[derive(Debug, Clone, Default, Serialize, Deserialize)]
33#[non_exhaustive]
34pub struct Usage {
35 #[serde(default)]
36 pub prompt_tokens: u64,
37 #[serde(default)]
38 pub completion_tokens: u64,
39 #[serde(default)]
40 pub total_tokens: u64,
41 #[serde(default, skip_serializing_if = "Option::is_none")]
43 pub prompt_tokens_details: Option<serde_json::Value>,
44 #[serde(default, skip_serializing_if = "Option::is_none")]
46 pub completion_tokens_details: Option<serde_json::Value>,
47}
48
49#[derive(Debug, Clone, Serialize, Deserialize)]
53pub struct ExpiresAfter {
54 pub anchor: String,
55 pub seconds: u64,
56}
57
58impl ExpiresAfter {
59 pub fn after_creation(seconds: u64) -> Self {
61 Self {
62 anchor: "created_at".into(),
63 seconds,
64 }
65 }
66}
67
68#[derive(Debug, Clone, Serialize, Deserialize)]
70#[serde(tag = "type", rename_all = "snake_case")]
71pub enum ResponseFormat {
72 Text,
73 JsonObject,
74 JsonSchema { json_schema: serde_json::Value },
77}
78
79impl ResponseFormat {
80 pub fn json_schema(name: impl Into<String>, schema: serde_json::Value, strict: bool) -> Self {
82 Self::JsonSchema {
83 json_schema: serde_json::json!({
84 "name": name.into(),
85 "schema": schema,
86 "strict": strict,
87 }),
88 }
89 }
90}
91
92#[derive(Debug, Clone, Serialize, Deserialize)]
94pub struct Tool {
95 #[serde(rename = "type")]
96 pub tool_type: String,
97 pub function: FunctionDef,
98}
99
100impl Tool {
101 pub fn function(
103 name: impl Into<String>,
104 description: impl Into<String>,
105 parameters: serde_json::Value,
106 ) -> Self {
107 Self {
108 tool_type: "function".into(),
109 function: FunctionDef {
110 name: name.into(),
111 description: Some(description.into()),
112 parameters: Some(parameters),
113 strict: None,
114 },
115 }
116 }
117}
118
119#[derive(Debug, Clone, Serialize, Deserialize)]
121pub struct FunctionDef {
122 pub name: String,
123 #[serde(default, skip_serializing_if = "Option::is_none")]
124 pub description: Option<String>,
125 #[serde(default, skip_serializing_if = "Option::is_none")]
126 pub parameters: Option<serde_json::Value>,
127 #[serde(default, skip_serializing_if = "Option::is_none")]
128 pub strict: Option<bool>,
129}
130
131#[derive(Debug, Clone, PartialEq, Eq)]
134pub enum ToolChoice {
135 None,
136 Auto,
137 Required,
138 Function(String),
140}
141
142impl Serialize for ToolChoice {
143 fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
144 match self {
145 Self::None => serializer.serialize_str("none"),
146 Self::Auto => serializer.serialize_str("auto"),
147 Self::Required => serializer.serialize_str("required"),
148 Self::Function(name) => serde_json::json!({
149 "type": "function",
150 "function": { "name": name },
151 })
152 .serialize(serializer),
153 }
154 }
155}
156
157#[cfg(test)]
158mod tests {
159 use super::*;
160
161 #[test]
162 fn tool_choice_serialization() {
163 assert_eq!(serde_json::to_value(ToolChoice::Auto).unwrap(), "auto");
164 assert_eq!(serde_json::to_value(ToolChoice::None).unwrap(), "none");
165 assert_eq!(
166 serde_json::to_value(ToolChoice::Function("get_weather".into())).unwrap(),
167 serde_json::json!({"type": "function", "function": {"name": "get_weather"}})
168 );
169 }
170
171 #[test]
172 fn response_format_tagging() {
173 assert_eq!(
174 serde_json::to_value(ResponseFormat::JsonObject).unwrap(),
175 serde_json::json!({"type": "json_object"})
176 );
177 let schema = ResponseFormat::json_schema("out", serde_json::json!({"type": "object"}), true);
178 assert_eq!(
179 serde_json::to_value(schema).unwrap(),
180 serde_json::json!({
181 "type": "json_schema",
182 "json_schema": {"name": "out", "schema": {"type": "object"}, "strict": true}
183 })
184 );
185 }
186
187 #[test]
188 fn unknown_finish_reason_is_forward_compatible() {
189 let reason: FinishReason = serde_json::from_str("\"eos_token\"").unwrap();
190 assert_eq!(reason, FinishReason::Unknown);
191 }
192}