Skip to main content

agent_sdk/
provider.rs

1//! LLM provider trait and shared types.
2
3use std::pin::Pin;
4
5use async_trait::async_trait;
6use futures::Stream;
7
8use crate::client::{CreateMessageRequest, MessageResponse, StreamEvent};
9use crate::error::Result;
10
11/// Capabilities advertised by a provider.
12#[derive(Debug, Clone)]
13pub struct ProviderCapabilities {
14    /// Supports streaming responses.
15    pub streaming: bool,
16    /// Supports tool/function calling.
17    pub tool_use: bool,
18    /// Supports extended thinking / chain-of-thought.
19    pub thinking: bool,
20    /// Supports prompt caching (cache_control blocks).
21    pub prompt_caching: bool,
22}
23
24/// Per-model cost rates (USD per million tokens).
25#[derive(Debug, Clone)]
26pub struct CostRates {
27    pub input_per_million: f64,
28    pub output_per_million: f64,
29    /// Multiplier for cache-read tokens relative to input rate (e.g. 0.1 = 10%).
30    /// `None` means cache tokens are billed at the standard input rate.
31    pub cache_read_multiplier: Option<f64>,
32    /// Multiplier for cache-creation tokens relative to input rate (e.g. 1.25 = 125%).
33    /// `None` means cache tokens are billed at the standard input rate.
34    pub cache_creation_multiplier: Option<f64>,
35}
36
37impl CostRates {
38    /// Compute cost for a given number of input/output tokens (ignoring cache).
39    pub fn compute(&self, input_tokens: u64, output_tokens: u64) -> f64 {
40        self.compute_with_cache(input_tokens, output_tokens, 0, 0)
41    }
42
43    /// Compute cost accounting for cached token pricing.
44    ///
45    /// `input_tokens` here is only the uncached portion (as returned by the API).
46    /// `cache_read` and `cache_creation` are billed at their respective multiplied rates.
47    pub fn compute_with_cache(
48        &self,
49        input_tokens: u64,
50        output_tokens: u64,
51        cache_read_tokens: u64,
52        cache_creation_tokens: u64,
53    ) -> f64 {
54        let read_rate = self.input_per_million * self.cache_read_multiplier.unwrap_or(1.0);
55        let create_rate = self.input_per_million * self.cache_creation_multiplier.unwrap_or(1.0);
56        (input_tokens as f64 * self.input_per_million
57            + cache_read_tokens as f64 * read_rate
58            + cache_creation_tokens as f64 * create_rate
59            + output_tokens as f64 * self.output_per_million)
60            / 1_000_000.0
61    }
62}
63
64/// Trait that all LLM providers implement.
65///
66/// Each provider translates between the canonical API types
67/// (`CreateMessageRequest`, `MessageResponse`, `StreamEvent`) and
68/// its own wire format internally.
69#[async_trait]
70pub trait LlmProvider: Send + Sync {
71    /// Human-readable provider name (e.g. "anthropic", "openai").
72    fn name(&self) -> &str;
73
74    /// Capabilities this provider supports.
75    fn capabilities(&self) -> ProviderCapabilities;
76
77    /// Cost rates for a given model.
78    fn cost_rates(&self, model: &str) -> CostRates;
79
80    /// Send a non-streaming request and return the complete response.
81    async fn create_message(&self, request: &CreateMessageRequest) -> Result<MessageResponse>;
82
83    /// Send a streaming request and return a stream of events.
84    async fn create_message_stream(
85        &self,
86        request: &CreateMessageRequest,
87    ) -> Result<Pin<Box<dyn Stream<Item = Result<StreamEvent>> + Send>>>;
88}
89
90#[cfg(test)]
91mod tests {
92    use super::*;
93
94    fn simple_rates(input: f64, output: f64) -> CostRates {
95        CostRates {
96            input_per_million: input,
97            output_per_million: output,
98            cache_read_multiplier: None,
99            cache_creation_multiplier: None,
100        }
101    }
102
103    #[test]
104    fn cost_rates_compute() {
105        let rates = simple_rates(2.0, 8.0);
106        // 1M input tokens at $2/M + 500K output tokens at $8/M = $2 + $4 = $6
107        let cost = rates.compute(1_000_000, 500_000);
108        assert!((cost - 6.0).abs() < 1e-9, "expected 6.0, got {}", cost);
109    }
110
111    #[test]
112    fn cost_rates_compute_zero_tokens() {
113        let rates = simple_rates(10.0, 40.0);
114        let cost = rates.compute(0, 0);
115        assert!((cost - 0.0).abs() < 1e-9, "expected 0.0, got {}", cost);
116    }
117
118    #[test]
119    fn cost_rates_compute_small_usage() {
120        let rates = simple_rates(2.5, 10.0);
121        // 100 input + 50 output => (100 * 2.5 + 50 * 10.0) / 1_000_000 = 750 / 1_000_000
122        let cost = rates.compute(100, 50);
123        let expected = 750.0 / 1_000_000.0;
124        assert!(
125            (cost - expected).abs() < 1e-12,
126            "expected {}, got {}",
127            expected,
128            cost
129        );
130    }
131
132    #[test]
133    fn cost_rates_with_cache() {
134        let rates = CostRates {
135            input_per_million: 3.0, // Sonnet pricing
136            output_per_million: 15.0,
137            cache_read_multiplier: Some(0.1),
138            cache_creation_multiplier: Some(1.25),
139        };
140        // 1000 uncached at $3/M + 10000 cache_read at $0.30/M + 2000 cache_creation at $3.75/M + 500 output at $15/M
141        let cost = rates.compute_with_cache(1000, 500, 10_000, 2000);
142        let expected = (1000.0 * 3.0 + 10_000.0 * 0.3 + 2000.0 * 3.75 + 500.0 * 15.0) / 1_000_000.0;
143        assert!(
144            (cost - expected).abs() < 1e-12,
145            "expected {}, got {}",
146            expected,
147            cost
148        );
149    }
150
151    #[test]
152    fn cost_rates_cache_read_only() {
153        // All input from cache (common on subsequent turns)
154        let rates = CostRates {
155            input_per_million: 3.0,
156            output_per_million: 15.0,
157            cache_read_multiplier: Some(0.1),
158            cache_creation_multiplier: Some(1.25),
159        };
160        let cost = rates.compute_with_cache(0, 200, 13_000, 0);
161        let expected = (13_000.0 * 0.3 + 200.0 * 15.0) / 1_000_000.0;
162        assert!(
163            (cost - expected).abs() < 1e-12,
164            "expected {}, got {}",
165            expected,
166            cost
167        );
168    }
169
170    #[test]
171    fn cost_rates_cache_creation_only() {
172        // First turn: system prompt written to cache, no reads yet
173        let rates = CostRates {
174            input_per_million: 3.0,
175            output_per_million: 15.0,
176            cache_read_multiplier: Some(0.1),
177            cache_creation_multiplier: Some(1.25),
178        };
179        let cost = rates.compute_with_cache(500, 452, 0, 13_000);
180        let expected = (500.0 * 3.0 + 13_000.0 * 3.75 + 452.0 * 15.0) / 1_000_000.0;
181        assert!(
182            (cost - expected).abs() < 1e-12,
183            "expected {}, got {}",
184            expected,
185            cost
186        );
187    }
188
189    #[test]
190    fn cost_rates_no_cache_multipliers_bills_at_standard_rate() {
191        // Providers without caching (OpenAI, Gemini) — cache tokens billed at input rate
192        let rates = CostRates {
193            input_per_million: 2.0,
194            output_per_million: 8.0,
195            cache_read_multiplier: None,
196            cache_creation_multiplier: None,
197        };
198        let cost = rates.compute_with_cache(1000, 500, 5000, 3000);
199        // All input tokens (1000 + 5000 + 3000) at $2/M + 500 output at $8/M
200        let expected = (9000.0 * 2.0 + 500.0 * 8.0) / 1_000_000.0;
201        assert!(
202            (cost - expected).abs() < 1e-12,
203            "expected {}, got {}",
204            expected,
205            cost
206        );
207    }
208
209    #[test]
210    fn multi_turn_cost_accumulation_with_cache() {
211        // Simulates the accumulation pattern from run_agent_loop in query.rs:
212        // total_cost += rates.compute_with_cache(...) per turn.
213        let rates = CostRates {
214            input_per_million: 3.0, // Sonnet
215            output_per_million: 15.0,
216            cache_read_multiplier: Some(0.1),
217            cache_creation_multiplier: Some(1.25),
218        };
219
220        let mut total_cost: f64 = 0.0;
221
222        // Turn 1: first request, system prompt written to cache
223        // API returns: input_tokens=500 (uncached), cache_creation=12000, cache_read=0
224        total_cost += rates.compute_with_cache(500, 800, 0, 12_000);
225
226        // Turn 2: tool use follow-up, system prompt now served from cache
227        // API returns: input_tokens=200 (new user msg), cache_creation=0, cache_read=12000
228        total_cost += rates.compute_with_cache(200, 400, 12_000, 0);
229
230        // Turn 3: another follow-up, still reading from cache
231        // API returns: input_tokens=300, cache_creation=0, cache_read=12000
232        total_cost += rates.compute_with_cache(300, 600, 12_000, 0);
233
234        // Verify total
235        let turn1 = (500.0 * 3.0 + 12_000.0 * 3.75 + 800.0 * 15.0) / 1_000_000.0;
236        let turn2 = (200.0 * 3.0 + 12_000.0 * 0.3 + 400.0 * 15.0) / 1_000_000.0;
237        let turn3 = (300.0 * 3.0 + 12_000.0 * 0.3 + 600.0 * 15.0) / 1_000_000.0;
238        let expected = turn1 + turn2 + turn3;
239
240        assert!(
241            (total_cost - expected).abs() < 1e-12,
242            "multi-turn total: expected {}, got {}",
243            expected,
244            total_cost
245        );
246
247        // Sanity: cache reads should make turns 2-3 much cheaper than turn 1
248        assert!(
249            turn2 < turn1,
250            "turn 2 should be cheaper than turn 1 (cache reads vs creation)"
251        );
252        assert!(
253            turn3 < turn1,
254            "turn 3 should be cheaper than turn 1 (cache reads vs creation)"
255        );
256    }
257}