Skip to main content

agent_code_lib/llm/
client.rs

1//! HTTP streaming client for LLM APIs.
2//!
3//! Sends conversation messages to an LLM API and streams back response
4//! events via Server-Sent Events (SSE). Features:
5//!
6//! - Prompt caching with cache_control markers
7//! - Beta header negotiation (thinking, structured outputs, effort)
8//! - Retry with exponential backoff and fallback model
9//! - Tool choice constraints
10//! - Thinking/reasoning token configuration
11
12use std::time::Duration;
13
14use futures::StreamExt;
15use reqwest::header::{CONTENT_TYPE, HeaderMap, HeaderValue};
16use tokio::sync::mpsc;
17use tracing::{debug, warn};
18
19use crate::error::LlmError;
20use crate::llm::message::{Message, messages_to_api_params};
21use crate::llm::stream::{RawSseEvent, StreamEvent, StreamParser};
22use crate::tools::ToolSchema;
23
24/// Client for communicating with an LLM API.
25pub struct LlmClient {
26    http: reqwest::Client,
27    base_url: String,
28    api_key: String,
29    model: String,
30}
31
32/// Configuration for thinking/reasoning behavior.
33#[derive(Debug, Clone, Default)]
34pub enum ThinkingMode {
35    /// Let the model decide when to think.
36    #[default]
37    Adaptive,
38    /// Always enable extended thinking with a token budget.
39    Enabled { budget_tokens: u32 },
40    /// Disable extended thinking.
41    Disabled,
42}
43
44/// Controls how the model selects tools.
45#[derive(Debug, Clone)]
46pub enum ToolChoice {
47    /// Model decides whether and which tools to use.
48    Auto,
49    /// Model must use the specified tool.
50    Specific { name: String },
51    /// Model must not use any tools.
52    None,
53}
54
55/// Agent effort level (influences thoroughness and token usage).
56#[derive(Debug, Clone, Copy)]
57pub enum EffortLevel {
58    Low,
59    Medium,
60    High,
61}
62
63/// A request to the LLM API.
64pub struct CompletionRequest<'a> {
65    pub messages: &'a [Message],
66    pub system_prompt: &'a str,
67    pub tools: &'a [ToolSchema],
68    pub max_tokens: Option<u32>,
69    /// Tool selection constraint.
70    pub tool_choice: Option<ToolChoice>,
71    /// Thinking/reasoning configuration.
72    pub thinking: Option<ThinkingMode>,
73    /// Effort level for the response.
74    pub effort: Option<EffortLevel>,
75    /// JSON schema for structured output mode.
76    pub output_schema: Option<serde_json::Value>,
77    /// Enable prompt caching.
78    pub enable_caching: bool,
79    /// Fallback model if primary is overloaded.
80    pub fallback_model: Option<String>,
81    /// Temperature override.
82    pub temperature: Option<f64>,
83}
84
85impl<'a> CompletionRequest<'a> {
86    /// Create a simple request with just messages and system prompt.
87    pub fn simple(
88        messages: &'a [Message],
89        system_prompt: &'a str,
90        tools: &'a [ToolSchema],
91        max_tokens: Option<u32>,
92    ) -> Self {
93        Self {
94            messages,
95            system_prompt,
96            tools,
97            max_tokens,
98            tool_choice: None,
99            thinking: None,
100            effort: None,
101            output_schema: None,
102            enable_caching: true,
103            fallback_model: None,
104            temperature: None,
105        }
106    }
107}
108
109impl LlmClient {
110    pub fn new(base_url: &str, api_key: &str, model: &str) -> Self {
111        let http = reqwest::Client::builder()
112            .timeout(Duration::from_secs(300))
113            .build()
114            .expect("failed to build HTTP client");
115
116        Self {
117            http,
118            base_url: base_url.trim_end_matches('/').to_string(),
119            api_key: api_key.to_string(),
120            model: model.to_string(),
121        }
122    }
123
124    /// Stream a completion request, yielding `StreamEvent` values.
125    pub async fn stream_completion(
126        &self,
127        request: CompletionRequest<'_>,
128    ) -> Result<mpsc::Receiver<StreamEvent>, LlmError> {
129        let model = request
130            .fallback_model
131            .clone()
132            .unwrap_or_else(|| self.model.clone());
133
134        self.stream_with_model(&model, request).await
135    }
136
137    async fn stream_with_model(
138        &self,
139        model: &str,
140        request: CompletionRequest<'_>,
141    ) -> Result<mpsc::Receiver<StreamEvent>, LlmError> {
142        let url = format!("{}/messages", self.base_url);
143
144        // Build headers with beta features.
145        let mut headers = HeaderMap::new();
146        headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
147        headers.insert(
148            "x-api-key",
149            HeaderValue::from_str(&self.api_key).map_err(|e| LlmError::AuthError(e.to_string()))?,
150        );
151        headers.insert("anthropic-version", HeaderValue::from_static("2023-06-01"));
152
153        // Collect beta features to enable.
154        let mut betas: Vec<&str> = Vec::new();
155
156        if request.thinking.is_some() {
157            betas.push("interleaved-thinking-2025-05-14");
158        }
159        if request.output_schema.is_some() {
160            betas.push("structured-outputs-2025-05-14");
161        }
162        if request.enable_caching {
163            betas.push("prompt-caching-2024-07-31");
164        }
165        if request.effort.is_some() {
166            betas.push("effort-control-2025-01-24");
167        }
168
169        if !betas.is_empty() {
170            headers.insert(
171                "anthropic-beta",
172                HeaderValue::from_str(&betas.join(",")).unwrap_or(HeaderValue::from_static("")),
173            );
174        }
175
176        // Build tool definitions with cache control on the last tool.
177        let tool_count = request.tools.len();
178        let tools_json: Vec<serde_json::Value> = request
179            .tools
180            .iter()
181            .enumerate()
182            .map(|(i, t)| {
183                let mut tool = serde_json::json!({
184                    "name": t.name,
185                    "description": t.description,
186                    "input_schema": t.input_schema,
187                });
188                if request.enable_caching && i == tool_count - 1 && tool_count > 0 {
189                    tool["cache_control"] = serde_json::json!({"type": "ephemeral"});
190                }
191                tool
192            })
193            .collect();
194
195        // Build system prompt with cache control.
196        let system = if request.enable_caching {
197            serde_json::json!([{
198                "type": "text",
199                "text": request.system_prompt,
200                "cache_control": { "type": "ephemeral" }
201            }])
202        } else {
203            serde_json::json!(request.system_prompt)
204        };
205
206        // Build request body.
207        let mut body = serde_json::json!({
208            "model": model,
209            "max_tokens": request.max_tokens.unwrap_or(16384),
210            "stream": true,
211            "system": system,
212            "messages": messages_to_api_params(request.messages),
213            "tools": tools_json,
214        });
215
216        // Add optional parameters.
217        if let Some(ref tc) = request.tool_choice {
218            body["tool_choice"] = match tc {
219                ToolChoice::Auto => serde_json::json!({"type": "auto"}),
220                ToolChoice::Specific { name } => {
221                    serde_json::json!({"type": "tool", "name": name})
222                }
223                ToolChoice::None => serde_json::json!({"type": "none"}),
224            };
225        }
226
227        if let Some(ref thinking) = request.thinking {
228            match thinking {
229                ThinkingMode::Enabled { budget_tokens } => {
230                    body["thinking"] = serde_json::json!({
231                        "type": "enabled",
232                        "budget_tokens": budget_tokens,
233                    });
234                }
235                ThinkingMode::Disabled => {
236                    body["thinking"] = serde_json::json!({"type": "disabled"});
237                }
238                ThinkingMode::Adaptive => {
239                    // Adaptive is the default — don't send explicit config.
240                }
241            }
242        }
243
244        if let Some(effort) = request.effort {
245            let value = match effort {
246                EffortLevel::Low => "low",
247                EffortLevel::Medium => "medium",
248                EffortLevel::High => "high",
249            };
250            body["metadata"] = serde_json::json!({
251                "effort": value,
252            });
253        }
254
255        if let Some(ref schema) = request.output_schema {
256            body["output_schema"] = schema.clone();
257        }
258
259        if let Some(temp) = request.temperature {
260            body["temperature"] = serde_json::json!(temp);
261        }
262
263        debug!("API request to {url} (model={model})");
264
265        let response = self
266            .http
267            .post(&url)
268            .headers(headers)
269            .json(&body)
270            .send()
271            .await
272            .map_err(|e| LlmError::Http(e.to_string()))?;
273
274        let status = response.status();
275        if !status.is_success() {
276            let body_text = response.text().await.unwrap_or_default();
277
278            if status.as_u16() == 429 {
279                let retry_after = parse_retry_after(&body_text);
280                return Err(LlmError::RateLimited {
281                    retry_after_ms: retry_after,
282                });
283            }
284
285            if status.as_u16() == 529 {
286                // Overloaded — treat like rate limit with longer backoff.
287                return Err(LlmError::RateLimited {
288                    retry_after_ms: 5000,
289                });
290            }
291
292            if status.as_u16() == 401 || status.as_u16() == 403 {
293                return Err(LlmError::AuthError(body_text));
294            }
295
296            return Err(LlmError::Api {
297                status: status.as_u16(),
298                body: body_text,
299            });
300        }
301
302        // Spawn SSE reader task.
303        let (tx, rx) = mpsc::channel(64);
304        tokio::spawn(async move {
305            let mut parser = StreamParser::new();
306            let mut byte_stream = response.bytes_stream();
307            let mut buffer = String::new();
308            let start = std::time::Instant::now();
309            let mut first_token = false;
310
311            while let Some(chunk_result) = byte_stream.next().await {
312                let chunk = match chunk_result {
313                    Ok(c) => c,
314                    Err(e) => {
315                        let _ = tx.send(StreamEvent::Error(e.to_string())).await;
316                        break;
317                    }
318                };
319
320                buffer.push_str(&String::from_utf8_lossy(&chunk));
321
322                while let Some(pos) = buffer.find("\n\n") {
323                    let event_text = buffer[..pos].to_string();
324                    buffer = buffer[pos + 2..].to_string();
325
326                    if let Some(data) = extract_sse_data(&event_text) {
327                        if data == "[DONE]" {
328                            return;
329                        }
330
331                        match serde_json::from_str::<RawSseEvent>(data) {
332                            Ok(raw) => {
333                                let events = parser.process(raw);
334                                for event in events {
335                                    if !first_token && matches!(event, StreamEvent::TextDelta(_)) {
336                                        first_token = true;
337                                        let ttft = start.elapsed().as_millis() as u64;
338                                        let _ = tx.send(StreamEvent::Ttft(ttft)).await;
339                                    }
340                                    if tx.send(event).await.is_err() {
341                                        return;
342                                    }
343                                }
344                            }
345                            Err(e) => {
346                                warn!("SSE parse error: {e}");
347                            }
348                        }
349                    }
350                }
351            }
352        });
353
354        Ok(rx)
355    }
356}
357
358/// Extract the `data:` payload from an SSE event block.
359fn extract_sse_data(event_text: &str) -> Option<&str> {
360    for line in event_text.lines() {
361        if let Some(data) = line.strip_prefix("data: ") {
362            return Some(data);
363        }
364        if let Some(data) = line.strip_prefix("data:") {
365            return Some(data.trim_start());
366        }
367    }
368    None
369}
370
371/// Try to parse a retry-after value from an error response.
372fn parse_retry_after(body: &str) -> u64 {
373    if let Ok(v) = serde_json::from_str::<serde_json::Value>(body)
374        && let Some(retry) = v
375            .get("error")
376            .and_then(|e| e.get("retry_after"))
377            .and_then(|r| r.as_f64())
378    {
379        return (retry * 1000.0) as u64;
380    }
381    1000
382}