1use std::time::Duration;
13
14use futures::StreamExt;
15use reqwest::header::{CONTENT_TYPE, HeaderMap, HeaderValue};
16use tokio::sync::mpsc;
17use tracing::{debug, warn};
18
19use crate::error::LlmError;
20use crate::llm::message::{Message, messages_to_api_params};
21use crate::llm::stream::{RawSseEvent, StreamEvent, StreamParser};
22use crate::tools::ToolSchema;
23
24pub struct LlmClient {
26 http: reqwest::Client,
27 base_url: String,
28 api_key: String,
29 model: String,
30}
31
32#[derive(Debug, Clone, Default)]
34pub enum ThinkingMode {
35 #[default]
37 Adaptive,
38 Enabled { budget_tokens: u32 },
40 Disabled,
42}
43
44#[derive(Debug, Clone)]
46pub enum ToolChoice {
47 Auto,
49 Specific { name: String },
51 None,
53}
54
55#[derive(Debug, Clone, Copy)]
57pub enum EffortLevel {
58 Low,
59 Medium,
60 High,
61}
62
63pub struct CompletionRequest<'a> {
65 pub messages: &'a [Message],
66 pub system_prompt: &'a str,
67 pub tools: &'a [ToolSchema],
68 pub max_tokens: Option<u32>,
69 pub tool_choice: Option<ToolChoice>,
71 pub thinking: Option<ThinkingMode>,
73 pub effort: Option<EffortLevel>,
75 pub output_schema: Option<serde_json::Value>,
77 pub enable_caching: bool,
79 pub fallback_model: Option<String>,
81 pub temperature: Option<f64>,
83}
84
85impl<'a> CompletionRequest<'a> {
86 pub fn simple(
88 messages: &'a [Message],
89 system_prompt: &'a str,
90 tools: &'a [ToolSchema],
91 max_tokens: Option<u32>,
92 ) -> Self {
93 Self {
94 messages,
95 system_prompt,
96 tools,
97 max_tokens,
98 tool_choice: None,
99 thinking: None,
100 effort: None,
101 output_schema: None,
102 enable_caching: true,
103 fallback_model: None,
104 temperature: None,
105 }
106 }
107}
108
109impl LlmClient {
110 pub fn new(base_url: &str, api_key: &str, model: &str) -> Self {
111 let http = reqwest::Client::builder()
112 .timeout(Duration::from_secs(300))
113 .build()
114 .expect("failed to build HTTP client");
115
116 Self {
117 http,
118 base_url: base_url.trim_end_matches('/').to_string(),
119 api_key: api_key.to_string(),
120 model: model.to_string(),
121 }
122 }
123
124 pub async fn stream_completion(
126 &self,
127 request: CompletionRequest<'_>,
128 ) -> Result<mpsc::Receiver<StreamEvent>, LlmError> {
129 let model = request
130 .fallback_model
131 .clone()
132 .unwrap_or_else(|| self.model.clone());
133
134 self.stream_with_model(&model, request).await
135 }
136
137 async fn stream_with_model(
138 &self,
139 model: &str,
140 request: CompletionRequest<'_>,
141 ) -> Result<mpsc::Receiver<StreamEvent>, LlmError> {
142 let url = format!("{}/messages", self.base_url);
143
144 let mut headers = HeaderMap::new();
146 headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
147 headers.insert(
148 "x-api-key",
149 HeaderValue::from_str(&self.api_key).map_err(|e| LlmError::AuthError(e.to_string()))?,
150 );
151 headers.insert("anthropic-version", HeaderValue::from_static("2023-06-01"));
152
153 let mut betas: Vec<&str> = Vec::new();
155
156 if request.thinking.is_some() {
157 betas.push("interleaved-thinking-2025-05-14");
158 }
159 if request.output_schema.is_some() {
160 betas.push("structured-outputs-2025-05-14");
161 }
162 if request.enable_caching {
163 betas.push("prompt-caching-2024-07-31");
164 }
165 if request.effort.is_some() {
166 betas.push("effort-control-2025-01-24");
167 }
168
169 if !betas.is_empty() {
170 headers.insert(
171 "anthropic-beta",
172 HeaderValue::from_str(&betas.join(",")).unwrap_or(HeaderValue::from_static("")),
173 );
174 }
175
176 let tool_count = request.tools.len();
178 let tools_json: Vec<serde_json::Value> = request
179 .tools
180 .iter()
181 .enumerate()
182 .map(|(i, t)| {
183 let mut tool = serde_json::json!({
184 "name": t.name,
185 "description": t.description,
186 "input_schema": t.input_schema,
187 });
188 if request.enable_caching && i == tool_count - 1 && tool_count > 0 {
189 tool["cache_control"] = serde_json::json!({"type": "ephemeral"});
190 }
191 tool
192 })
193 .collect();
194
195 let system = if request.enable_caching {
197 serde_json::json!([{
198 "type": "text",
199 "text": request.system_prompt,
200 "cache_control": { "type": "ephemeral" }
201 }])
202 } else {
203 serde_json::json!(request.system_prompt)
204 };
205
206 let mut body = serde_json::json!({
208 "model": model,
209 "max_tokens": request.max_tokens.unwrap_or(16384),
210 "stream": true,
211 "system": system,
212 "messages": messages_to_api_params(request.messages),
213 "tools": tools_json,
214 });
215
216 if let Some(ref tc) = request.tool_choice {
218 body["tool_choice"] = match tc {
219 ToolChoice::Auto => serde_json::json!({"type": "auto"}),
220 ToolChoice::Specific { name } => {
221 serde_json::json!({"type": "tool", "name": name})
222 }
223 ToolChoice::None => serde_json::json!({"type": "none"}),
224 };
225 }
226
227 if let Some(ref thinking) = request.thinking {
228 match thinking {
229 ThinkingMode::Enabled { budget_tokens } => {
230 body["thinking"] = serde_json::json!({
231 "type": "enabled",
232 "budget_tokens": budget_tokens,
233 });
234 }
235 ThinkingMode::Disabled => {
236 body["thinking"] = serde_json::json!({"type": "disabled"});
237 }
238 ThinkingMode::Adaptive => {
239 }
241 }
242 }
243
244 if let Some(effort) = request.effort {
245 let value = match effort {
246 EffortLevel::Low => "low",
247 EffortLevel::Medium => "medium",
248 EffortLevel::High => "high",
249 };
250 body["metadata"] = serde_json::json!({
251 "effort": value,
252 });
253 }
254
255 if let Some(ref schema) = request.output_schema {
256 body["output_schema"] = schema.clone();
257 }
258
259 if let Some(temp) = request.temperature {
260 body["temperature"] = serde_json::json!(temp);
261 }
262
263 debug!("API request to {url} (model={model})");
264
265 let response = self
266 .http
267 .post(&url)
268 .headers(headers)
269 .json(&body)
270 .send()
271 .await
272 .map_err(|e| LlmError::Http(e.to_string()))?;
273
274 let status = response.status();
275 if !status.is_success() {
276 let body_text = response.text().await.unwrap_or_default();
277
278 if status.as_u16() == 429 {
279 let retry_after = parse_retry_after(&body_text);
280 return Err(LlmError::RateLimited {
281 retry_after_ms: retry_after,
282 });
283 }
284
285 if status.as_u16() == 529 {
286 return Err(LlmError::RateLimited {
288 retry_after_ms: 5000,
289 });
290 }
291
292 if status.as_u16() == 401 || status.as_u16() == 403 {
293 return Err(LlmError::AuthError(body_text));
294 }
295
296 return Err(LlmError::Api {
297 status: status.as_u16(),
298 body: body_text,
299 });
300 }
301
302 let (tx, rx) = mpsc::channel(64);
304 tokio::spawn(async move {
305 let mut parser = StreamParser::new();
306 let mut byte_stream = response.bytes_stream();
307 let mut buffer = String::new();
308 let start = std::time::Instant::now();
309 let mut first_token = false;
310
311 while let Some(chunk_result) = byte_stream.next().await {
312 let chunk = match chunk_result {
313 Ok(c) => c,
314 Err(e) => {
315 let _ = tx.send(StreamEvent::Error(e.to_string())).await;
316 break;
317 }
318 };
319
320 buffer.push_str(&String::from_utf8_lossy(&chunk));
321
322 while let Some(pos) = buffer.find("\n\n") {
323 let event_text = buffer[..pos].to_string();
324 buffer = buffer[pos + 2..].to_string();
325
326 if let Some(data) = extract_sse_data(&event_text) {
327 if data == "[DONE]" {
328 return;
329 }
330
331 match serde_json::from_str::<RawSseEvent>(data) {
332 Ok(raw) => {
333 let events = parser.process(raw);
334 for event in events {
335 if !first_token && matches!(event, StreamEvent::TextDelta(_)) {
336 first_token = true;
337 let ttft = start.elapsed().as_millis() as u64;
338 let _ = tx.send(StreamEvent::Ttft(ttft)).await;
339 }
340 if tx.send(event).await.is_err() {
341 return;
342 }
343 }
344 }
345 Err(e) => {
346 warn!("SSE parse error: {e}");
347 }
348 }
349 }
350 }
351 }
352 });
353
354 Ok(rx)
355 }
356}
357
358fn extract_sse_data(event_text: &str) -> Option<&str> {
360 for line in event_text.lines() {
361 if let Some(data) = line.strip_prefix("data: ") {
362 return Some(data);
363 }
364 if let Some(data) = line.strip_prefix("data:") {
365 return Some(data.trim_start());
366 }
367 }
368 None
369}
370
371fn parse_retry_after(body: &str) -> u64 {
373 if let Ok(v) = serde_json::from_str::<serde_json::Value>(body)
374 && let Some(retry) = v
375 .get("error")
376 .and_then(|e| e.get("retry_after"))
377 .and_then(|r| r.as_f64())
378 {
379 return (retry * 1000.0) as u64;
380 }
381 1000
382}