1use std::sync::Arc;
8
9use async_trait::async_trait;
10use hmac::{Hmac, Mac};
11use reqwest::Client;
12use serde::{Deserialize, Serialize};
13use sha2::{Digest, Sha256};
14
15use punch_types::{
16 Message, ModelConfig, Provider, PunchError, PunchResult, Role, ToolCall, ToolDefinition,
17};
18
19#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
25#[serde(rename_all = "snake_case")]
26pub enum StopReason {
27 EndTurn,
29 ToolUse,
31 MaxTokens,
33 Error,
35}
36
37#[derive(Debug, Clone, Copy, Default, Serialize, Deserialize)]
39pub struct TokenUsage {
40 pub input_tokens: u64,
41 pub output_tokens: u64,
42}
43
44impl TokenUsage {
45 pub fn accumulate(&mut self, other: &TokenUsage) {
47 self.input_tokens += other.input_tokens;
48 self.output_tokens += other.output_tokens;
49 }
50
51 pub fn total(&self) -> u64 {
53 self.input_tokens + self.output_tokens
54 }
55}
56
57#[derive(Debug, Clone, Serialize, Deserialize)]
59pub struct CompletionRequest {
60 pub model: String,
62 pub messages: Vec<Message>,
64 #[serde(default)]
66 pub tools: Vec<ToolDefinition>,
67 pub max_tokens: u32,
69 pub temperature: Option<f32>,
71 pub system_prompt: Option<String>,
73}
74
75#[derive(Debug, Clone, Serialize, Deserialize)]
77pub struct CompletionResponse {
78 pub message: Message,
80 pub usage: TokenUsage,
82 pub stop_reason: StopReason,
84}
85
86pub fn strip_thinking_tags(content: &str) -> String {
99 let mut result = content.to_string();
100
101 for tag in &["think", "thinking", "reasoning", "reflection"] {
103 let open = format!("<{}>", tag);
104 let close = format!("</{}>", tag);
105
106 while let Some(start) = result.find(&open) {
108 if let Some(end) = result[start..].find(&close) {
109 let block_end = start + end + close.len();
110 result = format!("{}{}", &result[..start], &result[block_end..]);
111 } else {
112 result = result[..start].to_string();
114 break;
115 }
116 }
117 }
118
119 let trimmed = result.trim().to_string();
120
121 if trimmed.is_empty() {
124 content.to_string()
125 } else {
126 trimmed
127 }
128}
129
130#[async_trait]
136pub trait LlmDriver: Send + Sync + 'static {
137 async fn complete(&self, request: CompletionRequest) -> PunchResult<CompletionResponse>;
139
140 async fn stream_complete(
142 &self,
143 request: CompletionRequest,
144 ) -> PunchResult<CompletionResponse> {
145 self.complete(request).await
146 }
147}
148
149pub struct AnthropicDriver {
155 client: Client,
156 api_key: String,
157 base_url: String,
158}
159
160impl AnthropicDriver {
161 pub fn new(api_key: String, base_url: Option<String>) -> Self {
165 Self {
166 client: Client::new(),
167 api_key,
168 base_url: base_url.unwrap_or_else(|| "https://api.anthropic.com".to_string()),
169 }
170 }
171
172 pub fn with_client(client: Client, api_key: String, base_url: Option<String>) -> Self {
176 Self {
177 client,
178 api_key,
179 base_url: base_url.unwrap_or_else(|| "https://api.anthropic.com".to_string()),
180 }
181 }
182
183 fn build_request_body(&self, request: &CompletionRequest) -> serde_json::Value {
185 let mut messages = Vec::new();
186
187 for msg in &request.messages {
188 match msg.role {
189 Role::User => {
190 messages.push(serde_json::json!({
191 "role": "user",
192 "content": msg.content,
193 }));
194 }
195 Role::Assistant => {
196 let mut content_blocks: Vec<serde_json::Value> = Vec::new();
197
198 if !msg.content.is_empty() {
199 content_blocks.push(serde_json::json!({
200 "type": "text",
201 "text": msg.content,
202 }));
203 }
204
205 for tc in &msg.tool_calls {
206 content_blocks.push(serde_json::json!({
207 "type": "tool_use",
208 "id": tc.id,
209 "name": tc.name,
210 "input": tc.input,
211 }));
212 }
213
214 if content_blocks.is_empty() {
215 content_blocks.push(serde_json::json!({
216 "type": "text",
217 "text": "",
218 }));
219 }
220
221 messages.push(serde_json::json!({
222 "role": "assistant",
223 "content": content_blocks,
224 }));
225 }
226 Role::Tool => {
227 let mut result_blocks: Vec<serde_json::Value> = Vec::new();
228 for tr in &msg.tool_results {
229 result_blocks.push(serde_json::json!({
230 "type": "tool_result",
231 "tool_use_id": tr.id,
232 "content": tr.content,
233 "is_error": tr.is_error,
234 }));
235 }
236 messages.push(serde_json::json!({
237 "role": "user",
238 "content": result_blocks,
239 }));
240 }
241 Role::System => {
242 }
245 }
246 }
247
248 let tools: Vec<serde_json::Value> = request
249 .tools
250 .iter()
251 .map(|t| {
252 serde_json::json!({
253 "name": t.name,
254 "description": t.description,
255 "input_schema": t.input_schema,
256 })
257 })
258 .collect();
259
260 let mut body = serde_json::json!({
261 "model": request.model,
262 "messages": messages,
263 "max_tokens": request.max_tokens,
264 });
265
266 if let Some(temp) = request.temperature {
267 body["temperature"] = serde_json::json!(temp);
268 }
269
270 if let Some(ref system) = request.system_prompt {
271 body["system"] = serde_json::json!(system);
272 }
273
274 if !tools.is_empty() {
275 body["tools"] = serde_json::json!(tools);
276 }
277
278 body
279 }
280
281 fn parse_response(&self, body: &serde_json::Value) -> PunchResult<CompletionResponse> {
283 let stop_reason = match body["stop_reason"].as_str() {
284 Some("end_turn") => StopReason::EndTurn,
285 Some("tool_use") => StopReason::ToolUse,
286 Some("max_tokens") => StopReason::MaxTokens,
287 _ => StopReason::Error,
288 };
289
290 let usage = TokenUsage {
291 input_tokens: body["usage"]["input_tokens"].as_u64().unwrap_or(0),
292 output_tokens: body["usage"]["output_tokens"].as_u64().unwrap_or(0),
293 };
294
295 let mut text_content = String::new();
296 let mut tool_calls = Vec::new();
297
298 if let Some(content_array) = body["content"].as_array() {
299 for block in content_array {
300 match block["type"].as_str() {
301 Some("text") => {
302 if let Some(text) = block["text"].as_str() {
303 if !text_content.is_empty() {
304 text_content.push('\n');
305 }
306 text_content.push_str(text);
307 }
308 }
309 Some("tool_use") => {
310 tool_calls.push(ToolCall {
311 id: block["id"].as_str().unwrap_or_default().to_string(),
312 name: block["name"].as_str().unwrap_or_default().to_string(),
313 input: block["input"].clone(),
314 });
315 }
316 _ => {}
317 }
318 }
319 }
320
321 let text_content = strip_thinking_tags(&text_content);
323
324 let message = Message {
325 role: Role::Assistant,
326 content: text_content,
327 tool_calls,
328 tool_results: Vec::new(),
329 timestamp: chrono::Utc::now(),
330 };
331
332 Ok(CompletionResponse {
333 message,
334 usage,
335 stop_reason,
336 })
337 }
338}
339
340#[async_trait]
341impl LlmDriver for AnthropicDriver {
342 async fn complete(&self, request: CompletionRequest) -> PunchResult<CompletionResponse> {
343 let url = format!("{}/v1/messages", self.base_url);
344 let body = self.build_request_body(&request);
345
346 let response = self
347 .client
348 .post(&url)
349 .header("x-api-key", &self.api_key)
350 .header("anthropic-version", "2023-06-01")
351 .header("content-type", "application/json")
352 .json(&body)
353 .send()
354 .await
355 .map_err(|e| PunchError::Provider {
356 provider: "anthropic".to_string(),
357 message: format!("request failed: {e}"),
358 })?;
359
360 let status = response.status();
361
362 if status == reqwest::StatusCode::TOO_MANY_REQUESTS {
363 let retry_after = response
364 .headers()
365 .get("retry-after")
366 .and_then(|v| v.to_str().ok())
367 .and_then(|s| s.parse::<u64>().ok())
368 .unwrap_or(60)
369 * 1000;
370
371 return Err(PunchError::RateLimited {
372 provider: "anthropic".to_string(),
373 retry_after_ms: retry_after,
374 });
375 }
376
377 if status == reqwest::StatusCode::UNAUTHORIZED || status == reqwest::StatusCode::FORBIDDEN {
378 return Err(PunchError::Auth(
379 "anthropic API key is invalid or lacks permissions".to_string(),
380 ));
381 }
382
383 let response_body: serde_json::Value =
384 response.json().await.map_err(|e| PunchError::Provider {
385 provider: "anthropic".to_string(),
386 message: format!("failed to parse response: {e}"),
387 })?;
388
389 if !status.is_success() {
390 let error_msg = response_body["error"]["message"]
391 .as_str()
392 .unwrap_or("unknown error");
393 return Err(PunchError::Provider {
394 provider: "anthropic".to_string(),
395 message: format!("API error ({}): {}", status, error_msg),
396 });
397 }
398
399 self.parse_response(&response_body)
400 }
401}
402
403pub struct OpenAiCompatibleDriver {
413 client: Client,
414 api_key: String,
415 base_url: String,
416 provider_name: String,
417}
418
419impl OpenAiCompatibleDriver {
420 pub fn new(api_key: String, base_url: String, provider_name: String) -> Self {
422 Self {
423 client: Client::new(),
424 api_key,
425 base_url,
426 provider_name,
427 }
428 }
429
430 pub fn with_client(
432 client: Client,
433 api_key: String,
434 base_url: String,
435 provider_name: String,
436 ) -> Self {
437 Self {
438 client,
439 api_key,
440 base_url,
441 provider_name,
442 }
443 }
444
445 pub fn build_request_body(&self, request: &CompletionRequest) -> serde_json::Value {
447 let mut messages = Vec::new();
448
449 if let Some(ref system) = request.system_prompt {
451 messages.push(serde_json::json!({
452 "role": "system",
453 "content": system,
454 }));
455 }
456
457 for msg in &request.messages {
458 match msg.role {
459 Role::System => {
460 messages.push(serde_json::json!({
461 "role": "system",
462 "content": msg.content,
463 }));
464 }
465 Role::User => {
466 messages.push(serde_json::json!({
467 "role": "user",
468 "content": msg.content,
469 }));
470 }
471 Role::Assistant => {
472 let mut m = serde_json::json!({
473 "role": "assistant",
474 });
475
476 if !msg.content.is_empty() {
477 m["content"] = serde_json::json!(msg.content);
478 }
479
480 if !msg.tool_calls.is_empty() {
481 let tc: Vec<serde_json::Value> = msg
482 .tool_calls
483 .iter()
484 .map(|tc| {
485 serde_json::json!({
486 "id": tc.id,
487 "type": "function",
488 "function": {
489 "name": tc.name,
490 "arguments": tc.input.to_string(),
491 },
492 })
493 })
494 .collect();
495 m["tool_calls"] = serde_json::json!(tc);
496 }
497
498 messages.push(m);
499 }
500 Role::Tool => {
501 for tr in &msg.tool_results {
502 messages.push(serde_json::json!({
503 "role": "tool",
504 "tool_call_id": tr.id,
505 "content": tr.content,
506 }));
507 }
508 }
509 }
510 }
511
512 let tools: Vec<serde_json::Value> = request
513 .tools
514 .iter()
515 .map(|t| {
516 serde_json::json!({
517 "type": "function",
518 "function": {
519 "name": t.name,
520 "description": t.description,
521 "parameters": t.input_schema,
522 },
523 })
524 })
525 .collect();
526
527 let mut body = serde_json::json!({
528 "model": request.model,
529 "messages": messages,
530 "max_tokens": request.max_tokens,
531 });
532
533 if let Some(temp) = request.temperature {
534 body["temperature"] = serde_json::json!(temp);
535 }
536
537 if !tools.is_empty() {
538 body["tools"] = serde_json::json!(tools);
539 }
540
541 body
542 }
543
544 pub fn parse_response(&self, body: &serde_json::Value) -> PunchResult<CompletionResponse> {
546 let choice = body["choices"]
547 .get(0)
548 .ok_or_else(|| PunchError::Provider {
549 provider: self.provider_name.clone(),
550 message: "no choices in response".to_string(),
551 })?;
552
553 let finish_reason = choice["finish_reason"].as_str().unwrap_or("stop");
554 let stop_reason = match finish_reason {
555 "stop" => StopReason::EndTurn,
556 "tool_calls" => StopReason::ToolUse,
557 "length" => StopReason::MaxTokens,
558 _ => StopReason::EndTurn,
559 };
560
561 let msg = &choice["message"];
562 let raw_content = msg["content"].as_str().unwrap_or("");
563 let content = strip_thinking_tags(raw_content);
565
566 let mut tool_calls = Vec::new();
567 if let Some(tc_array) = msg["tool_calls"].as_array() {
568 for tc in tc_array {
569 let id = tc["id"].as_str().unwrap_or_default().to_string();
570 let name = tc["function"]["name"]
571 .as_str()
572 .unwrap_or_default()
573 .to_string();
574 let args_str = tc["function"]["arguments"].as_str().unwrap_or("{}");
575 let input: serde_json::Value =
576 serde_json::from_str(args_str).unwrap_or(serde_json::json!({}));
577
578 tool_calls.push(ToolCall { id, name, input });
579 }
580 }
581
582 let usage = TokenUsage {
583 input_tokens: body["usage"]["prompt_tokens"].as_u64().unwrap_or(0),
584 output_tokens: body["usage"]["completion_tokens"].as_u64().unwrap_or(0),
585 };
586
587 let stop_reason = if !tool_calls.is_empty() && stop_reason != StopReason::ToolUse {
589 StopReason::ToolUse
590 } else {
591 stop_reason
592 };
593
594 let message = Message {
595 role: Role::Assistant,
596 content,
597 tool_calls,
598 tool_results: Vec::new(),
599 timestamp: chrono::Utc::now(),
600 };
601
602 Ok(CompletionResponse {
603 message,
604 usage,
605 stop_reason,
606 })
607 }
608}
609
610#[async_trait]
611impl LlmDriver for OpenAiCompatibleDriver {
612 async fn complete(&self, request: CompletionRequest) -> PunchResult<CompletionResponse> {
613 let url = format!(
614 "{}/v1/chat/completions",
615 self.base_url.trim_end_matches('/')
616 );
617 let body = self.build_request_body(&request);
618
619 let response = self
620 .client
621 .post(&url)
622 .header("authorization", format!("Bearer {}", self.api_key))
623 .header("content-type", "application/json")
624 .json(&body)
625 .send()
626 .await
627 .map_err(|e| PunchError::Provider {
628 provider: self.provider_name.clone(),
629 message: format!("request failed: {e}"),
630 })?;
631
632 let status = response.status();
633
634 if status == reqwest::StatusCode::TOO_MANY_REQUESTS {
635 let retry_after = response
636 .headers()
637 .get("retry-after")
638 .and_then(|v| v.to_str().ok())
639 .and_then(|s| s.parse::<u64>().ok())
640 .unwrap_or(60)
641 * 1000;
642
643 return Err(PunchError::RateLimited {
644 provider: self.provider_name.clone(),
645 retry_after_ms: retry_after,
646 });
647 }
648
649 if status == reqwest::StatusCode::UNAUTHORIZED || status == reqwest::StatusCode::FORBIDDEN {
650 return Err(PunchError::Auth(format!(
651 "{} API key is invalid or lacks permissions",
652 self.provider_name
653 )));
654 }
655
656 let response_body: serde_json::Value =
657 response.json().await.map_err(|e| PunchError::Provider {
658 provider: self.provider_name.clone(),
659 message: format!("failed to parse response: {e}"),
660 })?;
661
662 if !status.is_success() {
663 let error_msg = response_body["error"]["message"]
664 .as_str()
665 .unwrap_or("unknown error");
666 return Err(PunchError::Provider {
667 provider: self.provider_name.clone(),
668 message: format!("API error ({}): {}", status, error_msg),
669 });
670 }
671
672 self.parse_response(&response_body)
673 }
674}
675
676pub struct GeminiDriver {
682 client: Client,
683 api_key: String,
684 base_url: String,
685}
686
687impl GeminiDriver {
688 pub fn new(api_key: String, base_url: Option<String>) -> Self {
690 Self {
691 client: Client::new(),
692 api_key,
693 base_url: base_url
694 .unwrap_or_else(|| "https://generativelanguage.googleapis.com".to_string()),
695 }
696 }
697
698 pub fn with_client(client: Client, api_key: String, base_url: Option<String>) -> Self {
700 Self {
701 client,
702 api_key,
703 base_url: base_url
704 .unwrap_or_else(|| "https://generativelanguage.googleapis.com".to_string()),
705 }
706 }
707
708 pub fn build_request_body(&self, request: &CompletionRequest) -> serde_json::Value {
710 let mut contents = Vec::new();
711 let mut system_text: Option<String> = request.system_prompt.clone();
712
713 for msg in &request.messages {
714 match msg.role {
715 Role::System => {
716 let existing = system_text.take().unwrap_or_default();
718 let combined = if existing.is_empty() {
719 msg.content.clone()
720 } else {
721 format!("{}\n{}", existing, msg.content)
722 };
723 system_text = Some(combined);
724 }
725 Role::User => {
726 let mut text = String::new();
727 if let Some(sys) = system_text.take()
728 && !sys.is_empty()
729 {
730 text.push_str(&sys);
731 text.push_str("\n\n");
732 }
733 text.push_str(&msg.content);
734 contents.push(serde_json::json!({
735 "role": "user",
736 "parts": [{"text": text}],
737 }));
738 }
739 Role::Assistant => {
740 let mut parts: Vec<serde_json::Value> = Vec::new();
741 if !msg.content.is_empty() {
742 parts.push(serde_json::json!({"text": msg.content}));
743 }
744 for tc in &msg.tool_calls {
745 parts.push(serde_json::json!({
746 "functionCall": {
747 "name": tc.name,
748 "args": tc.input,
749 }
750 }));
751 }
752 if parts.is_empty() {
753 parts.push(serde_json::json!({"text": ""}));
754 }
755 contents.push(serde_json::json!({
756 "role": "model",
757 "parts": parts,
758 }));
759 }
760 Role::Tool => {
761 let mut parts: Vec<serde_json::Value> = Vec::new();
762 for tr in &msg.tool_results {
763 parts.push(serde_json::json!({
764 "functionResponse": {
765 "name": tr.id.clone(),
766 "response": {"content": tr.content},
767 }
768 }));
769 }
770 contents.push(serde_json::json!({
771 "role": "user",
772 "parts": parts,
773 }));
774 }
775 }
776 }
777
778 if let Some(sys) = system_text
780 && !sys.is_empty()
781 {
782 contents.insert(
783 0,
784 serde_json::json!({
785 "role": "user",
786 "parts": [{"text": sys}],
787 }),
788 );
789 }
790
791 let mut body = serde_json::json!({
792 "contents": contents,
793 });
794
795 let mut gen_config = serde_json::json!({
796 "maxOutputTokens": request.max_tokens,
797 });
798 if let Some(temp) = request.temperature {
799 gen_config["temperature"] = serde_json::json!(temp);
800 }
801 body["generationConfig"] = gen_config;
802
803 if !request.tools.is_empty() {
804 let func_decls: Vec<serde_json::Value> = request
805 .tools
806 .iter()
807 .map(|t| {
808 serde_json::json!({
809 "name": t.name,
810 "description": t.description,
811 "parameters": t.input_schema,
812 })
813 })
814 .collect();
815 body["tools"] = serde_json::json!([{"function_declarations": func_decls}]);
816 }
817
818 body
819 }
820
821 pub fn build_url(&self, model: &str) -> String {
823 format!(
824 "{}/v1beta/models/{}:generateContent?key={}",
825 self.base_url.trim_end_matches('/'),
826 model,
827 self.api_key,
828 )
829 }
830
831 pub fn parse_response(&self, body: &serde_json::Value) -> PunchResult<CompletionResponse> {
833 let candidate = body["candidates"]
834 .get(0)
835 .ok_or_else(|| PunchError::Provider {
836 provider: "gemini".to_string(),
837 message: "no candidates in response".to_string(),
838 })?;
839
840 let parts = candidate["content"]["parts"]
841 .as_array()
842 .cloned()
843 .unwrap_or_default();
844
845 let mut text_content = String::new();
846 let mut tool_calls = Vec::new();
847
848 for part in &parts {
849 if let Some(text) = part["text"].as_str() {
850 if !text_content.is_empty() {
851 text_content.push('\n');
852 }
853 text_content.push_str(text);
854 }
855 if let Some(fc) = part.get("functionCall") {
856 let name = fc["name"].as_str().unwrap_or_default().to_string();
857 let args = fc["args"].clone();
858 tool_calls.push(ToolCall {
859 id: format!("gemini-{}", uuid::Uuid::new_v4()),
860 name,
861 input: args,
862 });
863 }
864 }
865
866 let finish_reason = candidate["finishReason"].as_str().unwrap_or("STOP");
867 let stop_reason = if !tool_calls.is_empty() {
868 StopReason::ToolUse
869 } else {
870 match finish_reason {
871 "STOP" => StopReason::EndTurn,
872 "MAX_TOKENS" => StopReason::MaxTokens,
873 _ => StopReason::EndTurn,
874 }
875 };
876
877 let usage = TokenUsage {
878 input_tokens: body["usageMetadata"]["promptTokenCount"]
879 .as_u64()
880 .unwrap_or(0),
881 output_tokens: body["usageMetadata"]["candidatesTokenCount"]
882 .as_u64()
883 .unwrap_or(0),
884 };
885
886 let text_content = strip_thinking_tags(&text_content);
888
889 let message = Message {
890 role: Role::Assistant,
891 content: text_content,
892 tool_calls,
893 tool_results: Vec::new(),
894 timestamp: chrono::Utc::now(),
895 };
896
897 Ok(CompletionResponse {
898 message,
899 usage,
900 stop_reason,
901 })
902 }
903}
904
905#[async_trait]
906impl LlmDriver for GeminiDriver {
907 async fn complete(&self, request: CompletionRequest) -> PunchResult<CompletionResponse> {
908 let url = self.build_url(&request.model);
909 let body = self.build_request_body(&request);
910
911 let response = self
912 .client
913 .post(&url)
914 .header("content-type", "application/json")
915 .json(&body)
916 .send()
917 .await
918 .map_err(|e| PunchError::Provider {
919 provider: "gemini".to_string(),
920 message: format!("request failed: {e}"),
921 })?;
922
923 let status = response.status();
924
925 if status == reqwest::StatusCode::TOO_MANY_REQUESTS {
926 return Err(PunchError::RateLimited {
927 provider: "gemini".to_string(),
928 retry_after_ms: 60_000,
929 });
930 }
931
932 if status == reqwest::StatusCode::UNAUTHORIZED || status == reqwest::StatusCode::FORBIDDEN {
933 return Err(PunchError::Auth(
934 "Gemini API key is invalid or lacks permissions".to_string(),
935 ));
936 }
937
938 let response_body: serde_json::Value =
939 response.json().await.map_err(|e| PunchError::Provider {
940 provider: "gemini".to_string(),
941 message: format!("failed to parse response: {e}"),
942 })?;
943
944 if !status.is_success() {
945 let error_msg = response_body["error"]["message"]
946 .as_str()
947 .unwrap_or("unknown error");
948 return Err(PunchError::Provider {
949 provider: "gemini".to_string(),
950 message: format!("API error ({}): {}", status, error_msg),
951 });
952 }
953
954 self.parse_response(&response_body)
955 }
956}
957
958pub struct OllamaDriver {
964 client: Client,
965 base_url: String,
966}
967
968impl OllamaDriver {
969 pub fn new(base_url: Option<String>) -> Self {
971 Self {
972 client: Client::new(),
973 base_url: base_url.unwrap_or_else(|| "http://localhost:11434".to_string()),
974 }
975 }
976
977 pub fn with_client(client: Client, base_url: Option<String>) -> Self {
979 Self {
980 client,
981 base_url: base_url.unwrap_or_else(|| "http://localhost:11434".to_string()),
982 }
983 }
984
985 pub fn base_url(&self) -> &str {
987 &self.base_url
988 }
989
990 pub fn build_request_body(&self, request: &CompletionRequest) -> serde_json::Value {
992 let mut messages = Vec::new();
993
994 if let Some(ref system) = request.system_prompt {
995 messages.push(serde_json::json!({
996 "role": "system",
997 "content": system,
998 }));
999 }
1000
1001 for msg in &request.messages {
1002 match msg.role {
1003 Role::System => {
1004 messages.push(serde_json::json!({
1005 "role": "system",
1006 "content": msg.content,
1007 }));
1008 }
1009 Role::User => {
1010 messages.push(serde_json::json!({
1011 "role": "user",
1012 "content": msg.content,
1013 }));
1014 }
1015 Role::Assistant => {
1016 let mut m = serde_json::json!({
1017 "role": "assistant",
1018 "content": msg.content,
1019 });
1020 if !msg.tool_calls.is_empty() {
1021 let tc: Vec<serde_json::Value> = msg
1022 .tool_calls
1023 .iter()
1024 .map(|tc| {
1025 serde_json::json!({
1026 "function": {
1027 "name": tc.name,
1028 "arguments": tc.input,
1029 }
1030 })
1031 })
1032 .collect();
1033 m["tool_calls"] = serde_json::json!(tc);
1034 }
1035 messages.push(m);
1036 }
1037 Role::Tool => {
1038 for tr in &msg.tool_results {
1039 messages.push(serde_json::json!({
1040 "role": "tool",
1041 "content": tr.content,
1042 }));
1043 }
1044 }
1045 }
1046 }
1047
1048 let mut body = serde_json::json!({
1049 "model": request.model,
1050 "messages": messages,
1051 "stream": false,
1052 });
1053
1054 let mut options = serde_json::json!({});
1055 if let Some(temp) = request.temperature {
1056 options["temperature"] = serde_json::json!(temp);
1057 }
1058 if request.max_tokens > 0 {
1059 options["num_predict"] = serde_json::json!(request.max_tokens);
1060 }
1061 body["options"] = options;
1062
1063 body["think"] = serde_json::json!(false);
1067
1068 if !request.tools.is_empty() {
1069 let tools: Vec<serde_json::Value> = request
1070 .tools
1071 .iter()
1072 .map(|t| {
1073 serde_json::json!({
1074 "type": "function",
1075 "function": {
1076 "name": t.name,
1077 "description": t.description,
1078 "parameters": t.input_schema,
1079 }
1080 })
1081 })
1082 .collect();
1083 body["tools"] = serde_json::json!(tools);
1084 }
1085
1086 body
1087 }
1088
1089 pub fn parse_response(&self, body: &serde_json::Value) -> PunchResult<CompletionResponse> {
1091 let msg = &body["message"];
1092 let raw_content = msg["content"].as_str().unwrap_or("");
1093 let content = strip_thinking_tags(raw_content);
1095
1096 let mut tool_calls = Vec::new();
1097 if let Some(tc_array) = msg["tool_calls"].as_array() {
1098 for tc in tc_array {
1099 let name = tc["function"]["name"]
1100 .as_str()
1101 .unwrap_or_default()
1102 .to_string();
1103 let input = tc["function"]["arguments"].clone();
1104 tool_calls.push(ToolCall {
1105 id: format!("ollama-{}", uuid::Uuid::new_v4()),
1106 name,
1107 input,
1108 });
1109 }
1110 }
1111
1112 let stop_reason = if !tool_calls.is_empty() {
1113 StopReason::ToolUse
1114 } else if body["done"].as_bool().unwrap_or(true) {
1115 StopReason::EndTurn
1116 } else {
1117 StopReason::MaxTokens
1118 };
1119
1120 let usage = TokenUsage {
1121 input_tokens: body["prompt_eval_count"].as_u64().unwrap_or(0),
1122 output_tokens: body["eval_count"].as_u64().unwrap_or(0),
1123 };
1124
1125 let message = Message {
1126 role: Role::Assistant,
1127 content,
1128 tool_calls,
1129 tool_results: Vec::new(),
1130 timestamp: chrono::Utc::now(),
1131 };
1132
1133 Ok(CompletionResponse {
1134 message,
1135 usage,
1136 stop_reason,
1137 })
1138 }
1139}
1140
1141#[async_trait]
1142impl LlmDriver for OllamaDriver {
1143 async fn complete(&self, request: CompletionRequest) -> PunchResult<CompletionResponse> {
1144 let url = format!("{}/api/chat", self.base_url.trim_end_matches('/'));
1145 let body = self.build_request_body(&request);
1146
1147 let response = self
1148 .client
1149 .post(&url)
1150 .header("content-type", "application/json")
1151 .json(&body)
1152 .send()
1153 .await
1154 .map_err(|e| PunchError::Provider {
1155 provider: "ollama".to_string(),
1156 message: format!("request failed: {e}"),
1157 })?;
1158
1159 let status = response.status();
1160 let response_body: serde_json::Value =
1161 response.json().await.map_err(|e| PunchError::Provider {
1162 provider: "ollama".to_string(),
1163 message: format!("failed to parse response: {e}"),
1164 })?;
1165
1166 if !status.is_success() {
1167 let error_msg = response_body["error"]
1168 .as_str()
1169 .unwrap_or("unknown error");
1170 return Err(PunchError::Provider {
1171 provider: "ollama".to_string(),
1172 message: format!("API error ({}): {}", status, error_msg),
1173 });
1174 }
1175
1176 self.parse_response(&response_body)
1177 }
1178}
1179
1180pub struct BedrockDriver {
1186 client: Client,
1187 access_key: String,
1188 secret_key: String,
1189 region: String,
1190}
1191
1192impl BedrockDriver {
1193 pub fn new(access_key: String, secret_key: String, region: String) -> Self {
1195 Self {
1196 client: Client::new(),
1197 access_key,
1198 secret_key,
1199 region,
1200 }
1201 }
1202
1203 pub fn with_client(
1205 client: Client,
1206 access_key: String,
1207 secret_key: String,
1208 region: String,
1209 ) -> Self {
1210 Self {
1211 client,
1212 access_key,
1213 secret_key,
1214 region,
1215 }
1216 }
1217
1218 pub fn build_request_body(&self, request: &CompletionRequest) -> serde_json::Value {
1220 let mut messages = Vec::new();
1221
1222 for msg in &request.messages {
1223 match msg.role {
1224 Role::User => {
1225 messages.push(serde_json::json!({
1226 "role": "user",
1227 "content": [{"text": msg.content}],
1228 }));
1229 }
1230 Role::Assistant => {
1231 let mut content: Vec<serde_json::Value> = Vec::new();
1232 if !msg.content.is_empty() {
1233 content.push(serde_json::json!({"text": msg.content}));
1234 }
1235 for tc in &msg.tool_calls {
1236 content.push(serde_json::json!({
1237 "toolUse": {
1238 "toolUseId": tc.id,
1239 "name": tc.name,
1240 "input": tc.input,
1241 }
1242 }));
1243 }
1244 if content.is_empty() {
1245 content.push(serde_json::json!({"text": ""}));
1246 }
1247 messages.push(serde_json::json!({
1248 "role": "assistant",
1249 "content": content,
1250 }));
1251 }
1252 Role::Tool => {
1253 let mut content: Vec<serde_json::Value> = Vec::new();
1254 for tr in &msg.tool_results {
1255 content.push(serde_json::json!({
1256 "toolResult": {
1257 "toolUseId": tr.id,
1258 "content": [{"text": tr.content}],
1259 "status": if tr.is_error { "error" } else { "success" },
1260 }
1261 }));
1262 }
1263 messages.push(serde_json::json!({
1264 "role": "user",
1265 "content": content,
1266 }));
1267 }
1268 Role::System => {
1269 }
1271 }
1272 }
1273
1274 let mut body = serde_json::json!({
1275 "messages": messages,
1276 });
1277
1278 let mut inference_config = serde_json::json!({
1279 "maxTokens": request.max_tokens,
1280 });
1281 if let Some(temp) = request.temperature {
1282 inference_config["temperature"] = serde_json::json!(temp);
1283 }
1284 body["inferenceConfig"] = inference_config;
1285
1286 if let Some(ref system) = request.system_prompt {
1287 body["system"] = serde_json::json!([{"text": system}]);
1288 }
1289
1290 if !request.tools.is_empty() {
1291 let tool_config: Vec<serde_json::Value> = request
1292 .tools
1293 .iter()
1294 .map(|t| {
1295 serde_json::json!({
1296 "toolSpec": {
1297 "name": t.name,
1298 "description": t.description,
1299 "inputSchema": {"json": t.input_schema},
1300 }
1301 })
1302 })
1303 .collect();
1304 body["toolConfig"] = serde_json::json!({"tools": tool_config});
1305 }
1306
1307 body
1308 }
1309
1310 pub fn build_url(&self, model_id: &str) -> String {
1312 format!(
1313 "https://bedrock-runtime.{}.amazonaws.com/model/{}/converse",
1314 self.region, model_id,
1315 )
1316 }
1317
1318 pub fn parse_response(&self, body: &serde_json::Value) -> PunchResult<CompletionResponse> {
1320 let content = body["output"]["message"]["content"]
1321 .as_array()
1322 .cloned()
1323 .unwrap_or_default();
1324
1325 let mut text_content = String::new();
1326 let mut tool_calls = Vec::new();
1327
1328 for block in &content {
1329 if let Some(text) = block["text"].as_str() {
1330 if !text_content.is_empty() {
1331 text_content.push('\n');
1332 }
1333 text_content.push_str(text);
1334 }
1335 if let Some(tu) = block.get("toolUse") {
1336 tool_calls.push(ToolCall {
1337 id: tu["toolUseId"].as_str().unwrap_or_default().to_string(),
1338 name: tu["name"].as_str().unwrap_or_default().to_string(),
1339 input: tu["input"].clone(),
1340 });
1341 }
1342 }
1343
1344 let stop_reason_str = body["stopReason"].as_str().unwrap_or("end_turn");
1345 let stop_reason = if !tool_calls.is_empty() {
1346 StopReason::ToolUse
1347 } else {
1348 match stop_reason_str {
1349 "end_turn" => StopReason::EndTurn,
1350 "tool_use" => StopReason::ToolUse,
1351 "max_tokens" => StopReason::MaxTokens,
1352 _ => StopReason::EndTurn,
1353 }
1354 };
1355
1356 let usage = TokenUsage {
1357 input_tokens: body["usage"]["inputTokens"].as_u64().unwrap_or(0),
1358 output_tokens: body["usage"]["outputTokens"].as_u64().unwrap_or(0),
1359 };
1360
1361 let text_content = strip_thinking_tags(&text_content);
1363
1364 let message = Message {
1365 role: Role::Assistant,
1366 content: text_content,
1367 tool_calls,
1368 tool_results: Vec::new(),
1369 timestamp: chrono::Utc::now(),
1370 };
1371
1372 Ok(CompletionResponse {
1373 message,
1374 usage,
1375 stop_reason,
1376 })
1377 }
1378
1379 pub fn sign_request(
1383 &self,
1384 method: &str,
1385 url: &str,
1386 headers: &[(String, String)],
1387 payload: &[u8],
1388 timestamp: &str, ) -> PunchResult<String> {
1390 let date = ×tamp[..8]; let service = "bedrock";
1392
1393 let parsed = url::Url::parse(url).map_err(|e| PunchError::Provider {
1395 provider: "bedrock".to_string(),
1396 message: format!("invalid URL: {e}"),
1397 })?;
1398 let host = parsed.host_str().unwrap_or("");
1399 let path = parsed.path();
1400
1401 let payload_hash = hex_sha256(payload);
1403
1404 let mut signed_header_names: Vec<String> =
1405 headers.iter().map(|(k, _)| k.to_lowercase()).collect();
1406 signed_header_names.push("host".to_string());
1407 signed_header_names.push("x-amz-date".to_string());
1408 signed_header_names.sort();
1409 signed_header_names.dedup();
1410
1411 let mut header_map: Vec<(String, String)> = headers
1412 .iter()
1413 .map(|(k, v)| (k.to_lowercase(), v.trim().to_string()))
1414 .collect();
1415 header_map.push(("host".to_string(), host.to_string()));
1416 header_map.push(("x-amz-date".to_string(), timestamp.to_string()));
1417 header_map.sort_by(|a, b| a.0.cmp(&b.0));
1418 header_map.dedup_by(|a, b| a.0 == b.0);
1419
1420 let canonical_headers: String = header_map
1421 .iter()
1422 .map(|(k, v)| format!("{}:{}\n", k, v))
1423 .collect();
1424
1425 let signed_headers = signed_header_names.join(";");
1426
1427 let canonical_request = format!(
1428 "{}\n{}\n\n{}\n{}\n{}",
1429 method, path, canonical_headers, signed_headers, payload_hash,
1430 );
1431
1432 let credential_scope = format!("{}/{}/{}/aws4_request", date, self.region, service);
1434 let string_to_sign = format!(
1435 "AWS4-HMAC-SHA256\n{}\n{}\n{}",
1436 timestamp,
1437 credential_scope,
1438 hex_sha256(canonical_request.as_bytes()),
1439 );
1440
1441 let k_date = hmac_sha256(
1443 format!("AWS4{}", self.secret_key).as_bytes(),
1444 date.as_bytes(),
1445 );
1446 let k_region = hmac_sha256(&k_date, self.region.as_bytes());
1447 let k_service = hmac_sha256(&k_region, service.as_bytes());
1448 let k_signing = hmac_sha256(&k_service, b"aws4_request");
1449
1450 let signature = hex_encode(&hmac_sha256(&k_signing, string_to_sign.as_bytes()));
1452
1453 Ok(format!(
1455 "AWS4-HMAC-SHA256 Credential={}/{}, SignedHeaders={}, Signature={}",
1456 self.access_key, credential_scope, signed_headers, signature,
1457 ))
1458 }
1459}
1460
1461fn hex_sha256(data: &[u8]) -> String {
1463 let mut hasher = Sha256::new();
1464 hasher.update(data);
1465 hex_encode(hasher.finalize().as_slice())
1466}
1467
1468fn hmac_sha256(key: &[u8], data: &[u8]) -> Vec<u8> {
1470 type HmacSha256 = Hmac<Sha256>;
1471 let mut mac = HmacSha256::new_from_slice(key).expect("HMAC can take key of any size");
1472 mac.update(data);
1473 mac.finalize().into_bytes().to_vec()
1474}
1475
1476fn hex_encode(bytes: &[u8]) -> String {
1478 bytes.iter().map(|b| format!("{:02x}", b)).collect()
1479}
1480
1481#[async_trait]
1482impl LlmDriver for BedrockDriver {
1483 async fn complete(&self, request: CompletionRequest) -> PunchResult<CompletionResponse> {
1484 let url = self.build_url(&request.model);
1485 let body = self.build_request_body(&request);
1486 let payload = serde_json::to_vec(&body).map_err(|e| PunchError::Provider {
1487 provider: "bedrock".to_string(),
1488 message: format!("failed to serialize request: {e}"),
1489 })?;
1490
1491 let timestamp = chrono::Utc::now().format("%Y%m%dT%H%M%SZ").to_string();
1492
1493 let auth_header = self.sign_request(
1494 "POST",
1495 &url,
1496 &[("content-type".to_string(), "application/json".to_string())],
1497 &payload,
1498 ×tamp,
1499 )?;
1500
1501 let parsed_url = url::Url::parse(&url).map_err(|e| PunchError::Provider {
1502 provider: "bedrock".to_string(),
1503 message: format!("invalid URL: {e}"),
1504 })?;
1505 let host = parsed_url.host_str().unwrap_or_default().to_string();
1506
1507 let response = self
1508 .client
1509 .post(&url)
1510 .header("content-type", "application/json")
1511 .header("host", &host)
1512 .header("x-amz-date", ×tamp)
1513 .header("authorization", &auth_header)
1514 .body(payload)
1515 .send()
1516 .await
1517 .map_err(|e| PunchError::Provider {
1518 provider: "bedrock".to_string(),
1519 message: format!("request failed: {e}"),
1520 })?;
1521
1522 let status = response.status();
1523
1524 if status == reqwest::StatusCode::TOO_MANY_REQUESTS {
1525 return Err(PunchError::RateLimited {
1526 provider: "bedrock".to_string(),
1527 retry_after_ms: 60_000,
1528 });
1529 }
1530
1531 if status == reqwest::StatusCode::UNAUTHORIZED || status == reqwest::StatusCode::FORBIDDEN {
1532 return Err(PunchError::Auth(
1533 "AWS Bedrock credentials are invalid or lack permissions".to_string(),
1534 ));
1535 }
1536
1537 let response_body: serde_json::Value =
1538 response.json().await.map_err(|e| PunchError::Provider {
1539 provider: "bedrock".to_string(),
1540 message: format!("failed to parse response: {e}"),
1541 })?;
1542
1543 if !status.is_success() {
1544 let error_msg = response_body["message"]
1545 .as_str()
1546 .unwrap_or("unknown error");
1547 return Err(PunchError::Provider {
1548 provider: "bedrock".to_string(),
1549 message: format!("API error ({}): {}", status, error_msg),
1550 });
1551 }
1552
1553 self.parse_response(&response_body)
1554 }
1555}
1556
1557pub struct AzureOpenAiDriver {
1566 inner: OpenAiCompatibleDriver,
1567 resource: String,
1568 deployment: String,
1569 api_version: String,
1570}
1571
1572impl AzureOpenAiDriver {
1573 pub fn new(
1580 api_key: String,
1581 resource: String,
1582 deployment: String,
1583 api_version: Option<String>,
1584 ) -> Self {
1585 let base_url = format!("https://{}.openai.azure.com", resource);
1586 Self {
1587 inner: OpenAiCompatibleDriver::new(api_key, base_url, "azure_openai".to_string()),
1588 resource,
1589 deployment,
1590 api_version: api_version.unwrap_or_else(|| "2024-02-01".to_string()),
1591 }
1592 }
1593
1594 pub fn with_client(
1596 client: Client,
1597 api_key: String,
1598 resource: String,
1599 deployment: String,
1600 api_version: Option<String>,
1601 ) -> Self {
1602 let base_url = format!("https://{}.openai.azure.com", resource);
1603 Self {
1604 inner: OpenAiCompatibleDriver::with_client(
1605 client,
1606 api_key,
1607 base_url,
1608 "azure_openai".to_string(),
1609 ),
1610 resource,
1611 deployment,
1612 api_version: api_version.unwrap_or_else(|| "2024-02-01".to_string()),
1613 }
1614 }
1615
1616 pub fn build_url(&self) -> String {
1618 format!(
1619 "https://{}.openai.azure.com/openai/deployments/{}/chat/completions?api-version={}",
1620 self.resource, self.deployment, self.api_version,
1621 )
1622 }
1623
1624 pub fn resource(&self) -> &str {
1626 &self.resource
1627 }
1628
1629 pub fn deployment(&self) -> &str {
1631 &self.deployment
1632 }
1633
1634 pub fn build_request_body(&self, request: &CompletionRequest) -> serde_json::Value {
1636 self.inner.build_request_body(request)
1637 }
1638
1639 pub fn parse_response(&self, body: &serde_json::Value) -> PunchResult<CompletionResponse> {
1641 self.inner.parse_response(body)
1642 }
1643}
1644
1645#[async_trait]
1646impl LlmDriver for AzureOpenAiDriver {
1647 async fn complete(&self, request: CompletionRequest) -> PunchResult<CompletionResponse> {
1648 let url = self.build_url();
1649 let body = self.inner.build_request_body(&request);
1650
1651 let response = self
1652 .inner
1653 .client
1654 .post(&url)
1655 .header("api-key", &self.inner.api_key)
1656 .header("content-type", "application/json")
1657 .json(&body)
1658 .send()
1659 .await
1660 .map_err(|e| PunchError::Provider {
1661 provider: "azure_openai".to_string(),
1662 message: format!("request failed: {e}"),
1663 })?;
1664
1665 let status = response.status();
1666
1667 if status == reqwest::StatusCode::TOO_MANY_REQUESTS {
1668 let retry_after = response
1669 .headers()
1670 .get("retry-after")
1671 .and_then(|v| v.to_str().ok())
1672 .and_then(|s| s.parse::<u64>().ok())
1673 .unwrap_or(60)
1674 * 1000;
1675
1676 return Err(PunchError::RateLimited {
1677 provider: "azure_openai".to_string(),
1678 retry_after_ms: retry_after,
1679 });
1680 }
1681
1682 if status == reqwest::StatusCode::UNAUTHORIZED || status == reqwest::StatusCode::FORBIDDEN {
1683 return Err(PunchError::Auth(
1684 "Azure OpenAI API key is invalid or lacks permissions".to_string(),
1685 ));
1686 }
1687
1688 let response_body: serde_json::Value =
1689 response.json().await.map_err(|e| PunchError::Provider {
1690 provider: "azure_openai".to_string(),
1691 message: format!("failed to parse response: {e}"),
1692 })?;
1693
1694 if !status.is_success() {
1695 let error_msg = response_body["error"]["message"]
1696 .as_str()
1697 .unwrap_or("unknown error");
1698 return Err(PunchError::Provider {
1699 provider: "azure_openai".to_string(),
1700 message: format!("API error ({}): {}", status, error_msg),
1701 });
1702 }
1703
1704 self.inner.parse_response(&response_body)
1705 }
1706}
1707
1708fn default_base_url(provider: &Provider) -> &'static str {
1714 match provider {
1715 Provider::Anthropic => "https://api.anthropic.com",
1716 Provider::OpenAI => "https://api.openai.com",
1717 Provider::Google => "https://generativelanguage.googleapis.com",
1718 Provider::Groq => "https://api.groq.com/openai",
1719 Provider::DeepSeek => "https://api.deepseek.com",
1720 Provider::Ollama => "http://localhost:11434",
1721 Provider::Mistral => "https://api.mistral.ai",
1722 Provider::Together => "https://api.together.xyz",
1723 Provider::Fireworks => "https://api.fireworks.ai/inference",
1724 Provider::Cerebras => "https://api.cerebras.ai",
1725 Provider::XAI => "https://api.x.ai",
1726 Provider::Cohere => "https://api.cohere.ai",
1727 Provider::Bedrock => "https://bedrock-runtime.us-east-1.amazonaws.com",
1728 Provider::AzureOpenAi => "",
1729 Provider::Custom(_) => "",
1730 }
1731}
1732
1733pub fn create_driver(config: &ModelConfig) -> PunchResult<Arc<dyn LlmDriver>> {
1743 create_driver_with_client(config, None)
1744}
1745
1746pub fn create_driver_with_client(
1748 config: &ModelConfig,
1749 shared_client: Option<&Client>,
1750) -> PunchResult<Arc<dyn LlmDriver>> {
1751 let api_key = match &config.api_key_env {
1752 Some(env_var) => std::env::var(env_var).map_err(|_| {
1753 PunchError::Auth(format!(
1754 "environment variable '{}' not set for {} driver",
1755 env_var, config.provider
1756 ))
1757 })?,
1758 None => {
1759 String::new()
1761 }
1762 };
1763
1764 let base_url = config
1765 .base_url
1766 .clone()
1767 .unwrap_or_else(|| default_base_url(&config.provider).to_string());
1768
1769 match &config.provider {
1770 Provider::Anthropic => {
1771 if let Some(client) = shared_client {
1772 Ok(Arc::new(AnthropicDriver::with_client(
1773 client.clone(),
1774 api_key,
1775 Some(base_url),
1776 )))
1777 } else {
1778 Ok(Arc::new(AnthropicDriver::new(api_key, Some(base_url))))
1779 }
1780 }
1781 Provider::Google => {
1782 if let Some(client) = shared_client {
1783 Ok(Arc::new(GeminiDriver::with_client(
1784 client.clone(),
1785 api_key,
1786 Some(base_url),
1787 )))
1788 } else {
1789 Ok(Arc::new(GeminiDriver::new(api_key, Some(base_url))))
1790 }
1791 }
1792 Provider::Ollama => {
1793 if let Some(client) = shared_client {
1794 Ok(Arc::new(OllamaDriver::with_client(
1795 client.clone(),
1796 Some(base_url),
1797 )))
1798 } else {
1799 Ok(Arc::new(OllamaDriver::new(Some(base_url))))
1800 }
1801 }
1802 Provider::Bedrock => {
1803 let (access_key, secret_key) = if api_key.contains(':') {
1806 let parts: Vec<&str> = api_key.splitn(2, ':').collect();
1807 (parts[0].to_string(), parts[1].to_string())
1808 } else {
1809 let ak = std::env::var("AWS_ACCESS_KEY_ID").unwrap_or(api_key);
1810 let sk = std::env::var("AWS_SECRET_ACCESS_KEY").unwrap_or_default();
1811 (ak, sk)
1812 };
1813 let region = if base_url.contains("bedrock-runtime.") {
1815 base_url
1816 .trim_start_matches("https://bedrock-runtime.")
1817 .split('.')
1818 .next()
1819 .unwrap_or("us-east-1")
1820 .to_string()
1821 } else {
1822 "us-east-1".to_string()
1823 };
1824 if let Some(client) = shared_client {
1825 Ok(Arc::new(BedrockDriver::with_client(
1826 client.clone(),
1827 access_key,
1828 secret_key,
1829 region,
1830 )))
1831 } else {
1832 Ok(Arc::new(BedrockDriver::new(access_key, secret_key, region)))
1833 }
1834 }
1835 Provider::AzureOpenAi => {
1836 let resource = if base_url.contains(".openai.azure.com") {
1839 base_url
1840 .trim_start_matches("https://")
1841 .split('.')
1842 .next()
1843 .unwrap_or("default")
1844 .to_string()
1845 } else {
1846 base_url.clone()
1847 };
1848 let deployment = config.model.clone();
1849 if let Some(client) = shared_client {
1850 Ok(Arc::new(AzureOpenAiDriver::with_client(
1851 client.clone(),
1852 api_key,
1853 resource,
1854 deployment,
1855 None,
1856 )))
1857 } else {
1858 Ok(Arc::new(AzureOpenAiDriver::new(
1859 api_key,
1860 resource,
1861 deployment,
1862 None,
1863 )))
1864 }
1865 }
1866 provider => {
1867 let name = provider.to_string();
1868 if let Some(client) = shared_client {
1869 Ok(Arc::new(OpenAiCompatibleDriver::with_client(
1870 client.clone(),
1871 api_key,
1872 base_url,
1873 name,
1874 )))
1875 } else {
1876 Ok(Arc::new(OpenAiCompatibleDriver::new(
1877 api_key, base_url, name,
1878 )))
1879 }
1880 }
1881 }
1882}
1883
1884#[cfg(test)]
1889mod tests {
1890 use super::*;
1891 use punch_types::ToolCategory;
1892
1893 fn simple_request() -> CompletionRequest {
1895 CompletionRequest {
1896 model: "test-model".to_string(),
1897 messages: vec![Message::new(Role::User, "Hello")],
1898 tools: Vec::new(),
1899 max_tokens: 4096,
1900 temperature: Some(0.7),
1901 system_prompt: Some("You are helpful.".to_string()),
1902 }
1903 }
1904
1905 fn request_with_tools() -> CompletionRequest {
1907 CompletionRequest {
1908 model: "test-model".to_string(),
1909 messages: vec![Message::new(Role::User, "Use the tool")],
1910 tools: vec![ToolDefinition {
1911 name: "get_weather".to_string(),
1912 description: "Get weather for a city".to_string(),
1913 input_schema: serde_json::json!({
1914 "type": "object",
1915 "properties": {
1916 "city": {"type": "string"}
1917 }
1918 }),
1919 category: ToolCategory::Web,
1920 }],
1921 max_tokens: 4096,
1922 temperature: Some(0.7),
1923 system_prompt: None,
1924 }
1925 }
1926
1927 #[test]
1932 fn gemini_request_formatting() {
1933 let driver = GeminiDriver::new("test-key".to_string(), None);
1934 let body = driver.build_request_body(&simple_request());
1935
1936 let contents = body["contents"].as_array().unwrap();
1937 assert_eq!(contents.len(), 1);
1938 let first_text = contents[0]["parts"][0]["text"].as_str().unwrap();
1940 assert!(first_text.contains("You are helpful."));
1941 assert!(first_text.contains("Hello"));
1942 assert_eq!(contents[0]["role"].as_str().unwrap(), "user");
1944
1945 assert_eq!(body["generationConfig"]["maxOutputTokens"], 4096);
1946 assert!((body["generationConfig"]["temperature"].as_f64().unwrap() - 0.7).abs() < 0.001);
1947 }
1948
1949 #[test]
1950 fn gemini_response_parsing() {
1951 let driver = GeminiDriver::new("test-key".to_string(), None);
1952 let response_body = serde_json::json!({
1953 "candidates": [{
1954 "content": {
1955 "parts": [{"text": "Hello there!"}],
1956 "role": "model"
1957 },
1958 "finishReason": "STOP"
1959 }],
1960 "usageMetadata": {
1961 "promptTokenCount": 10,
1962 "candidatesTokenCount": 5
1963 }
1964 });
1965
1966 let resp = driver.parse_response(&response_body).unwrap();
1967 assert_eq!(resp.message.content, "Hello there!");
1968 assert_eq!(resp.stop_reason, StopReason::EndTurn);
1969 assert_eq!(resp.usage.input_tokens, 10);
1970 assert_eq!(resp.usage.output_tokens, 5);
1971 }
1972
1973 #[test]
1974 fn gemini_role_mapping_system_prepended() {
1975 let driver = GeminiDriver::new("test-key".to_string(), None);
1976 let req = CompletionRequest {
1977 model: "gemini-pro".to_string(),
1978 messages: vec![
1979 Message::new(Role::System, "Be concise."),
1980 Message::new(Role::User, "Hi"),
1981 ],
1982 tools: Vec::new(),
1983 max_tokens: 1024,
1984 temperature: None,
1985 system_prompt: None,
1986 };
1987 let body = driver.build_request_body(&req);
1988 let contents = body["contents"].as_array().unwrap();
1989 assert_eq!(contents.len(), 1);
1991 let text = contents[0]["parts"][0]["text"].as_str().unwrap();
1992 assert!(text.contains("Be concise."));
1993 assert!(text.contains("Hi"));
1994 }
1995
1996 #[test]
1997 fn gemini_function_call_parsing() {
1998 let driver = GeminiDriver::new("test-key".to_string(), None);
1999 let response_body = serde_json::json!({
2000 "candidates": [{
2001 "content": {
2002 "parts": [
2003 {"text": "Let me check the weather."},
2004 {
2005 "functionCall": {
2006 "name": "get_weather",
2007 "args": {"city": "London"}
2008 }
2009 }
2010 ],
2011 "role": "model"
2012 },
2013 "finishReason": "STOP"
2014 }],
2015 "usageMetadata": {
2016 "promptTokenCount": 15,
2017 "candidatesTokenCount": 8
2018 }
2019 });
2020
2021 let resp = driver.parse_response(&response_body).unwrap();
2022 assert_eq!(resp.message.content, "Let me check the weather.");
2023 assert_eq!(resp.stop_reason, StopReason::ToolUse);
2024 assert_eq!(resp.message.tool_calls.len(), 1);
2025 assert_eq!(resp.message.tool_calls[0].name, "get_weather");
2026 assert_eq!(resp.message.tool_calls[0].input["city"], "London");
2027 }
2028
2029 #[test]
2030 fn gemini_api_key_in_url() {
2031 let driver = GeminiDriver::new("my-secret-key".to_string(), None);
2032 let url = driver.build_url("gemini-pro");
2033 assert!(url.contains("key=my-secret-key"));
2034 assert!(url.contains("models/gemini-pro:generateContent"));
2035 }
2036
2037 #[test]
2042 fn ollama_request_formatting() {
2043 let driver = OllamaDriver::new(None);
2044 let body = driver.build_request_body(&simple_request());
2045
2046 assert_eq!(body["model"], "test-model");
2047 assert_eq!(body["stream"], false);
2048 let messages = body["messages"].as_array().unwrap();
2049 assert_eq!(messages.len(), 2);
2051 assert_eq!(messages[0]["role"], "system");
2052 assert_eq!(messages[0]["content"], "You are helpful.");
2053 assert_eq!(messages[1]["role"], "user");
2054 assert_eq!(messages[1]["content"], "Hello");
2055 assert!((body["options"]["temperature"].as_f64().unwrap() - 0.7).abs() < 0.001);
2056 }
2057
2058 #[test]
2059 fn ollama_response_parsing() {
2060 let driver = OllamaDriver::new(None);
2061 let response_body = serde_json::json!({
2062 "message": {
2063 "role": "assistant",
2064 "content": "Hi there!"
2065 },
2066 "done": true,
2067 "prompt_eval_count": 20,
2068 "eval_count": 10
2069 });
2070
2071 let resp = driver.parse_response(&response_body).unwrap();
2072 assert_eq!(resp.message.content, "Hi there!");
2073 assert_eq!(resp.stop_reason, StopReason::EndTurn);
2074 assert_eq!(resp.usage.input_tokens, 20);
2075 assert_eq!(resp.usage.output_tokens, 10);
2076 }
2077
2078 #[test]
2079 fn ollama_default_endpoint() {
2080 let driver = OllamaDriver::new(None);
2081 assert_eq!(driver.base_url(), "http://localhost:11434");
2082 }
2083
2084 #[test]
2085 fn ollama_custom_endpoint() {
2086 let driver = OllamaDriver::new(Some("http://myhost:9999".to_string()));
2087 assert_eq!(driver.base_url(), "http://myhost:9999");
2088 }
2089
2090 #[test]
2095 fn bedrock_request_formatting() {
2096 let driver = BedrockDriver::new(
2097 "TESTKEY".to_string(),
2098 "testsecret".to_string(),
2099 "us-west-2".to_string(),
2100 );
2101 let body = driver.build_request_body(&simple_request());
2102
2103 let messages = body["messages"].as_array().unwrap();
2104 assert_eq!(messages.len(), 1);
2105 assert_eq!(messages[0]["role"], "user");
2106 assert_eq!(messages[0]["content"][0]["text"], "Hello");
2107
2108 assert_eq!(body["inferenceConfig"]["maxTokens"], 4096);
2109 assert!((body["inferenceConfig"]["temperature"].as_f64().unwrap() - 0.7).abs() < 0.001);
2110 assert_eq!(body["system"][0]["text"], "You are helpful.");
2111 }
2112
2113 #[test]
2114 fn bedrock_sigv4_canonical_request() {
2115 let driver = BedrockDriver::new(
2116 "TESTACCESS1234567890".to_string(),
2117 "TestSecretKeyValue1234567890abcdefghijk".to_string(),
2118 "us-east-1".to_string(),
2119 );
2120
2121 let payload = b"{}";
2122 let timestamp = "20260313T120000Z";
2123
2124 let auth = driver
2125 .sign_request(
2126 "POST",
2127 "https://bedrock-runtime.us-east-1.amazonaws.com/model/test/converse",
2128 &[("content-type".to_string(), "application/json".to_string())],
2129 payload,
2130 timestamp,
2131 )
2132 .unwrap();
2133
2134 assert!(auth.starts_with(
2135 "AWS4-HMAC-SHA256 Credential=TESTACCESS1234567890/20260313/us-east-1/bedrock/aws4_request"
2136 ));
2137 assert!(auth.contains("SignedHeaders=content-type;host;x-amz-date"));
2138 assert!(auth.contains("Signature="));
2139 }
2140
2141 #[test]
2142 fn bedrock_response_parsing() {
2143 let driver = BedrockDriver::new(
2144 "key".to_string(),
2145 "secret".to_string(),
2146 "us-east-1".to_string(),
2147 );
2148 let response_body = serde_json::json!({
2149 "output": {
2150 "message": {
2151 "role": "assistant",
2152 "content": [{"text": "The answer is 42."}]
2153 }
2154 },
2155 "stopReason": "end_turn",
2156 "usage": {
2157 "inputTokens": 100,
2158 "outputTokens": 50
2159 }
2160 });
2161
2162 let resp = driver.parse_response(&response_body).unwrap();
2163 assert_eq!(resp.message.content, "The answer is 42.");
2164 assert_eq!(resp.stop_reason, StopReason::EndTurn);
2165 assert_eq!(resp.usage.input_tokens, 100);
2166 assert_eq!(resp.usage.output_tokens, 50);
2167 }
2168
2169 #[test]
2174 fn azure_openai_url_construction() {
2175 let driver = AzureOpenAiDriver::new(
2176 "my-azure-key".to_string(),
2177 "myresource".to_string(),
2178 "gpt-4-deployment".to_string(),
2179 None,
2180 );
2181 let url = driver.build_url();
2182 assert_eq!(
2183 url,
2184 "https://myresource.openai.azure.com/openai/deployments/gpt-4-deployment/chat/completions?api-version=2024-02-01"
2185 );
2186 }
2187
2188 #[test]
2189 fn azure_openai_custom_api_version() {
2190 let driver = AzureOpenAiDriver::new(
2191 "key".to_string(),
2192 "res".to_string(),
2193 "dep".to_string(),
2194 Some("2024-06-01".to_string()),
2195 );
2196 let url = driver.build_url();
2197 assert!(url.contains("api-version=2024-06-01"));
2198 }
2199
2200 #[test]
2201 fn azure_openai_request_formatting() {
2202 let driver = AzureOpenAiDriver::new(
2203 "key".to_string(),
2204 "res".to_string(),
2205 "dep".to_string(),
2206 None,
2207 );
2208 let body = driver.build_request_body(&simple_request());
2209 let messages = body["messages"].as_array().unwrap();
2211 assert_eq!(messages.len(), 2);
2213 assert_eq!(messages[0]["role"], "system");
2214 assert_eq!(messages[1]["role"], "user");
2215 assert_eq!(body["model"], "test-model");
2216 }
2217
2218 #[test]
2219 fn azure_openai_resource_and_deployment() {
2220 let driver = AzureOpenAiDriver::new(
2221 "key".to_string(),
2222 "my-resource".to_string(),
2223 "my-deploy".to_string(),
2224 None,
2225 );
2226 assert_eq!(driver.resource(), "my-resource");
2227 assert_eq!(driver.deployment(), "my-deploy");
2228 }
2229
2230 #[test]
2235 fn create_driver_dispatches_ollama() {
2236 let config = ModelConfig {
2237 provider: Provider::Ollama,
2238 model: "llama3".to_string(),
2239 api_key_env: None,
2240 base_url: None,
2241 max_tokens: None,
2242 temperature: None,
2243 };
2244 let driver = create_driver(&config);
2246 assert!(driver.is_ok());
2247 }
2248
2249 #[test]
2250 fn create_driver_dispatches_gemini() {
2251 unsafe { std::env::set_var("TEST_GEMINI_KEY_DISPATCH", "fake-key") };
2254 let config = ModelConfig {
2255 provider: Provider::Google,
2256 model: "gemini-pro".to_string(),
2257 api_key_env: Some("TEST_GEMINI_KEY_DISPATCH".to_string()),
2258 base_url: None,
2259 max_tokens: None,
2260 temperature: None,
2261 };
2262 let driver = create_driver(&config);
2263 assert!(driver.is_ok());
2264 unsafe { std::env::remove_var("TEST_GEMINI_KEY_DISPATCH") };
2265 }
2266
2267 #[test]
2268 fn create_driver_dispatches_bedrock() {
2269 unsafe { std::env::set_var("TEST_BEDROCK_KEY_DISPATCH", "TESTKEY:TESTSECRET") };
2271 let config = ModelConfig {
2272 provider: Provider::Bedrock,
2273 model: "anthropic.claude-v2".to_string(),
2274 api_key_env: Some("TEST_BEDROCK_KEY_DISPATCH".to_string()),
2275 base_url: None,
2276 max_tokens: None,
2277 temperature: None,
2278 };
2279 let driver = create_driver(&config);
2280 assert!(driver.is_ok());
2281 unsafe { std::env::remove_var("TEST_BEDROCK_KEY_DISPATCH") };
2282 }
2283
2284 #[test]
2285 fn create_driver_dispatches_azure_openai() {
2286 unsafe { std::env::set_var("TEST_AZURE_KEY_DISPATCH", "azure-key") };
2288 let config = ModelConfig {
2289 provider: Provider::AzureOpenAi,
2290 model: "gpt-4".to_string(),
2291 api_key_env: Some("TEST_AZURE_KEY_DISPATCH".to_string()),
2292 base_url: Some("https://myres.openai.azure.com".to_string()),
2293 max_tokens: None,
2294 temperature: None,
2295 };
2296 let driver = create_driver(&config);
2297 assert!(driver.is_ok());
2298 unsafe { std::env::remove_var("TEST_AZURE_KEY_DISPATCH") };
2299 }
2300
2301 #[test]
2302 fn gemini_tools_in_request() {
2303 let driver = GeminiDriver::new("key".to_string(), None);
2304 let body = driver.build_request_body(&request_with_tools());
2305
2306 let tools = body["tools"].as_array().unwrap();
2307 assert_eq!(tools.len(), 1);
2308 let func_decls = tools[0]["function_declarations"].as_array().unwrap();
2309 assert_eq!(func_decls.len(), 1);
2310 assert_eq!(func_decls[0]["name"], "get_weather");
2311 }
2312
2313 #[test]
2314 fn ollama_tools_in_request() {
2315 let driver = OllamaDriver::new(None);
2316 let body = driver.build_request_body(&request_with_tools());
2317
2318 let tools = body["tools"].as_array().unwrap();
2319 assert_eq!(tools.len(), 1);
2320 assert_eq!(tools[0]["type"], "function");
2321 assert_eq!(tools[0]["function"]["name"], "get_weather");
2322 }
2323
2324 #[test]
2325 fn bedrock_url_construction() {
2326 let driver = BedrockDriver::new(
2327 "key".to_string(),
2328 "secret".to_string(),
2329 "eu-west-1".to_string(),
2330 );
2331 let url = driver.build_url("anthropic.claude-3-sonnet");
2332 assert_eq!(
2333 url,
2334 "https://bedrock-runtime.eu-west-1.amazonaws.com/model/anthropic.claude-3-sonnet/converse"
2335 );
2336 }
2337
2338 #[test]
2343 fn token_usage_default() {
2344 let u = TokenUsage::default();
2345 assert_eq!(u.input_tokens, 0);
2346 assert_eq!(u.output_tokens, 0);
2347 assert_eq!(u.total(), 0);
2348 }
2349
2350 #[test]
2351 fn token_usage_accumulate() {
2352 let mut u = TokenUsage { input_tokens: 10, output_tokens: 20 };
2353 let other = TokenUsage { input_tokens: 5, output_tokens: 15 };
2354 u.accumulate(&other);
2355 assert_eq!(u.input_tokens, 15);
2356 assert_eq!(u.output_tokens, 35);
2357 assert_eq!(u.total(), 50);
2358 }
2359
2360 #[test]
2361 fn token_usage_total() {
2362 let u = TokenUsage { input_tokens: 100, output_tokens: 200 };
2363 assert_eq!(u.total(), 300);
2364 }
2365
2366 #[test]
2371 fn stop_reason_serialization() {
2372 let json = serde_json::to_string(&StopReason::EndTurn).unwrap();
2373 assert_eq!(json, "\"end_turn\"");
2374
2375 let json = serde_json::to_string(&StopReason::ToolUse).unwrap();
2376 assert_eq!(json, "\"tool_use\"");
2377
2378 let json = serde_json::to_string(&StopReason::MaxTokens).unwrap();
2379 assert_eq!(json, "\"max_tokens\"");
2380
2381 let json = serde_json::to_string(&StopReason::Error).unwrap();
2382 assert_eq!(json, "\"error\"");
2383 }
2384
2385 #[test]
2386 fn stop_reason_deserialization() {
2387 let sr: StopReason = serde_json::from_str("\"end_turn\"").unwrap();
2388 assert_eq!(sr, StopReason::EndTurn);
2389
2390 let sr: StopReason = serde_json::from_str("\"tool_use\"").unwrap();
2391 assert_eq!(sr, StopReason::ToolUse);
2392 }
2393
2394 #[test]
2399 fn anthropic_request_body_simple() {
2400 let driver = AnthropicDriver::new("test-key".to_string(), None);
2401 let body = driver.build_request_body(&simple_request());
2402
2403 assert_eq!(body["model"], "test-model");
2404 assert_eq!(body["max_tokens"], 4096);
2405 assert_eq!(body["system"], "You are helpful.");
2406 assert!((body["temperature"].as_f64().unwrap() - 0.7).abs() < 0.001);
2407
2408 let messages = body["messages"].as_array().unwrap();
2409 assert_eq!(messages.len(), 1);
2410 assert_eq!(messages[0]["role"], "user");
2411 assert_eq!(messages[0]["content"], "Hello");
2412 }
2413
2414 #[test]
2415 fn anthropic_request_body_with_tools() {
2416 let driver = AnthropicDriver::new("test-key".to_string(), None);
2417 let body = driver.build_request_body(&request_with_tools());
2418
2419 let tools = body["tools"].as_array().unwrap();
2420 assert_eq!(tools.len(), 1);
2421 assert_eq!(tools[0]["name"], "get_weather");
2422 assert!(tools[0]["input_schema"]["properties"].is_object());
2423 }
2424
2425 #[test]
2426 fn anthropic_request_body_no_system_prompt() {
2427 let driver = AnthropicDriver::new("test-key".to_string(), None);
2428 let req = CompletionRequest {
2429 model: "test".into(),
2430 messages: vec![Message::new(Role::User, "Hi")],
2431 tools: Vec::new(),
2432 max_tokens: 100,
2433 temperature: None,
2434 system_prompt: None,
2435 };
2436 let body = driver.build_request_body(&req);
2437 assert!(body.get("system").is_none());
2438 assert!(body.get("temperature").is_none());
2439 }
2440
2441 #[test]
2442 fn anthropic_parse_response_text() {
2443 let driver = AnthropicDriver::new("test-key".to_string(), None);
2444 let response_body = serde_json::json!({
2445 "content": [
2446 {"type": "text", "text": "Hello!"}
2447 ],
2448 "stop_reason": "end_turn",
2449 "usage": {
2450 "input_tokens": 10,
2451 "output_tokens": 5
2452 }
2453 });
2454
2455 let resp = driver.parse_response(&response_body).unwrap();
2456 assert_eq!(resp.message.content, "Hello!");
2457 assert_eq!(resp.stop_reason, StopReason::EndTurn);
2458 assert_eq!(resp.usage.input_tokens, 10);
2459 assert_eq!(resp.usage.output_tokens, 5);
2460 assert!(resp.message.tool_calls.is_empty());
2461 }
2462
2463 #[test]
2464 fn anthropic_parse_response_tool_use() {
2465 let driver = AnthropicDriver::new("test-key".to_string(), None);
2466 let response_body = serde_json::json!({
2467 "content": [
2468 {"type": "text", "text": "Let me check."},
2469 {
2470 "type": "tool_use",
2471 "id": "tool_abc",
2472 "name": "get_weather",
2473 "input": {"city": "NYC"}
2474 }
2475 ],
2476 "stop_reason": "tool_use",
2477 "usage": {"input_tokens": 20, "output_tokens": 15}
2478 });
2479
2480 let resp = driver.parse_response(&response_body).unwrap();
2481 assert_eq!(resp.message.content, "Let me check.");
2482 assert_eq!(resp.stop_reason, StopReason::ToolUse);
2483 assert_eq!(resp.message.tool_calls.len(), 1);
2484 assert_eq!(resp.message.tool_calls[0].id, "tool_abc");
2485 assert_eq!(resp.message.tool_calls[0].name, "get_weather");
2486 assert_eq!(resp.message.tool_calls[0].input["city"], "NYC");
2487 }
2488
2489 #[test]
2490 fn anthropic_parse_response_max_tokens() {
2491 let driver = AnthropicDriver::new("test-key".to_string(), None);
2492 let response_body = serde_json::json!({
2493 "content": [{"type": "text", "text": "truncated"}],
2494 "stop_reason": "max_tokens",
2495 "usage": {"input_tokens": 5, "output_tokens": 100}
2496 });
2497
2498 let resp = driver.parse_response(&response_body).unwrap();
2499 assert_eq!(resp.stop_reason, StopReason::MaxTokens);
2500 }
2501
2502 #[test]
2503 fn anthropic_parse_response_unknown_stop_reason() {
2504 let driver = AnthropicDriver::new("test-key".to_string(), None);
2505 let response_body = serde_json::json!({
2506 "content": [{"type": "text", "text": "err"}],
2507 "stop_reason": "something_unknown",
2508 "usage": {"input_tokens": 0, "output_tokens": 0}
2509 });
2510
2511 let resp = driver.parse_response(&response_body).unwrap();
2512 assert_eq!(resp.stop_reason, StopReason::Error);
2513 }
2514
2515 #[test]
2516 fn anthropic_request_body_with_assistant_and_tool_messages() {
2517 let driver = AnthropicDriver::new("test-key".to_string(), None);
2518 let req = CompletionRequest {
2519 model: "test".into(),
2520 messages: vec![
2521 Message::new(Role::User, "Hi"),
2522 Message {
2523 role: Role::Assistant,
2524 content: "I'll check".into(),
2525 tool_calls: vec![ToolCall {
2526 id: "call_1".into(),
2527 name: "file_read".into(),
2528 input: serde_json::json!({"path": "/tmp/test"}),
2529 }],
2530 tool_results: Vec::new(),
2531 timestamp: chrono::Utc::now(),
2532 },
2533 Message {
2534 role: Role::Tool,
2535 content: String::new(),
2536 tool_calls: Vec::new(),
2537 tool_results: vec![punch_types::ToolCallResult {
2538 id: "call_1".into(),
2539 content: "file contents".into(),
2540 is_error: false,
2541 }],
2542 timestamp: chrono::Utc::now(),
2543 },
2544 ],
2545 tools: Vec::new(),
2546 max_tokens: 100,
2547 temperature: None,
2548 system_prompt: None,
2549 };
2550
2551 let body = driver.build_request_body(&req);
2552 let messages = body["messages"].as_array().unwrap();
2553 assert_eq!(messages.len(), 3);
2554 assert_eq!(messages[0]["role"], "user");
2555 assert_eq!(messages[1]["role"], "assistant");
2556 assert_eq!(messages[2]["role"], "user"); }
2558
2559 #[test]
2560 fn anthropic_request_body_system_message_skipped() {
2561 let driver = AnthropicDriver::new("test-key".to_string(), None);
2562 let req = CompletionRequest {
2563 model: "test".into(),
2564 messages: vec![
2565 Message::new(Role::System, "System instruction"),
2566 Message::new(Role::User, "Hi"),
2567 ],
2568 tools: Vec::new(),
2569 max_tokens: 100,
2570 temperature: None,
2571 system_prompt: None,
2572 };
2573
2574 let body = driver.build_request_body(&req);
2575 let messages = body["messages"].as_array().unwrap();
2576 assert_eq!(messages.len(), 1);
2578 assert_eq!(messages[0]["role"], "user");
2579 }
2580
2581 #[test]
2586 fn openai_request_body_simple() {
2587 let driver = OpenAiCompatibleDriver::new(
2588 "key".into(),
2589 "https://api.openai.com".into(),
2590 "openai".into(),
2591 );
2592 let body = driver.build_request_body(&simple_request());
2593
2594 assert_eq!(body["model"], "test-model");
2595 let messages = body["messages"].as_array().unwrap();
2596 assert_eq!(messages.len(), 2);
2597 assert_eq!(messages[0]["role"], "system");
2598 assert_eq!(messages[0]["content"], "You are helpful.");
2599 assert_eq!(messages[1]["role"], "user");
2600 }
2601
2602 #[test]
2603 fn openai_request_body_with_tools() {
2604 let driver = OpenAiCompatibleDriver::new(
2605 "key".into(),
2606 "https://api.openai.com".into(),
2607 "openai".into(),
2608 );
2609 let body = driver.build_request_body(&request_with_tools());
2610
2611 let tools = body["tools"].as_array().unwrap();
2612 assert_eq!(tools.len(), 1);
2613 assert_eq!(tools[0]["type"], "function");
2614 assert_eq!(tools[0]["function"]["name"], "get_weather");
2615 }
2616
2617 #[test]
2618 fn openai_parse_response_text() {
2619 let driver = OpenAiCompatibleDriver::new(
2620 "key".into(),
2621 "https://api.openai.com".into(),
2622 "openai".into(),
2623 );
2624 let response_body = serde_json::json!({
2625 "choices": [{
2626 "message": {
2627 "role": "assistant",
2628 "content": "Hello!"
2629 },
2630 "finish_reason": "stop"
2631 }],
2632 "usage": {
2633 "prompt_tokens": 10,
2634 "completion_tokens": 5
2635 }
2636 });
2637
2638 let resp = driver.parse_response(&response_body).unwrap();
2639 assert_eq!(resp.message.content, "Hello!");
2640 assert_eq!(resp.stop_reason, StopReason::EndTurn);
2641 assert_eq!(resp.usage.input_tokens, 10);
2642 assert_eq!(resp.usage.output_tokens, 5);
2643 }
2644
2645 #[test]
2646 fn openai_parse_response_tool_calls() {
2647 let driver = OpenAiCompatibleDriver::new(
2648 "key".into(),
2649 "https://api.openai.com".into(),
2650 "openai".into(),
2651 );
2652 let response_body = serde_json::json!({
2653 "choices": [{
2654 "message": {
2655 "role": "assistant",
2656 "content": null,
2657 "tool_calls": [{
2658 "id": "call_123",
2659 "type": "function",
2660 "function": {
2661 "name": "get_weather",
2662 "arguments": "{\"city\": \"NYC\"}"
2663 }
2664 }]
2665 },
2666 "finish_reason": "tool_calls"
2667 }],
2668 "usage": {"prompt_tokens": 10, "completion_tokens": 5}
2669 });
2670
2671 let resp = driver.parse_response(&response_body).unwrap();
2672 assert_eq!(resp.stop_reason, StopReason::ToolUse);
2673 assert_eq!(resp.message.tool_calls.len(), 1);
2674 assert_eq!(resp.message.tool_calls[0].name, "get_weather");
2675 assert_eq!(resp.message.tool_calls[0].input["city"], "NYC");
2676 }
2677
2678 #[test]
2679 fn openai_parse_response_tool_calls_fix_stop_reason() {
2680 let driver = OpenAiCompatibleDriver::new(
2681 "key".into(),
2682 "https://api.openai.com".into(),
2683 "openai".into(),
2684 );
2685 let response_body = serde_json::json!({
2687 "choices": [{
2688 "message": {
2689 "role": "assistant",
2690 "content": "Using tool",
2691 "tool_calls": [{
2692 "id": "call_1",
2693 "type": "function",
2694 "function": {
2695 "name": "test_tool",
2696 "arguments": "{}"
2697 }
2698 }]
2699 },
2700 "finish_reason": "stop"
2701 }],
2702 "usage": {"prompt_tokens": 0, "completion_tokens": 0}
2703 });
2704
2705 let resp = driver.parse_response(&response_body).unwrap();
2706 assert_eq!(resp.stop_reason, StopReason::ToolUse);
2707 }
2708
2709 #[test]
2710 fn openai_parse_response_length_stop_reason() {
2711 let driver = OpenAiCompatibleDriver::new(
2712 "key".into(),
2713 "https://api.openai.com".into(),
2714 "openai".into(),
2715 );
2716 let response_body = serde_json::json!({
2717 "choices": [{
2718 "message": {"role": "assistant", "content": "cut off"},
2719 "finish_reason": "length"
2720 }],
2721 "usage": {"prompt_tokens": 0, "completion_tokens": 0}
2722 });
2723
2724 let resp = driver.parse_response(&response_body).unwrap();
2725 assert_eq!(resp.stop_reason, StopReason::MaxTokens);
2726 }
2727
2728 #[test]
2729 fn openai_parse_response_no_choices_error() {
2730 let driver = OpenAiCompatibleDriver::new(
2731 "key".into(),
2732 "https://api.openai.com".into(),
2733 "openai".into(),
2734 );
2735 let response_body = serde_json::json!({"choices": []});
2736
2737 let result = driver.parse_response(&response_body);
2738 assert!(result.is_err());
2739 }
2740
2741 #[test]
2746 fn gemini_assistant_message_formatting() {
2747 let driver = GeminiDriver::new("key".to_string(), None);
2748 let req = CompletionRequest {
2749 model: "gemini-pro".into(),
2750 messages: vec![
2751 Message::new(Role::User, "Hi"),
2752 Message {
2753 role: Role::Assistant,
2754 content: "Let me help".into(),
2755 tool_calls: vec![ToolCall {
2756 id: "tc1".into(),
2757 name: "get_weather".into(),
2758 input: serde_json::json!({"city": "NYC"}),
2759 }],
2760 tool_results: Vec::new(),
2761 timestamp: chrono::Utc::now(),
2762 },
2763 ],
2764 tools: Vec::new(),
2765 max_tokens: 100,
2766 temperature: None,
2767 system_prompt: None,
2768 };
2769
2770 let body = driver.build_request_body(&req);
2771 let contents = body["contents"].as_array().unwrap();
2772 assert_eq!(contents[1]["role"], "model"); let parts = contents[1]["parts"].as_array().unwrap();
2774 assert!(parts.len() >= 2); }
2776
2777 #[test]
2778 fn gemini_max_tokens_stop_reason() {
2779 let driver = GeminiDriver::new("key".to_string(), None);
2780 let response_body = serde_json::json!({
2781 "candidates": [{
2782 "content": {
2783 "parts": [{"text": "truncated"}],
2784 "role": "model"
2785 },
2786 "finishReason": "MAX_TOKENS"
2787 }],
2788 "usageMetadata": {"promptTokenCount": 0, "candidatesTokenCount": 0}
2789 });
2790
2791 let resp = driver.parse_response(&response_body).unwrap();
2792 assert_eq!(resp.stop_reason, StopReason::MaxTokens);
2793 }
2794
2795 #[test]
2796 fn gemini_custom_base_url() {
2797 let driver = GeminiDriver::new("key".to_string(), Some("https://custom.example.com".into()));
2798 let url = driver.build_url("gemini-pro");
2799 assert!(url.starts_with("https://custom.example.com/"));
2800 }
2801
2802 #[test]
2807 fn ollama_response_with_tool_calls() {
2808 let driver = OllamaDriver::new(None);
2809 let response_body = serde_json::json!({
2810 "message": {
2811 "role": "assistant",
2812 "content": "",
2813 "tool_calls": [{
2814 "function": {
2815 "name": "get_weather",
2816 "arguments": {"city": "London"}
2817 }
2818 }]
2819 },
2820 "done": true,
2821 "prompt_eval_count": 10,
2822 "eval_count": 5
2823 });
2824
2825 let resp = driver.parse_response(&response_body).unwrap();
2826 assert_eq!(resp.stop_reason, StopReason::ToolUse);
2827 assert_eq!(resp.message.tool_calls.len(), 1);
2828 assert_eq!(resp.message.tool_calls[0].name, "get_weather");
2829 }
2830
2831 #[test]
2832 fn ollama_response_not_done() {
2833 let driver = OllamaDriver::new(None);
2834 let response_body = serde_json::json!({
2835 "message": {"role": "assistant", "content": "partial"},
2836 "done": false,
2837 "prompt_eval_count": 10,
2838 "eval_count": 5
2839 });
2840
2841 let resp = driver.parse_response(&response_body).unwrap();
2842 assert_eq!(resp.stop_reason, StopReason::MaxTokens);
2843 }
2844
2845 #[test]
2850 fn bedrock_request_with_tools() {
2851 let driver = BedrockDriver::new("key".into(), "secret".into(), "us-east-1".into());
2852 let body = driver.build_request_body(&request_with_tools());
2853
2854 let tool_config = &body["toolConfig"]["tools"];
2855 assert!(tool_config.is_array());
2856 let tools = tool_config.as_array().unwrap();
2857 assert_eq!(tools.len(), 1);
2858 assert_eq!(tools[0]["toolSpec"]["name"], "get_weather");
2859 }
2860
2861 #[test]
2862 fn bedrock_response_with_tool_use() {
2863 let driver = BedrockDriver::new("key".into(), "secret".into(), "us-east-1".into());
2864 let response_body = serde_json::json!({
2865 "output": {
2866 "message": {
2867 "role": "assistant",
2868 "content": [
2869 {"text": "Using tool"},
2870 {"toolUse": {
2871 "toolUseId": "tu_123",
2872 "name": "get_weather",
2873 "input": {"city": "NYC"}
2874 }}
2875 ]
2876 }
2877 },
2878 "stopReason": "tool_use",
2879 "usage": {"inputTokens": 10, "outputTokens": 20}
2880 });
2881
2882 let resp = driver.parse_response(&response_body).unwrap();
2883 assert_eq!(resp.stop_reason, StopReason::ToolUse);
2884 assert_eq!(resp.message.tool_calls.len(), 1);
2885 assert_eq!(resp.message.tool_calls[0].id, "tu_123");
2886 assert_eq!(resp.message.tool_calls[0].name, "get_weather");
2887 }
2888
2889 #[test]
2890 fn bedrock_request_with_tool_results() {
2891 let driver = BedrockDriver::new("key".into(), "secret".into(), "us-east-1".into());
2892 let req = CompletionRequest {
2893 model: "test".into(),
2894 messages: vec![
2895 Message::new(Role::User, "Hi"),
2896 Message {
2897 role: Role::Tool,
2898 content: String::new(),
2899 tool_calls: Vec::new(),
2900 tool_results: vec![punch_types::ToolCallResult {
2901 id: "tu_1".into(),
2902 content: "result data".into(),
2903 is_error: false,
2904 }],
2905 timestamp: chrono::Utc::now(),
2906 },
2907 ],
2908 tools: Vec::new(),
2909 max_tokens: 100,
2910 temperature: None,
2911 system_prompt: None,
2912 };
2913
2914 let body = driver.build_request_body(&req);
2915 let messages = body["messages"].as_array().unwrap();
2916 assert_eq!(messages[1]["role"], "user"); let content = messages[1]["content"].as_array().unwrap();
2918 assert!(content[0]["toolResult"].is_object());
2919 assert_eq!(content[0]["toolResult"]["status"], "success");
2920 }
2921
2922 #[test]
2923 fn bedrock_url_different_regions() {
2924 let driver = BedrockDriver::new("k".into(), "s".into(), "ap-southeast-1".into());
2925 let url = driver.build_url("model-id");
2926 assert!(url.contains("ap-southeast-1"));
2927 }
2928
2929 #[test]
2934 fn azure_openai_delegates_parse_to_openai() {
2935 let driver = AzureOpenAiDriver::new(
2936 "key".into(), "res".into(), "dep".into(), None,
2937 );
2938 let response_body = serde_json::json!({
2939 "choices": [{
2940 "message": {"role": "assistant", "content": "Azure response"},
2941 "finish_reason": "stop"
2942 }],
2943 "usage": {"prompt_tokens": 5, "completion_tokens": 3}
2944 });
2945
2946 let resp = driver.parse_response(&response_body).unwrap();
2947 assert_eq!(resp.message.content, "Azure response");
2948 }
2949
2950 #[test]
2955 fn default_base_url_anthropic() {
2956 assert_eq!(default_base_url(&Provider::Anthropic), "https://api.anthropic.com");
2957 }
2958
2959 #[test]
2960 fn default_base_url_openai() {
2961 assert_eq!(default_base_url(&Provider::OpenAI), "https://api.openai.com");
2962 }
2963
2964 #[test]
2965 fn default_base_url_google() {
2966 assert_eq!(default_base_url(&Provider::Google), "https://generativelanguage.googleapis.com");
2967 }
2968
2969 #[test]
2970 fn default_base_url_ollama() {
2971 assert_eq!(default_base_url(&Provider::Ollama), "http://localhost:11434");
2972 }
2973
2974 #[test]
2975 fn default_base_url_groq() {
2976 assert_eq!(default_base_url(&Provider::Groq), "https://api.groq.com/openai");
2977 }
2978
2979 #[test]
2980 fn default_base_url_deepseek() {
2981 assert_eq!(default_base_url(&Provider::DeepSeek), "https://api.deepseek.com");
2982 }
2983
2984 #[test]
2989 fn test_hex_sha256() {
2990 let hash = hex_sha256(b"");
2991 assert_eq!(hash, "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855");
2992 }
2993
2994 #[test]
2995 fn test_hex_encode() {
2996 assert_eq!(hex_encode(&[0x00, 0xff, 0x0a, 0xbc]), "00ff0abc");
2997 }
2998
2999 #[test]
3000 fn test_hmac_sha256_basic() {
3001 let result = hmac_sha256(b"key", b"data");
3002 assert!(!result.is_empty());
3003 assert_eq!(result.len(), 32); }
3005
3006 #[test]
3011 fn create_driver_missing_api_key_env() {
3012 let config = ModelConfig {
3013 provider: Provider::Anthropic,
3014 model: "claude-3".into(),
3015 api_key_env: Some("PUNCH_TEST_NONEXISTENT_KEY_XYZ".into()),
3016 base_url: None,
3017 max_tokens: None,
3018 temperature: None,
3019 };
3020 let result = create_driver(&config);
3021 assert!(result.is_err());
3022 }
3023
3024 #[test]
3025 fn create_driver_openai_compatible_fallback() {
3026 unsafe { std::env::set_var("TEST_CUSTOM_KEY_DRIVER", "fake-key") };
3028 let config = ModelConfig {
3029 provider: Provider::Custom("my-custom".into()),
3030 model: "custom-model".into(),
3031 api_key_env: Some("TEST_CUSTOM_KEY_DRIVER".into()),
3032 base_url: Some("https://custom.api.com".into()),
3033 max_tokens: None,
3034 temperature: None,
3035 };
3036 let result = create_driver(&config);
3037 assert!(result.is_ok());
3038 unsafe { std::env::remove_var("TEST_CUSTOM_KEY_DRIVER") };
3039 }
3040
3041 #[test]
3046 fn strip_thinking_tags_removes_think_block() {
3047 let input = "<think>internal reasoning here</think>The answer is 42.";
3048 assert_eq!(strip_thinking_tags(input), "The answer is 42.");
3049 }
3050
3051 #[test]
3052 fn strip_thinking_tags_removes_thinking_block() {
3053 let input = "<thinking>step by step reasoning</thinking>Hello world!";
3054 assert_eq!(strip_thinking_tags(input), "Hello world!");
3055 }
3056
3057 #[test]
3058 fn strip_thinking_tags_removes_reasoning_block() {
3059 let input = "<reasoning>let me figure this out</reasoning>The result is correct.";
3060 assert_eq!(strip_thinking_tags(input), "The result is correct.");
3061 }
3062
3063 #[test]
3064 fn strip_thinking_tags_removes_reflection_block() {
3065 let input = "<reflection>checking my work</reflection>Yes, that's right.";
3066 assert_eq!(strip_thinking_tags(input), "Yes, that's right.");
3067 }
3068
3069 #[test]
3070 fn strip_thinking_tags_removes_multiple_blocks() {
3071 let input = "<think>first thought</think>Hello <thinking>second thought</thinking>world!";
3072 assert_eq!(strip_thinking_tags(input), "Hello world!");
3073 }
3074
3075 #[test]
3076 fn strip_thinking_tags_preserves_content_without_tags() {
3077 let input = "Just a normal response with no thinking tags.";
3078 assert_eq!(strip_thinking_tags(input), input);
3079 }
3080
3081 #[test]
3082 fn strip_thinking_tags_handles_multiline_tags() {
3083 let input = "<think>\nLine 1\nLine 2\nLine 3\n</think>\nThe final answer.";
3084 assert_eq!(strip_thinking_tags(input), "The final answer.");
3085 }
3086
3087 #[test]
3088 fn strip_thinking_tags_returns_original_if_all_thinking() {
3089 let input = "<think>this is all thinking content and nothing else</think>";
3092 assert_eq!(strip_thinking_tags(input), input);
3093 }
3094
3095 #[test]
3096 fn strip_thinking_tags_handles_unclosed_tag() {
3097 let input = "Some text<think>unclosed thinking block";
3098 assert_eq!(strip_thinking_tags(input), "Some text");
3099 }
3100
3101 #[test]
3102 fn strip_thinking_tags_handles_empty_input() {
3103 assert_eq!(strip_thinking_tags(""), "");
3104 }
3105
3106 #[test]
3107 fn strip_thinking_tags_handles_empty_think_block() {
3108 let input = "<think></think>Visible content.";
3109 assert_eq!(strip_thinking_tags(input), "Visible content.");
3110 }
3111
3112 #[test]
3113 fn strip_thinking_tags_trims_whitespace() {
3114 let input = " <think>reasoning</think> Result ";
3115 assert_eq!(strip_thinking_tags(input), "Result");
3116 }
3117
3118 #[test]
3119 fn strip_thinking_tags_mixed_tag_types() {
3120 let input = "<think>t1</think>A<reasoning>r1</reasoning>B<reflection>f1</reflection>C";
3121 assert_eq!(strip_thinking_tags(input), "ABC");
3122 }
3123
3124 #[test]
3125 fn ollama_response_strips_thinking_tags() {
3126 let driver = OllamaDriver::new(None);
3127 let response_body = serde_json::json!({
3128 "message": {
3129 "role": "assistant",
3130 "content": "<think>\nLet me think about this...\nThe user wants hello world.\n</think>\nHello, world!"
3131 },
3132 "done": true,
3133 "prompt_eval_count": 20,
3134 "eval_count": 50
3135 });
3136
3137 let resp = driver.parse_response(&response_body).unwrap();
3138 assert_eq!(resp.message.content, "Hello, world!");
3139 assert!(!resp.message.content.contains("<think>"));
3140 }
3141
3142 #[test]
3143 fn gemini_response_strips_thinking_tags() {
3144 let driver = GeminiDriver::new("test-key".to_string(), None);
3145 let response_body = serde_json::json!({
3146 "candidates": [{
3147 "content": {
3148 "parts": [{"text": "<thinking>reasoning step</thinking>The answer is 7."}],
3149 "role": "model"
3150 },
3151 "finishReason": "STOP"
3152 }],
3153 "usageMetadata": {
3154 "promptTokenCount": 10,
3155 "candidatesTokenCount": 20
3156 }
3157 });
3158
3159 let resp = driver.parse_response(&response_body).unwrap();
3160 assert_eq!(resp.message.content, "The answer is 7.");
3161 assert!(!resp.message.content.contains("<thinking>"));
3162 }
3163
3164 #[test]
3165 fn anthropic_response_strips_thinking_tags() {
3166 let driver = AnthropicDriver::new("test-key".to_string(), None);
3167 let response_body = serde_json::json!({
3168 "content": [
3169 {"type": "text", "text": "<think>internal thought</think>Clean output."}
3170 ],
3171 "stop_reason": "end_turn",
3172 "usage": {"input_tokens": 10, "output_tokens": 5}
3173 });
3174
3175 let resp = driver.parse_response(&response_body).unwrap();
3176 assert_eq!(resp.message.content, "Clean output.");
3177 }
3178
3179 #[test]
3180 fn bedrock_response_strips_thinking_tags() {
3181 let driver = BedrockDriver::new(
3182 "key".to_string(),
3183 "secret".to_string(),
3184 "us-east-1".to_string(),
3185 );
3186 let response_body = serde_json::json!({
3187 "output": {
3188 "message": {
3189 "role": "assistant",
3190 "content": [{"text": "<reasoning>deep thought</reasoning>Result here."}]
3191 }
3192 },
3193 "stopReason": "end_turn",
3194 "usage": {"inputTokens": 50, "outputTokens": 25}
3195 });
3196
3197 let resp = driver.parse_response(&response_body).unwrap();
3198 assert_eq!(resp.message.content, "Result here.");
3199 }
3200}