Skip to main content

punch_runtime/
driver.rs

1//! LLM driver trait and provider implementations.
2//!
3//! The [`LlmDriver`] trait abstracts over different LLM providers so the
4//! fighter loop is provider-agnostic. Concrete implementations handle the
5//! wire format differences between Anthropic, OpenAI-compatible APIs, etc.
6
7use std::sync::Arc;
8
9use async_trait::async_trait;
10use futures::StreamExt;
11use hmac::{Hmac, Mac};
12use reqwest::Client;
13use serde::{Deserialize, Serialize};
14use sha2::{Digest, Sha256};
15
16use punch_types::{
17    Message, ModelConfig, Provider, PunchError, PunchResult, Role, ToolCall, ToolDefinition,
18};
19
20// ---------------------------------------------------------------------------
21// Core types
22// ---------------------------------------------------------------------------
23
24/// Why the model stopped generating.
25#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
26#[serde(rename_all = "snake_case")]
27pub enum StopReason {
28    /// The model finished its turn naturally.
29    EndTurn,
30    /// The model wants to invoke one or more tools.
31    ToolUse,
32    /// The response was truncated due to max_tokens.
33    MaxTokens,
34    /// An error occurred during generation.
35    Error,
36}
37
38/// Token usage statistics for a single completion.
39#[derive(Debug, Clone, Copy, Default, Serialize, Deserialize)]
40pub struct TokenUsage {
41    pub input_tokens: u64,
42    pub output_tokens: u64,
43}
44
45impl TokenUsage {
46    /// Add another usage on top of this one (accumulator).
47    pub fn accumulate(&mut self, other: &TokenUsage) {
48        self.input_tokens += other.input_tokens;
49        self.output_tokens += other.output_tokens;
50    }
51
52    /// Total tokens consumed.
53    pub fn total(&self) -> u64 {
54        self.input_tokens + self.output_tokens
55    }
56}
57
58/// A request to the LLM for a completion.
59#[derive(Debug, Clone, Serialize, Deserialize)]
60pub struct CompletionRequest {
61    /// Model identifier (e.g. "claude-sonnet-4-20250514").
62    pub model: String,
63    /// Conversation messages.
64    pub messages: Vec<Message>,
65    /// Tools available for the model to call.
66    #[serde(default)]
67    pub tools: Vec<ToolDefinition>,
68    /// Maximum tokens to generate.
69    pub max_tokens: u32,
70    /// Sampling temperature.
71    pub temperature: Option<f32>,
72    /// System prompt (separate from messages for providers that support it).
73    pub system_prompt: Option<String>,
74}
75
76/// The response from an LLM completion.
77#[derive(Debug, Clone, Serialize, Deserialize)]
78pub struct CompletionResponse {
79    /// The assistant message (may contain tool calls).
80    pub message: Message,
81    /// Token usage for this completion.
82    pub usage: TokenUsage,
83    /// Why the model stopped.
84    pub stop_reason: StopReason,
85}
86
87// ---------------------------------------------------------------------------
88// Streaming types
89// ---------------------------------------------------------------------------
90
91/// A chunk from a streaming LLM response.
92#[derive(Debug, Clone, Serialize, Deserialize)]
93pub struct StreamChunk {
94    /// Incremental text content.
95    pub delta: String,
96    /// Whether this is the final chunk.
97    pub is_final: bool,
98    /// Partial tool call data if any.
99    pub tool_call_delta: Option<ToolCallDelta>,
100}
101
102/// Incremental tool call data in a streaming response.
103#[derive(Debug, Clone, Serialize, Deserialize)]
104pub struct ToolCallDelta {
105    pub index: usize,
106    pub id: Option<String>,
107    pub name: Option<String>,
108    pub arguments_delta: String,
109}
110
111/// Callback invoked for each streaming chunk.
112pub type StreamCallback = Arc<dyn Fn(StreamChunk) + Send + Sync>;
113
114// ---------------------------------------------------------------------------
115// SSE parsing helpers
116// ---------------------------------------------------------------------------
117
118/// Parse a Server-Sent Events stream from raw bytes into discrete events.
119///
120/// Each event is a `(event_type, data)` tuple. Blank lines delimit events.
121fn parse_sse_events(raw: &str) -> Vec<(String, String)> {
122    let mut events = Vec::new();
123    let mut current_event = String::new();
124    let mut current_data = String::new();
125
126    for line in raw.lines() {
127        if line.is_empty() {
128            // Blank line = end of event
129            if !current_data.is_empty() || !current_event.is_empty() {
130                events.push((
131                    if current_event.is_empty() {
132                        "message".to_string()
133                    } else {
134                        current_event.clone()
135                    },
136                    current_data.clone(),
137                ));
138                current_event.clear();
139                current_data.clear();
140            }
141        } else if let Some(val) = line.strip_prefix("event: ") {
142            current_event = val.trim().to_string();
143        } else if let Some(val) = line.strip_prefix("event:") {
144            current_event = val.trim().to_string();
145        } else if let Some(val) = line.strip_prefix("data: ") {
146            if !current_data.is_empty() {
147                current_data.push('\n');
148            }
149            current_data.push_str(val);
150        } else if let Some(val) = line.strip_prefix("data:") {
151            if !current_data.is_empty() {
152                current_data.push('\n');
153            }
154            current_data.push_str(val.trim());
155        }
156    }
157
158    // Flush any trailing event without a final blank line
159    if !current_data.is_empty() || !current_event.is_empty() {
160        events.push((
161            if current_event.is_empty() {
162                "message".to_string()
163            } else {
164                current_event
165            },
166            current_data,
167        ));
168    }
169
170    events
171}
172
173/// Read the full response body as a stream of bytes and return as a String.
174async fn read_stream_body(response: reqwest::Response) -> PunchResult<String> {
175    let mut stream = response.bytes_stream();
176    let mut body = Vec::new();
177    while let Some(chunk) = stream.next().await {
178        let chunk = chunk.map_err(|e| PunchError::Provider {
179            provider: "stream".to_string(),
180            message: format!("stream read error: {e}"),
181        })?;
182        body.extend_from_slice(&chunk);
183    }
184    String::from_utf8(body).map_err(|e| PunchError::Provider {
185        provider: "stream".to_string(),
186        message: format!("invalid UTF-8 in stream: {e}"),
187    })
188}
189
190// ---------------------------------------------------------------------------
191// Think-tag stripping
192// ---------------------------------------------------------------------------
193
194/// Strip reasoning/thinking tags from LLM responses.
195///
196/// Many reasoning models (Qwen, DeepSeek, etc.) wrap internal chain-of-thought
197/// in `<think>...</think>`, `<thinking>...</thinking>`, or `<reasoning>...</reasoning>`
198/// tags. This function extracts only the visible output.
199///
200/// If the entire response is inside think tags (no visible output), returns
201/// the original content unchanged so the user still sees something.
202pub fn strip_thinking_tags(content: &str) -> String {
203    let mut result = content.to_string();
204
205    // Strip all known thinking tag variants
206    for tag in &["think", "thinking", "reasoning", "reflection"] {
207        let open = format!("<{}>", tag);
208        let close = format!("</{}>", tag);
209
210        // Remove all occurrences of <tag>...</tag> blocks
211        while let Some(start) = result.find(&open) {
212            if let Some(end) = result[start..].find(&close) {
213                let block_end = start + end + close.len();
214                result = format!("{}{}", &result[..start], &result[block_end..]);
215            } else {
216                // Unclosed tag — remove from open tag to end
217                result = result[..start].to_string();
218                break;
219            }
220        }
221    }
222
223    let trimmed = result.trim().to_string();
224
225    // If stripping removed everything, return original content
226    // (the model used all tokens for thinking)
227    if trimmed.is_empty() {
228        content.to_string()
229    } else {
230        trimmed
231    }
232}
233
234// ---------------------------------------------------------------------------
235// Trait
236// ---------------------------------------------------------------------------
237
238/// Abstraction over LLM providers.
239#[async_trait]
240pub trait LlmDriver: Send + Sync + 'static {
241    /// Send a completion request and return the response.
242    async fn complete(&self, request: CompletionRequest) -> PunchResult<CompletionResponse>;
243
244    /// Streaming variant. Default implementation falls back to `complete`.
245    async fn stream_complete(&self, request: CompletionRequest) -> PunchResult<CompletionResponse> {
246        let noop: StreamCallback = Arc::new(|_| {});
247        self.stream_complete_with_callback(request, noop).await
248    }
249
250    /// Streaming completion with per-chunk callback.
251    /// Returns the final assembled `CompletionResponse`.
252    async fn stream_complete_with_callback(
253        &self,
254        request: CompletionRequest,
255        callback: StreamCallback,
256    ) -> PunchResult<CompletionResponse> {
257        // Default: call complete() and send a single chunk.
258        let response = self.complete(request).await?;
259        callback(StreamChunk {
260            delta: response.message.content.clone(),
261            is_final: true,
262            tool_call_delta: None,
263        });
264        Ok(response)
265    }
266}
267
268// ---------------------------------------------------------------------------
269// Anthropic driver
270// ---------------------------------------------------------------------------
271
272/// Driver for the Anthropic Messages API (api.anthropic.com).
273pub struct AnthropicDriver {
274    client: Client,
275    api_key: String,
276    base_url: String,
277}
278
279impl AnthropicDriver {
280    /// Create a new Anthropic driver.
281    ///
282    /// `api_key` is the raw key value, not the env var name.
283    pub fn new(api_key: String, base_url: Option<String>) -> Self {
284        Self {
285            client: Client::new(),
286            api_key,
287            base_url: base_url.unwrap_or_else(|| "https://api.anthropic.com".to_string()),
288        }
289    }
290
291    /// Create a new Anthropic driver with a shared HTTP client.
292    ///
293    /// This allows connection pooling across all drivers.
294    pub fn with_client(client: Client, api_key: String, base_url: Option<String>) -> Self {
295        Self {
296            client,
297            api_key,
298            base_url: base_url.unwrap_or_else(|| "https://api.anthropic.com".to_string()),
299        }
300    }
301
302    /// Build the Anthropic API request body from our internal types.
303    fn build_request_body(&self, request: &CompletionRequest) -> serde_json::Value {
304        let mut messages = Vec::new();
305
306        for msg in &request.messages {
307            match msg.role {
308                Role::User => {
309                    messages.push(serde_json::json!({
310                        "role": "user",
311                        "content": msg.content,
312                    }));
313                }
314                Role::Assistant => {
315                    let mut content_blocks: Vec<serde_json::Value> = Vec::new();
316
317                    if !msg.content.is_empty() {
318                        content_blocks.push(serde_json::json!({
319                            "type": "text",
320                            "text": msg.content,
321                        }));
322                    }
323
324                    for tc in &msg.tool_calls {
325                        content_blocks.push(serde_json::json!({
326                            "type": "tool_use",
327                            "id": tc.id,
328                            "name": tc.name,
329                            "input": tc.input,
330                        }));
331                    }
332
333                    if content_blocks.is_empty() {
334                        content_blocks.push(serde_json::json!({
335                            "type": "text",
336                            "text": "",
337                        }));
338                    }
339
340                    messages.push(serde_json::json!({
341                        "role": "assistant",
342                        "content": content_blocks,
343                    }));
344                }
345                Role::Tool => {
346                    let mut result_blocks: Vec<serde_json::Value> = Vec::new();
347                    for tr in &msg.tool_results {
348                        result_blocks.push(serde_json::json!({
349                            "type": "tool_result",
350                            "tool_use_id": tr.id,
351                            "content": tr.content,
352                            "is_error": tr.is_error,
353                        }));
354                    }
355                    messages.push(serde_json::json!({
356                        "role": "user",
357                        "content": result_blocks,
358                    }));
359                }
360                Role::System => {
361                    // System messages are handled via the top-level `system` param;
362                    // skip them in the messages array.
363                }
364            }
365        }
366
367        let tools: Vec<serde_json::Value> = request
368            .tools
369            .iter()
370            .map(|t| {
371                serde_json::json!({
372                    "name": t.name,
373                    "description": t.description,
374                    "input_schema": t.input_schema,
375                })
376            })
377            .collect();
378
379        let mut body = serde_json::json!({
380            "model": request.model,
381            "messages": messages,
382            "max_tokens": request.max_tokens,
383        });
384
385        if let Some(temp) = request.temperature {
386            body["temperature"] = serde_json::json!(temp);
387        }
388
389        if let Some(ref system) = request.system_prompt {
390            body["system"] = serde_json::json!(system);
391        }
392
393        if !tools.is_empty() {
394            body["tools"] = serde_json::json!(tools);
395        }
396
397        body
398    }
399
400    /// Parse the Anthropic API response into our internal types.
401    fn parse_response(&self, body: &serde_json::Value) -> PunchResult<CompletionResponse> {
402        let stop_reason = match body["stop_reason"].as_str() {
403            Some("end_turn") => StopReason::EndTurn,
404            Some("tool_use") => StopReason::ToolUse,
405            Some("max_tokens") => StopReason::MaxTokens,
406            _ => StopReason::Error,
407        };
408
409        let usage = TokenUsage {
410            input_tokens: body["usage"]["input_tokens"].as_u64().unwrap_or(0),
411            output_tokens: body["usage"]["output_tokens"].as_u64().unwrap_or(0),
412        };
413
414        let mut text_content = String::new();
415        let mut tool_calls = Vec::new();
416
417        if let Some(content_array) = body["content"].as_array() {
418            for block in content_array {
419                match block["type"].as_str() {
420                    Some("text") => {
421                        if let Some(text) = block["text"].as_str() {
422                            if !text_content.is_empty() {
423                                text_content.push('\n');
424                            }
425                            text_content.push_str(text);
426                        }
427                    }
428                    Some("tool_use") => {
429                        tool_calls.push(ToolCall {
430                            id: block["id"].as_str().unwrap_or_default().to_string(),
431                            name: block["name"].as_str().unwrap_or_default().to_string(),
432                            input: block["input"].clone(),
433                        });
434                    }
435                    _ => {}
436                }
437            }
438        }
439
440        // Strip thinking tags from reasoning models
441        let text_content = strip_thinking_tags(&text_content);
442
443        let message = Message {
444            role: Role::Assistant,
445            content: text_content,
446            tool_calls,
447            tool_results: Vec::new(),
448            timestamp: chrono::Utc::now(),
449        };
450
451        Ok(CompletionResponse {
452            message,
453            usage,
454            stop_reason,
455        })
456    }
457}
458
459#[async_trait]
460impl LlmDriver for AnthropicDriver {
461    async fn complete(&self, request: CompletionRequest) -> PunchResult<CompletionResponse> {
462        let url = format!("{}/v1/messages", self.base_url);
463        let body = self.build_request_body(&request);
464
465        let response = self
466            .client
467            .post(&url)
468            .header("x-api-key", &self.api_key)
469            .header("anthropic-version", "2023-06-01")
470            .header("content-type", "application/json")
471            .json(&body)
472            .send()
473            .await
474            .map_err(|e| PunchError::Provider {
475                provider: "anthropic".to_string(),
476                message: format!("request failed: {e}"),
477            })?;
478
479        let status = response.status();
480
481        if status == reqwest::StatusCode::TOO_MANY_REQUESTS {
482            let retry_after = response
483                .headers()
484                .get("retry-after")
485                .and_then(|v| v.to_str().ok())
486                .and_then(|s| s.parse::<u64>().ok())
487                .unwrap_or(60)
488                * 1000;
489
490            return Err(PunchError::RateLimited {
491                provider: "anthropic".to_string(),
492                retry_after_ms: retry_after,
493            });
494        }
495
496        if status == reqwest::StatusCode::UNAUTHORIZED || status == reqwest::StatusCode::FORBIDDEN {
497            return Err(PunchError::Auth(
498                "anthropic API key is invalid or lacks permissions".to_string(),
499            ));
500        }
501
502        let response_body: serde_json::Value =
503            response.json().await.map_err(|e| PunchError::Provider {
504                provider: "anthropic".to_string(),
505                message: format!("failed to parse response: {e}"),
506            })?;
507
508        if !status.is_success() {
509            let error_msg = response_body["error"]["message"]
510                .as_str()
511                .unwrap_or("unknown error");
512            return Err(PunchError::Provider {
513                provider: "anthropic".to_string(),
514                message: format!("API error ({}): {}", status, error_msg),
515            });
516        }
517
518        self.parse_response(&response_body)
519    }
520
521    async fn stream_complete_with_callback(
522        &self,
523        request: CompletionRequest,
524        callback: StreamCallback,
525    ) -> PunchResult<CompletionResponse> {
526        let url = format!("{}/v1/messages", self.base_url);
527        let mut body = self.build_request_body(&request);
528        body["stream"] = serde_json::json!(true);
529
530        let response = self
531            .client
532            .post(&url)
533            .header("x-api-key", &self.api_key)
534            .header("anthropic-version", "2023-06-01")
535            .header("content-type", "application/json")
536            .json(&body)
537            .send()
538            .await
539            .map_err(|e| PunchError::Provider {
540                provider: "anthropic".to_string(),
541                message: format!("stream request failed: {e}"),
542            })?;
543
544        let status = response.status();
545        if status == reqwest::StatusCode::TOO_MANY_REQUESTS {
546            return Err(PunchError::RateLimited {
547                provider: "anthropic".to_string(),
548                retry_after_ms: 60_000,
549            });
550        }
551        if !status.is_success() {
552            let err_body: serde_json::Value =
553                response.json().await.unwrap_or(serde_json::json!({}));
554            let msg = err_body["error"]["message"]
555                .as_str()
556                .unwrap_or("unknown error");
557            return Err(PunchError::Provider {
558                provider: "anthropic".to_string(),
559                message: format!("API error ({}): {}", status, msg),
560            });
561        }
562
563        let raw = read_stream_body(response).await?;
564        let events = parse_sse_events(&raw);
565
566        let mut text_content = String::new();
567        let mut tool_calls: Vec<ToolCall> = Vec::new();
568        let mut usage = TokenUsage::default();
569        let mut stop_reason = StopReason::EndTurn;
570        // Track current content block index for tool use assembly
571        let mut current_tool_index: Option<usize> = None;
572
573        for (event_type, data) in &events {
574            let parsed: serde_json::Value = match serde_json::from_str(data) {
575                Ok(v) => v,
576                Err(_) => continue,
577            };
578
579            match event_type.as_str() {
580                "message_start" => {
581                    if let Some(inp) = parsed["message"]["usage"]["input_tokens"].as_u64() {
582                        usage.input_tokens = inp;
583                    }
584                }
585                "content_block_start" => {
586                    let block = &parsed["content_block"];
587                    match block["type"].as_str() {
588                        Some("tool_use") => {
589                            let id = block["id"].as_str().unwrap_or_default().to_string();
590                            let name = block["name"].as_str().unwrap_or_default().to_string();
591                            tool_calls.push(ToolCall {
592                                id: id.clone(),
593                                name: name.clone(),
594                                input: serde_json::json!({}),
595                            });
596                            current_tool_index = Some(tool_calls.len() - 1);
597                            callback(StreamChunk {
598                                delta: String::new(),
599                                is_final: false,
600                                tool_call_delta: Some(ToolCallDelta {
601                                    index: tool_calls.len() - 1,
602                                    id: Some(id),
603                                    name: Some(name),
604                                    arguments_delta: String::new(),
605                                }),
606                            });
607                        }
608                        Some("text") => {
609                            current_tool_index = None;
610                        }
611                        _ => {}
612                    }
613                }
614                "content_block_delta" => {
615                    let delta = &parsed["delta"];
616                    match delta["type"].as_str() {
617                        Some("text_delta") => {
618                            let text = delta["text"].as_str().unwrap_or("");
619                            text_content.push_str(text);
620                            callback(StreamChunk {
621                                delta: text.to_string(),
622                                is_final: false,
623                                tool_call_delta: None,
624                            });
625                        }
626                        Some("input_json_delta") => {
627                            let partial = delta["partial_json"].as_str().unwrap_or("");
628                            if let Some(idx) = current_tool_index {
629                                callback(StreamChunk {
630                                    delta: String::new(),
631                                    is_final: false,
632                                    tool_call_delta: Some(ToolCallDelta {
633                                        index: idx,
634                                        id: None,
635                                        name: None,
636                                        arguments_delta: partial.to_string(),
637                                    }),
638                                });
639                            }
640                        }
641                        _ => {}
642                    }
643                }
644                "message_delta" => {
645                    if let Some(sr) = parsed["delta"]["stop_reason"].as_str() {
646                        stop_reason = match sr {
647                            "end_turn" => StopReason::EndTurn,
648                            "tool_use" => StopReason::ToolUse,
649                            "max_tokens" => StopReason::MaxTokens,
650                            _ => StopReason::Error,
651                        };
652                    }
653                    if let Some(out) = parsed["usage"]["output_tokens"].as_u64() {
654                        usage.output_tokens = out;
655                    }
656                }
657                "message_stop" => {
658                    callback(StreamChunk {
659                        delta: String::new(),
660                        is_final: true,
661                        tool_call_delta: None,
662                    });
663                }
664                _ => {}
665            }
666        }
667
668        // Reassemble tool call inputs from the accumulated JSON fragments.
669        // The Anthropic SSE stream sends tool input as `input_json_delta` fragments.
670        // We need to re-parse the full accumulated JSON for each tool call.
671        // Since we only captured deltas via callback, we rebuild from the raw events.
672        let mut tool_json_bufs: Vec<String> = vec![String::new(); tool_calls.len()];
673        let mut tc_idx: Option<usize> = None;
674        for (event_type, data) in &events {
675            let parsed: serde_json::Value = match serde_json::from_str(data) {
676                Ok(v) => v,
677                Err(_) => continue,
678            };
679            match event_type.as_str() {
680                "content_block_start" => {
681                    if parsed["content_block"]["type"].as_str() == Some("tool_use") {
682                        tc_idx = Some(tc_idx.map_or(0, |i| i + 1));
683                    } else {
684                        tc_idx = None;
685                    }
686                }
687                "content_block_delta" => {
688                    if parsed["delta"]["type"].as_str() == Some("input_json_delta")
689                        && let Some(idx) = tc_idx
690                        && let Some(buf) = tool_json_bufs.get_mut(idx)
691                    {
692                        buf.push_str(parsed["delta"]["partial_json"].as_str().unwrap_or(""));
693                    }
694                }
695                _ => {}
696            }
697        }
698        for (i, buf) in tool_json_bufs.into_iter().enumerate() {
699            if !buf.is_empty()
700                && let Some(tc) = tool_calls.get_mut(i)
701            {
702                tc.input = serde_json::from_str(&buf).unwrap_or(serde_json::json!({}));
703            }
704        }
705
706        let text_content = strip_thinking_tags(&text_content);
707
708        if !tool_calls.is_empty() && stop_reason != StopReason::ToolUse {
709            stop_reason = StopReason::ToolUse;
710        }
711
712        let message = Message {
713            role: Role::Assistant,
714            content: text_content,
715            tool_calls,
716            tool_results: Vec::new(),
717            timestamp: chrono::Utc::now(),
718        };
719
720        Ok(CompletionResponse {
721            message,
722            usage,
723            stop_reason,
724        })
725    }
726}
727
728// ---------------------------------------------------------------------------
729// OpenAI-compatible driver
730// ---------------------------------------------------------------------------
731
732/// Driver for OpenAI-compatible chat completions APIs.
733///
734/// Works with OpenAI, Groq, DeepSeek, Together, Fireworks,
735/// Cerebras, xAI, Mistral, and any other provider exposing the
736/// `/v1/chat/completions` endpoint.
737pub struct OpenAiCompatibleDriver {
738    client: Client,
739    api_key: String,
740    base_url: String,
741    provider_name: String,
742}
743
744impl OpenAiCompatibleDriver {
745    /// Create a new OpenAI-compatible driver.
746    pub fn new(api_key: String, base_url: String, provider_name: String) -> Self {
747        Self {
748            client: Client::new(),
749            api_key,
750            base_url,
751            provider_name,
752        }
753    }
754
755    /// Create a new OpenAI-compatible driver with a shared HTTP client.
756    pub fn with_client(
757        client: Client,
758        api_key: String,
759        base_url: String,
760        provider_name: String,
761    ) -> Self {
762        Self {
763            client,
764            api_key,
765            base_url,
766            provider_name,
767        }
768    }
769
770    /// Build the OpenAI chat completions request body.
771    pub fn build_request_body(&self, request: &CompletionRequest) -> serde_json::Value {
772        let mut messages = Vec::new();
773
774        // System prompt as a system message.
775        if let Some(ref system) = request.system_prompt {
776            messages.push(serde_json::json!({
777                "role": "system",
778                "content": system,
779            }));
780        }
781
782        for msg in &request.messages {
783            match msg.role {
784                Role::System => {
785                    messages.push(serde_json::json!({
786                        "role": "system",
787                        "content": msg.content,
788                    }));
789                }
790                Role::User => {
791                    messages.push(serde_json::json!({
792                        "role": "user",
793                        "content": msg.content,
794                    }));
795                }
796                Role::Assistant => {
797                    let mut m = serde_json::json!({
798                        "role": "assistant",
799                    });
800
801                    if !msg.content.is_empty() {
802                        m["content"] = serde_json::json!(msg.content);
803                    }
804
805                    if !msg.tool_calls.is_empty() {
806                        let tc: Vec<serde_json::Value> = msg
807                            .tool_calls
808                            .iter()
809                            .map(|tc| {
810                                serde_json::json!({
811                                    "id": tc.id,
812                                    "type": "function",
813                                    "function": {
814                                        "name": tc.name,
815                                        "arguments": tc.input.to_string(),
816                                    },
817                                })
818                            })
819                            .collect();
820                        m["tool_calls"] = serde_json::json!(tc);
821                    }
822
823                    messages.push(m);
824                }
825                Role::Tool => {
826                    for tr in &msg.tool_results {
827                        messages.push(serde_json::json!({
828                            "role": "tool",
829                            "tool_call_id": tr.id,
830                            "content": tr.content,
831                        }));
832                    }
833                }
834            }
835        }
836
837        let tools: Vec<serde_json::Value> = request
838            .tools
839            .iter()
840            .map(|t| {
841                serde_json::json!({
842                    "type": "function",
843                    "function": {
844                        "name": t.name,
845                        "description": t.description,
846                        "parameters": t.input_schema,
847                    },
848                })
849            })
850            .collect();
851
852        let mut body = serde_json::json!({
853            "model": request.model,
854            "messages": messages,
855            "max_tokens": request.max_tokens,
856        });
857
858        if let Some(temp) = request.temperature {
859            body["temperature"] = serde_json::json!(temp);
860        }
861
862        if !tools.is_empty() {
863            body["tools"] = serde_json::json!(tools);
864        }
865
866        body
867    }
868
869    /// Parse the OpenAI chat completions response.
870    pub fn parse_response(&self, body: &serde_json::Value) -> PunchResult<CompletionResponse> {
871        let choice = body["choices"].get(0).ok_or_else(|| PunchError::Provider {
872            provider: self.provider_name.clone(),
873            message: "no choices in response".to_string(),
874        })?;
875
876        let finish_reason = choice["finish_reason"].as_str().unwrap_or("stop");
877        let stop_reason = match finish_reason {
878            "stop" => StopReason::EndTurn,
879            "tool_calls" => StopReason::ToolUse,
880            "length" => StopReason::MaxTokens,
881            _ => StopReason::EndTurn,
882        };
883
884        let msg = &choice["message"];
885        let raw_content = msg["content"].as_str().unwrap_or("");
886        // Strip thinking tags from reasoning models (Qwen, DeepSeek R1, etc.)
887        let content = strip_thinking_tags(raw_content);
888
889        let mut tool_calls = Vec::new();
890        if let Some(tc_array) = msg["tool_calls"].as_array() {
891            for tc in tc_array {
892                let id = tc["id"].as_str().unwrap_or_default().to_string();
893                let name = tc["function"]["name"]
894                    .as_str()
895                    .unwrap_or_default()
896                    .to_string();
897                let args_str = tc["function"]["arguments"].as_str().unwrap_or("{}");
898                let input: serde_json::Value =
899                    serde_json::from_str(args_str).unwrap_or(serde_json::json!({}));
900
901                tool_calls.push(ToolCall { id, name, input });
902            }
903        }
904
905        let usage = TokenUsage {
906            input_tokens: body["usage"]["prompt_tokens"].as_u64().unwrap_or(0),
907            output_tokens: body["usage"]["completion_tokens"].as_u64().unwrap_or(0),
908        };
909
910        // If there are tool calls but finish_reason was not "tool_calls", fix it up.
911        let stop_reason = if !tool_calls.is_empty() && stop_reason != StopReason::ToolUse {
912            StopReason::ToolUse
913        } else {
914            stop_reason
915        };
916
917        let message = Message {
918            role: Role::Assistant,
919            content,
920            tool_calls,
921            tool_results: Vec::new(),
922            timestamp: chrono::Utc::now(),
923        };
924
925        Ok(CompletionResponse {
926            message,
927            usage,
928            stop_reason,
929        })
930    }
931}
932
933#[async_trait]
934impl LlmDriver for OpenAiCompatibleDriver {
935    async fn complete(&self, request: CompletionRequest) -> PunchResult<CompletionResponse> {
936        let url = format!(
937            "{}/v1/chat/completions",
938            self.base_url.trim_end_matches('/')
939        );
940        let body = self.build_request_body(&request);
941
942        let response = self
943            .client
944            .post(&url)
945            .header("authorization", format!("Bearer {}", self.api_key))
946            .header("content-type", "application/json")
947            .json(&body)
948            .send()
949            .await
950            .map_err(|e| PunchError::Provider {
951                provider: self.provider_name.clone(),
952                message: format!("request failed: {e}"),
953            })?;
954
955        let status = response.status();
956
957        if status == reqwest::StatusCode::TOO_MANY_REQUESTS {
958            let retry_after = response
959                .headers()
960                .get("retry-after")
961                .and_then(|v| v.to_str().ok())
962                .and_then(|s| s.parse::<u64>().ok())
963                .unwrap_or(60)
964                * 1000;
965
966            return Err(PunchError::RateLimited {
967                provider: self.provider_name.clone(),
968                retry_after_ms: retry_after,
969            });
970        }
971
972        if status == reqwest::StatusCode::UNAUTHORIZED || status == reqwest::StatusCode::FORBIDDEN {
973            return Err(PunchError::Auth(format!(
974                "{} API key is invalid or lacks permissions",
975                self.provider_name
976            )));
977        }
978
979        let response_body: serde_json::Value =
980            response.json().await.map_err(|e| PunchError::Provider {
981                provider: self.provider_name.clone(),
982                message: format!("failed to parse response: {e}"),
983            })?;
984
985        if !status.is_success() {
986            let error_msg = response_body["error"]["message"]
987                .as_str()
988                .unwrap_or("unknown error");
989            return Err(PunchError::Provider {
990                provider: self.provider_name.clone(),
991                message: format!("API error ({}): {}", status, error_msg),
992            });
993        }
994
995        self.parse_response(&response_body)
996    }
997
998    async fn stream_complete_with_callback(
999        &self,
1000        request: CompletionRequest,
1001        callback: StreamCallback,
1002    ) -> PunchResult<CompletionResponse> {
1003        let url = format!(
1004            "{}/v1/chat/completions",
1005            self.base_url.trim_end_matches('/')
1006        );
1007        let mut body = self.build_request_body(&request);
1008        body["stream"] = serde_json::json!(true);
1009
1010        let response = self
1011            .client
1012            .post(&url)
1013            .header("authorization", format!("Bearer {}", self.api_key))
1014            .header("content-type", "application/json")
1015            .json(&body)
1016            .send()
1017            .await
1018            .map_err(|e| PunchError::Provider {
1019                provider: self.provider_name.clone(),
1020                message: format!("stream request failed: {e}"),
1021            })?;
1022
1023        let status = response.status();
1024        if status == reqwest::StatusCode::TOO_MANY_REQUESTS {
1025            return Err(PunchError::RateLimited {
1026                provider: self.provider_name.clone(),
1027                retry_after_ms: 60_000,
1028            });
1029        }
1030        if !status.is_success() {
1031            let err_body: serde_json::Value =
1032                response.json().await.unwrap_or(serde_json::json!({}));
1033            let msg = err_body["error"]["message"]
1034                .as_str()
1035                .unwrap_or("unknown error");
1036            return Err(PunchError::Provider {
1037                provider: self.provider_name.clone(),
1038                message: format!("API error ({}): {}", status, msg),
1039            });
1040        }
1041
1042        let raw = read_stream_body(response).await?;
1043        let assembled = self.parse_openai_stream(&raw, &callback)?;
1044        Ok(assembled)
1045    }
1046}
1047
1048impl OpenAiCompatibleDriver {
1049    /// Parse an OpenAI-style SSE stream into a `CompletionResponse`, invoking
1050    /// the callback for each chunk.
1051    pub fn parse_openai_stream(
1052        &self,
1053        raw: &str,
1054        callback: &StreamCallback,
1055    ) -> PunchResult<CompletionResponse> {
1056        let events = parse_sse_events(raw);
1057
1058        let mut text_content = String::new();
1059        // tool_calls keyed by index
1060        let mut tool_map: std::collections::BTreeMap<usize, (String, String, String)> =
1061            std::collections::BTreeMap::new();
1062        let mut finish_reason = String::new();
1063
1064        for (_event_type, data) in &events {
1065            if data.trim() == "[DONE]" {
1066                callback(StreamChunk {
1067                    delta: String::new(),
1068                    is_final: true,
1069                    tool_call_delta: None,
1070                });
1071                break;
1072            }
1073
1074            let parsed: serde_json::Value = match serde_json::from_str(data) {
1075                Ok(v) => v,
1076                Err(_) => continue,
1077            };
1078
1079            let choice = match parsed["choices"].get(0) {
1080                Some(c) => c,
1081                None => continue,
1082            };
1083
1084            if let Some(fr) = choice["finish_reason"].as_str() {
1085                finish_reason = fr.to_string();
1086            }
1087
1088            let delta = &choice["delta"];
1089
1090            // Text content delta
1091            if let Some(content) = delta["content"].as_str() {
1092                text_content.push_str(content);
1093                callback(StreamChunk {
1094                    delta: content.to_string(),
1095                    is_final: false,
1096                    tool_call_delta: None,
1097                });
1098            }
1099
1100            // Tool call deltas
1101            if let Some(tc_array) = delta["tool_calls"].as_array() {
1102                for tc in tc_array {
1103                    let idx = tc["index"].as_u64().unwrap_or(0) as usize;
1104                    let entry = tool_map
1105                        .entry(idx)
1106                        .or_insert_with(|| (String::new(), String::new(), String::new()));
1107
1108                    let id_delta = tc["id"].as_str().unwrap_or("");
1109                    let name_delta = tc["function"]["name"].as_str().unwrap_or("");
1110                    let args_delta = tc["function"]["arguments"].as_str().unwrap_or("");
1111
1112                    if !id_delta.is_empty() {
1113                        entry.0.push_str(id_delta);
1114                    }
1115                    if !name_delta.is_empty() {
1116                        entry.1.push_str(name_delta);
1117                    }
1118                    entry.2.push_str(args_delta);
1119
1120                    callback(StreamChunk {
1121                        delta: String::new(),
1122                        is_final: false,
1123                        tool_call_delta: Some(ToolCallDelta {
1124                            index: idx,
1125                            id: if id_delta.is_empty() {
1126                                None
1127                            } else {
1128                                Some(id_delta.to_string())
1129                            },
1130                            name: if name_delta.is_empty() {
1131                                None
1132                            } else {
1133                                Some(name_delta.to_string())
1134                            },
1135                            arguments_delta: args_delta.to_string(),
1136                        }),
1137                    });
1138                }
1139            }
1140        }
1141
1142        let tool_calls: Vec<ToolCall> = tool_map
1143            .into_values()
1144            .map(|(id, name, args)| {
1145                let input: serde_json::Value =
1146                    serde_json::from_str(&args).unwrap_or(serde_json::json!({}));
1147                ToolCall { id, name, input }
1148            })
1149            .collect();
1150
1151        let stop_reason = if !tool_calls.is_empty() {
1152            StopReason::ToolUse
1153        } else {
1154            match finish_reason.as_str() {
1155                "stop" => StopReason::EndTurn,
1156                "tool_calls" => StopReason::ToolUse,
1157                "length" => StopReason::MaxTokens,
1158                _ => StopReason::EndTurn,
1159            }
1160        };
1161
1162        let text_content = strip_thinking_tags(&text_content);
1163
1164        let message = Message {
1165            role: Role::Assistant,
1166            content: text_content,
1167            tool_calls,
1168            tool_results: Vec::new(),
1169            timestamp: chrono::Utc::now(),
1170        };
1171
1172        // OpenAI streaming does not include usage in most chunks; set to zero.
1173        Ok(CompletionResponse {
1174            message,
1175            usage: TokenUsage::default(),
1176            stop_reason,
1177        })
1178    }
1179}
1180
1181// ---------------------------------------------------------------------------
1182// Gemini driver
1183// ---------------------------------------------------------------------------
1184
1185/// Driver for the Google Gemini (Generative Language) API.
1186pub struct GeminiDriver {
1187    client: Client,
1188    api_key: String,
1189    base_url: String,
1190}
1191
1192impl GeminiDriver {
1193    /// Create a new Gemini driver.
1194    pub fn new(api_key: String, base_url: Option<String>) -> Self {
1195        Self {
1196            client: Client::new(),
1197            api_key,
1198            base_url: base_url
1199                .unwrap_or_else(|| "https://generativelanguage.googleapis.com".to_string()),
1200        }
1201    }
1202
1203    /// Create a new Gemini driver with a shared HTTP client.
1204    pub fn with_client(client: Client, api_key: String, base_url: Option<String>) -> Self {
1205        Self {
1206            client,
1207            api_key,
1208            base_url: base_url
1209                .unwrap_or_else(|| "https://generativelanguage.googleapis.com".to_string()),
1210        }
1211    }
1212
1213    /// Build the Gemini API request body.
1214    pub fn build_request_body(&self, request: &CompletionRequest) -> serde_json::Value {
1215        let mut contents = Vec::new();
1216        let mut system_text: Option<String> = request.system_prompt.clone();
1217
1218        for msg in &request.messages {
1219            match msg.role {
1220                Role::System => {
1221                    // Gemini does not have a system role; prepend to first user message.
1222                    let existing = system_text.take().unwrap_or_default();
1223                    let combined = if existing.is_empty() {
1224                        msg.content.clone()
1225                    } else {
1226                        format!("{}\n{}", existing, msg.content)
1227                    };
1228                    system_text = Some(combined);
1229                }
1230                Role::User => {
1231                    let mut text = String::new();
1232                    if let Some(sys) = system_text.take()
1233                        && !sys.is_empty()
1234                    {
1235                        text.push_str(&sys);
1236                        text.push_str("\n\n");
1237                    }
1238                    text.push_str(&msg.content);
1239                    contents.push(serde_json::json!({
1240                        "role": "user",
1241                        "parts": [{"text": text}],
1242                    }));
1243                }
1244                Role::Assistant => {
1245                    let mut parts: Vec<serde_json::Value> = Vec::new();
1246                    if !msg.content.is_empty() {
1247                        parts.push(serde_json::json!({"text": msg.content}));
1248                    }
1249                    for tc in &msg.tool_calls {
1250                        parts.push(serde_json::json!({
1251                            "functionCall": {
1252                                "name": tc.name,
1253                                "args": tc.input,
1254                            }
1255                        }));
1256                    }
1257                    if parts.is_empty() {
1258                        parts.push(serde_json::json!({"text": ""}));
1259                    }
1260                    contents.push(serde_json::json!({
1261                        "role": "model",
1262                        "parts": parts,
1263                    }));
1264                }
1265                Role::Tool => {
1266                    let mut parts: Vec<serde_json::Value> = Vec::new();
1267                    for tr in &msg.tool_results {
1268                        parts.push(serde_json::json!({
1269                            "functionResponse": {
1270                                "name": tr.id.clone(),
1271                                "response": {"content": tr.content},
1272                            }
1273                        }));
1274                    }
1275                    contents.push(serde_json::json!({
1276                        "role": "user",
1277                        "parts": parts,
1278                    }));
1279                }
1280            }
1281        }
1282
1283        // If we still have an unused system prompt (no user messages yet), add it.
1284        if let Some(sys) = system_text
1285            && !sys.is_empty()
1286        {
1287            contents.insert(
1288                0,
1289                serde_json::json!({
1290                    "role": "user",
1291                    "parts": [{"text": sys}],
1292                }),
1293            );
1294        }
1295
1296        let mut body = serde_json::json!({
1297            "contents": contents,
1298        });
1299
1300        let mut gen_config = serde_json::json!({
1301            "maxOutputTokens": request.max_tokens,
1302        });
1303        if let Some(temp) = request.temperature {
1304            gen_config["temperature"] = serde_json::json!(temp);
1305        }
1306        body["generationConfig"] = gen_config;
1307
1308        if !request.tools.is_empty() {
1309            let func_decls: Vec<serde_json::Value> = request
1310                .tools
1311                .iter()
1312                .map(|t| {
1313                    serde_json::json!({
1314                        "name": t.name,
1315                        "description": t.description,
1316                        "parameters": t.input_schema,
1317                    })
1318                })
1319                .collect();
1320            body["tools"] = serde_json::json!([{"function_declarations": func_decls}]);
1321        }
1322
1323        body
1324    }
1325
1326    /// Build the full URL for a Gemini request.
1327    pub fn build_url(&self, model: &str) -> String {
1328        format!(
1329            "{}/v1beta/models/{}:generateContent?key={}",
1330            self.base_url.trim_end_matches('/'),
1331            model,
1332            self.api_key,
1333        )
1334    }
1335
1336    /// Build the URL for Gemini streaming.
1337    pub fn build_stream_url(&self, model: &str) -> String {
1338        format!(
1339            "{}/v1beta/models/{}:streamGenerateContent?alt=sse&key={}",
1340            self.base_url.trim_end_matches('/'),
1341            model,
1342            self.api_key,
1343        )
1344    }
1345
1346    /// Parse the Gemini API response.
1347    pub fn parse_response(&self, body: &serde_json::Value) -> PunchResult<CompletionResponse> {
1348        let candidate = body["candidates"]
1349            .get(0)
1350            .ok_or_else(|| PunchError::Provider {
1351                provider: "gemini".to_string(),
1352                message: "no candidates in response".to_string(),
1353            })?;
1354
1355        let parts = candidate["content"]["parts"]
1356            .as_array()
1357            .cloned()
1358            .unwrap_or_default();
1359
1360        let mut text_content = String::new();
1361        let mut tool_calls = Vec::new();
1362
1363        for part in &parts {
1364            if let Some(text) = part["text"].as_str() {
1365                if !text_content.is_empty() {
1366                    text_content.push('\n');
1367                }
1368                text_content.push_str(text);
1369            }
1370            if let Some(fc) = part.get("functionCall") {
1371                let name = fc["name"].as_str().unwrap_or_default().to_string();
1372                let args = fc["args"].clone();
1373                tool_calls.push(ToolCall {
1374                    id: format!("gemini-{}", uuid::Uuid::new_v4()),
1375                    name,
1376                    input: args,
1377                });
1378            }
1379        }
1380
1381        let finish_reason = candidate["finishReason"].as_str().unwrap_or("STOP");
1382        let stop_reason = if !tool_calls.is_empty() {
1383            StopReason::ToolUse
1384        } else {
1385            match finish_reason {
1386                "STOP" => StopReason::EndTurn,
1387                "MAX_TOKENS" => StopReason::MaxTokens,
1388                _ => StopReason::EndTurn,
1389            }
1390        };
1391
1392        let usage = TokenUsage {
1393            input_tokens: body["usageMetadata"]["promptTokenCount"]
1394                .as_u64()
1395                .unwrap_or(0),
1396            output_tokens: body["usageMetadata"]["candidatesTokenCount"]
1397                .as_u64()
1398                .unwrap_or(0),
1399        };
1400
1401        // Strip thinking tags from reasoning models
1402        let text_content = strip_thinking_tags(&text_content);
1403
1404        let message = Message {
1405            role: Role::Assistant,
1406            content: text_content,
1407            tool_calls,
1408            tool_results: Vec::new(),
1409            timestamp: chrono::Utc::now(),
1410        };
1411
1412        Ok(CompletionResponse {
1413            message,
1414            usage,
1415            stop_reason,
1416        })
1417    }
1418}
1419
1420#[async_trait]
1421impl LlmDriver for GeminiDriver {
1422    async fn complete(&self, request: CompletionRequest) -> PunchResult<CompletionResponse> {
1423        let url = self.build_url(&request.model);
1424        let body = self.build_request_body(&request);
1425
1426        let response = self
1427            .client
1428            .post(&url)
1429            .header("content-type", "application/json")
1430            .json(&body)
1431            .send()
1432            .await
1433            .map_err(|e| PunchError::Provider {
1434                provider: "gemini".to_string(),
1435                message: format!("request failed: {e}"),
1436            })?;
1437
1438        let status = response.status();
1439
1440        if status == reqwest::StatusCode::TOO_MANY_REQUESTS {
1441            return Err(PunchError::RateLimited {
1442                provider: "gemini".to_string(),
1443                retry_after_ms: 60_000,
1444            });
1445        }
1446
1447        if status == reqwest::StatusCode::UNAUTHORIZED || status == reqwest::StatusCode::FORBIDDEN {
1448            return Err(PunchError::Auth(
1449                "Gemini API key is invalid or lacks permissions".to_string(),
1450            ));
1451        }
1452
1453        let response_body: serde_json::Value =
1454            response.json().await.map_err(|e| PunchError::Provider {
1455                provider: "gemini".to_string(),
1456                message: format!("failed to parse response: {e}"),
1457            })?;
1458
1459        if !status.is_success() {
1460            let error_msg = response_body["error"]["message"]
1461                .as_str()
1462                .unwrap_or("unknown error");
1463            return Err(PunchError::Provider {
1464                provider: "gemini".to_string(),
1465                message: format!("API error ({}): {}", status, error_msg),
1466            });
1467        }
1468
1469        self.parse_response(&response_body)
1470    }
1471
1472    async fn stream_complete_with_callback(
1473        &self,
1474        request: CompletionRequest,
1475        callback: StreamCallback,
1476    ) -> PunchResult<CompletionResponse> {
1477        let url = self.build_stream_url(&request.model);
1478        let body = self.build_request_body(&request);
1479
1480        let response = self
1481            .client
1482            .post(&url)
1483            .header("content-type", "application/json")
1484            .json(&body)
1485            .send()
1486            .await
1487            .map_err(|e| PunchError::Provider {
1488                provider: "gemini".to_string(),
1489                message: format!("stream request failed: {e}"),
1490            })?;
1491
1492        let status = response.status();
1493        if !status.is_success() {
1494            let err_body: serde_json::Value =
1495                response.json().await.unwrap_or(serde_json::json!({}));
1496            let msg = err_body["error"]["message"]
1497                .as_str()
1498                .unwrap_or("unknown error");
1499            return Err(PunchError::Provider {
1500                provider: "gemini".to_string(),
1501                message: format!("API error ({}): {}", status, msg),
1502            });
1503        }
1504
1505        let raw = read_stream_body(response).await?;
1506        let events = parse_sse_events(&raw);
1507
1508        let mut text_content = String::new();
1509        let mut tool_calls: Vec<ToolCall> = Vec::new();
1510        let mut usage = TokenUsage::default();
1511        let mut finish_reason = String::new();
1512
1513        for (_event_type, data) in &events {
1514            let parsed: serde_json::Value = match serde_json::from_str(data) {
1515                Ok(v) => v,
1516                Err(_) => continue,
1517            };
1518
1519            // Extract parts from the candidate
1520            if let Some(parts) = parsed["candidates"][0]["content"]["parts"].as_array() {
1521                for part in parts {
1522                    if let Some(text) = part["text"].as_str() {
1523                        text_content.push_str(text);
1524                        callback(StreamChunk {
1525                            delta: text.to_string(),
1526                            is_final: false,
1527                            tool_call_delta: None,
1528                        });
1529                    }
1530                    if let Some(fc) = part.get("functionCall") {
1531                        let name = fc["name"].as_str().unwrap_or_default().to_string();
1532                        let args = fc["args"].clone();
1533                        let idx = tool_calls.len();
1534                        tool_calls.push(ToolCall {
1535                            id: format!("gemini-{}", uuid::Uuid::new_v4()),
1536                            name: name.clone(),
1537                            input: args,
1538                        });
1539                        callback(StreamChunk {
1540                            delta: String::new(),
1541                            is_final: false,
1542                            tool_call_delta: Some(ToolCallDelta {
1543                                index: idx,
1544                                id: None,
1545                                name: Some(name),
1546                                arguments_delta: String::new(),
1547                            }),
1548                        });
1549                    }
1550                }
1551            }
1552
1553            if let Some(fr) = parsed["candidates"][0]["finishReason"].as_str() {
1554                finish_reason = fr.to_string();
1555            }
1556
1557            // Usage from the last chunk
1558            if let Some(inp) = parsed["usageMetadata"]["promptTokenCount"].as_u64() {
1559                usage.input_tokens = inp;
1560            }
1561            if let Some(out) = parsed["usageMetadata"]["candidatesTokenCount"].as_u64() {
1562                usage.output_tokens = out;
1563            }
1564        }
1565
1566        callback(StreamChunk {
1567            delta: String::new(),
1568            is_final: true,
1569            tool_call_delta: None,
1570        });
1571
1572        let stop_reason = if !tool_calls.is_empty() {
1573            StopReason::ToolUse
1574        } else {
1575            match finish_reason.as_str() {
1576                "STOP" => StopReason::EndTurn,
1577                "MAX_TOKENS" => StopReason::MaxTokens,
1578                _ => StopReason::EndTurn,
1579            }
1580        };
1581
1582        let text_content = strip_thinking_tags(&text_content);
1583
1584        let message = Message {
1585            role: Role::Assistant,
1586            content: text_content,
1587            tool_calls,
1588            tool_results: Vec::new(),
1589            timestamp: chrono::Utc::now(),
1590        };
1591
1592        Ok(CompletionResponse {
1593            message,
1594            usage,
1595            stop_reason,
1596        })
1597    }
1598}
1599
1600// ---------------------------------------------------------------------------
1601// Ollama driver
1602// ---------------------------------------------------------------------------
1603
1604/// Driver for local Ollama instances using the chat API.
1605pub struct OllamaDriver {
1606    client: Client,
1607    base_url: String,
1608}
1609
1610impl OllamaDriver {
1611    /// Create a new Ollama driver.
1612    pub fn new(base_url: Option<String>) -> Self {
1613        Self {
1614            client: Client::new(),
1615            base_url: base_url.unwrap_or_else(|| "http://localhost:11434".to_string()),
1616        }
1617    }
1618
1619    /// Create a new Ollama driver with a shared HTTP client.
1620    pub fn with_client(client: Client, base_url: Option<String>) -> Self {
1621        Self {
1622            client,
1623            base_url: base_url.unwrap_or_else(|| "http://localhost:11434".to_string()),
1624        }
1625    }
1626
1627    /// Get the base URL.
1628    pub fn base_url(&self) -> &str {
1629        &self.base_url
1630    }
1631
1632    /// Build the Ollama chat request body.
1633    pub fn build_request_body(&self, request: &CompletionRequest) -> serde_json::Value {
1634        let mut messages = Vec::new();
1635
1636        if let Some(ref system) = request.system_prompt {
1637            messages.push(serde_json::json!({
1638                "role": "system",
1639                "content": system,
1640            }));
1641        }
1642
1643        for msg in &request.messages {
1644            match msg.role {
1645                Role::System => {
1646                    messages.push(serde_json::json!({
1647                        "role": "system",
1648                        "content": msg.content,
1649                    }));
1650                }
1651                Role::User => {
1652                    messages.push(serde_json::json!({
1653                        "role": "user",
1654                        "content": msg.content,
1655                    }));
1656                }
1657                Role::Assistant => {
1658                    let mut m = serde_json::json!({
1659                        "role": "assistant",
1660                        "content": msg.content,
1661                    });
1662                    if !msg.tool_calls.is_empty() {
1663                        let tc: Vec<serde_json::Value> = msg
1664                            .tool_calls
1665                            .iter()
1666                            .map(|tc| {
1667                                serde_json::json!({
1668                                    "function": {
1669                                        "name": tc.name,
1670                                        "arguments": tc.input,
1671                                    }
1672                                })
1673                            })
1674                            .collect();
1675                        m["tool_calls"] = serde_json::json!(tc);
1676                    }
1677                    messages.push(m);
1678                }
1679                Role::Tool => {
1680                    for tr in &msg.tool_results {
1681                        messages.push(serde_json::json!({
1682                            "role": "tool",
1683                            "content": tr.content,
1684                        }));
1685                    }
1686                }
1687            }
1688        }
1689
1690        let mut body = serde_json::json!({
1691            "model": request.model,
1692            "messages": messages,
1693            "stream": false,
1694        });
1695
1696        let mut options = serde_json::json!({});
1697        if let Some(temp) = request.temperature {
1698            options["temperature"] = serde_json::json!(temp);
1699        }
1700        if request.max_tokens > 0 {
1701            options["num_predict"] = serde_json::json!(request.max_tokens);
1702        }
1703        body["options"] = options;
1704
1705        // Disable thinking mode for reasoning models (Qwen, DeepSeek) to prevent
1706        // the model from spending its entire token budget on internal reasoning.
1707        // The think tags get stripped anyway, so we avoid wasting tokens.
1708        body["think"] = serde_json::json!(false);
1709
1710        if !request.tools.is_empty() {
1711            let tools: Vec<serde_json::Value> = request
1712                .tools
1713                .iter()
1714                .map(|t| {
1715                    serde_json::json!({
1716                        "type": "function",
1717                        "function": {
1718                            "name": t.name,
1719                            "description": t.description,
1720                            "parameters": t.input_schema,
1721                        }
1722                    })
1723                })
1724                .collect();
1725            body["tools"] = serde_json::json!(tools);
1726        }
1727
1728        body
1729    }
1730
1731    /// Parse the Ollama chat response.
1732    pub fn parse_response(&self, body: &serde_json::Value) -> PunchResult<CompletionResponse> {
1733        let msg = &body["message"];
1734        let raw_content = msg["content"].as_str().unwrap_or("");
1735        // Strip thinking tags from reasoning models (Qwen, DeepSeek, etc.)
1736        let content = strip_thinking_tags(raw_content);
1737
1738        let mut tool_calls = Vec::new();
1739        if let Some(tc_array) = msg["tool_calls"].as_array() {
1740            for tc in tc_array {
1741                let name = tc["function"]["name"]
1742                    .as_str()
1743                    .unwrap_or_default()
1744                    .to_string();
1745                let input = tc["function"]["arguments"].clone();
1746                tool_calls.push(ToolCall {
1747                    id: format!("ollama-{}", uuid::Uuid::new_v4()),
1748                    name,
1749                    input,
1750                });
1751            }
1752        }
1753
1754        let stop_reason = if !tool_calls.is_empty() {
1755            StopReason::ToolUse
1756        } else if body["done"].as_bool().unwrap_or(true) {
1757            StopReason::EndTurn
1758        } else {
1759            StopReason::MaxTokens
1760        };
1761
1762        let usage = TokenUsage {
1763            input_tokens: body["prompt_eval_count"].as_u64().unwrap_or(0),
1764            output_tokens: body["eval_count"].as_u64().unwrap_or(0),
1765        };
1766
1767        let message = Message {
1768            role: Role::Assistant,
1769            content,
1770            tool_calls,
1771            tool_results: Vec::new(),
1772            timestamp: chrono::Utc::now(),
1773        };
1774
1775        Ok(CompletionResponse {
1776            message,
1777            usage,
1778            stop_reason,
1779        })
1780    }
1781}
1782
1783#[async_trait]
1784impl LlmDriver for OllamaDriver {
1785    async fn complete(&self, request: CompletionRequest) -> PunchResult<CompletionResponse> {
1786        let url = format!("{}/api/chat", self.base_url.trim_end_matches('/'));
1787        let body = self.build_request_body(&request);
1788
1789        let response = self
1790            .client
1791            .post(&url)
1792            .header("content-type", "application/json")
1793            .json(&body)
1794            .send()
1795            .await
1796            .map_err(|e| PunchError::Provider {
1797                provider: "ollama".to_string(),
1798                message: format!("request failed: {e}"),
1799            })?;
1800
1801        let status = response.status();
1802        let response_body: serde_json::Value =
1803            response.json().await.map_err(|e| PunchError::Provider {
1804                provider: "ollama".to_string(),
1805                message: format!("failed to parse response: {e}"),
1806            })?;
1807
1808        if !status.is_success() {
1809            let error_msg = response_body["error"].as_str().unwrap_or("unknown error");
1810            return Err(PunchError::Provider {
1811                provider: "ollama".to_string(),
1812                message: format!("API error ({}): {}", status, error_msg),
1813            });
1814        }
1815
1816        self.parse_response(&response_body)
1817    }
1818
1819    async fn stream_complete_with_callback(
1820        &self,
1821        request: CompletionRequest,
1822        callback: StreamCallback,
1823    ) -> PunchResult<CompletionResponse> {
1824        let url = format!("{}/api/chat", self.base_url.trim_end_matches('/'));
1825        let mut body = self.build_request_body(&request);
1826        body["stream"] = serde_json::json!(true);
1827
1828        let response = self
1829            .client
1830            .post(&url)
1831            .header("content-type", "application/json")
1832            .json(&body)
1833            .send()
1834            .await
1835            .map_err(|e| PunchError::Provider {
1836                provider: "ollama".to_string(),
1837                message: format!("stream request failed: {e}"),
1838            })?;
1839
1840        let status = response.status();
1841        if !status.is_success() {
1842            let err_body: serde_json::Value =
1843                response.json().await.unwrap_or(serde_json::json!({}));
1844            let msg = err_body["error"].as_str().unwrap_or("unknown error");
1845            return Err(PunchError::Provider {
1846                provider: "ollama".to_string(),
1847                message: format!("API error ({}): {}", status, msg),
1848            });
1849        }
1850
1851        let raw = read_stream_body(response).await?;
1852        let assembled = self.parse_ollama_stream(&raw, &callback)?;
1853        Ok(assembled)
1854    }
1855}
1856
1857impl OllamaDriver {
1858    /// Parse Ollama's newline-delimited JSON stream into a `CompletionResponse`,
1859    /// invoking the callback for each chunk.
1860    pub fn parse_ollama_stream(
1861        &self,
1862        raw: &str,
1863        callback: &StreamCallback,
1864    ) -> PunchResult<CompletionResponse> {
1865        let mut text_content = String::new();
1866        let mut tool_calls: Vec<ToolCall> = Vec::new();
1867        let mut usage = TokenUsage::default();
1868        let mut done = false;
1869
1870        for line in raw.lines() {
1871            let line = line.trim();
1872            if line.is_empty() {
1873                continue;
1874            }
1875
1876            let parsed: serde_json::Value = match serde_json::from_str(line) {
1877                Ok(v) => v,
1878                Err(_) => continue,
1879            };
1880
1881            if parsed["done"].as_bool() == Some(true) {
1882                done = true;
1883                // Final chunk may include stats
1884                if let Some(inp) = parsed["prompt_eval_count"].as_u64() {
1885                    usage.input_tokens = inp;
1886                }
1887                if let Some(out) = parsed["eval_count"].as_u64() {
1888                    usage.output_tokens = out;
1889                }
1890                // Final chunk may also have tool calls
1891                if let Some(tc_array) = parsed["message"]["tool_calls"].as_array() {
1892                    for tc in tc_array {
1893                        let name = tc["function"]["name"]
1894                            .as_str()
1895                            .unwrap_or_default()
1896                            .to_string();
1897                        let input = tc["function"]["arguments"].clone();
1898                        tool_calls.push(ToolCall {
1899                            id: format!("ollama-{}", uuid::Uuid::new_v4()),
1900                            name,
1901                            input,
1902                        });
1903                    }
1904                }
1905                callback(StreamChunk {
1906                    delta: String::new(),
1907                    is_final: true,
1908                    tool_call_delta: None,
1909                });
1910                break;
1911            }
1912
1913            // Streaming chunk with content
1914            let content = parsed["message"]["content"].as_str().unwrap_or("");
1915            if !content.is_empty() {
1916                text_content.push_str(content);
1917                callback(StreamChunk {
1918                    delta: content.to_string(),
1919                    is_final: false,
1920                    tool_call_delta: None,
1921                });
1922            }
1923        }
1924
1925        let text_content = strip_thinking_tags(&text_content);
1926
1927        let stop_reason = if !tool_calls.is_empty() {
1928            StopReason::ToolUse
1929        } else if done {
1930            StopReason::EndTurn
1931        } else {
1932            StopReason::MaxTokens
1933        };
1934
1935        let message = Message {
1936            role: Role::Assistant,
1937            content: text_content,
1938            tool_calls,
1939            tool_results: Vec::new(),
1940            timestamp: chrono::Utc::now(),
1941        };
1942
1943        Ok(CompletionResponse {
1944            message,
1945            usage,
1946            stop_reason,
1947        })
1948    }
1949}
1950
1951// ---------------------------------------------------------------------------
1952// AWS Bedrock driver
1953// ---------------------------------------------------------------------------
1954
1955/// Driver for AWS Bedrock using the Converse API with SigV4 authentication.
1956pub struct BedrockDriver {
1957    client: Client,
1958    access_key: String,
1959    secret_key: String,
1960    region: String,
1961}
1962
1963impl BedrockDriver {
1964    /// Create a new Bedrock driver.
1965    pub fn new(access_key: String, secret_key: String, region: String) -> Self {
1966        Self {
1967            client: Client::new(),
1968            access_key,
1969            secret_key,
1970            region,
1971        }
1972    }
1973
1974    /// Create a new Bedrock driver with a shared HTTP client.
1975    pub fn with_client(
1976        client: Client,
1977        access_key: String,
1978        secret_key: String,
1979        region: String,
1980    ) -> Self {
1981        Self {
1982            client,
1983            access_key,
1984            secret_key,
1985            region,
1986        }
1987    }
1988
1989    /// Build the Bedrock Converse API request body.
1990    pub fn build_request_body(&self, request: &CompletionRequest) -> serde_json::Value {
1991        let mut messages = Vec::new();
1992
1993        for msg in &request.messages {
1994            match msg.role {
1995                Role::User => {
1996                    messages.push(serde_json::json!({
1997                        "role": "user",
1998                        "content": [{"text": msg.content}],
1999                    }));
2000                }
2001                Role::Assistant => {
2002                    let mut content: Vec<serde_json::Value> = Vec::new();
2003                    if !msg.content.is_empty() {
2004                        content.push(serde_json::json!({"text": msg.content}));
2005                    }
2006                    for tc in &msg.tool_calls {
2007                        content.push(serde_json::json!({
2008                            "toolUse": {
2009                                "toolUseId": tc.id,
2010                                "name": tc.name,
2011                                "input": tc.input,
2012                            }
2013                        }));
2014                    }
2015                    if content.is_empty() {
2016                        content.push(serde_json::json!({"text": ""}));
2017                    }
2018                    messages.push(serde_json::json!({
2019                        "role": "assistant",
2020                        "content": content,
2021                    }));
2022                }
2023                Role::Tool => {
2024                    let mut content: Vec<serde_json::Value> = Vec::new();
2025                    for tr in &msg.tool_results {
2026                        content.push(serde_json::json!({
2027                            "toolResult": {
2028                                "toolUseId": tr.id,
2029                                "content": [{"text": tr.content}],
2030                                "status": if tr.is_error { "error" } else { "success" },
2031                            }
2032                        }));
2033                    }
2034                    messages.push(serde_json::json!({
2035                        "role": "user",
2036                        "content": content,
2037                    }));
2038                }
2039                Role::System => {
2040                    // System messages handled separately.
2041                }
2042            }
2043        }
2044
2045        let mut body = serde_json::json!({
2046            "messages": messages,
2047        });
2048
2049        let mut inference_config = serde_json::json!({
2050            "maxTokens": request.max_tokens,
2051        });
2052        if let Some(temp) = request.temperature {
2053            inference_config["temperature"] = serde_json::json!(temp);
2054        }
2055        body["inferenceConfig"] = inference_config;
2056
2057        if let Some(ref system) = request.system_prompt {
2058            body["system"] = serde_json::json!([{"text": system}]);
2059        }
2060
2061        if !request.tools.is_empty() {
2062            let tool_config: Vec<serde_json::Value> = request
2063                .tools
2064                .iter()
2065                .map(|t| {
2066                    serde_json::json!({
2067                        "toolSpec": {
2068                            "name": t.name,
2069                            "description": t.description,
2070                            "inputSchema": {"json": t.input_schema},
2071                        }
2072                    })
2073                })
2074                .collect();
2075            body["toolConfig"] = serde_json::json!({"tools": tool_config});
2076        }
2077
2078        body
2079    }
2080
2081    /// Build the endpoint URL for a model.
2082    pub fn build_url(&self, model_id: &str) -> String {
2083        format!(
2084            "https://bedrock-runtime.{}.amazonaws.com/model/{}/converse",
2085            self.region, model_id,
2086        )
2087    }
2088
2089    /// Parse the Bedrock Converse API response.
2090    pub fn parse_response(&self, body: &serde_json::Value) -> PunchResult<CompletionResponse> {
2091        let content = body["output"]["message"]["content"]
2092            .as_array()
2093            .cloned()
2094            .unwrap_or_default();
2095
2096        let mut text_content = String::new();
2097        let mut tool_calls = Vec::new();
2098
2099        for block in &content {
2100            if let Some(text) = block["text"].as_str() {
2101                if !text_content.is_empty() {
2102                    text_content.push('\n');
2103                }
2104                text_content.push_str(text);
2105            }
2106            if let Some(tu) = block.get("toolUse") {
2107                tool_calls.push(ToolCall {
2108                    id: tu["toolUseId"].as_str().unwrap_or_default().to_string(),
2109                    name: tu["name"].as_str().unwrap_or_default().to_string(),
2110                    input: tu["input"].clone(),
2111                });
2112            }
2113        }
2114
2115        let stop_reason_str = body["stopReason"].as_str().unwrap_or("end_turn");
2116        let stop_reason = if !tool_calls.is_empty() {
2117            StopReason::ToolUse
2118        } else {
2119            match stop_reason_str {
2120                "end_turn" => StopReason::EndTurn,
2121                "tool_use" => StopReason::ToolUse,
2122                "max_tokens" => StopReason::MaxTokens,
2123                _ => StopReason::EndTurn,
2124            }
2125        };
2126
2127        let usage = TokenUsage {
2128            input_tokens: body["usage"]["inputTokens"].as_u64().unwrap_or(0),
2129            output_tokens: body["usage"]["outputTokens"].as_u64().unwrap_or(0),
2130        };
2131
2132        // Strip thinking tags from reasoning models
2133        let text_content = strip_thinking_tags(&text_content);
2134
2135        let message = Message {
2136            role: Role::Assistant,
2137            content: text_content,
2138            tool_calls,
2139            tool_results: Vec::new(),
2140            timestamp: chrono::Utc::now(),
2141        };
2142
2143        Ok(CompletionResponse {
2144            message,
2145            usage,
2146            stop_reason,
2147        })
2148    }
2149
2150    /// Compute an AWS SigV4 signature and return the Authorization header value.
2151    ///
2152    /// This is a basic implementation sufficient for Bedrock API calls.
2153    pub fn sign_request(
2154        &self,
2155        method: &str,
2156        url: &str,
2157        headers: &[(String, String)],
2158        payload: &[u8],
2159        timestamp: &str, // format: "20260313T120000Z"
2160    ) -> PunchResult<String> {
2161        let date = &timestamp[..8]; // "20260313"
2162        let service = "bedrock";
2163
2164        // Parse the URL to get host and path.
2165        let parsed = url::Url::parse(url).map_err(|e| PunchError::Provider {
2166            provider: "bedrock".to_string(),
2167            message: format!("invalid URL: {e}"),
2168        })?;
2169        let host = parsed.host_str().unwrap_or("");
2170        let path = parsed.path();
2171
2172        // 1. Create canonical request.
2173        let payload_hash = hex_sha256(payload);
2174
2175        let mut signed_header_names: Vec<String> =
2176            headers.iter().map(|(k, _)| k.to_lowercase()).collect();
2177        signed_header_names.push("host".to_string());
2178        signed_header_names.push("x-amz-date".to_string());
2179        signed_header_names.sort();
2180        signed_header_names.dedup();
2181
2182        let mut header_map: Vec<(String, String)> = headers
2183            .iter()
2184            .map(|(k, v)| (k.to_lowercase(), v.trim().to_string()))
2185            .collect();
2186        header_map.push(("host".to_string(), host.to_string()));
2187        header_map.push(("x-amz-date".to_string(), timestamp.to_string()));
2188        header_map.sort_by(|a, b| a.0.cmp(&b.0));
2189        header_map.dedup_by(|a, b| a.0 == b.0);
2190
2191        let canonical_headers: String = header_map
2192            .iter()
2193            .map(|(k, v)| format!("{}:{}\n", k, v))
2194            .collect();
2195
2196        let signed_headers = signed_header_names.join(";");
2197
2198        let canonical_request = format!(
2199            "{}\n{}\n\n{}\n{}\n{}",
2200            method, path, canonical_headers, signed_headers, payload_hash,
2201        );
2202
2203        // 2. Create string to sign.
2204        let credential_scope = format!("{}/{}/{}/aws4_request", date, self.region, service);
2205        let string_to_sign = format!(
2206            "AWS4-HMAC-SHA256\n{}\n{}\n{}",
2207            timestamp,
2208            credential_scope,
2209            hex_sha256(canonical_request.as_bytes()),
2210        );
2211
2212        // 3. Calculate signing key.
2213        let k_date = hmac_sha256(
2214            format!("AWS4{}", self.secret_key).as_bytes(),
2215            date.as_bytes(),
2216        );
2217        let k_region = hmac_sha256(&k_date, self.region.as_bytes());
2218        let k_service = hmac_sha256(&k_region, service.as_bytes());
2219        let k_signing = hmac_sha256(&k_service, b"aws4_request");
2220
2221        // 4. Calculate signature.
2222        let signature = hex_encode(&hmac_sha256(&k_signing, string_to_sign.as_bytes()));
2223
2224        // 5. Build Authorization header.
2225        Ok(format!(
2226            "AWS4-HMAC-SHA256 Credential={}/{}, SignedHeaders={}, Signature={}",
2227            self.access_key, credential_scope, signed_headers, signature,
2228        ))
2229    }
2230}
2231
2232/// Compute SHA-256 hex digest.
2233fn hex_sha256(data: &[u8]) -> String {
2234    let mut hasher = Sha256::new();
2235    hasher.update(data);
2236    hex_encode(hasher.finalize().as_slice())
2237}
2238
2239/// Compute HMAC-SHA256.
2240fn hmac_sha256(key: &[u8], data: &[u8]) -> Vec<u8> {
2241    type HmacSha256 = Hmac<Sha256>;
2242    let mut mac = HmacSha256::new_from_slice(key).expect("HMAC can take key of any size");
2243    mac.update(data);
2244    mac.finalize().into_bytes().to_vec()
2245}
2246
2247/// Hex-encode bytes without an external crate.
2248fn hex_encode(bytes: &[u8]) -> String {
2249    bytes.iter().map(|b| format!("{:02x}", b)).collect()
2250}
2251
2252#[async_trait]
2253impl LlmDriver for BedrockDriver {
2254    async fn complete(&self, request: CompletionRequest) -> PunchResult<CompletionResponse> {
2255        let url = self.build_url(&request.model);
2256        let body = self.build_request_body(&request);
2257        let payload = serde_json::to_vec(&body).map_err(|e| PunchError::Provider {
2258            provider: "bedrock".to_string(),
2259            message: format!("failed to serialize request: {e}"),
2260        })?;
2261
2262        let timestamp = chrono::Utc::now().format("%Y%m%dT%H%M%SZ").to_string();
2263
2264        let auth_header = self.sign_request(
2265            "POST",
2266            &url,
2267            &[("content-type".to_string(), "application/json".to_string())],
2268            &payload,
2269            &timestamp,
2270        )?;
2271
2272        let parsed_url = url::Url::parse(&url).map_err(|e| PunchError::Provider {
2273            provider: "bedrock".to_string(),
2274            message: format!("invalid URL: {e}"),
2275        })?;
2276        let host = parsed_url.host_str().unwrap_or_default().to_string();
2277
2278        let response = self
2279            .client
2280            .post(&url)
2281            .header("content-type", "application/json")
2282            .header("host", &host)
2283            .header("x-amz-date", &timestamp)
2284            .header("authorization", &auth_header)
2285            .body(payload)
2286            .send()
2287            .await
2288            .map_err(|e| PunchError::Provider {
2289                provider: "bedrock".to_string(),
2290                message: format!("request failed: {e}"),
2291            })?;
2292
2293        let status = response.status();
2294
2295        if status == reqwest::StatusCode::TOO_MANY_REQUESTS {
2296            return Err(PunchError::RateLimited {
2297                provider: "bedrock".to_string(),
2298                retry_after_ms: 60_000,
2299            });
2300        }
2301
2302        if status == reqwest::StatusCode::UNAUTHORIZED || status == reqwest::StatusCode::FORBIDDEN {
2303            return Err(PunchError::Auth(
2304                "AWS Bedrock credentials are invalid or lack permissions".to_string(),
2305            ));
2306        }
2307
2308        let response_body: serde_json::Value =
2309            response.json().await.map_err(|e| PunchError::Provider {
2310                provider: "bedrock".to_string(),
2311                message: format!("failed to parse response: {e}"),
2312            })?;
2313
2314        if !status.is_success() {
2315            let error_msg = response_body["message"].as_str().unwrap_or("unknown error");
2316            return Err(PunchError::Provider {
2317                provider: "bedrock".to_string(),
2318                message: format!("API error ({}): {}", status, error_msg),
2319            });
2320        }
2321
2322        self.parse_response(&response_body)
2323    }
2324
2325    async fn stream_complete_with_callback(
2326        &self,
2327        request: CompletionRequest,
2328        callback: StreamCallback,
2329    ) -> PunchResult<CompletionResponse> {
2330        // Bedrock uses a proprietary binary event stream format for streaming.
2331        // Fall back to non-streaming and emit the result as a single final chunk.
2332        let response = self.complete(request).await?;
2333        callback(StreamChunk {
2334            delta: response.message.content.clone(),
2335            is_final: true,
2336            tool_call_delta: None,
2337        });
2338        Ok(response)
2339    }
2340}
2341
2342// ---------------------------------------------------------------------------
2343// Azure OpenAI driver
2344// ---------------------------------------------------------------------------
2345
2346/// Driver for Azure OpenAI deployments.
2347///
2348/// Uses the same request/response format as OpenAI but with Azure-specific
2349/// URL construction and API key header.
2350pub struct AzureOpenAiDriver {
2351    inner: OpenAiCompatibleDriver,
2352    resource: String,
2353    deployment: String,
2354    api_version: String,
2355}
2356
2357impl AzureOpenAiDriver {
2358    /// Create a new Azure OpenAI driver.
2359    ///
2360    /// - `api_key`: The Azure OpenAI API key.
2361    /// - `resource`: The Azure resource name (subdomain).
2362    /// - `deployment`: The deployment name.
2363    /// - `api_version`: API version string (e.g., "2024-02-01").
2364    pub fn new(
2365        api_key: String,
2366        resource: String,
2367        deployment: String,
2368        api_version: Option<String>,
2369    ) -> Self {
2370        let base_url = format!("https://{}.openai.azure.com", resource);
2371        Self {
2372            inner: OpenAiCompatibleDriver::new(api_key, base_url, "azure_openai".to_string()),
2373            resource,
2374            deployment,
2375            api_version: api_version.unwrap_or_else(|| "2024-02-01".to_string()),
2376        }
2377    }
2378
2379    /// Create a new Azure OpenAI driver with a shared HTTP client.
2380    pub fn with_client(
2381        client: Client,
2382        api_key: String,
2383        resource: String,
2384        deployment: String,
2385        api_version: Option<String>,
2386    ) -> Self {
2387        let base_url = format!("https://{}.openai.azure.com", resource);
2388        Self {
2389            inner: OpenAiCompatibleDriver::with_client(
2390                client,
2391                api_key,
2392                base_url,
2393                "azure_openai".to_string(),
2394            ),
2395            resource,
2396            deployment,
2397            api_version: api_version.unwrap_or_else(|| "2024-02-01".to_string()),
2398        }
2399    }
2400
2401    /// Build the Azure OpenAI endpoint URL.
2402    pub fn build_url(&self) -> String {
2403        format!(
2404            "https://{}.openai.azure.com/openai/deployments/{}/chat/completions?api-version={}",
2405            self.resource, self.deployment, self.api_version,
2406        )
2407    }
2408
2409    /// Get the resource name.
2410    pub fn resource(&self) -> &str {
2411        &self.resource
2412    }
2413
2414    /// Get the deployment name.
2415    pub fn deployment(&self) -> &str {
2416        &self.deployment
2417    }
2418
2419    /// Build request body (delegates to inner OpenAI-compatible driver).
2420    pub fn build_request_body(&self, request: &CompletionRequest) -> serde_json::Value {
2421        self.inner.build_request_body(request)
2422    }
2423
2424    /// Parse response (delegates to inner OpenAI-compatible driver).
2425    pub fn parse_response(&self, body: &serde_json::Value) -> PunchResult<CompletionResponse> {
2426        self.inner.parse_response(body)
2427    }
2428}
2429
2430#[async_trait]
2431impl LlmDriver for AzureOpenAiDriver {
2432    async fn complete(&self, request: CompletionRequest) -> PunchResult<CompletionResponse> {
2433        let url = self.build_url();
2434        let body = self.inner.build_request_body(&request);
2435
2436        let response = self
2437            .inner
2438            .client
2439            .post(&url)
2440            .header("api-key", &self.inner.api_key)
2441            .header("content-type", "application/json")
2442            .json(&body)
2443            .send()
2444            .await
2445            .map_err(|e| PunchError::Provider {
2446                provider: "azure_openai".to_string(),
2447                message: format!("request failed: {e}"),
2448            })?;
2449
2450        let status = response.status();
2451
2452        if status == reqwest::StatusCode::TOO_MANY_REQUESTS {
2453            let retry_after = response
2454                .headers()
2455                .get("retry-after")
2456                .and_then(|v| v.to_str().ok())
2457                .and_then(|s| s.parse::<u64>().ok())
2458                .unwrap_or(60)
2459                * 1000;
2460
2461            return Err(PunchError::RateLimited {
2462                provider: "azure_openai".to_string(),
2463                retry_after_ms: retry_after,
2464            });
2465        }
2466
2467        if status == reqwest::StatusCode::UNAUTHORIZED || status == reqwest::StatusCode::FORBIDDEN {
2468            return Err(PunchError::Auth(
2469                "Azure OpenAI API key is invalid or lacks permissions".to_string(),
2470            ));
2471        }
2472
2473        let response_body: serde_json::Value =
2474            response.json().await.map_err(|e| PunchError::Provider {
2475                provider: "azure_openai".to_string(),
2476                message: format!("failed to parse response: {e}"),
2477            })?;
2478
2479        if !status.is_success() {
2480            let error_msg = response_body["error"]["message"]
2481                .as_str()
2482                .unwrap_or("unknown error");
2483            return Err(PunchError::Provider {
2484                provider: "azure_openai".to_string(),
2485                message: format!("API error ({}): {}", status, error_msg),
2486            });
2487        }
2488
2489        self.inner.parse_response(&response_body)
2490    }
2491
2492    async fn stream_complete_with_callback(
2493        &self,
2494        request: CompletionRequest,
2495        callback: StreamCallback,
2496    ) -> PunchResult<CompletionResponse> {
2497        let url = self.build_url();
2498        let mut body = self.inner.build_request_body(&request);
2499        body["stream"] = serde_json::json!(true);
2500
2501        let response = self
2502            .inner
2503            .client
2504            .post(&url)
2505            .header("api-key", &self.inner.api_key)
2506            .header("content-type", "application/json")
2507            .json(&body)
2508            .send()
2509            .await
2510            .map_err(|e| PunchError::Provider {
2511                provider: "azure_openai".to_string(),
2512                message: format!("stream request failed: {e}"),
2513            })?;
2514
2515        let status = response.status();
2516        if status == reqwest::StatusCode::TOO_MANY_REQUESTS {
2517            return Err(PunchError::RateLimited {
2518                provider: "azure_openai".to_string(),
2519                retry_after_ms: 60_000,
2520            });
2521        }
2522        if !status.is_success() {
2523            let err_body: serde_json::Value =
2524                response.json().await.unwrap_or(serde_json::json!({}));
2525            let msg = err_body["error"]["message"]
2526                .as_str()
2527                .unwrap_or("unknown error");
2528            return Err(PunchError::Provider {
2529                provider: "azure_openai".to_string(),
2530                message: format!("API error ({}): {}", status, msg),
2531            });
2532        }
2533
2534        let raw = read_stream_body(response).await?;
2535        // Azure OpenAI uses the same SSE format as OpenAI
2536        let assembled = self.inner.parse_openai_stream(&raw, &callback)?;
2537        Ok(assembled)
2538    }
2539}
2540
2541// ---------------------------------------------------------------------------
2542// Factory
2543// ---------------------------------------------------------------------------
2544
2545/// Default base URLs for known providers.
2546fn default_base_url(provider: &Provider) -> &'static str {
2547    match provider {
2548        Provider::Anthropic => "https://api.anthropic.com",
2549        Provider::OpenAI => "https://api.openai.com",
2550        Provider::Google => "https://generativelanguage.googleapis.com",
2551        Provider::Groq => "https://api.groq.com/openai",
2552        Provider::DeepSeek => "https://api.deepseek.com",
2553        Provider::Ollama => "http://localhost:11434",
2554        Provider::Mistral => "https://api.mistral.ai",
2555        Provider::Together => "https://api.together.xyz",
2556        Provider::Fireworks => "https://api.fireworks.ai/inference",
2557        Provider::Cerebras => "https://api.cerebras.ai",
2558        Provider::XAI => "https://api.x.ai",
2559        Provider::Cohere => "https://api.cohere.ai",
2560        Provider::Bedrock => "https://bedrock-runtime.us-east-1.amazonaws.com",
2561        Provider::AzureOpenAi => "",
2562        Provider::Custom(_) => "",
2563    }
2564}
2565
2566/// Create an [`LlmDriver`] from a [`ModelConfig`].
2567///
2568/// Reads the API key from the environment variable specified in
2569/// `config.api_key_env`. Returns an error if the env var is missing
2570/// (except for Ollama which does not require auth).
2571/// Create a driver from config, optionally using a shared HTTP client.
2572///
2573/// If `shared_client` is `Some`, the driver will use that client for
2574/// connection pooling. Otherwise it creates its own client (backward compat).
2575pub fn create_driver(config: &ModelConfig) -> PunchResult<Arc<dyn LlmDriver>> {
2576    create_driver_with_client(config, None)
2577}
2578
2579/// Create a driver from config with an optional shared [`reqwest::Client`].
2580pub fn create_driver_with_client(
2581    config: &ModelConfig,
2582    shared_client: Option<&Client>,
2583) -> PunchResult<Arc<dyn LlmDriver>> {
2584    let api_key = match &config.api_key_env {
2585        Some(env_var) => std::env::var(env_var).map_err(|_| {
2586            PunchError::Auth(format!(
2587                "environment variable '{}' not set for {} driver",
2588                env_var, config.provider
2589            ))
2590        })?,
2591        None => {
2592            // Ollama typically has no auth; others will fail at the API.
2593            String::new()
2594        }
2595    };
2596
2597    let base_url = config
2598        .base_url
2599        .clone()
2600        .unwrap_or_else(|| default_base_url(&config.provider).to_string());
2601
2602    match &config.provider {
2603        Provider::Anthropic => {
2604            if let Some(client) = shared_client {
2605                Ok(Arc::new(AnthropicDriver::with_client(
2606                    client.clone(),
2607                    api_key,
2608                    Some(base_url),
2609                )))
2610            } else {
2611                Ok(Arc::new(AnthropicDriver::new(api_key, Some(base_url))))
2612            }
2613        }
2614        Provider::Google => {
2615            if let Some(client) = shared_client {
2616                Ok(Arc::new(GeminiDriver::with_client(
2617                    client.clone(),
2618                    api_key,
2619                    Some(base_url),
2620                )))
2621            } else {
2622                Ok(Arc::new(GeminiDriver::new(api_key, Some(base_url))))
2623            }
2624        }
2625        Provider::Ollama => {
2626            if let Some(client) = shared_client {
2627                Ok(Arc::new(OllamaDriver::with_client(
2628                    client.clone(),
2629                    Some(base_url),
2630                )))
2631            } else {
2632                Ok(Arc::new(OllamaDriver::new(Some(base_url))))
2633            }
2634        }
2635        Provider::Bedrock => {
2636            // For Bedrock, api_key is expected to be "ACCESS_KEY:SECRET_KEY" or
2637            // we read AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY from env.
2638            let (access_key, secret_key) = if api_key.contains(':') {
2639                let parts: Vec<&str> = api_key.splitn(2, ':').collect();
2640                (parts[0].to_string(), parts[1].to_string())
2641            } else {
2642                let ak = std::env::var("AWS_ACCESS_KEY_ID").unwrap_or(api_key);
2643                let sk = std::env::var("AWS_SECRET_ACCESS_KEY").unwrap_or_default();
2644                (ak, sk)
2645            };
2646            // Extract region from base_url or default to us-east-1.
2647            let region = if base_url.contains("bedrock-runtime.") {
2648                base_url
2649                    .trim_start_matches("https://bedrock-runtime.")
2650                    .split('.')
2651                    .next()
2652                    .unwrap_or("us-east-1")
2653                    .to_string()
2654            } else {
2655                "us-east-1".to_string()
2656            };
2657            if let Some(client) = shared_client {
2658                Ok(Arc::new(BedrockDriver::with_client(
2659                    client.clone(),
2660                    access_key,
2661                    secret_key,
2662                    region,
2663                )))
2664            } else {
2665                Ok(Arc::new(BedrockDriver::new(access_key, secret_key, region)))
2666            }
2667        }
2668        Provider::AzureOpenAi => {
2669            // For Azure, base_url should be "https://{resource}.openai.azure.com"
2670            // and model is the deployment name.
2671            let resource = if base_url.contains(".openai.azure.com") {
2672                base_url
2673                    .trim_start_matches("https://")
2674                    .split('.')
2675                    .next()
2676                    .unwrap_or("default")
2677                    .to_string()
2678            } else {
2679                base_url.clone()
2680            };
2681            let deployment = config.model.clone();
2682            if let Some(client) = shared_client {
2683                Ok(Arc::new(AzureOpenAiDriver::with_client(
2684                    client.clone(),
2685                    api_key,
2686                    resource,
2687                    deployment,
2688                    None,
2689                )))
2690            } else {
2691                Ok(Arc::new(AzureOpenAiDriver::new(
2692                    api_key, resource, deployment, None,
2693                )))
2694            }
2695        }
2696        provider => {
2697            let name = provider.to_string();
2698            if let Some(client) = shared_client {
2699                Ok(Arc::new(OpenAiCompatibleDriver::with_client(
2700                    client.clone(),
2701                    api_key,
2702                    base_url,
2703                    name,
2704                )))
2705            } else {
2706                Ok(Arc::new(OpenAiCompatibleDriver::new(
2707                    api_key, base_url, name,
2708                )))
2709            }
2710        }
2711    }
2712}
2713
2714// ---------------------------------------------------------------------------
2715// Tests
2716// ---------------------------------------------------------------------------
2717
2718#[cfg(test)]
2719mod tests {
2720    use super::*;
2721    use punch_types::ToolCategory;
2722
2723    /// Helper to build a simple completion request for testing.
2724    fn simple_request() -> CompletionRequest {
2725        CompletionRequest {
2726            model: "test-model".to_string(),
2727            messages: vec![Message::new(Role::User, "Hello")],
2728            tools: Vec::new(),
2729            max_tokens: 4096,
2730            temperature: Some(0.7),
2731            system_prompt: Some("You are helpful.".to_string()),
2732        }
2733    }
2734
2735    /// Helper to build a request with tools.
2736    fn request_with_tools() -> CompletionRequest {
2737        CompletionRequest {
2738            model: "test-model".to_string(),
2739            messages: vec![Message::new(Role::User, "Use the tool")],
2740            tools: vec![ToolDefinition {
2741                name: "get_weather".to_string(),
2742                description: "Get weather for a city".to_string(),
2743                input_schema: serde_json::json!({
2744                    "type": "object",
2745                    "properties": {
2746                        "city": {"type": "string"}
2747                    }
2748                }),
2749                category: ToolCategory::Web,
2750            }],
2751            max_tokens: 4096,
2752            temperature: Some(0.7),
2753            system_prompt: None,
2754        }
2755    }
2756
2757    // -----------------------------------------------------------------------
2758    // Gemini tests
2759    // -----------------------------------------------------------------------
2760
2761    #[test]
2762    fn gemini_request_formatting() {
2763        let driver = GeminiDriver::new("test-key".to_string(), None);
2764        let body = driver.build_request_body(&simple_request());
2765
2766        let contents = body["contents"].as_array().unwrap();
2767        assert_eq!(contents.len(), 1);
2768        // System prompt should be prepended to user message.
2769        let first_text = contents[0]["parts"][0]["text"].as_str().unwrap();
2770        assert!(first_text.contains("You are helpful."));
2771        assert!(first_text.contains("Hello"));
2772        // Role should be "user" (not "system").
2773        assert_eq!(contents[0]["role"].as_str().unwrap(), "user");
2774
2775        assert_eq!(body["generationConfig"]["maxOutputTokens"], 4096);
2776        assert!((body["generationConfig"]["temperature"].as_f64().unwrap() - 0.7).abs() < 0.001);
2777    }
2778
2779    #[test]
2780    fn gemini_response_parsing() {
2781        let driver = GeminiDriver::new("test-key".to_string(), None);
2782        let response_body = serde_json::json!({
2783            "candidates": [{
2784                "content": {
2785                    "parts": [{"text": "Hello there!"}],
2786                    "role": "model"
2787                },
2788                "finishReason": "STOP"
2789            }],
2790            "usageMetadata": {
2791                "promptTokenCount": 10,
2792                "candidatesTokenCount": 5
2793            }
2794        });
2795
2796        let resp = driver.parse_response(&response_body).unwrap();
2797        assert_eq!(resp.message.content, "Hello there!");
2798        assert_eq!(resp.stop_reason, StopReason::EndTurn);
2799        assert_eq!(resp.usage.input_tokens, 10);
2800        assert_eq!(resp.usage.output_tokens, 5);
2801    }
2802
2803    #[test]
2804    fn gemini_role_mapping_system_prepended() {
2805        let driver = GeminiDriver::new("test-key".to_string(), None);
2806        let req = CompletionRequest {
2807            model: "gemini-pro".to_string(),
2808            messages: vec![
2809                Message::new(Role::System, "Be concise."),
2810                Message::new(Role::User, "Hi"),
2811            ],
2812            tools: Vec::new(),
2813            max_tokens: 1024,
2814            temperature: None,
2815            system_prompt: None,
2816        };
2817        let body = driver.build_request_body(&req);
2818        let contents = body["contents"].as_array().unwrap();
2819        // System message should be merged into user message.
2820        assert_eq!(contents.len(), 1);
2821        let text = contents[0]["parts"][0]["text"].as_str().unwrap();
2822        assert!(text.contains("Be concise."));
2823        assert!(text.contains("Hi"));
2824    }
2825
2826    #[test]
2827    fn gemini_function_call_parsing() {
2828        let driver = GeminiDriver::new("test-key".to_string(), None);
2829        let response_body = serde_json::json!({
2830            "candidates": [{
2831                "content": {
2832                    "parts": [
2833                        {"text": "Let me check the weather."},
2834                        {
2835                            "functionCall": {
2836                                "name": "get_weather",
2837                                "args": {"city": "London"}
2838                            }
2839                        }
2840                    ],
2841                    "role": "model"
2842                },
2843                "finishReason": "STOP"
2844            }],
2845            "usageMetadata": {
2846                "promptTokenCount": 15,
2847                "candidatesTokenCount": 8
2848            }
2849        });
2850
2851        let resp = driver.parse_response(&response_body).unwrap();
2852        assert_eq!(resp.message.content, "Let me check the weather.");
2853        assert_eq!(resp.stop_reason, StopReason::ToolUse);
2854        assert_eq!(resp.message.tool_calls.len(), 1);
2855        assert_eq!(resp.message.tool_calls[0].name, "get_weather");
2856        assert_eq!(resp.message.tool_calls[0].input["city"], "London");
2857    }
2858
2859    #[test]
2860    fn gemini_api_key_in_url() {
2861        let driver = GeminiDriver::new("my-secret-key".to_string(), None);
2862        let url = driver.build_url("gemini-pro");
2863        assert!(url.contains("key=my-secret-key"));
2864        assert!(url.contains("models/gemini-pro:generateContent"));
2865    }
2866
2867    // -----------------------------------------------------------------------
2868    // Ollama tests
2869    // -----------------------------------------------------------------------
2870
2871    #[test]
2872    fn ollama_request_formatting() {
2873        let driver = OllamaDriver::new(None);
2874        let body = driver.build_request_body(&simple_request());
2875
2876        assert_eq!(body["model"], "test-model");
2877        assert_eq!(body["stream"], false);
2878        let messages = body["messages"].as_array().unwrap();
2879        // system prompt + user message = 2 messages
2880        assert_eq!(messages.len(), 2);
2881        assert_eq!(messages[0]["role"], "system");
2882        assert_eq!(messages[0]["content"], "You are helpful.");
2883        assert_eq!(messages[1]["role"], "user");
2884        assert_eq!(messages[1]["content"], "Hello");
2885        assert!((body["options"]["temperature"].as_f64().unwrap() - 0.7).abs() < 0.001);
2886    }
2887
2888    #[test]
2889    fn ollama_response_parsing() {
2890        let driver = OllamaDriver::new(None);
2891        let response_body = serde_json::json!({
2892            "message": {
2893                "role": "assistant",
2894                "content": "Hi there!"
2895            },
2896            "done": true,
2897            "prompt_eval_count": 20,
2898            "eval_count": 10
2899        });
2900
2901        let resp = driver.parse_response(&response_body).unwrap();
2902        assert_eq!(resp.message.content, "Hi there!");
2903        assert_eq!(resp.stop_reason, StopReason::EndTurn);
2904        assert_eq!(resp.usage.input_tokens, 20);
2905        assert_eq!(resp.usage.output_tokens, 10);
2906    }
2907
2908    #[test]
2909    fn ollama_default_endpoint() {
2910        let driver = OllamaDriver::new(None);
2911        assert_eq!(driver.base_url(), "http://localhost:11434");
2912    }
2913
2914    #[test]
2915    fn ollama_custom_endpoint() {
2916        let driver = OllamaDriver::new(Some("http://myhost:9999".to_string()));
2917        assert_eq!(driver.base_url(), "http://myhost:9999");
2918    }
2919
2920    // -----------------------------------------------------------------------
2921    // Bedrock tests
2922    // -----------------------------------------------------------------------
2923
2924    #[test]
2925    fn bedrock_request_formatting() {
2926        let driver = BedrockDriver::new(
2927            "TESTKEY".to_string(),
2928            "testsecret".to_string(),
2929            "us-west-2".to_string(),
2930        );
2931        let body = driver.build_request_body(&simple_request());
2932
2933        let messages = body["messages"].as_array().unwrap();
2934        assert_eq!(messages.len(), 1);
2935        assert_eq!(messages[0]["role"], "user");
2936        assert_eq!(messages[0]["content"][0]["text"], "Hello");
2937
2938        assert_eq!(body["inferenceConfig"]["maxTokens"], 4096);
2939        assert!((body["inferenceConfig"]["temperature"].as_f64().unwrap() - 0.7).abs() < 0.001);
2940        assert_eq!(body["system"][0]["text"], "You are helpful.");
2941    }
2942
2943    #[test]
2944    fn bedrock_sigv4_canonical_request() {
2945        let driver = BedrockDriver::new(
2946            "TESTACCESS1234567890".to_string(),
2947            "TestSecretKeyValue1234567890abcdefghijk".to_string(),
2948            "us-east-1".to_string(),
2949        );
2950
2951        let payload = b"{}";
2952        let timestamp = "20260313T120000Z";
2953
2954        let auth = driver
2955            .sign_request(
2956                "POST",
2957                "https://bedrock-runtime.us-east-1.amazonaws.com/model/test/converse",
2958                &[("content-type".to_string(), "application/json".to_string())],
2959                payload,
2960                timestamp,
2961            )
2962            .unwrap();
2963
2964        assert!(auth.starts_with(
2965            "AWS4-HMAC-SHA256 Credential=TESTACCESS1234567890/20260313/us-east-1/bedrock/aws4_request"
2966        ));
2967        assert!(auth.contains("SignedHeaders=content-type;host;x-amz-date"));
2968        assert!(auth.contains("Signature="));
2969    }
2970
2971    #[test]
2972    fn bedrock_response_parsing() {
2973        let driver = BedrockDriver::new(
2974            "key".to_string(),
2975            "secret".to_string(),
2976            "us-east-1".to_string(),
2977        );
2978        let response_body = serde_json::json!({
2979            "output": {
2980                "message": {
2981                    "role": "assistant",
2982                    "content": [{"text": "The answer is 42."}]
2983                }
2984            },
2985            "stopReason": "end_turn",
2986            "usage": {
2987                "inputTokens": 100,
2988                "outputTokens": 50
2989            }
2990        });
2991
2992        let resp = driver.parse_response(&response_body).unwrap();
2993        assert_eq!(resp.message.content, "The answer is 42.");
2994        assert_eq!(resp.stop_reason, StopReason::EndTurn);
2995        assert_eq!(resp.usage.input_tokens, 100);
2996        assert_eq!(resp.usage.output_tokens, 50);
2997    }
2998
2999    // -----------------------------------------------------------------------
3000    // Azure OpenAI tests
3001    // -----------------------------------------------------------------------
3002
3003    #[test]
3004    fn azure_openai_url_construction() {
3005        let driver = AzureOpenAiDriver::new(
3006            "my-azure-key".to_string(),
3007            "myresource".to_string(),
3008            "gpt-4-deployment".to_string(),
3009            None,
3010        );
3011        let url = driver.build_url();
3012        assert_eq!(
3013            url,
3014            "https://myresource.openai.azure.com/openai/deployments/gpt-4-deployment/chat/completions?api-version=2024-02-01"
3015        );
3016    }
3017
3018    #[test]
3019    fn azure_openai_custom_api_version() {
3020        let driver = AzureOpenAiDriver::new(
3021            "key".to_string(),
3022            "res".to_string(),
3023            "dep".to_string(),
3024            Some("2024-06-01".to_string()),
3025        );
3026        let url = driver.build_url();
3027        assert!(url.contains("api-version=2024-06-01"));
3028    }
3029
3030    #[test]
3031    fn azure_openai_request_formatting() {
3032        let driver = AzureOpenAiDriver::new(
3033            "key".to_string(),
3034            "res".to_string(),
3035            "dep".to_string(),
3036            None,
3037        );
3038        let body = driver.build_request_body(&simple_request());
3039        // Should use OpenAI format.
3040        let messages = body["messages"].as_array().unwrap();
3041        // system prompt + user message = 2
3042        assert_eq!(messages.len(), 2);
3043        assert_eq!(messages[0]["role"], "system");
3044        assert_eq!(messages[1]["role"], "user");
3045        assert_eq!(body["model"], "test-model");
3046    }
3047
3048    #[test]
3049    fn azure_openai_resource_and_deployment() {
3050        let driver = AzureOpenAiDriver::new(
3051            "key".to_string(),
3052            "my-resource".to_string(),
3053            "my-deploy".to_string(),
3054            None,
3055        );
3056        assert_eq!(driver.resource(), "my-resource");
3057        assert_eq!(driver.deployment(), "my-deploy");
3058    }
3059
3060    // -----------------------------------------------------------------------
3061    // create_driver dispatch tests
3062    // -----------------------------------------------------------------------
3063
3064    #[test]
3065    fn create_driver_dispatches_ollama() {
3066        let config = ModelConfig {
3067            provider: Provider::Ollama,
3068            model: "llama3".to_string(),
3069            api_key_env: None,
3070            base_url: None,
3071            max_tokens: None,
3072            temperature: None,
3073        };
3074        // Ollama does not need an API key, so this should succeed.
3075        let driver = create_driver(&config);
3076        assert!(driver.is_ok());
3077    }
3078
3079    #[test]
3080    fn create_driver_dispatches_gemini() {
3081        // Set a fake env var for this test.
3082        // SAFETY: Test is single-threaded relative to this env var name.
3083        unsafe { std::env::set_var("TEST_GEMINI_KEY_DISPATCH", "fake-key") };
3084        let config = ModelConfig {
3085            provider: Provider::Google,
3086            model: "gemini-pro".to_string(),
3087            api_key_env: Some("TEST_GEMINI_KEY_DISPATCH".to_string()),
3088            base_url: None,
3089            max_tokens: None,
3090            temperature: None,
3091        };
3092        let driver = create_driver(&config);
3093        assert!(driver.is_ok());
3094        unsafe { std::env::remove_var("TEST_GEMINI_KEY_DISPATCH") };
3095    }
3096
3097    #[test]
3098    fn create_driver_dispatches_bedrock() {
3099        // SAFETY: Test is single-threaded relative to this env var name.
3100        unsafe { std::env::set_var("TEST_BEDROCK_KEY_DISPATCH", "TESTKEY:TESTSECRET") };
3101        let config = ModelConfig {
3102            provider: Provider::Bedrock,
3103            model: "anthropic.claude-v2".to_string(),
3104            api_key_env: Some("TEST_BEDROCK_KEY_DISPATCH".to_string()),
3105            base_url: None,
3106            max_tokens: None,
3107            temperature: None,
3108        };
3109        let driver = create_driver(&config);
3110        assert!(driver.is_ok());
3111        unsafe { std::env::remove_var("TEST_BEDROCK_KEY_DISPATCH") };
3112    }
3113
3114    #[test]
3115    fn create_driver_dispatches_azure_openai() {
3116        // SAFETY: Test is single-threaded relative to this env var name.
3117        unsafe { std::env::set_var("TEST_AZURE_KEY_DISPATCH", "azure-key") };
3118        let config = ModelConfig {
3119            provider: Provider::AzureOpenAi,
3120            model: "gpt-4".to_string(),
3121            api_key_env: Some("TEST_AZURE_KEY_DISPATCH".to_string()),
3122            base_url: Some("https://myres.openai.azure.com".to_string()),
3123            max_tokens: None,
3124            temperature: None,
3125        };
3126        let driver = create_driver(&config);
3127        assert!(driver.is_ok());
3128        unsafe { std::env::remove_var("TEST_AZURE_KEY_DISPATCH") };
3129    }
3130
3131    #[test]
3132    fn gemini_tools_in_request() {
3133        let driver = GeminiDriver::new("key".to_string(), None);
3134        let body = driver.build_request_body(&request_with_tools());
3135
3136        let tools = body["tools"].as_array().unwrap();
3137        assert_eq!(tools.len(), 1);
3138        let func_decls = tools[0]["function_declarations"].as_array().unwrap();
3139        assert_eq!(func_decls.len(), 1);
3140        assert_eq!(func_decls[0]["name"], "get_weather");
3141    }
3142
3143    #[test]
3144    fn ollama_tools_in_request() {
3145        let driver = OllamaDriver::new(None);
3146        let body = driver.build_request_body(&request_with_tools());
3147
3148        let tools = body["tools"].as_array().unwrap();
3149        assert_eq!(tools.len(), 1);
3150        assert_eq!(tools[0]["type"], "function");
3151        assert_eq!(tools[0]["function"]["name"], "get_weather");
3152    }
3153
3154    #[test]
3155    fn bedrock_url_construction() {
3156        let driver = BedrockDriver::new(
3157            "key".to_string(),
3158            "secret".to_string(),
3159            "eu-west-1".to_string(),
3160        );
3161        let url = driver.build_url("anthropic.claude-3-sonnet");
3162        assert_eq!(
3163            url,
3164            "https://bedrock-runtime.eu-west-1.amazonaws.com/model/anthropic.claude-3-sonnet/converse"
3165        );
3166    }
3167
3168    // -----------------------------------------------------------------------
3169    // TokenUsage tests
3170    // -----------------------------------------------------------------------
3171
3172    #[test]
3173    fn token_usage_default() {
3174        let u = TokenUsage::default();
3175        assert_eq!(u.input_tokens, 0);
3176        assert_eq!(u.output_tokens, 0);
3177        assert_eq!(u.total(), 0);
3178    }
3179
3180    #[test]
3181    fn token_usage_accumulate() {
3182        let mut u = TokenUsage {
3183            input_tokens: 10,
3184            output_tokens: 20,
3185        };
3186        let other = TokenUsage {
3187            input_tokens: 5,
3188            output_tokens: 15,
3189        };
3190        u.accumulate(&other);
3191        assert_eq!(u.input_tokens, 15);
3192        assert_eq!(u.output_tokens, 35);
3193        assert_eq!(u.total(), 50);
3194    }
3195
3196    #[test]
3197    fn token_usage_total() {
3198        let u = TokenUsage {
3199            input_tokens: 100,
3200            output_tokens: 200,
3201        };
3202        assert_eq!(u.total(), 300);
3203    }
3204
3205    // -----------------------------------------------------------------------
3206    // StopReason serialization
3207    // -----------------------------------------------------------------------
3208
3209    #[test]
3210    fn stop_reason_serialization() {
3211        let json = serde_json::to_string(&StopReason::EndTurn).unwrap();
3212        assert_eq!(json, "\"end_turn\"");
3213
3214        let json = serde_json::to_string(&StopReason::ToolUse).unwrap();
3215        assert_eq!(json, "\"tool_use\"");
3216
3217        let json = serde_json::to_string(&StopReason::MaxTokens).unwrap();
3218        assert_eq!(json, "\"max_tokens\"");
3219
3220        let json = serde_json::to_string(&StopReason::Error).unwrap();
3221        assert_eq!(json, "\"error\"");
3222    }
3223
3224    #[test]
3225    fn stop_reason_deserialization() {
3226        let sr: StopReason = serde_json::from_str("\"end_turn\"").unwrap();
3227        assert_eq!(sr, StopReason::EndTurn);
3228
3229        let sr: StopReason = serde_json::from_str("\"tool_use\"").unwrap();
3230        assert_eq!(sr, StopReason::ToolUse);
3231    }
3232
3233    // -----------------------------------------------------------------------
3234    // Anthropic driver tests
3235    // -----------------------------------------------------------------------
3236
3237    #[test]
3238    fn anthropic_request_body_simple() {
3239        let driver = AnthropicDriver::new("test-key".to_string(), None);
3240        let body = driver.build_request_body(&simple_request());
3241
3242        assert_eq!(body["model"], "test-model");
3243        assert_eq!(body["max_tokens"], 4096);
3244        assert_eq!(body["system"], "You are helpful.");
3245        assert!((body["temperature"].as_f64().unwrap() - 0.7).abs() < 0.001);
3246
3247        let messages = body["messages"].as_array().unwrap();
3248        assert_eq!(messages.len(), 1);
3249        assert_eq!(messages[0]["role"], "user");
3250        assert_eq!(messages[0]["content"], "Hello");
3251    }
3252
3253    #[test]
3254    fn anthropic_request_body_with_tools() {
3255        let driver = AnthropicDriver::new("test-key".to_string(), None);
3256        let body = driver.build_request_body(&request_with_tools());
3257
3258        let tools = body["tools"].as_array().unwrap();
3259        assert_eq!(tools.len(), 1);
3260        assert_eq!(tools[0]["name"], "get_weather");
3261        assert!(tools[0]["input_schema"]["properties"].is_object());
3262    }
3263
3264    #[test]
3265    fn anthropic_request_body_no_system_prompt() {
3266        let driver = AnthropicDriver::new("test-key".to_string(), None);
3267        let req = CompletionRequest {
3268            model: "test".into(),
3269            messages: vec![Message::new(Role::User, "Hi")],
3270            tools: Vec::new(),
3271            max_tokens: 100,
3272            temperature: None,
3273            system_prompt: None,
3274        };
3275        let body = driver.build_request_body(&req);
3276        assert!(body.get("system").is_none());
3277        assert!(body.get("temperature").is_none());
3278    }
3279
3280    #[test]
3281    fn anthropic_parse_response_text() {
3282        let driver = AnthropicDriver::new("test-key".to_string(), None);
3283        let response_body = serde_json::json!({
3284            "content": [
3285                {"type": "text", "text": "Hello!"}
3286            ],
3287            "stop_reason": "end_turn",
3288            "usage": {
3289                "input_tokens": 10,
3290                "output_tokens": 5
3291            }
3292        });
3293
3294        let resp = driver.parse_response(&response_body).unwrap();
3295        assert_eq!(resp.message.content, "Hello!");
3296        assert_eq!(resp.stop_reason, StopReason::EndTurn);
3297        assert_eq!(resp.usage.input_tokens, 10);
3298        assert_eq!(resp.usage.output_tokens, 5);
3299        assert!(resp.message.tool_calls.is_empty());
3300    }
3301
3302    #[test]
3303    fn anthropic_parse_response_tool_use() {
3304        let driver = AnthropicDriver::new("test-key".to_string(), None);
3305        let response_body = serde_json::json!({
3306            "content": [
3307                {"type": "text", "text": "Let me check."},
3308                {
3309                    "type": "tool_use",
3310                    "id": "tool_abc",
3311                    "name": "get_weather",
3312                    "input": {"city": "NYC"}
3313                }
3314            ],
3315            "stop_reason": "tool_use",
3316            "usage": {"input_tokens": 20, "output_tokens": 15}
3317        });
3318
3319        let resp = driver.parse_response(&response_body).unwrap();
3320        assert_eq!(resp.message.content, "Let me check.");
3321        assert_eq!(resp.stop_reason, StopReason::ToolUse);
3322        assert_eq!(resp.message.tool_calls.len(), 1);
3323        assert_eq!(resp.message.tool_calls[0].id, "tool_abc");
3324        assert_eq!(resp.message.tool_calls[0].name, "get_weather");
3325        assert_eq!(resp.message.tool_calls[0].input["city"], "NYC");
3326    }
3327
3328    #[test]
3329    fn anthropic_parse_response_max_tokens() {
3330        let driver = AnthropicDriver::new("test-key".to_string(), None);
3331        let response_body = serde_json::json!({
3332            "content": [{"type": "text", "text": "truncated"}],
3333            "stop_reason": "max_tokens",
3334            "usage": {"input_tokens": 5, "output_tokens": 100}
3335        });
3336
3337        let resp = driver.parse_response(&response_body).unwrap();
3338        assert_eq!(resp.stop_reason, StopReason::MaxTokens);
3339    }
3340
3341    #[test]
3342    fn anthropic_parse_response_unknown_stop_reason() {
3343        let driver = AnthropicDriver::new("test-key".to_string(), None);
3344        let response_body = serde_json::json!({
3345            "content": [{"type": "text", "text": "err"}],
3346            "stop_reason": "something_unknown",
3347            "usage": {"input_tokens": 0, "output_tokens": 0}
3348        });
3349
3350        let resp = driver.parse_response(&response_body).unwrap();
3351        assert_eq!(resp.stop_reason, StopReason::Error);
3352    }
3353
3354    #[test]
3355    fn anthropic_request_body_with_assistant_and_tool_messages() {
3356        let driver = AnthropicDriver::new("test-key".to_string(), None);
3357        let req = CompletionRequest {
3358            model: "test".into(),
3359            messages: vec![
3360                Message::new(Role::User, "Hi"),
3361                Message {
3362                    role: Role::Assistant,
3363                    content: "I'll check".into(),
3364                    tool_calls: vec![ToolCall {
3365                        id: "call_1".into(),
3366                        name: "file_read".into(),
3367                        input: serde_json::json!({"path": "/tmp/test"}),
3368                    }],
3369                    tool_results: Vec::new(),
3370                    timestamp: chrono::Utc::now(),
3371                },
3372                Message {
3373                    role: Role::Tool,
3374                    content: String::new(),
3375                    tool_calls: Vec::new(),
3376                    tool_results: vec![punch_types::ToolCallResult {
3377                        id: "call_1".into(),
3378                        content: "file contents".into(),
3379                        is_error: false,
3380                    }],
3381                    timestamp: chrono::Utc::now(),
3382                },
3383            ],
3384            tools: Vec::new(),
3385            max_tokens: 100,
3386            temperature: None,
3387            system_prompt: None,
3388        };
3389
3390        let body = driver.build_request_body(&req);
3391        let messages = body["messages"].as_array().unwrap();
3392        assert_eq!(messages.len(), 3);
3393        assert_eq!(messages[0]["role"], "user");
3394        assert_eq!(messages[1]["role"], "assistant");
3395        assert_eq!(messages[2]["role"], "user"); // Tool results go as user role
3396    }
3397
3398    #[test]
3399    fn anthropic_request_body_system_message_skipped() {
3400        let driver = AnthropicDriver::new("test-key".to_string(), None);
3401        let req = CompletionRequest {
3402            model: "test".into(),
3403            messages: vec![
3404                Message::new(Role::System, "System instruction"),
3405                Message::new(Role::User, "Hi"),
3406            ],
3407            tools: Vec::new(),
3408            max_tokens: 100,
3409            temperature: None,
3410            system_prompt: None,
3411        };
3412
3413        let body = driver.build_request_body(&req);
3414        let messages = body["messages"].as_array().unwrap();
3415        // System messages are skipped in messages array
3416        assert_eq!(messages.len(), 1);
3417        assert_eq!(messages[0]["role"], "user");
3418    }
3419
3420    // -----------------------------------------------------------------------
3421    // OpenAI-compatible driver tests
3422    // -----------------------------------------------------------------------
3423
3424    #[test]
3425    fn openai_request_body_simple() {
3426        let driver = OpenAiCompatibleDriver::new(
3427            "key".into(),
3428            "https://api.openai.com".into(),
3429            "openai".into(),
3430        );
3431        let body = driver.build_request_body(&simple_request());
3432
3433        assert_eq!(body["model"], "test-model");
3434        let messages = body["messages"].as_array().unwrap();
3435        assert_eq!(messages.len(), 2);
3436        assert_eq!(messages[0]["role"], "system");
3437        assert_eq!(messages[0]["content"], "You are helpful.");
3438        assert_eq!(messages[1]["role"], "user");
3439    }
3440
3441    #[test]
3442    fn openai_request_body_with_tools() {
3443        let driver = OpenAiCompatibleDriver::new(
3444            "key".into(),
3445            "https://api.openai.com".into(),
3446            "openai".into(),
3447        );
3448        let body = driver.build_request_body(&request_with_tools());
3449
3450        let tools = body["tools"].as_array().unwrap();
3451        assert_eq!(tools.len(), 1);
3452        assert_eq!(tools[0]["type"], "function");
3453        assert_eq!(tools[0]["function"]["name"], "get_weather");
3454    }
3455
3456    #[test]
3457    fn openai_parse_response_text() {
3458        let driver = OpenAiCompatibleDriver::new(
3459            "key".into(),
3460            "https://api.openai.com".into(),
3461            "openai".into(),
3462        );
3463        let response_body = serde_json::json!({
3464            "choices": [{
3465                "message": {
3466                    "role": "assistant",
3467                    "content": "Hello!"
3468                },
3469                "finish_reason": "stop"
3470            }],
3471            "usage": {
3472                "prompt_tokens": 10,
3473                "completion_tokens": 5
3474            }
3475        });
3476
3477        let resp = driver.parse_response(&response_body).unwrap();
3478        assert_eq!(resp.message.content, "Hello!");
3479        assert_eq!(resp.stop_reason, StopReason::EndTurn);
3480        assert_eq!(resp.usage.input_tokens, 10);
3481        assert_eq!(resp.usage.output_tokens, 5);
3482    }
3483
3484    #[test]
3485    fn openai_parse_response_tool_calls() {
3486        let driver = OpenAiCompatibleDriver::new(
3487            "key".into(),
3488            "https://api.openai.com".into(),
3489            "openai".into(),
3490        );
3491        let response_body = serde_json::json!({
3492            "choices": [{
3493                "message": {
3494                    "role": "assistant",
3495                    "content": null,
3496                    "tool_calls": [{
3497                        "id": "call_123",
3498                        "type": "function",
3499                        "function": {
3500                            "name": "get_weather",
3501                            "arguments": "{\"city\": \"NYC\"}"
3502                        }
3503                    }]
3504                },
3505                "finish_reason": "tool_calls"
3506            }],
3507            "usage": {"prompt_tokens": 10, "completion_tokens": 5}
3508        });
3509
3510        let resp = driver.parse_response(&response_body).unwrap();
3511        assert_eq!(resp.stop_reason, StopReason::ToolUse);
3512        assert_eq!(resp.message.tool_calls.len(), 1);
3513        assert_eq!(resp.message.tool_calls[0].name, "get_weather");
3514        assert_eq!(resp.message.tool_calls[0].input["city"], "NYC");
3515    }
3516
3517    #[test]
3518    fn openai_parse_response_tool_calls_fix_stop_reason() {
3519        let driver = OpenAiCompatibleDriver::new(
3520            "key".into(),
3521            "https://api.openai.com".into(),
3522            "openai".into(),
3523        );
3524        // finish_reason is "stop" but there are tool_calls — should fix to ToolUse
3525        let response_body = serde_json::json!({
3526            "choices": [{
3527                "message": {
3528                    "role": "assistant",
3529                    "content": "Using tool",
3530                    "tool_calls": [{
3531                        "id": "call_1",
3532                        "type": "function",
3533                        "function": {
3534                            "name": "test_tool",
3535                            "arguments": "{}"
3536                        }
3537                    }]
3538                },
3539                "finish_reason": "stop"
3540            }],
3541            "usage": {"prompt_tokens": 0, "completion_tokens": 0}
3542        });
3543
3544        let resp = driver.parse_response(&response_body).unwrap();
3545        assert_eq!(resp.stop_reason, StopReason::ToolUse);
3546    }
3547
3548    #[test]
3549    fn openai_parse_response_length_stop_reason() {
3550        let driver = OpenAiCompatibleDriver::new(
3551            "key".into(),
3552            "https://api.openai.com".into(),
3553            "openai".into(),
3554        );
3555        let response_body = serde_json::json!({
3556            "choices": [{
3557                "message": {"role": "assistant", "content": "cut off"},
3558                "finish_reason": "length"
3559            }],
3560            "usage": {"prompt_tokens": 0, "completion_tokens": 0}
3561        });
3562
3563        let resp = driver.parse_response(&response_body).unwrap();
3564        assert_eq!(resp.stop_reason, StopReason::MaxTokens);
3565    }
3566
3567    #[test]
3568    fn openai_parse_response_no_choices_error() {
3569        let driver = OpenAiCompatibleDriver::new(
3570            "key".into(),
3571            "https://api.openai.com".into(),
3572            "openai".into(),
3573        );
3574        let response_body = serde_json::json!({"choices": []});
3575
3576        let result = driver.parse_response(&response_body);
3577        assert!(result.is_err());
3578    }
3579
3580    // -----------------------------------------------------------------------
3581    // Gemini driver additional tests
3582    // -----------------------------------------------------------------------
3583
3584    #[test]
3585    fn gemini_assistant_message_formatting() {
3586        let driver = GeminiDriver::new("key".to_string(), None);
3587        let req = CompletionRequest {
3588            model: "gemini-pro".into(),
3589            messages: vec![
3590                Message::new(Role::User, "Hi"),
3591                Message {
3592                    role: Role::Assistant,
3593                    content: "Let me help".into(),
3594                    tool_calls: vec![ToolCall {
3595                        id: "tc1".into(),
3596                        name: "get_weather".into(),
3597                        input: serde_json::json!({"city": "NYC"}),
3598                    }],
3599                    tool_results: Vec::new(),
3600                    timestamp: chrono::Utc::now(),
3601                },
3602            ],
3603            tools: Vec::new(),
3604            max_tokens: 100,
3605            temperature: None,
3606            system_prompt: None,
3607        };
3608
3609        let body = driver.build_request_body(&req);
3610        let contents = body["contents"].as_array().unwrap();
3611        assert_eq!(contents[1]["role"], "model"); // Gemini uses "model" not "assistant"
3612        let parts = contents[1]["parts"].as_array().unwrap();
3613        assert!(parts.len() >= 2); // text part + functionCall part
3614    }
3615
3616    #[test]
3617    fn gemini_max_tokens_stop_reason() {
3618        let driver = GeminiDriver::new("key".to_string(), None);
3619        let response_body = serde_json::json!({
3620            "candidates": [{
3621                "content": {
3622                    "parts": [{"text": "truncated"}],
3623                    "role": "model"
3624                },
3625                "finishReason": "MAX_TOKENS"
3626            }],
3627            "usageMetadata": {"promptTokenCount": 0, "candidatesTokenCount": 0}
3628        });
3629
3630        let resp = driver.parse_response(&response_body).unwrap();
3631        assert_eq!(resp.stop_reason, StopReason::MaxTokens);
3632    }
3633
3634    #[test]
3635    fn gemini_custom_base_url() {
3636        let driver =
3637            GeminiDriver::new("key".to_string(), Some("https://custom.example.com".into()));
3638        let url = driver.build_url("gemini-pro");
3639        assert!(url.starts_with("https://custom.example.com/"));
3640    }
3641
3642    // -----------------------------------------------------------------------
3643    // Ollama driver additional tests
3644    // -----------------------------------------------------------------------
3645
3646    #[test]
3647    fn ollama_response_with_tool_calls() {
3648        let driver = OllamaDriver::new(None);
3649        let response_body = serde_json::json!({
3650            "message": {
3651                "role": "assistant",
3652                "content": "",
3653                "tool_calls": [{
3654                    "function": {
3655                        "name": "get_weather",
3656                        "arguments": {"city": "London"}
3657                    }
3658                }]
3659            },
3660            "done": true,
3661            "prompt_eval_count": 10,
3662            "eval_count": 5
3663        });
3664
3665        let resp = driver.parse_response(&response_body).unwrap();
3666        assert_eq!(resp.stop_reason, StopReason::ToolUse);
3667        assert_eq!(resp.message.tool_calls.len(), 1);
3668        assert_eq!(resp.message.tool_calls[0].name, "get_weather");
3669    }
3670
3671    #[test]
3672    fn ollama_response_not_done() {
3673        let driver = OllamaDriver::new(None);
3674        let response_body = serde_json::json!({
3675            "message": {"role": "assistant", "content": "partial"},
3676            "done": false,
3677            "prompt_eval_count": 10,
3678            "eval_count": 5
3679        });
3680
3681        let resp = driver.parse_response(&response_body).unwrap();
3682        assert_eq!(resp.stop_reason, StopReason::MaxTokens);
3683    }
3684
3685    // -----------------------------------------------------------------------
3686    // Bedrock driver additional tests
3687    // -----------------------------------------------------------------------
3688
3689    #[test]
3690    fn bedrock_request_with_tools() {
3691        let driver = BedrockDriver::new("key".into(), "secret".into(), "us-east-1".into());
3692        let body = driver.build_request_body(&request_with_tools());
3693
3694        let tool_config = &body["toolConfig"]["tools"];
3695        assert!(tool_config.is_array());
3696        let tools = tool_config.as_array().unwrap();
3697        assert_eq!(tools.len(), 1);
3698        assert_eq!(tools[0]["toolSpec"]["name"], "get_weather");
3699    }
3700
3701    #[test]
3702    fn bedrock_response_with_tool_use() {
3703        let driver = BedrockDriver::new("key".into(), "secret".into(), "us-east-1".into());
3704        let response_body = serde_json::json!({
3705            "output": {
3706                "message": {
3707                    "role": "assistant",
3708                    "content": [
3709                        {"text": "Using tool"},
3710                        {"toolUse": {
3711                            "toolUseId": "tu_123",
3712                            "name": "get_weather",
3713                            "input": {"city": "NYC"}
3714                        }}
3715                    ]
3716                }
3717            },
3718            "stopReason": "tool_use",
3719            "usage": {"inputTokens": 10, "outputTokens": 20}
3720        });
3721
3722        let resp = driver.parse_response(&response_body).unwrap();
3723        assert_eq!(resp.stop_reason, StopReason::ToolUse);
3724        assert_eq!(resp.message.tool_calls.len(), 1);
3725        assert_eq!(resp.message.tool_calls[0].id, "tu_123");
3726        assert_eq!(resp.message.tool_calls[0].name, "get_weather");
3727    }
3728
3729    #[test]
3730    fn bedrock_request_with_tool_results() {
3731        let driver = BedrockDriver::new("key".into(), "secret".into(), "us-east-1".into());
3732        let req = CompletionRequest {
3733            model: "test".into(),
3734            messages: vec![
3735                Message::new(Role::User, "Hi"),
3736                Message {
3737                    role: Role::Tool,
3738                    content: String::new(),
3739                    tool_calls: Vec::new(),
3740                    tool_results: vec![punch_types::ToolCallResult {
3741                        id: "tu_1".into(),
3742                        content: "result data".into(),
3743                        is_error: false,
3744                    }],
3745                    timestamp: chrono::Utc::now(),
3746                },
3747            ],
3748            tools: Vec::new(),
3749            max_tokens: 100,
3750            temperature: None,
3751            system_prompt: None,
3752        };
3753
3754        let body = driver.build_request_body(&req);
3755        let messages = body["messages"].as_array().unwrap();
3756        assert_eq!(messages[1]["role"], "user"); // Bedrock sends tool results as user
3757        let content = messages[1]["content"].as_array().unwrap();
3758        assert!(content[0]["toolResult"].is_object());
3759        assert_eq!(content[0]["toolResult"]["status"], "success");
3760    }
3761
3762    #[test]
3763    fn bedrock_url_different_regions() {
3764        let driver = BedrockDriver::new("k".into(), "s".into(), "ap-southeast-1".into());
3765        let url = driver.build_url("model-id");
3766        assert!(url.contains("ap-southeast-1"));
3767    }
3768
3769    // -----------------------------------------------------------------------
3770    // Azure OpenAI additional tests
3771    // -----------------------------------------------------------------------
3772
3773    #[test]
3774    fn azure_openai_delegates_parse_to_openai() {
3775        let driver = AzureOpenAiDriver::new("key".into(), "res".into(), "dep".into(), None);
3776        let response_body = serde_json::json!({
3777            "choices": [{
3778                "message": {"role": "assistant", "content": "Azure response"},
3779                "finish_reason": "stop"
3780            }],
3781            "usage": {"prompt_tokens": 5, "completion_tokens": 3}
3782        });
3783
3784        let resp = driver.parse_response(&response_body).unwrap();
3785        assert_eq!(resp.message.content, "Azure response");
3786    }
3787
3788    // -----------------------------------------------------------------------
3789    // default_base_url tests
3790    // -----------------------------------------------------------------------
3791
3792    #[test]
3793    fn default_base_url_anthropic() {
3794        assert_eq!(
3795            default_base_url(&Provider::Anthropic),
3796            "https://api.anthropic.com"
3797        );
3798    }
3799
3800    #[test]
3801    fn default_base_url_openai() {
3802        assert_eq!(
3803            default_base_url(&Provider::OpenAI),
3804            "https://api.openai.com"
3805        );
3806    }
3807
3808    #[test]
3809    fn default_base_url_google() {
3810        assert_eq!(
3811            default_base_url(&Provider::Google),
3812            "https://generativelanguage.googleapis.com"
3813        );
3814    }
3815
3816    #[test]
3817    fn default_base_url_ollama() {
3818        assert_eq!(
3819            default_base_url(&Provider::Ollama),
3820            "http://localhost:11434"
3821        );
3822    }
3823
3824    #[test]
3825    fn default_base_url_groq() {
3826        assert_eq!(
3827            default_base_url(&Provider::Groq),
3828            "https://api.groq.com/openai"
3829        );
3830    }
3831
3832    #[test]
3833    fn default_base_url_deepseek() {
3834        assert_eq!(
3835            default_base_url(&Provider::DeepSeek),
3836            "https://api.deepseek.com"
3837        );
3838    }
3839
3840    // -----------------------------------------------------------------------
3841    // hex_sha256 and hex_encode tests
3842    // -----------------------------------------------------------------------
3843
3844    #[test]
3845    fn test_hex_sha256() {
3846        let hash = hex_sha256(b"");
3847        assert_eq!(
3848            hash,
3849            "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855"
3850        );
3851    }
3852
3853    #[test]
3854    fn test_hex_encode() {
3855        assert_eq!(hex_encode(&[0x00, 0xff, 0x0a, 0xbc]), "00ff0abc");
3856    }
3857
3858    #[test]
3859    fn test_hmac_sha256_basic() {
3860        let result = hmac_sha256(b"key", b"data");
3861        assert!(!result.is_empty());
3862        assert_eq!(result.len(), 32); // SHA-256 produces 32 bytes
3863    }
3864
3865    // -----------------------------------------------------------------------
3866    // create_driver error cases
3867    // -----------------------------------------------------------------------
3868
3869    #[test]
3870    fn create_driver_missing_api_key_env() {
3871        let config = ModelConfig {
3872            provider: Provider::Anthropic,
3873            model: "claude-3".into(),
3874            api_key_env: Some("PUNCH_TEST_NONEXISTENT_KEY_XYZ".into()),
3875            base_url: None,
3876            max_tokens: None,
3877            temperature: None,
3878        };
3879        let result = create_driver(&config);
3880        assert!(result.is_err());
3881    }
3882
3883    #[test]
3884    fn create_driver_openai_compatible_fallback() {
3885        // Custom provider should fall through to OpenAI-compatible
3886        unsafe { std::env::set_var("TEST_CUSTOM_KEY_DRIVER", "fake-key") };
3887        let config = ModelConfig {
3888            provider: Provider::Custom("my-custom".into()),
3889            model: "custom-model".into(),
3890            api_key_env: Some("TEST_CUSTOM_KEY_DRIVER".into()),
3891            base_url: Some("https://custom.api.com".into()),
3892            max_tokens: None,
3893            temperature: None,
3894        };
3895        let result = create_driver(&config);
3896        assert!(result.is_ok());
3897        unsafe { std::env::remove_var("TEST_CUSTOM_KEY_DRIVER") };
3898    }
3899
3900    // -----------------------------------------------------------------------
3901    // strip_thinking_tags tests
3902    // -----------------------------------------------------------------------
3903
3904    #[test]
3905    fn strip_thinking_tags_removes_think_block() {
3906        let input = "<think>internal reasoning here</think>The answer is 42.";
3907        assert_eq!(strip_thinking_tags(input), "The answer is 42.");
3908    }
3909
3910    #[test]
3911    fn strip_thinking_tags_removes_thinking_block() {
3912        let input = "<thinking>step by step reasoning</thinking>Hello world!";
3913        assert_eq!(strip_thinking_tags(input), "Hello world!");
3914    }
3915
3916    #[test]
3917    fn strip_thinking_tags_removes_reasoning_block() {
3918        let input = "<reasoning>let me figure this out</reasoning>The result is correct.";
3919        assert_eq!(strip_thinking_tags(input), "The result is correct.");
3920    }
3921
3922    #[test]
3923    fn strip_thinking_tags_removes_reflection_block() {
3924        let input = "<reflection>checking my work</reflection>Yes, that's right.";
3925        assert_eq!(strip_thinking_tags(input), "Yes, that's right.");
3926    }
3927
3928    #[test]
3929    fn strip_thinking_tags_removes_multiple_blocks() {
3930        let input = "<think>first thought</think>Hello <thinking>second thought</thinking>world!";
3931        assert_eq!(strip_thinking_tags(input), "Hello world!");
3932    }
3933
3934    #[test]
3935    fn strip_thinking_tags_preserves_content_without_tags() {
3936        let input = "Just a normal response with no thinking tags.";
3937        assert_eq!(strip_thinking_tags(input), input);
3938    }
3939
3940    #[test]
3941    fn strip_thinking_tags_handles_multiline_tags() {
3942        let input = "<think>\nLine 1\nLine 2\nLine 3\n</think>\nThe final answer.";
3943        assert_eq!(strip_thinking_tags(input), "The final answer.");
3944    }
3945
3946    #[test]
3947    fn strip_thinking_tags_returns_original_if_all_thinking() {
3948        // If the entire response is thinking with no visible output,
3949        // return the original so the user sees something.
3950        let input = "<think>this is all thinking content and nothing else</think>";
3951        assert_eq!(strip_thinking_tags(input), input);
3952    }
3953
3954    #[test]
3955    fn strip_thinking_tags_handles_unclosed_tag() {
3956        let input = "Some text<think>unclosed thinking block";
3957        assert_eq!(strip_thinking_tags(input), "Some text");
3958    }
3959
3960    #[test]
3961    fn strip_thinking_tags_handles_empty_input() {
3962        assert_eq!(strip_thinking_tags(""), "");
3963    }
3964
3965    #[test]
3966    fn strip_thinking_tags_handles_empty_think_block() {
3967        let input = "<think></think>Visible content.";
3968        assert_eq!(strip_thinking_tags(input), "Visible content.");
3969    }
3970
3971    #[test]
3972    fn strip_thinking_tags_trims_whitespace() {
3973        let input = "  <think>reasoning</think>  Result  ";
3974        assert_eq!(strip_thinking_tags(input), "Result");
3975    }
3976
3977    #[test]
3978    fn strip_thinking_tags_mixed_tag_types() {
3979        let input = "<think>t1</think>A<reasoning>r1</reasoning>B<reflection>f1</reflection>C";
3980        assert_eq!(strip_thinking_tags(input), "ABC");
3981    }
3982
3983    #[test]
3984    fn ollama_response_strips_thinking_tags() {
3985        let driver = OllamaDriver::new(None);
3986        let response_body = serde_json::json!({
3987            "message": {
3988                "role": "assistant",
3989                "content": "<think>\nLet me think about this...\nThe user wants hello world.\n</think>\nHello, world!"
3990            },
3991            "done": true,
3992            "prompt_eval_count": 20,
3993            "eval_count": 50
3994        });
3995
3996        let resp = driver.parse_response(&response_body).unwrap();
3997        assert_eq!(resp.message.content, "Hello, world!");
3998        assert!(!resp.message.content.contains("<think>"));
3999    }
4000
4001    #[test]
4002    fn gemini_response_strips_thinking_tags() {
4003        let driver = GeminiDriver::new("test-key".to_string(), None);
4004        let response_body = serde_json::json!({
4005            "candidates": [{
4006                "content": {
4007                    "parts": [{"text": "<thinking>reasoning step</thinking>The answer is 7."}],
4008                    "role": "model"
4009                },
4010                "finishReason": "STOP"
4011            }],
4012            "usageMetadata": {
4013                "promptTokenCount": 10,
4014                "candidatesTokenCount": 20
4015            }
4016        });
4017
4018        let resp = driver.parse_response(&response_body).unwrap();
4019        assert_eq!(resp.message.content, "The answer is 7.");
4020        assert!(!resp.message.content.contains("<thinking>"));
4021    }
4022
4023    #[test]
4024    fn anthropic_response_strips_thinking_tags() {
4025        let driver = AnthropicDriver::new("test-key".to_string(), None);
4026        let response_body = serde_json::json!({
4027            "content": [
4028                {"type": "text", "text": "<think>internal thought</think>Clean output."}
4029            ],
4030            "stop_reason": "end_turn",
4031            "usage": {"input_tokens": 10, "output_tokens": 5}
4032        });
4033
4034        let resp = driver.parse_response(&response_body).unwrap();
4035        assert_eq!(resp.message.content, "Clean output.");
4036    }
4037
4038    #[test]
4039    fn bedrock_response_strips_thinking_tags() {
4040        let driver = BedrockDriver::new(
4041            "key".to_string(),
4042            "secret".to_string(),
4043            "us-east-1".to_string(),
4044        );
4045        let response_body = serde_json::json!({
4046            "output": {
4047                "message": {
4048                    "role": "assistant",
4049                    "content": [{"text": "<reasoning>deep thought</reasoning>Result here."}]
4050                }
4051            },
4052            "stopReason": "end_turn",
4053            "usage": {"inputTokens": 50, "outputTokens": 25}
4054        });
4055
4056        let resp = driver.parse_response(&response_body).unwrap();
4057        assert_eq!(resp.message.content, "Result here.");
4058    }
4059
4060    // -----------------------------------------------------------------------
4061    // StreamChunk / ToolCallDelta serialization tests
4062    // -----------------------------------------------------------------------
4063
4064    #[test]
4065    fn stream_chunk_serialization_roundtrip() {
4066        let chunk = StreamChunk {
4067            delta: "Hello".to_string(),
4068            is_final: false,
4069            tool_call_delta: None,
4070        };
4071        let json = serde_json::to_string(&chunk).unwrap();
4072        let deserialized: StreamChunk = serde_json::from_str(&json).unwrap();
4073        assert_eq!(deserialized.delta, "Hello");
4074        assert!(!deserialized.is_final);
4075        assert!(deserialized.tool_call_delta.is_none());
4076    }
4077
4078    #[test]
4079    fn stream_chunk_with_tool_call_delta_serialization() {
4080        let chunk = StreamChunk {
4081            delta: String::new(),
4082            is_final: false,
4083            tool_call_delta: Some(ToolCallDelta {
4084                index: 0,
4085                id: Some("call_123".to_string()),
4086                name: Some("get_weather".to_string()),
4087                arguments_delta: "{\"city\":".to_string(),
4088            }),
4089        };
4090        let json = serde_json::to_string(&chunk).unwrap();
4091        let deserialized: StreamChunk = serde_json::from_str(&json).unwrap();
4092        let tcd = deserialized.tool_call_delta.unwrap();
4093        assert_eq!(tcd.index, 0);
4094        assert_eq!(tcd.id.unwrap(), "call_123");
4095        assert_eq!(tcd.name.unwrap(), "get_weather");
4096        assert_eq!(tcd.arguments_delta, "{\"city\":");
4097    }
4098
4099    #[test]
4100    fn stream_chunk_final_serialization() {
4101        let chunk = StreamChunk {
4102            delta: String::new(),
4103            is_final: true,
4104            tool_call_delta: None,
4105        };
4106        let json = serde_json::to_string(&chunk).unwrap();
4107        assert!(json.contains("\"is_final\":true"));
4108    }
4109
4110    #[test]
4111    fn tool_call_delta_serialization_roundtrip() {
4112        let tcd = ToolCallDelta {
4113            index: 2,
4114            id: None,
4115            name: None,
4116            arguments_delta: "\"NYC\"}".to_string(),
4117        };
4118        let json = serde_json::to_string(&tcd).unwrap();
4119        let deserialized: ToolCallDelta = serde_json::from_str(&json).unwrap();
4120        assert_eq!(deserialized.index, 2);
4121        assert!(deserialized.id.is_none());
4122        assert!(deserialized.name.is_none());
4123        assert_eq!(deserialized.arguments_delta, "\"NYC\"}");
4124    }
4125
4126    // -----------------------------------------------------------------------
4127    // SSE parsing tests
4128    // -----------------------------------------------------------------------
4129
4130    #[test]
4131    fn parse_sse_events_basic() {
4132        let raw = "event: message_start\ndata: {\"type\":\"message_start\"}\n\nevent: content_block_delta\ndata: {\"delta\":{\"text\":\"Hi\"}}\n\n";
4133        let events = parse_sse_events(raw);
4134        assert_eq!(events.len(), 2);
4135        assert_eq!(events[0].0, "message_start");
4136        assert_eq!(events[1].0, "content_block_delta");
4137    }
4138
4139    #[test]
4140    fn parse_sse_events_with_done() {
4141        let raw = "data: {\"choices\":[{\"delta\":{\"content\":\"Hello\"}}]}\n\ndata: [DONE]\n\n";
4142        let events = parse_sse_events(raw);
4143        assert_eq!(events.len(), 2);
4144        assert_eq!(events[1].1, "[DONE]");
4145    }
4146
4147    #[test]
4148    fn parse_sse_events_empty_input() {
4149        let events = parse_sse_events("");
4150        assert!(events.is_empty());
4151    }
4152
4153    #[test]
4154    fn parse_sse_events_no_trailing_newline() {
4155        let raw = "event: test\ndata: {\"value\":1}";
4156        let events = parse_sse_events(raw);
4157        assert_eq!(events.len(), 1);
4158        assert_eq!(events[0].0, "test");
4159    }
4160
4161    #[test]
4162    fn parse_sse_events_multiline_data() {
4163        let raw = "data: line1\ndata: line2\n\n";
4164        let events = parse_sse_events(raw);
4165        assert_eq!(events.len(), 1);
4166        assert_eq!(events[0].1, "line1\nline2");
4167    }
4168
4169    #[test]
4170    fn parse_sse_events_no_event_field() {
4171        let raw = "data: {\"hello\":\"world\"}\n\n";
4172        let events = parse_sse_events(raw);
4173        assert_eq!(events.len(), 1);
4174        assert_eq!(events[0].0, "message"); // default event type
4175    }
4176
4177    // -----------------------------------------------------------------------
4178    // Anthropic streaming tests
4179    // -----------------------------------------------------------------------
4180
4181    #[test]
4182    fn anthropic_stream_text_only() {
4183        let raw = concat!(
4184            "event: message_start\n",
4185            "data: {\"type\":\"message_start\",\"message\":{\"usage\":{\"input_tokens\":25}}}\n\n",
4186            "event: content_block_start\n",
4187            "data: {\"type\":\"content_block_start\",\"index\":0,\"content_block\":{\"type\":\"text\",\"text\":\"\"}}\n\n",
4188            "event: content_block_delta\n",
4189            "data: {\"type\":\"content_block_delta\",\"delta\":{\"type\":\"text_delta\",\"text\":\"Hello\"}}\n\n",
4190            "event: content_block_delta\n",
4191            "data: {\"type\":\"content_block_delta\",\"delta\":{\"type\":\"text_delta\",\"text\":\" world\"}}\n\n",
4192            "event: message_delta\n",
4193            "data: {\"type\":\"message_delta\",\"delta\":{\"stop_reason\":\"end_turn\"},\"usage\":{\"output_tokens\":10}}\n\n",
4194            "event: message_stop\n",
4195            "data: {\"type\":\"message_stop\"}\n\n",
4196        );
4197
4198        let events = parse_sse_events(raw);
4199        let chunks: Arc<std::sync::Mutex<Vec<StreamChunk>>> =
4200            Arc::new(std::sync::Mutex::new(Vec::new()));
4201        let chunks_clone = chunks.clone();
4202        let callback: StreamCallback = Arc::new(move |chunk| {
4203            chunks_clone.lock().unwrap().push(chunk);
4204        });
4205
4206        // Simulate the Anthropic stream processing
4207        let mut text_content = String::new();
4208        let mut usage = TokenUsage::default();
4209        let mut stop_reason = StopReason::EndTurn;
4210
4211        for (event_type, data) in &events {
4212            let parsed: serde_json::Value = match serde_json::from_str(data) {
4213                Ok(v) => v,
4214                Err(_) => continue,
4215            };
4216
4217            match event_type.as_str() {
4218                "message_start" => {
4219                    if let Some(inp) = parsed["message"]["usage"]["input_tokens"].as_u64() {
4220                        usage.input_tokens = inp;
4221                    }
4222                }
4223                "content_block_delta" => {
4224                    if let Some(text) = parsed["delta"]["text"].as_str() {
4225                        text_content.push_str(text);
4226                        callback(StreamChunk {
4227                            delta: text.to_string(),
4228                            is_final: false,
4229                            tool_call_delta: None,
4230                        });
4231                    }
4232                }
4233                "message_delta" => {
4234                    if let Some(sr) = parsed["delta"]["stop_reason"].as_str() {
4235                        stop_reason = match sr {
4236                            "end_turn" => StopReason::EndTurn,
4237                            "tool_use" => StopReason::ToolUse,
4238                            _ => StopReason::Error,
4239                        };
4240                    }
4241                    if let Some(out) = parsed["usage"]["output_tokens"].as_u64() {
4242                        usage.output_tokens = out;
4243                    }
4244                }
4245                "message_stop" => {
4246                    callback(StreamChunk {
4247                        delta: String::new(),
4248                        is_final: true,
4249                        tool_call_delta: None,
4250                    });
4251                }
4252                _ => {}
4253            }
4254        }
4255
4256        assert_eq!(text_content, "Hello world");
4257        assert_eq!(usage.input_tokens, 25);
4258        assert_eq!(usage.output_tokens, 10);
4259        assert_eq!(stop_reason, StopReason::EndTurn);
4260
4261        let received = chunks.lock().unwrap();
4262        assert_eq!(received.len(), 3); // "Hello", " world", final
4263        assert_eq!(received[0].delta, "Hello");
4264        assert_eq!(received[1].delta, " world");
4265        assert!(received[2].is_final);
4266    }
4267
4268    #[test]
4269    fn anthropic_stream_with_tool_use() {
4270        let raw = concat!(
4271            "event: message_start\n",
4272            "data: {\"type\":\"message_start\",\"message\":{\"usage\":{\"input_tokens\":15}}}\n\n",
4273            "event: content_block_start\n",
4274            "data: {\"type\":\"content_block_start\",\"index\":0,\"content_block\":{\"type\":\"text\",\"text\":\"\"}}\n\n",
4275            "event: content_block_delta\n",
4276            "data: {\"type\":\"content_block_delta\",\"delta\":{\"type\":\"text_delta\",\"text\":\"Checking.\"}}\n\n",
4277            "event: content_block_start\n",
4278            "data: {\"type\":\"content_block_start\",\"index\":1,\"content_block\":{\"type\":\"tool_use\",\"id\":\"tool_1\",\"name\":\"get_weather\"}}\n\n",
4279            "event: content_block_delta\n",
4280            "data: {\"type\":\"content_block_delta\",\"delta\":{\"type\":\"input_json_delta\",\"partial_json\":\"{\\\"city\\\"\"}}\n\n",
4281            "event: content_block_delta\n",
4282            "data: {\"type\":\"content_block_delta\",\"delta\":{\"type\":\"input_json_delta\",\"partial_json\":\": \\\"NYC\\\"}\"}}\n\n",
4283            "event: message_delta\n",
4284            "data: {\"type\":\"message_delta\",\"delta\":{\"stop_reason\":\"tool_use\"},\"usage\":{\"output_tokens\":20}}\n\n",
4285            "event: message_stop\n",
4286            "data: {\"type\":\"message_stop\"}\n\n",
4287        );
4288
4289        let events = parse_sse_events(raw);
4290        // Verify we can parse all events
4291        assert!(events.len() >= 7);
4292
4293        // Verify tool JSON reconstruction
4294        let mut tool_json_bufs: Vec<String> = Vec::new();
4295        let mut tc_idx: Option<usize> = None;
4296
4297        for (event_type, data) in &events {
4298            let parsed: serde_json::Value = match serde_json::from_str(data) {
4299                Ok(v) => v,
4300                Err(_) => continue,
4301            };
4302            match event_type.as_str() {
4303                "content_block_start" => {
4304                    if parsed["content_block"]["type"].as_str() == Some("tool_use") {
4305                        tool_json_bufs.push(String::new());
4306                        tc_idx = Some(tool_json_bufs.len() - 1);
4307                    } else {
4308                        tc_idx = None;
4309                    }
4310                }
4311                "content_block_delta" => {
4312                    if parsed["delta"]["type"].as_str() == Some("input_json_delta")
4313                        && let Some(idx) = tc_idx
4314                        && let Some(buf) = tool_json_bufs.get_mut(idx)
4315                    {
4316                        buf.push_str(parsed["delta"]["partial_json"].as_str().unwrap_or(""));
4317                    }
4318                }
4319                _ => {}
4320            }
4321        }
4322
4323        assert_eq!(tool_json_bufs.len(), 1);
4324        assert_eq!(tool_json_bufs[0], "{\"city\": \"NYC\"}");
4325
4326        let parsed_input: serde_json::Value = serde_json::from_str(&tool_json_bufs[0]).unwrap();
4327        assert_eq!(parsed_input["city"], "NYC");
4328    }
4329
4330    // -----------------------------------------------------------------------
4331    // OpenAI streaming tests
4332    // -----------------------------------------------------------------------
4333
4334    #[test]
4335    fn openai_stream_text_only() {
4336        let driver = OpenAiCompatibleDriver::new(
4337            "key".into(),
4338            "https://api.openai.com".into(),
4339            "openai".into(),
4340        );
4341
4342        let raw = concat!(
4343            "data: {\"choices\":[{\"delta\":{\"role\":\"assistant\"},\"index\":0}]}\n\n",
4344            "data: {\"choices\":[{\"delta\":{\"content\":\"Hello\"},\"index\":0}]}\n\n",
4345            "data: {\"choices\":[{\"delta\":{\"content\":\" world\"},\"index\":0}]}\n\n",
4346            "data: {\"choices\":[{\"delta\":{},\"index\":0,\"finish_reason\":\"stop\"}]}\n\n",
4347            "data: [DONE]\n\n",
4348        );
4349
4350        let chunks: Arc<std::sync::Mutex<Vec<StreamChunk>>> =
4351            Arc::new(std::sync::Mutex::new(Vec::new()));
4352        let chunks_clone = chunks.clone();
4353        let callback: StreamCallback = Arc::new(move |chunk| {
4354            chunks_clone.lock().unwrap().push(chunk);
4355        });
4356
4357        let resp = driver.parse_openai_stream(raw, &callback).unwrap();
4358
4359        assert_eq!(resp.message.content, "Hello world");
4360        assert_eq!(resp.stop_reason, StopReason::EndTurn);
4361        assert!(resp.message.tool_calls.is_empty());
4362
4363        let received = chunks.lock().unwrap();
4364        // "Hello", " world", final [DONE]
4365        assert!(received.len() >= 3);
4366        assert_eq!(received[0].delta, "Hello");
4367        assert_eq!(received[1].delta, " world");
4368        assert!(received.last().unwrap().is_final);
4369    }
4370
4371    #[test]
4372    fn openai_stream_with_tool_calls() {
4373        let driver = OpenAiCompatibleDriver::new(
4374            "key".into(),
4375            "https://api.openai.com".into(),
4376            "openai".into(),
4377        );
4378
4379        let raw = concat!(
4380            "data: {\"choices\":[{\"delta\":{\"tool_calls\":[{\"index\":0,\"id\":\"call_abc\",\"type\":\"function\",\"function\":{\"name\":\"get_weather\",\"arguments\":\"\"}}]},\"index\":0}]}\n\n",
4381            "data: {\"choices\":[{\"delta\":{\"tool_calls\":[{\"index\":0,\"function\":{\"arguments\":\"{\\\"ci\"}}]},\"index\":0}]}\n\n",
4382            "data: {\"choices\":[{\"delta\":{\"tool_calls\":[{\"index\":0,\"function\":{\"arguments\":\"ty\\\": \\\"NYC\\\"}\"}}]},\"index\":0}]}\n\n",
4383            "data: {\"choices\":[{\"delta\":{},\"index\":0,\"finish_reason\":\"tool_calls\"}]}\n\n",
4384            "data: [DONE]\n\n",
4385        );
4386
4387        let chunks: Arc<std::sync::Mutex<Vec<StreamChunk>>> =
4388            Arc::new(std::sync::Mutex::new(Vec::new()));
4389        let chunks_clone = chunks.clone();
4390        let callback: StreamCallback = Arc::new(move |chunk| {
4391            chunks_clone.lock().unwrap().push(chunk);
4392        });
4393
4394        let resp = driver.parse_openai_stream(raw, &callback).unwrap();
4395
4396        assert_eq!(resp.stop_reason, StopReason::ToolUse);
4397        assert_eq!(resp.message.tool_calls.len(), 1);
4398        assert_eq!(resp.message.tool_calls[0].id, "call_abc");
4399        assert_eq!(resp.message.tool_calls[0].name, "get_weather");
4400        assert_eq!(resp.message.tool_calls[0].input["city"], "NYC");
4401
4402        let received = chunks.lock().unwrap();
4403        // Should have tool call delta chunks and a final chunk
4404        let tool_chunks: Vec<_> = received
4405            .iter()
4406            .filter(|c| c.tool_call_delta.is_some())
4407            .collect();
4408        assert!(tool_chunks.len() >= 3); // id+name, partial args, more args
4409        assert!(received.last().unwrap().is_final);
4410    }
4411
4412    #[test]
4413    fn openai_stream_with_mixed_content_and_tools() {
4414        let driver = OpenAiCompatibleDriver::new(
4415            "key".into(),
4416            "https://api.openai.com".into(),
4417            "openai".into(),
4418        );
4419
4420        let raw = concat!(
4421            "data: {\"choices\":[{\"delta\":{\"content\":\"Sure, \"},\"index\":0}]}\n\n",
4422            "data: {\"choices\":[{\"delta\":{\"content\":\"checking.\"},\"index\":0}]}\n\n",
4423            "data: {\"choices\":[{\"delta\":{\"tool_calls\":[{\"index\":0,\"id\":\"call_1\",\"type\":\"function\",\"function\":{\"name\":\"search\",\"arguments\":\"{\\\"q\\\":\\\"test\\\"}\"}}]},\"index\":0}]}\n\n",
4424            "data: {\"choices\":[{\"delta\":{},\"index\":0,\"finish_reason\":\"tool_calls\"}]}\n\n",
4425            "data: [DONE]\n\n",
4426        );
4427
4428        let callback: StreamCallback = Arc::new(|_| {});
4429        let resp = driver.parse_openai_stream(raw, &callback).unwrap();
4430
4431        assert_eq!(resp.message.content, "Sure, checking.");
4432        assert_eq!(resp.stop_reason, StopReason::ToolUse);
4433        assert_eq!(resp.message.tool_calls.len(), 1);
4434        assert_eq!(resp.message.tool_calls[0].name, "search");
4435    }
4436
4437    #[test]
4438    fn openai_stream_length_stop_reason() {
4439        let driver = OpenAiCompatibleDriver::new(
4440            "key".into(),
4441            "https://api.openai.com".into(),
4442            "openai".into(),
4443        );
4444
4445        let raw = concat!(
4446            "data: {\"choices\":[{\"delta\":{\"content\":\"truncated\"},\"index\":0}]}\n\n",
4447            "data: {\"choices\":[{\"delta\":{},\"index\":0,\"finish_reason\":\"length\"}]}\n\n",
4448            "data: [DONE]\n\n",
4449        );
4450
4451        let callback: StreamCallback = Arc::new(|_| {});
4452        let resp = driver.parse_openai_stream(raw, &callback).unwrap();
4453        assert_eq!(resp.stop_reason, StopReason::MaxTokens);
4454    }
4455
4456    // -----------------------------------------------------------------------
4457    // Ollama streaming tests
4458    // -----------------------------------------------------------------------
4459
4460    #[test]
4461    fn ollama_stream_text_only() {
4462        let driver = OllamaDriver::new(None);
4463
4464        let raw = concat!(
4465            "{\"message\":{\"role\":\"assistant\",\"content\":\"Hello\"},\"done\":false}\n",
4466            "{\"message\":{\"role\":\"assistant\",\"content\":\" world\"},\"done\":false}\n",
4467            "{\"message\":{\"role\":\"assistant\",\"content\":\"!\"},\"done\":false}\n",
4468            "{\"done\":true,\"prompt_eval_count\":15,\"eval_count\":8}\n",
4469        );
4470
4471        let chunks: Arc<std::sync::Mutex<Vec<StreamChunk>>> =
4472            Arc::new(std::sync::Mutex::new(Vec::new()));
4473        let chunks_clone = chunks.clone();
4474        let callback: StreamCallback = Arc::new(move |chunk| {
4475            chunks_clone.lock().unwrap().push(chunk);
4476        });
4477
4478        let resp = driver.parse_ollama_stream(raw, &callback).unwrap();
4479
4480        assert_eq!(resp.message.content, "Hello world!");
4481        assert_eq!(resp.stop_reason, StopReason::EndTurn);
4482        assert_eq!(resp.usage.input_tokens, 15);
4483        assert_eq!(resp.usage.output_tokens, 8);
4484
4485        let received = chunks.lock().unwrap();
4486        assert_eq!(received.len(), 4); // 3 content + 1 final
4487        assert_eq!(received[0].delta, "Hello");
4488        assert_eq!(received[1].delta, " world");
4489        assert_eq!(received[2].delta, "!");
4490        assert!(received[3].is_final);
4491    }
4492
4493    #[test]
4494    fn ollama_stream_with_tool_calls() {
4495        let driver = OllamaDriver::new(None);
4496
4497        let raw = concat!(
4498            "{\"message\":{\"role\":\"assistant\",\"content\":\"Let me check.\"},\"done\":false}\n",
4499            "{\"message\":{\"role\":\"assistant\",\"content\":\"\",\"tool_calls\":[{\"function\":{\"name\":\"get_weather\",\"arguments\":{\"city\":\"London\"}}}]},\"done\":true,\"prompt_eval_count\":10,\"eval_count\":5}\n",
4500        );
4501
4502        let callback: StreamCallback = Arc::new(|_| {});
4503        let resp = driver.parse_ollama_stream(raw, &callback).unwrap();
4504
4505        assert_eq!(resp.message.content, "Let me check.");
4506        assert_eq!(resp.stop_reason, StopReason::ToolUse);
4507        assert_eq!(resp.message.tool_calls.len(), 1);
4508        assert_eq!(resp.message.tool_calls[0].name, "get_weather");
4509        assert_eq!(resp.usage.input_tokens, 10);
4510    }
4511
4512    #[test]
4513    fn ollama_stream_strips_thinking_tags() {
4514        let driver = OllamaDriver::new(None);
4515
4516        let raw = concat!(
4517            "{\"message\":{\"role\":\"assistant\",\"content\":\"<think>hmm</think>\"},\"done\":false}\n",
4518            "{\"message\":{\"role\":\"assistant\",\"content\":\"Clean answer.\"},\"done\":false}\n",
4519            "{\"done\":true,\"prompt_eval_count\":5,\"eval_count\":3}\n",
4520        );
4521
4522        let callback: StreamCallback = Arc::new(|_| {});
4523        let resp = driver.parse_ollama_stream(raw, &callback).unwrap();
4524        assert_eq!(resp.message.content, "Clean answer.");
4525    }
4526
4527    // -----------------------------------------------------------------------
4528    // Gemini streaming tests
4529    // -----------------------------------------------------------------------
4530
4531    #[test]
4532    fn gemini_stream_url_construction() {
4533        let driver = GeminiDriver::new("my-key".to_string(), None);
4534        let url = driver.build_stream_url("gemini-pro");
4535        assert!(url.contains("streamGenerateContent"));
4536        assert!(url.contains("alt=sse"));
4537        assert!(url.contains("key=my-key"));
4538        assert!(url.contains("models/gemini-pro"));
4539    }
4540
4541    #[test]
4542    fn gemini_stream_custom_base_url() {
4543        let driver = GeminiDriver::new(
4544            "key".to_string(),
4545            Some("https://custom.example.com".to_string()),
4546        );
4547        let url = driver.build_stream_url("gemini-pro");
4548        assert!(url.starts_with("https://custom.example.com/"));
4549        assert!(url.contains("streamGenerateContent"));
4550    }
4551
4552    // -----------------------------------------------------------------------
4553    // Callback mechanism tests
4554    // -----------------------------------------------------------------------
4555
4556    #[test]
4557    fn callback_receives_all_chunks_in_order() {
4558        let driver = OpenAiCompatibleDriver::new(
4559            "key".into(),
4560            "https://api.openai.com".into(),
4561            "openai".into(),
4562        );
4563
4564        let raw = concat!(
4565            "data: {\"choices\":[{\"delta\":{\"content\":\"A\"},\"index\":0}]}\n\n",
4566            "data: {\"choices\":[{\"delta\":{\"content\":\"B\"},\"index\":0}]}\n\n",
4567            "data: {\"choices\":[{\"delta\":{\"content\":\"C\"},\"index\":0}]}\n\n",
4568            "data: {\"choices\":[{\"delta\":{},\"index\":0,\"finish_reason\":\"stop\"}]}\n\n",
4569            "data: [DONE]\n\n",
4570        );
4571
4572        let deltas: Arc<std::sync::Mutex<Vec<String>>> =
4573            Arc::new(std::sync::Mutex::new(Vec::new()));
4574        let deltas_clone = deltas.clone();
4575        let callback: StreamCallback = Arc::new(move |chunk| {
4576            if !chunk.delta.is_empty() || chunk.is_final {
4577                deltas_clone.lock().unwrap().push(chunk.delta.clone());
4578            }
4579        });
4580
4581        let _resp = driver.parse_openai_stream(raw, &callback).unwrap();
4582        let received = deltas.lock().unwrap();
4583        assert_eq!(received.as_slice(), &["A", "B", "C", ""]);
4584    }
4585
4586    #[test]
4587    fn openai_stream_multiple_tool_calls() {
4588        let driver = OpenAiCompatibleDriver::new(
4589            "key".into(),
4590            "https://api.openai.com".into(),
4591            "openai".into(),
4592        );
4593
4594        let raw = concat!(
4595            "data: {\"choices\":[{\"delta\":{\"tool_calls\":[{\"index\":0,\"id\":\"call_1\",\"type\":\"function\",\"function\":{\"name\":\"tool_a\",\"arguments\":\"{\\\"x\\\":1}\"}}]},\"index\":0}]}\n\n",
4596            "data: {\"choices\":[{\"delta\":{\"tool_calls\":[{\"index\":1,\"id\":\"call_2\",\"type\":\"function\",\"function\":{\"name\":\"tool_b\",\"arguments\":\"{\\\"y\\\":2}\"}}]},\"index\":0}]}\n\n",
4597            "data: {\"choices\":[{\"delta\":{},\"index\":0,\"finish_reason\":\"tool_calls\"}]}\n\n",
4598            "data: [DONE]\n\n",
4599        );
4600
4601        let callback: StreamCallback = Arc::new(|_| {});
4602        let resp = driver.parse_openai_stream(raw, &callback).unwrap();
4603
4604        assert_eq!(resp.message.tool_calls.len(), 2);
4605        assert_eq!(resp.message.tool_calls[0].id, "call_1");
4606        assert_eq!(resp.message.tool_calls[0].name, "tool_a");
4607        assert_eq!(resp.message.tool_calls[0].input["x"], 1);
4608        assert_eq!(resp.message.tool_calls[1].id, "call_2");
4609        assert_eq!(resp.message.tool_calls[1].name, "tool_b");
4610        assert_eq!(resp.message.tool_calls[1].input["y"], 2);
4611    }
4612
4613    // -----------------------------------------------------------------------
4614    // Default stream_complete_with_callback (trait default) test
4615    // -----------------------------------------------------------------------
4616
4617    #[test]
4618    fn stream_chunk_default_values() {
4619        let chunk = StreamChunk {
4620            delta: String::new(),
4621            is_final: false,
4622            tool_call_delta: None,
4623        };
4624        assert!(chunk.delta.is_empty());
4625        assert!(!chunk.is_final);
4626        assert!(chunk.tool_call_delta.is_none());
4627    }
4628
4629    #[test]
4630    fn openai_stream_empty_input() {
4631        let driver = OpenAiCompatibleDriver::new(
4632            "key".into(),
4633            "https://api.openai.com".into(),
4634            "openai".into(),
4635        );
4636
4637        let raw = "data: [DONE]\n\n";
4638
4639        let callback: StreamCallback = Arc::new(|_| {});
4640        let resp = driver.parse_openai_stream(raw, &callback).unwrap();
4641
4642        assert_eq!(resp.message.content, "");
4643        assert!(resp.message.tool_calls.is_empty());
4644    }
4645
4646    #[test]
4647    fn ollama_stream_empty_input() {
4648        let driver = OllamaDriver::new(None);
4649        let raw = "";
4650
4651        let callback: StreamCallback = Arc::new(|_| {});
4652        let resp = driver.parse_ollama_stream(raw, &callback).unwrap();
4653
4654        assert_eq!(resp.message.content, "");
4655        assert_eq!(resp.stop_reason, StopReason::MaxTokens); // not done
4656    }
4657
4658    #[test]
4659    fn openai_stream_strips_thinking_tags() {
4660        let driver = OpenAiCompatibleDriver::new(
4661            "key".into(),
4662            "https://api.openai.com".into(),
4663            "openai".into(),
4664        );
4665
4666        let raw = concat!(
4667            "data: {\"choices\":[{\"delta\":{\"content\":\"<think>internal</think>\"},\"index\":0}]}\n\n",
4668            "data: {\"choices\":[{\"delta\":{\"content\":\"Result\"},\"index\":0}]}\n\n",
4669            "data: {\"choices\":[{\"delta\":{},\"index\":0,\"finish_reason\":\"stop\"}]}\n\n",
4670            "data: [DONE]\n\n",
4671        );
4672
4673        let callback: StreamCallback = Arc::new(|_| {});
4674        let resp = driver.parse_openai_stream(raw, &callback).unwrap();
4675        assert_eq!(resp.message.content, "Result");
4676    }
4677}