Skip to main content

agent_sdk/
client.rs

1//! Canonical API types for the Messages API.
2//!
3//! These types are the shared representation used by all providers.
4//! Each provider translates to/from these types internally.
5
6use std::time::Duration;
7
8use serde::{Deserialize, Serialize};
9
10
11// ---------------------------------------------------------------------------
12// API types – request
13// ---------------------------------------------------------------------------
14
15/// Parameters for extended-thinking / budget control.
16#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct ThinkingParam {
18    /// The thinking budget token (e.g. `"enabled"` or a specific budget).
19    #[serde(rename = "type")]
20    pub kind: String,
21    /// Optional budget tokens limit.
22    #[serde(skip_serializing_if = "Option::is_none")]
23    pub budget_tokens: Option<u64>,
24}
25
26/// Cache control marker for prompt caching.
27#[derive(Debug, Clone, Serialize, Deserialize)]
28pub struct CacheControl {
29    #[serde(rename = "type")]
30    pub kind: String,
31}
32
33impl CacheControl {
34    pub fn ephemeral() -> Self {
35        Self {
36            kind: "ephemeral".to_string(),
37        }
38    }
39}
40
41/// A tool definition sent with the request.
42#[derive(Debug, Clone, Serialize, Deserialize)]
43pub struct ToolDefinition {
44    pub name: String,
45    pub description: String,
46    pub input_schema: serde_json::Value,
47    #[serde(skip_serializing_if = "Option::is_none")]
48    pub cache_control: Option<CacheControl>,
49}
50
51/// A system prompt content block (supports cache_control).
52#[derive(Debug, Clone, Serialize, Deserialize)]
53pub struct SystemBlock {
54    #[serde(rename = "type")]
55    pub kind: String,
56    pub text: String,
57    #[serde(skip_serializing_if = "Option::is_none")]
58    pub cache_control: Option<CacheControl>,
59}
60
61/// Source data for an image content block.
62#[derive(Debug, Clone, Serialize, Deserialize)]
63pub struct ImageSource {
64    /// Always "base64".
65    #[serde(rename = "type")]
66    pub kind: String,
67    /// MIME type (e.g. "image/png").
68    pub media_type: String,
69    /// Base64-encoded image data.
70    pub data: String,
71}
72
73/// A single content block inside an API message.
74#[derive(Debug, Clone, Serialize, Deserialize)]
75#[serde(tag = "type")]
76pub enum ApiContentBlock {
77    #[serde(rename = "text")]
78    Text {
79        text: String,
80        #[serde(skip_serializing_if = "Option::is_none")]
81        cache_control: Option<CacheControl>,
82    },
83
84    #[serde(rename = "image")]
85    Image {
86        source: ImageSource,
87    },
88
89    #[serde(rename = "tool_use")]
90    ToolUse {
91        id: String,
92        name: String,
93        input: serde_json::Value,
94    },
95
96    #[serde(rename = "tool_result")]
97    ToolResult {
98        tool_use_id: String,
99        content: serde_json::Value,
100        #[serde(skip_serializing_if = "Option::is_none")]
101        is_error: Option<bool>,
102        #[serde(skip_serializing_if = "Option::is_none")]
103        cache_control: Option<CacheControl>,
104        /// The tool/function name (used by Gemini's functionResponse).
105        #[serde(skip_serializing_if = "Option::is_none")]
106        name: Option<String>,
107    },
108
109    #[serde(rename = "thinking")]
110    Thinking { thinking: String },
111}
112
113/// A message in the conversation sent to the API.
114#[derive(Debug, Clone, Serialize, Deserialize)]
115pub struct ApiMessage {
116    pub role: String,
117    pub content: Vec<ApiContentBlock>,
118}
119
120/// The full request body for `POST /v1/messages`.
121#[derive(Debug, Clone, Serialize)]
122pub struct CreateMessageRequest {
123    pub model: String,
124    pub max_tokens: u32,
125    pub messages: Vec<ApiMessage>,
126    #[serde(skip_serializing_if = "Option::is_none")]
127    pub system: Option<Vec<SystemBlock>>,
128    #[serde(skip_serializing_if = "Option::is_none")]
129    pub tools: Option<Vec<ToolDefinition>>,
130    pub stream: bool,
131    #[serde(skip_serializing_if = "Option::is_none")]
132    pub metadata: Option<serde_json::Value>,
133    #[serde(skip_serializing_if = "Option::is_none")]
134    pub thinking: Option<ThinkingParam>,
135}
136
137// ---------------------------------------------------------------------------
138// API types – response
139// ---------------------------------------------------------------------------
140
141/// Token usage returned by the API.
142#[derive(Debug, Clone, Default, Serialize, Deserialize)]
143pub struct ApiUsage {
144    #[serde(default)]
145    pub input_tokens: u64,
146    #[serde(default)]
147    pub output_tokens: u64,
148    #[serde(default)]
149    pub cache_creation_input_tokens: Option<u64>,
150    #[serde(default)]
151    pub cache_read_input_tokens: Option<u64>,
152}
153
154/// A full (non-streaming) response from the Messages API.
155#[derive(Debug, Clone, Serialize, Deserialize)]
156pub struct MessageResponse {
157    pub id: String,
158    pub role: String,
159    pub content: Vec<ApiContentBlock>,
160    pub model: String,
161    pub stop_reason: Option<String>,
162    #[serde(default)]
163    pub usage: ApiUsage,
164}
165
166/// Error payload returned by the API.
167#[derive(Debug, Clone, Serialize, Deserialize)]
168pub struct ApiError {
169    #[serde(rename = "type")]
170    pub kind: String,
171    pub message: String,
172}
173
174/// Error wrapper as returned in the top-level JSON.
175#[derive(Debug, Clone, Serialize, Deserialize)]
176pub struct ApiErrorResponse {
177    pub error: ApiError,
178}
179
180// ---------------------------------------------------------------------------
181// Streaming types
182// ---------------------------------------------------------------------------
183
184/// Delta for a text content block.
185#[derive(Debug, Clone, Serialize, Deserialize)]
186#[serde(tag = "type")]
187pub enum ContentDelta {
188    #[serde(rename = "text_delta")]
189    TextDelta { text: String },
190
191    #[serde(rename = "input_json_delta")]
192    InputJsonDelta { partial_json: String },
193
194    #[serde(rename = "thinking_delta")]
195    ThinkingDelta { thinking: String },
196}
197
198/// Delta that comes with `message_delta` events.
199#[derive(Debug, Clone, Serialize, Deserialize)]
200pub struct MessageDelta {
201    pub stop_reason: Option<String>,
202}
203
204/// Server-sent events emitted during streaming.
205#[derive(Debug, Clone)]
206pub enum StreamEvent {
207    MessageStart {
208        message: MessageResponse,
209    },
210    ContentBlockStart {
211        index: usize,
212        content_block: ApiContentBlock,
213    },
214    ContentBlockDelta {
215        index: usize,
216        delta: ContentDelta,
217    },
218    ContentBlockStop {
219        index: usize,
220    },
221    MessageDelta {
222        delta: MessageDelta,
223        usage: ApiUsage,
224    },
225    MessageStop,
226    Ping,
227    Error {
228        error: ApiError,
229    },
230}
231
232// ---------------------------------------------------------------------------
233// Retry configuration (shared by providers)
234// ---------------------------------------------------------------------------
235
236/// Configuration for exponential-backoff retries.
237#[derive(Debug, Clone)]
238pub struct RetryConfig {
239    /// Maximum number of retry attempts (not counting the initial request).
240    pub max_retries: u32,
241    /// Initial back-off duration.
242    pub initial_backoff: Duration,
243    /// Multiplicative factor applied after each attempt.
244    pub backoff_multiplier: f64,
245    /// Upper bound on the back-off duration.
246    pub max_backoff: Duration,
247}
248
249impl Default for RetryConfig {
250    fn default() -> Self {
251        Self {
252            max_retries: 5,
253            initial_backoff: Duration::from_secs(1),
254            backoff_multiplier: 2.0,
255            max_backoff: Duration::from_secs(60),
256        }
257    }
258}
259
260#[cfg(test)]
261mod tests {
262    use super::*;
263
264    #[test]
265    fn serialize_request_omits_none_fields() {
266        let req = CreateMessageRequest {
267            model: "claude-haiku-4-5".into(),
268            max_tokens: 1024,
269            messages: vec![ApiMessage {
270                role: "user".into(),
271                content: vec![ApiContentBlock::Text {
272                    text: "Hello".into(),
273                    cache_control: None,
274                }],
275            }],
276            system: None,
277            tools: None,
278            stream: false,
279            metadata: None,
280            thinking: None,
281        };
282        let json = serde_json::to_value(&req).unwrap();
283        assert!(!json.as_object().unwrap().contains_key("system"));
284        assert!(!json.as_object().unwrap().contains_key("tools"));
285        assert!(!json.as_object().unwrap().contains_key("metadata"));
286        assert!(!json.as_object().unwrap().contains_key("thinking"));
287    }
288
289    #[test]
290    fn tool_use_content_block_roundtrips() {
291        let block = ApiContentBlock::ToolUse {
292            id: "tu_123".into(),
293            name: "bash".into(),
294            input: serde_json::json!({"command": "ls"}),
295        };
296        let json = serde_json::to_string(&block).unwrap();
297        let back: ApiContentBlock = serde_json::from_str(&json).unwrap();
298        match back {
299            ApiContentBlock::ToolUse { id, name, input } => {
300                assert_eq!(id, "tu_123");
301                assert_eq!(name, "bash");
302                assert_eq!(input, serde_json::json!({"command": "ls"}));
303            }
304            _ => panic!("wrong variant"),
305        }
306    }
307
308    #[test]
309    fn tool_result_content_block_roundtrips() {
310        let block = ApiContentBlock::ToolResult {
311            tool_use_id: "tu_123".into(),
312            content: serde_json::json!("output text"),
313            is_error: Some(false),
314            cache_control: None,
315            name: None,
316        };
317        let json = serde_json::to_string(&block).unwrap();
318        let back: ApiContentBlock = serde_json::from_str(&json).unwrap();
319        match back {
320            ApiContentBlock::ToolResult {
321                tool_use_id,
322                content,
323                is_error,
324                ..
325            } => {
326                assert_eq!(tool_use_id, "tu_123");
327                assert_eq!(content, serde_json::json!("output text"));
328                assert_eq!(is_error, Some(false));
329            }
330            _ => panic!("wrong variant"),
331        }
332    }
333
334    #[test]
335    fn image_content_block_roundtrips() {
336        let block = ApiContentBlock::Image {
337            source: ImageSource {
338                kind: "base64".into(),
339                media_type: "image/png".into(),
340                data: "iVBORw0KGgo=".into(),
341            },
342        };
343        let json = serde_json::to_string(&block).unwrap();
344        assert!(json.contains("\"type\":\"image\""));
345        assert!(json.contains("\"media_type\":\"image/png\""));
346        let back: ApiContentBlock = serde_json::from_str(&json).unwrap();
347        match back {
348            ApiContentBlock::Image { source } => {
349                assert_eq!(source.kind, "base64");
350                assert_eq!(source.media_type, "image/png");
351                assert_eq!(source.data, "iVBORw0KGgo=");
352            }
353            _ => panic!("wrong variant"),
354        }
355    }
356
357    #[test]
358    fn image_in_user_message_serializes() {
359        let msg = ApiMessage {
360            role: "user".into(),
361            content: vec![
362                ApiContentBlock::Image {
363                    source: ImageSource {
364                        kind: "base64".into(),
365                        media_type: "image/jpeg".into(),
366                        data: "abc123".into(),
367                    },
368                },
369                ApiContentBlock::Text {
370                    text: "What is this?".into(),
371                    cache_control: None,
372                },
373            ],
374        };
375        let json = serde_json::to_value(&msg).unwrap();
376        let content = json["content"].as_array().unwrap();
377        assert_eq!(content.len(), 2);
378        assert_eq!(content[0]["type"], "image");
379        assert_eq!(content[1]["type"], "text");
380    }
381
382    #[test]
383    fn backoff_duration_increases() {
384        use crate::AnthropicProvider;
385        use crate::provider::LlmProvider;
386        let provider = AnthropicProvider::with_api_key("test-key");
387        let caps = provider.capabilities();
388        assert!(caps.streaming);
389        assert!(caps.tool_use);
390    }
391
392}