1use anyhow::Result;
10use reqwest::Client;
11use tokio::process::Command;
12
13use crate::types::*;
14
15pub const PRIMARY: ModelConfig = ModelConfig {
16 id: "claude-opus-4-6",
17 max_tokens: 16384,
18};
19
20pub const AUXILIARY: ModelConfig = ModelConfig {
21 id: "claude-haiku-4-5-20251001",
22 max_tokens: 8192,
23};
24
25const API_URL: &str = "https://api.anthropic.com/v1/messages";
26const API_VERSION: &str = "2023-06-01";
27
28#[derive(Debug, Clone, Copy, PartialEq)]
29pub enum InferenceBackend {
30 Cli, Api, }
33
34pub struct InferenceEngine {
35 client: Client,
36 api_key: Option<String>,
37 backend: InferenceBackend,
38}
39
40impl InferenceEngine {
41 pub fn new(api_key: Option<&str>, backend: InferenceBackend) -> Self {
42 Self {
43 client: Client::new(),
44 api_key: api_key.map(String::from),
45 backend,
46 }
47 }
48
49 pub async fn chat_stream(
51 &self,
52 request: &InferenceRequest,
53 on_text: &mut dyn FnMut(&str),
54 ) -> Result<InferenceResponse> {
55 match self.backend {
56 InferenceBackend::Cli => self.chat_stream_cli(request, on_text).await,
57 InferenceBackend::Api => self.chat_stream_api(request, on_text).await,
58 }
59 }
60
61 async fn chat_stream_api(
63 &self,
64 request: &InferenceRequest,
65 on_text: &mut dyn FnMut(&str),
66 ) -> Result<InferenceResponse> {
67 let api_key = self.api_key.as_deref()
68 .ok_or_else(|| anyhow::anyhow!("API backend requires ANTHROPIC_API_KEY"))?;
69
70 let mut body = serde_json::to_value(request)?;
71 body.as_object_mut().unwrap().insert("stream".to_string(), serde_json::Value::Bool(true));
72
73 let response = self.client
74 .post(API_URL)
75 .header("x-api-key", api_key)
76 .header("anthropic-version", API_VERSION)
77 .header("content-type", "application/json")
78 .json(&body)
79 .send()
80 .await?;
81
82 let status = response.status();
83 if !status.is_success() {
84 let body = response.text().await.unwrap_or_default();
85 anyhow::bail!("API error {status}: {body}");
86 }
87
88 parse_sse_stream(response, on_text).await
89 }
90
91 async fn chat_stream_cli(
93 &self,
94 request: &InferenceRequest,
95 on_text: &mut dyn FnMut(&str),
96 ) -> Result<InferenceResponse> {
97 let prompt = extract_last_user_text(&request.messages);
98 if prompt.is_empty() {
99 anyhow::bail!("No user message to send");
100 }
101
102 let mut args = vec![
103 "-p".to_string(),
104 "--output-format".to_string(),
105 "stream-json".to_string(),
106 "--verbose".to_string(),
107 ];
108
109 if let Some(ref system) = request.system {
110 args.push("--append-system-prompt".to_string());
111 args.push(system.clone());
112 }
113
114 if let Some(ref tools) = request.tools {
115 let tool_names: Vec<String> = tools.iter().map(|t| {
116 match t.name.as_str() {
117 "bash" => "Bash".to_string(),
118 "read" => "Read".to_string(),
119 "edit" => "Edit".to_string(),
120 "write" => "Write".to_string(),
121 "glob" => "Glob".to_string(),
122 "grep" => "Grep".to_string(),
123 "web_fetch" => "WebFetch".to_string(),
124 other => other.to_string(),
125 }
126 }).collect();
127 args.push("--allowedTools".to_string());
128 args.push(tool_names.join(","));
129 }
130
131 args.push("--".to_string());
133 args.push(prompt);
134
135 use tokio::io::{AsyncBufReadExt, BufReader};
136
137 let mut child = Command::new("claude")
138 .args(&args)
139 .stdin(std::process::Stdio::null())
140 .stdout(std::process::Stdio::piped())
141 .stderr(std::process::Stdio::piped())
142 .env_remove("CLAUDECODE")
143 .env_remove("CLAUDE_CODE_ENTRYPOINT")
144 .env_remove("ANTHROPIC_API_KEY")
145 .spawn()?;
146
147 let stdout = child.stdout.take()
148 .ok_or_else(|| anyhow::anyhow!("Failed to capture CLI stdout"))?;
149 let reader = BufReader::new(stdout);
150 let mut lines = reader.lines();
151
152 let mut result_event: Option<serde_json::Value> = None;
153 let mut full_text = String::new();
154
155 while let Ok(Some(line)) = lines.next_line().await {
156 if line.trim().is_empty() { continue; }
157
158 let event: serde_json::Value = match serde_json::from_str(&line) {
159 Ok(v) => v,
160 Err(_) => continue,
161 };
162
163 let event_type = event.get("type").and_then(|v| v.as_str()).unwrap_or("");
164
165 match event_type {
166 "content_block_delta" => {
167 if let Some(delta) = event.get("delta") {
168 if let Some(text) = delta.get("text").and_then(|v| v.as_str()) {
169 on_text(text);
170 full_text.push_str(text);
171 }
172 }
173 }
174 "assistant" => {
175 if let Some(message) = event.get("message") {
177 if let Some(content) = message.get("content").and_then(|c| c.as_array()) {
178 for block in content {
179 if let Some(text) = block.get("text").and_then(|t| t.as_str()) {
180 on_text(text);
181 full_text.push_str(text);
182 }
183 }
184 }
185 if let Some(err) = event.get("error").and_then(|e| e.as_str()) {
187 if err == "billing_error" {
188 let msg = full_text.clone();
189 if !msg.is_empty() {
190 anyhow::bail!("{msg}");
191 }
192 }
193 }
194 }
195 }
196 "result" => {
197 if event.get("is_error").and_then(|v| v.as_bool()).unwrap_or(false) {
199 let msg = event.get("result").and_then(|v| v.as_str()).unwrap_or("Unknown error");
200 if msg.contains("Credit balance") || msg.contains("billing") || msg.contains("auth") {
201 anyhow::bail!("{msg}. Check your Claude subscription or API credits.");
202 }
203 }
204 result_event = Some(event);
205 }
206 _ => {}
207 }
208 }
209
210 let _status = child.wait().await?;
211
212 let usage = result_event.as_ref().map(|r| {
214 let u = r.get("usage").unwrap_or(&serde_json::Value::Null);
215 Usage {
216 input_tokens: u.get("input_tokens").and_then(|v| v.as_u64()).unwrap_or(0),
217 output_tokens: u.get("output_tokens").and_then(|v| v.as_u64()).unwrap_or(0),
218 cache_read_input_tokens: u.get("cache_read_input_tokens").and_then(|v| v.as_u64()).unwrap_or(0),
219 cache_creation_input_tokens: u.get("cache_creation_input_tokens").and_then(|v| v.as_u64()).unwrap_or(0),
220 }
221 }).unwrap_or_default();
222
223 let cost = result_event.as_ref()
224 .and_then(|r| r.get("total_cost_usd"))
225 .and_then(|v| v.as_f64());
226
227 let duration_ms = result_event.as_ref()
228 .and_then(|r| r.get("duration_api_ms"))
229 .and_then(|v| v.as_u64())
230 .unwrap_or(0);
231
232 if !full_text.is_empty() {
233 Ok(InferenceResponse {
234 content: vec![ContentBlock::Text { text: full_text }],
235 stop_reason: Some("end_turn".to_string()),
236 usage,
237 model: "cli-stream".to_string(),
238 cli_meta: Some(CliMeta {
239 cost_usd: cost.unwrap_or(0.0),
240 duration_ms,
241 duration_api_ms: result_event.as_ref()
242 .and_then(|r| r.get("duration_api_ms"))
243 .and_then(|v| v.as_u64())
244 .unwrap_or(0),
245 num_turns: result_event.as_ref()
246 .and_then(|r| r.get("num_turns"))
247 .and_then(|v| v.as_u64())
248 .unwrap_or(1),
249 }),
250 })
251 } else if let Some(ref r) = result_event {
252 let result_text = r.get("result").and_then(|v| v.as_str()).unwrap_or("");
254 if r.get("is_error").and_then(|v| v.as_bool()).unwrap_or(false) {
255 anyhow::bail!("Claude CLI error: {result_text}");
256 }
257 Ok(InferenceResponse {
258 content: vec![ContentBlock::Text { text: result_text.to_string() }],
259 stop_reason: Some("end_turn".to_string()),
260 usage,
261 model: "cli-stream".to_string(),
262 cli_meta: None,
263 })
264 } else {
265 anyhow::bail!("CLI stream ended with no output")
266 }
267 }
268
269 pub fn build_request(
270 &self,
271 messages: &[Message],
272 system: Option<&str>,
273 tools: &[ToolDefinition],
274 model: Option<&str>,
275 ) -> InferenceRequest {
276 InferenceRequest {
277 model: model.unwrap_or(PRIMARY.id).to_string(),
278 max_tokens: PRIMARY.max_tokens,
279 messages: messages.to_vec(),
280 system: system.map(String::from),
281 tools: if tools.is_empty() {
282 None
283 } else {
284 Some(tools.to_vec())
285 },
286 }
287 }
288
289 pub fn auto_route(&self, messages: &[Message]) -> &'static str {
292 let last_text = extract_last_user_text(messages);
293 let complexity = estimate_complexity(&last_text);
294 if complexity >= ComplexityLevel::High {
295 PRIMARY.id
296 } else {
297 AUXILIARY.id
298 }
299 }
300}
301
302fn extract_last_user_text(messages: &[Message]) -> String {
303 for msg in messages.iter().rev() {
304 if matches!(msg.role, Role::User) {
305 match &msg.content {
306 MessageContent::Text(t) => return t.clone(),
307 MessageContent::Blocks(blocks) => {
308 for block in blocks {
309 if let ContentBlock::Text { text } = block {
310 return text.clone();
311 }
312 }
313 }
314 }
315 }
316 }
317 String::new()
318}
319
320async fn parse_sse_stream(
322 response: reqwest::Response,
323 on_text: &mut dyn FnMut(&str),
324) -> Result<InferenceResponse> {
325 use futures_util::StreamExt;
326
327 let mut stream = response.bytes_stream();
328 let mut buffer = String::new();
329
330 let mut content_blocks: Vec<ContentBlock> = Vec::new();
331 let mut current_text = String::new();
332 let mut current_tool_id = String::new();
333 let mut current_tool_name = String::new();
334 let mut current_tool_input = String::new();
335 let mut in_tool = false;
336 let mut stop_reason: Option<String> = None;
337 let mut model = String::new();
338 let mut usage = Usage::default();
339
340 while let Some(chunk) = stream.next().await {
341 let chunk = chunk?;
342 buffer.push_str(&String::from_utf8_lossy(&chunk));
343
344 while let Some(pos) = buffer.find("\n\n") {
345 let event_block = buffer[..pos].to_string();
346 buffer = buffer[pos + 2..].to_string();
347
348 let data = event_block.lines()
349 .find(|l| l.starts_with("data: "))
350 .map(|l| &l[6..]);
351
352 let Some(data) = data else { continue };
353 if data == "[DONE]" { continue; }
354
355 let event: serde_json::Value = match serde_json::from_str(data) {
356 Ok(v) => v,
357 Err(_) => continue,
358 };
359
360 let etype = event.get("type").and_then(|v| v.as_str()).unwrap_or("");
361
362 match etype {
363 "message_start" => {
364 if let Some(msg) = event.get("message") {
365 model = msg.get("model").and_then(|v| v.as_str()).unwrap_or("").to_string();
366 if let Some(u) = msg.get("usage") {
367 usage.input_tokens = u.get("input_tokens").and_then(|v| v.as_u64()).unwrap_or(0);
368 usage.cache_read_input_tokens = u.get("cache_read_input_tokens").and_then(|v| v.as_u64()).unwrap_or(0);
369 usage.cache_creation_input_tokens = u.get("cache_creation_input_tokens").and_then(|v| v.as_u64()).unwrap_or(0);
370 }
371 }
372 }
373 "content_block_start" => {
374 if let Some(cb) = event.get("content_block") {
375 match cb.get("type").and_then(|v| v.as_str()) {
376 Some("text") => { in_tool = false; }
377 Some("tool_use") => {
378 in_tool = true;
379 current_tool_id = cb.get("id").and_then(|v| v.as_str()).unwrap_or("").to_string();
380 current_tool_name = cb.get("name").and_then(|v| v.as_str()).unwrap_or("").to_string();
381 current_tool_input.clear();
382 }
383 _ => {}
384 }
385 }
386 }
387 "content_block_delta" => {
388 if let Some(delta) = event.get("delta") {
389 match delta.get("type").and_then(|v| v.as_str()) {
390 Some("text_delta") => {
391 if let Some(text) = delta.get("text").and_then(|v| v.as_str()) {
392 on_text(text);
393 current_text.push_str(text);
394 }
395 }
396 Some("input_json_delta") => {
397 if let Some(json) = delta.get("partial_json").and_then(|v| v.as_str()) {
398 current_tool_input.push_str(json);
399 }
400 }
401 _ => {}
402 }
403 }
404 }
405 "content_block_stop" => {
406 if !in_tool && !current_text.is_empty() {
407 content_blocks.push(ContentBlock::Text { text: current_text.clone() });
408 current_text.clear();
409 }
410 if in_tool && !current_tool_name.is_empty() {
411 let input = serde_json::from_str(¤t_tool_input)
412 .unwrap_or(serde_json::Value::Object(Default::default()));
413 content_blocks.push(ContentBlock::ToolUse {
414 id: current_tool_id.clone(),
415 name: current_tool_name.clone(),
416 input,
417 });
418 current_tool_name.clear();
419 current_tool_input.clear();
420 in_tool = false;
421 }
422 }
423 "message_delta" => {
424 if let Some(delta) = event.get("delta") {
425 stop_reason = delta.get("stop_reason")
426 .and_then(|v| v.as_str())
427 .map(String::from);
428 }
429 if let Some(u) = event.get("usage") {
430 usage.output_tokens = u.get("output_tokens").and_then(|v| v.as_u64()).unwrap_or(0);
431 }
432 }
433 _ => {}
434 }
435 }
436 }
437
438 if !current_text.is_empty() {
439 content_blocks.push(ContentBlock::Text { text: current_text });
440 }
441
442 Ok(InferenceResponse {
443 content: content_blocks,
444 stop_reason,
445 usage,
446 model,
447 cli_meta: None,
448 })
449}
450
451#[derive(Debug, PartialEq, PartialOrd)]
454enum ComplexityLevel {
455 Low, Medium, High, }
459
460fn estimate_complexity(text: &str) -> ComplexityLevel {
461 let lower = text.to_lowercase();
462 let words: Vec<&str> = text.split_whitespace().collect();
463 let word_count = words.len();
464
465 let high_signals = [
467 "refactor", "architect", "design", "migrate", "implement",
468 "debug", "investigate", "analyze", "review", "security",
469 "optimize", "performance", "deploy", "infrastructure",
470 "multiple files", "multi-file", "across the codebase",
471 ];
472 let has_high = high_signals.iter().any(|s| lower.contains(s));
473
474 let low_signals = [
476 "what is", "how do", "explain", "hello", "hi ", "thanks",
477 "what's", "define", "list", "show me", "tell me",
478 ];
479 let has_low = low_signals.iter().any(|s| lower.contains(s));
480
481 let has_code = text.contains("```") || text.contains("fn ") || text.contains("def ");
483 let has_paths = text.contains('/') && text.contains('.');
484
485 if has_high || word_count > 50 || (has_code && has_paths) {
486 ComplexityLevel::High
487 } else if has_low && word_count < 15 && !has_code {
488 ComplexityLevel::Low
489 } else {
490 ComplexityLevel::Medium
491 }
492}