Skip to main content

battlecommand_forge/
llm.rs

1use anyhow::{Context, Result};
2use reqwest::Client;
3use serde_json::json;
4use std::time::Instant;
5use tokio::sync::mpsc;
6
7/// Get the Ollama base URL from OLLAMA_HOST env var, or default to localhost.
8/// Supports: "host:port", "http://host:port", or just "host".
9pub fn ollama_url() -> String {
10    match std::env::var("OLLAMA_HOST") {
11        Ok(host) if !host.is_empty() => {
12            let host = host.trim_end_matches('/');
13            if host.starts_with("http://") || host.starts_with("https://") {
14                host.to_string()
15            } else {
16                format!("http://{}", host)
17            }
18        }
19        _ => "http://localhost:11434".to_string(),
20    }
21}
22
23/// Events emitted during streaming generation.
24#[derive(Debug)]
25pub enum StreamEvent {
26    /// A chunk of generated text.
27    Token(String),
28    /// Generation complete, full text included.
29    Done(String),
30    /// An error occurred.
31    Error(String),
32    /// CTO tool call started.
33    ToolCallStart { name: String, args: String },
34    /// CTO tool call result.
35    ToolCallResult { name: String, result: String },
36    /// CTO agent returned after async task.
37    AgentReturn(Box<crate::cto::CtoAgent>),
38}
39
40// ── Ollama /api/chat tool calling types ──
41
42/// Chat message for /api/chat (multi-turn with tool support).
43#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
44pub struct ChatMessage {
45    pub role: String,
46    pub content: String,
47    #[serde(skip_serializing_if = "Option::is_none")]
48    pub tool_calls: Option<Vec<OllamaToolCall>>,
49    #[serde(skip_serializing_if = "Option::is_none")]
50    pub tool_call_id: Option<String>,
51}
52
53/// Ollama tool definition (JSON Schema).
54#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
55pub struct OllamaTool {
56    #[serde(rename = "type")]
57    pub tool_type: String,
58    pub function: OllamaToolFunction,
59}
60
61#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
62pub struct OllamaToolFunction {
63    pub name: String,
64    pub description: String,
65    pub parameters: serde_json::Value,
66}
67
68/// Tool call returned by the model.
69#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
70pub struct OllamaToolCall {
71    pub function: OllamaToolCallFunction,
72}
73
74#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
75pub struct OllamaToolCallFunction {
76    pub name: String,
77    pub arguments: serde_json::Value,
78}
79
80/// Result of chat_with_tools call.
81#[derive(Debug, Clone)]
82pub struct ChatToolResponse {
83    pub content: String,
84    pub tool_calls: Vec<OllamaToolCall>,
85}
86
87/// Stats captured from an LLM call for pipeline reports.
88#[derive(Debug, Clone)]
89pub struct LlmCallStats {
90    pub model: String,
91    pub duration_secs: f64,
92    pub token_count: u64,
93    pub tok_per_sec: f64,
94    pub output_lines: u64,
95}
96
97/// Cloud provider for routing.
98#[derive(Debug, Clone, Copy, PartialEq)]
99enum CloudProvider {
100    None,   // Local Ollama
101    Claude, // Anthropic API
102    Grok,   // xAI API (OpenAI-compatible)
103}
104
105/// Unified LLM client — routes to Claude API, Grok API, or Ollama based on model name.
106/// Supports both blocking and streaming generation.
107#[derive(Clone)]
108pub struct LlmClient {
109    http: Client,
110    claude_key: Option<String>,
111    grok_key: Option<String>,
112    model: String,
113    provider: CloudProvider,
114    context_size: u32,
115    max_predict: u32,
116}
117
118impl LlmClient {
119    pub fn new(model: &str) -> Self {
120        Self::with_limits(model, 32768, 8192)
121    }
122
123    pub fn with_limits(model: &str, context_size: u32, max_predict: u32) -> Self {
124        let provider = if model.starts_with("claude-") {
125            CloudProvider::Claude
126        } else if model.starts_with("grok-") {
127            CloudProvider::Grok
128        } else {
129            CloudProvider::None
130        };
131        Self {
132            http: Client::builder()
133                .timeout(std::time::Duration::from_secs(1800))
134                .build()
135                .expect("http client"),
136            claude_key: std::env::var("ANTHROPIC_API_KEY").ok(),
137            grok_key: std::env::var("XAI_API_KEY").ok(),
138            model: model.to_string(),
139            provider,
140            context_size,
141            max_predict,
142        }
143    }
144
145    /// Route to the appropriate cloud/local provider.
146    async fn route_generate(&self, role: &str, system: &str, user_prompt: &str) -> Result<String> {
147        match self.provider {
148            CloudProvider::Claude => {
149                if let Some(ref key) = self.claude_key {
150                    self.call_claude(key, role, system, user_prompt).await
151                } else {
152                    eprintln!("   {} Claude model selected but ANTHROPIC_API_KEY not set, falling back to Ollama", role);
153                    self.call_ollama(role, system, user_prompt).await
154                }
155            }
156            CloudProvider::Grok => {
157                if let Some(ref key) = self.grok_key {
158                    self.call_grok(key, role, system, user_prompt).await
159                } else {
160                    eprintln!(
161                        "   {} Grok model selected but XAI_API_KEY not set, falling back to Ollama",
162                        role
163                    );
164                    self.call_ollama(role, system, user_prompt).await
165                }
166            }
167            CloudProvider::None => {
168                let ollama_result = self.call_ollama(role, system, user_prompt).await;
169                if ollama_result.is_err() {
170                    if let Some(ref key) = self.claude_key {
171                        eprintln!(
172                            "   {} Ollama unavailable, falling back to Claude Opus",
173                            role
174                        );
175                        self.call_claude_fallback(key, role, system, user_prompt)
176                            .await
177                    } else {
178                        ollama_result
179                    }
180                } else {
181                    ollama_result
182                }
183            }
184        }
185    }
186
187    /// Generate text (blocking — waits for full response).
188    pub async fn generate(&self, role: &str, system: &str, user_prompt: &str) -> Result<String> {
189        let start = Instant::now();
190        let result = self.route_generate(role, system, user_prompt).await;
191
192        match &result {
193            Ok(text) => {
194                let dur = start.elapsed();
195                let lines = text.lines().count();
196                println!("   {} [{} lines, {:.1}s]", role, lines, dur.as_secs_f64());
197            }
198            Err(e) => {
199                eprintln!("   {} FAILED: {}", role, e);
200            }
201        }
202        result
203    }
204
205    /// Generate text and return stats for reports.
206    pub async fn generate_with_stats(
207        &self,
208        role: &str,
209        system: &str,
210        user_prompt: &str,
211    ) -> Result<(String, LlmCallStats)> {
212        let start = Instant::now();
213        let result = self.route_generate(role, system, user_prompt).await;
214
215        match result {
216            Ok(text) => {
217                let dur = start.elapsed();
218                let lines = text.lines().count() as u64;
219                println!("   {} [{} lines, {:.1}s]", role, lines, dur.as_secs_f64());
220                let stats = LlmCallStats {
221                    model: self.model.clone(),
222                    duration_secs: dur.as_secs_f64(),
223                    token_count: 0, // non-streaming doesn't count tokens
224                    tok_per_sec: 0.0,
225                    output_lines: lines,
226                };
227                Ok((text, stats))
228            }
229            Err(e) => {
230                eprintln!("   {} FAILED: {}", role, e);
231                Err(e)
232            }
233        }
234    }
235
236    /// Generate with live streaming and return stats for reports.
237    pub async fn generate_live_with_stats(
238        &self,
239        role: &str,
240        system: &str,
241        user_prompt: &str,
242    ) -> Result<(String, LlmCallStats)> {
243        if self.provider != CloudProvider::None {
244            // Cloud models: use generate_live (which now streams), then wrap in stats
245            let start = Instant::now();
246            let text = self.generate_live(role, system, user_prompt).await?;
247            let dur = start.elapsed();
248            let lines = text.lines().count() as u64;
249            return Ok((
250                text,
251                LlmCallStats {
252                    model: self.model.clone(),
253                    duration_secs: dur.as_secs_f64(),
254                    token_count: 0,
255                    tok_per_sec: 0.0,
256                    output_lines: lines,
257                },
258            ));
259        }
260
261        let live_result = self.call_ollama_live(role, system, user_prompt).await;
262        if live_result.is_err() {
263            if let Some(ref key) = self.claude_key {
264                eprintln!(
265                    "   {} Ollama unavailable, falling back to Claude Opus",
266                    role
267                );
268                return self
269                    .generate_with_stats_claude_fallback(key, role, system, user_prompt)
270                    .await;
271            }
272        }
273        let (text, token_count, line_count, dur) = live_result?;
274        let tok_per_sec = if dur > 0.0 {
275            token_count as f64 / dur
276        } else {
277            0.0
278        };
279
280        let stats = LlmCallStats {
281            model: self.model.clone(),
282            duration_secs: dur,
283            token_count,
284            tok_per_sec,
285            output_lines: line_count,
286        };
287        Ok((text, stats))
288    }
289
290    /// Internal: Ollama live streaming, returns (text, token_count, line_count, duration_secs).
291    async fn call_ollama_live(
292        &self,
293        role: &str,
294        system: &str,
295        user_prompt: &str,
296    ) -> Result<(String, u64, u64, f64)> {
297        use futures_util::StreamExt;
298        use std::io::Write;
299
300        let start = Instant::now();
301        println!("   {} -> Ollama live ({})", role, self.model);
302        print!("   \x1b[90m");
303
304        let body = serde_json::json!({
305            "model": &self.model,
306            "system": system,
307            "prompt": user_prompt,
308            "stream": true,
309            "options": { "temperature": 0.0, "num_ctx": self.context_size, "num_predict": self.max_predict }
310        });
311
312        let resp = self
313            .http
314            .post(format!("{}/api/generate", ollama_url()))
315            .json(&body)
316            .send()
317            .await
318            .context("Ollama request failed — is `ollama serve` running?")?;
319
320        // Check HTTP status before streaming (model not found returns 404)
321        if !resp.status().is_success() {
322            let status = resp.status();
323            let body = resp.text().await.unwrap_or_default();
324            let err_msg = serde_json::from_str::<serde_json::Value>(&body)
325                .ok()
326                .and_then(|j| j["error"].as_str().map(|s| s.to_string()))
327                .unwrap_or_else(|| format!("HTTP {}", status));
328            print!("\x1b[0m");
329            anyhow::bail!("Ollama error for model '{}': {}", self.model, err_msg);
330        }
331
332        let mut full_text = String::new();
333        let mut stream = resp.bytes_stream();
334        let mut buffer = String::new();
335        let mut token_count = 0u64;
336        let mut line_count = 0u64;
337
338        while let Some(chunk) = stream.next().await {
339            let chunk = chunk.context("Stream chunk error")?;
340            buffer.push_str(&String::from_utf8_lossy(&chunk));
341
342            while let Some(nl) = buffer.find('\n') {
343                let line = buffer[..nl].to_string();
344                buffer = buffer[nl + 1..].to_string();
345
346                if line.trim().is_empty() {
347                    continue;
348                }
349
350                if let Ok(json) = serde_json::from_str::<serde_json::Value>(&line) {
351                    if let Some(token) = json["response"].as_str() {
352                        if !token.is_empty() {
353                            full_text.push_str(token);
354                            token_count += 1;
355                            print!("{}", token);
356                            let _ = std::io::stdout().flush();
357                            line_count += token.matches('\n').count() as u64;
358                        }
359                    }
360                    if json["done"].as_bool().unwrap_or(false) {
361                        break;
362                    }
363                }
364            }
365        }
366
367        let dur = start.elapsed();
368        let tok_per_sec = if dur.as_secs_f64() > 0.0 {
369            token_count as f64 / dur.as_secs_f64()
370        } else {
371            0.0
372        };
373
374        println!("\x1b[0m");
375        println!(
376            "   {} [{} lines, {} tokens, {:.1}s, {:.0} tok/s]",
377            role,
378            line_count,
379            token_count,
380            dur.as_secs_f64(),
381            tok_per_sec
382        );
383
384        Ok((full_text, token_count, line_count, dur.as_secs_f64()))
385    }
386
387    /// Generate with live token-by-token output to stdout.
388    /// Shows the model's thinking process in real-time.
389    pub async fn generate_live(
390        &self,
391        role: &str,
392        system: &str,
393        user_prompt: &str,
394    ) -> Result<String> {
395        use std::io::Write;
396
397        match self.provider {
398            CloudProvider::Claude => {
399                if let Some(ref key) = self.claude_key {
400                    return self.call_claude_live(key, role, system, user_prompt).await;
401                }
402                return self.generate(role, system, user_prompt).await;
403            }
404            CloudProvider::Grok => {
405                if let Some(ref key) = self.grok_key {
406                    return self.call_grok_live(key, role, system, user_prompt).await;
407                }
408                return self.generate(role, system, user_prompt).await;
409            }
410            CloudProvider::None => {} // fall through to Ollama live below
411        }
412
413        // Test Ollama connectivity first with a quick check
414        let ollama_check = self
415            .http
416            .get(format!("{}/api/tags", ollama_url()))
417            .send()
418            .await;
419        if ollama_check.is_err() {
420            if let Some(ref key) = self.claude_key {
421                eprintln!(
422                    "   {} Ollama unavailable, falling back to Claude Opus",
423                    role
424                );
425                return self
426                    .call_claude_fallback(key, role, system, user_prompt)
427                    .await;
428            }
429        }
430
431        let start = Instant::now();
432        println!("   {} -> Ollama live ({})", role, self.model);
433        print!("   \x1b[90m"); // dim gray for streaming output
434
435        let body = serde_json::json!({
436            "model": &self.model,
437            "system": system,
438            "prompt": user_prompt,
439            "stream": true,
440            "options": { "temperature": 0.0, "num_ctx": self.context_size, "num_predict": self.max_predict }
441        });
442
443        let resp = self
444            .http
445            .post(format!("{}/api/generate", ollama_url()))
446            .json(&body)
447            .send()
448            .await
449            .context("Ollama request failed — is `ollama serve` running?")?;
450
451        if !resp.status().is_success() {
452            let status = resp.status();
453            let body = resp.text().await.unwrap_or_default();
454            let err_msg = serde_json::from_str::<serde_json::Value>(&body)
455                .ok()
456                .and_then(|j| j["error"].as_str().map(|s| s.to_string()))
457                .unwrap_or_else(|| format!("HTTP {}", status));
458            print!("\x1b[0m");
459            anyhow::bail!("Ollama error for model '{}': {}", self.model, err_msg);
460        }
461
462        let mut full_text = String::new();
463        let mut stream = resp.bytes_stream();
464        let mut buffer = String::new();
465        let mut token_count = 0u64;
466        let mut line_count = 0u64;
467
468        use futures_util::StreamExt;
469
470        while let Some(chunk) = stream.next().await {
471            let chunk = chunk.context("Stream chunk error")?;
472            buffer.push_str(&String::from_utf8_lossy(&chunk));
473
474            while let Some(nl) = buffer.find('\n') {
475                let line = buffer[..nl].to_string();
476                buffer = buffer[nl + 1..].to_string();
477
478                if line.trim().is_empty() {
479                    continue;
480                }
481
482                if let Ok(json) = serde_json::from_str::<serde_json::Value>(&line) {
483                    if let Some(token) = json["response"].as_str() {
484                        if !token.is_empty() {
485                            full_text.push_str(token);
486                            token_count += 1;
487
488                            // Print token to stdout in real-time
489                            print!("{}", token);
490                            let _ = std::io::stdout().flush();
491
492                            // Track newlines for line count
493                            line_count += token.matches('\n').count() as u64;
494                        }
495                    }
496                    if json["done"].as_bool().unwrap_or(false) {
497                        break;
498                    }
499                }
500            }
501        }
502
503        // Reset color and print summary
504        let dur = start.elapsed();
505        let tok_per_sec = if dur.as_secs_f64() > 0.0 {
506            token_count as f64 / dur.as_secs_f64()
507        } else {
508            0.0
509        };
510
511        println!("\x1b[0m"); // reset color
512        println!(
513            "   {} [{} lines, {} tokens, {:.1}s, {:.0} tok/s]",
514            role,
515            line_count,
516            token_count,
517            dur.as_secs_f64(),
518            tok_per_sec
519        );
520
521        Ok(full_text)
522    }
523
524    /// Generate text with streaming — tokens sent via channel as they arrive.
525    /// Returns the full accumulated text when done.
526    pub async fn generate_streaming(
527        &self,
528        role: &str,
529        system: &str,
530        user_prompt: &str,
531        tx: mpsc::Sender<StreamEvent>,
532    ) -> Result<String> {
533        let start = Instant::now();
534
535        let result = if self.provider != CloudProvider::None {
536            // Cloud providers: use non-streaming for now, send as single chunk
537            let text = self.generate(role, system, user_prompt).await?;
538            let _ = tx.send(StreamEvent::Token(text.clone())).await;
539            let _ = tx.send(StreamEvent::Done(text.clone())).await;
540            Ok(text)
541        } else {
542            self.call_ollama_streaming(role, system, user_prompt, &tx)
543                .await
544        };
545
546        match &result {
547            Ok(text) => {
548                let dur = start.elapsed();
549                println!(
550                    "   {} [streamed, {} lines, {:.1}s]",
551                    role,
552                    text.lines().count(),
553                    dur.as_secs_f64()
554                );
555            }
556            Err(e) => {
557                let _ = tx.send(StreamEvent::Error(e.to_string())).await;
558                eprintln!("   {} STREAM FAILED: {}", role, e);
559            }
560        }
561        result
562    }
563
564    /// Ollama streaming: parse NDJSON lines as they arrive.
565    async fn call_ollama_streaming(
566        &self,
567        role: &str,
568        system: &str,
569        user_prompt: &str,
570        tx: &mpsc::Sender<StreamEvent>,
571    ) -> Result<String> {
572        println!("   {} -> Ollama streaming ({})", role, self.model);
573
574        let body = json!({
575            "model": &self.model,
576            "system": system,
577            "prompt": user_prompt,
578            "stream": true,
579            "options": {
580                "temperature": 0.0,
581                "num_ctx": self.context_size,
582                "num_predict": self.max_predict
583            }
584        });
585
586        let resp = self
587            .http
588            .post(format!("{}/api/generate", ollama_url()))
589            .json(&body)
590            .send()
591            .await
592            .context("Ollama streaming request failed")?;
593
594        let mut full_text = String::new();
595        let mut stream = resp.bytes_stream();
596
597        use futures_util::StreamExt;
598        let mut buffer = String::new();
599
600        while let Some(chunk) = stream.next().await {
601            let chunk = chunk.context("Stream chunk error")?;
602            buffer.push_str(&String::from_utf8_lossy(&chunk));
603
604            // Process complete NDJSON lines
605            while let Some(newline_pos) = buffer.find('\n') {
606                let line = buffer[..newline_pos].to_string();
607                buffer = buffer[newline_pos + 1..].to_string();
608
609                if line.trim().is_empty() {
610                    continue;
611                }
612
613                if let Ok(json) = serde_json::from_str::<serde_json::Value>(&line) {
614                    if let Some(token) = json["response"].as_str() {
615                        if !token.is_empty() {
616                            full_text.push_str(token);
617                            let _ = tx.send(StreamEvent::Token(token.to_string())).await;
618                        }
619                    }
620
621                    if json["done"].as_bool().unwrap_or(false) {
622                        let _ = tx.send(StreamEvent::Done(full_text.clone())).await;
623                        return Ok(full_text);
624                    }
625                }
626            }
627        }
628
629        // If we get here without a done signal, send what we have
630        let _ = tx.send(StreamEvent::Done(full_text.clone())).await;
631        Ok(full_text)
632    }
633
634    async fn call_claude(
635        &self,
636        api_key: &str,
637        role: &str,
638        system: &str,
639        user_prompt: &str,
640    ) -> Result<String> {
641        println!("   {} -> Claude ({})", role, self.model);
642
643        let body = json!({
644            "model": &self.model,
645            "max_tokens": self.max_predict,
646            "system": system,
647            "messages": [{"role": "user", "content": user_prompt}]
648        });
649
650        let resp = self
651            .http
652            .post("https://api.anthropic.com/v1/messages")
653            .header("x-api-key", api_key)
654            .header("anthropic-version", "2023-06-01")
655            .header("content-type", "application/json")
656            .json(&body)
657            .send()
658            .await
659            .context("Claude API request failed")?;
660
661        let status = resp.status();
662        let text = resp.text().await?;
663
664        if !status.is_success() {
665            anyhow::bail!(
666                "Claude API error ({}): {}",
667                status,
668                text.chars().take(200).collect::<String>()
669            );
670        }
671
672        let json: serde_json::Value = serde_json::from_str(&text)?;
673        let content = json["content"][0]["text"]
674            .as_str()
675            .unwrap_or("")
676            .to_string();
677
678        // Log cost from usage data
679        let input_tokens = json["usage"]["input_tokens"].as_u64().unwrap_or(0);
680        let output_tokens = json["usage"]["output_tokens"].as_u64().unwrap_or(0);
681        let _ =
682            crate::enterprise::log_cost("mission", &self.model, role, input_tokens, output_tokens);
683
684        Ok(content)
685    }
686
687    /// Claude streaming: SSE with content_block_delta events.
688    async fn call_claude_live(
689        &self,
690        api_key: &str,
691        role: &str,
692        system: &str,
693        user_prompt: &str,
694    ) -> Result<String> {
695        use futures_util::StreamExt;
696        use std::io::Write;
697
698        let start = Instant::now();
699        println!("   {} -> Claude live ({})", role, self.model);
700        print!("   \x1b[90m");
701
702        let body = json!({
703            "model": &self.model,
704            "max_tokens": self.max_predict,
705            "stream": true,
706            "system": system,
707            "messages": [{"role": "user", "content": user_prompt}]
708        });
709
710        let resp = self
711            .http
712            .post("https://api.anthropic.com/v1/messages")
713            .header("x-api-key", api_key)
714            .header("anthropic-version", "2023-06-01")
715            .header("content-type", "application/json")
716            .json(&body)
717            .send()
718            .await
719            .context("Claude streaming request failed")?;
720
721        if !resp.status().is_success() {
722            let status = resp.status();
723            let body = resp.text().await.unwrap_or_default();
724            print!("\x1b[0m");
725            anyhow::bail!(
726                "Claude API error ({}): {}",
727                status,
728                body.chars().take(200).collect::<String>()
729            );
730        }
731
732        let mut full_text = String::new();
733        let mut token_count = 0u64;
734        let mut line_count = 0u64;
735        let mut input_tokens = 0u64;
736        let mut output_tokens = 0u64;
737        let mut stream = resp.bytes_stream();
738        let mut buffer = String::new();
739
740        while let Some(chunk) = stream.next().await {
741            let chunk = chunk.context("Stream chunk error")?;
742            buffer.push_str(&String::from_utf8_lossy(&chunk));
743
744            while let Some(nl) = buffer.find('\n') {
745                let line = buffer[..nl].to_string();
746                buffer = buffer[nl + 1..].to_string();
747
748                let line = line.trim();
749                if !line.starts_with("data: ") {
750                    continue;
751                }
752                let data = &line[6..];
753                if data == "[DONE]" {
754                    break;
755                }
756
757                if let Ok(json) = serde_json::from_str::<serde_json::Value>(data) {
758                    // Claude SSE: content_block_delta with delta.text
759                    if json["type"].as_str() == Some("content_block_delta") {
760                        if let Some(text) = json["delta"]["text"].as_str() {
761                            full_text.push_str(text);
762                            token_count += 1;
763                            line_count += text.matches('\n').count() as u64;
764                            print!("{}", text);
765                            let _ = std::io::stdout().flush();
766                        }
767                    }
768                    // Claude SSE: message_delta has usage.output_tokens
769                    if json["type"].as_str() == Some("message_delta") {
770                        output_tokens = json["usage"]["output_tokens"]
771                            .as_u64()
772                            .unwrap_or(token_count);
773                    }
774                    // Claude SSE: message_start has usage.input_tokens
775                    if json["type"].as_str() == Some("message_start") {
776                        input_tokens = json["message"]["usage"]["input_tokens"]
777                            .as_u64()
778                            .unwrap_or(0);
779                    }
780                }
781            }
782        }
783
784        let dur = start.elapsed();
785        if output_tokens == 0 {
786            output_tokens = token_count;
787        }
788        let tok_per_sec = if dur.as_secs_f64() > 0.0 {
789            output_tokens as f64 / dur.as_secs_f64()
790        } else {
791            0.0
792        };
793        println!("\x1b[0m");
794        println!(
795            "   {} [{} lines, {} tokens, {:.1}s, {:.0} tok/s]",
796            role,
797            line_count,
798            output_tokens,
799            dur.as_secs_f64(),
800            tok_per_sec
801        );
802
803        let _ =
804            crate::enterprise::log_cost("mission", &self.model, role, input_tokens, output_tokens);
805
806        Ok(full_text)
807    }
808
809    /// Grok streaming: OpenAI-compatible SSE with choices[0].delta.content.
810    async fn call_grok_live(
811        &self,
812        api_key: &str,
813        role: &str,
814        system: &str,
815        user_prompt: &str,
816    ) -> Result<String> {
817        use futures_util::StreamExt;
818        use std::io::Write;
819
820        let start = Instant::now();
821        println!("   {} -> Grok live ({})", role, self.model);
822        print!("   \x1b[90m");
823
824        let body = json!({
825            "model": &self.model,
826            "max_tokens": self.max_predict,
827            "temperature": 0.0,
828            "stream": true,
829            "messages": [
830                {"role": "system", "content": system},
831                {"role": "user", "content": user_prompt}
832            ]
833        });
834
835        let resp = self
836            .http
837            .post("https://api.x.ai/v1/chat/completions")
838            .header("Authorization", format!("Bearer {}", api_key))
839            .header("content-type", "application/json")
840            .json(&body)
841            .send()
842            .await
843            .context("Grok streaming request failed")?;
844
845        if !resp.status().is_success() {
846            let status = resp.status();
847            let body = resp.text().await.unwrap_or_default();
848            print!("\x1b[0m");
849            anyhow::bail!(
850                "Grok API error ({}): {}",
851                status,
852                body.chars().take(200).collect::<String>()
853            );
854        }
855
856        let mut full_text = String::new();
857        let mut token_count = 0u64;
858        let mut line_count = 0u64;
859        let mut stream = resp.bytes_stream();
860        let mut buffer = String::new();
861
862        while let Some(chunk) = stream.next().await {
863            let chunk = chunk.context("Stream chunk error")?;
864            buffer.push_str(&String::from_utf8_lossy(&chunk));
865
866            while let Some(nl) = buffer.find('\n') {
867                let line = buffer[..nl].to_string();
868                buffer = buffer[nl + 1..].to_string();
869
870                let line = line.trim();
871                if !line.starts_with("data: ") {
872                    continue;
873                }
874                let data = &line[6..];
875                if data == "[DONE]" {
876                    break;
877                }
878
879                if let Ok(json) = serde_json::from_str::<serde_json::Value>(data) {
880                    if let Some(text) = json["choices"][0]["delta"]["content"].as_str() {
881                        full_text.push_str(text);
882                        token_count += 1;
883                        line_count += text.matches('\n').count() as u64;
884                        print!("{}", text);
885                        let _ = std::io::stdout().flush();
886                    }
887                }
888            }
889        }
890
891        let dur = start.elapsed();
892        let tok_per_sec = if dur.as_secs_f64() > 0.0 {
893            token_count as f64 / dur.as_secs_f64()
894        } else {
895            0.0
896        };
897        println!("\x1b[0m");
898        println!(
899            "   {} [{} lines, {} tokens, {:.1}s, {:.0} tok/s]",
900            role,
901            line_count,
902            token_count,
903            dur.as_secs_f64(),
904            tok_per_sec
905        );
906
907        // Estimate input tokens (~4 chars/token), output tokens from stream count
908        let est_input = (system.len() + user_prompt.len()) as u64 / 4;
909        let _ = crate::enterprise::log_cost("mission", &self.model, role, est_input, token_count);
910
911        Ok(full_text)
912    }
913
914    async fn call_grok(
915        &self,
916        api_key: &str,
917        role: &str,
918        system: &str,
919        user_prompt: &str,
920    ) -> Result<String> {
921        println!("   {} -> Grok ({})", role, self.model);
922
923        let body = json!({
924            "model": &self.model,
925            "max_tokens": self.max_predict,
926            "temperature": 0.0,
927            "messages": [
928                {"role": "system", "content": system},
929                {"role": "user", "content": user_prompt}
930            ]
931        });
932
933        let resp = self
934            .http
935            .post("https://api.x.ai/v1/chat/completions")
936            .header("Authorization", format!("Bearer {}", api_key))
937            .header("content-type", "application/json")
938            .json(&body)
939            .send()
940            .await
941            .context("Grok API request failed")?;
942
943        let status = resp.status();
944        let text = resp.text().await?;
945
946        if !status.is_success() {
947            anyhow::bail!(
948                "Grok API error ({}): {}",
949                status,
950                text.chars().take(200).collect::<String>()
951            );
952        }
953
954        let json: serde_json::Value = serde_json::from_str(&text)?;
955        let content = json["choices"][0]["message"]["content"]
956            .as_str()
957            .unwrap_or("")
958            .to_string();
959
960        // Log cost from usage data
961        let input_tokens = json["usage"]["prompt_tokens"].as_u64().unwrap_or(0);
962        let output_tokens = json["usage"]["completion_tokens"].as_u64().unwrap_or(0);
963        let _ =
964            crate::enterprise::log_cost("mission", &self.model, role, input_tokens, output_tokens);
965
966        Ok(content)
967    }
968
969    async fn call_ollama(&self, role: &str, system: &str, user_prompt: &str) -> Result<String> {
970        println!("   {} -> Ollama ({})", role, self.model);
971
972        let body = json!({
973            "model": &self.model,
974            "system": system,
975            "prompt": user_prompt,
976            "stream": false,
977            "options": {
978                "temperature": 0.0,
979                "num_ctx": self.context_size,
980                "num_predict": self.max_predict
981            }
982        });
983
984        let resp = self
985            .http
986            .post(format!("{}/api/generate", ollama_url()))
987            .json(&body)
988            .send()
989            .await
990            .context("Ollama request failed — is `ollama serve` running?")?;
991
992        let status = resp.status();
993        let text = resp.text().await?;
994        let json: serde_json::Value =
995            serde_json::from_str(&text).context("Ollama returned invalid JSON")?;
996
997        // Check for Ollama error (model not found, pull required, etc.)
998        if !status.is_success() || json.get("error").is_some() {
999            let err_msg = json["error"].as_str().unwrap_or("unknown error");
1000            anyhow::bail!("Ollama error for model '{}': {}", self.model, err_msg);
1001        }
1002
1003        let response = json["response"].as_str().unwrap_or("").to_string();
1004
1005        Ok(response)
1006    }
1007
1008    /// Fallback: call Claude Opus when Ollama is unavailable.
1009    /// Uses claude-opus-4-6 with the same system/user prompt.
1010    async fn call_claude_fallback(
1011        &self,
1012        api_key: &str,
1013        role: &str,
1014        system: &str,
1015        user_prompt: &str,
1016    ) -> Result<String> {
1017        println!("   {} -> Claude Opus (fallback from {})", role, self.model);
1018        self.call_claude(api_key, role, system, user_prompt).await
1019    }
1020
1021    /// Fallback with stats for live streaming functions.
1022    async fn generate_with_stats_claude_fallback(
1023        &self,
1024        api_key: &str,
1025        role: &str,
1026        system: &str,
1027        user_prompt: &str,
1028    ) -> Result<(String, LlmCallStats)> {
1029        let start = Instant::now();
1030        let text = self
1031            .call_claude_fallback(api_key, role, system, user_prompt)
1032            .await?;
1033        let dur = start.elapsed();
1034        let lines = text.lines().count() as u64;
1035        println!("   {} [{} lines, {:.1}s]", role, lines, dur.as_secs_f64());
1036        let stats = LlmCallStats {
1037            model: "claude-opus-4-6 (fallback)".to_string(),
1038            duration_secs: dur.as_secs_f64(),
1039            token_count: 0,
1040            tok_per_sec: 0.0,
1041            output_lines: lines,
1042        };
1043        Ok((text, stats))
1044    }
1045
1046    // ── Chat with tools (Ollama /api/chat, Claude tool_use, Grok functions) ──
1047
1048    /// Chat with native tool calling. Returns assistant content + tool calls.
1049    pub async fn chat_with_tools(
1050        &self,
1051        messages: &[ChatMessage],
1052        tools: &[OllamaTool],
1053    ) -> Result<ChatToolResponse> {
1054        match self.provider {
1055            CloudProvider::None => self.chat_tools_ollama(messages, tools).await,
1056            CloudProvider::Claude => self.chat_tools_claude(messages, tools).await,
1057            CloudProvider::Grok => self.chat_tools_grok(messages, tools).await,
1058        }
1059    }
1060
1061    async fn chat_tools_ollama(
1062        &self,
1063        messages: &[ChatMessage],
1064        tools: &[OllamaTool],
1065    ) -> Result<ChatToolResponse> {
1066        let msgs: Vec<serde_json::Value> = messages
1067            .iter()
1068            .map(|m| {
1069                let mut msg = json!({ "role": m.role, "content": m.content });
1070                if let Some(ref tc) = m.tool_calls {
1071                    msg["tool_calls"] = serde_json::to_value(tc).unwrap_or_default();
1072                }
1073                if let Some(ref id) = m.tool_call_id {
1074                    msg["tool_call_id"] = json!(id);
1075                }
1076                msg
1077            })
1078            .collect();
1079
1080        let body = json!({
1081            "model": &self.model,
1082            "messages": msgs,
1083            "tools": tools,
1084            "stream": false,
1085            "options": {
1086                "temperature": 0.0,
1087                "num_ctx": self.context_size,
1088                "num_predict": self.max_predict
1089            }
1090        });
1091
1092        let url = format!("{}/api/chat", ollama_url());
1093        let resp = self
1094            .http
1095            .post(&url)
1096            .json(&body)
1097            .send()
1098            .await
1099            .context("Ollama chat_with_tools request failed")?;
1100        let data: serde_json::Value = resp
1101            .json()
1102            .await
1103            .context("Ollama chat_with_tools parse failed")?;
1104
1105        let content = data["message"]["content"]
1106            .as_str()
1107            .unwrap_or("")
1108            .to_string();
1109        let tool_calls: Vec<OllamaToolCall> = data["message"]["tool_calls"]
1110            .as_array()
1111            .map(|arr| {
1112                arr.iter()
1113                    .filter_map(|tc| serde_json::from_value(tc.clone()).ok())
1114                    .collect()
1115            })
1116            .unwrap_or_default();
1117
1118        // Fallback: if no native tool calls, check for TOOL_CALL: text pattern
1119        if tool_calls.is_empty() {
1120            if let Some(tc) = extract_text_tool_call(&content) {
1121                return Ok(ChatToolResponse {
1122                    content: content.clone(),
1123                    tool_calls: vec![tc],
1124                });
1125            }
1126        }
1127
1128        Ok(ChatToolResponse {
1129            content,
1130            tool_calls,
1131        })
1132    }
1133
1134    async fn chat_tools_claude(
1135        &self,
1136        messages: &[ChatMessage],
1137        tools: &[OllamaTool],
1138    ) -> Result<ChatToolResponse> {
1139        let api_key = self
1140            .claude_key
1141            .as_deref()
1142            .ok_or_else(|| anyhow::anyhow!("ANTHROPIC_API_KEY required for Claude tool calling"))?;
1143
1144        // Convert messages (skip system, extract separately)
1145        let system_msg = messages
1146            .iter()
1147            .find(|m| m.role == "system")
1148            .map(|m| m.content.clone())
1149            .unwrap_or_default();
1150        let claude_msgs: Vec<serde_json::Value> = messages.iter()
1151            .filter(|m| m.role != "system")
1152            .map(|m| {
1153                let role = if m.role == "tool" { "user" } else { &m.role };
1154                let content = if m.role == "tool" {
1155                    json!([{ "type": "tool_result", "tool_use_id": m.tool_call_id.as_deref().unwrap_or(""), "content": m.content }])
1156                } else {
1157                    json!(m.content)
1158                };
1159                json!({ "role": role, "content": content })
1160            })
1161            .collect();
1162
1163        // Convert tools to Claude format
1164        let claude_tools: Vec<serde_json::Value> = tools.iter().map(|t| {
1165            json!({ "name": t.function.name, "description": t.function.description, "input_schema": t.function.parameters })
1166        }).collect();
1167
1168        let body = json!({
1169            "model": &self.model,
1170            "max_tokens": self.max_predict,
1171            "system": system_msg,
1172            "messages": claude_msgs,
1173            "tools": claude_tools,
1174        });
1175
1176        let resp = self
1177            .http
1178            .post("https://api.anthropic.com/v1/messages")
1179            .header("x-api-key", api_key)
1180            .header("anthropic-version", "2023-06-01")
1181            .header("content-type", "application/json")
1182            .json(&body)
1183            .send()
1184            .await
1185            .context("Claude chat_with_tools failed")?;
1186        let data: serde_json::Value = resp.json().await?;
1187
1188        let mut content = String::new();
1189        let mut tool_calls = Vec::new();
1190
1191        if let Some(blocks) = data["content"].as_array() {
1192            for block in blocks {
1193                match block["type"].as_str() {
1194                    Some("text") => {
1195                        content.push_str(block["text"].as_str().unwrap_or(""));
1196                    }
1197                    Some("tool_use") => {
1198                        tool_calls.push(OllamaToolCall {
1199                            function: OllamaToolCallFunction {
1200                                name: block["name"].as_str().unwrap_or("").to_string(),
1201                                arguments: block["input"].clone(),
1202                            },
1203                        });
1204                    }
1205                    _ => {}
1206                }
1207            }
1208        }
1209
1210        Ok(ChatToolResponse {
1211            content,
1212            tool_calls,
1213        })
1214    }
1215
1216    async fn chat_tools_grok(
1217        &self,
1218        messages: &[ChatMessage],
1219        tools: &[OllamaTool],
1220    ) -> Result<ChatToolResponse> {
1221        let api_key = self
1222            .grok_key
1223            .as_deref()
1224            .ok_or_else(|| anyhow::anyhow!("XAI_API_KEY required for Grok tool calling"))?;
1225
1226        let oai_msgs: Vec<serde_json::Value> = messages
1227            .iter()
1228            .map(|m| json!({ "role": m.role, "content": m.content }))
1229            .collect();
1230
1231        let oai_tools: Vec<serde_json::Value> = tools.iter().map(|t| {
1232            json!({ "type": "function", "function": { "name": t.function.name, "description": t.function.description, "parameters": t.function.parameters } })
1233        }).collect();
1234
1235        let body = json!({
1236            "model": &self.model,
1237            "messages": oai_msgs,
1238            "tools": oai_tools,
1239            "max_tokens": self.max_predict,
1240        });
1241
1242        let resp = self
1243            .http
1244            .post("https://api.x.ai/v1/chat/completions")
1245            .header("Authorization", format!("Bearer {}", api_key))
1246            .header("Content-Type", "application/json")
1247            .json(&body)
1248            .send()
1249            .await
1250            .context("Grok chat_with_tools failed")?;
1251        let data: serde_json::Value = resp.json().await?;
1252
1253        let content = data["choices"][0]["message"]["content"]
1254            .as_str()
1255            .unwrap_or("")
1256            .to_string();
1257        let tool_calls: Vec<OllamaToolCall> = data["choices"][0]["message"]["tool_calls"]
1258            .as_array()
1259            .map(|arr| {
1260                arr.iter()
1261                    .filter_map(|tc| {
1262                        let name = tc["function"]["name"].as_str()?.to_string();
1263                        let args_str = tc["function"]["arguments"].as_str().unwrap_or("{}");
1264                        let arguments = serde_json::from_str(args_str).unwrap_or(json!({}));
1265                        Some(OllamaToolCall {
1266                            function: OllamaToolCallFunction { name, arguments },
1267                        })
1268                    })
1269                    .collect()
1270            })
1271            .unwrap_or_default();
1272
1273        Ok(ChatToolResponse {
1274            content,
1275            tool_calls,
1276        })
1277    }
1278}
1279
1280/// Fallback: extract tool call from text pattern "TOOL_CALL: name args".
1281fn extract_text_tool_call(text: &str) -> Option<OllamaToolCall> {
1282    for line in text.lines() {
1283        let trimmed = line.trim();
1284        if trimmed.starts_with("TOOL_CALL:") {
1285            let rest = trimmed.strip_prefix("TOOL_CALL:")?.trim();
1286            let (name, args) = rest.split_once(' ').unwrap_or((rest, ""));
1287            return Some(OllamaToolCall {
1288                function: OllamaToolCallFunction {
1289                    name: name.to_string(),
1290                    arguments: json!({ "input": args.trim() }),
1291                },
1292            });
1293        }
1294    }
1295    None
1296}
1297
1298/// Extract clean code from an LLM response that may contain markdown fences.
1299pub fn extract_code(raw: &str, language: &str) -> String {
1300    // Try ```language\n...\n```
1301    let fence = format!("```{}", language);
1302    if let Some(start) = raw.find(&fence) {
1303        let after_fence = &raw[start + fence.len()..];
1304        let code_start = if after_fence.starts_with('\n') { 1 } else { 0 };
1305        if let Some(end) = after_fence[code_start..].find("```") {
1306            return after_fence[code_start..code_start + end].trim().to_string();
1307        }
1308    }
1309
1310    // Try generic ```\n...\n```
1311    if let Some(start) = raw.find("```\n") {
1312        let after = &raw[start + 4..];
1313        if let Some(end) = after.find("```") {
1314            return after[..end].trim().to_string();
1315        }
1316    }
1317
1318    // No fences found — return as-is (trimmed)
1319    raw.trim().to_string()
1320}
1321
1322#[cfg(test)]
1323mod tests {
1324    use super::*;
1325
1326    #[test]
1327    fn test_extract_code_with_python_fence() {
1328        let raw = "Here is the code:\n```python\ndef hello():\n    print('hi')\n```\nDone!";
1329        assert_eq!(extract_code(raw, "python"), "def hello():\n    print('hi')");
1330    }
1331
1332    #[test]
1333    fn test_extract_code_generic_fence() {
1334        let raw = "```\nconst x = 1;\n```";
1335        assert_eq!(extract_code(raw, "javascript"), "const x = 1;");
1336    }
1337
1338    #[test]
1339    fn test_extract_code_no_fence() {
1340        let raw = "def hello():\n    print('hi')";
1341        assert_eq!(extract_code(raw, "python"), raw.trim());
1342    }
1343
1344    #[test]
1345    fn test_stream_event_variants() {
1346        let token = StreamEvent::Token("hello".into());
1347        let done = StreamEvent::Done("full text".into());
1348        let err = StreamEvent::Error("oops".into());
1349        // Just verify they construct without panic
1350        match token {
1351            StreamEvent::Token(t) => assert_eq!(t, "hello"),
1352            other => unreachable!("unexpected variant: {:?}", other),
1353        }
1354        match done {
1355            StreamEvent::Done(t) => assert_eq!(t, "full text"),
1356            other => unreachable!("unexpected variant: {:?}", other),
1357        }
1358        match err {
1359            StreamEvent::Error(t) => assert_eq!(t, "oops"),
1360            other => unreachable!("unexpected variant: {:?}", other),
1361        }
1362    }
1363}