1use std::time::Duration;
13
14use futures::StreamExt;
15use reqwest::header::{CONTENT_TYPE, HeaderMap, HeaderValue};
16use tokio::sync::mpsc;
17use tracing::{debug, warn};
18
19use crate::error::LlmError;
20use crate::llm::message::{Message, messages_to_api_params};
21use crate::llm::stream::{RawSseEvent, StreamEvent, StreamParser};
22use crate::tools::ToolSchema;
23
24pub struct LlmClient {
26 http: reqwest::Client,
27 base_url: String,
28 api_key: String,
29 model: String,
30}
31
32#[derive(Debug, Clone, Default)]
34pub enum ThinkingMode {
35 #[default]
37 Adaptive,
38 Enabled { budget_tokens: u32 },
40 Disabled,
42}
43
44#[derive(Debug, Clone)]
46pub enum ToolChoice {
47 Auto,
49 Specific { name: String },
51 None,
53}
54
55#[derive(Debug, Clone, Copy)]
57pub enum EffortLevel {
58 Low,
59 Medium,
60 High,
61}
62
63pub struct CompletionRequest<'a> {
65 pub messages: &'a [Message],
66 pub system_prompt: &'a str,
67 pub tools: &'a [ToolSchema],
68 pub max_tokens: Option<u32>,
69 pub tool_choice: Option<ToolChoice>,
71 pub thinking: Option<ThinkingMode>,
73 pub effort: Option<EffortLevel>,
75 pub output_schema: Option<serde_json::Value>,
77 pub enable_caching: bool,
79 pub fallback_model: Option<String>,
81 pub temperature: Option<f64>,
83}
84
85impl<'a> CompletionRequest<'a> {
86 pub fn simple(
88 messages: &'a [Message],
89 system_prompt: &'a str,
90 tools: &'a [ToolSchema],
91 max_tokens: Option<u32>,
92 ) -> Self {
93 Self {
94 messages,
95 system_prompt,
96 tools,
97 max_tokens,
98 tool_choice: None,
99 thinking: None,
100 effort: None,
101 output_schema: None,
102 enable_caching: true,
103 fallback_model: None,
104 temperature: None,
105 }
106 }
107}
108
109impl LlmClient {
110 pub fn new(base_url: &str, api_key: &str, model: &str) -> Self {
111 let http = reqwest::Client::builder()
112 .timeout(Duration::from_secs(300))
113 .build()
114 .expect("failed to build HTTP client");
115
116 Self {
117 http,
118 base_url: base_url.trim_end_matches('/').to_string(),
119 api_key: api_key.to_string(),
120 model: model.to_string(),
121 }
122 }
123
124 pub async fn stream_completion(
126 &self,
127 request: CompletionRequest<'_>,
128 ) -> Result<mpsc::Receiver<StreamEvent>, LlmError> {
129 let model = request
130 .fallback_model
131 .clone()
132 .unwrap_or_else(|| self.model.clone());
133
134 self.stream_with_model(&model, request).await
135 }
136
137 async fn stream_with_model(
138 &self,
139 model: &str,
140 request: CompletionRequest<'_>,
141 ) -> Result<mpsc::Receiver<StreamEvent>, LlmError> {
142 let url = format!("{}/messages", self.base_url);
143
144 let mut headers = HeaderMap::new();
146 headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
147 headers.insert(
148 "x-api-key",
149 HeaderValue::from_str(&self.api_key).map_err(|e| LlmError::AuthError(e.to_string()))?,
150 );
151 headers.insert("anthropic-version", HeaderValue::from_static("2023-06-01"));
152
153 let mut betas: Vec<&str> = Vec::new();
155
156 if request.thinking.is_some() {
157 betas.push("interleaved-thinking-2025-05-14");
158 }
159 if request.output_schema.is_some() {
160 betas.push("structured-outputs-2025-05-14");
161 }
162 if request.enable_caching {
163 betas.push("prompt-caching-2024-07-31");
164 }
165 if request.effort.is_some() {
166 betas.push("effort-control-2025-01-24");
167 }
168
169 if !betas.is_empty() {
170 headers.insert(
171 "anthropic-beta",
172 HeaderValue::from_str(&betas.join(",")).unwrap_or(HeaderValue::from_static("")),
173 );
174 }
175
176 let tools_json: Vec<serde_json::Value> = request
178 .tools
179 .iter()
180 .map(|t| {
181 serde_json::json!({
182 "name": t.name,
183 "description": t.description,
184 "input_schema": t.input_schema,
185 })
186 })
187 .collect();
188
189 let system = if request.enable_caching {
191 serde_json::json!([{
192 "type": "text",
193 "text": request.system_prompt,
194 "cache_control": { "type": "ephemeral" }
195 }])
196 } else {
197 serde_json::json!(request.system_prompt)
198 };
199
200 let mut body = serde_json::json!({
202 "model": model,
203 "max_tokens": request.max_tokens.unwrap_or(16384),
204 "stream": true,
205 "system": system,
206 "messages": messages_to_api_params(request.messages),
207 "tools": tools_json,
208 });
209
210 if let Some(ref tc) = request.tool_choice {
212 body["tool_choice"] = match tc {
213 ToolChoice::Auto => serde_json::json!({"type": "auto"}),
214 ToolChoice::Specific { name } => {
215 serde_json::json!({"type": "tool", "name": name})
216 }
217 ToolChoice::None => serde_json::json!({"type": "none"}),
218 };
219 }
220
221 if let Some(ref thinking) = request.thinking {
222 match thinking {
223 ThinkingMode::Enabled { budget_tokens } => {
224 body["thinking"] = serde_json::json!({
225 "type": "enabled",
226 "budget_tokens": budget_tokens,
227 });
228 }
229 ThinkingMode::Disabled => {
230 body["thinking"] = serde_json::json!({"type": "disabled"});
231 }
232 ThinkingMode::Adaptive => {
233 }
235 }
236 }
237
238 if let Some(effort) = request.effort {
239 let value = match effort {
240 EffortLevel::Low => "low",
241 EffortLevel::Medium => "medium",
242 EffortLevel::High => "high",
243 };
244 body["metadata"] = serde_json::json!({
245 "effort": value,
246 });
247 }
248
249 if let Some(ref schema) = request.output_schema {
250 body["output_schema"] = schema.clone();
251 }
252
253 if let Some(temp) = request.temperature {
254 body["temperature"] = serde_json::json!(temp);
255 }
256
257 debug!("API request to {url} (model={model})");
258
259 let response = self
260 .http
261 .post(&url)
262 .headers(headers)
263 .json(&body)
264 .send()
265 .await
266 .map_err(|e| LlmError::Http(e.to_string()))?;
267
268 let status = response.status();
269 if !status.is_success() {
270 let body_text = response.text().await.unwrap_or_default();
271
272 if status.as_u16() == 429 {
273 let retry_after = parse_retry_after(&body_text);
274 return Err(LlmError::RateLimited {
275 retry_after_ms: retry_after,
276 });
277 }
278
279 if status.as_u16() == 529 {
280 return Err(LlmError::RateLimited {
282 retry_after_ms: 5000,
283 });
284 }
285
286 if status.as_u16() == 401 || status.as_u16() == 403 {
287 return Err(LlmError::AuthError(body_text));
288 }
289
290 return Err(LlmError::Api {
291 status: status.as_u16(),
292 body: body_text,
293 });
294 }
295
296 let (tx, rx) = mpsc::channel(64);
298 tokio::spawn(async move {
299 let mut parser = StreamParser::new();
300 let mut byte_stream = response.bytes_stream();
301 let mut buffer = String::new();
302 let start = std::time::Instant::now();
303 let mut first_token = false;
304
305 while let Some(chunk_result) = byte_stream.next().await {
306 let chunk = match chunk_result {
307 Ok(c) => c,
308 Err(e) => {
309 let _ = tx.send(StreamEvent::Error(e.to_string())).await;
310 break;
311 }
312 };
313
314 buffer.push_str(&String::from_utf8_lossy(&chunk));
315
316 while let Some(pos) = buffer.find("\n\n") {
317 let event_text = buffer[..pos].to_string();
318 buffer = buffer[pos + 2..].to_string();
319
320 if let Some(data) = extract_sse_data(&event_text) {
321 if data == "[DONE]" {
322 return;
323 }
324
325 match serde_json::from_str::<RawSseEvent>(data) {
326 Ok(raw) => {
327 let events = parser.process(raw);
328 for event in events {
329 if !first_token && matches!(event, StreamEvent::TextDelta(_)) {
330 first_token = true;
331 let ttft = start.elapsed().as_millis() as u64;
332 let _ = tx.send(StreamEvent::Ttft(ttft)).await;
333 }
334 if tx.send(event).await.is_err() {
335 return;
336 }
337 }
338 }
339 Err(e) => {
340 warn!("SSE parse error: {e}");
341 }
342 }
343 }
344 }
345 }
346 });
347
348 Ok(rx)
349 }
350}
351
352fn extract_sse_data(event_text: &str) -> Option<&str> {
354 for line in event_text.lines() {
355 if let Some(data) = line.strip_prefix("data: ") {
356 return Some(data);
357 }
358 if let Some(data) = line.strip_prefix("data:") {
359 return Some(data.trim_start());
360 }
361 }
362 None
363}
364
365fn parse_retry_after(body: &str) -> u64 {
367 if let Ok(v) = serde_json::from_str::<serde_json::Value>(body)
368 && let Some(retry) = v
369 .get("error")
370 .and_then(|e| e.get("retry_after"))
371 .and_then(|r| r.as_f64())
372 {
373 return (retry * 1000.0) as u64;
374 }
375 1000
376}