Skip to main content

punch_runtime/
driver.rs

1//! LLM driver trait and provider implementations.
2//!
3//! The [`LlmDriver`] trait abstracts over different LLM providers so the
4//! fighter loop is provider-agnostic. Concrete implementations handle the
5//! wire format differences between Anthropic, OpenAI-compatible APIs, etc.
6
7use std::sync::Arc;
8
9use async_trait::async_trait;
10use futures::StreamExt;
11use hmac::{Hmac, Mac};
12use reqwest::Client;
13use serde::{Deserialize, Serialize};
14use sha2::{Digest, Sha256};
15
16use punch_types::{
17    Message, ModelConfig, Provider, PunchError, PunchResult, Role, ToolCall, ToolDefinition,
18};
19
20// ---------------------------------------------------------------------------
21// Core types
22// ---------------------------------------------------------------------------
23
24/// Why the model stopped generating.
25#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
26#[serde(rename_all = "snake_case")]
27pub enum StopReason {
28    /// The model finished its turn naturally.
29    EndTurn,
30    /// The model wants to invoke one or more tools.
31    ToolUse,
32    /// The response was truncated due to max_tokens.
33    MaxTokens,
34    /// An error occurred during generation.
35    Error,
36}
37
38/// Token usage statistics for a single completion.
39#[derive(Debug, Clone, Copy, Default, Serialize, Deserialize)]
40pub struct TokenUsage {
41    pub input_tokens: u64,
42    pub output_tokens: u64,
43}
44
45impl TokenUsage {
46    /// Add another usage on top of this one (accumulator).
47    pub fn accumulate(&mut self, other: &TokenUsage) {
48        self.input_tokens += other.input_tokens;
49        self.output_tokens += other.output_tokens;
50    }
51
52    /// Total tokens consumed.
53    pub fn total(&self) -> u64 {
54        self.input_tokens + self.output_tokens
55    }
56}
57
58/// A request to the LLM for a completion.
59#[derive(Debug, Clone, Serialize, Deserialize)]
60pub struct CompletionRequest {
61    /// Model identifier (e.g. "claude-sonnet-4-20250514").
62    pub model: String,
63    /// Conversation messages.
64    pub messages: Vec<Message>,
65    /// Tools available for the model to call.
66    #[serde(default)]
67    pub tools: Vec<ToolDefinition>,
68    /// Maximum tokens to generate.
69    pub max_tokens: u32,
70    /// Sampling temperature.
71    pub temperature: Option<f32>,
72    /// System prompt (separate from messages for providers that support it).
73    pub system_prompt: Option<String>,
74}
75
76/// The response from an LLM completion.
77#[derive(Debug, Clone, Serialize, Deserialize)]
78pub struct CompletionResponse {
79    /// The assistant message (may contain tool calls).
80    pub message: Message,
81    /// Token usage for this completion.
82    pub usage: TokenUsage,
83    /// Why the model stopped.
84    pub stop_reason: StopReason,
85}
86
87// ---------------------------------------------------------------------------
88// Streaming types
89// ---------------------------------------------------------------------------
90
91/// A chunk from a streaming LLM response.
92#[derive(Debug, Clone, Serialize, Deserialize)]
93pub struct StreamChunk {
94    /// Incremental text content.
95    pub delta: String,
96    /// Whether this is the final chunk.
97    pub is_final: bool,
98    /// Partial tool call data if any.
99    pub tool_call_delta: Option<ToolCallDelta>,
100}
101
102/// Incremental tool call data in a streaming response.
103#[derive(Debug, Clone, Serialize, Deserialize)]
104pub struct ToolCallDelta {
105    pub index: usize,
106    pub id: Option<String>,
107    pub name: Option<String>,
108    pub arguments_delta: String,
109}
110
111/// Callback invoked for each streaming chunk.
112pub type StreamCallback = Arc<dyn Fn(StreamChunk) + Send + Sync>;
113
114// ---------------------------------------------------------------------------
115// SSE parsing helpers
116// ---------------------------------------------------------------------------
117
118/// Parse a Server-Sent Events stream from raw bytes into discrete events.
119///
120/// Each event is a `(event_type, data)` tuple. Blank lines delimit events.
121fn parse_sse_events(raw: &str) -> Vec<(String, String)> {
122    let mut events = Vec::new();
123    let mut current_event = String::new();
124    let mut current_data = String::new();
125
126    for line in raw.lines() {
127        if line.is_empty() {
128            // Blank line = end of event
129            if !current_data.is_empty() || !current_event.is_empty() {
130                events.push((
131                    if current_event.is_empty() {
132                        "message".to_string()
133                    } else {
134                        current_event.clone()
135                    },
136                    current_data.clone(),
137                ));
138                current_event.clear();
139                current_data.clear();
140            }
141        } else if let Some(val) = line.strip_prefix("event: ") {
142            current_event = val.trim().to_string();
143        } else if let Some(val) = line.strip_prefix("event:") {
144            current_event = val.trim().to_string();
145        } else if let Some(val) = line.strip_prefix("data: ") {
146            if !current_data.is_empty() {
147                current_data.push('\n');
148            }
149            current_data.push_str(val);
150        } else if let Some(val) = line.strip_prefix("data:") {
151            if !current_data.is_empty() {
152                current_data.push('\n');
153            }
154            current_data.push_str(val.trim());
155        }
156    }
157
158    // Flush any trailing event without a final blank line
159    if !current_data.is_empty() || !current_event.is_empty() {
160        events.push((
161            if current_event.is_empty() {
162                "message".to_string()
163            } else {
164                current_event
165            },
166            current_data,
167        ));
168    }
169
170    events
171}
172
173/// Read the full response body as a stream of bytes and return as a String.
174async fn read_stream_body(response: reqwest::Response) -> PunchResult<String> {
175    let mut stream = response.bytes_stream();
176    let mut body = Vec::new();
177    while let Some(chunk) = stream.next().await {
178        let chunk = chunk.map_err(|e| PunchError::Provider {
179            provider: "stream".to_string(),
180            message: format!("stream read error: {e}"),
181        })?;
182        body.extend_from_slice(&chunk);
183    }
184    String::from_utf8(body).map_err(|e| PunchError::Provider {
185        provider: "stream".to_string(),
186        message: format!("invalid UTF-8 in stream: {e}"),
187    })
188}
189
190// ---------------------------------------------------------------------------
191// Think-tag stripping
192// ---------------------------------------------------------------------------
193
194/// Strip reasoning/thinking tags from LLM responses.
195///
196/// Many reasoning models (Qwen, DeepSeek, etc.) wrap internal chain-of-thought
197/// in `<think>...</think>`, `<thinking>...</thinking>`, or `<reasoning>...</reasoning>`
198/// tags. This function extracts only the visible output.
199///
200/// If the entire response is inside think tags (no visible output), returns
201/// the original content unchanged so the user still sees something.
202pub fn strip_thinking_tags(content: &str) -> String {
203    let mut result = content.to_string();
204
205    // Strip all known thinking tag variants
206    for tag in &["think", "thinking", "reasoning", "reflection"] {
207        let open = format!("<{}>", tag);
208        let close = format!("</{}>", tag);
209
210        // Remove all occurrences of <tag>...</tag> blocks
211        while let Some(start) = result.find(&open) {
212            if let Some(end) = result[start..].find(&close) {
213                let block_end = start + end + close.len();
214                result = format!("{}{}", &result[..start], &result[block_end..]);
215            } else {
216                // Unclosed tag — remove from open tag to end
217                result = result[..start].to_string();
218                break;
219            }
220        }
221    }
222
223    let trimmed = result.trim().to_string();
224
225    // If stripping removed everything, return original content
226    // (the model used all tokens for thinking)
227    if trimmed.is_empty() {
228        content.to_string()
229    } else {
230        trimmed
231    }
232}
233
234// ---------------------------------------------------------------------------
235// Trait
236// ---------------------------------------------------------------------------
237
238/// Abstraction over LLM providers.
239#[async_trait]
240pub trait LlmDriver: Send + Sync + 'static {
241    /// Send a completion request and return the response.
242    async fn complete(&self, request: CompletionRequest) -> PunchResult<CompletionResponse>;
243
244    /// Streaming variant. Default implementation falls back to `complete`.
245    async fn stream_complete(&self, request: CompletionRequest) -> PunchResult<CompletionResponse> {
246        let noop: StreamCallback = Arc::new(|_| {});
247        self.stream_complete_with_callback(request, noop).await
248    }
249
250    /// Streaming completion with per-chunk callback.
251    /// Returns the final assembled `CompletionResponse`.
252    async fn stream_complete_with_callback(
253        &self,
254        request: CompletionRequest,
255        callback: StreamCallback,
256    ) -> PunchResult<CompletionResponse> {
257        // Default: call complete() and send a single chunk.
258        let response = self.complete(request).await?;
259        callback(StreamChunk {
260            delta: response.message.content.clone(),
261            is_final: true,
262            tool_call_delta: None,
263        });
264        Ok(response)
265    }
266}
267
268// ---------------------------------------------------------------------------
269// Anthropic driver
270// ---------------------------------------------------------------------------
271
272/// Driver for the Anthropic Messages API (api.anthropic.com).
273pub struct AnthropicDriver {
274    client: Client,
275    api_key: String,
276    base_url: String,
277}
278
279impl AnthropicDriver {
280    /// Create a new Anthropic driver.
281    ///
282    /// `api_key` is the raw key value, not the env var name.
283    pub fn new(api_key: String, base_url: Option<String>) -> Self {
284        Self {
285            client: Client::new(),
286            api_key,
287            base_url: base_url.unwrap_or_else(|| "https://api.anthropic.com".to_string()),
288        }
289    }
290
291    /// Create a new Anthropic driver with a shared HTTP client.
292    ///
293    /// This allows connection pooling across all drivers.
294    pub fn with_client(client: Client, api_key: String, base_url: Option<String>) -> Self {
295        Self {
296            client,
297            api_key,
298            base_url: base_url.unwrap_or_else(|| "https://api.anthropic.com".to_string()),
299        }
300    }
301
302    /// Build the Anthropic API request body from our internal types.
303    fn build_request_body(&self, request: &CompletionRequest) -> serde_json::Value {
304        let mut messages = Vec::new();
305
306        for msg in &request.messages {
307            match msg.role {
308                Role::User => {
309                    if msg.content_parts.is_empty() {
310                        messages.push(serde_json::json!({
311                            "role": "user",
312                            "content": msg.content,
313                        }));
314                    } else {
315                        // Multimodal: build content blocks from parts.
316                        let mut content_blocks: Vec<serde_json::Value> = Vec::new();
317                        if !msg.content.is_empty() {
318                            content_blocks.push(serde_json::json!({
319                                "type": "text",
320                                "text": msg.content,
321                            }));
322                        }
323                        for part in &msg.content_parts {
324                            match part {
325                                punch_types::ContentPart::Text { text } => {
326                                    content_blocks.push(serde_json::json!({
327                                        "type": "text",
328                                        "text": text,
329                                    }));
330                                }
331                                punch_types::ContentPart::Image { media_type, data } => {
332                                    content_blocks.push(serde_json::json!({
333                                        "type": "image",
334                                        "source": {
335                                            "type": "base64",
336                                            "media_type": media_type,
337                                            "data": data,
338                                        },
339                                    }));
340                                }
341                            }
342                        }
343                        messages.push(serde_json::json!({
344                            "role": "user",
345                            "content": content_blocks,
346                        }));
347                    }
348                }
349                Role::Assistant => {
350                    let mut content_blocks: Vec<serde_json::Value> = Vec::new();
351
352                    if !msg.content.is_empty() {
353                        content_blocks.push(serde_json::json!({
354                            "type": "text",
355                            "text": msg.content,
356                        }));
357                    }
358
359                    for tc in &msg.tool_calls {
360                        content_blocks.push(serde_json::json!({
361                            "type": "tool_use",
362                            "id": tc.id,
363                            "name": tc.name,
364                            "input": tc.input,
365                        }));
366                    }
367
368                    if content_blocks.is_empty() {
369                        content_blocks.push(serde_json::json!({
370                            "type": "text",
371                            "text": "",
372                        }));
373                    }
374
375                    messages.push(serde_json::json!({
376                        "role": "assistant",
377                        "content": content_blocks,
378                    }));
379                }
380                Role::Tool => {
381                    let mut result_blocks: Vec<serde_json::Value> = Vec::new();
382                    for tr in &msg.tool_results {
383                        // Build content for this tool result — may include an image.
384                        if let Some(ref image) = tr.image {
385                            let mut content: Vec<serde_json::Value> = vec![serde_json::json!({
386                                "type": "text",
387                                "text": tr.content,
388                            })];
389                            if let punch_types::ContentPart::Image { media_type, data } = image {
390                                content.push(serde_json::json!({
391                                    "type": "image",
392                                    "source": {
393                                        "type": "base64",
394                                        "media_type": media_type,
395                                        "data": data,
396                                    },
397                                }));
398                            }
399                            result_blocks.push(serde_json::json!({
400                                "type": "tool_result",
401                                "tool_use_id": tr.id,
402                                "content": content,
403                                "is_error": tr.is_error,
404                            }));
405                        } else {
406                            result_blocks.push(serde_json::json!({
407                                "type": "tool_result",
408                                "tool_use_id": tr.id,
409                                "content": tr.content,
410                                "is_error": tr.is_error,
411                            }));
412                        }
413                    }
414                    messages.push(serde_json::json!({
415                        "role": "user",
416                        "content": result_blocks,
417                    }));
418                }
419                Role::System => {
420                    // System messages are handled via the top-level `system` param;
421                    // skip them in the messages array.
422                }
423            }
424        }
425
426        let tools: Vec<serde_json::Value> = request
427            .tools
428            .iter()
429            .map(|t| {
430                serde_json::json!({
431                    "name": t.name,
432                    "description": t.description,
433                    "input_schema": t.input_schema,
434                })
435            })
436            .collect();
437
438        let mut body = serde_json::json!({
439            "model": request.model,
440            "messages": messages,
441            "max_tokens": request.max_tokens,
442        });
443
444        if let Some(temp) = request.temperature {
445            body["temperature"] = serde_json::json!(temp);
446        }
447
448        // Anthropic prompt caching: use structured system content blocks with
449        // cache_control so the system prompt is cached across turns (~90% cost
450        // reduction on cached input tokens).
451        if let Some(ref system) = request.system_prompt {
452            body["system"] = serde_json::json!([
453                {
454                    "type": "text",
455                    "text": system,
456                    "cache_control": {"type": "ephemeral"},
457                }
458            ]);
459        }
460
461        if !tools.is_empty() {
462            // Mark the last tool with cache_control so the entire tool block
463            // is included in the cached prefix.
464            let mut tools_json = serde_json::json!(tools);
465            if let Some(arr) = tools_json.as_array_mut()
466                && let Some(last) = arr.last_mut()
467            {
468                last["cache_control"] = serde_json::json!({"type": "ephemeral"});
469            }
470            body["tools"] = tools_json;
471        }
472
473        body
474    }
475
476    /// Parse the Anthropic API response into our internal types.
477    fn parse_response(&self, body: &serde_json::Value) -> PunchResult<CompletionResponse> {
478        let stop_reason = match body["stop_reason"].as_str() {
479            Some("end_turn") => StopReason::EndTurn,
480            Some("tool_use") => StopReason::ToolUse,
481            Some("max_tokens") => StopReason::MaxTokens,
482            _ => StopReason::Error,
483        };
484
485        let usage = TokenUsage {
486            input_tokens: body["usage"]["input_tokens"].as_u64().unwrap_or(0),
487            output_tokens: body["usage"]["output_tokens"].as_u64().unwrap_or(0),
488        };
489
490        let mut text_content = String::new();
491        let mut tool_calls = Vec::new();
492
493        if let Some(content_array) = body["content"].as_array() {
494            for block in content_array {
495                match block["type"].as_str() {
496                    Some("text") => {
497                        if let Some(text) = block["text"].as_str() {
498                            if !text_content.is_empty() {
499                                text_content.push('\n');
500                            }
501                            text_content.push_str(text);
502                        }
503                    }
504                    Some("tool_use") => {
505                        tool_calls.push(ToolCall {
506                            id: block["id"].as_str().unwrap_or_default().to_string(),
507                            name: block["name"].as_str().unwrap_or_default().to_string(),
508                            input: block["input"].clone(),
509                        });
510                    }
511                    _ => {}
512                }
513            }
514        }
515
516        // Strip thinking tags from reasoning models
517        let text_content = strip_thinking_tags(&text_content);
518
519        let message = Message {
520            role: Role::Assistant,
521            content: text_content,
522            tool_calls,
523            tool_results: Vec::new(),
524            content_parts: Vec::new(),
525            timestamp: chrono::Utc::now(),
526        };
527
528        Ok(CompletionResponse {
529            message,
530            usage,
531            stop_reason,
532        })
533    }
534}
535
536#[async_trait]
537impl LlmDriver for AnthropicDriver {
538    async fn complete(&self, request: CompletionRequest) -> PunchResult<CompletionResponse> {
539        let url = format!("{}/v1/messages", self.base_url);
540        let body = self.build_request_body(&request);
541
542        let response = self
543            .client
544            .post(&url)
545            .header("x-api-key", &self.api_key)
546            .header("anthropic-version", "2023-06-01")
547            .header("content-type", "application/json")
548            .json(&body)
549            .send()
550            .await
551            .map_err(|e| PunchError::Provider {
552                provider: "anthropic".to_string(),
553                message: format!("request failed: {e}"),
554            })?;
555
556        let status = response.status();
557
558        if status == reqwest::StatusCode::TOO_MANY_REQUESTS {
559            let retry_after = response
560                .headers()
561                .get("retry-after")
562                .and_then(|v| v.to_str().ok())
563                .and_then(|s| s.parse::<u64>().ok())
564                .unwrap_or(60)
565                * 1000;
566
567            return Err(PunchError::RateLimited {
568                provider: "anthropic".to_string(),
569                retry_after_ms: retry_after,
570            });
571        }
572
573        if status == reqwest::StatusCode::UNAUTHORIZED || status == reqwest::StatusCode::FORBIDDEN {
574            return Err(PunchError::Auth(
575                "anthropic API key is invalid or lacks permissions".to_string(),
576            ));
577        }
578
579        let response_body: serde_json::Value =
580            response.json().await.map_err(|e| PunchError::Provider {
581                provider: "anthropic".to_string(),
582                message: format!("failed to parse response: {e}"),
583            })?;
584
585        if !status.is_success() {
586            let error_msg = response_body["error"]["message"]
587                .as_str()
588                .unwrap_or("unknown error");
589            return Err(PunchError::Provider {
590                provider: "anthropic".to_string(),
591                message: format!("API error ({}): {}", status, error_msg),
592            });
593        }
594
595        self.parse_response(&response_body)
596    }
597
598    async fn stream_complete_with_callback(
599        &self,
600        request: CompletionRequest,
601        callback: StreamCallback,
602    ) -> PunchResult<CompletionResponse> {
603        let url = format!("{}/v1/messages", self.base_url);
604        let mut body = self.build_request_body(&request);
605        body["stream"] = serde_json::json!(true);
606
607        let response = self
608            .client
609            .post(&url)
610            .header("x-api-key", &self.api_key)
611            .header("anthropic-version", "2023-06-01")
612            .header("content-type", "application/json")
613            .json(&body)
614            .send()
615            .await
616            .map_err(|e| PunchError::Provider {
617                provider: "anthropic".to_string(),
618                message: format!("stream request failed: {e}"),
619            })?;
620
621        let status = response.status();
622        if status == reqwest::StatusCode::TOO_MANY_REQUESTS {
623            return Err(PunchError::RateLimited {
624                provider: "anthropic".to_string(),
625                retry_after_ms: 60_000,
626            });
627        }
628        if !status.is_success() {
629            let err_body: serde_json::Value =
630                response.json().await.unwrap_or(serde_json::json!({}));
631            let msg = err_body["error"]["message"]
632                .as_str()
633                .unwrap_or("unknown error");
634            return Err(PunchError::Provider {
635                provider: "anthropic".to_string(),
636                message: format!("API error ({}): {}", status, msg),
637            });
638        }
639
640        let raw = read_stream_body(response).await?;
641        let events = parse_sse_events(&raw);
642
643        let mut text_content = String::new();
644        let mut tool_calls: Vec<ToolCall> = Vec::new();
645        let mut usage = TokenUsage::default();
646        let mut stop_reason = StopReason::EndTurn;
647        // Track current content block index for tool use assembly
648        let mut current_tool_index: Option<usize> = None;
649
650        for (event_type, data) in &events {
651            let parsed: serde_json::Value = match serde_json::from_str(data) {
652                Ok(v) => v,
653                Err(_) => continue,
654            };
655
656            match event_type.as_str() {
657                "message_start" => {
658                    if let Some(inp) = parsed["message"]["usage"]["input_tokens"].as_u64() {
659                        usage.input_tokens = inp;
660                    }
661                }
662                "content_block_start" => {
663                    let block = &parsed["content_block"];
664                    match block["type"].as_str() {
665                        Some("tool_use") => {
666                            let id = block["id"].as_str().unwrap_or_default().to_string();
667                            let name = block["name"].as_str().unwrap_or_default().to_string();
668                            tool_calls.push(ToolCall {
669                                id: id.clone(),
670                                name: name.clone(),
671                                input: serde_json::json!({}),
672                            });
673                            current_tool_index = Some(tool_calls.len() - 1);
674                            callback(StreamChunk {
675                                delta: String::new(),
676                                is_final: false,
677                                tool_call_delta: Some(ToolCallDelta {
678                                    index: tool_calls.len() - 1,
679                                    id: Some(id),
680                                    name: Some(name),
681                                    arguments_delta: String::new(),
682                                }),
683                            });
684                        }
685                        Some("text") => {
686                            current_tool_index = None;
687                        }
688                        _ => {}
689                    }
690                }
691                "content_block_delta" => {
692                    let delta = &parsed["delta"];
693                    match delta["type"].as_str() {
694                        Some("text_delta") => {
695                            let text = delta["text"].as_str().unwrap_or("");
696                            text_content.push_str(text);
697                            callback(StreamChunk {
698                                delta: text.to_string(),
699                                is_final: false,
700                                tool_call_delta: None,
701                            });
702                        }
703                        Some("input_json_delta") => {
704                            let partial = delta["partial_json"].as_str().unwrap_or("");
705                            if let Some(idx) = current_tool_index {
706                                callback(StreamChunk {
707                                    delta: String::new(),
708                                    is_final: false,
709                                    tool_call_delta: Some(ToolCallDelta {
710                                        index: idx,
711                                        id: None,
712                                        name: None,
713                                        arguments_delta: partial.to_string(),
714                                    }),
715                                });
716                            }
717                        }
718                        _ => {}
719                    }
720                }
721                "message_delta" => {
722                    if let Some(sr) = parsed["delta"]["stop_reason"].as_str() {
723                        stop_reason = match sr {
724                            "end_turn" => StopReason::EndTurn,
725                            "tool_use" => StopReason::ToolUse,
726                            "max_tokens" => StopReason::MaxTokens,
727                            _ => StopReason::Error,
728                        };
729                    }
730                    if let Some(out) = parsed["usage"]["output_tokens"].as_u64() {
731                        usage.output_tokens = out;
732                    }
733                }
734                "message_stop" => {
735                    callback(StreamChunk {
736                        delta: String::new(),
737                        is_final: true,
738                        tool_call_delta: None,
739                    });
740                }
741                _ => {}
742            }
743        }
744
745        // Reassemble tool call inputs from the accumulated JSON fragments.
746        // The Anthropic SSE stream sends tool input as `input_json_delta` fragments.
747        // We need to re-parse the full accumulated JSON for each tool call.
748        // Since we only captured deltas via callback, we rebuild from the raw events.
749        let mut tool_json_bufs: Vec<String> = vec![String::new(); tool_calls.len()];
750        let mut tc_idx: Option<usize> = None;
751        for (event_type, data) in &events {
752            let parsed: serde_json::Value = match serde_json::from_str(data) {
753                Ok(v) => v,
754                Err(_) => continue,
755            };
756            match event_type.as_str() {
757                "content_block_start" => {
758                    if parsed["content_block"]["type"].as_str() == Some("tool_use") {
759                        tc_idx = Some(tc_idx.map_or(0, |i| i + 1));
760                    } else {
761                        tc_idx = None;
762                    }
763                }
764                "content_block_delta" => {
765                    if parsed["delta"]["type"].as_str() == Some("input_json_delta")
766                        && let Some(idx) = tc_idx
767                        && let Some(buf) = tool_json_bufs.get_mut(idx)
768                    {
769                        buf.push_str(parsed["delta"]["partial_json"].as_str().unwrap_or(""));
770                    }
771                }
772                _ => {}
773            }
774        }
775        for (i, buf) in tool_json_bufs.into_iter().enumerate() {
776            if !buf.is_empty()
777                && let Some(tc) = tool_calls.get_mut(i)
778            {
779                tc.input = serde_json::from_str(&buf).unwrap_or(serde_json::json!({}));
780            }
781        }
782
783        let text_content = strip_thinking_tags(&text_content);
784
785        if !tool_calls.is_empty() && stop_reason != StopReason::ToolUse {
786            stop_reason = StopReason::ToolUse;
787        }
788
789        let message = Message {
790            role: Role::Assistant,
791            content: text_content,
792            tool_calls,
793            tool_results: Vec::new(),
794            content_parts: Vec::new(),
795            timestamp: chrono::Utc::now(),
796        };
797
798        Ok(CompletionResponse {
799            message,
800            usage,
801            stop_reason,
802        })
803    }
804}
805
806// ---------------------------------------------------------------------------
807// OpenAI-compatible driver
808// ---------------------------------------------------------------------------
809
810/// Driver for OpenAI-compatible chat completions APIs.
811///
812/// Works with OpenAI, Groq, DeepSeek, Together, Fireworks,
813/// Cerebras, xAI, Mistral, and any other provider exposing the
814/// `/v1/chat/completions` endpoint.
815pub struct OpenAiCompatibleDriver {
816    client: Client,
817    api_key: String,
818    base_url: String,
819    provider_name: String,
820}
821
822impl OpenAiCompatibleDriver {
823    /// Create a new OpenAI-compatible driver.
824    pub fn new(api_key: String, base_url: String, provider_name: String) -> Self {
825        Self {
826            client: Client::new(),
827            api_key,
828            base_url,
829            provider_name,
830        }
831    }
832
833    /// Create a new OpenAI-compatible driver with a shared HTTP client.
834    pub fn with_client(
835        client: Client,
836        api_key: String,
837        base_url: String,
838        provider_name: String,
839    ) -> Self {
840        Self {
841            client,
842            api_key,
843            base_url,
844            provider_name,
845        }
846    }
847
848    /// Build the OpenAI chat completions request body.
849    pub fn build_request_body(&self, request: &CompletionRequest) -> serde_json::Value {
850        let mut messages = Vec::new();
851
852        // System prompt as a system message.
853        if let Some(ref system) = request.system_prompt {
854            messages.push(serde_json::json!({
855                "role": "system",
856                "content": system,
857            }));
858        }
859
860        for msg in &request.messages {
861            match msg.role {
862                Role::System => {
863                    messages.push(serde_json::json!({
864                        "role": "system",
865                        "content": msg.content,
866                    }));
867                }
868                Role::User => {
869                    if msg.content_parts.is_empty() {
870                        messages.push(serde_json::json!({
871                            "role": "user",
872                            "content": msg.content,
873                        }));
874                    } else {
875                        // Multimodal: OpenAI format with content array.
876                        let mut content_blocks: Vec<serde_json::Value> = Vec::new();
877                        if !msg.content.is_empty() {
878                            content_blocks.push(serde_json::json!({
879                                "type": "text",
880                                "text": msg.content,
881                            }));
882                        }
883                        for part in &msg.content_parts {
884                            match part {
885                                punch_types::ContentPart::Text { text } => {
886                                    content_blocks.push(serde_json::json!({
887                                        "type": "text",
888                                        "text": text,
889                                    }));
890                                }
891                                punch_types::ContentPart::Image { media_type, data } => {
892                                    content_blocks.push(serde_json::json!({
893                                        "type": "image_url",
894                                        "image_url": {
895                                            "url": format!("data:{media_type};base64,{data}"),
896                                        },
897                                    }));
898                                }
899                            }
900                        }
901                        messages.push(serde_json::json!({
902                            "role": "user",
903                            "content": content_blocks,
904                        }));
905                    }
906                }
907                Role::Assistant => {
908                    let mut m = serde_json::json!({
909                        "role": "assistant",
910                    });
911
912                    if !msg.content.is_empty() {
913                        m["content"] = serde_json::json!(msg.content);
914                    }
915
916                    if !msg.tool_calls.is_empty() {
917                        let tc: Vec<serde_json::Value> = msg
918                            .tool_calls
919                            .iter()
920                            .map(|tc| {
921                                serde_json::json!({
922                                    "id": tc.id,
923                                    "type": "function",
924                                    "function": {
925                                        "name": tc.name,
926                                        "arguments": tc.input.to_string(),
927                                    },
928                                })
929                            })
930                            .collect();
931                        m["tool_calls"] = serde_json::json!(tc);
932                    }
933
934                    messages.push(m);
935                }
936                Role::Tool => {
937                    for tr in &msg.tool_results {
938                        messages.push(serde_json::json!({
939                            "role": "tool",
940                            "tool_call_id": tr.id,
941                            "content": tr.content,
942                        }));
943                    }
944                }
945            }
946        }
947
948        let tools: Vec<serde_json::Value> = request
949            .tools
950            .iter()
951            .map(|t| {
952                serde_json::json!({
953                    "type": "function",
954                    "function": {
955                        "name": t.name,
956                        "description": t.description,
957                        "parameters": t.input_schema,
958                    },
959                })
960            })
961            .collect();
962
963        let mut body = serde_json::json!({
964            "model": request.model,
965            "messages": messages,
966            "max_tokens": request.max_tokens,
967        });
968
969        if let Some(temp) = request.temperature {
970            body["temperature"] = serde_json::json!(temp);
971        }
972
973        if !tools.is_empty() {
974            body["tools"] = serde_json::json!(tools);
975        }
976
977        body
978    }
979
980    /// Parse the OpenAI chat completions response.
981    pub fn parse_response(&self, body: &serde_json::Value) -> PunchResult<CompletionResponse> {
982        let choice = body["choices"].get(0).ok_or_else(|| PunchError::Provider {
983            provider: self.provider_name.clone(),
984            message: "no choices in response".to_string(),
985        })?;
986
987        let finish_reason = choice["finish_reason"].as_str().unwrap_or("stop");
988        let stop_reason = match finish_reason {
989            "stop" => StopReason::EndTurn,
990            "tool_calls" => StopReason::ToolUse,
991            "length" => StopReason::MaxTokens,
992            _ => StopReason::EndTurn,
993        };
994
995        let msg = &choice["message"];
996        let raw_content = msg["content"].as_str().unwrap_or("");
997        // Strip thinking tags from reasoning models (Qwen, DeepSeek R1, etc.)
998        let content = strip_thinking_tags(raw_content);
999
1000        let mut tool_calls = Vec::new();
1001        if let Some(tc_array) = msg["tool_calls"].as_array() {
1002            for tc in tc_array {
1003                let id = tc["id"].as_str().unwrap_or_default().to_string();
1004                let name = tc["function"]["name"]
1005                    .as_str()
1006                    .unwrap_or_default()
1007                    .to_string();
1008                let args_str = tc["function"]["arguments"].as_str().unwrap_or("{}");
1009                let input: serde_json::Value =
1010                    serde_json::from_str(args_str).unwrap_or(serde_json::json!({}));
1011
1012                tool_calls.push(ToolCall { id, name, input });
1013            }
1014        }
1015
1016        let usage = TokenUsage {
1017            input_tokens: body["usage"]["prompt_tokens"].as_u64().unwrap_or(0),
1018            output_tokens: body["usage"]["completion_tokens"].as_u64().unwrap_or(0),
1019        };
1020
1021        // If there are tool calls but finish_reason was not "tool_calls", fix it up.
1022        let stop_reason = if !tool_calls.is_empty() && stop_reason != StopReason::ToolUse {
1023            StopReason::ToolUse
1024        } else {
1025            stop_reason
1026        };
1027
1028        let message = Message {
1029            role: Role::Assistant,
1030            content,
1031            tool_calls,
1032            tool_results: Vec::new(),
1033            content_parts: Vec::new(),
1034            timestamp: chrono::Utc::now(),
1035        };
1036
1037        Ok(CompletionResponse {
1038            message,
1039            usage,
1040            stop_reason,
1041        })
1042    }
1043}
1044
1045#[async_trait]
1046impl LlmDriver for OpenAiCompatibleDriver {
1047    async fn complete(&self, request: CompletionRequest) -> PunchResult<CompletionResponse> {
1048        let url = format!(
1049            "{}/v1/chat/completions",
1050            self.base_url.trim_end_matches('/')
1051        );
1052        let body = self.build_request_body(&request);
1053
1054        let response = self
1055            .client
1056            .post(&url)
1057            .header("authorization", format!("Bearer {}", self.api_key))
1058            .header("content-type", "application/json")
1059            .json(&body)
1060            .send()
1061            .await
1062            .map_err(|e| PunchError::Provider {
1063                provider: self.provider_name.clone(),
1064                message: format!("request failed: {e}"),
1065            })?;
1066
1067        let status = response.status();
1068
1069        if status == reqwest::StatusCode::TOO_MANY_REQUESTS {
1070            let retry_after = response
1071                .headers()
1072                .get("retry-after")
1073                .and_then(|v| v.to_str().ok())
1074                .and_then(|s| s.parse::<u64>().ok())
1075                .unwrap_or(60)
1076                * 1000;
1077
1078            return Err(PunchError::RateLimited {
1079                provider: self.provider_name.clone(),
1080                retry_after_ms: retry_after,
1081            });
1082        }
1083
1084        if status == reqwest::StatusCode::UNAUTHORIZED || status == reqwest::StatusCode::FORBIDDEN {
1085            return Err(PunchError::Auth(format!(
1086                "{} API key is invalid or lacks permissions",
1087                self.provider_name
1088            )));
1089        }
1090
1091        let response_body: serde_json::Value =
1092            response.json().await.map_err(|e| PunchError::Provider {
1093                provider: self.provider_name.clone(),
1094                message: format!("failed to parse response: {e}"),
1095            })?;
1096
1097        if !status.is_success() {
1098            let error_msg = response_body["error"]["message"]
1099                .as_str()
1100                .unwrap_or("unknown error");
1101            return Err(PunchError::Provider {
1102                provider: self.provider_name.clone(),
1103                message: format!("API error ({}): {}", status, error_msg),
1104            });
1105        }
1106
1107        self.parse_response(&response_body)
1108    }
1109
1110    async fn stream_complete_with_callback(
1111        &self,
1112        request: CompletionRequest,
1113        callback: StreamCallback,
1114    ) -> PunchResult<CompletionResponse> {
1115        let url = format!(
1116            "{}/v1/chat/completions",
1117            self.base_url.trim_end_matches('/')
1118        );
1119        let mut body = self.build_request_body(&request);
1120        body["stream"] = serde_json::json!(true);
1121
1122        let response = self
1123            .client
1124            .post(&url)
1125            .header("authorization", format!("Bearer {}", self.api_key))
1126            .header("content-type", "application/json")
1127            .json(&body)
1128            .send()
1129            .await
1130            .map_err(|e| PunchError::Provider {
1131                provider: self.provider_name.clone(),
1132                message: format!("stream request failed: {e}"),
1133            })?;
1134
1135        let status = response.status();
1136        if status == reqwest::StatusCode::TOO_MANY_REQUESTS {
1137            return Err(PunchError::RateLimited {
1138                provider: self.provider_name.clone(),
1139                retry_after_ms: 60_000,
1140            });
1141        }
1142        if !status.is_success() {
1143            let err_body: serde_json::Value =
1144                response.json().await.unwrap_or(serde_json::json!({}));
1145            let msg = err_body["error"]["message"]
1146                .as_str()
1147                .unwrap_or("unknown error");
1148            return Err(PunchError::Provider {
1149                provider: self.provider_name.clone(),
1150                message: format!("API error ({}): {}", status, msg),
1151            });
1152        }
1153
1154        let raw = read_stream_body(response).await?;
1155        let assembled = self.parse_openai_stream(&raw, &callback)?;
1156        Ok(assembled)
1157    }
1158}
1159
1160impl OpenAiCompatibleDriver {
1161    /// Parse an OpenAI-style SSE stream into a `CompletionResponse`, invoking
1162    /// the callback for each chunk.
1163    pub fn parse_openai_stream(
1164        &self,
1165        raw: &str,
1166        callback: &StreamCallback,
1167    ) -> PunchResult<CompletionResponse> {
1168        let events = parse_sse_events(raw);
1169
1170        let mut text_content = String::new();
1171        // tool_calls keyed by index
1172        let mut tool_map: std::collections::BTreeMap<usize, (String, String, String)> =
1173            std::collections::BTreeMap::new();
1174        let mut finish_reason = String::new();
1175
1176        for (_event_type, data) in &events {
1177            if data.trim() == "[DONE]" {
1178                callback(StreamChunk {
1179                    delta: String::new(),
1180                    is_final: true,
1181                    tool_call_delta: None,
1182                });
1183                break;
1184            }
1185
1186            let parsed: serde_json::Value = match serde_json::from_str(data) {
1187                Ok(v) => v,
1188                Err(_) => continue,
1189            };
1190
1191            let choice = match parsed["choices"].get(0) {
1192                Some(c) => c,
1193                None => continue,
1194            };
1195
1196            if let Some(fr) = choice["finish_reason"].as_str() {
1197                finish_reason = fr.to_string();
1198            }
1199
1200            let delta = &choice["delta"];
1201
1202            // Text content delta
1203            if let Some(content) = delta["content"].as_str() {
1204                text_content.push_str(content);
1205                callback(StreamChunk {
1206                    delta: content.to_string(),
1207                    is_final: false,
1208                    tool_call_delta: None,
1209                });
1210            }
1211
1212            // Tool call deltas
1213            if let Some(tc_array) = delta["tool_calls"].as_array() {
1214                for tc in tc_array {
1215                    let idx = tc["index"].as_u64().unwrap_or(0) as usize;
1216                    let entry = tool_map
1217                        .entry(idx)
1218                        .or_insert_with(|| (String::new(), String::new(), String::new()));
1219
1220                    let id_delta = tc["id"].as_str().unwrap_or("");
1221                    let name_delta = tc["function"]["name"].as_str().unwrap_or("");
1222                    let args_delta = tc["function"]["arguments"].as_str().unwrap_or("");
1223
1224                    if !id_delta.is_empty() {
1225                        entry.0.push_str(id_delta);
1226                    }
1227                    if !name_delta.is_empty() {
1228                        entry.1.push_str(name_delta);
1229                    }
1230                    entry.2.push_str(args_delta);
1231
1232                    callback(StreamChunk {
1233                        delta: String::new(),
1234                        is_final: false,
1235                        tool_call_delta: Some(ToolCallDelta {
1236                            index: idx,
1237                            id: if id_delta.is_empty() {
1238                                None
1239                            } else {
1240                                Some(id_delta.to_string())
1241                            },
1242                            name: if name_delta.is_empty() {
1243                                None
1244                            } else {
1245                                Some(name_delta.to_string())
1246                            },
1247                            arguments_delta: args_delta.to_string(),
1248                        }),
1249                    });
1250                }
1251            }
1252        }
1253
1254        let tool_calls: Vec<ToolCall> = tool_map
1255            .into_values()
1256            .map(|(id, name, args)| {
1257                let input: serde_json::Value =
1258                    serde_json::from_str(&args).unwrap_or(serde_json::json!({}));
1259                ToolCall { id, name, input }
1260            })
1261            .collect();
1262
1263        let stop_reason = if !tool_calls.is_empty() {
1264            StopReason::ToolUse
1265        } else {
1266            match finish_reason.as_str() {
1267                "stop" => StopReason::EndTurn,
1268                "tool_calls" => StopReason::ToolUse,
1269                "length" => StopReason::MaxTokens,
1270                _ => StopReason::EndTurn,
1271            }
1272        };
1273
1274        let text_content = strip_thinking_tags(&text_content);
1275
1276        let message = Message {
1277            role: Role::Assistant,
1278            content: text_content,
1279            tool_calls,
1280            tool_results: Vec::new(),
1281            content_parts: Vec::new(),
1282            timestamp: chrono::Utc::now(),
1283        };
1284
1285        // OpenAI streaming does not include usage in most chunks; set to zero.
1286        Ok(CompletionResponse {
1287            message,
1288            usage: TokenUsage::default(),
1289            stop_reason,
1290        })
1291    }
1292}
1293
1294// ---------------------------------------------------------------------------
1295// Gemini driver
1296// ---------------------------------------------------------------------------
1297
1298/// Driver for the Google Gemini (Generative Language) API.
1299pub struct GeminiDriver {
1300    client: Client,
1301    api_key: String,
1302    base_url: String,
1303}
1304
1305impl GeminiDriver {
1306    /// Create a new Gemini driver.
1307    pub fn new(api_key: String, base_url: Option<String>) -> Self {
1308        Self {
1309            client: Client::new(),
1310            api_key,
1311            base_url: base_url
1312                .unwrap_or_else(|| "https://generativelanguage.googleapis.com".to_string()),
1313        }
1314    }
1315
1316    /// Create a new Gemini driver with a shared HTTP client.
1317    pub fn with_client(client: Client, api_key: String, base_url: Option<String>) -> Self {
1318        Self {
1319            client,
1320            api_key,
1321            base_url: base_url
1322                .unwrap_or_else(|| "https://generativelanguage.googleapis.com".to_string()),
1323        }
1324    }
1325
1326    /// Build the Gemini API request body.
1327    pub fn build_request_body(&self, request: &CompletionRequest) -> serde_json::Value {
1328        let mut contents = Vec::new();
1329        // Collect system text for the dedicated systemInstruction field.
1330        let mut system_text: Option<String> = request.system_prompt.clone();
1331
1332        for msg in &request.messages {
1333            match msg.role {
1334                Role::System => {
1335                    // Accumulate system-role messages into the systemInstruction.
1336                    let existing = system_text.take().unwrap_or_default();
1337                    let combined = if existing.is_empty() {
1338                        msg.content.clone()
1339                    } else {
1340                        format!("{}\n{}", existing, msg.content)
1341                    };
1342                    system_text = Some(combined);
1343                }
1344                Role::User => {
1345                    let mut parts: Vec<serde_json::Value> = Vec::new();
1346                    if !msg.content.is_empty() {
1347                        parts.push(serde_json::json!({"text": msg.content}));
1348                    }
1349                    // Add multimodal parts for Gemini.
1350                    for part in &msg.content_parts {
1351                        match part {
1352                            punch_types::ContentPart::Text { text: t } => {
1353                                parts.push(serde_json::json!({"text": t}));
1354                            }
1355                            punch_types::ContentPart::Image { media_type, data } => {
1356                                parts.push(serde_json::json!({
1357                                    "inline_data": {
1358                                        "mime_type": media_type,
1359                                        "data": data,
1360                                    }
1361                                }));
1362                            }
1363                        }
1364                    }
1365                    if parts.is_empty() {
1366                        parts.push(serde_json::json!({"text": ""}));
1367                    }
1368                    contents.push(serde_json::json!({
1369                        "role": "user",
1370                        "parts": parts,
1371                    }));
1372                }
1373                Role::Assistant => {
1374                    let mut parts: Vec<serde_json::Value> = Vec::new();
1375                    if !msg.content.is_empty() {
1376                        parts.push(serde_json::json!({"text": msg.content}));
1377                    }
1378                    for tc in &msg.tool_calls {
1379                        parts.push(serde_json::json!({
1380                            "functionCall": {
1381                                "name": tc.name,
1382                                "args": tc.input,
1383                            }
1384                        }));
1385                    }
1386                    if parts.is_empty() {
1387                        parts.push(serde_json::json!({"text": ""}));
1388                    }
1389                    contents.push(serde_json::json!({
1390                        "role": "model",
1391                        "parts": parts,
1392                    }));
1393                }
1394                Role::Tool => {
1395                    let mut parts: Vec<serde_json::Value> = Vec::new();
1396                    for tr in &msg.tool_results {
1397                        parts.push(serde_json::json!({
1398                            "functionResponse": {
1399                                "name": tr.id.clone(),
1400                                "response": {"content": tr.content},
1401                            }
1402                        }));
1403                    }
1404                    contents.push(serde_json::json!({
1405                        "role": "user",
1406                        "parts": parts,
1407                    }));
1408                }
1409            }
1410        }
1411
1412        let mut body = serde_json::json!({
1413            "contents": contents,
1414        });
1415
1416        // Use Gemini's dedicated systemInstruction field instead of prepending
1417        // to user messages. This enables Gemini's automatic prompt caching and
1418        // keeps the system prompt separate from conversation content.
1419        if let Some(sys) = system_text
1420            && !sys.is_empty()
1421        {
1422            body["system_instruction"] = serde_json::json!({
1423                "parts": [{"text": sys}],
1424            });
1425        }
1426
1427        let mut gen_config = serde_json::json!({
1428            "maxOutputTokens": request.max_tokens,
1429        });
1430        if let Some(temp) = request.temperature {
1431            gen_config["temperature"] = serde_json::json!(temp);
1432        }
1433        body["generationConfig"] = gen_config;
1434
1435        if !request.tools.is_empty() {
1436            let func_decls: Vec<serde_json::Value> = request
1437                .tools
1438                .iter()
1439                .map(|t| {
1440                    serde_json::json!({
1441                        "name": t.name,
1442                        "description": t.description,
1443                        "parameters": t.input_schema,
1444                    })
1445                })
1446                .collect();
1447            body["tools"] = serde_json::json!([{"function_declarations": func_decls}]);
1448        }
1449
1450        body
1451    }
1452
1453    /// Build the full URL for a Gemini request.
1454    pub fn build_url(&self, model: &str) -> String {
1455        format!(
1456            "{}/v1beta/models/{}:generateContent?key={}",
1457            self.base_url.trim_end_matches('/'),
1458            model,
1459            self.api_key,
1460        )
1461    }
1462
1463    /// Build the URL for Gemini streaming.
1464    pub fn build_stream_url(&self, model: &str) -> String {
1465        format!(
1466            "{}/v1beta/models/{}:streamGenerateContent?alt=sse&key={}",
1467            self.base_url.trim_end_matches('/'),
1468            model,
1469            self.api_key,
1470        )
1471    }
1472
1473    /// Parse the Gemini API response.
1474    pub fn parse_response(&self, body: &serde_json::Value) -> PunchResult<CompletionResponse> {
1475        let candidate = body["candidates"]
1476            .get(0)
1477            .ok_or_else(|| PunchError::Provider {
1478                provider: "gemini".to_string(),
1479                message: "no candidates in response".to_string(),
1480            })?;
1481
1482        let parts = candidate["content"]["parts"]
1483            .as_array()
1484            .cloned()
1485            .unwrap_or_default();
1486
1487        let mut text_content = String::new();
1488        let mut tool_calls = Vec::new();
1489
1490        for part in &parts {
1491            if let Some(text) = part["text"].as_str() {
1492                if !text_content.is_empty() {
1493                    text_content.push('\n');
1494                }
1495                text_content.push_str(text);
1496            }
1497            if let Some(fc) = part.get("functionCall") {
1498                let name = fc["name"].as_str().unwrap_or_default().to_string();
1499                let args = fc["args"].clone();
1500                tool_calls.push(ToolCall {
1501                    id: format!("gemini-{}", uuid::Uuid::new_v4()),
1502                    name,
1503                    input: args,
1504                });
1505            }
1506        }
1507
1508        let finish_reason = candidate["finishReason"].as_str().unwrap_or("STOP");
1509        let stop_reason = if !tool_calls.is_empty() {
1510            StopReason::ToolUse
1511        } else {
1512            match finish_reason {
1513                "STOP" => StopReason::EndTurn,
1514                "MAX_TOKENS" => StopReason::MaxTokens,
1515                _ => StopReason::EndTurn,
1516            }
1517        };
1518
1519        let usage = TokenUsage {
1520            input_tokens: body["usageMetadata"]["promptTokenCount"]
1521                .as_u64()
1522                .unwrap_or(0),
1523            output_tokens: body["usageMetadata"]["candidatesTokenCount"]
1524                .as_u64()
1525                .unwrap_or(0),
1526        };
1527
1528        // Strip thinking tags from reasoning models
1529        let text_content = strip_thinking_tags(&text_content);
1530
1531        let message = Message {
1532            role: Role::Assistant,
1533            content: text_content,
1534            tool_calls,
1535            tool_results: Vec::new(),
1536            content_parts: Vec::new(),
1537            timestamp: chrono::Utc::now(),
1538        };
1539
1540        Ok(CompletionResponse {
1541            message,
1542            usage,
1543            stop_reason,
1544        })
1545    }
1546}
1547
1548#[async_trait]
1549impl LlmDriver for GeminiDriver {
1550    async fn complete(&self, request: CompletionRequest) -> PunchResult<CompletionResponse> {
1551        let url = self.build_url(&request.model);
1552        let body = self.build_request_body(&request);
1553
1554        let response = self
1555            .client
1556            .post(&url)
1557            .header("content-type", "application/json")
1558            .json(&body)
1559            .send()
1560            .await
1561            .map_err(|e| PunchError::Provider {
1562                provider: "gemini".to_string(),
1563                message: format!("request failed: {e}"),
1564            })?;
1565
1566        let status = response.status();
1567
1568        if status == reqwest::StatusCode::TOO_MANY_REQUESTS {
1569            return Err(PunchError::RateLimited {
1570                provider: "gemini".to_string(),
1571                retry_after_ms: 60_000,
1572            });
1573        }
1574
1575        if status == reqwest::StatusCode::UNAUTHORIZED || status == reqwest::StatusCode::FORBIDDEN {
1576            return Err(PunchError::Auth(
1577                "Gemini API key is invalid or lacks permissions".to_string(),
1578            ));
1579        }
1580
1581        let response_body: serde_json::Value =
1582            response.json().await.map_err(|e| PunchError::Provider {
1583                provider: "gemini".to_string(),
1584                message: format!("failed to parse response: {e}"),
1585            })?;
1586
1587        if !status.is_success() {
1588            let error_msg = response_body["error"]["message"]
1589                .as_str()
1590                .unwrap_or("unknown error");
1591            return Err(PunchError::Provider {
1592                provider: "gemini".to_string(),
1593                message: format!("API error ({}): {}", status, error_msg),
1594            });
1595        }
1596
1597        self.parse_response(&response_body)
1598    }
1599
1600    async fn stream_complete_with_callback(
1601        &self,
1602        request: CompletionRequest,
1603        callback: StreamCallback,
1604    ) -> PunchResult<CompletionResponse> {
1605        let url = self.build_stream_url(&request.model);
1606        let body = self.build_request_body(&request);
1607
1608        let response = self
1609            .client
1610            .post(&url)
1611            .header("content-type", "application/json")
1612            .json(&body)
1613            .send()
1614            .await
1615            .map_err(|e| PunchError::Provider {
1616                provider: "gemini".to_string(),
1617                message: format!("stream request failed: {e}"),
1618            })?;
1619
1620        let status = response.status();
1621        if !status.is_success() {
1622            let err_body: serde_json::Value =
1623                response.json().await.unwrap_or(serde_json::json!({}));
1624            let msg = err_body["error"]["message"]
1625                .as_str()
1626                .unwrap_or("unknown error");
1627            return Err(PunchError::Provider {
1628                provider: "gemini".to_string(),
1629                message: format!("API error ({}): {}", status, msg),
1630            });
1631        }
1632
1633        let raw = read_stream_body(response).await?;
1634        let events = parse_sse_events(&raw);
1635
1636        let mut text_content = String::new();
1637        let mut tool_calls: Vec<ToolCall> = Vec::new();
1638        let mut usage = TokenUsage::default();
1639        let mut finish_reason = String::new();
1640
1641        for (_event_type, data) in &events {
1642            let parsed: serde_json::Value = match serde_json::from_str(data) {
1643                Ok(v) => v,
1644                Err(_) => continue,
1645            };
1646
1647            // Extract parts from the candidate
1648            if let Some(parts) = parsed["candidates"][0]["content"]["parts"].as_array() {
1649                for part in parts {
1650                    if let Some(text) = part["text"].as_str() {
1651                        text_content.push_str(text);
1652                        callback(StreamChunk {
1653                            delta: text.to_string(),
1654                            is_final: false,
1655                            tool_call_delta: None,
1656                        });
1657                    }
1658                    if let Some(fc) = part.get("functionCall") {
1659                        let name = fc["name"].as_str().unwrap_or_default().to_string();
1660                        let args = fc["args"].clone();
1661                        let idx = tool_calls.len();
1662                        tool_calls.push(ToolCall {
1663                            id: format!("gemini-{}", uuid::Uuid::new_v4()),
1664                            name: name.clone(),
1665                            input: args,
1666                        });
1667                        callback(StreamChunk {
1668                            delta: String::new(),
1669                            is_final: false,
1670                            tool_call_delta: Some(ToolCallDelta {
1671                                index: idx,
1672                                id: None,
1673                                name: Some(name),
1674                                arguments_delta: String::new(),
1675                            }),
1676                        });
1677                    }
1678                }
1679            }
1680
1681            if let Some(fr) = parsed["candidates"][0]["finishReason"].as_str() {
1682                finish_reason = fr.to_string();
1683            }
1684
1685            // Usage from the last chunk
1686            if let Some(inp) = parsed["usageMetadata"]["promptTokenCount"].as_u64() {
1687                usage.input_tokens = inp;
1688            }
1689            if let Some(out) = parsed["usageMetadata"]["candidatesTokenCount"].as_u64() {
1690                usage.output_tokens = out;
1691            }
1692        }
1693
1694        callback(StreamChunk {
1695            delta: String::new(),
1696            is_final: true,
1697            tool_call_delta: None,
1698        });
1699
1700        let stop_reason = if !tool_calls.is_empty() {
1701            StopReason::ToolUse
1702        } else {
1703            match finish_reason.as_str() {
1704                "STOP" => StopReason::EndTurn,
1705                "MAX_TOKENS" => StopReason::MaxTokens,
1706                _ => StopReason::EndTurn,
1707            }
1708        };
1709
1710        let text_content = strip_thinking_tags(&text_content);
1711
1712        let message = Message {
1713            role: Role::Assistant,
1714            content: text_content,
1715            tool_calls,
1716            tool_results: Vec::new(),
1717            content_parts: Vec::new(),
1718            timestamp: chrono::Utc::now(),
1719        };
1720
1721        Ok(CompletionResponse {
1722            message,
1723            usage,
1724            stop_reason,
1725        })
1726    }
1727}
1728
1729// ---------------------------------------------------------------------------
1730// Ollama driver
1731// ---------------------------------------------------------------------------
1732
1733/// Driver for local Ollama instances using the chat API.
1734pub struct OllamaDriver {
1735    client: Client,
1736    base_url: String,
1737}
1738
1739impl OllamaDriver {
1740    /// Create a new Ollama driver.
1741    pub fn new(base_url: Option<String>) -> Self {
1742        Self {
1743            client: Client::new(),
1744            base_url: base_url.unwrap_or_else(|| "http://localhost:11434".to_string()),
1745        }
1746    }
1747
1748    /// Create a new Ollama driver with a shared HTTP client.
1749    pub fn with_client(client: Client, base_url: Option<String>) -> Self {
1750        Self {
1751            client,
1752            base_url: base_url.unwrap_or_else(|| "http://localhost:11434".to_string()),
1753        }
1754    }
1755
1756    /// Get the base URL.
1757    pub fn base_url(&self) -> &str {
1758        &self.base_url
1759    }
1760
1761    /// Build the Ollama chat request body.
1762    pub fn build_request_body(&self, request: &CompletionRequest) -> serde_json::Value {
1763        let mut messages = Vec::new();
1764
1765        if let Some(ref system) = request.system_prompt {
1766            messages.push(serde_json::json!({
1767                "role": "system",
1768                "content": system,
1769            }));
1770        }
1771
1772        for msg in &request.messages {
1773            match msg.role {
1774                Role::System => {
1775                    messages.push(serde_json::json!({
1776                        "role": "system",
1777                        "content": msg.content,
1778                    }));
1779                }
1780                Role::User => {
1781                    // Ollama multimodal: images go in a separate "images" array.
1782                    let images: Vec<&str> = msg
1783                        .content_parts
1784                        .iter()
1785                        .filter_map(|p| match p {
1786                            punch_types::ContentPart::Image { data, .. } => Some(data.as_str()),
1787                            _ => None,
1788                        })
1789                        .collect();
1790                    let mut m = serde_json::json!({
1791                        "role": "user",
1792                        "content": msg.content,
1793                    });
1794                    if !images.is_empty() {
1795                        m["images"] = serde_json::json!(images);
1796                    }
1797                    messages.push(m);
1798                }
1799                Role::Assistant => {
1800                    let mut m = serde_json::json!({
1801                        "role": "assistant",
1802                        "content": msg.content,
1803                    });
1804                    if !msg.tool_calls.is_empty() {
1805                        let tc: Vec<serde_json::Value> = msg
1806                            .tool_calls
1807                            .iter()
1808                            .map(|tc| {
1809                                serde_json::json!({
1810                                    "function": {
1811                                        "name": tc.name,
1812                                        "arguments": tc.input,
1813                                    }
1814                                })
1815                            })
1816                            .collect();
1817                        m["tool_calls"] = serde_json::json!(tc);
1818                    }
1819                    messages.push(m);
1820                }
1821                Role::Tool => {
1822                    for tr in &msg.tool_results {
1823                        messages.push(serde_json::json!({
1824                            "role": "tool",
1825                            "content": tr.content,
1826                        }));
1827                    }
1828                }
1829            }
1830        }
1831
1832        let mut body = serde_json::json!({
1833            "model": request.model,
1834            "messages": messages,
1835            "stream": false,
1836        });
1837
1838        let mut options = serde_json::json!({});
1839        if let Some(temp) = request.temperature {
1840            options["temperature"] = serde_json::json!(temp);
1841        }
1842        if request.max_tokens > 0 {
1843            options["num_predict"] = serde_json::json!(request.max_tokens);
1844        }
1845        body["options"] = options;
1846
1847        // Disable thinking mode for reasoning models (Qwen, DeepSeek) to prevent
1848        // the model from spending its entire token budget on internal reasoning.
1849        // The think tags get stripped anyway, so we avoid wasting tokens.
1850        body["think"] = serde_json::json!(false);
1851
1852        if !request.tools.is_empty() {
1853            let tools: Vec<serde_json::Value> = request
1854                .tools
1855                .iter()
1856                .map(|t| {
1857                    serde_json::json!({
1858                        "type": "function",
1859                        "function": {
1860                            "name": t.name,
1861                            "description": t.description,
1862                            "parameters": t.input_schema,
1863                        }
1864                    })
1865                })
1866                .collect();
1867            body["tools"] = serde_json::json!(tools);
1868        }
1869
1870        body
1871    }
1872
1873    /// Parse the Ollama chat response.
1874    pub fn parse_response(&self, body: &serde_json::Value) -> PunchResult<CompletionResponse> {
1875        let msg = &body["message"];
1876        let raw_content = msg["content"].as_str().unwrap_or("");
1877        // Strip thinking tags from reasoning models (Qwen, DeepSeek, etc.)
1878        let content = strip_thinking_tags(raw_content);
1879
1880        let mut tool_calls = Vec::new();
1881        if let Some(tc_array) = msg["tool_calls"].as_array() {
1882            for tc in tc_array {
1883                let name = tc["function"]["name"]
1884                    .as_str()
1885                    .unwrap_or_default()
1886                    .to_string();
1887                let input = tc["function"]["arguments"].clone();
1888                tool_calls.push(ToolCall {
1889                    id: format!("ollama-{}", uuid::Uuid::new_v4()),
1890                    name,
1891                    input,
1892                });
1893            }
1894        }
1895
1896        let stop_reason = if !tool_calls.is_empty() {
1897            StopReason::ToolUse
1898        } else if body["done"].as_bool().unwrap_or(true) {
1899            StopReason::EndTurn
1900        } else {
1901            StopReason::MaxTokens
1902        };
1903
1904        let usage = TokenUsage {
1905            input_tokens: body["prompt_eval_count"].as_u64().unwrap_or(0),
1906            output_tokens: body["eval_count"].as_u64().unwrap_or(0),
1907        };
1908
1909        let message = Message {
1910            role: Role::Assistant,
1911            content,
1912            tool_calls,
1913            tool_results: Vec::new(),
1914            content_parts: Vec::new(),
1915            timestamp: chrono::Utc::now(),
1916        };
1917
1918        Ok(CompletionResponse {
1919            message,
1920            usage,
1921            stop_reason,
1922        })
1923    }
1924}
1925
1926#[async_trait]
1927impl LlmDriver for OllamaDriver {
1928    async fn complete(&self, request: CompletionRequest) -> PunchResult<CompletionResponse> {
1929        let url = format!("{}/api/chat", self.base_url.trim_end_matches('/'));
1930        let body = self.build_request_body(&request);
1931
1932        let response = self
1933            .client
1934            .post(&url)
1935            .header("content-type", "application/json")
1936            .json(&body)
1937            .send()
1938            .await
1939            .map_err(|e| PunchError::Provider {
1940                provider: "ollama".to_string(),
1941                message: format!("request failed: {e}"),
1942            })?;
1943
1944        let status = response.status();
1945        let response_body: serde_json::Value =
1946            response.json().await.map_err(|e| PunchError::Provider {
1947                provider: "ollama".to_string(),
1948                message: format!("failed to parse response: {e}"),
1949            })?;
1950
1951        if !status.is_success() {
1952            let error_msg = response_body["error"].as_str().unwrap_or("unknown error");
1953            return Err(PunchError::Provider {
1954                provider: "ollama".to_string(),
1955                message: format!("API error ({}): {}", status, error_msg),
1956            });
1957        }
1958
1959        self.parse_response(&response_body)
1960    }
1961
1962    async fn stream_complete_with_callback(
1963        &self,
1964        request: CompletionRequest,
1965        callback: StreamCallback,
1966    ) -> PunchResult<CompletionResponse> {
1967        let url = format!("{}/api/chat", self.base_url.trim_end_matches('/'));
1968        let mut body = self.build_request_body(&request);
1969        body["stream"] = serde_json::json!(true);
1970
1971        let response = self
1972            .client
1973            .post(&url)
1974            .header("content-type", "application/json")
1975            .json(&body)
1976            .send()
1977            .await
1978            .map_err(|e| PunchError::Provider {
1979                provider: "ollama".to_string(),
1980                message: format!("stream request failed: {e}"),
1981            })?;
1982
1983        let status = response.status();
1984        if !status.is_success() {
1985            let err_body: serde_json::Value =
1986                response.json().await.unwrap_or(serde_json::json!({}));
1987            let msg = err_body["error"].as_str().unwrap_or("unknown error");
1988            return Err(PunchError::Provider {
1989                provider: "ollama".to_string(),
1990                message: format!("API error ({}): {}", status, msg),
1991            });
1992        }
1993
1994        let raw = read_stream_body(response).await?;
1995        let assembled = self.parse_ollama_stream(&raw, &callback)?;
1996        Ok(assembled)
1997    }
1998}
1999
2000impl OllamaDriver {
2001    /// Parse Ollama's newline-delimited JSON stream into a `CompletionResponse`,
2002    /// invoking the callback for each chunk.
2003    pub fn parse_ollama_stream(
2004        &self,
2005        raw: &str,
2006        callback: &StreamCallback,
2007    ) -> PunchResult<CompletionResponse> {
2008        let mut text_content = String::new();
2009        let mut tool_calls: Vec<ToolCall> = Vec::new();
2010        let mut usage = TokenUsage::default();
2011        let mut done = false;
2012
2013        for line in raw.lines() {
2014            let line = line.trim();
2015            if line.is_empty() {
2016                continue;
2017            }
2018
2019            let parsed: serde_json::Value = match serde_json::from_str(line) {
2020                Ok(v) => v,
2021                Err(_) => continue,
2022            };
2023
2024            if parsed["done"].as_bool() == Some(true) {
2025                done = true;
2026                // Final chunk may include stats
2027                if let Some(inp) = parsed["prompt_eval_count"].as_u64() {
2028                    usage.input_tokens = inp;
2029                }
2030                if let Some(out) = parsed["eval_count"].as_u64() {
2031                    usage.output_tokens = out;
2032                }
2033                // Final chunk may also have tool calls
2034                if let Some(tc_array) = parsed["message"]["tool_calls"].as_array() {
2035                    for tc in tc_array {
2036                        let name = tc["function"]["name"]
2037                            .as_str()
2038                            .unwrap_or_default()
2039                            .to_string();
2040                        let input = tc["function"]["arguments"].clone();
2041                        tool_calls.push(ToolCall {
2042                            id: format!("ollama-{}", uuid::Uuid::new_v4()),
2043                            name,
2044                            input,
2045                        });
2046                    }
2047                }
2048                callback(StreamChunk {
2049                    delta: String::new(),
2050                    is_final: true,
2051                    tool_call_delta: None,
2052                });
2053                break;
2054            }
2055
2056            // Streaming chunk with content
2057            let content = parsed["message"]["content"].as_str().unwrap_or("");
2058            if !content.is_empty() {
2059                text_content.push_str(content);
2060                callback(StreamChunk {
2061                    delta: content.to_string(),
2062                    is_final: false,
2063                    tool_call_delta: None,
2064                });
2065            }
2066        }
2067
2068        let text_content = strip_thinking_tags(&text_content);
2069
2070        let stop_reason = if !tool_calls.is_empty() {
2071            StopReason::ToolUse
2072        } else if done {
2073            StopReason::EndTurn
2074        } else {
2075            StopReason::MaxTokens
2076        };
2077
2078        let message = Message {
2079            role: Role::Assistant,
2080            content: text_content,
2081            tool_calls,
2082            tool_results: Vec::new(),
2083            content_parts: Vec::new(),
2084            timestamp: chrono::Utc::now(),
2085        };
2086
2087        Ok(CompletionResponse {
2088            message,
2089            usage,
2090            stop_reason,
2091        })
2092    }
2093}
2094
2095// ---------------------------------------------------------------------------
2096// AWS Bedrock driver
2097// ---------------------------------------------------------------------------
2098
2099/// Driver for AWS Bedrock using the Converse API with SigV4 authentication.
2100pub struct BedrockDriver {
2101    client: Client,
2102    access_key: String,
2103    secret_key: String,
2104    region: String,
2105}
2106
2107impl BedrockDriver {
2108    /// Create a new Bedrock driver.
2109    pub fn new(access_key: String, secret_key: String, region: String) -> Self {
2110        Self {
2111            client: Client::new(),
2112            access_key,
2113            secret_key,
2114            region,
2115        }
2116    }
2117
2118    /// Create a new Bedrock driver with a shared HTTP client.
2119    pub fn with_client(
2120        client: Client,
2121        access_key: String,
2122        secret_key: String,
2123        region: String,
2124    ) -> Self {
2125        Self {
2126            client,
2127            access_key,
2128            secret_key,
2129            region,
2130        }
2131    }
2132
2133    /// Build the Bedrock Converse API request body.
2134    pub fn build_request_body(&self, request: &CompletionRequest) -> serde_json::Value {
2135        let mut messages = Vec::new();
2136
2137        for msg in &request.messages {
2138            match msg.role {
2139                Role::User => {
2140                    let mut content: Vec<serde_json::Value> = Vec::new();
2141                    if !msg.content.is_empty() {
2142                        content.push(serde_json::json!({"text": msg.content}));
2143                    }
2144                    // Add multimodal parts for Bedrock (same as Anthropic).
2145                    for part in &msg.content_parts {
2146                        match part {
2147                            punch_types::ContentPart::Text { text } => {
2148                                content.push(serde_json::json!({"text": text}));
2149                            }
2150                            punch_types::ContentPart::Image { media_type, data } => {
2151                                content.push(serde_json::json!({
2152                                    "image": {
2153                                        "format": media_type.rsplit('/').next().unwrap_or("png"),
2154                                        "source": {
2155                                            "bytes": data,
2156                                        }
2157                                    }
2158                                }));
2159                            }
2160                        }
2161                    }
2162                    if content.is_empty() {
2163                        content.push(serde_json::json!({"text": ""}));
2164                    }
2165                    messages.push(serde_json::json!({
2166                        "role": "user",
2167                        "content": content,
2168                    }));
2169                }
2170                Role::Assistant => {
2171                    let mut content: Vec<serde_json::Value> = Vec::new();
2172                    if !msg.content.is_empty() {
2173                        content.push(serde_json::json!({"text": msg.content}));
2174                    }
2175                    for tc in &msg.tool_calls {
2176                        content.push(serde_json::json!({
2177                            "toolUse": {
2178                                "toolUseId": tc.id,
2179                                "name": tc.name,
2180                                "input": tc.input,
2181                            }
2182                        }));
2183                    }
2184                    if content.is_empty() {
2185                        content.push(serde_json::json!({"text": ""}));
2186                    }
2187                    messages.push(serde_json::json!({
2188                        "role": "assistant",
2189                        "content": content,
2190                    }));
2191                }
2192                Role::Tool => {
2193                    let mut content: Vec<serde_json::Value> = Vec::new();
2194                    for tr in &msg.tool_results {
2195                        content.push(serde_json::json!({
2196                            "toolResult": {
2197                                "toolUseId": tr.id,
2198                                "content": [{"text": tr.content}],
2199                                "status": if tr.is_error { "error" } else { "success" },
2200                            }
2201                        }));
2202                    }
2203                    messages.push(serde_json::json!({
2204                        "role": "user",
2205                        "content": content,
2206                    }));
2207                }
2208                Role::System => {
2209                    // System messages handled separately.
2210                }
2211            }
2212        }
2213
2214        let mut body = serde_json::json!({
2215            "messages": messages,
2216        });
2217
2218        let mut inference_config = serde_json::json!({
2219            "maxTokens": request.max_tokens,
2220        });
2221        if let Some(temp) = request.temperature {
2222            inference_config["temperature"] = serde_json::json!(temp);
2223        }
2224        body["inferenceConfig"] = inference_config;
2225
2226        if let Some(ref system) = request.system_prompt {
2227            body["system"] = serde_json::json!([{"text": system}]);
2228        }
2229
2230        if !request.tools.is_empty() {
2231            let tool_config: Vec<serde_json::Value> = request
2232                .tools
2233                .iter()
2234                .map(|t| {
2235                    serde_json::json!({
2236                        "toolSpec": {
2237                            "name": t.name,
2238                            "description": t.description,
2239                            "inputSchema": {"json": t.input_schema},
2240                        }
2241                    })
2242                })
2243                .collect();
2244            body["toolConfig"] = serde_json::json!({"tools": tool_config});
2245        }
2246
2247        body
2248    }
2249
2250    /// Build the endpoint URL for a model.
2251    pub fn build_url(&self, model_id: &str) -> String {
2252        format!(
2253            "https://bedrock-runtime.{}.amazonaws.com/model/{}/converse",
2254            self.region, model_id,
2255        )
2256    }
2257
2258    /// Parse the Bedrock Converse API response.
2259    pub fn parse_response(&self, body: &serde_json::Value) -> PunchResult<CompletionResponse> {
2260        let content = body["output"]["message"]["content"]
2261            .as_array()
2262            .cloned()
2263            .unwrap_or_default();
2264
2265        let mut text_content = String::new();
2266        let mut tool_calls = Vec::new();
2267
2268        for block in &content {
2269            if let Some(text) = block["text"].as_str() {
2270                if !text_content.is_empty() {
2271                    text_content.push('\n');
2272                }
2273                text_content.push_str(text);
2274            }
2275            if let Some(tu) = block.get("toolUse") {
2276                tool_calls.push(ToolCall {
2277                    id: tu["toolUseId"].as_str().unwrap_or_default().to_string(),
2278                    name: tu["name"].as_str().unwrap_or_default().to_string(),
2279                    input: tu["input"].clone(),
2280                });
2281            }
2282        }
2283
2284        let stop_reason_str = body["stopReason"].as_str().unwrap_or("end_turn");
2285        let stop_reason = if !tool_calls.is_empty() {
2286            StopReason::ToolUse
2287        } else {
2288            match stop_reason_str {
2289                "end_turn" => StopReason::EndTurn,
2290                "tool_use" => StopReason::ToolUse,
2291                "max_tokens" => StopReason::MaxTokens,
2292                _ => StopReason::EndTurn,
2293            }
2294        };
2295
2296        let usage = TokenUsage {
2297            input_tokens: body["usage"]["inputTokens"].as_u64().unwrap_or(0),
2298            output_tokens: body["usage"]["outputTokens"].as_u64().unwrap_or(0),
2299        };
2300
2301        // Strip thinking tags from reasoning models
2302        let text_content = strip_thinking_tags(&text_content);
2303
2304        let message = Message {
2305            role: Role::Assistant,
2306            content: text_content,
2307            tool_calls,
2308            tool_results: Vec::new(),
2309            content_parts: Vec::new(),
2310            timestamp: chrono::Utc::now(),
2311        };
2312
2313        Ok(CompletionResponse {
2314            message,
2315            usage,
2316            stop_reason,
2317        })
2318    }
2319
2320    /// Compute an AWS SigV4 signature and return the Authorization header value.
2321    ///
2322    /// This is a basic implementation sufficient for Bedrock API calls.
2323    pub fn sign_request(
2324        &self,
2325        method: &str,
2326        url: &str,
2327        headers: &[(String, String)],
2328        payload: &[u8],
2329        timestamp: &str, // format: "20260313T120000Z"
2330    ) -> PunchResult<String> {
2331        let date = &timestamp[..8]; // "20260313"
2332        let service = "bedrock";
2333
2334        // Parse the URL to get host and path.
2335        let parsed = url::Url::parse(url).map_err(|e| PunchError::Provider {
2336            provider: "bedrock".to_string(),
2337            message: format!("invalid URL: {e}"),
2338        })?;
2339        let host = parsed.host_str().unwrap_or("");
2340        let path = parsed.path();
2341
2342        // 1. Create canonical request.
2343        let payload_hash = hex_sha256(payload);
2344
2345        let mut signed_header_names: Vec<String> =
2346            headers.iter().map(|(k, _)| k.to_lowercase()).collect();
2347        signed_header_names.push("host".to_string());
2348        signed_header_names.push("x-amz-date".to_string());
2349        signed_header_names.sort();
2350        signed_header_names.dedup();
2351
2352        let mut header_map: Vec<(String, String)> = headers
2353            .iter()
2354            .map(|(k, v)| (k.to_lowercase(), v.trim().to_string()))
2355            .collect();
2356        header_map.push(("host".to_string(), host.to_string()));
2357        header_map.push(("x-amz-date".to_string(), timestamp.to_string()));
2358        header_map.sort_by(|a, b| a.0.cmp(&b.0));
2359        header_map.dedup_by(|a, b| a.0 == b.0);
2360
2361        let canonical_headers: String = header_map
2362            .iter()
2363            .map(|(k, v)| format!("{}:{}\n", k, v))
2364            .collect();
2365
2366        let signed_headers = signed_header_names.join(";");
2367
2368        let canonical_request = format!(
2369            "{}\n{}\n\n{}\n{}\n{}",
2370            method, path, canonical_headers, signed_headers, payload_hash,
2371        );
2372
2373        // 2. Create string to sign.
2374        let credential_scope = format!("{}/{}/{}/aws4_request", date, self.region, service);
2375        let string_to_sign = format!(
2376            "AWS4-HMAC-SHA256\n{}\n{}\n{}",
2377            timestamp,
2378            credential_scope,
2379            hex_sha256(canonical_request.as_bytes()),
2380        );
2381
2382        // 3. Calculate signing key.
2383        let k_date = hmac_sha256(
2384            format!("AWS4{}", self.secret_key).as_bytes(),
2385            date.as_bytes(),
2386        );
2387        let k_region = hmac_sha256(&k_date, self.region.as_bytes());
2388        let k_service = hmac_sha256(&k_region, service.as_bytes());
2389        let k_signing = hmac_sha256(&k_service, b"aws4_request");
2390
2391        // 4. Calculate signature.
2392        let signature = hex_encode(&hmac_sha256(&k_signing, string_to_sign.as_bytes()));
2393
2394        // 5. Build Authorization header.
2395        Ok(format!(
2396            "AWS4-HMAC-SHA256 Credential={}/{}, SignedHeaders={}, Signature={}",
2397            self.access_key, credential_scope, signed_headers, signature,
2398        ))
2399    }
2400}
2401
2402/// Compute SHA-256 hex digest.
2403fn hex_sha256(data: &[u8]) -> String {
2404    let mut hasher = Sha256::new();
2405    hasher.update(data);
2406    hex_encode(hasher.finalize().as_slice())
2407}
2408
2409/// Compute HMAC-SHA256.
2410fn hmac_sha256(key: &[u8], data: &[u8]) -> Vec<u8> {
2411    type HmacSha256 = Hmac<Sha256>;
2412    let mut mac = HmacSha256::new_from_slice(key).expect("HMAC can take key of any size");
2413    mac.update(data);
2414    mac.finalize().into_bytes().to_vec()
2415}
2416
2417/// Hex-encode bytes without an external crate.
2418fn hex_encode(bytes: &[u8]) -> String {
2419    bytes.iter().map(|b| format!("{:02x}", b)).collect()
2420}
2421
2422#[async_trait]
2423impl LlmDriver for BedrockDriver {
2424    async fn complete(&self, request: CompletionRequest) -> PunchResult<CompletionResponse> {
2425        let url = self.build_url(&request.model);
2426        let body = self.build_request_body(&request);
2427        let payload = serde_json::to_vec(&body).map_err(|e| PunchError::Provider {
2428            provider: "bedrock".to_string(),
2429            message: format!("failed to serialize request: {e}"),
2430        })?;
2431
2432        let timestamp = chrono::Utc::now().format("%Y%m%dT%H%M%SZ").to_string();
2433
2434        let auth_header = self.sign_request(
2435            "POST",
2436            &url,
2437            &[("content-type".to_string(), "application/json".to_string())],
2438            &payload,
2439            &timestamp,
2440        )?;
2441
2442        let parsed_url = url::Url::parse(&url).map_err(|e| PunchError::Provider {
2443            provider: "bedrock".to_string(),
2444            message: format!("invalid URL: {e}"),
2445        })?;
2446        let host = parsed_url.host_str().unwrap_or_default().to_string();
2447
2448        let response = self
2449            .client
2450            .post(&url)
2451            .header("content-type", "application/json")
2452            .header("host", &host)
2453            .header("x-amz-date", &timestamp)
2454            .header("authorization", &auth_header)
2455            .body(payload)
2456            .send()
2457            .await
2458            .map_err(|e| PunchError::Provider {
2459                provider: "bedrock".to_string(),
2460                message: format!("request failed: {e}"),
2461            })?;
2462
2463        let status = response.status();
2464
2465        if status == reqwest::StatusCode::TOO_MANY_REQUESTS {
2466            return Err(PunchError::RateLimited {
2467                provider: "bedrock".to_string(),
2468                retry_after_ms: 60_000,
2469            });
2470        }
2471
2472        if status == reqwest::StatusCode::UNAUTHORIZED || status == reqwest::StatusCode::FORBIDDEN {
2473            return Err(PunchError::Auth(
2474                "AWS Bedrock credentials are invalid or lack permissions".to_string(),
2475            ));
2476        }
2477
2478        let response_body: serde_json::Value =
2479            response.json().await.map_err(|e| PunchError::Provider {
2480                provider: "bedrock".to_string(),
2481                message: format!("failed to parse response: {e}"),
2482            })?;
2483
2484        if !status.is_success() {
2485            let error_msg = response_body["message"].as_str().unwrap_or("unknown error");
2486            return Err(PunchError::Provider {
2487                provider: "bedrock".to_string(),
2488                message: format!("API error ({}): {}", status, error_msg),
2489            });
2490        }
2491
2492        self.parse_response(&response_body)
2493    }
2494
2495    async fn stream_complete_with_callback(
2496        &self,
2497        request: CompletionRequest,
2498        callback: StreamCallback,
2499    ) -> PunchResult<CompletionResponse> {
2500        // Bedrock uses a proprietary binary event stream format for streaming.
2501        // Fall back to non-streaming and emit the result as a single final chunk.
2502        let response = self.complete(request).await?;
2503        callback(StreamChunk {
2504            delta: response.message.content.clone(),
2505            is_final: true,
2506            tool_call_delta: None,
2507        });
2508        Ok(response)
2509    }
2510}
2511
2512// ---------------------------------------------------------------------------
2513// Azure OpenAI driver
2514// ---------------------------------------------------------------------------
2515
2516/// Driver for Azure OpenAI deployments.
2517///
2518/// Uses the same request/response format as OpenAI but with Azure-specific
2519/// URL construction and API key header.
2520pub struct AzureOpenAiDriver {
2521    inner: OpenAiCompatibleDriver,
2522    resource: String,
2523    deployment: String,
2524    api_version: String,
2525}
2526
2527impl AzureOpenAiDriver {
2528    /// Create a new Azure OpenAI driver.
2529    ///
2530    /// - `api_key`: The Azure OpenAI API key.
2531    /// - `resource`: The Azure resource name (subdomain).
2532    /// - `deployment`: The deployment name.
2533    /// - `api_version`: API version string (e.g., "2024-02-01").
2534    pub fn new(
2535        api_key: String,
2536        resource: String,
2537        deployment: String,
2538        api_version: Option<String>,
2539    ) -> Self {
2540        let base_url = format!("https://{}.openai.azure.com", resource);
2541        Self {
2542            inner: OpenAiCompatibleDriver::new(api_key, base_url, "azure_openai".to_string()),
2543            resource,
2544            deployment,
2545            api_version: api_version.unwrap_or_else(|| "2024-02-01".to_string()),
2546        }
2547    }
2548
2549    /// Create a new Azure OpenAI driver with a shared HTTP client.
2550    pub fn with_client(
2551        client: Client,
2552        api_key: String,
2553        resource: String,
2554        deployment: String,
2555        api_version: Option<String>,
2556    ) -> Self {
2557        let base_url = format!("https://{}.openai.azure.com", resource);
2558        Self {
2559            inner: OpenAiCompatibleDriver::with_client(
2560                client,
2561                api_key,
2562                base_url,
2563                "azure_openai".to_string(),
2564            ),
2565            resource,
2566            deployment,
2567            api_version: api_version.unwrap_or_else(|| "2024-02-01".to_string()),
2568        }
2569    }
2570
2571    /// Build the Azure OpenAI endpoint URL.
2572    pub fn build_url(&self) -> String {
2573        format!(
2574            "https://{}.openai.azure.com/openai/deployments/{}/chat/completions?api-version={}",
2575            self.resource, self.deployment, self.api_version,
2576        )
2577    }
2578
2579    /// Get the resource name.
2580    pub fn resource(&self) -> &str {
2581        &self.resource
2582    }
2583
2584    /// Get the deployment name.
2585    pub fn deployment(&self) -> &str {
2586        &self.deployment
2587    }
2588
2589    /// Build request body (delegates to inner OpenAI-compatible driver).
2590    pub fn build_request_body(&self, request: &CompletionRequest) -> serde_json::Value {
2591        self.inner.build_request_body(request)
2592    }
2593
2594    /// Parse response (delegates to inner OpenAI-compatible driver).
2595    pub fn parse_response(&self, body: &serde_json::Value) -> PunchResult<CompletionResponse> {
2596        self.inner.parse_response(body)
2597    }
2598}
2599
2600#[async_trait]
2601impl LlmDriver for AzureOpenAiDriver {
2602    async fn complete(&self, request: CompletionRequest) -> PunchResult<CompletionResponse> {
2603        let url = self.build_url();
2604        let body = self.inner.build_request_body(&request);
2605
2606        let response = self
2607            .inner
2608            .client
2609            .post(&url)
2610            .header("api-key", &self.inner.api_key)
2611            .header("content-type", "application/json")
2612            .json(&body)
2613            .send()
2614            .await
2615            .map_err(|e| PunchError::Provider {
2616                provider: "azure_openai".to_string(),
2617                message: format!("request failed: {e}"),
2618            })?;
2619
2620        let status = response.status();
2621
2622        if status == reqwest::StatusCode::TOO_MANY_REQUESTS {
2623            let retry_after = response
2624                .headers()
2625                .get("retry-after")
2626                .and_then(|v| v.to_str().ok())
2627                .and_then(|s| s.parse::<u64>().ok())
2628                .unwrap_or(60)
2629                * 1000;
2630
2631            return Err(PunchError::RateLimited {
2632                provider: "azure_openai".to_string(),
2633                retry_after_ms: retry_after,
2634            });
2635        }
2636
2637        if status == reqwest::StatusCode::UNAUTHORIZED || status == reqwest::StatusCode::FORBIDDEN {
2638            return Err(PunchError::Auth(
2639                "Azure OpenAI API key is invalid or lacks permissions".to_string(),
2640            ));
2641        }
2642
2643        let response_body: serde_json::Value =
2644            response.json().await.map_err(|e| PunchError::Provider {
2645                provider: "azure_openai".to_string(),
2646                message: format!("failed to parse response: {e}"),
2647            })?;
2648
2649        if !status.is_success() {
2650            let error_msg = response_body["error"]["message"]
2651                .as_str()
2652                .unwrap_or("unknown error");
2653            return Err(PunchError::Provider {
2654                provider: "azure_openai".to_string(),
2655                message: format!("API error ({}): {}", status, error_msg),
2656            });
2657        }
2658
2659        self.inner.parse_response(&response_body)
2660    }
2661
2662    async fn stream_complete_with_callback(
2663        &self,
2664        request: CompletionRequest,
2665        callback: StreamCallback,
2666    ) -> PunchResult<CompletionResponse> {
2667        let url = self.build_url();
2668        let mut body = self.inner.build_request_body(&request);
2669        body["stream"] = serde_json::json!(true);
2670
2671        let response = self
2672            .inner
2673            .client
2674            .post(&url)
2675            .header("api-key", &self.inner.api_key)
2676            .header("content-type", "application/json")
2677            .json(&body)
2678            .send()
2679            .await
2680            .map_err(|e| PunchError::Provider {
2681                provider: "azure_openai".to_string(),
2682                message: format!("stream request failed: {e}"),
2683            })?;
2684
2685        let status = response.status();
2686        if status == reqwest::StatusCode::TOO_MANY_REQUESTS {
2687            return Err(PunchError::RateLimited {
2688                provider: "azure_openai".to_string(),
2689                retry_after_ms: 60_000,
2690            });
2691        }
2692        if !status.is_success() {
2693            let err_body: serde_json::Value =
2694                response.json().await.unwrap_or(serde_json::json!({}));
2695            let msg = err_body["error"]["message"]
2696                .as_str()
2697                .unwrap_or("unknown error");
2698            return Err(PunchError::Provider {
2699                provider: "azure_openai".to_string(),
2700                message: format!("API error ({}): {}", status, msg),
2701            });
2702        }
2703
2704        let raw = read_stream_body(response).await?;
2705        // Azure OpenAI uses the same SSE format as OpenAI
2706        let assembled = self.inner.parse_openai_stream(&raw, &callback)?;
2707        Ok(assembled)
2708    }
2709}
2710
2711// ---------------------------------------------------------------------------
2712// Factory
2713// ---------------------------------------------------------------------------
2714
2715/// Default base URLs for known providers.
2716fn default_base_url(provider: &Provider) -> &'static str {
2717    match provider {
2718        Provider::Anthropic => "https://api.anthropic.com",
2719        Provider::OpenAI => "https://api.openai.com",
2720        Provider::Google => "https://generativelanguage.googleapis.com",
2721        Provider::Groq => "https://api.groq.com/openai",
2722        Provider::DeepSeek => "https://api.deepseek.com",
2723        Provider::Ollama => "http://localhost:11434",
2724        Provider::Mistral => "https://api.mistral.ai",
2725        Provider::Together => "https://api.together.xyz",
2726        Provider::Fireworks => "https://api.fireworks.ai/inference",
2727        Provider::Cerebras => "https://api.cerebras.ai",
2728        Provider::XAI => "https://api.x.ai",
2729        Provider::Cohere => "https://api.cohere.ai",
2730        Provider::Bedrock => "https://bedrock-runtime.us-east-1.amazonaws.com",
2731        Provider::AzureOpenAi => "",
2732        Provider::Custom(_) => "",
2733    }
2734}
2735
2736/// Create an [`LlmDriver`] from a [`ModelConfig`].
2737///
2738/// Reads the API key from the environment variable specified in
2739/// `config.api_key_env`. Returns an error if the env var is missing
2740/// (except for Ollama which does not require auth).
2741/// Create a driver from config, optionally using a shared HTTP client.
2742///
2743/// If `shared_client` is `Some`, the driver will use that client for
2744/// connection pooling. Otherwise it creates its own client (backward compat).
2745pub fn create_driver(config: &ModelConfig) -> PunchResult<Arc<dyn LlmDriver>> {
2746    create_driver_with_client(config, None)
2747}
2748
2749/// Create a driver from config with an optional shared [`reqwest::Client`].
2750pub fn create_driver_with_client(
2751    config: &ModelConfig,
2752    shared_client: Option<&Client>,
2753) -> PunchResult<Arc<dyn LlmDriver>> {
2754    let api_key = match &config.api_key_env {
2755        Some(env_var) => std::env::var(env_var).map_err(|_| {
2756            PunchError::Auth(format!(
2757                "environment variable '{}' not set for {} driver",
2758                env_var, config.provider
2759            ))
2760        })?,
2761        None => {
2762            // Ollama typically has no auth; others will fail at the API.
2763            String::new()
2764        }
2765    };
2766
2767    let base_url = config
2768        .base_url
2769        .clone()
2770        .unwrap_or_else(|| default_base_url(&config.provider).to_string());
2771
2772    match &config.provider {
2773        Provider::Anthropic => {
2774            if let Some(client) = shared_client {
2775                Ok(Arc::new(AnthropicDriver::with_client(
2776                    client.clone(),
2777                    api_key,
2778                    Some(base_url),
2779                )))
2780            } else {
2781                Ok(Arc::new(AnthropicDriver::new(api_key, Some(base_url))))
2782            }
2783        }
2784        Provider::Google => {
2785            if let Some(client) = shared_client {
2786                Ok(Arc::new(GeminiDriver::with_client(
2787                    client.clone(),
2788                    api_key,
2789                    Some(base_url),
2790                )))
2791            } else {
2792                Ok(Arc::new(GeminiDriver::new(api_key, Some(base_url))))
2793            }
2794        }
2795        Provider::Ollama => {
2796            if let Some(client) = shared_client {
2797                Ok(Arc::new(OllamaDriver::with_client(
2798                    client.clone(),
2799                    Some(base_url),
2800                )))
2801            } else {
2802                Ok(Arc::new(OllamaDriver::new(Some(base_url))))
2803            }
2804        }
2805        Provider::Bedrock => {
2806            // For Bedrock, api_key is expected to be "ACCESS_KEY:SECRET_KEY" or
2807            // we read AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY from env.
2808            let (access_key, secret_key) = if api_key.contains(':') {
2809                let parts: Vec<&str> = api_key.splitn(2, ':').collect();
2810                (parts[0].to_string(), parts[1].to_string())
2811            } else {
2812                let ak = std::env::var("AWS_ACCESS_KEY_ID").unwrap_or(api_key);
2813                let sk = std::env::var("AWS_SECRET_ACCESS_KEY").unwrap_or_default();
2814                (ak, sk)
2815            };
2816            // Extract region from base_url or default to us-east-1.
2817            let region = if base_url.contains("bedrock-runtime.") {
2818                base_url
2819                    .trim_start_matches("https://bedrock-runtime.")
2820                    .split('.')
2821                    .next()
2822                    .unwrap_or("us-east-1")
2823                    .to_string()
2824            } else {
2825                "us-east-1".to_string()
2826            };
2827            if let Some(client) = shared_client {
2828                Ok(Arc::new(BedrockDriver::with_client(
2829                    client.clone(),
2830                    access_key,
2831                    secret_key,
2832                    region,
2833                )))
2834            } else {
2835                Ok(Arc::new(BedrockDriver::new(access_key, secret_key, region)))
2836            }
2837        }
2838        Provider::AzureOpenAi => {
2839            // For Azure, base_url should be "https://{resource}.openai.azure.com"
2840            // and model is the deployment name.
2841            let resource = if base_url.contains(".openai.azure.com") {
2842                base_url
2843                    .trim_start_matches("https://")
2844                    .split('.')
2845                    .next()
2846                    .unwrap_or("default")
2847                    .to_string()
2848            } else {
2849                base_url.clone()
2850            };
2851            let deployment = config.model.clone();
2852            if let Some(client) = shared_client {
2853                Ok(Arc::new(AzureOpenAiDriver::with_client(
2854                    client.clone(),
2855                    api_key,
2856                    resource,
2857                    deployment,
2858                    None,
2859                )))
2860            } else {
2861                Ok(Arc::new(AzureOpenAiDriver::new(
2862                    api_key, resource, deployment, None,
2863                )))
2864            }
2865        }
2866        provider => {
2867            let name = provider.to_string();
2868            if let Some(client) = shared_client {
2869                Ok(Arc::new(OpenAiCompatibleDriver::with_client(
2870                    client.clone(),
2871                    api_key,
2872                    base_url,
2873                    name,
2874                )))
2875            } else {
2876                Ok(Arc::new(OpenAiCompatibleDriver::new(
2877                    api_key, base_url, name,
2878                )))
2879            }
2880        }
2881    }
2882}
2883
2884// ---------------------------------------------------------------------------
2885// Tests
2886// ---------------------------------------------------------------------------
2887
2888#[cfg(test)]
2889mod tests {
2890    use super::*;
2891    use punch_types::ToolCategory;
2892
2893    /// Helper to build a simple completion request for testing.
2894    fn simple_request() -> CompletionRequest {
2895        CompletionRequest {
2896            model: "test-model".to_string(),
2897            messages: vec![Message::new(Role::User, "Hello")],
2898            tools: Vec::new(),
2899            max_tokens: 4096,
2900            temperature: Some(0.7),
2901            system_prompt: Some("You are helpful.".to_string()),
2902        }
2903    }
2904
2905    /// Helper to build a request with tools.
2906    fn request_with_tools() -> CompletionRequest {
2907        CompletionRequest {
2908            model: "test-model".to_string(),
2909            messages: vec![Message::new(Role::User, "Use the tool")],
2910            tools: vec![ToolDefinition {
2911                name: "get_weather".to_string(),
2912                description: "Get weather for a city".to_string(),
2913                input_schema: serde_json::json!({
2914                    "type": "object",
2915                    "properties": {
2916                        "city": {"type": "string"}
2917                    }
2918                }),
2919                category: ToolCategory::Web,
2920            }],
2921            max_tokens: 4096,
2922            temperature: Some(0.7),
2923            system_prompt: None,
2924        }
2925    }
2926
2927    // -----------------------------------------------------------------------
2928    // Gemini tests
2929    // -----------------------------------------------------------------------
2930
2931    #[test]
2932    fn gemini_request_formatting() {
2933        let driver = GeminiDriver::new("test-key".to_string(), None);
2934        let body = driver.build_request_body(&simple_request());
2935
2936        let contents = body["contents"].as_array().unwrap();
2937        assert_eq!(contents.len(), 1);
2938        // User message should contain only the user text (system is separate).
2939        let first_text = contents[0]["parts"][0]["text"].as_str().unwrap();
2940        assert_eq!(first_text, "Hello");
2941        assert_eq!(contents[0]["role"].as_str().unwrap(), "user");
2942        // System prompt should be in the dedicated systemInstruction field.
2943        let sys_text = body["system_instruction"]["parts"][0]["text"]
2944            .as_str()
2945            .unwrap();
2946        assert_eq!(sys_text, "You are helpful.");
2947
2948        assert_eq!(body["generationConfig"]["maxOutputTokens"], 4096);
2949        assert!((body["generationConfig"]["temperature"].as_f64().unwrap() - 0.7).abs() < 0.001);
2950    }
2951
2952    #[test]
2953    fn gemini_response_parsing() {
2954        let driver = GeminiDriver::new("test-key".to_string(), None);
2955        let response_body = serde_json::json!({
2956            "candidates": [{
2957                "content": {
2958                    "parts": [{"text": "Hello there!"}],
2959                    "role": "model"
2960                },
2961                "finishReason": "STOP"
2962            }],
2963            "usageMetadata": {
2964                "promptTokenCount": 10,
2965                "candidatesTokenCount": 5
2966            }
2967        });
2968
2969        let resp = driver.parse_response(&response_body).unwrap();
2970        assert_eq!(resp.message.content, "Hello there!");
2971        assert_eq!(resp.stop_reason, StopReason::EndTurn);
2972        assert_eq!(resp.usage.input_tokens, 10);
2973        assert_eq!(resp.usage.output_tokens, 5);
2974    }
2975
2976    #[test]
2977    fn gemini_role_mapping_system_prepended() {
2978        let driver = GeminiDriver::new("test-key".to_string(), None);
2979        let req = CompletionRequest {
2980            model: "gemini-pro".to_string(),
2981            messages: vec![
2982                Message::new(Role::System, "Be concise."),
2983                Message::new(Role::User, "Hi"),
2984            ],
2985            tools: Vec::new(),
2986            max_tokens: 1024,
2987            temperature: None,
2988            system_prompt: None,
2989        };
2990        let body = driver.build_request_body(&req);
2991        let contents = body["contents"].as_array().unwrap();
2992        // System message should go to systemInstruction, not user message.
2993        assert_eq!(contents.len(), 1);
2994        let text = contents[0]["parts"][0]["text"].as_str().unwrap();
2995        assert_eq!(text, "Hi");
2996        // System text lives in the dedicated field.
2997        let sys_text = body["system_instruction"]["parts"][0]["text"]
2998            .as_str()
2999            .unwrap();
3000        assert_eq!(sys_text, "Be concise.");
3001    }
3002
3003    #[test]
3004    fn gemini_function_call_parsing() {
3005        let driver = GeminiDriver::new("test-key".to_string(), None);
3006        let response_body = serde_json::json!({
3007            "candidates": [{
3008                "content": {
3009                    "parts": [
3010                        {"text": "Let me check the weather."},
3011                        {
3012                            "functionCall": {
3013                                "name": "get_weather",
3014                                "args": {"city": "London"}
3015                            }
3016                        }
3017                    ],
3018                    "role": "model"
3019                },
3020                "finishReason": "STOP"
3021            }],
3022            "usageMetadata": {
3023                "promptTokenCount": 15,
3024                "candidatesTokenCount": 8
3025            }
3026        });
3027
3028        let resp = driver.parse_response(&response_body).unwrap();
3029        assert_eq!(resp.message.content, "Let me check the weather.");
3030        assert_eq!(resp.stop_reason, StopReason::ToolUse);
3031        assert_eq!(resp.message.tool_calls.len(), 1);
3032        assert_eq!(resp.message.tool_calls[0].name, "get_weather");
3033        assert_eq!(resp.message.tool_calls[0].input["city"], "London");
3034    }
3035
3036    #[test]
3037    fn gemini_api_key_in_url() {
3038        let driver = GeminiDriver::new("my-secret-key".to_string(), None);
3039        let url = driver.build_url("gemini-pro");
3040        assert!(url.contains("key=my-secret-key"));
3041        assert!(url.contains("models/gemini-pro:generateContent"));
3042    }
3043
3044    // -----------------------------------------------------------------------
3045    // Ollama tests
3046    // -----------------------------------------------------------------------
3047
3048    #[test]
3049    fn ollama_request_formatting() {
3050        let driver = OllamaDriver::new(None);
3051        let body = driver.build_request_body(&simple_request());
3052
3053        assert_eq!(body["model"], "test-model");
3054        assert_eq!(body["stream"], false);
3055        let messages = body["messages"].as_array().unwrap();
3056        // system prompt + user message = 2 messages
3057        assert_eq!(messages.len(), 2);
3058        assert_eq!(messages[0]["role"], "system");
3059        assert_eq!(messages[0]["content"], "You are helpful.");
3060        assert_eq!(messages[1]["role"], "user");
3061        assert_eq!(messages[1]["content"], "Hello");
3062        assert!((body["options"]["temperature"].as_f64().unwrap() - 0.7).abs() < 0.001);
3063    }
3064
3065    #[test]
3066    fn ollama_response_parsing() {
3067        let driver = OllamaDriver::new(None);
3068        let response_body = serde_json::json!({
3069            "message": {
3070                "role": "assistant",
3071                "content": "Hi there!"
3072            },
3073            "done": true,
3074            "prompt_eval_count": 20,
3075            "eval_count": 10
3076        });
3077
3078        let resp = driver.parse_response(&response_body).unwrap();
3079        assert_eq!(resp.message.content, "Hi there!");
3080        assert_eq!(resp.stop_reason, StopReason::EndTurn);
3081        assert_eq!(resp.usage.input_tokens, 20);
3082        assert_eq!(resp.usage.output_tokens, 10);
3083    }
3084
3085    #[test]
3086    fn ollama_default_endpoint() {
3087        let driver = OllamaDriver::new(None);
3088        assert_eq!(driver.base_url(), "http://localhost:11434");
3089    }
3090
3091    #[test]
3092    fn ollama_custom_endpoint() {
3093        let driver = OllamaDriver::new(Some("http://myhost:9999".to_string()));
3094        assert_eq!(driver.base_url(), "http://myhost:9999");
3095    }
3096
3097    // -----------------------------------------------------------------------
3098    // Bedrock tests
3099    // -----------------------------------------------------------------------
3100
3101    #[test]
3102    fn bedrock_request_formatting() {
3103        let driver = BedrockDriver::new(
3104            "TESTKEY".to_string(),
3105            "testsecret".to_string(),
3106            "us-west-2".to_string(),
3107        );
3108        let body = driver.build_request_body(&simple_request());
3109
3110        let messages = body["messages"].as_array().unwrap();
3111        assert_eq!(messages.len(), 1);
3112        assert_eq!(messages[0]["role"], "user");
3113        assert_eq!(messages[0]["content"][0]["text"], "Hello");
3114
3115        assert_eq!(body["inferenceConfig"]["maxTokens"], 4096);
3116        assert!((body["inferenceConfig"]["temperature"].as_f64().unwrap() - 0.7).abs() < 0.001);
3117        assert_eq!(body["system"][0]["text"], "You are helpful.");
3118    }
3119
3120    #[test]
3121    fn bedrock_sigv4_canonical_request() {
3122        let driver = BedrockDriver::new(
3123            "TESTACCESS1234567890".to_string(),
3124            "TestSecretKeyValue1234567890abcdefghijk".to_string(),
3125            "us-east-1".to_string(),
3126        );
3127
3128        let payload = b"{}";
3129        let timestamp = "20260313T120000Z";
3130
3131        let auth = driver
3132            .sign_request(
3133                "POST",
3134                "https://bedrock-runtime.us-east-1.amazonaws.com/model/test/converse",
3135                &[("content-type".to_string(), "application/json".to_string())],
3136                payload,
3137                timestamp,
3138            )
3139            .unwrap();
3140
3141        assert!(auth.starts_with(
3142            "AWS4-HMAC-SHA256 Credential=TESTACCESS1234567890/20260313/us-east-1/bedrock/aws4_request"
3143        ));
3144        assert!(auth.contains("SignedHeaders=content-type;host;x-amz-date"));
3145        assert!(auth.contains("Signature="));
3146    }
3147
3148    #[test]
3149    fn bedrock_response_parsing() {
3150        let driver = BedrockDriver::new(
3151            "key".to_string(),
3152            "secret".to_string(),
3153            "us-east-1".to_string(),
3154        );
3155        let response_body = serde_json::json!({
3156            "output": {
3157                "message": {
3158                    "role": "assistant",
3159                    "content": [{"text": "The answer is 42."}]
3160                }
3161            },
3162            "stopReason": "end_turn",
3163            "usage": {
3164                "inputTokens": 100,
3165                "outputTokens": 50
3166            }
3167        });
3168
3169        let resp = driver.parse_response(&response_body).unwrap();
3170        assert_eq!(resp.message.content, "The answer is 42.");
3171        assert_eq!(resp.stop_reason, StopReason::EndTurn);
3172        assert_eq!(resp.usage.input_tokens, 100);
3173        assert_eq!(resp.usage.output_tokens, 50);
3174    }
3175
3176    // -----------------------------------------------------------------------
3177    // Azure OpenAI tests
3178    // -----------------------------------------------------------------------
3179
3180    #[test]
3181    fn azure_openai_url_construction() {
3182        let driver = AzureOpenAiDriver::new(
3183            "my-azure-key".to_string(),
3184            "myresource".to_string(),
3185            "gpt-4-deployment".to_string(),
3186            None,
3187        );
3188        let url = driver.build_url();
3189        assert_eq!(
3190            url,
3191            "https://myresource.openai.azure.com/openai/deployments/gpt-4-deployment/chat/completions?api-version=2024-02-01"
3192        );
3193    }
3194
3195    #[test]
3196    fn azure_openai_custom_api_version() {
3197        let driver = AzureOpenAiDriver::new(
3198            "key".to_string(),
3199            "res".to_string(),
3200            "dep".to_string(),
3201            Some("2024-06-01".to_string()),
3202        );
3203        let url = driver.build_url();
3204        assert!(url.contains("api-version=2024-06-01"));
3205    }
3206
3207    #[test]
3208    fn azure_openai_request_formatting() {
3209        let driver = AzureOpenAiDriver::new(
3210            "key".to_string(),
3211            "res".to_string(),
3212            "dep".to_string(),
3213            None,
3214        );
3215        let body = driver.build_request_body(&simple_request());
3216        // Should use OpenAI format.
3217        let messages = body["messages"].as_array().unwrap();
3218        // system prompt + user message = 2
3219        assert_eq!(messages.len(), 2);
3220        assert_eq!(messages[0]["role"], "system");
3221        assert_eq!(messages[1]["role"], "user");
3222        assert_eq!(body["model"], "test-model");
3223    }
3224
3225    #[test]
3226    fn azure_openai_resource_and_deployment() {
3227        let driver = AzureOpenAiDriver::new(
3228            "key".to_string(),
3229            "my-resource".to_string(),
3230            "my-deploy".to_string(),
3231            None,
3232        );
3233        assert_eq!(driver.resource(), "my-resource");
3234        assert_eq!(driver.deployment(), "my-deploy");
3235    }
3236
3237    // -----------------------------------------------------------------------
3238    // create_driver dispatch tests
3239    // -----------------------------------------------------------------------
3240
3241    #[test]
3242    fn create_driver_dispatches_ollama() {
3243        let config = ModelConfig {
3244            provider: Provider::Ollama,
3245            model: "llama3".to_string(),
3246            api_key_env: None,
3247            base_url: None,
3248            max_tokens: None,
3249            temperature: None,
3250        };
3251        // Ollama does not need an API key, so this should succeed.
3252        let driver = create_driver(&config);
3253        assert!(driver.is_ok());
3254    }
3255
3256    #[test]
3257    fn create_driver_dispatches_gemini() {
3258        // Set a fake env var for this test.
3259        // SAFETY: Test is single-threaded relative to this env var name.
3260        unsafe { std::env::set_var("TEST_GEMINI_KEY_DISPATCH", "fake-key") };
3261        let config = ModelConfig {
3262            provider: Provider::Google,
3263            model: "gemini-pro".to_string(),
3264            api_key_env: Some("TEST_GEMINI_KEY_DISPATCH".to_string()),
3265            base_url: None,
3266            max_tokens: None,
3267            temperature: None,
3268        };
3269        let driver = create_driver(&config);
3270        assert!(driver.is_ok());
3271        unsafe { std::env::remove_var("TEST_GEMINI_KEY_DISPATCH") };
3272    }
3273
3274    #[test]
3275    fn create_driver_dispatches_bedrock() {
3276        // SAFETY: Test is single-threaded relative to this env var name.
3277        unsafe { std::env::set_var("TEST_BEDROCK_KEY_DISPATCH", "TESTKEY:TESTSECRET") };
3278        let config = ModelConfig {
3279            provider: Provider::Bedrock,
3280            model: "anthropic.claude-v2".to_string(),
3281            api_key_env: Some("TEST_BEDROCK_KEY_DISPATCH".to_string()),
3282            base_url: None,
3283            max_tokens: None,
3284            temperature: None,
3285        };
3286        let driver = create_driver(&config);
3287        assert!(driver.is_ok());
3288        unsafe { std::env::remove_var("TEST_BEDROCK_KEY_DISPATCH") };
3289    }
3290
3291    #[test]
3292    fn create_driver_dispatches_azure_openai() {
3293        // SAFETY: Test is single-threaded relative to this env var name.
3294        unsafe { std::env::set_var("TEST_AZURE_KEY_DISPATCH", "azure-key") };
3295        let config = ModelConfig {
3296            provider: Provider::AzureOpenAi,
3297            model: "gpt-4".to_string(),
3298            api_key_env: Some("TEST_AZURE_KEY_DISPATCH".to_string()),
3299            base_url: Some("https://myres.openai.azure.com".to_string()),
3300            max_tokens: None,
3301            temperature: None,
3302        };
3303        let driver = create_driver(&config);
3304        assert!(driver.is_ok());
3305        unsafe { std::env::remove_var("TEST_AZURE_KEY_DISPATCH") };
3306    }
3307
3308    #[test]
3309    fn gemini_tools_in_request() {
3310        let driver = GeminiDriver::new("key".to_string(), None);
3311        let body = driver.build_request_body(&request_with_tools());
3312
3313        let tools = body["tools"].as_array().unwrap();
3314        assert_eq!(tools.len(), 1);
3315        let func_decls = tools[0]["function_declarations"].as_array().unwrap();
3316        assert_eq!(func_decls.len(), 1);
3317        assert_eq!(func_decls[0]["name"], "get_weather");
3318    }
3319
3320    #[test]
3321    fn ollama_tools_in_request() {
3322        let driver = OllamaDriver::new(None);
3323        let body = driver.build_request_body(&request_with_tools());
3324
3325        let tools = body["tools"].as_array().unwrap();
3326        assert_eq!(tools.len(), 1);
3327        assert_eq!(tools[0]["type"], "function");
3328        assert_eq!(tools[0]["function"]["name"], "get_weather");
3329    }
3330
3331    #[test]
3332    fn bedrock_url_construction() {
3333        let driver = BedrockDriver::new(
3334            "key".to_string(),
3335            "secret".to_string(),
3336            "eu-west-1".to_string(),
3337        );
3338        let url = driver.build_url("anthropic.claude-3-sonnet");
3339        assert_eq!(
3340            url,
3341            "https://bedrock-runtime.eu-west-1.amazonaws.com/model/anthropic.claude-3-sonnet/converse"
3342        );
3343    }
3344
3345    // -----------------------------------------------------------------------
3346    // TokenUsage tests
3347    // -----------------------------------------------------------------------
3348
3349    #[test]
3350    fn token_usage_default() {
3351        let u = TokenUsage::default();
3352        assert_eq!(u.input_tokens, 0);
3353        assert_eq!(u.output_tokens, 0);
3354        assert_eq!(u.total(), 0);
3355    }
3356
3357    #[test]
3358    fn token_usage_accumulate() {
3359        let mut u = TokenUsage {
3360            input_tokens: 10,
3361            output_tokens: 20,
3362        };
3363        let other = TokenUsage {
3364            input_tokens: 5,
3365            output_tokens: 15,
3366        };
3367        u.accumulate(&other);
3368        assert_eq!(u.input_tokens, 15);
3369        assert_eq!(u.output_tokens, 35);
3370        assert_eq!(u.total(), 50);
3371    }
3372
3373    #[test]
3374    fn token_usage_total() {
3375        let u = TokenUsage {
3376            input_tokens: 100,
3377            output_tokens: 200,
3378        };
3379        assert_eq!(u.total(), 300);
3380    }
3381
3382    // -----------------------------------------------------------------------
3383    // StopReason serialization
3384    // -----------------------------------------------------------------------
3385
3386    #[test]
3387    fn stop_reason_serialization() {
3388        let json = serde_json::to_string(&StopReason::EndTurn).unwrap();
3389        assert_eq!(json, "\"end_turn\"");
3390
3391        let json = serde_json::to_string(&StopReason::ToolUse).unwrap();
3392        assert_eq!(json, "\"tool_use\"");
3393
3394        let json = serde_json::to_string(&StopReason::MaxTokens).unwrap();
3395        assert_eq!(json, "\"max_tokens\"");
3396
3397        let json = serde_json::to_string(&StopReason::Error).unwrap();
3398        assert_eq!(json, "\"error\"");
3399    }
3400
3401    #[test]
3402    fn stop_reason_deserialization() {
3403        let sr: StopReason = serde_json::from_str("\"end_turn\"").unwrap();
3404        assert_eq!(sr, StopReason::EndTurn);
3405
3406        let sr: StopReason = serde_json::from_str("\"tool_use\"").unwrap();
3407        assert_eq!(sr, StopReason::ToolUse);
3408    }
3409
3410    // -----------------------------------------------------------------------
3411    // Anthropic driver tests
3412    // -----------------------------------------------------------------------
3413
3414    #[test]
3415    fn anthropic_request_body_simple() {
3416        let driver = AnthropicDriver::new("test-key".to_string(), None);
3417        let body = driver.build_request_body(&simple_request());
3418
3419        assert_eq!(body["model"], "test-model");
3420        assert_eq!(body["max_tokens"], 4096);
3421        // System prompt is now a structured content block with cache_control.
3422        let system = body["system"].as_array().unwrap();
3423        assert_eq!(system.len(), 1);
3424        assert_eq!(system[0]["type"], "text");
3425        assert_eq!(system[0]["text"], "You are helpful.");
3426        assert_eq!(system[0]["cache_control"]["type"], "ephemeral");
3427        assert!((body["temperature"].as_f64().unwrap() - 0.7).abs() < 0.001);
3428
3429        let messages = body["messages"].as_array().unwrap();
3430        assert_eq!(messages.len(), 1);
3431        assert_eq!(messages[0]["role"], "user");
3432        assert_eq!(messages[0]["content"], "Hello");
3433    }
3434
3435    #[test]
3436    fn anthropic_request_body_with_tools() {
3437        let driver = AnthropicDriver::new("test-key".to_string(), None);
3438        let body = driver.build_request_body(&request_with_tools());
3439
3440        let tools = body["tools"].as_array().unwrap();
3441        assert_eq!(tools.len(), 1);
3442        assert_eq!(tools[0]["name"], "get_weather");
3443        assert!(tools[0]["input_schema"]["properties"].is_object());
3444    }
3445
3446    #[test]
3447    fn anthropic_request_body_no_system_prompt() {
3448        let driver = AnthropicDriver::new("test-key".to_string(), None);
3449        let req = CompletionRequest {
3450            model: "test".into(),
3451            messages: vec![Message::new(Role::User, "Hi")],
3452            tools: Vec::new(),
3453            max_tokens: 100,
3454            temperature: None,
3455            system_prompt: None,
3456        };
3457        let body = driver.build_request_body(&req);
3458        assert!(body.get("system").is_none());
3459        assert!(body.get("temperature").is_none());
3460    }
3461
3462    #[test]
3463    fn anthropic_parse_response_text() {
3464        let driver = AnthropicDriver::new("test-key".to_string(), None);
3465        let response_body = serde_json::json!({
3466            "content": [
3467                {"type": "text", "text": "Hello!"}
3468            ],
3469            "stop_reason": "end_turn",
3470            "usage": {
3471                "input_tokens": 10,
3472                "output_tokens": 5
3473            }
3474        });
3475
3476        let resp = driver.parse_response(&response_body).unwrap();
3477        assert_eq!(resp.message.content, "Hello!");
3478        assert_eq!(resp.stop_reason, StopReason::EndTurn);
3479        assert_eq!(resp.usage.input_tokens, 10);
3480        assert_eq!(resp.usage.output_tokens, 5);
3481        assert!(resp.message.tool_calls.is_empty());
3482    }
3483
3484    #[test]
3485    fn anthropic_parse_response_tool_use() {
3486        let driver = AnthropicDriver::new("test-key".to_string(), None);
3487        let response_body = serde_json::json!({
3488            "content": [
3489                {"type": "text", "text": "Let me check."},
3490                {
3491                    "type": "tool_use",
3492                    "id": "tool_abc",
3493                    "name": "get_weather",
3494                    "input": {"city": "NYC"}
3495                }
3496            ],
3497            "stop_reason": "tool_use",
3498            "usage": {"input_tokens": 20, "output_tokens": 15}
3499        });
3500
3501        let resp = driver.parse_response(&response_body).unwrap();
3502        assert_eq!(resp.message.content, "Let me check.");
3503        assert_eq!(resp.stop_reason, StopReason::ToolUse);
3504        assert_eq!(resp.message.tool_calls.len(), 1);
3505        assert_eq!(resp.message.tool_calls[0].id, "tool_abc");
3506        assert_eq!(resp.message.tool_calls[0].name, "get_weather");
3507        assert_eq!(resp.message.tool_calls[0].input["city"], "NYC");
3508    }
3509
3510    #[test]
3511    fn anthropic_parse_response_max_tokens() {
3512        let driver = AnthropicDriver::new("test-key".to_string(), None);
3513        let response_body = serde_json::json!({
3514            "content": [{"type": "text", "text": "truncated"}],
3515            "stop_reason": "max_tokens",
3516            "usage": {"input_tokens": 5, "output_tokens": 100}
3517        });
3518
3519        let resp = driver.parse_response(&response_body).unwrap();
3520        assert_eq!(resp.stop_reason, StopReason::MaxTokens);
3521    }
3522
3523    #[test]
3524    fn anthropic_parse_response_unknown_stop_reason() {
3525        let driver = AnthropicDriver::new("test-key".to_string(), None);
3526        let response_body = serde_json::json!({
3527            "content": [{"type": "text", "text": "err"}],
3528            "stop_reason": "something_unknown",
3529            "usage": {"input_tokens": 0, "output_tokens": 0}
3530        });
3531
3532        let resp = driver.parse_response(&response_body).unwrap();
3533        assert_eq!(resp.stop_reason, StopReason::Error);
3534    }
3535
3536    #[test]
3537    fn anthropic_request_body_with_assistant_and_tool_messages() {
3538        let driver = AnthropicDriver::new("test-key".to_string(), None);
3539        let req = CompletionRequest {
3540            model: "test".into(),
3541            messages: vec![
3542                Message::new(Role::User, "Hi"),
3543                Message {
3544                    role: Role::Assistant,
3545                    content: "I'll check".into(),
3546                    content_parts: Vec::new(),
3547                    tool_calls: vec![ToolCall {
3548                        id: "call_1".into(),
3549                        name: "file_read".into(),
3550                        input: serde_json::json!({"path": "/tmp/test"}),
3551                    }],
3552                    tool_results: Vec::new(),
3553                    timestamp: chrono::Utc::now(),
3554                },
3555                Message {
3556                    role: Role::Tool,
3557                    content: String::new(),
3558                    content_parts: Vec::new(),
3559                    tool_calls: Vec::new(),
3560                    tool_results: vec![punch_types::ToolCallResult {
3561                        id: "call_1".into(),
3562                        content: "file contents".into(),
3563                        is_error: false,
3564                        image: None,
3565                    }],
3566                    timestamp: chrono::Utc::now(),
3567                },
3568            ],
3569            tools: Vec::new(),
3570            max_tokens: 100,
3571            temperature: None,
3572            system_prompt: None,
3573        };
3574
3575        let body = driver.build_request_body(&req);
3576        let messages = body["messages"].as_array().unwrap();
3577        assert_eq!(messages.len(), 3);
3578        assert_eq!(messages[0]["role"], "user");
3579        assert_eq!(messages[1]["role"], "assistant");
3580        assert_eq!(messages[2]["role"], "user"); // Tool results go as user role
3581    }
3582
3583    #[test]
3584    fn anthropic_request_body_system_message_skipped() {
3585        let driver = AnthropicDriver::new("test-key".to_string(), None);
3586        let req = CompletionRequest {
3587            model: "test".into(),
3588            messages: vec![
3589                Message::new(Role::System, "System instruction"),
3590                Message::new(Role::User, "Hi"),
3591            ],
3592            tools: Vec::new(),
3593            max_tokens: 100,
3594            temperature: None,
3595            system_prompt: None,
3596        };
3597
3598        let body = driver.build_request_body(&req);
3599        let messages = body["messages"].as_array().unwrap();
3600        // System messages are skipped in messages array
3601        assert_eq!(messages.len(), 1);
3602        assert_eq!(messages[0]["role"], "user");
3603    }
3604
3605    // -----------------------------------------------------------------------
3606    // OpenAI-compatible driver tests
3607    // -----------------------------------------------------------------------
3608
3609    #[test]
3610    fn openai_request_body_simple() {
3611        let driver = OpenAiCompatibleDriver::new(
3612            "key".into(),
3613            "https://api.openai.com".into(),
3614            "openai".into(),
3615        );
3616        let body = driver.build_request_body(&simple_request());
3617
3618        assert_eq!(body["model"], "test-model");
3619        let messages = body["messages"].as_array().unwrap();
3620        assert_eq!(messages.len(), 2);
3621        assert_eq!(messages[0]["role"], "system");
3622        assert_eq!(messages[0]["content"], "You are helpful.");
3623        assert_eq!(messages[1]["role"], "user");
3624    }
3625
3626    #[test]
3627    fn openai_request_body_with_tools() {
3628        let driver = OpenAiCompatibleDriver::new(
3629            "key".into(),
3630            "https://api.openai.com".into(),
3631            "openai".into(),
3632        );
3633        let body = driver.build_request_body(&request_with_tools());
3634
3635        let tools = body["tools"].as_array().unwrap();
3636        assert_eq!(tools.len(), 1);
3637        assert_eq!(tools[0]["type"], "function");
3638        assert_eq!(tools[0]["function"]["name"], "get_weather");
3639    }
3640
3641    #[test]
3642    fn openai_parse_response_text() {
3643        let driver = OpenAiCompatibleDriver::new(
3644            "key".into(),
3645            "https://api.openai.com".into(),
3646            "openai".into(),
3647        );
3648        let response_body = serde_json::json!({
3649            "choices": [{
3650                "message": {
3651                    "role": "assistant",
3652                    "content": "Hello!"
3653                },
3654                "finish_reason": "stop"
3655            }],
3656            "usage": {
3657                "prompt_tokens": 10,
3658                "completion_tokens": 5
3659            }
3660        });
3661
3662        let resp = driver.parse_response(&response_body).unwrap();
3663        assert_eq!(resp.message.content, "Hello!");
3664        assert_eq!(resp.stop_reason, StopReason::EndTurn);
3665        assert_eq!(resp.usage.input_tokens, 10);
3666        assert_eq!(resp.usage.output_tokens, 5);
3667    }
3668
3669    #[test]
3670    fn openai_parse_response_tool_calls() {
3671        let driver = OpenAiCompatibleDriver::new(
3672            "key".into(),
3673            "https://api.openai.com".into(),
3674            "openai".into(),
3675        );
3676        let response_body = serde_json::json!({
3677            "choices": [{
3678                "message": {
3679                    "role": "assistant",
3680                    "content": null,
3681                    "tool_calls": [{
3682                        "id": "call_123",
3683                        "type": "function",
3684                        "function": {
3685                            "name": "get_weather",
3686                            "arguments": "{\"city\": \"NYC\"}"
3687                        }
3688                    }]
3689                },
3690                "finish_reason": "tool_calls"
3691            }],
3692            "usage": {"prompt_tokens": 10, "completion_tokens": 5}
3693        });
3694
3695        let resp = driver.parse_response(&response_body).unwrap();
3696        assert_eq!(resp.stop_reason, StopReason::ToolUse);
3697        assert_eq!(resp.message.tool_calls.len(), 1);
3698        assert_eq!(resp.message.tool_calls[0].name, "get_weather");
3699        assert_eq!(resp.message.tool_calls[0].input["city"], "NYC");
3700    }
3701
3702    #[test]
3703    fn openai_parse_response_tool_calls_fix_stop_reason() {
3704        let driver = OpenAiCompatibleDriver::new(
3705            "key".into(),
3706            "https://api.openai.com".into(),
3707            "openai".into(),
3708        );
3709        // finish_reason is "stop" but there are tool_calls — should fix to ToolUse
3710        let response_body = serde_json::json!({
3711            "choices": [{
3712                "message": {
3713                    "role": "assistant",
3714                    "content": "Using tool",
3715                    "tool_calls": [{
3716                        "id": "call_1",
3717                        "type": "function",
3718                        "function": {
3719                            "name": "test_tool",
3720                            "arguments": "{}"
3721                        }
3722                    }]
3723                },
3724                "finish_reason": "stop"
3725            }],
3726            "usage": {"prompt_tokens": 0, "completion_tokens": 0}
3727        });
3728
3729        let resp = driver.parse_response(&response_body).unwrap();
3730        assert_eq!(resp.stop_reason, StopReason::ToolUse);
3731    }
3732
3733    #[test]
3734    fn openai_parse_response_length_stop_reason() {
3735        let driver = OpenAiCompatibleDriver::new(
3736            "key".into(),
3737            "https://api.openai.com".into(),
3738            "openai".into(),
3739        );
3740        let response_body = serde_json::json!({
3741            "choices": [{
3742                "message": {"role": "assistant", "content": "cut off"},
3743                "finish_reason": "length"
3744            }],
3745            "usage": {"prompt_tokens": 0, "completion_tokens": 0}
3746        });
3747
3748        let resp = driver.parse_response(&response_body).unwrap();
3749        assert_eq!(resp.stop_reason, StopReason::MaxTokens);
3750    }
3751
3752    #[test]
3753    fn openai_parse_response_no_choices_error() {
3754        let driver = OpenAiCompatibleDriver::new(
3755            "key".into(),
3756            "https://api.openai.com".into(),
3757            "openai".into(),
3758        );
3759        let response_body = serde_json::json!({"choices": []});
3760
3761        let result = driver.parse_response(&response_body);
3762        assert!(result.is_err());
3763    }
3764
3765    // -----------------------------------------------------------------------
3766    // Gemini driver additional tests
3767    // -----------------------------------------------------------------------
3768
3769    #[test]
3770    fn gemini_assistant_message_formatting() {
3771        let driver = GeminiDriver::new("key".to_string(), None);
3772        let req = CompletionRequest {
3773            model: "gemini-pro".into(),
3774            messages: vec![
3775                Message::new(Role::User, "Hi"),
3776                Message {
3777                    role: Role::Assistant,
3778                    content: "Let me help".into(),
3779                    content_parts: Vec::new(),
3780                    tool_calls: vec![ToolCall {
3781                        id: "tc1".into(),
3782                        name: "get_weather".into(),
3783                        input: serde_json::json!({"city": "NYC"}),
3784                    }],
3785                    tool_results: Vec::new(),
3786                    timestamp: chrono::Utc::now(),
3787                },
3788            ],
3789            tools: Vec::new(),
3790            max_tokens: 100,
3791            temperature: None,
3792            system_prompt: None,
3793        };
3794
3795        let body = driver.build_request_body(&req);
3796        let contents = body["contents"].as_array().unwrap();
3797        assert_eq!(contents[1]["role"], "model"); // Gemini uses "model" not "assistant"
3798        let parts = contents[1]["parts"].as_array().unwrap();
3799        assert!(parts.len() >= 2); // text part + functionCall part
3800    }
3801
3802    #[test]
3803    fn gemini_max_tokens_stop_reason() {
3804        let driver = GeminiDriver::new("key".to_string(), None);
3805        let response_body = serde_json::json!({
3806            "candidates": [{
3807                "content": {
3808                    "parts": [{"text": "truncated"}],
3809                    "role": "model"
3810                },
3811                "finishReason": "MAX_TOKENS"
3812            }],
3813            "usageMetadata": {"promptTokenCount": 0, "candidatesTokenCount": 0}
3814        });
3815
3816        let resp = driver.parse_response(&response_body).unwrap();
3817        assert_eq!(resp.stop_reason, StopReason::MaxTokens);
3818    }
3819
3820    #[test]
3821    fn gemini_custom_base_url() {
3822        let driver =
3823            GeminiDriver::new("key".to_string(), Some("https://custom.example.com".into()));
3824        let url = driver.build_url("gemini-pro");
3825        assert!(url.starts_with("https://custom.example.com/"));
3826    }
3827
3828    // -----------------------------------------------------------------------
3829    // Ollama driver additional tests
3830    // -----------------------------------------------------------------------
3831
3832    #[test]
3833    fn ollama_response_with_tool_calls() {
3834        let driver = OllamaDriver::new(None);
3835        let response_body = serde_json::json!({
3836            "message": {
3837                "role": "assistant",
3838                "content": "",
3839                "tool_calls": [{
3840                    "function": {
3841                        "name": "get_weather",
3842                        "arguments": {"city": "London"}
3843                    }
3844                }]
3845            },
3846            "done": true,
3847            "prompt_eval_count": 10,
3848            "eval_count": 5
3849        });
3850
3851        let resp = driver.parse_response(&response_body).unwrap();
3852        assert_eq!(resp.stop_reason, StopReason::ToolUse);
3853        assert_eq!(resp.message.tool_calls.len(), 1);
3854        assert_eq!(resp.message.tool_calls[0].name, "get_weather");
3855    }
3856
3857    #[test]
3858    fn ollama_response_not_done() {
3859        let driver = OllamaDriver::new(None);
3860        let response_body = serde_json::json!({
3861            "message": {"role": "assistant", "content": "partial"},
3862            "done": false,
3863            "prompt_eval_count": 10,
3864            "eval_count": 5
3865        });
3866
3867        let resp = driver.parse_response(&response_body).unwrap();
3868        assert_eq!(resp.stop_reason, StopReason::MaxTokens);
3869    }
3870
3871    // -----------------------------------------------------------------------
3872    // Bedrock driver additional tests
3873    // -----------------------------------------------------------------------
3874
3875    #[test]
3876    fn bedrock_request_with_tools() {
3877        let driver = BedrockDriver::new("key".into(), "secret".into(), "us-east-1".into());
3878        let body = driver.build_request_body(&request_with_tools());
3879
3880        let tool_config = &body["toolConfig"]["tools"];
3881        assert!(tool_config.is_array());
3882        let tools = tool_config.as_array().unwrap();
3883        assert_eq!(tools.len(), 1);
3884        assert_eq!(tools[0]["toolSpec"]["name"], "get_weather");
3885    }
3886
3887    #[test]
3888    fn bedrock_response_with_tool_use() {
3889        let driver = BedrockDriver::new("key".into(), "secret".into(), "us-east-1".into());
3890        let response_body = serde_json::json!({
3891            "output": {
3892                "message": {
3893                    "role": "assistant",
3894                    "content": [
3895                        {"text": "Using tool"},
3896                        {"toolUse": {
3897                            "toolUseId": "tu_123",
3898                            "name": "get_weather",
3899                            "input": {"city": "NYC"}
3900                        }}
3901                    ]
3902                }
3903            },
3904            "stopReason": "tool_use",
3905            "usage": {"inputTokens": 10, "outputTokens": 20}
3906        });
3907
3908        let resp = driver.parse_response(&response_body).unwrap();
3909        assert_eq!(resp.stop_reason, StopReason::ToolUse);
3910        assert_eq!(resp.message.tool_calls.len(), 1);
3911        assert_eq!(resp.message.tool_calls[0].id, "tu_123");
3912        assert_eq!(resp.message.tool_calls[0].name, "get_weather");
3913    }
3914
3915    #[test]
3916    fn bedrock_request_with_tool_results() {
3917        let driver = BedrockDriver::new("key".into(), "secret".into(), "us-east-1".into());
3918        let req = CompletionRequest {
3919            model: "test".into(),
3920            messages: vec![
3921                Message::new(Role::User, "Hi"),
3922                Message {
3923                    role: Role::Tool,
3924                    content: String::new(),
3925                    content_parts: Vec::new(),
3926                    tool_calls: Vec::new(),
3927                    tool_results: vec![punch_types::ToolCallResult {
3928                        id: "tu_1".into(),
3929                        content: "result data".into(),
3930                        is_error: false,
3931                        image: None,
3932                    }],
3933                    timestamp: chrono::Utc::now(),
3934                },
3935            ],
3936            tools: Vec::new(),
3937            max_tokens: 100,
3938            temperature: None,
3939            system_prompt: None,
3940        };
3941
3942        let body = driver.build_request_body(&req);
3943        let messages = body["messages"].as_array().unwrap();
3944        assert_eq!(messages[1]["role"], "user"); // Bedrock sends tool results as user
3945        let content = messages[1]["content"].as_array().unwrap();
3946        assert!(content[0]["toolResult"].is_object());
3947        assert_eq!(content[0]["toolResult"]["status"], "success");
3948    }
3949
3950    #[test]
3951    fn bedrock_url_different_regions() {
3952        let driver = BedrockDriver::new("k".into(), "s".into(), "ap-southeast-1".into());
3953        let url = driver.build_url("model-id");
3954        assert!(url.contains("ap-southeast-1"));
3955    }
3956
3957    // -----------------------------------------------------------------------
3958    // Azure OpenAI additional tests
3959    // -----------------------------------------------------------------------
3960
3961    #[test]
3962    fn azure_openai_delegates_parse_to_openai() {
3963        let driver = AzureOpenAiDriver::new("key".into(), "res".into(), "dep".into(), None);
3964        let response_body = serde_json::json!({
3965            "choices": [{
3966                "message": {"role": "assistant", "content": "Azure response"},
3967                "finish_reason": "stop"
3968            }],
3969            "usage": {"prompt_tokens": 5, "completion_tokens": 3}
3970        });
3971
3972        let resp = driver.parse_response(&response_body).unwrap();
3973        assert_eq!(resp.message.content, "Azure response");
3974    }
3975
3976    // -----------------------------------------------------------------------
3977    // default_base_url tests
3978    // -----------------------------------------------------------------------
3979
3980    #[test]
3981    fn default_base_url_anthropic() {
3982        assert_eq!(
3983            default_base_url(&Provider::Anthropic),
3984            "https://api.anthropic.com"
3985        );
3986    }
3987
3988    #[test]
3989    fn default_base_url_openai() {
3990        assert_eq!(
3991            default_base_url(&Provider::OpenAI),
3992            "https://api.openai.com"
3993        );
3994    }
3995
3996    #[test]
3997    fn default_base_url_google() {
3998        assert_eq!(
3999            default_base_url(&Provider::Google),
4000            "https://generativelanguage.googleapis.com"
4001        );
4002    }
4003
4004    #[test]
4005    fn default_base_url_ollama() {
4006        assert_eq!(
4007            default_base_url(&Provider::Ollama),
4008            "http://localhost:11434"
4009        );
4010    }
4011
4012    #[test]
4013    fn default_base_url_groq() {
4014        assert_eq!(
4015            default_base_url(&Provider::Groq),
4016            "https://api.groq.com/openai"
4017        );
4018    }
4019
4020    #[test]
4021    fn default_base_url_deepseek() {
4022        assert_eq!(
4023            default_base_url(&Provider::DeepSeek),
4024            "https://api.deepseek.com"
4025        );
4026    }
4027
4028    // -----------------------------------------------------------------------
4029    // hex_sha256 and hex_encode tests
4030    // -----------------------------------------------------------------------
4031
4032    #[test]
4033    fn test_hex_sha256() {
4034        let hash = hex_sha256(b"");
4035        assert_eq!(
4036            hash,
4037            "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855"
4038        );
4039    }
4040
4041    #[test]
4042    fn test_hex_encode() {
4043        assert_eq!(hex_encode(&[0x00, 0xff, 0x0a, 0xbc]), "00ff0abc");
4044    }
4045
4046    #[test]
4047    fn test_hmac_sha256_basic() {
4048        let result = hmac_sha256(b"key", b"data");
4049        assert!(!result.is_empty());
4050        assert_eq!(result.len(), 32); // SHA-256 produces 32 bytes
4051    }
4052
4053    // -----------------------------------------------------------------------
4054    // create_driver error cases
4055    // -----------------------------------------------------------------------
4056
4057    #[test]
4058    fn create_driver_missing_api_key_env() {
4059        let config = ModelConfig {
4060            provider: Provider::Anthropic,
4061            model: "claude-3".into(),
4062            api_key_env: Some("PUNCH_TEST_NONEXISTENT_KEY_XYZ".into()),
4063            base_url: None,
4064            max_tokens: None,
4065            temperature: None,
4066        };
4067        let result = create_driver(&config);
4068        assert!(result.is_err());
4069    }
4070
4071    #[test]
4072    fn create_driver_openai_compatible_fallback() {
4073        // Custom provider should fall through to OpenAI-compatible
4074        unsafe { std::env::set_var("TEST_CUSTOM_KEY_DRIVER", "fake-key") };
4075        let config = ModelConfig {
4076            provider: Provider::Custom("my-custom".into()),
4077            model: "custom-model".into(),
4078            api_key_env: Some("TEST_CUSTOM_KEY_DRIVER".into()),
4079            base_url: Some("https://custom.api.com".into()),
4080            max_tokens: None,
4081            temperature: None,
4082        };
4083        let result = create_driver(&config);
4084        assert!(result.is_ok());
4085        unsafe { std::env::remove_var("TEST_CUSTOM_KEY_DRIVER") };
4086    }
4087
4088    // -----------------------------------------------------------------------
4089    // strip_thinking_tags tests
4090    // -----------------------------------------------------------------------
4091
4092    #[test]
4093    fn strip_thinking_tags_removes_think_block() {
4094        let input = "<think>internal reasoning here</think>The answer is 42.";
4095        assert_eq!(strip_thinking_tags(input), "The answer is 42.");
4096    }
4097
4098    #[test]
4099    fn strip_thinking_tags_removes_thinking_block() {
4100        let input = "<thinking>step by step reasoning</thinking>Hello world!";
4101        assert_eq!(strip_thinking_tags(input), "Hello world!");
4102    }
4103
4104    #[test]
4105    fn strip_thinking_tags_removes_reasoning_block() {
4106        let input = "<reasoning>let me figure this out</reasoning>The result is correct.";
4107        assert_eq!(strip_thinking_tags(input), "The result is correct.");
4108    }
4109
4110    #[test]
4111    fn strip_thinking_tags_removes_reflection_block() {
4112        let input = "<reflection>checking my work</reflection>Yes, that's right.";
4113        assert_eq!(strip_thinking_tags(input), "Yes, that's right.");
4114    }
4115
4116    #[test]
4117    fn strip_thinking_tags_removes_multiple_blocks() {
4118        let input = "<think>first thought</think>Hello <thinking>second thought</thinking>world!";
4119        assert_eq!(strip_thinking_tags(input), "Hello world!");
4120    }
4121
4122    #[test]
4123    fn strip_thinking_tags_preserves_content_without_tags() {
4124        let input = "Just a normal response with no thinking tags.";
4125        assert_eq!(strip_thinking_tags(input), input);
4126    }
4127
4128    #[test]
4129    fn strip_thinking_tags_handles_multiline_tags() {
4130        let input = "<think>\nLine 1\nLine 2\nLine 3\n</think>\nThe final answer.";
4131        assert_eq!(strip_thinking_tags(input), "The final answer.");
4132    }
4133
4134    #[test]
4135    fn strip_thinking_tags_returns_original_if_all_thinking() {
4136        // If the entire response is thinking with no visible output,
4137        // return the original so the user sees something.
4138        let input = "<think>this is all thinking content and nothing else</think>";
4139        assert_eq!(strip_thinking_tags(input), input);
4140    }
4141
4142    #[test]
4143    fn strip_thinking_tags_handles_unclosed_tag() {
4144        let input = "Some text<think>unclosed thinking block";
4145        assert_eq!(strip_thinking_tags(input), "Some text");
4146    }
4147
4148    #[test]
4149    fn strip_thinking_tags_handles_empty_input() {
4150        assert_eq!(strip_thinking_tags(""), "");
4151    }
4152
4153    #[test]
4154    fn strip_thinking_tags_handles_empty_think_block() {
4155        let input = "<think></think>Visible content.";
4156        assert_eq!(strip_thinking_tags(input), "Visible content.");
4157    }
4158
4159    #[test]
4160    fn strip_thinking_tags_trims_whitespace() {
4161        let input = "  <think>reasoning</think>  Result  ";
4162        assert_eq!(strip_thinking_tags(input), "Result");
4163    }
4164
4165    #[test]
4166    fn strip_thinking_tags_mixed_tag_types() {
4167        let input = "<think>t1</think>A<reasoning>r1</reasoning>B<reflection>f1</reflection>C";
4168        assert_eq!(strip_thinking_tags(input), "ABC");
4169    }
4170
4171    #[test]
4172    fn ollama_response_strips_thinking_tags() {
4173        let driver = OllamaDriver::new(None);
4174        let response_body = serde_json::json!({
4175            "message": {
4176                "role": "assistant",
4177                "content": "<think>\nLet me think about this...\nThe user wants hello world.\n</think>\nHello, world!"
4178            },
4179            "done": true,
4180            "prompt_eval_count": 20,
4181            "eval_count": 50
4182        });
4183
4184        let resp = driver.parse_response(&response_body).unwrap();
4185        assert_eq!(resp.message.content, "Hello, world!");
4186        assert!(!resp.message.content.contains("<think>"));
4187    }
4188
4189    #[test]
4190    fn gemini_response_strips_thinking_tags() {
4191        let driver = GeminiDriver::new("test-key".to_string(), None);
4192        let response_body = serde_json::json!({
4193            "candidates": [{
4194                "content": {
4195                    "parts": [{"text": "<thinking>reasoning step</thinking>The answer is 7."}],
4196                    "role": "model"
4197                },
4198                "finishReason": "STOP"
4199            }],
4200            "usageMetadata": {
4201                "promptTokenCount": 10,
4202                "candidatesTokenCount": 20
4203            }
4204        });
4205
4206        let resp = driver.parse_response(&response_body).unwrap();
4207        assert_eq!(resp.message.content, "The answer is 7.");
4208        assert!(!resp.message.content.contains("<thinking>"));
4209    }
4210
4211    #[test]
4212    fn anthropic_response_strips_thinking_tags() {
4213        let driver = AnthropicDriver::new("test-key".to_string(), None);
4214        let response_body = serde_json::json!({
4215            "content": [
4216                {"type": "text", "text": "<think>internal thought</think>Clean output."}
4217            ],
4218            "stop_reason": "end_turn",
4219            "usage": {"input_tokens": 10, "output_tokens": 5}
4220        });
4221
4222        let resp = driver.parse_response(&response_body).unwrap();
4223        assert_eq!(resp.message.content, "Clean output.");
4224    }
4225
4226    #[test]
4227    fn bedrock_response_strips_thinking_tags() {
4228        let driver = BedrockDriver::new(
4229            "key".to_string(),
4230            "secret".to_string(),
4231            "us-east-1".to_string(),
4232        );
4233        let response_body = serde_json::json!({
4234            "output": {
4235                "message": {
4236                    "role": "assistant",
4237                    "content": [{"text": "<reasoning>deep thought</reasoning>Result here."}]
4238                }
4239            },
4240            "stopReason": "end_turn",
4241            "usage": {"inputTokens": 50, "outputTokens": 25}
4242        });
4243
4244        let resp = driver.parse_response(&response_body).unwrap();
4245        assert_eq!(resp.message.content, "Result here.");
4246    }
4247
4248    // -----------------------------------------------------------------------
4249    // StreamChunk / ToolCallDelta serialization tests
4250    // -----------------------------------------------------------------------
4251
4252    #[test]
4253    fn stream_chunk_serialization_roundtrip() {
4254        let chunk = StreamChunk {
4255            delta: "Hello".to_string(),
4256            is_final: false,
4257            tool_call_delta: None,
4258        };
4259        let json = serde_json::to_string(&chunk).unwrap();
4260        let deserialized: StreamChunk = serde_json::from_str(&json).unwrap();
4261        assert_eq!(deserialized.delta, "Hello");
4262        assert!(!deserialized.is_final);
4263        assert!(deserialized.tool_call_delta.is_none());
4264    }
4265
4266    #[test]
4267    fn stream_chunk_with_tool_call_delta_serialization() {
4268        let chunk = StreamChunk {
4269            delta: String::new(),
4270            is_final: false,
4271            tool_call_delta: Some(ToolCallDelta {
4272                index: 0,
4273                id: Some("call_123".to_string()),
4274                name: Some("get_weather".to_string()),
4275                arguments_delta: "{\"city\":".to_string(),
4276            }),
4277        };
4278        let json = serde_json::to_string(&chunk).unwrap();
4279        let deserialized: StreamChunk = serde_json::from_str(&json).unwrap();
4280        let tcd = deserialized.tool_call_delta.unwrap();
4281        assert_eq!(tcd.index, 0);
4282        assert_eq!(tcd.id.unwrap(), "call_123");
4283        assert_eq!(tcd.name.unwrap(), "get_weather");
4284        assert_eq!(tcd.arguments_delta, "{\"city\":");
4285    }
4286
4287    #[test]
4288    fn stream_chunk_final_serialization() {
4289        let chunk = StreamChunk {
4290            delta: String::new(),
4291            is_final: true,
4292            tool_call_delta: None,
4293        };
4294        let json = serde_json::to_string(&chunk).unwrap();
4295        assert!(json.contains("\"is_final\":true"));
4296    }
4297
4298    #[test]
4299    fn tool_call_delta_serialization_roundtrip() {
4300        let tcd = ToolCallDelta {
4301            index: 2,
4302            id: None,
4303            name: None,
4304            arguments_delta: "\"NYC\"}".to_string(),
4305        };
4306        let json = serde_json::to_string(&tcd).unwrap();
4307        let deserialized: ToolCallDelta = serde_json::from_str(&json).unwrap();
4308        assert_eq!(deserialized.index, 2);
4309        assert!(deserialized.id.is_none());
4310        assert!(deserialized.name.is_none());
4311        assert_eq!(deserialized.arguments_delta, "\"NYC\"}");
4312    }
4313
4314    // -----------------------------------------------------------------------
4315    // SSE parsing tests
4316    // -----------------------------------------------------------------------
4317
4318    #[test]
4319    fn parse_sse_events_basic() {
4320        let raw = "event: message_start\ndata: {\"type\":\"message_start\"}\n\nevent: content_block_delta\ndata: {\"delta\":{\"text\":\"Hi\"}}\n\n";
4321        let events = parse_sse_events(raw);
4322        assert_eq!(events.len(), 2);
4323        assert_eq!(events[0].0, "message_start");
4324        assert_eq!(events[1].0, "content_block_delta");
4325    }
4326
4327    #[test]
4328    fn parse_sse_events_with_done() {
4329        let raw = "data: {\"choices\":[{\"delta\":{\"content\":\"Hello\"}}]}\n\ndata: [DONE]\n\n";
4330        let events = parse_sse_events(raw);
4331        assert_eq!(events.len(), 2);
4332        assert_eq!(events[1].1, "[DONE]");
4333    }
4334
4335    #[test]
4336    fn parse_sse_events_empty_input() {
4337        let events = parse_sse_events("");
4338        assert!(events.is_empty());
4339    }
4340
4341    #[test]
4342    fn parse_sse_events_no_trailing_newline() {
4343        let raw = "event: test\ndata: {\"value\":1}";
4344        let events = parse_sse_events(raw);
4345        assert_eq!(events.len(), 1);
4346        assert_eq!(events[0].0, "test");
4347    }
4348
4349    #[test]
4350    fn parse_sse_events_multiline_data() {
4351        let raw = "data: line1\ndata: line2\n\n";
4352        let events = parse_sse_events(raw);
4353        assert_eq!(events.len(), 1);
4354        assert_eq!(events[0].1, "line1\nline2");
4355    }
4356
4357    #[test]
4358    fn parse_sse_events_no_event_field() {
4359        let raw = "data: {\"hello\":\"world\"}\n\n";
4360        let events = parse_sse_events(raw);
4361        assert_eq!(events.len(), 1);
4362        assert_eq!(events[0].0, "message"); // default event type
4363    }
4364
4365    // -----------------------------------------------------------------------
4366    // Anthropic streaming tests
4367    // -----------------------------------------------------------------------
4368
4369    #[test]
4370    fn anthropic_stream_text_only() {
4371        let raw = concat!(
4372            "event: message_start\n",
4373            "data: {\"type\":\"message_start\",\"message\":{\"usage\":{\"input_tokens\":25}}}\n\n",
4374            "event: content_block_start\n",
4375            "data: {\"type\":\"content_block_start\",\"index\":0,\"content_block\":{\"type\":\"text\",\"text\":\"\"}}\n\n",
4376            "event: content_block_delta\n",
4377            "data: {\"type\":\"content_block_delta\",\"delta\":{\"type\":\"text_delta\",\"text\":\"Hello\"}}\n\n",
4378            "event: content_block_delta\n",
4379            "data: {\"type\":\"content_block_delta\",\"delta\":{\"type\":\"text_delta\",\"text\":\" world\"}}\n\n",
4380            "event: message_delta\n",
4381            "data: {\"type\":\"message_delta\",\"delta\":{\"stop_reason\":\"end_turn\"},\"usage\":{\"output_tokens\":10}}\n\n",
4382            "event: message_stop\n",
4383            "data: {\"type\":\"message_stop\"}\n\n",
4384        );
4385
4386        let events = parse_sse_events(raw);
4387        let chunks: Arc<std::sync::Mutex<Vec<StreamChunk>>> =
4388            Arc::new(std::sync::Mutex::new(Vec::new()));
4389        let chunks_clone = chunks.clone();
4390        let callback: StreamCallback = Arc::new(move |chunk| {
4391            chunks_clone.lock().unwrap().push(chunk);
4392        });
4393
4394        // Simulate the Anthropic stream processing
4395        let mut text_content = String::new();
4396        let mut usage = TokenUsage::default();
4397        let mut stop_reason = StopReason::EndTurn;
4398
4399        for (event_type, data) in &events {
4400            let parsed: serde_json::Value = match serde_json::from_str(data) {
4401                Ok(v) => v,
4402                Err(_) => continue,
4403            };
4404
4405            match event_type.as_str() {
4406                "message_start" => {
4407                    if let Some(inp) = parsed["message"]["usage"]["input_tokens"].as_u64() {
4408                        usage.input_tokens = inp;
4409                    }
4410                }
4411                "content_block_delta" => {
4412                    if let Some(text) = parsed["delta"]["text"].as_str() {
4413                        text_content.push_str(text);
4414                        callback(StreamChunk {
4415                            delta: text.to_string(),
4416                            is_final: false,
4417                            tool_call_delta: None,
4418                        });
4419                    }
4420                }
4421                "message_delta" => {
4422                    if let Some(sr) = parsed["delta"]["stop_reason"].as_str() {
4423                        stop_reason = match sr {
4424                            "end_turn" => StopReason::EndTurn,
4425                            "tool_use" => StopReason::ToolUse,
4426                            _ => StopReason::Error,
4427                        };
4428                    }
4429                    if let Some(out) = parsed["usage"]["output_tokens"].as_u64() {
4430                        usage.output_tokens = out;
4431                    }
4432                }
4433                "message_stop" => {
4434                    callback(StreamChunk {
4435                        delta: String::new(),
4436                        is_final: true,
4437                        tool_call_delta: None,
4438                    });
4439                }
4440                _ => {}
4441            }
4442        }
4443
4444        assert_eq!(text_content, "Hello world");
4445        assert_eq!(usage.input_tokens, 25);
4446        assert_eq!(usage.output_tokens, 10);
4447        assert_eq!(stop_reason, StopReason::EndTurn);
4448
4449        let received = chunks.lock().unwrap();
4450        assert_eq!(received.len(), 3); // "Hello", " world", final
4451        assert_eq!(received[0].delta, "Hello");
4452        assert_eq!(received[1].delta, " world");
4453        assert!(received[2].is_final);
4454    }
4455
4456    #[test]
4457    fn anthropic_stream_with_tool_use() {
4458        let raw = concat!(
4459            "event: message_start\n",
4460            "data: {\"type\":\"message_start\",\"message\":{\"usage\":{\"input_tokens\":15}}}\n\n",
4461            "event: content_block_start\n",
4462            "data: {\"type\":\"content_block_start\",\"index\":0,\"content_block\":{\"type\":\"text\",\"text\":\"\"}}\n\n",
4463            "event: content_block_delta\n",
4464            "data: {\"type\":\"content_block_delta\",\"delta\":{\"type\":\"text_delta\",\"text\":\"Checking.\"}}\n\n",
4465            "event: content_block_start\n",
4466            "data: {\"type\":\"content_block_start\",\"index\":1,\"content_block\":{\"type\":\"tool_use\",\"id\":\"tool_1\",\"name\":\"get_weather\"}}\n\n",
4467            "event: content_block_delta\n",
4468            "data: {\"type\":\"content_block_delta\",\"delta\":{\"type\":\"input_json_delta\",\"partial_json\":\"{\\\"city\\\"\"}}\n\n",
4469            "event: content_block_delta\n",
4470            "data: {\"type\":\"content_block_delta\",\"delta\":{\"type\":\"input_json_delta\",\"partial_json\":\": \\\"NYC\\\"}\"}}\n\n",
4471            "event: message_delta\n",
4472            "data: {\"type\":\"message_delta\",\"delta\":{\"stop_reason\":\"tool_use\"},\"usage\":{\"output_tokens\":20}}\n\n",
4473            "event: message_stop\n",
4474            "data: {\"type\":\"message_stop\"}\n\n",
4475        );
4476
4477        let events = parse_sse_events(raw);
4478        // Verify we can parse all events
4479        assert!(events.len() >= 7);
4480
4481        // Verify tool JSON reconstruction
4482        let mut tool_json_bufs: Vec<String> = Vec::new();
4483        let mut tc_idx: Option<usize> = None;
4484
4485        for (event_type, data) in &events {
4486            let parsed: serde_json::Value = match serde_json::from_str(data) {
4487                Ok(v) => v,
4488                Err(_) => continue,
4489            };
4490            match event_type.as_str() {
4491                "content_block_start" => {
4492                    if parsed["content_block"]["type"].as_str() == Some("tool_use") {
4493                        tool_json_bufs.push(String::new());
4494                        tc_idx = Some(tool_json_bufs.len() - 1);
4495                    } else {
4496                        tc_idx = None;
4497                    }
4498                }
4499                "content_block_delta" => {
4500                    if parsed["delta"]["type"].as_str() == Some("input_json_delta")
4501                        && let Some(idx) = tc_idx
4502                        && let Some(buf) = tool_json_bufs.get_mut(idx)
4503                    {
4504                        buf.push_str(parsed["delta"]["partial_json"].as_str().unwrap_or(""));
4505                    }
4506                }
4507                _ => {}
4508            }
4509        }
4510
4511        assert_eq!(tool_json_bufs.len(), 1);
4512        assert_eq!(tool_json_bufs[0], "{\"city\": \"NYC\"}");
4513
4514        let parsed_input: serde_json::Value = serde_json::from_str(&tool_json_bufs[0]).unwrap();
4515        assert_eq!(parsed_input["city"], "NYC");
4516    }
4517
4518    // -----------------------------------------------------------------------
4519    // OpenAI streaming tests
4520    // -----------------------------------------------------------------------
4521
4522    #[test]
4523    fn openai_stream_text_only() {
4524        let driver = OpenAiCompatibleDriver::new(
4525            "key".into(),
4526            "https://api.openai.com".into(),
4527            "openai".into(),
4528        );
4529
4530        let raw = concat!(
4531            "data: {\"choices\":[{\"delta\":{\"role\":\"assistant\"},\"index\":0}]}\n\n",
4532            "data: {\"choices\":[{\"delta\":{\"content\":\"Hello\"},\"index\":0}]}\n\n",
4533            "data: {\"choices\":[{\"delta\":{\"content\":\" world\"},\"index\":0}]}\n\n",
4534            "data: {\"choices\":[{\"delta\":{},\"index\":0,\"finish_reason\":\"stop\"}]}\n\n",
4535            "data: [DONE]\n\n",
4536        );
4537
4538        let chunks: Arc<std::sync::Mutex<Vec<StreamChunk>>> =
4539            Arc::new(std::sync::Mutex::new(Vec::new()));
4540        let chunks_clone = chunks.clone();
4541        let callback: StreamCallback = Arc::new(move |chunk| {
4542            chunks_clone.lock().unwrap().push(chunk);
4543        });
4544
4545        let resp = driver.parse_openai_stream(raw, &callback).unwrap();
4546
4547        assert_eq!(resp.message.content, "Hello world");
4548        assert_eq!(resp.stop_reason, StopReason::EndTurn);
4549        assert!(resp.message.tool_calls.is_empty());
4550
4551        let received = chunks.lock().unwrap();
4552        // "Hello", " world", final [DONE]
4553        assert!(received.len() >= 3);
4554        assert_eq!(received[0].delta, "Hello");
4555        assert_eq!(received[1].delta, " world");
4556        assert!(received.last().unwrap().is_final);
4557    }
4558
4559    #[test]
4560    fn openai_stream_with_tool_calls() {
4561        let driver = OpenAiCompatibleDriver::new(
4562            "key".into(),
4563            "https://api.openai.com".into(),
4564            "openai".into(),
4565        );
4566
4567        let raw = concat!(
4568            "data: {\"choices\":[{\"delta\":{\"tool_calls\":[{\"index\":0,\"id\":\"call_abc\",\"type\":\"function\",\"function\":{\"name\":\"get_weather\",\"arguments\":\"\"}}]},\"index\":0}]}\n\n",
4569            "data: {\"choices\":[{\"delta\":{\"tool_calls\":[{\"index\":0,\"function\":{\"arguments\":\"{\\\"ci\"}}]},\"index\":0}]}\n\n",
4570            "data: {\"choices\":[{\"delta\":{\"tool_calls\":[{\"index\":0,\"function\":{\"arguments\":\"ty\\\": \\\"NYC\\\"}\"}}]},\"index\":0}]}\n\n",
4571            "data: {\"choices\":[{\"delta\":{},\"index\":0,\"finish_reason\":\"tool_calls\"}]}\n\n",
4572            "data: [DONE]\n\n",
4573        );
4574
4575        let chunks: Arc<std::sync::Mutex<Vec<StreamChunk>>> =
4576            Arc::new(std::sync::Mutex::new(Vec::new()));
4577        let chunks_clone = chunks.clone();
4578        let callback: StreamCallback = Arc::new(move |chunk| {
4579            chunks_clone.lock().unwrap().push(chunk);
4580        });
4581
4582        let resp = driver.parse_openai_stream(raw, &callback).unwrap();
4583
4584        assert_eq!(resp.stop_reason, StopReason::ToolUse);
4585        assert_eq!(resp.message.tool_calls.len(), 1);
4586        assert_eq!(resp.message.tool_calls[0].id, "call_abc");
4587        assert_eq!(resp.message.tool_calls[0].name, "get_weather");
4588        assert_eq!(resp.message.tool_calls[0].input["city"], "NYC");
4589
4590        let received = chunks.lock().unwrap();
4591        // Should have tool call delta chunks and a final chunk
4592        let tool_chunks: Vec<_> = received
4593            .iter()
4594            .filter(|c| c.tool_call_delta.is_some())
4595            .collect();
4596        assert!(tool_chunks.len() >= 3); // id+name, partial args, more args
4597        assert!(received.last().unwrap().is_final);
4598    }
4599
4600    #[test]
4601    fn openai_stream_with_mixed_content_and_tools() {
4602        let driver = OpenAiCompatibleDriver::new(
4603            "key".into(),
4604            "https://api.openai.com".into(),
4605            "openai".into(),
4606        );
4607
4608        let raw = concat!(
4609            "data: {\"choices\":[{\"delta\":{\"content\":\"Sure, \"},\"index\":0}]}\n\n",
4610            "data: {\"choices\":[{\"delta\":{\"content\":\"checking.\"},\"index\":0}]}\n\n",
4611            "data: {\"choices\":[{\"delta\":{\"tool_calls\":[{\"index\":0,\"id\":\"call_1\",\"type\":\"function\",\"function\":{\"name\":\"search\",\"arguments\":\"{\\\"q\\\":\\\"test\\\"}\"}}]},\"index\":0}]}\n\n",
4612            "data: {\"choices\":[{\"delta\":{},\"index\":0,\"finish_reason\":\"tool_calls\"}]}\n\n",
4613            "data: [DONE]\n\n",
4614        );
4615
4616        let callback: StreamCallback = Arc::new(|_| {});
4617        let resp = driver.parse_openai_stream(raw, &callback).unwrap();
4618
4619        assert_eq!(resp.message.content, "Sure, checking.");
4620        assert_eq!(resp.stop_reason, StopReason::ToolUse);
4621        assert_eq!(resp.message.tool_calls.len(), 1);
4622        assert_eq!(resp.message.tool_calls[0].name, "search");
4623    }
4624
4625    #[test]
4626    fn openai_stream_length_stop_reason() {
4627        let driver = OpenAiCompatibleDriver::new(
4628            "key".into(),
4629            "https://api.openai.com".into(),
4630            "openai".into(),
4631        );
4632
4633        let raw = concat!(
4634            "data: {\"choices\":[{\"delta\":{\"content\":\"truncated\"},\"index\":0}]}\n\n",
4635            "data: {\"choices\":[{\"delta\":{},\"index\":0,\"finish_reason\":\"length\"}]}\n\n",
4636            "data: [DONE]\n\n",
4637        );
4638
4639        let callback: StreamCallback = Arc::new(|_| {});
4640        let resp = driver.parse_openai_stream(raw, &callback).unwrap();
4641        assert_eq!(resp.stop_reason, StopReason::MaxTokens);
4642    }
4643
4644    // -----------------------------------------------------------------------
4645    // Ollama streaming tests
4646    // -----------------------------------------------------------------------
4647
4648    #[test]
4649    fn ollama_stream_text_only() {
4650        let driver = OllamaDriver::new(None);
4651
4652        let raw = concat!(
4653            "{\"message\":{\"role\":\"assistant\",\"content\":\"Hello\"},\"done\":false}\n",
4654            "{\"message\":{\"role\":\"assistant\",\"content\":\" world\"},\"done\":false}\n",
4655            "{\"message\":{\"role\":\"assistant\",\"content\":\"!\"},\"done\":false}\n",
4656            "{\"done\":true,\"prompt_eval_count\":15,\"eval_count\":8}\n",
4657        );
4658
4659        let chunks: Arc<std::sync::Mutex<Vec<StreamChunk>>> =
4660            Arc::new(std::sync::Mutex::new(Vec::new()));
4661        let chunks_clone = chunks.clone();
4662        let callback: StreamCallback = Arc::new(move |chunk| {
4663            chunks_clone.lock().unwrap().push(chunk);
4664        });
4665
4666        let resp = driver.parse_ollama_stream(raw, &callback).unwrap();
4667
4668        assert_eq!(resp.message.content, "Hello world!");
4669        assert_eq!(resp.stop_reason, StopReason::EndTurn);
4670        assert_eq!(resp.usage.input_tokens, 15);
4671        assert_eq!(resp.usage.output_tokens, 8);
4672
4673        let received = chunks.lock().unwrap();
4674        assert_eq!(received.len(), 4); // 3 content + 1 final
4675        assert_eq!(received[0].delta, "Hello");
4676        assert_eq!(received[1].delta, " world");
4677        assert_eq!(received[2].delta, "!");
4678        assert!(received[3].is_final);
4679    }
4680
4681    #[test]
4682    fn ollama_stream_with_tool_calls() {
4683        let driver = OllamaDriver::new(None);
4684
4685        let raw = concat!(
4686            "{\"message\":{\"role\":\"assistant\",\"content\":\"Let me check.\"},\"done\":false}\n",
4687            "{\"message\":{\"role\":\"assistant\",\"content\":\"\",\"tool_calls\":[{\"function\":{\"name\":\"get_weather\",\"arguments\":{\"city\":\"London\"}}}]},\"done\":true,\"prompt_eval_count\":10,\"eval_count\":5}\n",
4688        );
4689
4690        let callback: StreamCallback = Arc::new(|_| {});
4691        let resp = driver.parse_ollama_stream(raw, &callback).unwrap();
4692
4693        assert_eq!(resp.message.content, "Let me check.");
4694        assert_eq!(resp.stop_reason, StopReason::ToolUse);
4695        assert_eq!(resp.message.tool_calls.len(), 1);
4696        assert_eq!(resp.message.tool_calls[0].name, "get_weather");
4697        assert_eq!(resp.usage.input_tokens, 10);
4698    }
4699
4700    #[test]
4701    fn ollama_stream_strips_thinking_tags() {
4702        let driver = OllamaDriver::new(None);
4703
4704        let raw = concat!(
4705            "{\"message\":{\"role\":\"assistant\",\"content\":\"<think>hmm</think>\"},\"done\":false}\n",
4706            "{\"message\":{\"role\":\"assistant\",\"content\":\"Clean answer.\"},\"done\":false}\n",
4707            "{\"done\":true,\"prompt_eval_count\":5,\"eval_count\":3}\n",
4708        );
4709
4710        let callback: StreamCallback = Arc::new(|_| {});
4711        let resp = driver.parse_ollama_stream(raw, &callback).unwrap();
4712        assert_eq!(resp.message.content, "Clean answer.");
4713    }
4714
4715    // -----------------------------------------------------------------------
4716    // Gemini streaming tests
4717    // -----------------------------------------------------------------------
4718
4719    #[test]
4720    fn gemini_stream_url_construction() {
4721        let driver = GeminiDriver::new("my-key".to_string(), None);
4722        let url = driver.build_stream_url("gemini-pro");
4723        assert!(url.contains("streamGenerateContent"));
4724        assert!(url.contains("alt=sse"));
4725        assert!(url.contains("key=my-key"));
4726        assert!(url.contains("models/gemini-pro"));
4727    }
4728
4729    #[test]
4730    fn gemini_stream_custom_base_url() {
4731        let driver = GeminiDriver::new(
4732            "key".to_string(),
4733            Some("https://custom.example.com".to_string()),
4734        );
4735        let url = driver.build_stream_url("gemini-pro");
4736        assert!(url.starts_with("https://custom.example.com/"));
4737        assert!(url.contains("streamGenerateContent"));
4738    }
4739
4740    // -----------------------------------------------------------------------
4741    // Callback mechanism tests
4742    // -----------------------------------------------------------------------
4743
4744    #[test]
4745    fn callback_receives_all_chunks_in_order() {
4746        let driver = OpenAiCompatibleDriver::new(
4747            "key".into(),
4748            "https://api.openai.com".into(),
4749            "openai".into(),
4750        );
4751
4752        let raw = concat!(
4753            "data: {\"choices\":[{\"delta\":{\"content\":\"A\"},\"index\":0}]}\n\n",
4754            "data: {\"choices\":[{\"delta\":{\"content\":\"B\"},\"index\":0}]}\n\n",
4755            "data: {\"choices\":[{\"delta\":{\"content\":\"C\"},\"index\":0}]}\n\n",
4756            "data: {\"choices\":[{\"delta\":{},\"index\":0,\"finish_reason\":\"stop\"}]}\n\n",
4757            "data: [DONE]\n\n",
4758        );
4759
4760        let deltas: Arc<std::sync::Mutex<Vec<String>>> =
4761            Arc::new(std::sync::Mutex::new(Vec::new()));
4762        let deltas_clone = deltas.clone();
4763        let callback: StreamCallback = Arc::new(move |chunk| {
4764            if !chunk.delta.is_empty() || chunk.is_final {
4765                deltas_clone.lock().unwrap().push(chunk.delta.clone());
4766            }
4767        });
4768
4769        let _resp = driver.parse_openai_stream(raw, &callback).unwrap();
4770        let received = deltas.lock().unwrap();
4771        assert_eq!(received.as_slice(), &["A", "B", "C", ""]);
4772    }
4773
4774    #[test]
4775    fn openai_stream_multiple_tool_calls() {
4776        let driver = OpenAiCompatibleDriver::new(
4777            "key".into(),
4778            "https://api.openai.com".into(),
4779            "openai".into(),
4780        );
4781
4782        let raw = concat!(
4783            "data: {\"choices\":[{\"delta\":{\"tool_calls\":[{\"index\":0,\"id\":\"call_1\",\"type\":\"function\",\"function\":{\"name\":\"tool_a\",\"arguments\":\"{\\\"x\\\":1}\"}}]},\"index\":0}]}\n\n",
4784            "data: {\"choices\":[{\"delta\":{\"tool_calls\":[{\"index\":1,\"id\":\"call_2\",\"type\":\"function\",\"function\":{\"name\":\"tool_b\",\"arguments\":\"{\\\"y\\\":2}\"}}]},\"index\":0}]}\n\n",
4785            "data: {\"choices\":[{\"delta\":{},\"index\":0,\"finish_reason\":\"tool_calls\"}]}\n\n",
4786            "data: [DONE]\n\n",
4787        );
4788
4789        let callback: StreamCallback = Arc::new(|_| {});
4790        let resp = driver.parse_openai_stream(raw, &callback).unwrap();
4791
4792        assert_eq!(resp.message.tool_calls.len(), 2);
4793        assert_eq!(resp.message.tool_calls[0].id, "call_1");
4794        assert_eq!(resp.message.tool_calls[0].name, "tool_a");
4795        assert_eq!(resp.message.tool_calls[0].input["x"], 1);
4796        assert_eq!(resp.message.tool_calls[1].id, "call_2");
4797        assert_eq!(resp.message.tool_calls[1].name, "tool_b");
4798        assert_eq!(resp.message.tool_calls[1].input["y"], 2);
4799    }
4800
4801    // -----------------------------------------------------------------------
4802    // Default stream_complete_with_callback (trait default) test
4803    // -----------------------------------------------------------------------
4804
4805    #[test]
4806    fn stream_chunk_default_values() {
4807        let chunk = StreamChunk {
4808            delta: String::new(),
4809            is_final: false,
4810            tool_call_delta: None,
4811        };
4812        assert!(chunk.delta.is_empty());
4813        assert!(!chunk.is_final);
4814        assert!(chunk.tool_call_delta.is_none());
4815    }
4816
4817    #[test]
4818    fn openai_stream_empty_input() {
4819        let driver = OpenAiCompatibleDriver::new(
4820            "key".into(),
4821            "https://api.openai.com".into(),
4822            "openai".into(),
4823        );
4824
4825        let raw = "data: [DONE]\n\n";
4826
4827        let callback: StreamCallback = Arc::new(|_| {});
4828        let resp = driver.parse_openai_stream(raw, &callback).unwrap();
4829
4830        assert_eq!(resp.message.content, "");
4831        assert!(resp.message.tool_calls.is_empty());
4832    }
4833
4834    #[test]
4835    fn ollama_stream_empty_input() {
4836        let driver = OllamaDriver::new(None);
4837        let raw = "";
4838
4839        let callback: StreamCallback = Arc::new(|_| {});
4840        let resp = driver.parse_ollama_stream(raw, &callback).unwrap();
4841
4842        assert_eq!(resp.message.content, "");
4843        assert_eq!(resp.stop_reason, StopReason::MaxTokens); // not done
4844    }
4845
4846    #[test]
4847    fn openai_stream_strips_thinking_tags() {
4848        let driver = OpenAiCompatibleDriver::new(
4849            "key".into(),
4850            "https://api.openai.com".into(),
4851            "openai".into(),
4852        );
4853
4854        let raw = concat!(
4855            "data: {\"choices\":[{\"delta\":{\"content\":\"<think>internal</think>\"},\"index\":0}]}\n\n",
4856            "data: {\"choices\":[{\"delta\":{\"content\":\"Result\"},\"index\":0}]}\n\n",
4857            "data: {\"choices\":[{\"delta\":{},\"index\":0,\"finish_reason\":\"stop\"}]}\n\n",
4858            "data: [DONE]\n\n",
4859        );
4860
4861        let callback: StreamCallback = Arc::new(|_| {});
4862        let resp = driver.parse_openai_stream(raw, &callback).unwrap();
4863        assert_eq!(resp.message.content, "Result");
4864    }
4865}