Skip to main content

claude_agent/
inference.rs

1// Layer 4: THE BRAIN CORE — Dual Model Inference
2// Left hemisphere: Opus 4.6 (primary). Right hemisphere: Haiku 4.5 (auxiliary).
3// Crystalline bridge: Prompt Cache (8.3x read ratio).
4//
5// Two backends:
6//   1. CLI mode (default) — uses `claude -p` from Claude Code subscription
7//   2. API mode — direct POST to api.anthropic.com (requires API credits)
8
9use anyhow::Result;
10use reqwest::Client;
11use tokio::process::Command;
12
13use crate::types::*;
14
15pub const PRIMARY: ModelConfig = ModelConfig {
16    id: "claude-opus-4-6",
17    max_tokens: 16384,
18};
19
20pub const AUXILIARY: ModelConfig = ModelConfig {
21    id: "claude-haiku-4-5-20251001",
22    max_tokens: 8192,
23};
24
25const API_URL: &str = "https://api.anthropic.com/v1/messages";
26const API_VERSION: &str = "2023-06-01";
27
28#[derive(Debug, Clone, Copy, PartialEq)]
29pub enum InferenceBackend {
30    Cli,  // claude -p (Claude Max subscription)
31    Api,  // direct API (requires credits)
32}
33
34pub struct InferenceEngine {
35    client: Client,
36    api_key: Option<String>,
37    backend: InferenceBackend,
38}
39
40impl InferenceEngine {
41    pub fn new(api_key: Option<&str>, backend: InferenceBackend) -> Self {
42        Self {
43            client: Client::new(),
44            api_key: api_key.map(String::from),
45            backend,
46        }
47    }
48
49    /// Streaming chat — displays text tokens as they arrive, returns accumulated response
50    pub async fn chat_stream(
51        &self,
52        request: &InferenceRequest,
53        on_text: &mut dyn FnMut(&str),
54    ) -> Result<InferenceResponse> {
55        match self.backend {
56            InferenceBackend::Cli => self.chat_stream_cli(request, on_text).await,
57            InferenceBackend::Api => self.chat_stream_api(request, on_text).await,
58        }
59    }
60
61    /// API streaming — SSE from api.anthropic.com
62    async fn chat_stream_api(
63        &self,
64        request: &InferenceRequest,
65        on_text: &mut dyn FnMut(&str),
66    ) -> Result<InferenceResponse> {
67        let api_key = self.api_key.as_deref()
68            .ok_or_else(|| anyhow::anyhow!("API backend requires ANTHROPIC_API_KEY"))?;
69
70        let mut body = serde_json::to_value(request)?;
71        body.as_object_mut().unwrap().insert("stream".to_string(), serde_json::Value::Bool(true));
72
73        let response = self.client
74            .post(API_URL)
75            .header("x-api-key", api_key)
76            .header("anthropic-version", API_VERSION)
77            .header("content-type", "application/json")
78            .json(&body)
79            .send()
80            .await?;
81
82        let status = response.status();
83        if !status.is_success() {
84            let body = response.text().await.unwrap_or_default();
85            anyhow::bail!("API error {status}: {body}");
86        }
87
88        parse_sse_stream(response, on_text).await
89    }
90
91    /// CLI streaming — incremental stdout from `claude -p --output-format stream-json`
92    async fn chat_stream_cli(
93        &self,
94        request: &InferenceRequest,
95        on_text: &mut dyn FnMut(&str),
96    ) -> Result<InferenceResponse> {
97        let prompt = extract_last_user_text(&request.messages);
98        if prompt.is_empty() {
99            anyhow::bail!("No user message to send");
100        }
101
102        let mut args = vec![
103            "-p".to_string(),
104            "--output-format".to_string(),
105            "stream-json".to_string(),
106            "--verbose".to_string(),
107        ];
108
109        if let Some(ref system) = request.system {
110            args.push("--append-system-prompt".to_string());
111            args.push(system.clone());
112        }
113
114        if let Some(ref tools) = request.tools {
115            let tool_names: Vec<String> = tools.iter().map(|t| {
116                match t.name.as_str() {
117                    "bash" => "Bash".to_string(),
118                    "read" => "Read".to_string(),
119                    "edit" => "Edit".to_string(),
120                    "write" => "Write".to_string(),
121                    "glob" => "Glob".to_string(),
122                    "grep" => "Grep".to_string(),
123                    "web_fetch" => "WebFetch".to_string(),
124                    other => other.to_string(),
125                }
126            }).collect();
127            args.push("--allowedTools".to_string());
128            args.push(tool_names.join(","));
129        }
130
131        // Prompt goes after -- to separate from flags
132        args.push("--".to_string());
133        args.push(prompt);
134
135        use tokio::io::{AsyncBufReadExt, BufReader};
136
137        let mut child = Command::new("claude")
138            .args(&args)
139            .stdin(std::process::Stdio::null())
140            .stdout(std::process::Stdio::piped())
141            .stderr(std::process::Stdio::piped())
142            .env_remove("CLAUDECODE")
143            .env_remove("CLAUDE_CODE_ENTRYPOINT")
144            .env_remove("ANTHROPIC_API_KEY")
145            .spawn()?;
146
147        let stdout = child.stdout.take()
148            .ok_or_else(|| anyhow::anyhow!("Failed to capture CLI stdout"))?;
149        let reader = BufReader::new(stdout);
150        let mut lines = reader.lines();
151
152        let mut result_event: Option<serde_json::Value> = None;
153        let mut full_text = String::new();
154
155        while let Ok(Some(line)) = lines.next_line().await {
156            if line.trim().is_empty() { continue; }
157
158            let event: serde_json::Value = match serde_json::from_str(&line) {
159                Ok(v) => v,
160                Err(_) => continue,
161            };
162
163            let event_type = event.get("type").and_then(|v| v.as_str()).unwrap_or("");
164
165            match event_type {
166                "content_block_delta" => {
167                    if let Some(delta) = event.get("delta") {
168                        if let Some(text) = delta.get("text").and_then(|v| v.as_str()) {
169                            on_text(text);
170                            full_text.push_str(text);
171                        }
172                    }
173                }
174                "assistant" => {
175                    // stream-json wraps the response in {"type":"assistant","message":{...}}
176                    if let Some(message) = event.get("message") {
177                        if let Some(content) = message.get("content").and_then(|c| c.as_array()) {
178                            for block in content {
179                                if let Some(text) = block.get("text").and_then(|t| t.as_str()) {
180                                    on_text(text);
181                                    full_text.push_str(text);
182                                }
183                            }
184                        }
185                        // Check for billing errors
186                        if let Some(err) = event.get("error").and_then(|e| e.as_str()) {
187                            if err == "billing_error" {
188                                let msg = full_text.clone();
189                                if !msg.is_empty() {
190                                    anyhow::bail!("{msg}");
191                                }
192                            }
193                        }
194                    }
195                }
196                "result" => {
197                    // Check for billing/auth errors in result
198                    if event.get("is_error").and_then(|v| v.as_bool()).unwrap_or(false) {
199                        let msg = event.get("result").and_then(|v| v.as_str()).unwrap_or("Unknown error");
200                        if msg.contains("Credit balance") || msg.contains("billing") || msg.contains("auth") {
201                            anyhow::bail!("{msg}. Check your Claude subscription or API credits.");
202                        }
203                    }
204                    result_event = Some(event);
205                }
206                _ => {}
207            }
208        }
209
210        let _status = child.wait().await?;
211
212        // Extract usage from result event
213        let usage = result_event.as_ref().map(|r| {
214            let u = r.get("usage").unwrap_or(&serde_json::Value::Null);
215            Usage {
216                input_tokens: u.get("input_tokens").and_then(|v| v.as_u64()).unwrap_or(0),
217                output_tokens: u.get("output_tokens").and_then(|v| v.as_u64()).unwrap_or(0),
218                cache_read_input_tokens: u.get("cache_read_input_tokens").and_then(|v| v.as_u64()).unwrap_or(0),
219                cache_creation_input_tokens: u.get("cache_creation_input_tokens").and_then(|v| v.as_u64()).unwrap_or(0),
220            }
221        }).unwrap_or_default();
222
223        let cost = result_event.as_ref()
224            .and_then(|r| r.get("total_cost_usd"))
225            .and_then(|v| v.as_f64());
226
227        let duration_ms = result_event.as_ref()
228            .and_then(|r| r.get("duration_api_ms"))
229            .and_then(|v| v.as_u64())
230            .unwrap_or(0);
231
232        if !full_text.is_empty() {
233            Ok(InferenceResponse {
234                content: vec![ContentBlock::Text { text: full_text }],
235                stop_reason: Some("end_turn".to_string()),
236                usage,
237                model: "cli-stream".to_string(),
238                cli_meta: Some(CliMeta {
239                    cost_usd: cost.unwrap_or(0.0),
240                    duration_ms,
241                    duration_api_ms: result_event.as_ref()
242                        .and_then(|r| r.get("duration_api_ms"))
243                        .and_then(|v| v.as_u64())
244                        .unwrap_or(0),
245                    num_turns: result_event.as_ref()
246                        .and_then(|r| r.get("num_turns"))
247                        .and_then(|v| v.as_u64())
248                        .unwrap_or(1),
249                }),
250            })
251        } else if let Some(ref r) = result_event {
252            // Result exists but no text — check for error
253            let result_text = r.get("result").and_then(|v| v.as_str()).unwrap_or("");
254            if r.get("is_error").and_then(|v| v.as_bool()).unwrap_or(false) {
255                anyhow::bail!("Claude CLI error: {result_text}");
256            }
257            Ok(InferenceResponse {
258                content: vec![ContentBlock::Text { text: result_text.to_string() }],
259                stop_reason: Some("end_turn".to_string()),
260                usage,
261                model: "cli-stream".to_string(),
262                cli_meta: None,
263            })
264        } else {
265            anyhow::bail!("CLI stream ended with no output")
266        }
267    }
268
269    pub fn build_request(
270        &self,
271        messages: &[Message],
272        system: Option<&str>,
273        tools: &[ToolDefinition],
274        model: Option<&str>,
275    ) -> InferenceRequest {
276        InferenceRequest {
277            model: model.unwrap_or(PRIMARY.id).to_string(),
278            max_tokens: PRIMARY.max_tokens,
279            messages: messages.to_vec(),
280            system: system.map(String::from),
281            tools: if tools.is_empty() {
282                None
283            } else {
284                Some(tools.to_vec())
285            },
286        }
287    }
288
289    /// Auto-route: pick Opus for complex tasks, Haiku for simple ones.
290    /// Returns the model ID to use based on the last user message.
291    pub fn auto_route(&self, messages: &[Message]) -> &'static str {
292        let last_text = extract_last_user_text(messages);
293        let complexity = estimate_complexity(&last_text);
294        if complexity >= ComplexityLevel::High {
295            PRIMARY.id
296        } else {
297            AUXILIARY.id
298        }
299    }
300}
301
302fn extract_last_user_text(messages: &[Message]) -> String {
303    for msg in messages.iter().rev() {
304        if matches!(msg.role, Role::User) {
305            match &msg.content {
306                MessageContent::Text(t) => return t.clone(),
307                MessageContent::Blocks(blocks) => {
308                    for block in blocks {
309                        if let ContentBlock::Text { text } = block {
310                            return text.clone();
311                        }
312                    }
313                }
314            }
315        }
316    }
317    String::new()
318}
319
320/// Parse an SSE stream from the Anthropic Messages API
321async fn parse_sse_stream(
322    response: reqwest::Response,
323    on_text: &mut dyn FnMut(&str),
324) -> Result<InferenceResponse> {
325    use futures_util::StreamExt;
326
327    let mut stream = response.bytes_stream();
328    let mut buffer = String::new();
329
330    let mut content_blocks: Vec<ContentBlock> = Vec::new();
331    let mut current_text = String::new();
332    let mut current_tool_id = String::new();
333    let mut current_tool_name = String::new();
334    let mut current_tool_input = String::new();
335    let mut in_tool = false;
336    let mut stop_reason: Option<String> = None;
337    let mut model = String::new();
338    let mut usage = Usage::default();
339
340    while let Some(chunk) = stream.next().await {
341        let chunk = chunk?;
342        buffer.push_str(&String::from_utf8_lossy(&chunk));
343
344        while let Some(pos) = buffer.find("\n\n") {
345            let event_block = buffer[..pos].to_string();
346            buffer = buffer[pos + 2..].to_string();
347
348            let data = event_block.lines()
349                .find(|l| l.starts_with("data: "))
350                .map(|l| &l[6..]);
351
352            let Some(data) = data else { continue };
353            if data == "[DONE]" { continue; }
354
355            let event: serde_json::Value = match serde_json::from_str(data) {
356                Ok(v) => v,
357                Err(_) => continue,
358            };
359
360            let etype = event.get("type").and_then(|v| v.as_str()).unwrap_or("");
361
362            match etype {
363                "message_start" => {
364                    if let Some(msg) = event.get("message") {
365                        model = msg.get("model").and_then(|v| v.as_str()).unwrap_or("").to_string();
366                        if let Some(u) = msg.get("usage") {
367                            usage.input_tokens = u.get("input_tokens").and_then(|v| v.as_u64()).unwrap_or(0);
368                            usage.cache_read_input_tokens = u.get("cache_read_input_tokens").and_then(|v| v.as_u64()).unwrap_or(0);
369                            usage.cache_creation_input_tokens = u.get("cache_creation_input_tokens").and_then(|v| v.as_u64()).unwrap_or(0);
370                        }
371                    }
372                }
373                "content_block_start" => {
374                    if let Some(cb) = event.get("content_block") {
375                        match cb.get("type").and_then(|v| v.as_str()) {
376                            Some("text") => { in_tool = false; }
377                            Some("tool_use") => {
378                                in_tool = true;
379                                current_tool_id = cb.get("id").and_then(|v| v.as_str()).unwrap_or("").to_string();
380                                current_tool_name = cb.get("name").and_then(|v| v.as_str()).unwrap_or("").to_string();
381                                current_tool_input.clear();
382                            }
383                            _ => {}
384                        }
385                    }
386                }
387                "content_block_delta" => {
388                    if let Some(delta) = event.get("delta") {
389                        match delta.get("type").and_then(|v| v.as_str()) {
390                            Some("text_delta") => {
391                                if let Some(text) = delta.get("text").and_then(|v| v.as_str()) {
392                                    on_text(text);
393                                    current_text.push_str(text);
394                                }
395                            }
396                            Some("input_json_delta") => {
397                                if let Some(json) = delta.get("partial_json").and_then(|v| v.as_str()) {
398                                    current_tool_input.push_str(json);
399                                }
400                            }
401                            _ => {}
402                        }
403                    }
404                }
405                "content_block_stop" => {
406                    if !in_tool && !current_text.is_empty() {
407                        content_blocks.push(ContentBlock::Text { text: current_text.clone() });
408                        current_text.clear();
409                    }
410                    if in_tool && !current_tool_name.is_empty() {
411                        let input = serde_json::from_str(&current_tool_input)
412                            .unwrap_or(serde_json::Value::Object(Default::default()));
413                        content_blocks.push(ContentBlock::ToolUse {
414                            id: current_tool_id.clone(),
415                            name: current_tool_name.clone(),
416                            input,
417                        });
418                        current_tool_name.clear();
419                        current_tool_input.clear();
420                        in_tool = false;
421                    }
422                }
423                "message_delta" => {
424                    if let Some(delta) = event.get("delta") {
425                        stop_reason = delta.get("stop_reason")
426                            .and_then(|v| v.as_str())
427                            .map(String::from);
428                    }
429                    if let Some(u) = event.get("usage") {
430                        usage.output_tokens = u.get("output_tokens").and_then(|v| v.as_u64()).unwrap_or(0);
431                    }
432                }
433                _ => {}
434            }
435        }
436    }
437
438    if !current_text.is_empty() {
439        content_blocks.push(ContentBlock::Text { text: current_text });
440    }
441
442    Ok(InferenceResponse {
443        content: content_blocks,
444        stop_reason,
445        usage,
446        model,
447        cli_meta: None,
448    })
449}
450
451// ─── Multi-Model Routing ────────────────────────────────────────
452
453#[derive(Debug, PartialEq, PartialOrd)]
454enum ComplexityLevel {
455    Low,    // Simple questions, greetings, one-liners
456    Medium, // Moderate tasks, single-file edits
457    High,   // Multi-file changes, architecture, debugging
458}
459
460fn estimate_complexity(text: &str) -> ComplexityLevel {
461    let lower = text.to_lowercase();
462    let words: Vec<&str> = text.split_whitespace().collect();
463    let word_count = words.len();
464
465    // High complexity signals
466    let high_signals = [
467        "refactor", "architect", "design", "migrate", "implement",
468        "debug", "investigate", "analyze", "review", "security",
469        "optimize", "performance", "deploy", "infrastructure",
470        "multiple files", "multi-file", "across the codebase",
471    ];
472    let has_high = high_signals.iter().any(|s| lower.contains(s));
473
474    // Low complexity signals
475    let low_signals = [
476        "what is", "how do", "explain", "hello", "hi ", "thanks",
477        "what's", "define", "list", "show me", "tell me",
478    ];
479    let has_low = low_signals.iter().any(|s| lower.contains(s));
480
481    // Code blocks or file paths suggest medium+
482    let has_code = text.contains("```") || text.contains("fn ") || text.contains("def ");
483    let has_paths = text.contains('/') && text.contains('.');
484
485    if has_high || word_count > 50 || (has_code && has_paths) {
486        ComplexityLevel::High
487    } else if has_low && word_count < 15 && !has_code {
488        ComplexityLevel::Low
489    } else {
490        ComplexityLevel::Medium
491    }
492}