Skip to main content

agent_sdk/
client.rs

1//! Canonical API types for the Messages API.
2//!
3//! These types are the shared representation used by all providers.
4//! Each provider translates to/from these types internally.
5
6use std::time::Duration;
7
8use serde::{Deserialize, Serialize};
9
10// ---------------------------------------------------------------------------
11// API types – request
12// ---------------------------------------------------------------------------
13
14/// Parameters for extended-thinking / budget control.
15#[derive(Debug, Clone, Serialize, Deserialize)]
16pub struct ThinkingParam {
17    /// The thinking budget token (e.g. `"enabled"` or a specific budget).
18    #[serde(rename = "type")]
19    pub kind: String,
20    /// Optional budget tokens limit.
21    #[serde(skip_serializing_if = "Option::is_none")]
22    pub budget_tokens: Option<u64>,
23}
24
25/// Cache control marker for prompt caching.
26#[derive(Debug, Clone, Serialize, Deserialize)]
27pub struct CacheControl {
28    #[serde(rename = "type")]
29    pub kind: String,
30}
31
32impl CacheControl {
33    pub fn ephemeral() -> Self {
34        Self {
35            kind: "ephemeral".to_string(),
36        }
37    }
38}
39
40/// A tool definition sent with the request.
41#[derive(Debug, Clone, Serialize, Deserialize)]
42pub struct ToolDefinition {
43    pub name: String,
44    pub description: String,
45    pub input_schema: serde_json::Value,
46    #[serde(skip_serializing_if = "Option::is_none")]
47    pub cache_control: Option<CacheControl>,
48}
49
50/// A system prompt content block (supports cache_control).
51#[derive(Debug, Clone, Serialize, Deserialize)]
52pub struct SystemBlock {
53    #[serde(rename = "type")]
54    pub kind: String,
55    pub text: String,
56    #[serde(skip_serializing_if = "Option::is_none")]
57    pub cache_control: Option<CacheControl>,
58}
59
60/// Source data for an image content block.
61#[derive(Debug, Clone, Serialize, Deserialize)]
62pub struct ImageSource {
63    /// Always "base64".
64    #[serde(rename = "type")]
65    pub kind: String,
66    /// MIME type (e.g. "image/png").
67    pub media_type: String,
68    /// Base64-encoded image data.
69    pub data: String,
70}
71
72/// A single content block inside an API message.
73#[derive(Debug, Clone, Serialize, Deserialize)]
74#[serde(tag = "type")]
75pub enum ApiContentBlock {
76    #[serde(rename = "text")]
77    Text {
78        text: String,
79        #[serde(skip_serializing_if = "Option::is_none")]
80        cache_control: Option<CacheControl>,
81    },
82
83    #[serde(rename = "image")]
84    Image { source: ImageSource },
85
86    #[serde(rename = "tool_use")]
87    ToolUse {
88        id: String,
89        name: String,
90        input: serde_json::Value,
91    },
92
93    #[serde(rename = "tool_result")]
94    ToolResult {
95        tool_use_id: String,
96        content: serde_json::Value,
97        #[serde(skip_serializing_if = "Option::is_none")]
98        is_error: Option<bool>,
99        #[serde(skip_serializing_if = "Option::is_none")]
100        cache_control: Option<CacheControl>,
101        /// The tool/function name (used by Gemini's functionResponse).
102        #[serde(skip_serializing_if = "Option::is_none")]
103        name: Option<String>,
104    },
105
106    #[serde(rename = "thinking")]
107    Thinking { thinking: String },
108}
109
110/// A message in the conversation sent to the API.
111#[derive(Debug, Clone, Serialize, Deserialize)]
112pub struct ApiMessage {
113    pub role: String,
114    pub content: Vec<ApiContentBlock>,
115}
116
117/// The full request body for `POST /v1/messages`.
118#[derive(Debug, Clone, Serialize)]
119pub struct CreateMessageRequest {
120    pub model: String,
121    pub max_tokens: u32,
122    pub messages: Vec<ApiMessage>,
123    #[serde(skip_serializing_if = "Option::is_none")]
124    pub system: Option<Vec<SystemBlock>>,
125    #[serde(skip_serializing_if = "Option::is_none")]
126    pub tools: Option<Vec<ToolDefinition>>,
127    pub stream: bool,
128    #[serde(skip_serializing_if = "Option::is_none")]
129    pub metadata: Option<serde_json::Value>,
130    #[serde(skip_serializing_if = "Option::is_none")]
131    pub thinking: Option<ThinkingParam>,
132}
133
134// ---------------------------------------------------------------------------
135// API types – response
136// ---------------------------------------------------------------------------
137
138/// Token usage returned by the API.
139#[derive(Debug, Clone, Default, Serialize, Deserialize)]
140pub struct ApiUsage {
141    #[serde(default)]
142    pub input_tokens: u64,
143    #[serde(default)]
144    pub output_tokens: u64,
145    #[serde(default)]
146    pub cache_creation_input_tokens: Option<u64>,
147    #[serde(default)]
148    pub cache_read_input_tokens: Option<u64>,
149}
150
151/// A full (non-streaming) response from the Messages API.
152#[derive(Debug, Clone, Serialize, Deserialize)]
153pub struct MessageResponse {
154    pub id: String,
155    pub role: String,
156    pub content: Vec<ApiContentBlock>,
157    pub model: String,
158    pub stop_reason: Option<String>,
159    #[serde(default)]
160    pub usage: ApiUsage,
161}
162
163/// Error payload returned by the API.
164#[derive(Debug, Clone, Serialize, Deserialize)]
165pub struct ApiError {
166    #[serde(rename = "type")]
167    pub kind: String,
168    pub message: String,
169}
170
171/// Error wrapper as returned in the top-level JSON.
172#[derive(Debug, Clone, Serialize, Deserialize)]
173pub struct ApiErrorResponse {
174    pub error: ApiError,
175}
176
177// ---------------------------------------------------------------------------
178// Streaming types
179// ---------------------------------------------------------------------------
180
181/// Delta for a text content block.
182#[derive(Debug, Clone, Serialize, Deserialize)]
183#[serde(tag = "type")]
184pub enum ContentDelta {
185    #[serde(rename = "text_delta")]
186    TextDelta { text: String },
187
188    #[serde(rename = "input_json_delta")]
189    InputJsonDelta { partial_json: String },
190
191    #[serde(rename = "thinking_delta")]
192    ThinkingDelta { thinking: String },
193}
194
195/// Delta that comes with `message_delta` events.
196#[derive(Debug, Clone, Serialize, Deserialize)]
197pub struct MessageDelta {
198    pub stop_reason: Option<String>,
199}
200
201/// Server-sent events emitted during streaming.
202#[derive(Debug, Clone)]
203pub enum StreamEvent {
204    MessageStart {
205        message: MessageResponse,
206    },
207    ContentBlockStart {
208        index: usize,
209        content_block: ApiContentBlock,
210    },
211    ContentBlockDelta {
212        index: usize,
213        delta: ContentDelta,
214    },
215    ContentBlockStop {
216        index: usize,
217    },
218    MessageDelta {
219        delta: MessageDelta,
220        usage: ApiUsage,
221    },
222    MessageStop,
223    Ping,
224    Error {
225        error: ApiError,
226    },
227}
228
229// ---------------------------------------------------------------------------
230// Retry configuration (shared by providers)
231// ---------------------------------------------------------------------------
232
233/// Configuration for exponential-backoff retries.
234#[derive(Debug, Clone)]
235pub struct RetryConfig {
236    /// Maximum number of retry attempts (not counting the initial request).
237    pub max_retries: u32,
238    /// Initial back-off duration.
239    pub initial_backoff: Duration,
240    /// Multiplicative factor applied after each attempt.
241    pub backoff_multiplier: f64,
242    /// Upper bound on the back-off duration.
243    pub max_backoff: Duration,
244}
245
246impl Default for RetryConfig {
247    fn default() -> Self {
248        Self {
249            max_retries: 5,
250            initial_backoff: Duration::from_secs(1),
251            backoff_multiplier: 2.0,
252            max_backoff: Duration::from_secs(60),
253        }
254    }
255}
256
257#[cfg(test)]
258mod tests {
259    use super::*;
260
261    #[test]
262    fn serialize_request_omits_none_fields() {
263        let req = CreateMessageRequest {
264            model: "claude-haiku-4-5".into(),
265            max_tokens: 1024,
266            messages: vec![ApiMessage {
267                role: "user".into(),
268                content: vec![ApiContentBlock::Text {
269                    text: "Hello".into(),
270                    cache_control: None,
271                }],
272            }],
273            system: None,
274            tools: None,
275            stream: false,
276            metadata: None,
277            thinking: None,
278        };
279        let json = serde_json::to_value(&req).unwrap();
280        assert!(!json.as_object().unwrap().contains_key("system"));
281        assert!(!json.as_object().unwrap().contains_key("tools"));
282        assert!(!json.as_object().unwrap().contains_key("metadata"));
283        assert!(!json.as_object().unwrap().contains_key("thinking"));
284    }
285
286    #[test]
287    fn tool_use_content_block_roundtrips() {
288        let block = ApiContentBlock::ToolUse {
289            id: "tu_123".into(),
290            name: "bash".into(),
291            input: serde_json::json!({"command": "ls"}),
292        };
293        let json = serde_json::to_string(&block).unwrap();
294        let back: ApiContentBlock = serde_json::from_str(&json).unwrap();
295        match back {
296            ApiContentBlock::ToolUse { id, name, input } => {
297                assert_eq!(id, "tu_123");
298                assert_eq!(name, "bash");
299                assert_eq!(input, serde_json::json!({"command": "ls"}));
300            }
301            _ => panic!("wrong variant"),
302        }
303    }
304
305    #[test]
306    fn tool_result_content_block_roundtrips() {
307        let block = ApiContentBlock::ToolResult {
308            tool_use_id: "tu_123".into(),
309            content: serde_json::json!("output text"),
310            is_error: Some(false),
311            cache_control: None,
312            name: None,
313        };
314        let json = serde_json::to_string(&block).unwrap();
315        let back: ApiContentBlock = serde_json::from_str(&json).unwrap();
316        match back {
317            ApiContentBlock::ToolResult {
318                tool_use_id,
319                content,
320                is_error,
321                ..
322            } => {
323                assert_eq!(tool_use_id, "tu_123");
324                assert_eq!(content, serde_json::json!("output text"));
325                assert_eq!(is_error, Some(false));
326            }
327            _ => panic!("wrong variant"),
328        }
329    }
330
331    #[test]
332    fn image_content_block_roundtrips() {
333        let block = ApiContentBlock::Image {
334            source: ImageSource {
335                kind: "base64".into(),
336                media_type: "image/png".into(),
337                data: "iVBORw0KGgo=".into(),
338            },
339        };
340        let json = serde_json::to_string(&block).unwrap();
341        assert!(json.contains("\"type\":\"image\""));
342        assert!(json.contains("\"media_type\":\"image/png\""));
343        let back: ApiContentBlock = serde_json::from_str(&json).unwrap();
344        match back {
345            ApiContentBlock::Image { source } => {
346                assert_eq!(source.kind, "base64");
347                assert_eq!(source.media_type, "image/png");
348                assert_eq!(source.data, "iVBORw0KGgo=");
349            }
350            _ => panic!("wrong variant"),
351        }
352    }
353
354    #[test]
355    fn image_in_user_message_serializes() {
356        let msg = ApiMessage {
357            role: "user".into(),
358            content: vec![
359                ApiContentBlock::Image {
360                    source: ImageSource {
361                        kind: "base64".into(),
362                        media_type: "image/jpeg".into(),
363                        data: "abc123".into(),
364                    },
365                },
366                ApiContentBlock::Text {
367                    text: "What is this?".into(),
368                    cache_control: None,
369                },
370            ],
371        };
372        let json = serde_json::to_value(&msg).unwrap();
373        let content = json["content"].as_array().unwrap();
374        assert_eq!(content.len(), 2);
375        assert_eq!(content[0]["type"], "image");
376        assert_eq!(content[1]["type"], "text");
377    }
378
379    #[test]
380    fn backoff_duration_increases() {
381        use crate::provider::LlmProvider;
382        use crate::AnthropicProvider;
383        let provider = AnthropicProvider::with_api_key("test-key");
384        let caps = provider.capabilities();
385        assert!(caps.streaming);
386        assert!(caps.tool_use);
387    }
388}