1use serde::{Deserialize, Serialize};
2
3#[derive(Debug, Clone)]
4pub struct ChatRequest {
5 pub system: String,
6 pub messages: Vec<Message>,
7 pub tools: Option<Vec<Tool>>,
8 pub max_tokens: u32,
9}
10
11#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct Message {
13 pub role: Role,
14 pub content: Content,
15}
16
17impl Message {
18 #[must_use]
19 pub fn user(text: impl Into<String>) -> Self {
20 Self {
21 role: Role::User,
22 content: Content::Text(text.into()),
23 }
24 }
25
26 #[must_use]
27 pub fn assistant(text: impl Into<String>) -> Self {
28 Self {
29 role: Role::Assistant,
30 content: Content::Text(text.into()),
31 }
32 }
33
34 #[must_use]
35 pub fn assistant_with_tool_use(
36 text: Option<String>,
37 id: impl Into<String>,
38 name: impl Into<String>,
39 input: serde_json::Value,
40 ) -> Self {
41 let mut blocks = Vec::new();
42 if let Some(t) = text {
43 blocks.push(ContentBlock::Text { text: t });
44 }
45 blocks.push(ContentBlock::ToolUse {
46 id: id.into(),
47 name: name.into(),
48 input,
49 thought_signature: None,
50 });
51 Self {
52 role: Role::Assistant,
53 content: Content::Blocks(blocks),
54 }
55 }
56
57 #[must_use]
58 pub fn tool_result(
59 tool_use_id: impl Into<String>,
60 content: impl Into<String>,
61 is_error: bool,
62 ) -> Self {
63 Self {
64 role: Role::User,
65 content: Content::Blocks(vec![ContentBlock::ToolResult {
66 tool_use_id: tool_use_id.into(),
67 content: content.into(),
68 is_error: if is_error { Some(true) } else { None },
69 }]),
70 }
71 }
72}
73
74#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
75#[serde(rename_all = "lowercase")]
76pub enum Role {
77 User,
78 Assistant,
79}
80
81#[derive(Debug, Clone, Serialize, Deserialize)]
82#[serde(untagged)]
83pub enum Content {
84 Text(String),
85 Blocks(Vec<ContentBlock>),
86}
87
88impl Content {
89 #[must_use]
90 pub fn first_text(&self) -> Option<&str> {
91 match self {
92 Self::Text(s) => Some(s),
93 Self::Blocks(blocks) => blocks.iter().find_map(|b| match b {
94 ContentBlock::Text { text } => Some(text.as_str()),
95 _ => None,
96 }),
97 }
98 }
99}
100
101#[derive(Debug, Clone, Serialize, Deserialize)]
102#[serde(tag = "type")]
103pub enum ContentBlock {
104 #[serde(rename = "text")]
105 Text { text: String },
106
107 #[serde(rename = "tool_use")]
108 ToolUse {
109 id: String,
110 name: String,
111 input: serde_json::Value,
112 #[serde(skip_serializing_if = "Option::is_none")]
115 thought_signature: Option<String>,
116 },
117
118 #[serde(rename = "tool_result")]
119 ToolResult {
120 tool_use_id: String,
121 content: String,
122 #[serde(skip_serializing_if = "Option::is_none")]
123 is_error: Option<bool>,
124 },
125}
126
127#[derive(Debug, Clone, Serialize)]
128pub struct Tool {
129 pub name: String,
130 pub description: String,
131 pub input_schema: serde_json::Value,
132}
133
134#[derive(Debug, Clone)]
135pub struct ChatResponse {
136 pub id: String,
137 pub content: Vec<ContentBlock>,
138 pub model: String,
139 pub stop_reason: Option<StopReason>,
140 pub usage: Usage,
141}
142
143impl ChatResponse {
144 #[must_use]
145 pub fn first_text(&self) -> Option<&str> {
146 self.content.iter().find_map(|b| match b {
147 ContentBlock::Text { text } => Some(text.as_str()),
148 _ => None,
149 })
150 }
151
152 pub fn tool_uses(&self) -> impl Iterator<Item = (&str, &str, &serde_json::Value)> {
153 self.content.iter().filter_map(|b| match b {
154 ContentBlock::ToolUse {
155 id, name, input, ..
156 } => Some((id.as_str(), name.as_str(), input)),
157 _ => None,
158 })
159 }
160
161 #[must_use]
162 pub fn has_tool_use(&self) -> bool {
163 self.content
164 .iter()
165 .any(|b| matches!(b, ContentBlock::ToolUse { .. }))
166 }
167}
168
169#[derive(Debug, Clone, Copy, Deserialize, PartialEq, Eq)]
170#[serde(rename_all = "snake_case")]
171pub enum StopReason {
172 EndTurn,
173 ToolUse,
174 MaxTokens,
175 StopSequence,
176}
177
178#[derive(Debug, Clone, Deserialize)]
179pub struct Usage {
180 pub input_tokens: u32,
181 pub output_tokens: u32,
182}
183
184#[derive(Debug)]
185pub enum ChatOutcome {
186 Success(ChatResponse),
187 RateLimited,
188 InvalidRequest(String),
189 ServerError(String),
190}