Skip to main content

punch_runtime/
driver.rs

1//! LLM driver trait and provider implementations.
2//!
3//! The [`LlmDriver`] trait abstracts over different LLM providers so the
4//! fighter loop is provider-agnostic. Concrete implementations handle the
5//! wire format differences between Anthropic, OpenAI-compatible APIs, etc.
6
7use std::sync::Arc;
8
9use async_trait::async_trait;
10use hmac::{Hmac, Mac};
11use reqwest::Client;
12use serde::{Deserialize, Serialize};
13use sha2::{Digest, Sha256};
14
15use punch_types::{
16    Message, ModelConfig, Provider, PunchError, PunchResult, Role, ToolCall, ToolDefinition,
17};
18
19// ---------------------------------------------------------------------------
20// Core types
21// ---------------------------------------------------------------------------
22
23/// Why the model stopped generating.
24#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
25#[serde(rename_all = "snake_case")]
26pub enum StopReason {
27    /// The model finished its turn naturally.
28    EndTurn,
29    /// The model wants to invoke one or more tools.
30    ToolUse,
31    /// The response was truncated due to max_tokens.
32    MaxTokens,
33    /// An error occurred during generation.
34    Error,
35}
36
37/// Token usage statistics for a single completion.
38#[derive(Debug, Clone, Copy, Default, Serialize, Deserialize)]
39pub struct TokenUsage {
40    pub input_tokens: u64,
41    pub output_tokens: u64,
42}
43
44impl TokenUsage {
45    /// Add another usage on top of this one (accumulator).
46    pub fn accumulate(&mut self, other: &TokenUsage) {
47        self.input_tokens += other.input_tokens;
48        self.output_tokens += other.output_tokens;
49    }
50
51    /// Total tokens consumed.
52    pub fn total(&self) -> u64 {
53        self.input_tokens + self.output_tokens
54    }
55}
56
57/// A request to the LLM for a completion.
58#[derive(Debug, Clone, Serialize, Deserialize)]
59pub struct CompletionRequest {
60    /// Model identifier (e.g. "claude-sonnet-4-20250514").
61    pub model: String,
62    /// Conversation messages.
63    pub messages: Vec<Message>,
64    /// Tools available for the model to call.
65    #[serde(default)]
66    pub tools: Vec<ToolDefinition>,
67    /// Maximum tokens to generate.
68    pub max_tokens: u32,
69    /// Sampling temperature.
70    pub temperature: Option<f32>,
71    /// System prompt (separate from messages for providers that support it).
72    pub system_prompt: Option<String>,
73}
74
75/// The response from an LLM completion.
76#[derive(Debug, Clone, Serialize, Deserialize)]
77pub struct CompletionResponse {
78    /// The assistant message (may contain tool calls).
79    pub message: Message,
80    /// Token usage for this completion.
81    pub usage: TokenUsage,
82    /// Why the model stopped.
83    pub stop_reason: StopReason,
84}
85
86// ---------------------------------------------------------------------------
87// Think-tag stripping
88// ---------------------------------------------------------------------------
89
90/// Strip reasoning/thinking tags from LLM responses.
91///
92/// Many reasoning models (Qwen, DeepSeek, etc.) wrap internal chain-of-thought
93/// in `<think>...</think>`, `<thinking>...</thinking>`, or `<reasoning>...</reasoning>`
94/// tags. This function extracts only the visible output.
95///
96/// If the entire response is inside think tags (no visible output), returns
97/// the original content unchanged so the user still sees something.
98pub fn strip_thinking_tags(content: &str) -> String {
99    let mut result = content.to_string();
100
101    // Strip all known thinking tag variants
102    for tag in &["think", "thinking", "reasoning", "reflection"] {
103        let open = format!("<{}>", tag);
104        let close = format!("</{}>", tag);
105
106        // Remove all occurrences of <tag>...</tag> blocks
107        while let Some(start) = result.find(&open) {
108            if let Some(end) = result[start..].find(&close) {
109                let block_end = start + end + close.len();
110                result = format!("{}{}", &result[..start], &result[block_end..]);
111            } else {
112                // Unclosed tag — remove from open tag to end
113                result = result[..start].to_string();
114                break;
115            }
116        }
117    }
118
119    let trimmed = result.trim().to_string();
120
121    // If stripping removed everything, return original content
122    // (the model used all tokens for thinking)
123    if trimmed.is_empty() {
124        content.to_string()
125    } else {
126        trimmed
127    }
128}
129
130// ---------------------------------------------------------------------------
131// Trait
132// ---------------------------------------------------------------------------
133
134/// Abstraction over LLM providers.
135#[async_trait]
136pub trait LlmDriver: Send + Sync + 'static {
137    /// Send a completion request and return the response.
138    async fn complete(&self, request: CompletionRequest) -> PunchResult<CompletionResponse>;
139
140    /// Streaming variant. Default implementation falls back to `complete`.
141    async fn stream_complete(
142        &self,
143        request: CompletionRequest,
144    ) -> PunchResult<CompletionResponse> {
145        self.complete(request).await
146    }
147}
148
149// ---------------------------------------------------------------------------
150// Anthropic driver
151// ---------------------------------------------------------------------------
152
153/// Driver for the Anthropic Messages API (api.anthropic.com).
154pub struct AnthropicDriver {
155    client: Client,
156    api_key: String,
157    base_url: String,
158}
159
160impl AnthropicDriver {
161    /// Create a new Anthropic driver.
162    ///
163    /// `api_key` is the raw key value, not the env var name.
164    pub fn new(api_key: String, base_url: Option<String>) -> Self {
165        Self {
166            client: Client::new(),
167            api_key,
168            base_url: base_url.unwrap_or_else(|| "https://api.anthropic.com".to_string()),
169        }
170    }
171
172    /// Create a new Anthropic driver with a shared HTTP client.
173    ///
174    /// This allows connection pooling across all drivers.
175    pub fn with_client(client: Client, api_key: String, base_url: Option<String>) -> Self {
176        Self {
177            client,
178            api_key,
179            base_url: base_url.unwrap_or_else(|| "https://api.anthropic.com".to_string()),
180        }
181    }
182
183    /// Build the Anthropic API request body from our internal types.
184    fn build_request_body(&self, request: &CompletionRequest) -> serde_json::Value {
185        let mut messages = Vec::new();
186
187        for msg in &request.messages {
188            match msg.role {
189                Role::User => {
190                    messages.push(serde_json::json!({
191                        "role": "user",
192                        "content": msg.content,
193                    }));
194                }
195                Role::Assistant => {
196                    let mut content_blocks: Vec<serde_json::Value> = Vec::new();
197
198                    if !msg.content.is_empty() {
199                        content_blocks.push(serde_json::json!({
200                            "type": "text",
201                            "text": msg.content,
202                        }));
203                    }
204
205                    for tc in &msg.tool_calls {
206                        content_blocks.push(serde_json::json!({
207                            "type": "tool_use",
208                            "id": tc.id,
209                            "name": tc.name,
210                            "input": tc.input,
211                        }));
212                    }
213
214                    if content_blocks.is_empty() {
215                        content_blocks.push(serde_json::json!({
216                            "type": "text",
217                            "text": "",
218                        }));
219                    }
220
221                    messages.push(serde_json::json!({
222                        "role": "assistant",
223                        "content": content_blocks,
224                    }));
225                }
226                Role::Tool => {
227                    let mut result_blocks: Vec<serde_json::Value> = Vec::new();
228                    for tr in &msg.tool_results {
229                        result_blocks.push(serde_json::json!({
230                            "type": "tool_result",
231                            "tool_use_id": tr.id,
232                            "content": tr.content,
233                            "is_error": tr.is_error,
234                        }));
235                    }
236                    messages.push(serde_json::json!({
237                        "role": "user",
238                        "content": result_blocks,
239                    }));
240                }
241                Role::System => {
242                    // System messages are handled via the top-level `system` param;
243                    // skip them in the messages array.
244                }
245            }
246        }
247
248        let tools: Vec<serde_json::Value> = request
249            .tools
250            .iter()
251            .map(|t| {
252                serde_json::json!({
253                    "name": t.name,
254                    "description": t.description,
255                    "input_schema": t.input_schema,
256                })
257            })
258            .collect();
259
260        let mut body = serde_json::json!({
261            "model": request.model,
262            "messages": messages,
263            "max_tokens": request.max_tokens,
264        });
265
266        if let Some(temp) = request.temperature {
267            body["temperature"] = serde_json::json!(temp);
268        }
269
270        if let Some(ref system) = request.system_prompt {
271            body["system"] = serde_json::json!(system);
272        }
273
274        if !tools.is_empty() {
275            body["tools"] = serde_json::json!(tools);
276        }
277
278        body
279    }
280
281    /// Parse the Anthropic API response into our internal types.
282    fn parse_response(&self, body: &serde_json::Value) -> PunchResult<CompletionResponse> {
283        let stop_reason = match body["stop_reason"].as_str() {
284            Some("end_turn") => StopReason::EndTurn,
285            Some("tool_use") => StopReason::ToolUse,
286            Some("max_tokens") => StopReason::MaxTokens,
287            _ => StopReason::Error,
288        };
289
290        let usage = TokenUsage {
291            input_tokens: body["usage"]["input_tokens"].as_u64().unwrap_or(0),
292            output_tokens: body["usage"]["output_tokens"].as_u64().unwrap_or(0),
293        };
294
295        let mut text_content = String::new();
296        let mut tool_calls = Vec::new();
297
298        if let Some(content_array) = body["content"].as_array() {
299            for block in content_array {
300                match block["type"].as_str() {
301                    Some("text") => {
302                        if let Some(text) = block["text"].as_str() {
303                            if !text_content.is_empty() {
304                                text_content.push('\n');
305                            }
306                            text_content.push_str(text);
307                        }
308                    }
309                    Some("tool_use") => {
310                        tool_calls.push(ToolCall {
311                            id: block["id"].as_str().unwrap_or_default().to_string(),
312                            name: block["name"].as_str().unwrap_or_default().to_string(),
313                            input: block["input"].clone(),
314                        });
315                    }
316                    _ => {}
317                }
318            }
319        }
320
321        // Strip thinking tags from reasoning models
322        let text_content = strip_thinking_tags(&text_content);
323
324        let message = Message {
325            role: Role::Assistant,
326            content: text_content,
327            tool_calls,
328            tool_results: Vec::new(),
329            timestamp: chrono::Utc::now(),
330        };
331
332        Ok(CompletionResponse {
333            message,
334            usage,
335            stop_reason,
336        })
337    }
338}
339
340#[async_trait]
341impl LlmDriver for AnthropicDriver {
342    async fn complete(&self, request: CompletionRequest) -> PunchResult<CompletionResponse> {
343        let url = format!("{}/v1/messages", self.base_url);
344        let body = self.build_request_body(&request);
345
346        let response = self
347            .client
348            .post(&url)
349            .header("x-api-key", &self.api_key)
350            .header("anthropic-version", "2023-06-01")
351            .header("content-type", "application/json")
352            .json(&body)
353            .send()
354            .await
355            .map_err(|e| PunchError::Provider {
356                provider: "anthropic".to_string(),
357                message: format!("request failed: {e}"),
358            })?;
359
360        let status = response.status();
361
362        if status == reqwest::StatusCode::TOO_MANY_REQUESTS {
363            let retry_after = response
364                .headers()
365                .get("retry-after")
366                .and_then(|v| v.to_str().ok())
367                .and_then(|s| s.parse::<u64>().ok())
368                .unwrap_or(60)
369                * 1000;
370
371            return Err(PunchError::RateLimited {
372                provider: "anthropic".to_string(),
373                retry_after_ms: retry_after,
374            });
375        }
376
377        if status == reqwest::StatusCode::UNAUTHORIZED || status == reqwest::StatusCode::FORBIDDEN {
378            return Err(PunchError::Auth(
379                "anthropic API key is invalid or lacks permissions".to_string(),
380            ));
381        }
382
383        let response_body: serde_json::Value =
384            response.json().await.map_err(|e| PunchError::Provider {
385                provider: "anthropic".to_string(),
386                message: format!("failed to parse response: {e}"),
387            })?;
388
389        if !status.is_success() {
390            let error_msg = response_body["error"]["message"]
391                .as_str()
392                .unwrap_or("unknown error");
393            return Err(PunchError::Provider {
394                provider: "anthropic".to_string(),
395                message: format!("API error ({}): {}", status, error_msg),
396            });
397        }
398
399        self.parse_response(&response_body)
400    }
401}
402
403// ---------------------------------------------------------------------------
404// OpenAI-compatible driver
405// ---------------------------------------------------------------------------
406
407/// Driver for OpenAI-compatible chat completions APIs.
408///
409/// Works with OpenAI, Groq, DeepSeek, Together, Fireworks,
410/// Cerebras, xAI, Mistral, and any other provider exposing the
411/// `/v1/chat/completions` endpoint.
412pub struct OpenAiCompatibleDriver {
413    client: Client,
414    api_key: String,
415    base_url: String,
416    provider_name: String,
417}
418
419impl OpenAiCompatibleDriver {
420    /// Create a new OpenAI-compatible driver.
421    pub fn new(api_key: String, base_url: String, provider_name: String) -> Self {
422        Self {
423            client: Client::new(),
424            api_key,
425            base_url,
426            provider_name,
427        }
428    }
429
430    /// Create a new OpenAI-compatible driver with a shared HTTP client.
431    pub fn with_client(
432        client: Client,
433        api_key: String,
434        base_url: String,
435        provider_name: String,
436    ) -> Self {
437        Self {
438            client,
439            api_key,
440            base_url,
441            provider_name,
442        }
443    }
444
445    /// Build the OpenAI chat completions request body.
446    pub fn build_request_body(&self, request: &CompletionRequest) -> serde_json::Value {
447        let mut messages = Vec::new();
448
449        // System prompt as a system message.
450        if let Some(ref system) = request.system_prompt {
451            messages.push(serde_json::json!({
452                "role": "system",
453                "content": system,
454            }));
455        }
456
457        for msg in &request.messages {
458            match msg.role {
459                Role::System => {
460                    messages.push(serde_json::json!({
461                        "role": "system",
462                        "content": msg.content,
463                    }));
464                }
465                Role::User => {
466                    messages.push(serde_json::json!({
467                        "role": "user",
468                        "content": msg.content,
469                    }));
470                }
471                Role::Assistant => {
472                    let mut m = serde_json::json!({
473                        "role": "assistant",
474                    });
475
476                    if !msg.content.is_empty() {
477                        m["content"] = serde_json::json!(msg.content);
478                    }
479
480                    if !msg.tool_calls.is_empty() {
481                        let tc: Vec<serde_json::Value> = msg
482                            .tool_calls
483                            .iter()
484                            .map(|tc| {
485                                serde_json::json!({
486                                    "id": tc.id,
487                                    "type": "function",
488                                    "function": {
489                                        "name": tc.name,
490                                        "arguments": tc.input.to_string(),
491                                    },
492                                })
493                            })
494                            .collect();
495                        m["tool_calls"] = serde_json::json!(tc);
496                    }
497
498                    messages.push(m);
499                }
500                Role::Tool => {
501                    for tr in &msg.tool_results {
502                        messages.push(serde_json::json!({
503                            "role": "tool",
504                            "tool_call_id": tr.id,
505                            "content": tr.content,
506                        }));
507                    }
508                }
509            }
510        }
511
512        let tools: Vec<serde_json::Value> = request
513            .tools
514            .iter()
515            .map(|t| {
516                serde_json::json!({
517                    "type": "function",
518                    "function": {
519                        "name": t.name,
520                        "description": t.description,
521                        "parameters": t.input_schema,
522                    },
523                })
524            })
525            .collect();
526
527        let mut body = serde_json::json!({
528            "model": request.model,
529            "messages": messages,
530            "max_tokens": request.max_tokens,
531        });
532
533        if let Some(temp) = request.temperature {
534            body["temperature"] = serde_json::json!(temp);
535        }
536
537        if !tools.is_empty() {
538            body["tools"] = serde_json::json!(tools);
539        }
540
541        body
542    }
543
544    /// Parse the OpenAI chat completions response.
545    pub fn parse_response(&self, body: &serde_json::Value) -> PunchResult<CompletionResponse> {
546        let choice = body["choices"]
547            .get(0)
548            .ok_or_else(|| PunchError::Provider {
549                provider: self.provider_name.clone(),
550                message: "no choices in response".to_string(),
551            })?;
552
553        let finish_reason = choice["finish_reason"].as_str().unwrap_or("stop");
554        let stop_reason = match finish_reason {
555            "stop" => StopReason::EndTurn,
556            "tool_calls" => StopReason::ToolUse,
557            "length" => StopReason::MaxTokens,
558            _ => StopReason::EndTurn,
559        };
560
561        let msg = &choice["message"];
562        let raw_content = msg["content"].as_str().unwrap_or("");
563        // Strip thinking tags from reasoning models (Qwen, DeepSeek R1, etc.)
564        let content = strip_thinking_tags(raw_content);
565
566        let mut tool_calls = Vec::new();
567        if let Some(tc_array) = msg["tool_calls"].as_array() {
568            for tc in tc_array {
569                let id = tc["id"].as_str().unwrap_or_default().to_string();
570                let name = tc["function"]["name"]
571                    .as_str()
572                    .unwrap_or_default()
573                    .to_string();
574                let args_str = tc["function"]["arguments"].as_str().unwrap_or("{}");
575                let input: serde_json::Value =
576                    serde_json::from_str(args_str).unwrap_or(serde_json::json!({}));
577
578                tool_calls.push(ToolCall { id, name, input });
579            }
580        }
581
582        let usage = TokenUsage {
583            input_tokens: body["usage"]["prompt_tokens"].as_u64().unwrap_or(0),
584            output_tokens: body["usage"]["completion_tokens"].as_u64().unwrap_or(0),
585        };
586
587        // If there are tool calls but finish_reason was not "tool_calls", fix it up.
588        let stop_reason = if !tool_calls.is_empty() && stop_reason != StopReason::ToolUse {
589            StopReason::ToolUse
590        } else {
591            stop_reason
592        };
593
594        let message = Message {
595            role: Role::Assistant,
596            content,
597            tool_calls,
598            tool_results: Vec::new(),
599            timestamp: chrono::Utc::now(),
600        };
601
602        Ok(CompletionResponse {
603            message,
604            usage,
605            stop_reason,
606        })
607    }
608}
609
610#[async_trait]
611impl LlmDriver for OpenAiCompatibleDriver {
612    async fn complete(&self, request: CompletionRequest) -> PunchResult<CompletionResponse> {
613        let url = format!(
614            "{}/v1/chat/completions",
615            self.base_url.trim_end_matches('/')
616        );
617        let body = self.build_request_body(&request);
618
619        let response = self
620            .client
621            .post(&url)
622            .header("authorization", format!("Bearer {}", self.api_key))
623            .header("content-type", "application/json")
624            .json(&body)
625            .send()
626            .await
627            .map_err(|e| PunchError::Provider {
628                provider: self.provider_name.clone(),
629                message: format!("request failed: {e}"),
630            })?;
631
632        let status = response.status();
633
634        if status == reqwest::StatusCode::TOO_MANY_REQUESTS {
635            let retry_after = response
636                .headers()
637                .get("retry-after")
638                .and_then(|v| v.to_str().ok())
639                .and_then(|s| s.parse::<u64>().ok())
640                .unwrap_or(60)
641                * 1000;
642
643            return Err(PunchError::RateLimited {
644                provider: self.provider_name.clone(),
645                retry_after_ms: retry_after,
646            });
647        }
648
649        if status == reqwest::StatusCode::UNAUTHORIZED || status == reqwest::StatusCode::FORBIDDEN {
650            return Err(PunchError::Auth(format!(
651                "{} API key is invalid or lacks permissions",
652                self.provider_name
653            )));
654        }
655
656        let response_body: serde_json::Value =
657            response.json().await.map_err(|e| PunchError::Provider {
658                provider: self.provider_name.clone(),
659                message: format!("failed to parse response: {e}"),
660            })?;
661
662        if !status.is_success() {
663            let error_msg = response_body["error"]["message"]
664                .as_str()
665                .unwrap_or("unknown error");
666            return Err(PunchError::Provider {
667                provider: self.provider_name.clone(),
668                message: format!("API error ({}): {}", status, error_msg),
669            });
670        }
671
672        self.parse_response(&response_body)
673    }
674}
675
676// ---------------------------------------------------------------------------
677// Gemini driver
678// ---------------------------------------------------------------------------
679
680/// Driver for the Google Gemini (Generative Language) API.
681pub struct GeminiDriver {
682    client: Client,
683    api_key: String,
684    base_url: String,
685}
686
687impl GeminiDriver {
688    /// Create a new Gemini driver.
689    pub fn new(api_key: String, base_url: Option<String>) -> Self {
690        Self {
691            client: Client::new(),
692            api_key,
693            base_url: base_url
694                .unwrap_or_else(|| "https://generativelanguage.googleapis.com".to_string()),
695        }
696    }
697
698    /// Create a new Gemini driver with a shared HTTP client.
699    pub fn with_client(client: Client, api_key: String, base_url: Option<String>) -> Self {
700        Self {
701            client,
702            api_key,
703            base_url: base_url
704                .unwrap_or_else(|| "https://generativelanguage.googleapis.com".to_string()),
705        }
706    }
707
708    /// Build the Gemini API request body.
709    pub fn build_request_body(&self, request: &CompletionRequest) -> serde_json::Value {
710        let mut contents = Vec::new();
711        let mut system_text: Option<String> = request.system_prompt.clone();
712
713        for msg in &request.messages {
714            match msg.role {
715                Role::System => {
716                    // Gemini does not have a system role; prepend to first user message.
717                    let existing = system_text.take().unwrap_or_default();
718                    let combined = if existing.is_empty() {
719                        msg.content.clone()
720                    } else {
721                        format!("{}\n{}", existing, msg.content)
722                    };
723                    system_text = Some(combined);
724                }
725                Role::User => {
726                    let mut text = String::new();
727                    if let Some(sys) = system_text.take()
728                        && !sys.is_empty()
729                    {
730                        text.push_str(&sys);
731                        text.push_str("\n\n");
732                    }
733                    text.push_str(&msg.content);
734                    contents.push(serde_json::json!({
735                        "role": "user",
736                        "parts": [{"text": text}],
737                    }));
738                }
739                Role::Assistant => {
740                    let mut parts: Vec<serde_json::Value> = Vec::new();
741                    if !msg.content.is_empty() {
742                        parts.push(serde_json::json!({"text": msg.content}));
743                    }
744                    for tc in &msg.tool_calls {
745                        parts.push(serde_json::json!({
746                            "functionCall": {
747                                "name": tc.name,
748                                "args": tc.input,
749                            }
750                        }));
751                    }
752                    if parts.is_empty() {
753                        parts.push(serde_json::json!({"text": ""}));
754                    }
755                    contents.push(serde_json::json!({
756                        "role": "model",
757                        "parts": parts,
758                    }));
759                }
760                Role::Tool => {
761                    let mut parts: Vec<serde_json::Value> = Vec::new();
762                    for tr in &msg.tool_results {
763                        parts.push(serde_json::json!({
764                            "functionResponse": {
765                                "name": tr.id.clone(),
766                                "response": {"content": tr.content},
767                            }
768                        }));
769                    }
770                    contents.push(serde_json::json!({
771                        "role": "user",
772                        "parts": parts,
773                    }));
774                }
775            }
776        }
777
778        // If we still have an unused system prompt (no user messages yet), add it.
779        if let Some(sys) = system_text
780            && !sys.is_empty()
781        {
782            contents.insert(
783                0,
784                serde_json::json!({
785                    "role": "user",
786                    "parts": [{"text": sys}],
787                }),
788            );
789        }
790
791        let mut body = serde_json::json!({
792            "contents": contents,
793        });
794
795        let mut gen_config = serde_json::json!({
796            "maxOutputTokens": request.max_tokens,
797        });
798        if let Some(temp) = request.temperature {
799            gen_config["temperature"] = serde_json::json!(temp);
800        }
801        body["generationConfig"] = gen_config;
802
803        if !request.tools.is_empty() {
804            let func_decls: Vec<serde_json::Value> = request
805                .tools
806                .iter()
807                .map(|t| {
808                    serde_json::json!({
809                        "name": t.name,
810                        "description": t.description,
811                        "parameters": t.input_schema,
812                    })
813                })
814                .collect();
815            body["tools"] = serde_json::json!([{"function_declarations": func_decls}]);
816        }
817
818        body
819    }
820
821    /// Build the full URL for a Gemini request.
822    pub fn build_url(&self, model: &str) -> String {
823        format!(
824            "{}/v1beta/models/{}:generateContent?key={}",
825            self.base_url.trim_end_matches('/'),
826            model,
827            self.api_key,
828        )
829    }
830
831    /// Parse the Gemini API response.
832    pub fn parse_response(&self, body: &serde_json::Value) -> PunchResult<CompletionResponse> {
833        let candidate = body["candidates"]
834            .get(0)
835            .ok_or_else(|| PunchError::Provider {
836                provider: "gemini".to_string(),
837                message: "no candidates in response".to_string(),
838            })?;
839
840        let parts = candidate["content"]["parts"]
841            .as_array()
842            .cloned()
843            .unwrap_or_default();
844
845        let mut text_content = String::new();
846        let mut tool_calls = Vec::new();
847
848        for part in &parts {
849            if let Some(text) = part["text"].as_str() {
850                if !text_content.is_empty() {
851                    text_content.push('\n');
852                }
853                text_content.push_str(text);
854            }
855            if let Some(fc) = part.get("functionCall") {
856                let name = fc["name"].as_str().unwrap_or_default().to_string();
857                let args = fc["args"].clone();
858                tool_calls.push(ToolCall {
859                    id: format!("gemini-{}", uuid::Uuid::new_v4()),
860                    name,
861                    input: args,
862                });
863            }
864        }
865
866        let finish_reason = candidate["finishReason"].as_str().unwrap_or("STOP");
867        let stop_reason = if !tool_calls.is_empty() {
868            StopReason::ToolUse
869        } else {
870            match finish_reason {
871                "STOP" => StopReason::EndTurn,
872                "MAX_TOKENS" => StopReason::MaxTokens,
873                _ => StopReason::EndTurn,
874            }
875        };
876
877        let usage = TokenUsage {
878            input_tokens: body["usageMetadata"]["promptTokenCount"]
879                .as_u64()
880                .unwrap_or(0),
881            output_tokens: body["usageMetadata"]["candidatesTokenCount"]
882                .as_u64()
883                .unwrap_or(0),
884        };
885
886        // Strip thinking tags from reasoning models
887        let text_content = strip_thinking_tags(&text_content);
888
889        let message = Message {
890            role: Role::Assistant,
891            content: text_content,
892            tool_calls,
893            tool_results: Vec::new(),
894            timestamp: chrono::Utc::now(),
895        };
896
897        Ok(CompletionResponse {
898            message,
899            usage,
900            stop_reason,
901        })
902    }
903}
904
905#[async_trait]
906impl LlmDriver for GeminiDriver {
907    async fn complete(&self, request: CompletionRequest) -> PunchResult<CompletionResponse> {
908        let url = self.build_url(&request.model);
909        let body = self.build_request_body(&request);
910
911        let response = self
912            .client
913            .post(&url)
914            .header("content-type", "application/json")
915            .json(&body)
916            .send()
917            .await
918            .map_err(|e| PunchError::Provider {
919                provider: "gemini".to_string(),
920                message: format!("request failed: {e}"),
921            })?;
922
923        let status = response.status();
924
925        if status == reqwest::StatusCode::TOO_MANY_REQUESTS {
926            return Err(PunchError::RateLimited {
927                provider: "gemini".to_string(),
928                retry_after_ms: 60_000,
929            });
930        }
931
932        if status == reqwest::StatusCode::UNAUTHORIZED || status == reqwest::StatusCode::FORBIDDEN {
933            return Err(PunchError::Auth(
934                "Gemini API key is invalid or lacks permissions".to_string(),
935            ));
936        }
937
938        let response_body: serde_json::Value =
939            response.json().await.map_err(|e| PunchError::Provider {
940                provider: "gemini".to_string(),
941                message: format!("failed to parse response: {e}"),
942            })?;
943
944        if !status.is_success() {
945            let error_msg = response_body["error"]["message"]
946                .as_str()
947                .unwrap_or("unknown error");
948            return Err(PunchError::Provider {
949                provider: "gemini".to_string(),
950                message: format!("API error ({}): {}", status, error_msg),
951            });
952        }
953
954        self.parse_response(&response_body)
955    }
956}
957
958// ---------------------------------------------------------------------------
959// Ollama driver
960// ---------------------------------------------------------------------------
961
962/// Driver for local Ollama instances using the chat API.
963pub struct OllamaDriver {
964    client: Client,
965    base_url: String,
966}
967
968impl OllamaDriver {
969    /// Create a new Ollama driver.
970    pub fn new(base_url: Option<String>) -> Self {
971        Self {
972            client: Client::new(),
973            base_url: base_url.unwrap_or_else(|| "http://localhost:11434".to_string()),
974        }
975    }
976
977    /// Create a new Ollama driver with a shared HTTP client.
978    pub fn with_client(client: Client, base_url: Option<String>) -> Self {
979        Self {
980            client,
981            base_url: base_url.unwrap_or_else(|| "http://localhost:11434".to_string()),
982        }
983    }
984
985    /// Get the base URL.
986    pub fn base_url(&self) -> &str {
987        &self.base_url
988    }
989
990    /// Build the Ollama chat request body.
991    pub fn build_request_body(&self, request: &CompletionRequest) -> serde_json::Value {
992        let mut messages = Vec::new();
993
994        if let Some(ref system) = request.system_prompt {
995            messages.push(serde_json::json!({
996                "role": "system",
997                "content": system,
998            }));
999        }
1000
1001        for msg in &request.messages {
1002            match msg.role {
1003                Role::System => {
1004                    messages.push(serde_json::json!({
1005                        "role": "system",
1006                        "content": msg.content,
1007                    }));
1008                }
1009                Role::User => {
1010                    messages.push(serde_json::json!({
1011                        "role": "user",
1012                        "content": msg.content,
1013                    }));
1014                }
1015                Role::Assistant => {
1016                    let mut m = serde_json::json!({
1017                        "role": "assistant",
1018                        "content": msg.content,
1019                    });
1020                    if !msg.tool_calls.is_empty() {
1021                        let tc: Vec<serde_json::Value> = msg
1022                            .tool_calls
1023                            .iter()
1024                            .map(|tc| {
1025                                serde_json::json!({
1026                                    "function": {
1027                                        "name": tc.name,
1028                                        "arguments": tc.input,
1029                                    }
1030                                })
1031                            })
1032                            .collect();
1033                        m["tool_calls"] = serde_json::json!(tc);
1034                    }
1035                    messages.push(m);
1036                }
1037                Role::Tool => {
1038                    for tr in &msg.tool_results {
1039                        messages.push(serde_json::json!({
1040                            "role": "tool",
1041                            "content": tr.content,
1042                        }));
1043                    }
1044                }
1045            }
1046        }
1047
1048        let mut body = serde_json::json!({
1049            "model": request.model,
1050            "messages": messages,
1051            "stream": false,
1052        });
1053
1054        let mut options = serde_json::json!({});
1055        if let Some(temp) = request.temperature {
1056            options["temperature"] = serde_json::json!(temp);
1057        }
1058        if request.max_tokens > 0 {
1059            options["num_predict"] = serde_json::json!(request.max_tokens);
1060        }
1061        body["options"] = options;
1062
1063        // Disable thinking mode for reasoning models (Qwen, DeepSeek) to prevent
1064        // the model from spending its entire token budget on internal reasoning.
1065        // The think tags get stripped anyway, so we avoid wasting tokens.
1066        body["think"] = serde_json::json!(false);
1067
1068        if !request.tools.is_empty() {
1069            let tools: Vec<serde_json::Value> = request
1070                .tools
1071                .iter()
1072                .map(|t| {
1073                    serde_json::json!({
1074                        "type": "function",
1075                        "function": {
1076                            "name": t.name,
1077                            "description": t.description,
1078                            "parameters": t.input_schema,
1079                        }
1080                    })
1081                })
1082                .collect();
1083            body["tools"] = serde_json::json!(tools);
1084        }
1085
1086        body
1087    }
1088
1089    /// Parse the Ollama chat response.
1090    pub fn parse_response(&self, body: &serde_json::Value) -> PunchResult<CompletionResponse> {
1091        let msg = &body["message"];
1092        let raw_content = msg["content"].as_str().unwrap_or("");
1093        // Strip thinking tags from reasoning models (Qwen, DeepSeek, etc.)
1094        let content = strip_thinking_tags(raw_content);
1095
1096        let mut tool_calls = Vec::new();
1097        if let Some(tc_array) = msg["tool_calls"].as_array() {
1098            for tc in tc_array {
1099                let name = tc["function"]["name"]
1100                    .as_str()
1101                    .unwrap_or_default()
1102                    .to_string();
1103                let input = tc["function"]["arguments"].clone();
1104                tool_calls.push(ToolCall {
1105                    id: format!("ollama-{}", uuid::Uuid::new_v4()),
1106                    name,
1107                    input,
1108                });
1109            }
1110        }
1111
1112        let stop_reason = if !tool_calls.is_empty() {
1113            StopReason::ToolUse
1114        } else if body["done"].as_bool().unwrap_or(true) {
1115            StopReason::EndTurn
1116        } else {
1117            StopReason::MaxTokens
1118        };
1119
1120        let usage = TokenUsage {
1121            input_tokens: body["prompt_eval_count"].as_u64().unwrap_or(0),
1122            output_tokens: body["eval_count"].as_u64().unwrap_or(0),
1123        };
1124
1125        let message = Message {
1126            role: Role::Assistant,
1127            content,
1128            tool_calls,
1129            tool_results: Vec::new(),
1130            timestamp: chrono::Utc::now(),
1131        };
1132
1133        Ok(CompletionResponse {
1134            message,
1135            usage,
1136            stop_reason,
1137        })
1138    }
1139}
1140
1141#[async_trait]
1142impl LlmDriver for OllamaDriver {
1143    async fn complete(&self, request: CompletionRequest) -> PunchResult<CompletionResponse> {
1144        let url = format!("{}/api/chat", self.base_url.trim_end_matches('/'));
1145        let body = self.build_request_body(&request);
1146
1147        let response = self
1148            .client
1149            .post(&url)
1150            .header("content-type", "application/json")
1151            .json(&body)
1152            .send()
1153            .await
1154            .map_err(|e| PunchError::Provider {
1155                provider: "ollama".to_string(),
1156                message: format!("request failed: {e}"),
1157            })?;
1158
1159        let status = response.status();
1160        let response_body: serde_json::Value =
1161            response.json().await.map_err(|e| PunchError::Provider {
1162                provider: "ollama".to_string(),
1163                message: format!("failed to parse response: {e}"),
1164            })?;
1165
1166        if !status.is_success() {
1167            let error_msg = response_body["error"]
1168                .as_str()
1169                .unwrap_or("unknown error");
1170            return Err(PunchError::Provider {
1171                provider: "ollama".to_string(),
1172                message: format!("API error ({}): {}", status, error_msg),
1173            });
1174        }
1175
1176        self.parse_response(&response_body)
1177    }
1178}
1179
1180// ---------------------------------------------------------------------------
1181// AWS Bedrock driver
1182// ---------------------------------------------------------------------------
1183
1184/// Driver for AWS Bedrock using the Converse API with SigV4 authentication.
1185pub struct BedrockDriver {
1186    client: Client,
1187    access_key: String,
1188    secret_key: String,
1189    region: String,
1190}
1191
1192impl BedrockDriver {
1193    /// Create a new Bedrock driver.
1194    pub fn new(access_key: String, secret_key: String, region: String) -> Self {
1195        Self {
1196            client: Client::new(),
1197            access_key,
1198            secret_key,
1199            region,
1200        }
1201    }
1202
1203    /// Create a new Bedrock driver with a shared HTTP client.
1204    pub fn with_client(
1205        client: Client,
1206        access_key: String,
1207        secret_key: String,
1208        region: String,
1209    ) -> Self {
1210        Self {
1211            client,
1212            access_key,
1213            secret_key,
1214            region,
1215        }
1216    }
1217
1218    /// Build the Bedrock Converse API request body.
1219    pub fn build_request_body(&self, request: &CompletionRequest) -> serde_json::Value {
1220        let mut messages = Vec::new();
1221
1222        for msg in &request.messages {
1223            match msg.role {
1224                Role::User => {
1225                    messages.push(serde_json::json!({
1226                        "role": "user",
1227                        "content": [{"text": msg.content}],
1228                    }));
1229                }
1230                Role::Assistant => {
1231                    let mut content: Vec<serde_json::Value> = Vec::new();
1232                    if !msg.content.is_empty() {
1233                        content.push(serde_json::json!({"text": msg.content}));
1234                    }
1235                    for tc in &msg.tool_calls {
1236                        content.push(serde_json::json!({
1237                            "toolUse": {
1238                                "toolUseId": tc.id,
1239                                "name": tc.name,
1240                                "input": tc.input,
1241                            }
1242                        }));
1243                    }
1244                    if content.is_empty() {
1245                        content.push(serde_json::json!({"text": ""}));
1246                    }
1247                    messages.push(serde_json::json!({
1248                        "role": "assistant",
1249                        "content": content,
1250                    }));
1251                }
1252                Role::Tool => {
1253                    let mut content: Vec<serde_json::Value> = Vec::new();
1254                    for tr in &msg.tool_results {
1255                        content.push(serde_json::json!({
1256                            "toolResult": {
1257                                "toolUseId": tr.id,
1258                                "content": [{"text": tr.content}],
1259                                "status": if tr.is_error { "error" } else { "success" },
1260                            }
1261                        }));
1262                    }
1263                    messages.push(serde_json::json!({
1264                        "role": "user",
1265                        "content": content,
1266                    }));
1267                }
1268                Role::System => {
1269                    // System messages handled separately.
1270                }
1271            }
1272        }
1273
1274        let mut body = serde_json::json!({
1275            "messages": messages,
1276        });
1277
1278        let mut inference_config = serde_json::json!({
1279            "maxTokens": request.max_tokens,
1280        });
1281        if let Some(temp) = request.temperature {
1282            inference_config["temperature"] = serde_json::json!(temp);
1283        }
1284        body["inferenceConfig"] = inference_config;
1285
1286        if let Some(ref system) = request.system_prompt {
1287            body["system"] = serde_json::json!([{"text": system}]);
1288        }
1289
1290        if !request.tools.is_empty() {
1291            let tool_config: Vec<serde_json::Value> = request
1292                .tools
1293                .iter()
1294                .map(|t| {
1295                    serde_json::json!({
1296                        "toolSpec": {
1297                            "name": t.name,
1298                            "description": t.description,
1299                            "inputSchema": {"json": t.input_schema},
1300                        }
1301                    })
1302                })
1303                .collect();
1304            body["toolConfig"] = serde_json::json!({"tools": tool_config});
1305        }
1306
1307        body
1308    }
1309
1310    /// Build the endpoint URL for a model.
1311    pub fn build_url(&self, model_id: &str) -> String {
1312        format!(
1313            "https://bedrock-runtime.{}.amazonaws.com/model/{}/converse",
1314            self.region, model_id,
1315        )
1316    }
1317
1318    /// Parse the Bedrock Converse API response.
1319    pub fn parse_response(&self, body: &serde_json::Value) -> PunchResult<CompletionResponse> {
1320        let content = body["output"]["message"]["content"]
1321            .as_array()
1322            .cloned()
1323            .unwrap_or_default();
1324
1325        let mut text_content = String::new();
1326        let mut tool_calls = Vec::new();
1327
1328        for block in &content {
1329            if let Some(text) = block["text"].as_str() {
1330                if !text_content.is_empty() {
1331                    text_content.push('\n');
1332                }
1333                text_content.push_str(text);
1334            }
1335            if let Some(tu) = block.get("toolUse") {
1336                tool_calls.push(ToolCall {
1337                    id: tu["toolUseId"].as_str().unwrap_or_default().to_string(),
1338                    name: tu["name"].as_str().unwrap_or_default().to_string(),
1339                    input: tu["input"].clone(),
1340                });
1341            }
1342        }
1343
1344        let stop_reason_str = body["stopReason"].as_str().unwrap_or("end_turn");
1345        let stop_reason = if !tool_calls.is_empty() {
1346            StopReason::ToolUse
1347        } else {
1348            match stop_reason_str {
1349                "end_turn" => StopReason::EndTurn,
1350                "tool_use" => StopReason::ToolUse,
1351                "max_tokens" => StopReason::MaxTokens,
1352                _ => StopReason::EndTurn,
1353            }
1354        };
1355
1356        let usage = TokenUsage {
1357            input_tokens: body["usage"]["inputTokens"].as_u64().unwrap_or(0),
1358            output_tokens: body["usage"]["outputTokens"].as_u64().unwrap_or(0),
1359        };
1360
1361        // Strip thinking tags from reasoning models
1362        let text_content = strip_thinking_tags(&text_content);
1363
1364        let message = Message {
1365            role: Role::Assistant,
1366            content: text_content,
1367            tool_calls,
1368            tool_results: Vec::new(),
1369            timestamp: chrono::Utc::now(),
1370        };
1371
1372        Ok(CompletionResponse {
1373            message,
1374            usage,
1375            stop_reason,
1376        })
1377    }
1378
1379    /// Compute an AWS SigV4 signature and return the Authorization header value.
1380    ///
1381    /// This is a basic implementation sufficient for Bedrock API calls.
1382    pub fn sign_request(
1383        &self,
1384        method: &str,
1385        url: &str,
1386        headers: &[(String, String)],
1387        payload: &[u8],
1388        timestamp: &str, // format: "20260313T120000Z"
1389    ) -> PunchResult<String> {
1390        let date = &timestamp[..8]; // "20260313"
1391        let service = "bedrock";
1392
1393        // Parse the URL to get host and path.
1394        let parsed = url::Url::parse(url).map_err(|e| PunchError::Provider {
1395            provider: "bedrock".to_string(),
1396            message: format!("invalid URL: {e}"),
1397        })?;
1398        let host = parsed.host_str().unwrap_or("");
1399        let path = parsed.path();
1400
1401        // 1. Create canonical request.
1402        let payload_hash = hex_sha256(payload);
1403
1404        let mut signed_header_names: Vec<String> =
1405            headers.iter().map(|(k, _)| k.to_lowercase()).collect();
1406        signed_header_names.push("host".to_string());
1407        signed_header_names.push("x-amz-date".to_string());
1408        signed_header_names.sort();
1409        signed_header_names.dedup();
1410
1411        let mut header_map: Vec<(String, String)> = headers
1412            .iter()
1413            .map(|(k, v)| (k.to_lowercase(), v.trim().to_string()))
1414            .collect();
1415        header_map.push(("host".to_string(), host.to_string()));
1416        header_map.push(("x-amz-date".to_string(), timestamp.to_string()));
1417        header_map.sort_by(|a, b| a.0.cmp(&b.0));
1418        header_map.dedup_by(|a, b| a.0 == b.0);
1419
1420        let canonical_headers: String = header_map
1421            .iter()
1422            .map(|(k, v)| format!("{}:{}\n", k, v))
1423            .collect();
1424
1425        let signed_headers = signed_header_names.join(";");
1426
1427        let canonical_request = format!(
1428            "{}\n{}\n\n{}\n{}\n{}",
1429            method, path, canonical_headers, signed_headers, payload_hash,
1430        );
1431
1432        // 2. Create string to sign.
1433        let credential_scope = format!("{}/{}/{}/aws4_request", date, self.region, service);
1434        let string_to_sign = format!(
1435            "AWS4-HMAC-SHA256\n{}\n{}\n{}",
1436            timestamp,
1437            credential_scope,
1438            hex_sha256(canonical_request.as_bytes()),
1439        );
1440
1441        // 3. Calculate signing key.
1442        let k_date = hmac_sha256(
1443            format!("AWS4{}", self.secret_key).as_bytes(),
1444            date.as_bytes(),
1445        );
1446        let k_region = hmac_sha256(&k_date, self.region.as_bytes());
1447        let k_service = hmac_sha256(&k_region, service.as_bytes());
1448        let k_signing = hmac_sha256(&k_service, b"aws4_request");
1449
1450        // 4. Calculate signature.
1451        let signature = hex_encode(&hmac_sha256(&k_signing, string_to_sign.as_bytes()));
1452
1453        // 5. Build Authorization header.
1454        Ok(format!(
1455            "AWS4-HMAC-SHA256 Credential={}/{}, SignedHeaders={}, Signature={}",
1456            self.access_key, credential_scope, signed_headers, signature,
1457        ))
1458    }
1459}
1460
1461/// Compute SHA-256 hex digest.
1462fn hex_sha256(data: &[u8]) -> String {
1463    let mut hasher = Sha256::new();
1464    hasher.update(data);
1465    hex_encode(hasher.finalize().as_slice())
1466}
1467
1468/// Compute HMAC-SHA256.
1469fn hmac_sha256(key: &[u8], data: &[u8]) -> Vec<u8> {
1470    type HmacSha256 = Hmac<Sha256>;
1471    let mut mac = HmacSha256::new_from_slice(key).expect("HMAC can take key of any size");
1472    mac.update(data);
1473    mac.finalize().into_bytes().to_vec()
1474}
1475
1476/// Hex-encode bytes without an external crate.
1477fn hex_encode(bytes: &[u8]) -> String {
1478    bytes.iter().map(|b| format!("{:02x}", b)).collect()
1479}
1480
1481#[async_trait]
1482impl LlmDriver for BedrockDriver {
1483    async fn complete(&self, request: CompletionRequest) -> PunchResult<CompletionResponse> {
1484        let url = self.build_url(&request.model);
1485        let body = self.build_request_body(&request);
1486        let payload = serde_json::to_vec(&body).map_err(|e| PunchError::Provider {
1487            provider: "bedrock".to_string(),
1488            message: format!("failed to serialize request: {e}"),
1489        })?;
1490
1491        let timestamp = chrono::Utc::now().format("%Y%m%dT%H%M%SZ").to_string();
1492
1493        let auth_header = self.sign_request(
1494            "POST",
1495            &url,
1496            &[("content-type".to_string(), "application/json".to_string())],
1497            &payload,
1498            &timestamp,
1499        )?;
1500
1501        let parsed_url = url::Url::parse(&url).map_err(|e| PunchError::Provider {
1502            provider: "bedrock".to_string(),
1503            message: format!("invalid URL: {e}"),
1504        })?;
1505        let host = parsed_url.host_str().unwrap_or_default().to_string();
1506
1507        let response = self
1508            .client
1509            .post(&url)
1510            .header("content-type", "application/json")
1511            .header("host", &host)
1512            .header("x-amz-date", &timestamp)
1513            .header("authorization", &auth_header)
1514            .body(payload)
1515            .send()
1516            .await
1517            .map_err(|e| PunchError::Provider {
1518                provider: "bedrock".to_string(),
1519                message: format!("request failed: {e}"),
1520            })?;
1521
1522        let status = response.status();
1523
1524        if status == reqwest::StatusCode::TOO_MANY_REQUESTS {
1525            return Err(PunchError::RateLimited {
1526                provider: "bedrock".to_string(),
1527                retry_after_ms: 60_000,
1528            });
1529        }
1530
1531        if status == reqwest::StatusCode::UNAUTHORIZED || status == reqwest::StatusCode::FORBIDDEN {
1532            return Err(PunchError::Auth(
1533                "AWS Bedrock credentials are invalid or lack permissions".to_string(),
1534            ));
1535        }
1536
1537        let response_body: serde_json::Value =
1538            response.json().await.map_err(|e| PunchError::Provider {
1539                provider: "bedrock".to_string(),
1540                message: format!("failed to parse response: {e}"),
1541            })?;
1542
1543        if !status.is_success() {
1544            let error_msg = response_body["message"]
1545                .as_str()
1546                .unwrap_or("unknown error");
1547            return Err(PunchError::Provider {
1548                provider: "bedrock".to_string(),
1549                message: format!("API error ({}): {}", status, error_msg),
1550            });
1551        }
1552
1553        self.parse_response(&response_body)
1554    }
1555}
1556
1557// ---------------------------------------------------------------------------
1558// Azure OpenAI driver
1559// ---------------------------------------------------------------------------
1560
1561/// Driver for Azure OpenAI deployments.
1562///
1563/// Uses the same request/response format as OpenAI but with Azure-specific
1564/// URL construction and API key header.
1565pub struct AzureOpenAiDriver {
1566    inner: OpenAiCompatibleDriver,
1567    resource: String,
1568    deployment: String,
1569    api_version: String,
1570}
1571
1572impl AzureOpenAiDriver {
1573    /// Create a new Azure OpenAI driver.
1574    ///
1575    /// - `api_key`: The Azure OpenAI API key.
1576    /// - `resource`: The Azure resource name (subdomain).
1577    /// - `deployment`: The deployment name.
1578    /// - `api_version`: API version string (e.g., "2024-02-01").
1579    pub fn new(
1580        api_key: String,
1581        resource: String,
1582        deployment: String,
1583        api_version: Option<String>,
1584    ) -> Self {
1585        let base_url = format!("https://{}.openai.azure.com", resource);
1586        Self {
1587            inner: OpenAiCompatibleDriver::new(api_key, base_url, "azure_openai".to_string()),
1588            resource,
1589            deployment,
1590            api_version: api_version.unwrap_or_else(|| "2024-02-01".to_string()),
1591        }
1592    }
1593
1594    /// Create a new Azure OpenAI driver with a shared HTTP client.
1595    pub fn with_client(
1596        client: Client,
1597        api_key: String,
1598        resource: String,
1599        deployment: String,
1600        api_version: Option<String>,
1601    ) -> Self {
1602        let base_url = format!("https://{}.openai.azure.com", resource);
1603        Self {
1604            inner: OpenAiCompatibleDriver::with_client(
1605                client,
1606                api_key,
1607                base_url,
1608                "azure_openai".to_string(),
1609            ),
1610            resource,
1611            deployment,
1612            api_version: api_version.unwrap_or_else(|| "2024-02-01".to_string()),
1613        }
1614    }
1615
1616    /// Build the Azure OpenAI endpoint URL.
1617    pub fn build_url(&self) -> String {
1618        format!(
1619            "https://{}.openai.azure.com/openai/deployments/{}/chat/completions?api-version={}",
1620            self.resource, self.deployment, self.api_version,
1621        )
1622    }
1623
1624    /// Get the resource name.
1625    pub fn resource(&self) -> &str {
1626        &self.resource
1627    }
1628
1629    /// Get the deployment name.
1630    pub fn deployment(&self) -> &str {
1631        &self.deployment
1632    }
1633
1634    /// Build request body (delegates to inner OpenAI-compatible driver).
1635    pub fn build_request_body(&self, request: &CompletionRequest) -> serde_json::Value {
1636        self.inner.build_request_body(request)
1637    }
1638
1639    /// Parse response (delegates to inner OpenAI-compatible driver).
1640    pub fn parse_response(&self, body: &serde_json::Value) -> PunchResult<CompletionResponse> {
1641        self.inner.parse_response(body)
1642    }
1643}
1644
1645#[async_trait]
1646impl LlmDriver for AzureOpenAiDriver {
1647    async fn complete(&self, request: CompletionRequest) -> PunchResult<CompletionResponse> {
1648        let url = self.build_url();
1649        let body = self.inner.build_request_body(&request);
1650
1651        let response = self
1652            .inner
1653            .client
1654            .post(&url)
1655            .header("api-key", &self.inner.api_key)
1656            .header("content-type", "application/json")
1657            .json(&body)
1658            .send()
1659            .await
1660            .map_err(|e| PunchError::Provider {
1661                provider: "azure_openai".to_string(),
1662                message: format!("request failed: {e}"),
1663            })?;
1664
1665        let status = response.status();
1666
1667        if status == reqwest::StatusCode::TOO_MANY_REQUESTS {
1668            let retry_after = response
1669                .headers()
1670                .get("retry-after")
1671                .and_then(|v| v.to_str().ok())
1672                .and_then(|s| s.parse::<u64>().ok())
1673                .unwrap_or(60)
1674                * 1000;
1675
1676            return Err(PunchError::RateLimited {
1677                provider: "azure_openai".to_string(),
1678                retry_after_ms: retry_after,
1679            });
1680        }
1681
1682        if status == reqwest::StatusCode::UNAUTHORIZED || status == reqwest::StatusCode::FORBIDDEN {
1683            return Err(PunchError::Auth(
1684                "Azure OpenAI API key is invalid or lacks permissions".to_string(),
1685            ));
1686        }
1687
1688        let response_body: serde_json::Value =
1689            response.json().await.map_err(|e| PunchError::Provider {
1690                provider: "azure_openai".to_string(),
1691                message: format!("failed to parse response: {e}"),
1692            })?;
1693
1694        if !status.is_success() {
1695            let error_msg = response_body["error"]["message"]
1696                .as_str()
1697                .unwrap_or("unknown error");
1698            return Err(PunchError::Provider {
1699                provider: "azure_openai".to_string(),
1700                message: format!("API error ({}): {}", status, error_msg),
1701            });
1702        }
1703
1704        self.inner.parse_response(&response_body)
1705    }
1706}
1707
1708// ---------------------------------------------------------------------------
1709// Factory
1710// ---------------------------------------------------------------------------
1711
1712/// Default base URLs for known providers.
1713fn default_base_url(provider: &Provider) -> &'static str {
1714    match provider {
1715        Provider::Anthropic => "https://api.anthropic.com",
1716        Provider::OpenAI => "https://api.openai.com",
1717        Provider::Google => "https://generativelanguage.googleapis.com",
1718        Provider::Groq => "https://api.groq.com/openai",
1719        Provider::DeepSeek => "https://api.deepseek.com",
1720        Provider::Ollama => "http://localhost:11434",
1721        Provider::Mistral => "https://api.mistral.ai",
1722        Provider::Together => "https://api.together.xyz",
1723        Provider::Fireworks => "https://api.fireworks.ai/inference",
1724        Provider::Cerebras => "https://api.cerebras.ai",
1725        Provider::XAI => "https://api.x.ai",
1726        Provider::Cohere => "https://api.cohere.ai",
1727        Provider::Bedrock => "https://bedrock-runtime.us-east-1.amazonaws.com",
1728        Provider::AzureOpenAi => "",
1729        Provider::Custom(_) => "",
1730    }
1731}
1732
1733/// Create an [`LlmDriver`] from a [`ModelConfig`].
1734///
1735/// Reads the API key from the environment variable specified in
1736/// `config.api_key_env`. Returns an error if the env var is missing
1737/// (except for Ollama which does not require auth).
1738/// Create a driver from config, optionally using a shared HTTP client.
1739///
1740/// If `shared_client` is `Some`, the driver will use that client for
1741/// connection pooling. Otherwise it creates its own client (backward compat).
1742pub fn create_driver(config: &ModelConfig) -> PunchResult<Arc<dyn LlmDriver>> {
1743    create_driver_with_client(config, None)
1744}
1745
1746/// Create a driver from config with an optional shared [`reqwest::Client`].
1747pub fn create_driver_with_client(
1748    config: &ModelConfig,
1749    shared_client: Option<&Client>,
1750) -> PunchResult<Arc<dyn LlmDriver>> {
1751    let api_key = match &config.api_key_env {
1752        Some(env_var) => std::env::var(env_var).map_err(|_| {
1753            PunchError::Auth(format!(
1754                "environment variable '{}' not set for {} driver",
1755                env_var, config.provider
1756            ))
1757        })?,
1758        None => {
1759            // Ollama typically has no auth; others will fail at the API.
1760            String::new()
1761        }
1762    };
1763
1764    let base_url = config
1765        .base_url
1766        .clone()
1767        .unwrap_or_else(|| default_base_url(&config.provider).to_string());
1768
1769    match &config.provider {
1770        Provider::Anthropic => {
1771            if let Some(client) = shared_client {
1772                Ok(Arc::new(AnthropicDriver::with_client(
1773                    client.clone(),
1774                    api_key,
1775                    Some(base_url),
1776                )))
1777            } else {
1778                Ok(Arc::new(AnthropicDriver::new(api_key, Some(base_url))))
1779            }
1780        }
1781        Provider::Google => {
1782            if let Some(client) = shared_client {
1783                Ok(Arc::new(GeminiDriver::with_client(
1784                    client.clone(),
1785                    api_key,
1786                    Some(base_url),
1787                )))
1788            } else {
1789                Ok(Arc::new(GeminiDriver::new(api_key, Some(base_url))))
1790            }
1791        }
1792        Provider::Ollama => {
1793            if let Some(client) = shared_client {
1794                Ok(Arc::new(OllamaDriver::with_client(
1795                    client.clone(),
1796                    Some(base_url),
1797                )))
1798            } else {
1799                Ok(Arc::new(OllamaDriver::new(Some(base_url))))
1800            }
1801        }
1802        Provider::Bedrock => {
1803            // For Bedrock, api_key is expected to be "ACCESS_KEY:SECRET_KEY" or
1804            // we read AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY from env.
1805            let (access_key, secret_key) = if api_key.contains(':') {
1806                let parts: Vec<&str> = api_key.splitn(2, ':').collect();
1807                (parts[0].to_string(), parts[1].to_string())
1808            } else {
1809                let ak = std::env::var("AWS_ACCESS_KEY_ID").unwrap_or(api_key);
1810                let sk = std::env::var("AWS_SECRET_ACCESS_KEY").unwrap_or_default();
1811                (ak, sk)
1812            };
1813            // Extract region from base_url or default to us-east-1.
1814            let region = if base_url.contains("bedrock-runtime.") {
1815                base_url
1816                    .trim_start_matches("https://bedrock-runtime.")
1817                    .split('.')
1818                    .next()
1819                    .unwrap_or("us-east-1")
1820                    .to_string()
1821            } else {
1822                "us-east-1".to_string()
1823            };
1824            if let Some(client) = shared_client {
1825                Ok(Arc::new(BedrockDriver::with_client(
1826                    client.clone(),
1827                    access_key,
1828                    secret_key,
1829                    region,
1830                )))
1831            } else {
1832                Ok(Arc::new(BedrockDriver::new(access_key, secret_key, region)))
1833            }
1834        }
1835        Provider::AzureOpenAi => {
1836            // For Azure, base_url should be "https://{resource}.openai.azure.com"
1837            // and model is the deployment name.
1838            let resource = if base_url.contains(".openai.azure.com") {
1839                base_url
1840                    .trim_start_matches("https://")
1841                    .split('.')
1842                    .next()
1843                    .unwrap_or("default")
1844                    .to_string()
1845            } else {
1846                base_url.clone()
1847            };
1848            let deployment = config.model.clone();
1849            if let Some(client) = shared_client {
1850                Ok(Arc::new(AzureOpenAiDriver::with_client(
1851                    client.clone(),
1852                    api_key,
1853                    resource,
1854                    deployment,
1855                    None,
1856                )))
1857            } else {
1858                Ok(Arc::new(AzureOpenAiDriver::new(
1859                    api_key,
1860                    resource,
1861                    deployment,
1862                    None,
1863                )))
1864            }
1865        }
1866        provider => {
1867            let name = provider.to_string();
1868            if let Some(client) = shared_client {
1869                Ok(Arc::new(OpenAiCompatibleDriver::with_client(
1870                    client.clone(),
1871                    api_key,
1872                    base_url,
1873                    name,
1874                )))
1875            } else {
1876                Ok(Arc::new(OpenAiCompatibleDriver::new(
1877                    api_key, base_url, name,
1878                )))
1879            }
1880        }
1881    }
1882}
1883
1884// ---------------------------------------------------------------------------
1885// Tests
1886// ---------------------------------------------------------------------------
1887
1888#[cfg(test)]
1889mod tests {
1890    use super::*;
1891    use punch_types::ToolCategory;
1892
1893    /// Helper to build a simple completion request for testing.
1894    fn simple_request() -> CompletionRequest {
1895        CompletionRequest {
1896            model: "test-model".to_string(),
1897            messages: vec![Message::new(Role::User, "Hello")],
1898            tools: Vec::new(),
1899            max_tokens: 4096,
1900            temperature: Some(0.7),
1901            system_prompt: Some("You are helpful.".to_string()),
1902        }
1903    }
1904
1905    /// Helper to build a request with tools.
1906    fn request_with_tools() -> CompletionRequest {
1907        CompletionRequest {
1908            model: "test-model".to_string(),
1909            messages: vec![Message::new(Role::User, "Use the tool")],
1910            tools: vec![ToolDefinition {
1911                name: "get_weather".to_string(),
1912                description: "Get weather for a city".to_string(),
1913                input_schema: serde_json::json!({
1914                    "type": "object",
1915                    "properties": {
1916                        "city": {"type": "string"}
1917                    }
1918                }),
1919                category: ToolCategory::Web,
1920            }],
1921            max_tokens: 4096,
1922            temperature: Some(0.7),
1923            system_prompt: None,
1924        }
1925    }
1926
1927    // -----------------------------------------------------------------------
1928    // Gemini tests
1929    // -----------------------------------------------------------------------
1930
1931    #[test]
1932    fn gemini_request_formatting() {
1933        let driver = GeminiDriver::new("test-key".to_string(), None);
1934        let body = driver.build_request_body(&simple_request());
1935
1936        let contents = body["contents"].as_array().unwrap();
1937        assert_eq!(contents.len(), 1);
1938        // System prompt should be prepended to user message.
1939        let first_text = contents[0]["parts"][0]["text"].as_str().unwrap();
1940        assert!(first_text.contains("You are helpful."));
1941        assert!(first_text.contains("Hello"));
1942        // Role should be "user" (not "system").
1943        assert_eq!(contents[0]["role"].as_str().unwrap(), "user");
1944
1945        assert_eq!(body["generationConfig"]["maxOutputTokens"], 4096);
1946        assert!((body["generationConfig"]["temperature"].as_f64().unwrap() - 0.7).abs() < 0.001);
1947    }
1948
1949    #[test]
1950    fn gemini_response_parsing() {
1951        let driver = GeminiDriver::new("test-key".to_string(), None);
1952        let response_body = serde_json::json!({
1953            "candidates": [{
1954                "content": {
1955                    "parts": [{"text": "Hello there!"}],
1956                    "role": "model"
1957                },
1958                "finishReason": "STOP"
1959            }],
1960            "usageMetadata": {
1961                "promptTokenCount": 10,
1962                "candidatesTokenCount": 5
1963            }
1964        });
1965
1966        let resp = driver.parse_response(&response_body).unwrap();
1967        assert_eq!(resp.message.content, "Hello there!");
1968        assert_eq!(resp.stop_reason, StopReason::EndTurn);
1969        assert_eq!(resp.usage.input_tokens, 10);
1970        assert_eq!(resp.usage.output_tokens, 5);
1971    }
1972
1973    #[test]
1974    fn gemini_role_mapping_system_prepended() {
1975        let driver = GeminiDriver::new("test-key".to_string(), None);
1976        let req = CompletionRequest {
1977            model: "gemini-pro".to_string(),
1978            messages: vec![
1979                Message::new(Role::System, "Be concise."),
1980                Message::new(Role::User, "Hi"),
1981            ],
1982            tools: Vec::new(),
1983            max_tokens: 1024,
1984            temperature: None,
1985            system_prompt: None,
1986        };
1987        let body = driver.build_request_body(&req);
1988        let contents = body["contents"].as_array().unwrap();
1989        // System message should be merged into user message.
1990        assert_eq!(contents.len(), 1);
1991        let text = contents[0]["parts"][0]["text"].as_str().unwrap();
1992        assert!(text.contains("Be concise."));
1993        assert!(text.contains("Hi"));
1994    }
1995
1996    #[test]
1997    fn gemini_function_call_parsing() {
1998        let driver = GeminiDriver::new("test-key".to_string(), None);
1999        let response_body = serde_json::json!({
2000            "candidates": [{
2001                "content": {
2002                    "parts": [
2003                        {"text": "Let me check the weather."},
2004                        {
2005                            "functionCall": {
2006                                "name": "get_weather",
2007                                "args": {"city": "London"}
2008                            }
2009                        }
2010                    ],
2011                    "role": "model"
2012                },
2013                "finishReason": "STOP"
2014            }],
2015            "usageMetadata": {
2016                "promptTokenCount": 15,
2017                "candidatesTokenCount": 8
2018            }
2019        });
2020
2021        let resp = driver.parse_response(&response_body).unwrap();
2022        assert_eq!(resp.message.content, "Let me check the weather.");
2023        assert_eq!(resp.stop_reason, StopReason::ToolUse);
2024        assert_eq!(resp.message.tool_calls.len(), 1);
2025        assert_eq!(resp.message.tool_calls[0].name, "get_weather");
2026        assert_eq!(resp.message.tool_calls[0].input["city"], "London");
2027    }
2028
2029    #[test]
2030    fn gemini_api_key_in_url() {
2031        let driver = GeminiDriver::new("my-secret-key".to_string(), None);
2032        let url = driver.build_url("gemini-pro");
2033        assert!(url.contains("key=my-secret-key"));
2034        assert!(url.contains("models/gemini-pro:generateContent"));
2035    }
2036
2037    // -----------------------------------------------------------------------
2038    // Ollama tests
2039    // -----------------------------------------------------------------------
2040
2041    #[test]
2042    fn ollama_request_formatting() {
2043        let driver = OllamaDriver::new(None);
2044        let body = driver.build_request_body(&simple_request());
2045
2046        assert_eq!(body["model"], "test-model");
2047        assert_eq!(body["stream"], false);
2048        let messages = body["messages"].as_array().unwrap();
2049        // system prompt + user message = 2 messages
2050        assert_eq!(messages.len(), 2);
2051        assert_eq!(messages[0]["role"], "system");
2052        assert_eq!(messages[0]["content"], "You are helpful.");
2053        assert_eq!(messages[1]["role"], "user");
2054        assert_eq!(messages[1]["content"], "Hello");
2055        assert!((body["options"]["temperature"].as_f64().unwrap() - 0.7).abs() < 0.001);
2056    }
2057
2058    #[test]
2059    fn ollama_response_parsing() {
2060        let driver = OllamaDriver::new(None);
2061        let response_body = serde_json::json!({
2062            "message": {
2063                "role": "assistant",
2064                "content": "Hi there!"
2065            },
2066            "done": true,
2067            "prompt_eval_count": 20,
2068            "eval_count": 10
2069        });
2070
2071        let resp = driver.parse_response(&response_body).unwrap();
2072        assert_eq!(resp.message.content, "Hi there!");
2073        assert_eq!(resp.stop_reason, StopReason::EndTurn);
2074        assert_eq!(resp.usage.input_tokens, 20);
2075        assert_eq!(resp.usage.output_tokens, 10);
2076    }
2077
2078    #[test]
2079    fn ollama_default_endpoint() {
2080        let driver = OllamaDriver::new(None);
2081        assert_eq!(driver.base_url(), "http://localhost:11434");
2082    }
2083
2084    #[test]
2085    fn ollama_custom_endpoint() {
2086        let driver = OllamaDriver::new(Some("http://myhost:9999".to_string()));
2087        assert_eq!(driver.base_url(), "http://myhost:9999");
2088    }
2089
2090    // -----------------------------------------------------------------------
2091    // Bedrock tests
2092    // -----------------------------------------------------------------------
2093
2094    #[test]
2095    fn bedrock_request_formatting() {
2096        let driver = BedrockDriver::new(
2097            "TESTKEY".to_string(),
2098            "testsecret".to_string(),
2099            "us-west-2".to_string(),
2100        );
2101        let body = driver.build_request_body(&simple_request());
2102
2103        let messages = body["messages"].as_array().unwrap();
2104        assert_eq!(messages.len(), 1);
2105        assert_eq!(messages[0]["role"], "user");
2106        assert_eq!(messages[0]["content"][0]["text"], "Hello");
2107
2108        assert_eq!(body["inferenceConfig"]["maxTokens"], 4096);
2109        assert!((body["inferenceConfig"]["temperature"].as_f64().unwrap() - 0.7).abs() < 0.001);
2110        assert_eq!(body["system"][0]["text"], "You are helpful.");
2111    }
2112
2113    #[test]
2114    fn bedrock_sigv4_canonical_request() {
2115        let driver = BedrockDriver::new(
2116            "TESTACCESS1234567890".to_string(),
2117            "TestSecretKeyValue1234567890abcdefghijk".to_string(),
2118            "us-east-1".to_string(),
2119        );
2120
2121        let payload = b"{}";
2122        let timestamp = "20260313T120000Z";
2123
2124        let auth = driver
2125            .sign_request(
2126                "POST",
2127                "https://bedrock-runtime.us-east-1.amazonaws.com/model/test/converse",
2128                &[("content-type".to_string(), "application/json".to_string())],
2129                payload,
2130                timestamp,
2131            )
2132            .unwrap();
2133
2134        assert!(auth.starts_with(
2135            "AWS4-HMAC-SHA256 Credential=TESTACCESS1234567890/20260313/us-east-1/bedrock/aws4_request"
2136        ));
2137        assert!(auth.contains("SignedHeaders=content-type;host;x-amz-date"));
2138        assert!(auth.contains("Signature="));
2139    }
2140
2141    #[test]
2142    fn bedrock_response_parsing() {
2143        let driver = BedrockDriver::new(
2144            "key".to_string(),
2145            "secret".to_string(),
2146            "us-east-1".to_string(),
2147        );
2148        let response_body = serde_json::json!({
2149            "output": {
2150                "message": {
2151                    "role": "assistant",
2152                    "content": [{"text": "The answer is 42."}]
2153                }
2154            },
2155            "stopReason": "end_turn",
2156            "usage": {
2157                "inputTokens": 100,
2158                "outputTokens": 50
2159            }
2160        });
2161
2162        let resp = driver.parse_response(&response_body).unwrap();
2163        assert_eq!(resp.message.content, "The answer is 42.");
2164        assert_eq!(resp.stop_reason, StopReason::EndTurn);
2165        assert_eq!(resp.usage.input_tokens, 100);
2166        assert_eq!(resp.usage.output_tokens, 50);
2167    }
2168
2169    // -----------------------------------------------------------------------
2170    // Azure OpenAI tests
2171    // -----------------------------------------------------------------------
2172
2173    #[test]
2174    fn azure_openai_url_construction() {
2175        let driver = AzureOpenAiDriver::new(
2176            "my-azure-key".to_string(),
2177            "myresource".to_string(),
2178            "gpt-4-deployment".to_string(),
2179            None,
2180        );
2181        let url = driver.build_url();
2182        assert_eq!(
2183            url,
2184            "https://myresource.openai.azure.com/openai/deployments/gpt-4-deployment/chat/completions?api-version=2024-02-01"
2185        );
2186    }
2187
2188    #[test]
2189    fn azure_openai_custom_api_version() {
2190        let driver = AzureOpenAiDriver::new(
2191            "key".to_string(),
2192            "res".to_string(),
2193            "dep".to_string(),
2194            Some("2024-06-01".to_string()),
2195        );
2196        let url = driver.build_url();
2197        assert!(url.contains("api-version=2024-06-01"));
2198    }
2199
2200    #[test]
2201    fn azure_openai_request_formatting() {
2202        let driver = AzureOpenAiDriver::new(
2203            "key".to_string(),
2204            "res".to_string(),
2205            "dep".to_string(),
2206            None,
2207        );
2208        let body = driver.build_request_body(&simple_request());
2209        // Should use OpenAI format.
2210        let messages = body["messages"].as_array().unwrap();
2211        // system prompt + user message = 2
2212        assert_eq!(messages.len(), 2);
2213        assert_eq!(messages[0]["role"], "system");
2214        assert_eq!(messages[1]["role"], "user");
2215        assert_eq!(body["model"], "test-model");
2216    }
2217
2218    #[test]
2219    fn azure_openai_resource_and_deployment() {
2220        let driver = AzureOpenAiDriver::new(
2221            "key".to_string(),
2222            "my-resource".to_string(),
2223            "my-deploy".to_string(),
2224            None,
2225        );
2226        assert_eq!(driver.resource(), "my-resource");
2227        assert_eq!(driver.deployment(), "my-deploy");
2228    }
2229
2230    // -----------------------------------------------------------------------
2231    // create_driver dispatch tests
2232    // -----------------------------------------------------------------------
2233
2234    #[test]
2235    fn create_driver_dispatches_ollama() {
2236        let config = ModelConfig {
2237            provider: Provider::Ollama,
2238            model: "llama3".to_string(),
2239            api_key_env: None,
2240            base_url: None,
2241            max_tokens: None,
2242            temperature: None,
2243        };
2244        // Ollama does not need an API key, so this should succeed.
2245        let driver = create_driver(&config);
2246        assert!(driver.is_ok());
2247    }
2248
2249    #[test]
2250    fn create_driver_dispatches_gemini() {
2251        // Set a fake env var for this test.
2252        // SAFETY: Test is single-threaded relative to this env var name.
2253        unsafe { std::env::set_var("TEST_GEMINI_KEY_DISPATCH", "fake-key") };
2254        let config = ModelConfig {
2255            provider: Provider::Google,
2256            model: "gemini-pro".to_string(),
2257            api_key_env: Some("TEST_GEMINI_KEY_DISPATCH".to_string()),
2258            base_url: None,
2259            max_tokens: None,
2260            temperature: None,
2261        };
2262        let driver = create_driver(&config);
2263        assert!(driver.is_ok());
2264        unsafe { std::env::remove_var("TEST_GEMINI_KEY_DISPATCH") };
2265    }
2266
2267    #[test]
2268    fn create_driver_dispatches_bedrock() {
2269        // SAFETY: Test is single-threaded relative to this env var name.
2270        unsafe { std::env::set_var("TEST_BEDROCK_KEY_DISPATCH", "TESTKEY:TESTSECRET") };
2271        let config = ModelConfig {
2272            provider: Provider::Bedrock,
2273            model: "anthropic.claude-v2".to_string(),
2274            api_key_env: Some("TEST_BEDROCK_KEY_DISPATCH".to_string()),
2275            base_url: None,
2276            max_tokens: None,
2277            temperature: None,
2278        };
2279        let driver = create_driver(&config);
2280        assert!(driver.is_ok());
2281        unsafe { std::env::remove_var("TEST_BEDROCK_KEY_DISPATCH") };
2282    }
2283
2284    #[test]
2285    fn create_driver_dispatches_azure_openai() {
2286        // SAFETY: Test is single-threaded relative to this env var name.
2287        unsafe { std::env::set_var("TEST_AZURE_KEY_DISPATCH", "azure-key") };
2288        let config = ModelConfig {
2289            provider: Provider::AzureOpenAi,
2290            model: "gpt-4".to_string(),
2291            api_key_env: Some("TEST_AZURE_KEY_DISPATCH".to_string()),
2292            base_url: Some("https://myres.openai.azure.com".to_string()),
2293            max_tokens: None,
2294            temperature: None,
2295        };
2296        let driver = create_driver(&config);
2297        assert!(driver.is_ok());
2298        unsafe { std::env::remove_var("TEST_AZURE_KEY_DISPATCH") };
2299    }
2300
2301    #[test]
2302    fn gemini_tools_in_request() {
2303        let driver = GeminiDriver::new("key".to_string(), None);
2304        let body = driver.build_request_body(&request_with_tools());
2305
2306        let tools = body["tools"].as_array().unwrap();
2307        assert_eq!(tools.len(), 1);
2308        let func_decls = tools[0]["function_declarations"].as_array().unwrap();
2309        assert_eq!(func_decls.len(), 1);
2310        assert_eq!(func_decls[0]["name"], "get_weather");
2311    }
2312
2313    #[test]
2314    fn ollama_tools_in_request() {
2315        let driver = OllamaDriver::new(None);
2316        let body = driver.build_request_body(&request_with_tools());
2317
2318        let tools = body["tools"].as_array().unwrap();
2319        assert_eq!(tools.len(), 1);
2320        assert_eq!(tools[0]["type"], "function");
2321        assert_eq!(tools[0]["function"]["name"], "get_weather");
2322    }
2323
2324    #[test]
2325    fn bedrock_url_construction() {
2326        let driver = BedrockDriver::new(
2327            "key".to_string(),
2328            "secret".to_string(),
2329            "eu-west-1".to_string(),
2330        );
2331        let url = driver.build_url("anthropic.claude-3-sonnet");
2332        assert_eq!(
2333            url,
2334            "https://bedrock-runtime.eu-west-1.amazonaws.com/model/anthropic.claude-3-sonnet/converse"
2335        );
2336    }
2337
2338    // -----------------------------------------------------------------------
2339    // TokenUsage tests
2340    // -----------------------------------------------------------------------
2341
2342    #[test]
2343    fn token_usage_default() {
2344        let u = TokenUsage::default();
2345        assert_eq!(u.input_tokens, 0);
2346        assert_eq!(u.output_tokens, 0);
2347        assert_eq!(u.total(), 0);
2348    }
2349
2350    #[test]
2351    fn token_usage_accumulate() {
2352        let mut u = TokenUsage { input_tokens: 10, output_tokens: 20 };
2353        let other = TokenUsage { input_tokens: 5, output_tokens: 15 };
2354        u.accumulate(&other);
2355        assert_eq!(u.input_tokens, 15);
2356        assert_eq!(u.output_tokens, 35);
2357        assert_eq!(u.total(), 50);
2358    }
2359
2360    #[test]
2361    fn token_usage_total() {
2362        let u = TokenUsage { input_tokens: 100, output_tokens: 200 };
2363        assert_eq!(u.total(), 300);
2364    }
2365
2366    // -----------------------------------------------------------------------
2367    // StopReason serialization
2368    // -----------------------------------------------------------------------
2369
2370    #[test]
2371    fn stop_reason_serialization() {
2372        let json = serde_json::to_string(&StopReason::EndTurn).unwrap();
2373        assert_eq!(json, "\"end_turn\"");
2374
2375        let json = serde_json::to_string(&StopReason::ToolUse).unwrap();
2376        assert_eq!(json, "\"tool_use\"");
2377
2378        let json = serde_json::to_string(&StopReason::MaxTokens).unwrap();
2379        assert_eq!(json, "\"max_tokens\"");
2380
2381        let json = serde_json::to_string(&StopReason::Error).unwrap();
2382        assert_eq!(json, "\"error\"");
2383    }
2384
2385    #[test]
2386    fn stop_reason_deserialization() {
2387        let sr: StopReason = serde_json::from_str("\"end_turn\"").unwrap();
2388        assert_eq!(sr, StopReason::EndTurn);
2389
2390        let sr: StopReason = serde_json::from_str("\"tool_use\"").unwrap();
2391        assert_eq!(sr, StopReason::ToolUse);
2392    }
2393
2394    // -----------------------------------------------------------------------
2395    // Anthropic driver tests
2396    // -----------------------------------------------------------------------
2397
2398    #[test]
2399    fn anthropic_request_body_simple() {
2400        let driver = AnthropicDriver::new("test-key".to_string(), None);
2401        let body = driver.build_request_body(&simple_request());
2402
2403        assert_eq!(body["model"], "test-model");
2404        assert_eq!(body["max_tokens"], 4096);
2405        assert_eq!(body["system"], "You are helpful.");
2406        assert!((body["temperature"].as_f64().unwrap() - 0.7).abs() < 0.001);
2407
2408        let messages = body["messages"].as_array().unwrap();
2409        assert_eq!(messages.len(), 1);
2410        assert_eq!(messages[0]["role"], "user");
2411        assert_eq!(messages[0]["content"], "Hello");
2412    }
2413
2414    #[test]
2415    fn anthropic_request_body_with_tools() {
2416        let driver = AnthropicDriver::new("test-key".to_string(), None);
2417        let body = driver.build_request_body(&request_with_tools());
2418
2419        let tools = body["tools"].as_array().unwrap();
2420        assert_eq!(tools.len(), 1);
2421        assert_eq!(tools[0]["name"], "get_weather");
2422        assert!(tools[0]["input_schema"]["properties"].is_object());
2423    }
2424
2425    #[test]
2426    fn anthropic_request_body_no_system_prompt() {
2427        let driver = AnthropicDriver::new("test-key".to_string(), None);
2428        let req = CompletionRequest {
2429            model: "test".into(),
2430            messages: vec![Message::new(Role::User, "Hi")],
2431            tools: Vec::new(),
2432            max_tokens: 100,
2433            temperature: None,
2434            system_prompt: None,
2435        };
2436        let body = driver.build_request_body(&req);
2437        assert!(body.get("system").is_none());
2438        assert!(body.get("temperature").is_none());
2439    }
2440
2441    #[test]
2442    fn anthropic_parse_response_text() {
2443        let driver = AnthropicDriver::new("test-key".to_string(), None);
2444        let response_body = serde_json::json!({
2445            "content": [
2446                {"type": "text", "text": "Hello!"}
2447            ],
2448            "stop_reason": "end_turn",
2449            "usage": {
2450                "input_tokens": 10,
2451                "output_tokens": 5
2452            }
2453        });
2454
2455        let resp = driver.parse_response(&response_body).unwrap();
2456        assert_eq!(resp.message.content, "Hello!");
2457        assert_eq!(resp.stop_reason, StopReason::EndTurn);
2458        assert_eq!(resp.usage.input_tokens, 10);
2459        assert_eq!(resp.usage.output_tokens, 5);
2460        assert!(resp.message.tool_calls.is_empty());
2461    }
2462
2463    #[test]
2464    fn anthropic_parse_response_tool_use() {
2465        let driver = AnthropicDriver::new("test-key".to_string(), None);
2466        let response_body = serde_json::json!({
2467            "content": [
2468                {"type": "text", "text": "Let me check."},
2469                {
2470                    "type": "tool_use",
2471                    "id": "tool_abc",
2472                    "name": "get_weather",
2473                    "input": {"city": "NYC"}
2474                }
2475            ],
2476            "stop_reason": "tool_use",
2477            "usage": {"input_tokens": 20, "output_tokens": 15}
2478        });
2479
2480        let resp = driver.parse_response(&response_body).unwrap();
2481        assert_eq!(resp.message.content, "Let me check.");
2482        assert_eq!(resp.stop_reason, StopReason::ToolUse);
2483        assert_eq!(resp.message.tool_calls.len(), 1);
2484        assert_eq!(resp.message.tool_calls[0].id, "tool_abc");
2485        assert_eq!(resp.message.tool_calls[0].name, "get_weather");
2486        assert_eq!(resp.message.tool_calls[0].input["city"], "NYC");
2487    }
2488
2489    #[test]
2490    fn anthropic_parse_response_max_tokens() {
2491        let driver = AnthropicDriver::new("test-key".to_string(), None);
2492        let response_body = serde_json::json!({
2493            "content": [{"type": "text", "text": "truncated"}],
2494            "stop_reason": "max_tokens",
2495            "usage": {"input_tokens": 5, "output_tokens": 100}
2496        });
2497
2498        let resp = driver.parse_response(&response_body).unwrap();
2499        assert_eq!(resp.stop_reason, StopReason::MaxTokens);
2500    }
2501
2502    #[test]
2503    fn anthropic_parse_response_unknown_stop_reason() {
2504        let driver = AnthropicDriver::new("test-key".to_string(), None);
2505        let response_body = serde_json::json!({
2506            "content": [{"type": "text", "text": "err"}],
2507            "stop_reason": "something_unknown",
2508            "usage": {"input_tokens": 0, "output_tokens": 0}
2509        });
2510
2511        let resp = driver.parse_response(&response_body).unwrap();
2512        assert_eq!(resp.stop_reason, StopReason::Error);
2513    }
2514
2515    #[test]
2516    fn anthropic_request_body_with_assistant_and_tool_messages() {
2517        let driver = AnthropicDriver::new("test-key".to_string(), None);
2518        let req = CompletionRequest {
2519            model: "test".into(),
2520            messages: vec![
2521                Message::new(Role::User, "Hi"),
2522                Message {
2523                    role: Role::Assistant,
2524                    content: "I'll check".into(),
2525                    tool_calls: vec![ToolCall {
2526                        id: "call_1".into(),
2527                        name: "file_read".into(),
2528                        input: serde_json::json!({"path": "/tmp/test"}),
2529                    }],
2530                    tool_results: Vec::new(),
2531                    timestamp: chrono::Utc::now(),
2532                },
2533                Message {
2534                    role: Role::Tool,
2535                    content: String::new(),
2536                    tool_calls: Vec::new(),
2537                    tool_results: vec![punch_types::ToolCallResult {
2538                        id: "call_1".into(),
2539                        content: "file contents".into(),
2540                        is_error: false,
2541                    }],
2542                    timestamp: chrono::Utc::now(),
2543                },
2544            ],
2545            tools: Vec::new(),
2546            max_tokens: 100,
2547            temperature: None,
2548            system_prompt: None,
2549        };
2550
2551        let body = driver.build_request_body(&req);
2552        let messages = body["messages"].as_array().unwrap();
2553        assert_eq!(messages.len(), 3);
2554        assert_eq!(messages[0]["role"], "user");
2555        assert_eq!(messages[1]["role"], "assistant");
2556        assert_eq!(messages[2]["role"], "user"); // Tool results go as user role
2557    }
2558
2559    #[test]
2560    fn anthropic_request_body_system_message_skipped() {
2561        let driver = AnthropicDriver::new("test-key".to_string(), None);
2562        let req = CompletionRequest {
2563            model: "test".into(),
2564            messages: vec![
2565                Message::new(Role::System, "System instruction"),
2566                Message::new(Role::User, "Hi"),
2567            ],
2568            tools: Vec::new(),
2569            max_tokens: 100,
2570            temperature: None,
2571            system_prompt: None,
2572        };
2573
2574        let body = driver.build_request_body(&req);
2575        let messages = body["messages"].as_array().unwrap();
2576        // System messages are skipped in messages array
2577        assert_eq!(messages.len(), 1);
2578        assert_eq!(messages[0]["role"], "user");
2579    }
2580
2581    // -----------------------------------------------------------------------
2582    // OpenAI-compatible driver tests
2583    // -----------------------------------------------------------------------
2584
2585    #[test]
2586    fn openai_request_body_simple() {
2587        let driver = OpenAiCompatibleDriver::new(
2588            "key".into(),
2589            "https://api.openai.com".into(),
2590            "openai".into(),
2591        );
2592        let body = driver.build_request_body(&simple_request());
2593
2594        assert_eq!(body["model"], "test-model");
2595        let messages = body["messages"].as_array().unwrap();
2596        assert_eq!(messages.len(), 2);
2597        assert_eq!(messages[0]["role"], "system");
2598        assert_eq!(messages[0]["content"], "You are helpful.");
2599        assert_eq!(messages[1]["role"], "user");
2600    }
2601
2602    #[test]
2603    fn openai_request_body_with_tools() {
2604        let driver = OpenAiCompatibleDriver::new(
2605            "key".into(),
2606            "https://api.openai.com".into(),
2607            "openai".into(),
2608        );
2609        let body = driver.build_request_body(&request_with_tools());
2610
2611        let tools = body["tools"].as_array().unwrap();
2612        assert_eq!(tools.len(), 1);
2613        assert_eq!(tools[0]["type"], "function");
2614        assert_eq!(tools[0]["function"]["name"], "get_weather");
2615    }
2616
2617    #[test]
2618    fn openai_parse_response_text() {
2619        let driver = OpenAiCompatibleDriver::new(
2620            "key".into(),
2621            "https://api.openai.com".into(),
2622            "openai".into(),
2623        );
2624        let response_body = serde_json::json!({
2625            "choices": [{
2626                "message": {
2627                    "role": "assistant",
2628                    "content": "Hello!"
2629                },
2630                "finish_reason": "stop"
2631            }],
2632            "usage": {
2633                "prompt_tokens": 10,
2634                "completion_tokens": 5
2635            }
2636        });
2637
2638        let resp = driver.parse_response(&response_body).unwrap();
2639        assert_eq!(resp.message.content, "Hello!");
2640        assert_eq!(resp.stop_reason, StopReason::EndTurn);
2641        assert_eq!(resp.usage.input_tokens, 10);
2642        assert_eq!(resp.usage.output_tokens, 5);
2643    }
2644
2645    #[test]
2646    fn openai_parse_response_tool_calls() {
2647        let driver = OpenAiCompatibleDriver::new(
2648            "key".into(),
2649            "https://api.openai.com".into(),
2650            "openai".into(),
2651        );
2652        let response_body = serde_json::json!({
2653            "choices": [{
2654                "message": {
2655                    "role": "assistant",
2656                    "content": null,
2657                    "tool_calls": [{
2658                        "id": "call_123",
2659                        "type": "function",
2660                        "function": {
2661                            "name": "get_weather",
2662                            "arguments": "{\"city\": \"NYC\"}"
2663                        }
2664                    }]
2665                },
2666                "finish_reason": "tool_calls"
2667            }],
2668            "usage": {"prompt_tokens": 10, "completion_tokens": 5}
2669        });
2670
2671        let resp = driver.parse_response(&response_body).unwrap();
2672        assert_eq!(resp.stop_reason, StopReason::ToolUse);
2673        assert_eq!(resp.message.tool_calls.len(), 1);
2674        assert_eq!(resp.message.tool_calls[0].name, "get_weather");
2675        assert_eq!(resp.message.tool_calls[0].input["city"], "NYC");
2676    }
2677
2678    #[test]
2679    fn openai_parse_response_tool_calls_fix_stop_reason() {
2680        let driver = OpenAiCompatibleDriver::new(
2681            "key".into(),
2682            "https://api.openai.com".into(),
2683            "openai".into(),
2684        );
2685        // finish_reason is "stop" but there are tool_calls — should fix to ToolUse
2686        let response_body = serde_json::json!({
2687            "choices": [{
2688                "message": {
2689                    "role": "assistant",
2690                    "content": "Using tool",
2691                    "tool_calls": [{
2692                        "id": "call_1",
2693                        "type": "function",
2694                        "function": {
2695                            "name": "test_tool",
2696                            "arguments": "{}"
2697                        }
2698                    }]
2699                },
2700                "finish_reason": "stop"
2701            }],
2702            "usage": {"prompt_tokens": 0, "completion_tokens": 0}
2703        });
2704
2705        let resp = driver.parse_response(&response_body).unwrap();
2706        assert_eq!(resp.stop_reason, StopReason::ToolUse);
2707    }
2708
2709    #[test]
2710    fn openai_parse_response_length_stop_reason() {
2711        let driver = OpenAiCompatibleDriver::new(
2712            "key".into(),
2713            "https://api.openai.com".into(),
2714            "openai".into(),
2715        );
2716        let response_body = serde_json::json!({
2717            "choices": [{
2718                "message": {"role": "assistant", "content": "cut off"},
2719                "finish_reason": "length"
2720            }],
2721            "usage": {"prompt_tokens": 0, "completion_tokens": 0}
2722        });
2723
2724        let resp = driver.parse_response(&response_body).unwrap();
2725        assert_eq!(resp.stop_reason, StopReason::MaxTokens);
2726    }
2727
2728    #[test]
2729    fn openai_parse_response_no_choices_error() {
2730        let driver = OpenAiCompatibleDriver::new(
2731            "key".into(),
2732            "https://api.openai.com".into(),
2733            "openai".into(),
2734        );
2735        let response_body = serde_json::json!({"choices": []});
2736
2737        let result = driver.parse_response(&response_body);
2738        assert!(result.is_err());
2739    }
2740
2741    // -----------------------------------------------------------------------
2742    // Gemini driver additional tests
2743    // -----------------------------------------------------------------------
2744
2745    #[test]
2746    fn gemini_assistant_message_formatting() {
2747        let driver = GeminiDriver::new("key".to_string(), None);
2748        let req = CompletionRequest {
2749            model: "gemini-pro".into(),
2750            messages: vec![
2751                Message::new(Role::User, "Hi"),
2752                Message {
2753                    role: Role::Assistant,
2754                    content: "Let me help".into(),
2755                    tool_calls: vec![ToolCall {
2756                        id: "tc1".into(),
2757                        name: "get_weather".into(),
2758                        input: serde_json::json!({"city": "NYC"}),
2759                    }],
2760                    tool_results: Vec::new(),
2761                    timestamp: chrono::Utc::now(),
2762                },
2763            ],
2764            tools: Vec::new(),
2765            max_tokens: 100,
2766            temperature: None,
2767            system_prompt: None,
2768        };
2769
2770        let body = driver.build_request_body(&req);
2771        let contents = body["contents"].as_array().unwrap();
2772        assert_eq!(contents[1]["role"], "model"); // Gemini uses "model" not "assistant"
2773        let parts = contents[1]["parts"].as_array().unwrap();
2774        assert!(parts.len() >= 2); // text part + functionCall part
2775    }
2776
2777    #[test]
2778    fn gemini_max_tokens_stop_reason() {
2779        let driver = GeminiDriver::new("key".to_string(), None);
2780        let response_body = serde_json::json!({
2781            "candidates": [{
2782                "content": {
2783                    "parts": [{"text": "truncated"}],
2784                    "role": "model"
2785                },
2786                "finishReason": "MAX_TOKENS"
2787            }],
2788            "usageMetadata": {"promptTokenCount": 0, "candidatesTokenCount": 0}
2789        });
2790
2791        let resp = driver.parse_response(&response_body).unwrap();
2792        assert_eq!(resp.stop_reason, StopReason::MaxTokens);
2793    }
2794
2795    #[test]
2796    fn gemini_custom_base_url() {
2797        let driver = GeminiDriver::new("key".to_string(), Some("https://custom.example.com".into()));
2798        let url = driver.build_url("gemini-pro");
2799        assert!(url.starts_with("https://custom.example.com/"));
2800    }
2801
2802    // -----------------------------------------------------------------------
2803    // Ollama driver additional tests
2804    // -----------------------------------------------------------------------
2805
2806    #[test]
2807    fn ollama_response_with_tool_calls() {
2808        let driver = OllamaDriver::new(None);
2809        let response_body = serde_json::json!({
2810            "message": {
2811                "role": "assistant",
2812                "content": "",
2813                "tool_calls": [{
2814                    "function": {
2815                        "name": "get_weather",
2816                        "arguments": {"city": "London"}
2817                    }
2818                }]
2819            },
2820            "done": true,
2821            "prompt_eval_count": 10,
2822            "eval_count": 5
2823        });
2824
2825        let resp = driver.parse_response(&response_body).unwrap();
2826        assert_eq!(resp.stop_reason, StopReason::ToolUse);
2827        assert_eq!(resp.message.tool_calls.len(), 1);
2828        assert_eq!(resp.message.tool_calls[0].name, "get_weather");
2829    }
2830
2831    #[test]
2832    fn ollama_response_not_done() {
2833        let driver = OllamaDriver::new(None);
2834        let response_body = serde_json::json!({
2835            "message": {"role": "assistant", "content": "partial"},
2836            "done": false,
2837            "prompt_eval_count": 10,
2838            "eval_count": 5
2839        });
2840
2841        let resp = driver.parse_response(&response_body).unwrap();
2842        assert_eq!(resp.stop_reason, StopReason::MaxTokens);
2843    }
2844
2845    // -----------------------------------------------------------------------
2846    // Bedrock driver additional tests
2847    // -----------------------------------------------------------------------
2848
2849    #[test]
2850    fn bedrock_request_with_tools() {
2851        let driver = BedrockDriver::new("key".into(), "secret".into(), "us-east-1".into());
2852        let body = driver.build_request_body(&request_with_tools());
2853
2854        let tool_config = &body["toolConfig"]["tools"];
2855        assert!(tool_config.is_array());
2856        let tools = tool_config.as_array().unwrap();
2857        assert_eq!(tools.len(), 1);
2858        assert_eq!(tools[0]["toolSpec"]["name"], "get_weather");
2859    }
2860
2861    #[test]
2862    fn bedrock_response_with_tool_use() {
2863        let driver = BedrockDriver::new("key".into(), "secret".into(), "us-east-1".into());
2864        let response_body = serde_json::json!({
2865            "output": {
2866                "message": {
2867                    "role": "assistant",
2868                    "content": [
2869                        {"text": "Using tool"},
2870                        {"toolUse": {
2871                            "toolUseId": "tu_123",
2872                            "name": "get_weather",
2873                            "input": {"city": "NYC"}
2874                        }}
2875                    ]
2876                }
2877            },
2878            "stopReason": "tool_use",
2879            "usage": {"inputTokens": 10, "outputTokens": 20}
2880        });
2881
2882        let resp = driver.parse_response(&response_body).unwrap();
2883        assert_eq!(resp.stop_reason, StopReason::ToolUse);
2884        assert_eq!(resp.message.tool_calls.len(), 1);
2885        assert_eq!(resp.message.tool_calls[0].id, "tu_123");
2886        assert_eq!(resp.message.tool_calls[0].name, "get_weather");
2887    }
2888
2889    #[test]
2890    fn bedrock_request_with_tool_results() {
2891        let driver = BedrockDriver::new("key".into(), "secret".into(), "us-east-1".into());
2892        let req = CompletionRequest {
2893            model: "test".into(),
2894            messages: vec![
2895                Message::new(Role::User, "Hi"),
2896                Message {
2897                    role: Role::Tool,
2898                    content: String::new(),
2899                    tool_calls: Vec::new(),
2900                    tool_results: vec![punch_types::ToolCallResult {
2901                        id: "tu_1".into(),
2902                        content: "result data".into(),
2903                        is_error: false,
2904                    }],
2905                    timestamp: chrono::Utc::now(),
2906                },
2907            ],
2908            tools: Vec::new(),
2909            max_tokens: 100,
2910            temperature: None,
2911            system_prompt: None,
2912        };
2913
2914        let body = driver.build_request_body(&req);
2915        let messages = body["messages"].as_array().unwrap();
2916        assert_eq!(messages[1]["role"], "user"); // Bedrock sends tool results as user
2917        let content = messages[1]["content"].as_array().unwrap();
2918        assert!(content[0]["toolResult"].is_object());
2919        assert_eq!(content[0]["toolResult"]["status"], "success");
2920    }
2921
2922    #[test]
2923    fn bedrock_url_different_regions() {
2924        let driver = BedrockDriver::new("k".into(), "s".into(), "ap-southeast-1".into());
2925        let url = driver.build_url("model-id");
2926        assert!(url.contains("ap-southeast-1"));
2927    }
2928
2929    // -----------------------------------------------------------------------
2930    // Azure OpenAI additional tests
2931    // -----------------------------------------------------------------------
2932
2933    #[test]
2934    fn azure_openai_delegates_parse_to_openai() {
2935        let driver = AzureOpenAiDriver::new(
2936            "key".into(), "res".into(), "dep".into(), None,
2937        );
2938        let response_body = serde_json::json!({
2939            "choices": [{
2940                "message": {"role": "assistant", "content": "Azure response"},
2941                "finish_reason": "stop"
2942            }],
2943            "usage": {"prompt_tokens": 5, "completion_tokens": 3}
2944        });
2945
2946        let resp = driver.parse_response(&response_body).unwrap();
2947        assert_eq!(resp.message.content, "Azure response");
2948    }
2949
2950    // -----------------------------------------------------------------------
2951    // default_base_url tests
2952    // -----------------------------------------------------------------------
2953
2954    #[test]
2955    fn default_base_url_anthropic() {
2956        assert_eq!(default_base_url(&Provider::Anthropic), "https://api.anthropic.com");
2957    }
2958
2959    #[test]
2960    fn default_base_url_openai() {
2961        assert_eq!(default_base_url(&Provider::OpenAI), "https://api.openai.com");
2962    }
2963
2964    #[test]
2965    fn default_base_url_google() {
2966        assert_eq!(default_base_url(&Provider::Google), "https://generativelanguage.googleapis.com");
2967    }
2968
2969    #[test]
2970    fn default_base_url_ollama() {
2971        assert_eq!(default_base_url(&Provider::Ollama), "http://localhost:11434");
2972    }
2973
2974    #[test]
2975    fn default_base_url_groq() {
2976        assert_eq!(default_base_url(&Provider::Groq), "https://api.groq.com/openai");
2977    }
2978
2979    #[test]
2980    fn default_base_url_deepseek() {
2981        assert_eq!(default_base_url(&Provider::DeepSeek), "https://api.deepseek.com");
2982    }
2983
2984    // -----------------------------------------------------------------------
2985    // hex_sha256 and hex_encode tests
2986    // -----------------------------------------------------------------------
2987
2988    #[test]
2989    fn test_hex_sha256() {
2990        let hash = hex_sha256(b"");
2991        assert_eq!(hash, "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855");
2992    }
2993
2994    #[test]
2995    fn test_hex_encode() {
2996        assert_eq!(hex_encode(&[0x00, 0xff, 0x0a, 0xbc]), "00ff0abc");
2997    }
2998
2999    #[test]
3000    fn test_hmac_sha256_basic() {
3001        let result = hmac_sha256(b"key", b"data");
3002        assert!(!result.is_empty());
3003        assert_eq!(result.len(), 32); // SHA-256 produces 32 bytes
3004    }
3005
3006    // -----------------------------------------------------------------------
3007    // create_driver error cases
3008    // -----------------------------------------------------------------------
3009
3010    #[test]
3011    fn create_driver_missing_api_key_env() {
3012        let config = ModelConfig {
3013            provider: Provider::Anthropic,
3014            model: "claude-3".into(),
3015            api_key_env: Some("PUNCH_TEST_NONEXISTENT_KEY_XYZ".into()),
3016            base_url: None,
3017            max_tokens: None,
3018            temperature: None,
3019        };
3020        let result = create_driver(&config);
3021        assert!(result.is_err());
3022    }
3023
3024    #[test]
3025    fn create_driver_openai_compatible_fallback() {
3026        // Custom provider should fall through to OpenAI-compatible
3027        unsafe { std::env::set_var("TEST_CUSTOM_KEY_DRIVER", "fake-key") };
3028        let config = ModelConfig {
3029            provider: Provider::Custom("my-custom".into()),
3030            model: "custom-model".into(),
3031            api_key_env: Some("TEST_CUSTOM_KEY_DRIVER".into()),
3032            base_url: Some("https://custom.api.com".into()),
3033            max_tokens: None,
3034            temperature: None,
3035        };
3036        let result = create_driver(&config);
3037        assert!(result.is_ok());
3038        unsafe { std::env::remove_var("TEST_CUSTOM_KEY_DRIVER") };
3039    }
3040
3041    // -----------------------------------------------------------------------
3042    // strip_thinking_tags tests
3043    // -----------------------------------------------------------------------
3044
3045    #[test]
3046    fn strip_thinking_tags_removes_think_block() {
3047        let input = "<think>internal reasoning here</think>The answer is 42.";
3048        assert_eq!(strip_thinking_tags(input), "The answer is 42.");
3049    }
3050
3051    #[test]
3052    fn strip_thinking_tags_removes_thinking_block() {
3053        let input = "<thinking>step by step reasoning</thinking>Hello world!";
3054        assert_eq!(strip_thinking_tags(input), "Hello world!");
3055    }
3056
3057    #[test]
3058    fn strip_thinking_tags_removes_reasoning_block() {
3059        let input = "<reasoning>let me figure this out</reasoning>The result is correct.";
3060        assert_eq!(strip_thinking_tags(input), "The result is correct.");
3061    }
3062
3063    #[test]
3064    fn strip_thinking_tags_removes_reflection_block() {
3065        let input = "<reflection>checking my work</reflection>Yes, that's right.";
3066        assert_eq!(strip_thinking_tags(input), "Yes, that's right.");
3067    }
3068
3069    #[test]
3070    fn strip_thinking_tags_removes_multiple_blocks() {
3071        let input = "<think>first thought</think>Hello <thinking>second thought</thinking>world!";
3072        assert_eq!(strip_thinking_tags(input), "Hello world!");
3073    }
3074
3075    #[test]
3076    fn strip_thinking_tags_preserves_content_without_tags() {
3077        let input = "Just a normal response with no thinking tags.";
3078        assert_eq!(strip_thinking_tags(input), input);
3079    }
3080
3081    #[test]
3082    fn strip_thinking_tags_handles_multiline_tags() {
3083        let input = "<think>\nLine 1\nLine 2\nLine 3\n</think>\nThe final answer.";
3084        assert_eq!(strip_thinking_tags(input), "The final answer.");
3085    }
3086
3087    #[test]
3088    fn strip_thinking_tags_returns_original_if_all_thinking() {
3089        // If the entire response is thinking with no visible output,
3090        // return the original so the user sees something.
3091        let input = "<think>this is all thinking content and nothing else</think>";
3092        assert_eq!(strip_thinking_tags(input), input);
3093    }
3094
3095    #[test]
3096    fn strip_thinking_tags_handles_unclosed_tag() {
3097        let input = "Some text<think>unclosed thinking block";
3098        assert_eq!(strip_thinking_tags(input), "Some text");
3099    }
3100
3101    #[test]
3102    fn strip_thinking_tags_handles_empty_input() {
3103        assert_eq!(strip_thinking_tags(""), "");
3104    }
3105
3106    #[test]
3107    fn strip_thinking_tags_handles_empty_think_block() {
3108        let input = "<think></think>Visible content.";
3109        assert_eq!(strip_thinking_tags(input), "Visible content.");
3110    }
3111
3112    #[test]
3113    fn strip_thinking_tags_trims_whitespace() {
3114        let input = "  <think>reasoning</think>  Result  ";
3115        assert_eq!(strip_thinking_tags(input), "Result");
3116    }
3117
3118    #[test]
3119    fn strip_thinking_tags_mixed_tag_types() {
3120        let input = "<think>t1</think>A<reasoning>r1</reasoning>B<reflection>f1</reflection>C";
3121        assert_eq!(strip_thinking_tags(input), "ABC");
3122    }
3123
3124    #[test]
3125    fn ollama_response_strips_thinking_tags() {
3126        let driver = OllamaDriver::new(None);
3127        let response_body = serde_json::json!({
3128            "message": {
3129                "role": "assistant",
3130                "content": "<think>\nLet me think about this...\nThe user wants hello world.\n</think>\nHello, world!"
3131            },
3132            "done": true,
3133            "prompt_eval_count": 20,
3134            "eval_count": 50
3135        });
3136
3137        let resp = driver.parse_response(&response_body).unwrap();
3138        assert_eq!(resp.message.content, "Hello, world!");
3139        assert!(!resp.message.content.contains("<think>"));
3140    }
3141
3142    #[test]
3143    fn gemini_response_strips_thinking_tags() {
3144        let driver = GeminiDriver::new("test-key".to_string(), None);
3145        let response_body = serde_json::json!({
3146            "candidates": [{
3147                "content": {
3148                    "parts": [{"text": "<thinking>reasoning step</thinking>The answer is 7."}],
3149                    "role": "model"
3150                },
3151                "finishReason": "STOP"
3152            }],
3153            "usageMetadata": {
3154                "promptTokenCount": 10,
3155                "candidatesTokenCount": 20
3156            }
3157        });
3158
3159        let resp = driver.parse_response(&response_body).unwrap();
3160        assert_eq!(resp.message.content, "The answer is 7.");
3161        assert!(!resp.message.content.contains("<thinking>"));
3162    }
3163
3164    #[test]
3165    fn anthropic_response_strips_thinking_tags() {
3166        let driver = AnthropicDriver::new("test-key".to_string(), None);
3167        let response_body = serde_json::json!({
3168            "content": [
3169                {"type": "text", "text": "<think>internal thought</think>Clean output."}
3170            ],
3171            "stop_reason": "end_turn",
3172            "usage": {"input_tokens": 10, "output_tokens": 5}
3173        });
3174
3175        let resp = driver.parse_response(&response_body).unwrap();
3176        assert_eq!(resp.message.content, "Clean output.");
3177    }
3178
3179    #[test]
3180    fn bedrock_response_strips_thinking_tags() {
3181        let driver = BedrockDriver::new(
3182            "key".to_string(),
3183            "secret".to_string(),
3184            "us-east-1".to_string(),
3185        );
3186        let response_body = serde_json::json!({
3187            "output": {
3188                "message": {
3189                    "role": "assistant",
3190                    "content": [{"text": "<reasoning>deep thought</reasoning>Result here."}]
3191                }
3192            },
3193            "stopReason": "end_turn",
3194            "usage": {"inputTokens": 50, "outputTokens": 25}
3195        });
3196
3197        let resp = driver.parse_response(&response_body).unwrap();
3198        assert_eq!(resp.message.content, "Result here.");
3199    }
3200}