1use std::sync::Arc;
8
9use async_trait::async_trait;
10use futures::StreamExt;
11use hmac::{Hmac, Mac};
12use reqwest::Client;
13use serde::{Deserialize, Serialize};
14use sha2::{Digest, Sha256};
15
16use punch_types::{
17 Message, ModelConfig, Provider, PunchError, PunchResult, Role, ToolCall, ToolDefinition,
18};
19
20#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
26#[serde(rename_all = "snake_case")]
27pub enum StopReason {
28 EndTurn,
30 ToolUse,
32 MaxTokens,
34 Error,
36}
37
38#[derive(Debug, Clone, Copy, Default, Serialize, Deserialize)]
40pub struct TokenUsage {
41 pub input_tokens: u64,
42 pub output_tokens: u64,
43}
44
45impl TokenUsage {
46 pub fn accumulate(&mut self, other: &TokenUsage) {
48 self.input_tokens += other.input_tokens;
49 self.output_tokens += other.output_tokens;
50 }
51
52 pub fn total(&self) -> u64 {
54 self.input_tokens + self.output_tokens
55 }
56}
57
58#[derive(Debug, Clone, Serialize, Deserialize)]
60pub struct CompletionRequest {
61 pub model: String,
63 pub messages: Vec<Message>,
65 #[serde(default)]
67 pub tools: Vec<ToolDefinition>,
68 pub max_tokens: u32,
70 pub temperature: Option<f32>,
72 pub system_prompt: Option<String>,
74}
75
76#[derive(Debug, Clone, Serialize, Deserialize)]
78pub struct CompletionResponse {
79 pub message: Message,
81 pub usage: TokenUsage,
83 pub stop_reason: StopReason,
85}
86
87#[derive(Debug, Clone, Serialize, Deserialize)]
93pub struct StreamChunk {
94 pub delta: String,
96 pub is_final: bool,
98 pub tool_call_delta: Option<ToolCallDelta>,
100}
101
102#[derive(Debug, Clone, Serialize, Deserialize)]
104pub struct ToolCallDelta {
105 pub index: usize,
106 pub id: Option<String>,
107 pub name: Option<String>,
108 pub arguments_delta: String,
109}
110
111pub type StreamCallback = Arc<dyn Fn(StreamChunk) + Send + Sync>;
113
114fn parse_sse_events(raw: &str) -> Vec<(String, String)> {
122 let mut events = Vec::new();
123 let mut current_event = String::new();
124 let mut current_data = String::new();
125
126 for line in raw.lines() {
127 if line.is_empty() {
128 if !current_data.is_empty() || !current_event.is_empty() {
130 events.push((
131 if current_event.is_empty() {
132 "message".to_string()
133 } else {
134 current_event.clone()
135 },
136 current_data.clone(),
137 ));
138 current_event.clear();
139 current_data.clear();
140 }
141 } else if let Some(val) = line.strip_prefix("event: ") {
142 current_event = val.trim().to_string();
143 } else if let Some(val) = line.strip_prefix("event:") {
144 current_event = val.trim().to_string();
145 } else if let Some(val) = line.strip_prefix("data: ") {
146 if !current_data.is_empty() {
147 current_data.push('\n');
148 }
149 current_data.push_str(val);
150 } else if let Some(val) = line.strip_prefix("data:") {
151 if !current_data.is_empty() {
152 current_data.push('\n');
153 }
154 current_data.push_str(val.trim());
155 }
156 }
157
158 if !current_data.is_empty() || !current_event.is_empty() {
160 events.push((
161 if current_event.is_empty() {
162 "message".to_string()
163 } else {
164 current_event
165 },
166 current_data,
167 ));
168 }
169
170 events
171}
172
173async fn read_stream_body(response: reqwest::Response) -> PunchResult<String> {
175 let mut stream = response.bytes_stream();
176 let mut body = Vec::new();
177 while let Some(chunk) = stream.next().await {
178 let chunk = chunk.map_err(|e| PunchError::Provider {
179 provider: "stream".to_string(),
180 message: format!("stream read error: {e}"),
181 })?;
182 body.extend_from_slice(&chunk);
183 }
184 String::from_utf8(body).map_err(|e| PunchError::Provider {
185 provider: "stream".to_string(),
186 message: format!("invalid UTF-8 in stream: {e}"),
187 })
188}
189
190pub fn strip_thinking_tags(content: &str) -> String {
203 let mut result = content.to_string();
204
205 for tag in &["think", "thinking", "reasoning", "reflection"] {
207 let open = format!("<{}>", tag);
208 let close = format!("</{}>", tag);
209
210 while let Some(start) = result.find(&open) {
212 if let Some(end) = result[start..].find(&close) {
213 let block_end = start + end + close.len();
214 result = format!("{}{}", &result[..start], &result[block_end..]);
215 } else {
216 result = result[..start].to_string();
218 break;
219 }
220 }
221 }
222
223 let trimmed = result.trim().to_string();
224
225 if trimmed.is_empty() {
228 content.to_string()
229 } else {
230 trimmed
231 }
232}
233
234#[async_trait]
240pub trait LlmDriver: Send + Sync + 'static {
241 async fn complete(&self, request: CompletionRequest) -> PunchResult<CompletionResponse>;
243
244 async fn stream_complete(&self, request: CompletionRequest) -> PunchResult<CompletionResponse> {
246 let noop: StreamCallback = Arc::new(|_| {});
247 self.stream_complete_with_callback(request, noop).await
248 }
249
250 async fn stream_complete_with_callback(
253 &self,
254 request: CompletionRequest,
255 callback: StreamCallback,
256 ) -> PunchResult<CompletionResponse> {
257 let response = self.complete(request).await?;
259 callback(StreamChunk {
260 delta: response.message.content.clone(),
261 is_final: true,
262 tool_call_delta: None,
263 });
264 Ok(response)
265 }
266}
267
268pub struct AnthropicDriver {
274 client: Client,
275 api_key: String,
276 base_url: String,
277}
278
279impl AnthropicDriver {
280 pub fn new(api_key: String, base_url: Option<String>) -> Self {
284 Self {
285 client: Client::new(),
286 api_key,
287 base_url: base_url.unwrap_or_else(|| "https://api.anthropic.com".to_string()),
288 }
289 }
290
291 pub fn with_client(client: Client, api_key: String, base_url: Option<String>) -> Self {
295 Self {
296 client,
297 api_key,
298 base_url: base_url.unwrap_or_else(|| "https://api.anthropic.com".to_string()),
299 }
300 }
301
302 fn build_request_body(&self, request: &CompletionRequest) -> serde_json::Value {
304 let mut messages = Vec::new();
305
306 for msg in &request.messages {
307 match msg.role {
308 Role::User => {
309 if msg.content_parts.is_empty() {
310 messages.push(serde_json::json!({
311 "role": "user",
312 "content": msg.content,
313 }));
314 } else {
315 let mut content_blocks: Vec<serde_json::Value> = Vec::new();
317 if !msg.content.is_empty() {
318 content_blocks.push(serde_json::json!({
319 "type": "text",
320 "text": msg.content,
321 }));
322 }
323 for part in &msg.content_parts {
324 match part {
325 punch_types::ContentPart::Text { text } => {
326 content_blocks.push(serde_json::json!({
327 "type": "text",
328 "text": text,
329 }));
330 }
331 punch_types::ContentPart::Image { media_type, data } => {
332 content_blocks.push(serde_json::json!({
333 "type": "image",
334 "source": {
335 "type": "base64",
336 "media_type": media_type,
337 "data": data,
338 },
339 }));
340 }
341 }
342 }
343 messages.push(serde_json::json!({
344 "role": "user",
345 "content": content_blocks,
346 }));
347 }
348 }
349 Role::Assistant => {
350 let mut content_blocks: Vec<serde_json::Value> = Vec::new();
351
352 if !msg.content.is_empty() {
353 content_blocks.push(serde_json::json!({
354 "type": "text",
355 "text": msg.content,
356 }));
357 }
358
359 for tc in &msg.tool_calls {
360 content_blocks.push(serde_json::json!({
361 "type": "tool_use",
362 "id": tc.id,
363 "name": tc.name,
364 "input": tc.input,
365 }));
366 }
367
368 if content_blocks.is_empty() {
369 content_blocks.push(serde_json::json!({
370 "type": "text",
371 "text": "",
372 }));
373 }
374
375 messages.push(serde_json::json!({
376 "role": "assistant",
377 "content": content_blocks,
378 }));
379 }
380 Role::Tool => {
381 let mut result_blocks: Vec<serde_json::Value> = Vec::new();
382 for tr in &msg.tool_results {
383 if let Some(ref image) = tr.image {
385 let mut content: Vec<serde_json::Value> = vec![serde_json::json!({
386 "type": "text",
387 "text": tr.content,
388 })];
389 if let punch_types::ContentPart::Image { media_type, data } = image {
390 content.push(serde_json::json!({
391 "type": "image",
392 "source": {
393 "type": "base64",
394 "media_type": media_type,
395 "data": data,
396 },
397 }));
398 }
399 result_blocks.push(serde_json::json!({
400 "type": "tool_result",
401 "tool_use_id": tr.id,
402 "content": content,
403 "is_error": tr.is_error,
404 }));
405 } else {
406 result_blocks.push(serde_json::json!({
407 "type": "tool_result",
408 "tool_use_id": tr.id,
409 "content": tr.content,
410 "is_error": tr.is_error,
411 }));
412 }
413 }
414 messages.push(serde_json::json!({
415 "role": "user",
416 "content": result_blocks,
417 }));
418 }
419 Role::System => {
420 }
423 }
424 }
425
426 let tools: Vec<serde_json::Value> = request
427 .tools
428 .iter()
429 .map(|t| {
430 serde_json::json!({
431 "name": t.name,
432 "description": t.description,
433 "input_schema": t.input_schema,
434 })
435 })
436 .collect();
437
438 let mut body = serde_json::json!({
439 "model": request.model,
440 "messages": messages,
441 "max_tokens": request.max_tokens,
442 });
443
444 if let Some(temp) = request.temperature {
445 body["temperature"] = serde_json::json!(temp);
446 }
447
448 if let Some(ref system) = request.system_prompt {
452 body["system"] = serde_json::json!([
453 {
454 "type": "text",
455 "text": system,
456 "cache_control": {"type": "ephemeral"},
457 }
458 ]);
459 }
460
461 if !tools.is_empty() {
462 let mut tools_json = serde_json::json!(tools);
465 if let Some(arr) = tools_json.as_array_mut()
466 && let Some(last) = arr.last_mut()
467 {
468 last["cache_control"] = serde_json::json!({"type": "ephemeral"});
469 }
470 body["tools"] = tools_json;
471 }
472
473 body
474 }
475
476 fn parse_response(&self, body: &serde_json::Value) -> PunchResult<CompletionResponse> {
478 let stop_reason = match body["stop_reason"].as_str() {
479 Some("end_turn") => StopReason::EndTurn,
480 Some("tool_use") => StopReason::ToolUse,
481 Some("max_tokens") => StopReason::MaxTokens,
482 _ => StopReason::Error,
483 };
484
485 let usage = TokenUsage {
486 input_tokens: body["usage"]["input_tokens"].as_u64().unwrap_or(0),
487 output_tokens: body["usage"]["output_tokens"].as_u64().unwrap_or(0),
488 };
489
490 let mut text_content = String::new();
491 let mut tool_calls = Vec::new();
492
493 if let Some(content_array) = body["content"].as_array() {
494 for block in content_array {
495 match block["type"].as_str() {
496 Some("text") => {
497 if let Some(text) = block["text"].as_str() {
498 if !text_content.is_empty() {
499 text_content.push('\n');
500 }
501 text_content.push_str(text);
502 }
503 }
504 Some("tool_use") => {
505 tool_calls.push(ToolCall {
506 id: block["id"].as_str().unwrap_or_default().to_string(),
507 name: block["name"].as_str().unwrap_or_default().to_string(),
508 input: block["input"].clone(),
509 });
510 }
511 _ => {}
512 }
513 }
514 }
515
516 let text_content = strip_thinking_tags(&text_content);
518
519 let message = Message {
520 role: Role::Assistant,
521 content: text_content,
522 tool_calls,
523 tool_results: Vec::new(),
524 content_parts: Vec::new(),
525 timestamp: chrono::Utc::now(),
526 };
527
528 Ok(CompletionResponse {
529 message,
530 usage,
531 stop_reason,
532 })
533 }
534}
535
536#[async_trait]
537impl LlmDriver for AnthropicDriver {
538 async fn complete(&self, request: CompletionRequest) -> PunchResult<CompletionResponse> {
539 let url = format!("{}/v1/messages", self.base_url);
540 let body = self.build_request_body(&request);
541
542 let response = self
543 .client
544 .post(&url)
545 .header("x-api-key", &self.api_key)
546 .header("anthropic-version", "2023-06-01")
547 .header("content-type", "application/json")
548 .json(&body)
549 .send()
550 .await
551 .map_err(|e| PunchError::Provider {
552 provider: "anthropic".to_string(),
553 message: format!("request failed: {e}"),
554 })?;
555
556 let status = response.status();
557
558 if status == reqwest::StatusCode::TOO_MANY_REQUESTS {
559 let retry_after = response
560 .headers()
561 .get("retry-after")
562 .and_then(|v| v.to_str().ok())
563 .and_then(|s| s.parse::<u64>().ok())
564 .unwrap_or(60)
565 * 1000;
566
567 return Err(PunchError::RateLimited {
568 provider: "anthropic".to_string(),
569 retry_after_ms: retry_after,
570 });
571 }
572
573 if status == reqwest::StatusCode::UNAUTHORIZED || status == reqwest::StatusCode::FORBIDDEN {
574 return Err(PunchError::Auth(
575 "anthropic API key is invalid or lacks permissions".to_string(),
576 ));
577 }
578
579 let response_body: serde_json::Value =
580 response.json().await.map_err(|e| PunchError::Provider {
581 provider: "anthropic".to_string(),
582 message: format!("failed to parse response: {e}"),
583 })?;
584
585 if !status.is_success() {
586 let error_msg = response_body["error"]["message"]
587 .as_str()
588 .unwrap_or("unknown error");
589 return Err(PunchError::Provider {
590 provider: "anthropic".to_string(),
591 message: format!("API error ({}): {}", status, error_msg),
592 });
593 }
594
595 self.parse_response(&response_body)
596 }
597
598 async fn stream_complete_with_callback(
599 &self,
600 request: CompletionRequest,
601 callback: StreamCallback,
602 ) -> PunchResult<CompletionResponse> {
603 let url = format!("{}/v1/messages", self.base_url);
604 let mut body = self.build_request_body(&request);
605 body["stream"] = serde_json::json!(true);
606
607 let response = self
608 .client
609 .post(&url)
610 .header("x-api-key", &self.api_key)
611 .header("anthropic-version", "2023-06-01")
612 .header("content-type", "application/json")
613 .json(&body)
614 .send()
615 .await
616 .map_err(|e| PunchError::Provider {
617 provider: "anthropic".to_string(),
618 message: format!("stream request failed: {e}"),
619 })?;
620
621 let status = response.status();
622 if status == reqwest::StatusCode::TOO_MANY_REQUESTS {
623 return Err(PunchError::RateLimited {
624 provider: "anthropic".to_string(),
625 retry_after_ms: 60_000,
626 });
627 }
628 if !status.is_success() {
629 let err_body: serde_json::Value =
630 response.json().await.unwrap_or(serde_json::json!({}));
631 let msg = err_body["error"]["message"]
632 .as_str()
633 .unwrap_or("unknown error");
634 return Err(PunchError::Provider {
635 provider: "anthropic".to_string(),
636 message: format!("API error ({}): {}", status, msg),
637 });
638 }
639
640 let raw = read_stream_body(response).await?;
641 let events = parse_sse_events(&raw);
642
643 let mut text_content = String::new();
644 let mut tool_calls: Vec<ToolCall> = Vec::new();
645 let mut usage = TokenUsage::default();
646 let mut stop_reason = StopReason::EndTurn;
647 let mut current_tool_index: Option<usize> = None;
649
650 for (event_type, data) in &events {
651 let parsed: serde_json::Value = match serde_json::from_str(data) {
652 Ok(v) => v,
653 Err(_) => continue,
654 };
655
656 match event_type.as_str() {
657 "message_start" => {
658 if let Some(inp) = parsed["message"]["usage"]["input_tokens"].as_u64() {
659 usage.input_tokens = inp;
660 }
661 }
662 "content_block_start" => {
663 let block = &parsed["content_block"];
664 match block["type"].as_str() {
665 Some("tool_use") => {
666 let id = block["id"].as_str().unwrap_or_default().to_string();
667 let name = block["name"].as_str().unwrap_or_default().to_string();
668 tool_calls.push(ToolCall {
669 id: id.clone(),
670 name: name.clone(),
671 input: serde_json::json!({}),
672 });
673 current_tool_index = Some(tool_calls.len() - 1);
674 callback(StreamChunk {
675 delta: String::new(),
676 is_final: false,
677 tool_call_delta: Some(ToolCallDelta {
678 index: tool_calls.len() - 1,
679 id: Some(id),
680 name: Some(name),
681 arguments_delta: String::new(),
682 }),
683 });
684 }
685 Some("text") => {
686 current_tool_index = None;
687 }
688 _ => {}
689 }
690 }
691 "content_block_delta" => {
692 let delta = &parsed["delta"];
693 match delta["type"].as_str() {
694 Some("text_delta") => {
695 let text = delta["text"].as_str().unwrap_or("");
696 text_content.push_str(text);
697 callback(StreamChunk {
698 delta: text.to_string(),
699 is_final: false,
700 tool_call_delta: None,
701 });
702 }
703 Some("input_json_delta") => {
704 let partial = delta["partial_json"].as_str().unwrap_or("");
705 if let Some(idx) = current_tool_index {
706 callback(StreamChunk {
707 delta: String::new(),
708 is_final: false,
709 tool_call_delta: Some(ToolCallDelta {
710 index: idx,
711 id: None,
712 name: None,
713 arguments_delta: partial.to_string(),
714 }),
715 });
716 }
717 }
718 _ => {}
719 }
720 }
721 "message_delta" => {
722 if let Some(sr) = parsed["delta"]["stop_reason"].as_str() {
723 stop_reason = match sr {
724 "end_turn" => StopReason::EndTurn,
725 "tool_use" => StopReason::ToolUse,
726 "max_tokens" => StopReason::MaxTokens,
727 _ => StopReason::Error,
728 };
729 }
730 if let Some(out) = parsed["usage"]["output_tokens"].as_u64() {
731 usage.output_tokens = out;
732 }
733 }
734 "message_stop" => {
735 callback(StreamChunk {
736 delta: String::new(),
737 is_final: true,
738 tool_call_delta: None,
739 });
740 }
741 _ => {}
742 }
743 }
744
745 let mut tool_json_bufs: Vec<String> = vec![String::new(); tool_calls.len()];
750 let mut tc_idx: Option<usize> = None;
751 for (event_type, data) in &events {
752 let parsed: serde_json::Value = match serde_json::from_str(data) {
753 Ok(v) => v,
754 Err(_) => continue,
755 };
756 match event_type.as_str() {
757 "content_block_start" => {
758 if parsed["content_block"]["type"].as_str() == Some("tool_use") {
759 tc_idx = Some(tc_idx.map_or(0, |i| i + 1));
760 } else {
761 tc_idx = None;
762 }
763 }
764 "content_block_delta" => {
765 if parsed["delta"]["type"].as_str() == Some("input_json_delta")
766 && let Some(idx) = tc_idx
767 && let Some(buf) = tool_json_bufs.get_mut(idx)
768 {
769 buf.push_str(parsed["delta"]["partial_json"].as_str().unwrap_or(""));
770 }
771 }
772 _ => {}
773 }
774 }
775 for (i, buf) in tool_json_bufs.into_iter().enumerate() {
776 if !buf.is_empty()
777 && let Some(tc) = tool_calls.get_mut(i)
778 {
779 tc.input = serde_json::from_str(&buf).unwrap_or(serde_json::json!({}));
780 }
781 }
782
783 let text_content = strip_thinking_tags(&text_content);
784
785 if !tool_calls.is_empty() && stop_reason != StopReason::ToolUse {
786 stop_reason = StopReason::ToolUse;
787 }
788
789 let message = Message {
790 role: Role::Assistant,
791 content: text_content,
792 tool_calls,
793 tool_results: Vec::new(),
794 content_parts: Vec::new(),
795 timestamp: chrono::Utc::now(),
796 };
797
798 Ok(CompletionResponse {
799 message,
800 usage,
801 stop_reason,
802 })
803 }
804}
805
806pub struct OpenAiCompatibleDriver {
816 client: Client,
817 api_key: String,
818 base_url: String,
819 provider_name: String,
820}
821
822impl OpenAiCompatibleDriver {
823 pub fn new(api_key: String, base_url: String, provider_name: String) -> Self {
825 Self {
826 client: Client::new(),
827 api_key,
828 base_url,
829 provider_name,
830 }
831 }
832
833 pub fn with_client(
835 client: Client,
836 api_key: String,
837 base_url: String,
838 provider_name: String,
839 ) -> Self {
840 Self {
841 client,
842 api_key,
843 base_url,
844 provider_name,
845 }
846 }
847
848 pub fn build_request_body(&self, request: &CompletionRequest) -> serde_json::Value {
850 let mut messages = Vec::new();
851
852 if let Some(ref system) = request.system_prompt {
854 messages.push(serde_json::json!({
855 "role": "system",
856 "content": system,
857 }));
858 }
859
860 for msg in &request.messages {
861 match msg.role {
862 Role::System => {
863 messages.push(serde_json::json!({
864 "role": "system",
865 "content": msg.content,
866 }));
867 }
868 Role::User => {
869 if msg.content_parts.is_empty() {
870 messages.push(serde_json::json!({
871 "role": "user",
872 "content": msg.content,
873 }));
874 } else {
875 let mut content_blocks: Vec<serde_json::Value> = Vec::new();
877 if !msg.content.is_empty() {
878 content_blocks.push(serde_json::json!({
879 "type": "text",
880 "text": msg.content,
881 }));
882 }
883 for part in &msg.content_parts {
884 match part {
885 punch_types::ContentPart::Text { text } => {
886 content_blocks.push(serde_json::json!({
887 "type": "text",
888 "text": text,
889 }));
890 }
891 punch_types::ContentPart::Image { media_type, data } => {
892 content_blocks.push(serde_json::json!({
893 "type": "image_url",
894 "image_url": {
895 "url": format!("data:{media_type};base64,{data}"),
896 },
897 }));
898 }
899 }
900 }
901 messages.push(serde_json::json!({
902 "role": "user",
903 "content": content_blocks,
904 }));
905 }
906 }
907 Role::Assistant => {
908 let mut m = serde_json::json!({
909 "role": "assistant",
910 });
911
912 if !msg.content.is_empty() {
913 m["content"] = serde_json::json!(msg.content);
914 }
915
916 if !msg.tool_calls.is_empty() {
917 let tc: Vec<serde_json::Value> = msg
918 .tool_calls
919 .iter()
920 .map(|tc| {
921 serde_json::json!({
922 "id": tc.id,
923 "type": "function",
924 "function": {
925 "name": tc.name,
926 "arguments": tc.input.to_string(),
927 },
928 })
929 })
930 .collect();
931 m["tool_calls"] = serde_json::json!(tc);
932 }
933
934 messages.push(m);
935 }
936 Role::Tool => {
937 for tr in &msg.tool_results {
938 messages.push(serde_json::json!({
939 "role": "tool",
940 "tool_call_id": tr.id,
941 "content": tr.content,
942 }));
943 }
944 }
945 }
946 }
947
948 let tools: Vec<serde_json::Value> = request
949 .tools
950 .iter()
951 .map(|t| {
952 serde_json::json!({
953 "type": "function",
954 "function": {
955 "name": t.name,
956 "description": t.description,
957 "parameters": t.input_schema,
958 },
959 })
960 })
961 .collect();
962
963 let mut body = serde_json::json!({
964 "model": request.model,
965 "messages": messages,
966 "max_tokens": request.max_tokens,
967 });
968
969 if let Some(temp) = request.temperature {
970 body["temperature"] = serde_json::json!(temp);
971 }
972
973 if !tools.is_empty() {
974 body["tools"] = serde_json::json!(tools);
975 }
976
977 body
978 }
979
980 pub fn parse_response(&self, body: &serde_json::Value) -> PunchResult<CompletionResponse> {
982 let choice = body["choices"].get(0).ok_or_else(|| PunchError::Provider {
983 provider: self.provider_name.clone(),
984 message: "no choices in response".to_string(),
985 })?;
986
987 let finish_reason = choice["finish_reason"].as_str().unwrap_or("stop");
988 let stop_reason = match finish_reason {
989 "stop" => StopReason::EndTurn,
990 "tool_calls" => StopReason::ToolUse,
991 "length" => StopReason::MaxTokens,
992 _ => StopReason::EndTurn,
993 };
994
995 let msg = &choice["message"];
996 let raw_content = msg["content"].as_str().unwrap_or("");
997 let content = strip_thinking_tags(raw_content);
999
1000 let mut tool_calls = Vec::new();
1001 if let Some(tc_array) = msg["tool_calls"].as_array() {
1002 for tc in tc_array {
1003 let id = tc["id"].as_str().unwrap_or_default().to_string();
1004 let name = tc["function"]["name"]
1005 .as_str()
1006 .unwrap_or_default()
1007 .to_string();
1008 let args_str = tc["function"]["arguments"].as_str().unwrap_or("{}");
1009 let input: serde_json::Value =
1010 serde_json::from_str(args_str).unwrap_or(serde_json::json!({}));
1011
1012 tool_calls.push(ToolCall { id, name, input });
1013 }
1014 }
1015
1016 let usage = TokenUsage {
1017 input_tokens: body["usage"]["prompt_tokens"].as_u64().unwrap_or(0),
1018 output_tokens: body["usage"]["completion_tokens"].as_u64().unwrap_or(0),
1019 };
1020
1021 let stop_reason = if !tool_calls.is_empty() && stop_reason != StopReason::ToolUse {
1023 StopReason::ToolUse
1024 } else {
1025 stop_reason
1026 };
1027
1028 let message = Message {
1029 role: Role::Assistant,
1030 content,
1031 tool_calls,
1032 tool_results: Vec::new(),
1033 content_parts: Vec::new(),
1034 timestamp: chrono::Utc::now(),
1035 };
1036
1037 Ok(CompletionResponse {
1038 message,
1039 usage,
1040 stop_reason,
1041 })
1042 }
1043}
1044
1045#[async_trait]
1046impl LlmDriver for OpenAiCompatibleDriver {
1047 async fn complete(&self, request: CompletionRequest) -> PunchResult<CompletionResponse> {
1048 let url = format!(
1049 "{}/v1/chat/completions",
1050 self.base_url.trim_end_matches('/')
1051 );
1052 let body = self.build_request_body(&request);
1053
1054 let response = self
1055 .client
1056 .post(&url)
1057 .header("authorization", format!("Bearer {}", self.api_key))
1058 .header("content-type", "application/json")
1059 .json(&body)
1060 .send()
1061 .await
1062 .map_err(|e| PunchError::Provider {
1063 provider: self.provider_name.clone(),
1064 message: format!("request failed: {e}"),
1065 })?;
1066
1067 let status = response.status();
1068
1069 if status == reqwest::StatusCode::TOO_MANY_REQUESTS {
1070 let retry_after = response
1071 .headers()
1072 .get("retry-after")
1073 .and_then(|v| v.to_str().ok())
1074 .and_then(|s| s.parse::<u64>().ok())
1075 .unwrap_or(60)
1076 * 1000;
1077
1078 return Err(PunchError::RateLimited {
1079 provider: self.provider_name.clone(),
1080 retry_after_ms: retry_after,
1081 });
1082 }
1083
1084 if status == reqwest::StatusCode::UNAUTHORIZED || status == reqwest::StatusCode::FORBIDDEN {
1085 return Err(PunchError::Auth(format!(
1086 "{} API key is invalid or lacks permissions",
1087 self.provider_name
1088 )));
1089 }
1090
1091 let response_body: serde_json::Value =
1092 response.json().await.map_err(|e| PunchError::Provider {
1093 provider: self.provider_name.clone(),
1094 message: format!("failed to parse response: {e}"),
1095 })?;
1096
1097 if !status.is_success() {
1098 let error_msg = response_body["error"]["message"]
1099 .as_str()
1100 .unwrap_or("unknown error");
1101 return Err(PunchError::Provider {
1102 provider: self.provider_name.clone(),
1103 message: format!("API error ({}): {}", status, error_msg),
1104 });
1105 }
1106
1107 self.parse_response(&response_body)
1108 }
1109
1110 async fn stream_complete_with_callback(
1111 &self,
1112 request: CompletionRequest,
1113 callback: StreamCallback,
1114 ) -> PunchResult<CompletionResponse> {
1115 let url = format!(
1116 "{}/v1/chat/completions",
1117 self.base_url.trim_end_matches('/')
1118 );
1119 let mut body = self.build_request_body(&request);
1120 body["stream"] = serde_json::json!(true);
1121
1122 let response = self
1123 .client
1124 .post(&url)
1125 .header("authorization", format!("Bearer {}", self.api_key))
1126 .header("content-type", "application/json")
1127 .json(&body)
1128 .send()
1129 .await
1130 .map_err(|e| PunchError::Provider {
1131 provider: self.provider_name.clone(),
1132 message: format!("stream request failed: {e}"),
1133 })?;
1134
1135 let status = response.status();
1136 if status == reqwest::StatusCode::TOO_MANY_REQUESTS {
1137 return Err(PunchError::RateLimited {
1138 provider: self.provider_name.clone(),
1139 retry_after_ms: 60_000,
1140 });
1141 }
1142 if !status.is_success() {
1143 let err_body: serde_json::Value =
1144 response.json().await.unwrap_or(serde_json::json!({}));
1145 let msg = err_body["error"]["message"]
1146 .as_str()
1147 .unwrap_or("unknown error");
1148 return Err(PunchError::Provider {
1149 provider: self.provider_name.clone(),
1150 message: format!("API error ({}): {}", status, msg),
1151 });
1152 }
1153
1154 let raw = read_stream_body(response).await?;
1155 let assembled = self.parse_openai_stream(&raw, &callback)?;
1156 Ok(assembled)
1157 }
1158}
1159
1160impl OpenAiCompatibleDriver {
1161 pub fn parse_openai_stream(
1164 &self,
1165 raw: &str,
1166 callback: &StreamCallback,
1167 ) -> PunchResult<CompletionResponse> {
1168 let events = parse_sse_events(raw);
1169
1170 let mut text_content = String::new();
1171 let mut tool_map: std::collections::BTreeMap<usize, (String, String, String)> =
1173 std::collections::BTreeMap::new();
1174 let mut finish_reason = String::new();
1175
1176 for (_event_type, data) in &events {
1177 if data.trim() == "[DONE]" {
1178 callback(StreamChunk {
1179 delta: String::new(),
1180 is_final: true,
1181 tool_call_delta: None,
1182 });
1183 break;
1184 }
1185
1186 let parsed: serde_json::Value = match serde_json::from_str(data) {
1187 Ok(v) => v,
1188 Err(_) => continue,
1189 };
1190
1191 let choice = match parsed["choices"].get(0) {
1192 Some(c) => c,
1193 None => continue,
1194 };
1195
1196 if let Some(fr) = choice["finish_reason"].as_str() {
1197 finish_reason = fr.to_string();
1198 }
1199
1200 let delta = &choice["delta"];
1201
1202 if let Some(content) = delta["content"].as_str() {
1204 text_content.push_str(content);
1205 callback(StreamChunk {
1206 delta: content.to_string(),
1207 is_final: false,
1208 tool_call_delta: None,
1209 });
1210 }
1211
1212 if let Some(tc_array) = delta["tool_calls"].as_array() {
1214 for tc in tc_array {
1215 let idx = tc["index"].as_u64().unwrap_or(0) as usize;
1216 let entry = tool_map
1217 .entry(idx)
1218 .or_insert_with(|| (String::new(), String::new(), String::new()));
1219
1220 let id_delta = tc["id"].as_str().unwrap_or("");
1221 let name_delta = tc["function"]["name"].as_str().unwrap_or("");
1222 let args_delta = tc["function"]["arguments"].as_str().unwrap_or("");
1223
1224 if !id_delta.is_empty() {
1225 entry.0.push_str(id_delta);
1226 }
1227 if !name_delta.is_empty() {
1228 entry.1.push_str(name_delta);
1229 }
1230 entry.2.push_str(args_delta);
1231
1232 callback(StreamChunk {
1233 delta: String::new(),
1234 is_final: false,
1235 tool_call_delta: Some(ToolCallDelta {
1236 index: idx,
1237 id: if id_delta.is_empty() {
1238 None
1239 } else {
1240 Some(id_delta.to_string())
1241 },
1242 name: if name_delta.is_empty() {
1243 None
1244 } else {
1245 Some(name_delta.to_string())
1246 },
1247 arguments_delta: args_delta.to_string(),
1248 }),
1249 });
1250 }
1251 }
1252 }
1253
1254 let tool_calls: Vec<ToolCall> = tool_map
1255 .into_values()
1256 .map(|(id, name, args)| {
1257 let input: serde_json::Value =
1258 serde_json::from_str(&args).unwrap_or(serde_json::json!({}));
1259 ToolCall { id, name, input }
1260 })
1261 .collect();
1262
1263 let stop_reason = if !tool_calls.is_empty() {
1264 StopReason::ToolUse
1265 } else {
1266 match finish_reason.as_str() {
1267 "stop" => StopReason::EndTurn,
1268 "tool_calls" => StopReason::ToolUse,
1269 "length" => StopReason::MaxTokens,
1270 _ => StopReason::EndTurn,
1271 }
1272 };
1273
1274 let text_content = strip_thinking_tags(&text_content);
1275
1276 let message = Message {
1277 role: Role::Assistant,
1278 content: text_content,
1279 tool_calls,
1280 tool_results: Vec::new(),
1281 content_parts: Vec::new(),
1282 timestamp: chrono::Utc::now(),
1283 };
1284
1285 Ok(CompletionResponse {
1287 message,
1288 usage: TokenUsage::default(),
1289 stop_reason,
1290 })
1291 }
1292}
1293
1294pub struct GeminiDriver {
1300 client: Client,
1301 api_key: String,
1302 base_url: String,
1303}
1304
1305impl GeminiDriver {
1306 pub fn new(api_key: String, base_url: Option<String>) -> Self {
1308 Self {
1309 client: Client::new(),
1310 api_key,
1311 base_url: base_url
1312 .unwrap_or_else(|| "https://generativelanguage.googleapis.com".to_string()),
1313 }
1314 }
1315
1316 pub fn with_client(client: Client, api_key: String, base_url: Option<String>) -> Self {
1318 Self {
1319 client,
1320 api_key,
1321 base_url: base_url
1322 .unwrap_or_else(|| "https://generativelanguage.googleapis.com".to_string()),
1323 }
1324 }
1325
1326 pub fn build_request_body(&self, request: &CompletionRequest) -> serde_json::Value {
1328 let mut contents = Vec::new();
1329 let mut system_text: Option<String> = request.system_prompt.clone();
1331
1332 for msg in &request.messages {
1333 match msg.role {
1334 Role::System => {
1335 let existing = system_text.take().unwrap_or_default();
1337 let combined = if existing.is_empty() {
1338 msg.content.clone()
1339 } else {
1340 format!("{}\n{}", existing, msg.content)
1341 };
1342 system_text = Some(combined);
1343 }
1344 Role::User => {
1345 let mut parts: Vec<serde_json::Value> = Vec::new();
1346 if !msg.content.is_empty() {
1347 parts.push(serde_json::json!({"text": msg.content}));
1348 }
1349 for part in &msg.content_parts {
1351 match part {
1352 punch_types::ContentPart::Text { text: t } => {
1353 parts.push(serde_json::json!({"text": t}));
1354 }
1355 punch_types::ContentPart::Image { media_type, data } => {
1356 parts.push(serde_json::json!({
1357 "inline_data": {
1358 "mime_type": media_type,
1359 "data": data,
1360 }
1361 }));
1362 }
1363 }
1364 }
1365 if parts.is_empty() {
1366 parts.push(serde_json::json!({"text": ""}));
1367 }
1368 contents.push(serde_json::json!({
1369 "role": "user",
1370 "parts": parts,
1371 }));
1372 }
1373 Role::Assistant => {
1374 let mut parts: Vec<serde_json::Value> = Vec::new();
1375 if !msg.content.is_empty() {
1376 parts.push(serde_json::json!({"text": msg.content}));
1377 }
1378 for tc in &msg.tool_calls {
1379 parts.push(serde_json::json!({
1380 "functionCall": {
1381 "name": tc.name,
1382 "args": tc.input,
1383 }
1384 }));
1385 }
1386 if parts.is_empty() {
1387 parts.push(serde_json::json!({"text": ""}));
1388 }
1389 contents.push(serde_json::json!({
1390 "role": "model",
1391 "parts": parts,
1392 }));
1393 }
1394 Role::Tool => {
1395 let mut parts: Vec<serde_json::Value> = Vec::new();
1396 for tr in &msg.tool_results {
1397 parts.push(serde_json::json!({
1398 "functionResponse": {
1399 "name": tr.id.clone(),
1400 "response": {"content": tr.content},
1401 }
1402 }));
1403 }
1404 contents.push(serde_json::json!({
1405 "role": "user",
1406 "parts": parts,
1407 }));
1408 }
1409 }
1410 }
1411
1412 let mut body = serde_json::json!({
1413 "contents": contents,
1414 });
1415
1416 if let Some(sys) = system_text
1420 && !sys.is_empty()
1421 {
1422 body["system_instruction"] = serde_json::json!({
1423 "parts": [{"text": sys}],
1424 });
1425 }
1426
1427 let mut gen_config = serde_json::json!({
1428 "maxOutputTokens": request.max_tokens,
1429 });
1430 if let Some(temp) = request.temperature {
1431 gen_config["temperature"] = serde_json::json!(temp);
1432 }
1433 body["generationConfig"] = gen_config;
1434
1435 if !request.tools.is_empty() {
1436 let func_decls: Vec<serde_json::Value> = request
1437 .tools
1438 .iter()
1439 .map(|t| {
1440 serde_json::json!({
1441 "name": t.name,
1442 "description": t.description,
1443 "parameters": t.input_schema,
1444 })
1445 })
1446 .collect();
1447 body["tools"] = serde_json::json!([{"function_declarations": func_decls}]);
1448 }
1449
1450 body
1451 }
1452
1453 pub fn build_url(&self, model: &str) -> String {
1455 format!(
1456 "{}/v1beta/models/{}:generateContent?key={}",
1457 self.base_url.trim_end_matches('/'),
1458 model,
1459 self.api_key,
1460 )
1461 }
1462
1463 pub fn build_stream_url(&self, model: &str) -> String {
1465 format!(
1466 "{}/v1beta/models/{}:streamGenerateContent?alt=sse&key={}",
1467 self.base_url.trim_end_matches('/'),
1468 model,
1469 self.api_key,
1470 )
1471 }
1472
1473 pub fn parse_response(&self, body: &serde_json::Value) -> PunchResult<CompletionResponse> {
1475 let candidate = body["candidates"]
1476 .get(0)
1477 .ok_or_else(|| PunchError::Provider {
1478 provider: "gemini".to_string(),
1479 message: "no candidates in response".to_string(),
1480 })?;
1481
1482 let parts = candidate["content"]["parts"]
1483 .as_array()
1484 .cloned()
1485 .unwrap_or_default();
1486
1487 let mut text_content = String::new();
1488 let mut tool_calls = Vec::new();
1489
1490 for part in &parts {
1491 if let Some(text) = part["text"].as_str() {
1492 if !text_content.is_empty() {
1493 text_content.push('\n');
1494 }
1495 text_content.push_str(text);
1496 }
1497 if let Some(fc) = part.get("functionCall") {
1498 let name = fc["name"].as_str().unwrap_or_default().to_string();
1499 let args = fc["args"].clone();
1500 tool_calls.push(ToolCall {
1501 id: format!("gemini-{}", uuid::Uuid::new_v4()),
1502 name,
1503 input: args,
1504 });
1505 }
1506 }
1507
1508 let finish_reason = candidate["finishReason"].as_str().unwrap_or("STOP");
1509 let stop_reason = if !tool_calls.is_empty() {
1510 StopReason::ToolUse
1511 } else {
1512 match finish_reason {
1513 "STOP" => StopReason::EndTurn,
1514 "MAX_TOKENS" => StopReason::MaxTokens,
1515 _ => StopReason::EndTurn,
1516 }
1517 };
1518
1519 let usage = TokenUsage {
1520 input_tokens: body["usageMetadata"]["promptTokenCount"]
1521 .as_u64()
1522 .unwrap_or(0),
1523 output_tokens: body["usageMetadata"]["candidatesTokenCount"]
1524 .as_u64()
1525 .unwrap_or(0),
1526 };
1527
1528 let text_content = strip_thinking_tags(&text_content);
1530
1531 let message = Message {
1532 role: Role::Assistant,
1533 content: text_content,
1534 tool_calls,
1535 tool_results: Vec::new(),
1536 content_parts: Vec::new(),
1537 timestamp: chrono::Utc::now(),
1538 };
1539
1540 Ok(CompletionResponse {
1541 message,
1542 usage,
1543 stop_reason,
1544 })
1545 }
1546}
1547
1548#[async_trait]
1549impl LlmDriver for GeminiDriver {
1550 async fn complete(&self, request: CompletionRequest) -> PunchResult<CompletionResponse> {
1551 let url = self.build_url(&request.model);
1552 let body = self.build_request_body(&request);
1553
1554 let response = self
1555 .client
1556 .post(&url)
1557 .header("content-type", "application/json")
1558 .json(&body)
1559 .send()
1560 .await
1561 .map_err(|e| PunchError::Provider {
1562 provider: "gemini".to_string(),
1563 message: format!("request failed: {e}"),
1564 })?;
1565
1566 let status = response.status();
1567
1568 if status == reqwest::StatusCode::TOO_MANY_REQUESTS {
1569 return Err(PunchError::RateLimited {
1570 provider: "gemini".to_string(),
1571 retry_after_ms: 60_000,
1572 });
1573 }
1574
1575 if status == reqwest::StatusCode::UNAUTHORIZED || status == reqwest::StatusCode::FORBIDDEN {
1576 return Err(PunchError::Auth(
1577 "Gemini API key is invalid or lacks permissions".to_string(),
1578 ));
1579 }
1580
1581 let response_body: serde_json::Value =
1582 response.json().await.map_err(|e| PunchError::Provider {
1583 provider: "gemini".to_string(),
1584 message: format!("failed to parse response: {e}"),
1585 })?;
1586
1587 if !status.is_success() {
1588 let error_msg = response_body["error"]["message"]
1589 .as_str()
1590 .unwrap_or("unknown error");
1591 return Err(PunchError::Provider {
1592 provider: "gemini".to_string(),
1593 message: format!("API error ({}): {}", status, error_msg),
1594 });
1595 }
1596
1597 self.parse_response(&response_body)
1598 }
1599
1600 async fn stream_complete_with_callback(
1601 &self,
1602 request: CompletionRequest,
1603 callback: StreamCallback,
1604 ) -> PunchResult<CompletionResponse> {
1605 let url = self.build_stream_url(&request.model);
1606 let body = self.build_request_body(&request);
1607
1608 let response = self
1609 .client
1610 .post(&url)
1611 .header("content-type", "application/json")
1612 .json(&body)
1613 .send()
1614 .await
1615 .map_err(|e| PunchError::Provider {
1616 provider: "gemini".to_string(),
1617 message: format!("stream request failed: {e}"),
1618 })?;
1619
1620 let status = response.status();
1621 if !status.is_success() {
1622 let err_body: serde_json::Value =
1623 response.json().await.unwrap_or(serde_json::json!({}));
1624 let msg = err_body["error"]["message"]
1625 .as_str()
1626 .unwrap_or("unknown error");
1627 return Err(PunchError::Provider {
1628 provider: "gemini".to_string(),
1629 message: format!("API error ({}): {}", status, msg),
1630 });
1631 }
1632
1633 let raw = read_stream_body(response).await?;
1634 let events = parse_sse_events(&raw);
1635
1636 let mut text_content = String::new();
1637 let mut tool_calls: Vec<ToolCall> = Vec::new();
1638 let mut usage = TokenUsage::default();
1639 let mut finish_reason = String::new();
1640
1641 for (_event_type, data) in &events {
1642 let parsed: serde_json::Value = match serde_json::from_str(data) {
1643 Ok(v) => v,
1644 Err(_) => continue,
1645 };
1646
1647 if let Some(parts) = parsed["candidates"][0]["content"]["parts"].as_array() {
1649 for part in parts {
1650 if let Some(text) = part["text"].as_str() {
1651 text_content.push_str(text);
1652 callback(StreamChunk {
1653 delta: text.to_string(),
1654 is_final: false,
1655 tool_call_delta: None,
1656 });
1657 }
1658 if let Some(fc) = part.get("functionCall") {
1659 let name = fc["name"].as_str().unwrap_or_default().to_string();
1660 let args = fc["args"].clone();
1661 let idx = tool_calls.len();
1662 tool_calls.push(ToolCall {
1663 id: format!("gemini-{}", uuid::Uuid::new_v4()),
1664 name: name.clone(),
1665 input: args,
1666 });
1667 callback(StreamChunk {
1668 delta: String::new(),
1669 is_final: false,
1670 tool_call_delta: Some(ToolCallDelta {
1671 index: idx,
1672 id: None,
1673 name: Some(name),
1674 arguments_delta: String::new(),
1675 }),
1676 });
1677 }
1678 }
1679 }
1680
1681 if let Some(fr) = parsed["candidates"][0]["finishReason"].as_str() {
1682 finish_reason = fr.to_string();
1683 }
1684
1685 if let Some(inp) = parsed["usageMetadata"]["promptTokenCount"].as_u64() {
1687 usage.input_tokens = inp;
1688 }
1689 if let Some(out) = parsed["usageMetadata"]["candidatesTokenCount"].as_u64() {
1690 usage.output_tokens = out;
1691 }
1692 }
1693
1694 callback(StreamChunk {
1695 delta: String::new(),
1696 is_final: true,
1697 tool_call_delta: None,
1698 });
1699
1700 let stop_reason = if !tool_calls.is_empty() {
1701 StopReason::ToolUse
1702 } else {
1703 match finish_reason.as_str() {
1704 "STOP" => StopReason::EndTurn,
1705 "MAX_TOKENS" => StopReason::MaxTokens,
1706 _ => StopReason::EndTurn,
1707 }
1708 };
1709
1710 let text_content = strip_thinking_tags(&text_content);
1711
1712 let message = Message {
1713 role: Role::Assistant,
1714 content: text_content,
1715 tool_calls,
1716 tool_results: Vec::new(),
1717 content_parts: Vec::new(),
1718 timestamp: chrono::Utc::now(),
1719 };
1720
1721 Ok(CompletionResponse {
1722 message,
1723 usage,
1724 stop_reason,
1725 })
1726 }
1727}
1728
1729pub struct OllamaDriver {
1735 client: Client,
1736 base_url: String,
1737}
1738
1739impl OllamaDriver {
1740 pub fn new(base_url: Option<String>) -> Self {
1742 Self {
1743 client: Client::new(),
1744 base_url: base_url.unwrap_or_else(|| "http://localhost:11434".to_string()),
1745 }
1746 }
1747
1748 pub fn with_client(client: Client, base_url: Option<String>) -> Self {
1750 Self {
1751 client,
1752 base_url: base_url.unwrap_or_else(|| "http://localhost:11434".to_string()),
1753 }
1754 }
1755
1756 pub fn base_url(&self) -> &str {
1758 &self.base_url
1759 }
1760
1761 pub fn build_request_body(&self, request: &CompletionRequest) -> serde_json::Value {
1763 let mut messages = Vec::new();
1764
1765 if let Some(ref system) = request.system_prompt {
1766 messages.push(serde_json::json!({
1767 "role": "system",
1768 "content": system,
1769 }));
1770 }
1771
1772 for msg in &request.messages {
1773 match msg.role {
1774 Role::System => {
1775 messages.push(serde_json::json!({
1776 "role": "system",
1777 "content": msg.content,
1778 }));
1779 }
1780 Role::User => {
1781 let images: Vec<&str> = msg
1783 .content_parts
1784 .iter()
1785 .filter_map(|p| match p {
1786 punch_types::ContentPart::Image { data, .. } => Some(data.as_str()),
1787 _ => None,
1788 })
1789 .collect();
1790 let mut m = serde_json::json!({
1791 "role": "user",
1792 "content": msg.content,
1793 });
1794 if !images.is_empty() {
1795 m["images"] = serde_json::json!(images);
1796 }
1797 messages.push(m);
1798 }
1799 Role::Assistant => {
1800 let mut m = serde_json::json!({
1801 "role": "assistant",
1802 "content": msg.content,
1803 });
1804 if !msg.tool_calls.is_empty() {
1805 let tc: Vec<serde_json::Value> = msg
1806 .tool_calls
1807 .iter()
1808 .map(|tc| {
1809 serde_json::json!({
1810 "function": {
1811 "name": tc.name,
1812 "arguments": tc.input,
1813 }
1814 })
1815 })
1816 .collect();
1817 m["tool_calls"] = serde_json::json!(tc);
1818 }
1819 messages.push(m);
1820 }
1821 Role::Tool => {
1822 for tr in &msg.tool_results {
1823 messages.push(serde_json::json!({
1824 "role": "tool",
1825 "content": tr.content,
1826 }));
1827 }
1828 }
1829 }
1830 }
1831
1832 let mut body = serde_json::json!({
1833 "model": request.model,
1834 "messages": messages,
1835 "stream": false,
1836 });
1837
1838 let mut options = serde_json::json!({});
1839 if let Some(temp) = request.temperature {
1840 options["temperature"] = serde_json::json!(temp);
1841 }
1842 if request.max_tokens > 0 {
1843 options["num_predict"] = serde_json::json!(request.max_tokens);
1844 }
1845 body["options"] = options;
1846
1847 body["think"] = serde_json::json!(false);
1851
1852 if !request.tools.is_empty() {
1853 let tools: Vec<serde_json::Value> = request
1854 .tools
1855 .iter()
1856 .map(|t| {
1857 serde_json::json!({
1858 "type": "function",
1859 "function": {
1860 "name": t.name,
1861 "description": t.description,
1862 "parameters": t.input_schema,
1863 }
1864 })
1865 })
1866 .collect();
1867 body["tools"] = serde_json::json!(tools);
1868 }
1869
1870 body
1871 }
1872
1873 pub fn parse_response(&self, body: &serde_json::Value) -> PunchResult<CompletionResponse> {
1875 let msg = &body["message"];
1876 let raw_content = msg["content"].as_str().unwrap_or("");
1877 let content = strip_thinking_tags(raw_content);
1879
1880 let mut tool_calls = Vec::new();
1881 if let Some(tc_array) = msg["tool_calls"].as_array() {
1882 for tc in tc_array {
1883 let name = tc["function"]["name"]
1884 .as_str()
1885 .unwrap_or_default()
1886 .to_string();
1887 let input = tc["function"]["arguments"].clone();
1888 tool_calls.push(ToolCall {
1889 id: format!("ollama-{}", uuid::Uuid::new_v4()),
1890 name,
1891 input,
1892 });
1893 }
1894 }
1895
1896 let stop_reason = if !tool_calls.is_empty() {
1897 StopReason::ToolUse
1898 } else if body["done"].as_bool().unwrap_or(true) {
1899 StopReason::EndTurn
1900 } else {
1901 StopReason::MaxTokens
1902 };
1903
1904 let usage = TokenUsage {
1905 input_tokens: body["prompt_eval_count"].as_u64().unwrap_or(0),
1906 output_tokens: body["eval_count"].as_u64().unwrap_or(0),
1907 };
1908
1909 let message = Message {
1910 role: Role::Assistant,
1911 content,
1912 tool_calls,
1913 tool_results: Vec::new(),
1914 content_parts: Vec::new(),
1915 timestamp: chrono::Utc::now(),
1916 };
1917
1918 Ok(CompletionResponse {
1919 message,
1920 usage,
1921 stop_reason,
1922 })
1923 }
1924}
1925
1926#[async_trait]
1927impl LlmDriver for OllamaDriver {
1928 async fn complete(&self, request: CompletionRequest) -> PunchResult<CompletionResponse> {
1929 let url = format!("{}/api/chat", self.base_url.trim_end_matches('/'));
1930 let body = self.build_request_body(&request);
1931
1932 let response = self
1933 .client
1934 .post(&url)
1935 .header("content-type", "application/json")
1936 .json(&body)
1937 .send()
1938 .await
1939 .map_err(|e| PunchError::Provider {
1940 provider: "ollama".to_string(),
1941 message: format!("request failed: {e}"),
1942 })?;
1943
1944 let status = response.status();
1945 let response_body: serde_json::Value =
1946 response.json().await.map_err(|e| PunchError::Provider {
1947 provider: "ollama".to_string(),
1948 message: format!("failed to parse response: {e}"),
1949 })?;
1950
1951 if !status.is_success() {
1952 let error_msg = response_body["error"].as_str().unwrap_or("unknown error");
1953 return Err(PunchError::Provider {
1954 provider: "ollama".to_string(),
1955 message: format!("API error ({}): {}", status, error_msg),
1956 });
1957 }
1958
1959 self.parse_response(&response_body)
1960 }
1961
1962 async fn stream_complete_with_callback(
1963 &self,
1964 request: CompletionRequest,
1965 callback: StreamCallback,
1966 ) -> PunchResult<CompletionResponse> {
1967 let url = format!("{}/api/chat", self.base_url.trim_end_matches('/'));
1968 let mut body = self.build_request_body(&request);
1969 body["stream"] = serde_json::json!(true);
1970
1971 let response = self
1972 .client
1973 .post(&url)
1974 .header("content-type", "application/json")
1975 .json(&body)
1976 .send()
1977 .await
1978 .map_err(|e| PunchError::Provider {
1979 provider: "ollama".to_string(),
1980 message: format!("stream request failed: {e}"),
1981 })?;
1982
1983 let status = response.status();
1984 if !status.is_success() {
1985 let err_body: serde_json::Value =
1986 response.json().await.unwrap_or(serde_json::json!({}));
1987 let msg = err_body["error"].as_str().unwrap_or("unknown error");
1988 return Err(PunchError::Provider {
1989 provider: "ollama".to_string(),
1990 message: format!("API error ({}): {}", status, msg),
1991 });
1992 }
1993
1994 let raw = read_stream_body(response).await?;
1995 let assembled = self.parse_ollama_stream(&raw, &callback)?;
1996 Ok(assembled)
1997 }
1998}
1999
2000impl OllamaDriver {
2001 pub fn parse_ollama_stream(
2004 &self,
2005 raw: &str,
2006 callback: &StreamCallback,
2007 ) -> PunchResult<CompletionResponse> {
2008 let mut text_content = String::new();
2009 let mut tool_calls: Vec<ToolCall> = Vec::new();
2010 let mut usage = TokenUsage::default();
2011 let mut done = false;
2012
2013 for line in raw.lines() {
2014 let line = line.trim();
2015 if line.is_empty() {
2016 continue;
2017 }
2018
2019 let parsed: serde_json::Value = match serde_json::from_str(line) {
2020 Ok(v) => v,
2021 Err(_) => continue,
2022 };
2023
2024 if parsed["done"].as_bool() == Some(true) {
2025 done = true;
2026 if let Some(inp) = parsed["prompt_eval_count"].as_u64() {
2028 usage.input_tokens = inp;
2029 }
2030 if let Some(out) = parsed["eval_count"].as_u64() {
2031 usage.output_tokens = out;
2032 }
2033 if let Some(tc_array) = parsed["message"]["tool_calls"].as_array() {
2035 for tc in tc_array {
2036 let name = tc["function"]["name"]
2037 .as_str()
2038 .unwrap_or_default()
2039 .to_string();
2040 let input = tc["function"]["arguments"].clone();
2041 tool_calls.push(ToolCall {
2042 id: format!("ollama-{}", uuid::Uuid::new_v4()),
2043 name,
2044 input,
2045 });
2046 }
2047 }
2048 callback(StreamChunk {
2049 delta: String::new(),
2050 is_final: true,
2051 tool_call_delta: None,
2052 });
2053 break;
2054 }
2055
2056 let content = parsed["message"]["content"].as_str().unwrap_or("");
2058 if !content.is_empty() {
2059 text_content.push_str(content);
2060 callback(StreamChunk {
2061 delta: content.to_string(),
2062 is_final: false,
2063 tool_call_delta: None,
2064 });
2065 }
2066 }
2067
2068 let text_content = strip_thinking_tags(&text_content);
2069
2070 let stop_reason = if !tool_calls.is_empty() {
2071 StopReason::ToolUse
2072 } else if done {
2073 StopReason::EndTurn
2074 } else {
2075 StopReason::MaxTokens
2076 };
2077
2078 let message = Message {
2079 role: Role::Assistant,
2080 content: text_content,
2081 tool_calls,
2082 tool_results: Vec::new(),
2083 content_parts: Vec::new(),
2084 timestamp: chrono::Utc::now(),
2085 };
2086
2087 Ok(CompletionResponse {
2088 message,
2089 usage,
2090 stop_reason,
2091 })
2092 }
2093}
2094
2095pub struct BedrockDriver {
2101 client: Client,
2102 access_key: String,
2103 secret_key: String,
2104 region: String,
2105}
2106
2107impl BedrockDriver {
2108 pub fn new(access_key: String, secret_key: String, region: String) -> Self {
2110 Self {
2111 client: Client::new(),
2112 access_key,
2113 secret_key,
2114 region,
2115 }
2116 }
2117
2118 pub fn with_client(
2120 client: Client,
2121 access_key: String,
2122 secret_key: String,
2123 region: String,
2124 ) -> Self {
2125 Self {
2126 client,
2127 access_key,
2128 secret_key,
2129 region,
2130 }
2131 }
2132
2133 pub fn build_request_body(&self, request: &CompletionRequest) -> serde_json::Value {
2135 let mut messages = Vec::new();
2136
2137 for msg in &request.messages {
2138 match msg.role {
2139 Role::User => {
2140 let mut content: Vec<serde_json::Value> = Vec::new();
2141 if !msg.content.is_empty() {
2142 content.push(serde_json::json!({"text": msg.content}));
2143 }
2144 for part in &msg.content_parts {
2146 match part {
2147 punch_types::ContentPart::Text { text } => {
2148 content.push(serde_json::json!({"text": text}));
2149 }
2150 punch_types::ContentPart::Image { media_type, data } => {
2151 content.push(serde_json::json!({
2152 "image": {
2153 "format": media_type.rsplit('/').next().unwrap_or("png"),
2154 "source": {
2155 "bytes": data,
2156 }
2157 }
2158 }));
2159 }
2160 }
2161 }
2162 if content.is_empty() {
2163 content.push(serde_json::json!({"text": ""}));
2164 }
2165 messages.push(serde_json::json!({
2166 "role": "user",
2167 "content": content,
2168 }));
2169 }
2170 Role::Assistant => {
2171 let mut content: Vec<serde_json::Value> = Vec::new();
2172 if !msg.content.is_empty() {
2173 content.push(serde_json::json!({"text": msg.content}));
2174 }
2175 for tc in &msg.tool_calls {
2176 content.push(serde_json::json!({
2177 "toolUse": {
2178 "toolUseId": tc.id,
2179 "name": tc.name,
2180 "input": tc.input,
2181 }
2182 }));
2183 }
2184 if content.is_empty() {
2185 content.push(serde_json::json!({"text": ""}));
2186 }
2187 messages.push(serde_json::json!({
2188 "role": "assistant",
2189 "content": content,
2190 }));
2191 }
2192 Role::Tool => {
2193 let mut content: Vec<serde_json::Value> = Vec::new();
2194 for tr in &msg.tool_results {
2195 content.push(serde_json::json!({
2196 "toolResult": {
2197 "toolUseId": tr.id,
2198 "content": [{"text": tr.content}],
2199 "status": if tr.is_error { "error" } else { "success" },
2200 }
2201 }));
2202 }
2203 messages.push(serde_json::json!({
2204 "role": "user",
2205 "content": content,
2206 }));
2207 }
2208 Role::System => {
2209 }
2211 }
2212 }
2213
2214 let mut body = serde_json::json!({
2215 "messages": messages,
2216 });
2217
2218 let mut inference_config = serde_json::json!({
2219 "maxTokens": request.max_tokens,
2220 });
2221 if let Some(temp) = request.temperature {
2222 inference_config["temperature"] = serde_json::json!(temp);
2223 }
2224 body["inferenceConfig"] = inference_config;
2225
2226 if let Some(ref system) = request.system_prompt {
2227 body["system"] = serde_json::json!([{"text": system}]);
2228 }
2229
2230 if !request.tools.is_empty() {
2231 let tool_config: Vec<serde_json::Value> = request
2232 .tools
2233 .iter()
2234 .map(|t| {
2235 serde_json::json!({
2236 "toolSpec": {
2237 "name": t.name,
2238 "description": t.description,
2239 "inputSchema": {"json": t.input_schema},
2240 }
2241 })
2242 })
2243 .collect();
2244 body["toolConfig"] = serde_json::json!({"tools": tool_config});
2245 }
2246
2247 body
2248 }
2249
2250 pub fn build_url(&self, model_id: &str) -> String {
2252 format!(
2253 "https://bedrock-runtime.{}.amazonaws.com/model/{}/converse",
2254 self.region, model_id,
2255 )
2256 }
2257
2258 pub fn parse_response(&self, body: &serde_json::Value) -> PunchResult<CompletionResponse> {
2260 let content = body["output"]["message"]["content"]
2261 .as_array()
2262 .cloned()
2263 .unwrap_or_default();
2264
2265 let mut text_content = String::new();
2266 let mut tool_calls = Vec::new();
2267
2268 for block in &content {
2269 if let Some(text) = block["text"].as_str() {
2270 if !text_content.is_empty() {
2271 text_content.push('\n');
2272 }
2273 text_content.push_str(text);
2274 }
2275 if let Some(tu) = block.get("toolUse") {
2276 tool_calls.push(ToolCall {
2277 id: tu["toolUseId"].as_str().unwrap_or_default().to_string(),
2278 name: tu["name"].as_str().unwrap_or_default().to_string(),
2279 input: tu["input"].clone(),
2280 });
2281 }
2282 }
2283
2284 let stop_reason_str = body["stopReason"].as_str().unwrap_or("end_turn");
2285 let stop_reason = if !tool_calls.is_empty() {
2286 StopReason::ToolUse
2287 } else {
2288 match stop_reason_str {
2289 "end_turn" => StopReason::EndTurn,
2290 "tool_use" => StopReason::ToolUse,
2291 "max_tokens" => StopReason::MaxTokens,
2292 _ => StopReason::EndTurn,
2293 }
2294 };
2295
2296 let usage = TokenUsage {
2297 input_tokens: body["usage"]["inputTokens"].as_u64().unwrap_or(0),
2298 output_tokens: body["usage"]["outputTokens"].as_u64().unwrap_or(0),
2299 };
2300
2301 let text_content = strip_thinking_tags(&text_content);
2303
2304 let message = Message {
2305 role: Role::Assistant,
2306 content: text_content,
2307 tool_calls,
2308 tool_results: Vec::new(),
2309 content_parts: Vec::new(),
2310 timestamp: chrono::Utc::now(),
2311 };
2312
2313 Ok(CompletionResponse {
2314 message,
2315 usage,
2316 stop_reason,
2317 })
2318 }
2319
2320 pub fn sign_request(
2324 &self,
2325 method: &str,
2326 url: &str,
2327 headers: &[(String, String)],
2328 payload: &[u8],
2329 timestamp: &str, ) -> PunchResult<String> {
2331 let date = ×tamp[..8]; let service = "bedrock";
2333
2334 let parsed = url::Url::parse(url).map_err(|e| PunchError::Provider {
2336 provider: "bedrock".to_string(),
2337 message: format!("invalid URL: {e}"),
2338 })?;
2339 let host = parsed.host_str().unwrap_or("");
2340 let path = parsed.path();
2341
2342 let payload_hash = hex_sha256(payload);
2344
2345 let mut signed_header_names: Vec<String> =
2346 headers.iter().map(|(k, _)| k.to_lowercase()).collect();
2347 signed_header_names.push("host".to_string());
2348 signed_header_names.push("x-amz-date".to_string());
2349 signed_header_names.sort();
2350 signed_header_names.dedup();
2351
2352 let mut header_map: Vec<(String, String)> = headers
2353 .iter()
2354 .map(|(k, v)| (k.to_lowercase(), v.trim().to_string()))
2355 .collect();
2356 header_map.push(("host".to_string(), host.to_string()));
2357 header_map.push(("x-amz-date".to_string(), timestamp.to_string()));
2358 header_map.sort_by(|a, b| a.0.cmp(&b.0));
2359 header_map.dedup_by(|a, b| a.0 == b.0);
2360
2361 let canonical_headers: String = header_map
2362 .iter()
2363 .map(|(k, v)| format!("{}:{}\n", k, v))
2364 .collect();
2365
2366 let signed_headers = signed_header_names.join(";");
2367
2368 let canonical_request = format!(
2369 "{}\n{}\n\n{}\n{}\n{}",
2370 method, path, canonical_headers, signed_headers, payload_hash,
2371 );
2372
2373 let credential_scope = format!("{}/{}/{}/aws4_request", date, self.region, service);
2375 let string_to_sign = format!(
2376 "AWS4-HMAC-SHA256\n{}\n{}\n{}",
2377 timestamp,
2378 credential_scope,
2379 hex_sha256(canonical_request.as_bytes()),
2380 );
2381
2382 let k_date = hmac_sha256(
2384 format!("AWS4{}", self.secret_key).as_bytes(),
2385 date.as_bytes(),
2386 );
2387 let k_region = hmac_sha256(&k_date, self.region.as_bytes());
2388 let k_service = hmac_sha256(&k_region, service.as_bytes());
2389 let k_signing = hmac_sha256(&k_service, b"aws4_request");
2390
2391 let signature = hex_encode(&hmac_sha256(&k_signing, string_to_sign.as_bytes()));
2393
2394 Ok(format!(
2396 "AWS4-HMAC-SHA256 Credential={}/{}, SignedHeaders={}, Signature={}",
2397 self.access_key, credential_scope, signed_headers, signature,
2398 ))
2399 }
2400}
2401
2402fn hex_sha256(data: &[u8]) -> String {
2404 let mut hasher = Sha256::new();
2405 hasher.update(data);
2406 hex_encode(hasher.finalize().as_slice())
2407}
2408
2409fn hmac_sha256(key: &[u8], data: &[u8]) -> Vec<u8> {
2411 type HmacSha256 = Hmac<Sha256>;
2412 let mut mac = HmacSha256::new_from_slice(key).expect("HMAC can take key of any size");
2413 mac.update(data);
2414 mac.finalize().into_bytes().to_vec()
2415}
2416
2417fn hex_encode(bytes: &[u8]) -> String {
2419 bytes.iter().map(|b| format!("{:02x}", b)).collect()
2420}
2421
2422#[async_trait]
2423impl LlmDriver for BedrockDriver {
2424 async fn complete(&self, request: CompletionRequest) -> PunchResult<CompletionResponse> {
2425 let url = self.build_url(&request.model);
2426 let body = self.build_request_body(&request);
2427 let payload = serde_json::to_vec(&body).map_err(|e| PunchError::Provider {
2428 provider: "bedrock".to_string(),
2429 message: format!("failed to serialize request: {e}"),
2430 })?;
2431
2432 let timestamp = chrono::Utc::now().format("%Y%m%dT%H%M%SZ").to_string();
2433
2434 let auth_header = self.sign_request(
2435 "POST",
2436 &url,
2437 &[("content-type".to_string(), "application/json".to_string())],
2438 &payload,
2439 ×tamp,
2440 )?;
2441
2442 let parsed_url = url::Url::parse(&url).map_err(|e| PunchError::Provider {
2443 provider: "bedrock".to_string(),
2444 message: format!("invalid URL: {e}"),
2445 })?;
2446 let host = parsed_url.host_str().unwrap_or_default().to_string();
2447
2448 let response = self
2449 .client
2450 .post(&url)
2451 .header("content-type", "application/json")
2452 .header("host", &host)
2453 .header("x-amz-date", ×tamp)
2454 .header("authorization", &auth_header)
2455 .body(payload)
2456 .send()
2457 .await
2458 .map_err(|e| PunchError::Provider {
2459 provider: "bedrock".to_string(),
2460 message: format!("request failed: {e}"),
2461 })?;
2462
2463 let status = response.status();
2464
2465 if status == reqwest::StatusCode::TOO_MANY_REQUESTS {
2466 return Err(PunchError::RateLimited {
2467 provider: "bedrock".to_string(),
2468 retry_after_ms: 60_000,
2469 });
2470 }
2471
2472 if status == reqwest::StatusCode::UNAUTHORIZED || status == reqwest::StatusCode::FORBIDDEN {
2473 return Err(PunchError::Auth(
2474 "AWS Bedrock credentials are invalid or lack permissions".to_string(),
2475 ));
2476 }
2477
2478 let response_body: serde_json::Value =
2479 response.json().await.map_err(|e| PunchError::Provider {
2480 provider: "bedrock".to_string(),
2481 message: format!("failed to parse response: {e}"),
2482 })?;
2483
2484 if !status.is_success() {
2485 let error_msg = response_body["message"].as_str().unwrap_or("unknown error");
2486 return Err(PunchError::Provider {
2487 provider: "bedrock".to_string(),
2488 message: format!("API error ({}): {}", status, error_msg),
2489 });
2490 }
2491
2492 self.parse_response(&response_body)
2493 }
2494
2495 async fn stream_complete_with_callback(
2496 &self,
2497 request: CompletionRequest,
2498 callback: StreamCallback,
2499 ) -> PunchResult<CompletionResponse> {
2500 let response = self.complete(request).await?;
2503 callback(StreamChunk {
2504 delta: response.message.content.clone(),
2505 is_final: true,
2506 tool_call_delta: None,
2507 });
2508 Ok(response)
2509 }
2510}
2511
2512pub struct AzureOpenAiDriver {
2521 inner: OpenAiCompatibleDriver,
2522 resource: String,
2523 deployment: String,
2524 api_version: String,
2525}
2526
2527impl AzureOpenAiDriver {
2528 pub fn new(
2535 api_key: String,
2536 resource: String,
2537 deployment: String,
2538 api_version: Option<String>,
2539 ) -> Self {
2540 let base_url = format!("https://{}.openai.azure.com", resource);
2541 Self {
2542 inner: OpenAiCompatibleDriver::new(api_key, base_url, "azure_openai".to_string()),
2543 resource,
2544 deployment,
2545 api_version: api_version.unwrap_or_else(|| "2024-02-01".to_string()),
2546 }
2547 }
2548
2549 pub fn with_client(
2551 client: Client,
2552 api_key: String,
2553 resource: String,
2554 deployment: String,
2555 api_version: Option<String>,
2556 ) -> Self {
2557 let base_url = format!("https://{}.openai.azure.com", resource);
2558 Self {
2559 inner: OpenAiCompatibleDriver::with_client(
2560 client,
2561 api_key,
2562 base_url,
2563 "azure_openai".to_string(),
2564 ),
2565 resource,
2566 deployment,
2567 api_version: api_version.unwrap_or_else(|| "2024-02-01".to_string()),
2568 }
2569 }
2570
2571 pub fn build_url(&self) -> String {
2573 format!(
2574 "https://{}.openai.azure.com/openai/deployments/{}/chat/completions?api-version={}",
2575 self.resource, self.deployment, self.api_version,
2576 )
2577 }
2578
2579 pub fn resource(&self) -> &str {
2581 &self.resource
2582 }
2583
2584 pub fn deployment(&self) -> &str {
2586 &self.deployment
2587 }
2588
2589 pub fn build_request_body(&self, request: &CompletionRequest) -> serde_json::Value {
2591 self.inner.build_request_body(request)
2592 }
2593
2594 pub fn parse_response(&self, body: &serde_json::Value) -> PunchResult<CompletionResponse> {
2596 self.inner.parse_response(body)
2597 }
2598}
2599
2600#[async_trait]
2601impl LlmDriver for AzureOpenAiDriver {
2602 async fn complete(&self, request: CompletionRequest) -> PunchResult<CompletionResponse> {
2603 let url = self.build_url();
2604 let body = self.inner.build_request_body(&request);
2605
2606 let response = self
2607 .inner
2608 .client
2609 .post(&url)
2610 .header("api-key", &self.inner.api_key)
2611 .header("content-type", "application/json")
2612 .json(&body)
2613 .send()
2614 .await
2615 .map_err(|e| PunchError::Provider {
2616 provider: "azure_openai".to_string(),
2617 message: format!("request failed: {e}"),
2618 })?;
2619
2620 let status = response.status();
2621
2622 if status == reqwest::StatusCode::TOO_MANY_REQUESTS {
2623 let retry_after = response
2624 .headers()
2625 .get("retry-after")
2626 .and_then(|v| v.to_str().ok())
2627 .and_then(|s| s.parse::<u64>().ok())
2628 .unwrap_or(60)
2629 * 1000;
2630
2631 return Err(PunchError::RateLimited {
2632 provider: "azure_openai".to_string(),
2633 retry_after_ms: retry_after,
2634 });
2635 }
2636
2637 if status == reqwest::StatusCode::UNAUTHORIZED || status == reqwest::StatusCode::FORBIDDEN {
2638 return Err(PunchError::Auth(
2639 "Azure OpenAI API key is invalid or lacks permissions".to_string(),
2640 ));
2641 }
2642
2643 let response_body: serde_json::Value =
2644 response.json().await.map_err(|e| PunchError::Provider {
2645 provider: "azure_openai".to_string(),
2646 message: format!("failed to parse response: {e}"),
2647 })?;
2648
2649 if !status.is_success() {
2650 let error_msg = response_body["error"]["message"]
2651 .as_str()
2652 .unwrap_or("unknown error");
2653 return Err(PunchError::Provider {
2654 provider: "azure_openai".to_string(),
2655 message: format!("API error ({}): {}", status, error_msg),
2656 });
2657 }
2658
2659 self.inner.parse_response(&response_body)
2660 }
2661
2662 async fn stream_complete_with_callback(
2663 &self,
2664 request: CompletionRequest,
2665 callback: StreamCallback,
2666 ) -> PunchResult<CompletionResponse> {
2667 let url = self.build_url();
2668 let mut body = self.inner.build_request_body(&request);
2669 body["stream"] = serde_json::json!(true);
2670
2671 let response = self
2672 .inner
2673 .client
2674 .post(&url)
2675 .header("api-key", &self.inner.api_key)
2676 .header("content-type", "application/json")
2677 .json(&body)
2678 .send()
2679 .await
2680 .map_err(|e| PunchError::Provider {
2681 provider: "azure_openai".to_string(),
2682 message: format!("stream request failed: {e}"),
2683 })?;
2684
2685 let status = response.status();
2686 if status == reqwest::StatusCode::TOO_MANY_REQUESTS {
2687 return Err(PunchError::RateLimited {
2688 provider: "azure_openai".to_string(),
2689 retry_after_ms: 60_000,
2690 });
2691 }
2692 if !status.is_success() {
2693 let err_body: serde_json::Value =
2694 response.json().await.unwrap_or(serde_json::json!({}));
2695 let msg = err_body["error"]["message"]
2696 .as_str()
2697 .unwrap_or("unknown error");
2698 return Err(PunchError::Provider {
2699 provider: "azure_openai".to_string(),
2700 message: format!("API error ({}): {}", status, msg),
2701 });
2702 }
2703
2704 let raw = read_stream_body(response).await?;
2705 let assembled = self.inner.parse_openai_stream(&raw, &callback)?;
2707 Ok(assembled)
2708 }
2709}
2710
2711fn default_base_url(provider: &Provider) -> &'static str {
2717 match provider {
2718 Provider::Anthropic => "https://api.anthropic.com",
2719 Provider::OpenAI => "https://api.openai.com",
2720 Provider::Google => "https://generativelanguage.googleapis.com",
2721 Provider::Groq => "https://api.groq.com/openai",
2722 Provider::DeepSeek => "https://api.deepseek.com",
2723 Provider::Ollama => "http://localhost:11434",
2724 Provider::Mistral => "https://api.mistral.ai",
2725 Provider::Together => "https://api.together.xyz",
2726 Provider::Fireworks => "https://api.fireworks.ai/inference",
2727 Provider::Cerebras => "https://api.cerebras.ai",
2728 Provider::XAI => "https://api.x.ai",
2729 Provider::Cohere => "https://api.cohere.ai",
2730 Provider::Bedrock => "https://bedrock-runtime.us-east-1.amazonaws.com",
2731 Provider::AzureOpenAi => "",
2732 Provider::Custom(_) => "",
2733 }
2734}
2735
2736pub fn create_driver(config: &ModelConfig) -> PunchResult<Arc<dyn LlmDriver>> {
2746 create_driver_with_client(config, None)
2747}
2748
2749pub fn create_driver_with_client(
2751 config: &ModelConfig,
2752 shared_client: Option<&Client>,
2753) -> PunchResult<Arc<dyn LlmDriver>> {
2754 let api_key = match &config.api_key_env {
2755 Some(env_var) => std::env::var(env_var).map_err(|_| {
2756 PunchError::Auth(format!(
2757 "environment variable '{}' not set for {} driver",
2758 env_var, config.provider
2759 ))
2760 })?,
2761 None => {
2762 String::new()
2764 }
2765 };
2766
2767 let base_url = config
2768 .base_url
2769 .clone()
2770 .unwrap_or_else(|| default_base_url(&config.provider).to_string());
2771
2772 match &config.provider {
2773 Provider::Anthropic => {
2774 if let Some(client) = shared_client {
2775 Ok(Arc::new(AnthropicDriver::with_client(
2776 client.clone(),
2777 api_key,
2778 Some(base_url),
2779 )))
2780 } else {
2781 Ok(Arc::new(AnthropicDriver::new(api_key, Some(base_url))))
2782 }
2783 }
2784 Provider::Google => {
2785 if let Some(client) = shared_client {
2786 Ok(Arc::new(GeminiDriver::with_client(
2787 client.clone(),
2788 api_key,
2789 Some(base_url),
2790 )))
2791 } else {
2792 Ok(Arc::new(GeminiDriver::new(api_key, Some(base_url))))
2793 }
2794 }
2795 Provider::Ollama => {
2796 if let Some(client) = shared_client {
2797 Ok(Arc::new(OllamaDriver::with_client(
2798 client.clone(),
2799 Some(base_url),
2800 )))
2801 } else {
2802 Ok(Arc::new(OllamaDriver::new(Some(base_url))))
2803 }
2804 }
2805 Provider::Bedrock => {
2806 let (access_key, secret_key) = if api_key.contains(':') {
2809 let parts: Vec<&str> = api_key.splitn(2, ':').collect();
2810 (parts[0].to_string(), parts[1].to_string())
2811 } else {
2812 let ak = std::env::var("AWS_ACCESS_KEY_ID").unwrap_or(api_key);
2813 let sk = std::env::var("AWS_SECRET_ACCESS_KEY").unwrap_or_default();
2814 (ak, sk)
2815 };
2816 let region = if base_url.contains("bedrock-runtime.") {
2818 base_url
2819 .trim_start_matches("https://bedrock-runtime.")
2820 .split('.')
2821 .next()
2822 .unwrap_or("us-east-1")
2823 .to_string()
2824 } else {
2825 "us-east-1".to_string()
2826 };
2827 if let Some(client) = shared_client {
2828 Ok(Arc::new(BedrockDriver::with_client(
2829 client.clone(),
2830 access_key,
2831 secret_key,
2832 region,
2833 )))
2834 } else {
2835 Ok(Arc::new(BedrockDriver::new(access_key, secret_key, region)))
2836 }
2837 }
2838 Provider::AzureOpenAi => {
2839 let resource = if base_url.contains(".openai.azure.com") {
2842 base_url
2843 .trim_start_matches("https://")
2844 .split('.')
2845 .next()
2846 .unwrap_or("default")
2847 .to_string()
2848 } else {
2849 base_url.clone()
2850 };
2851 let deployment = config.model.clone();
2852 if let Some(client) = shared_client {
2853 Ok(Arc::new(AzureOpenAiDriver::with_client(
2854 client.clone(),
2855 api_key,
2856 resource,
2857 deployment,
2858 None,
2859 )))
2860 } else {
2861 Ok(Arc::new(AzureOpenAiDriver::new(
2862 api_key, resource, deployment, None,
2863 )))
2864 }
2865 }
2866 provider => {
2867 let name = provider.to_string();
2868 if let Some(client) = shared_client {
2869 Ok(Arc::new(OpenAiCompatibleDriver::with_client(
2870 client.clone(),
2871 api_key,
2872 base_url,
2873 name,
2874 )))
2875 } else {
2876 Ok(Arc::new(OpenAiCompatibleDriver::new(
2877 api_key, base_url, name,
2878 )))
2879 }
2880 }
2881 }
2882}
2883
2884#[cfg(test)]
2889mod tests {
2890 use super::*;
2891 use punch_types::ToolCategory;
2892
2893 fn simple_request() -> CompletionRequest {
2895 CompletionRequest {
2896 model: "test-model".to_string(),
2897 messages: vec![Message::new(Role::User, "Hello")],
2898 tools: Vec::new(),
2899 max_tokens: 4096,
2900 temperature: Some(0.7),
2901 system_prompt: Some("You are helpful.".to_string()),
2902 }
2903 }
2904
2905 fn request_with_tools() -> CompletionRequest {
2907 CompletionRequest {
2908 model: "test-model".to_string(),
2909 messages: vec![Message::new(Role::User, "Use the tool")],
2910 tools: vec![ToolDefinition {
2911 name: "get_weather".to_string(),
2912 description: "Get weather for a city".to_string(),
2913 input_schema: serde_json::json!({
2914 "type": "object",
2915 "properties": {
2916 "city": {"type": "string"}
2917 }
2918 }),
2919 category: ToolCategory::Web,
2920 }],
2921 max_tokens: 4096,
2922 temperature: Some(0.7),
2923 system_prompt: None,
2924 }
2925 }
2926
2927 #[test]
2932 fn gemini_request_formatting() {
2933 let driver = GeminiDriver::new("test-key".to_string(), None);
2934 let body = driver.build_request_body(&simple_request());
2935
2936 let contents = body["contents"].as_array().unwrap();
2937 assert_eq!(contents.len(), 1);
2938 let first_text = contents[0]["parts"][0]["text"].as_str().unwrap();
2940 assert_eq!(first_text, "Hello");
2941 assert_eq!(contents[0]["role"].as_str().unwrap(), "user");
2942 let sys_text = body["system_instruction"]["parts"][0]["text"]
2944 .as_str()
2945 .unwrap();
2946 assert_eq!(sys_text, "You are helpful.");
2947
2948 assert_eq!(body["generationConfig"]["maxOutputTokens"], 4096);
2949 assert!((body["generationConfig"]["temperature"].as_f64().unwrap() - 0.7).abs() < 0.001);
2950 }
2951
2952 #[test]
2953 fn gemini_response_parsing() {
2954 let driver = GeminiDriver::new("test-key".to_string(), None);
2955 let response_body = serde_json::json!({
2956 "candidates": [{
2957 "content": {
2958 "parts": [{"text": "Hello there!"}],
2959 "role": "model"
2960 },
2961 "finishReason": "STOP"
2962 }],
2963 "usageMetadata": {
2964 "promptTokenCount": 10,
2965 "candidatesTokenCount": 5
2966 }
2967 });
2968
2969 let resp = driver.parse_response(&response_body).unwrap();
2970 assert_eq!(resp.message.content, "Hello there!");
2971 assert_eq!(resp.stop_reason, StopReason::EndTurn);
2972 assert_eq!(resp.usage.input_tokens, 10);
2973 assert_eq!(resp.usage.output_tokens, 5);
2974 }
2975
2976 #[test]
2977 fn gemini_role_mapping_system_prepended() {
2978 let driver = GeminiDriver::new("test-key".to_string(), None);
2979 let req = CompletionRequest {
2980 model: "gemini-pro".to_string(),
2981 messages: vec![
2982 Message::new(Role::System, "Be concise."),
2983 Message::new(Role::User, "Hi"),
2984 ],
2985 tools: Vec::new(),
2986 max_tokens: 1024,
2987 temperature: None,
2988 system_prompt: None,
2989 };
2990 let body = driver.build_request_body(&req);
2991 let contents = body["contents"].as_array().unwrap();
2992 assert_eq!(contents.len(), 1);
2994 let text = contents[0]["parts"][0]["text"].as_str().unwrap();
2995 assert_eq!(text, "Hi");
2996 let sys_text = body["system_instruction"]["parts"][0]["text"]
2998 .as_str()
2999 .unwrap();
3000 assert_eq!(sys_text, "Be concise.");
3001 }
3002
3003 #[test]
3004 fn gemini_function_call_parsing() {
3005 let driver = GeminiDriver::new("test-key".to_string(), None);
3006 let response_body = serde_json::json!({
3007 "candidates": [{
3008 "content": {
3009 "parts": [
3010 {"text": "Let me check the weather."},
3011 {
3012 "functionCall": {
3013 "name": "get_weather",
3014 "args": {"city": "London"}
3015 }
3016 }
3017 ],
3018 "role": "model"
3019 },
3020 "finishReason": "STOP"
3021 }],
3022 "usageMetadata": {
3023 "promptTokenCount": 15,
3024 "candidatesTokenCount": 8
3025 }
3026 });
3027
3028 let resp = driver.parse_response(&response_body).unwrap();
3029 assert_eq!(resp.message.content, "Let me check the weather.");
3030 assert_eq!(resp.stop_reason, StopReason::ToolUse);
3031 assert_eq!(resp.message.tool_calls.len(), 1);
3032 assert_eq!(resp.message.tool_calls[0].name, "get_weather");
3033 assert_eq!(resp.message.tool_calls[0].input["city"], "London");
3034 }
3035
3036 #[test]
3037 fn gemini_api_key_in_url() {
3038 let driver = GeminiDriver::new("my-secret-key".to_string(), None);
3039 let url = driver.build_url("gemini-pro");
3040 assert!(url.contains("key=my-secret-key"));
3041 assert!(url.contains("models/gemini-pro:generateContent"));
3042 }
3043
3044 #[test]
3049 fn ollama_request_formatting() {
3050 let driver = OllamaDriver::new(None);
3051 let body = driver.build_request_body(&simple_request());
3052
3053 assert_eq!(body["model"], "test-model");
3054 assert_eq!(body["stream"], false);
3055 let messages = body["messages"].as_array().unwrap();
3056 assert_eq!(messages.len(), 2);
3058 assert_eq!(messages[0]["role"], "system");
3059 assert_eq!(messages[0]["content"], "You are helpful.");
3060 assert_eq!(messages[1]["role"], "user");
3061 assert_eq!(messages[1]["content"], "Hello");
3062 assert!((body["options"]["temperature"].as_f64().unwrap() - 0.7).abs() < 0.001);
3063 }
3064
3065 #[test]
3066 fn ollama_response_parsing() {
3067 let driver = OllamaDriver::new(None);
3068 let response_body = serde_json::json!({
3069 "message": {
3070 "role": "assistant",
3071 "content": "Hi there!"
3072 },
3073 "done": true,
3074 "prompt_eval_count": 20,
3075 "eval_count": 10
3076 });
3077
3078 let resp = driver.parse_response(&response_body).unwrap();
3079 assert_eq!(resp.message.content, "Hi there!");
3080 assert_eq!(resp.stop_reason, StopReason::EndTurn);
3081 assert_eq!(resp.usage.input_tokens, 20);
3082 assert_eq!(resp.usage.output_tokens, 10);
3083 }
3084
3085 #[test]
3086 fn ollama_default_endpoint() {
3087 let driver = OllamaDriver::new(None);
3088 assert_eq!(driver.base_url(), "http://localhost:11434");
3089 }
3090
3091 #[test]
3092 fn ollama_custom_endpoint() {
3093 let driver = OllamaDriver::new(Some("http://myhost:9999".to_string()));
3094 assert_eq!(driver.base_url(), "http://myhost:9999");
3095 }
3096
3097 #[test]
3102 fn bedrock_request_formatting() {
3103 let driver = BedrockDriver::new(
3104 "TESTKEY".to_string(),
3105 "testsecret".to_string(),
3106 "us-west-2".to_string(),
3107 );
3108 let body = driver.build_request_body(&simple_request());
3109
3110 let messages = body["messages"].as_array().unwrap();
3111 assert_eq!(messages.len(), 1);
3112 assert_eq!(messages[0]["role"], "user");
3113 assert_eq!(messages[0]["content"][0]["text"], "Hello");
3114
3115 assert_eq!(body["inferenceConfig"]["maxTokens"], 4096);
3116 assert!((body["inferenceConfig"]["temperature"].as_f64().unwrap() - 0.7).abs() < 0.001);
3117 assert_eq!(body["system"][0]["text"], "You are helpful.");
3118 }
3119
3120 #[test]
3121 fn bedrock_sigv4_canonical_request() {
3122 let driver = BedrockDriver::new(
3123 "TESTACCESS1234567890".to_string(),
3124 "TestSecretKeyValue1234567890abcdefghijk".to_string(),
3125 "us-east-1".to_string(),
3126 );
3127
3128 let payload = b"{}";
3129 let timestamp = "20260313T120000Z";
3130
3131 let auth = driver
3132 .sign_request(
3133 "POST",
3134 "https://bedrock-runtime.us-east-1.amazonaws.com/model/test/converse",
3135 &[("content-type".to_string(), "application/json".to_string())],
3136 payload,
3137 timestamp,
3138 )
3139 .unwrap();
3140
3141 assert!(auth.starts_with(
3142 "AWS4-HMAC-SHA256 Credential=TESTACCESS1234567890/20260313/us-east-1/bedrock/aws4_request"
3143 ));
3144 assert!(auth.contains("SignedHeaders=content-type;host;x-amz-date"));
3145 assert!(auth.contains("Signature="));
3146 }
3147
3148 #[test]
3149 fn bedrock_response_parsing() {
3150 let driver = BedrockDriver::new(
3151 "key".to_string(),
3152 "secret".to_string(),
3153 "us-east-1".to_string(),
3154 );
3155 let response_body = serde_json::json!({
3156 "output": {
3157 "message": {
3158 "role": "assistant",
3159 "content": [{"text": "The answer is 42."}]
3160 }
3161 },
3162 "stopReason": "end_turn",
3163 "usage": {
3164 "inputTokens": 100,
3165 "outputTokens": 50
3166 }
3167 });
3168
3169 let resp = driver.parse_response(&response_body).unwrap();
3170 assert_eq!(resp.message.content, "The answer is 42.");
3171 assert_eq!(resp.stop_reason, StopReason::EndTurn);
3172 assert_eq!(resp.usage.input_tokens, 100);
3173 assert_eq!(resp.usage.output_tokens, 50);
3174 }
3175
3176 #[test]
3181 fn azure_openai_url_construction() {
3182 let driver = AzureOpenAiDriver::new(
3183 "my-azure-key".to_string(),
3184 "myresource".to_string(),
3185 "gpt-4-deployment".to_string(),
3186 None,
3187 );
3188 let url = driver.build_url();
3189 assert_eq!(
3190 url,
3191 "https://myresource.openai.azure.com/openai/deployments/gpt-4-deployment/chat/completions?api-version=2024-02-01"
3192 );
3193 }
3194
3195 #[test]
3196 fn azure_openai_custom_api_version() {
3197 let driver = AzureOpenAiDriver::new(
3198 "key".to_string(),
3199 "res".to_string(),
3200 "dep".to_string(),
3201 Some("2024-06-01".to_string()),
3202 );
3203 let url = driver.build_url();
3204 assert!(url.contains("api-version=2024-06-01"));
3205 }
3206
3207 #[test]
3208 fn azure_openai_request_formatting() {
3209 let driver = AzureOpenAiDriver::new(
3210 "key".to_string(),
3211 "res".to_string(),
3212 "dep".to_string(),
3213 None,
3214 );
3215 let body = driver.build_request_body(&simple_request());
3216 let messages = body["messages"].as_array().unwrap();
3218 assert_eq!(messages.len(), 2);
3220 assert_eq!(messages[0]["role"], "system");
3221 assert_eq!(messages[1]["role"], "user");
3222 assert_eq!(body["model"], "test-model");
3223 }
3224
3225 #[test]
3226 fn azure_openai_resource_and_deployment() {
3227 let driver = AzureOpenAiDriver::new(
3228 "key".to_string(),
3229 "my-resource".to_string(),
3230 "my-deploy".to_string(),
3231 None,
3232 );
3233 assert_eq!(driver.resource(), "my-resource");
3234 assert_eq!(driver.deployment(), "my-deploy");
3235 }
3236
3237 #[test]
3242 fn create_driver_dispatches_ollama() {
3243 let config = ModelConfig {
3244 provider: Provider::Ollama,
3245 model: "llama3".to_string(),
3246 api_key_env: None,
3247 base_url: None,
3248 max_tokens: None,
3249 temperature: None,
3250 };
3251 let driver = create_driver(&config);
3253 assert!(driver.is_ok());
3254 }
3255
3256 #[test]
3257 fn create_driver_dispatches_gemini() {
3258 unsafe { std::env::set_var("TEST_GEMINI_KEY_DISPATCH", "fake-key") };
3261 let config = ModelConfig {
3262 provider: Provider::Google,
3263 model: "gemini-pro".to_string(),
3264 api_key_env: Some("TEST_GEMINI_KEY_DISPATCH".to_string()),
3265 base_url: None,
3266 max_tokens: None,
3267 temperature: None,
3268 };
3269 let driver = create_driver(&config);
3270 assert!(driver.is_ok());
3271 unsafe { std::env::remove_var("TEST_GEMINI_KEY_DISPATCH") };
3272 }
3273
3274 #[test]
3275 fn create_driver_dispatches_bedrock() {
3276 unsafe { std::env::set_var("TEST_BEDROCK_KEY_DISPATCH", "TESTKEY:TESTSECRET") };
3278 let config = ModelConfig {
3279 provider: Provider::Bedrock,
3280 model: "anthropic.claude-v2".to_string(),
3281 api_key_env: Some("TEST_BEDROCK_KEY_DISPATCH".to_string()),
3282 base_url: None,
3283 max_tokens: None,
3284 temperature: None,
3285 };
3286 let driver = create_driver(&config);
3287 assert!(driver.is_ok());
3288 unsafe { std::env::remove_var("TEST_BEDROCK_KEY_DISPATCH") };
3289 }
3290
3291 #[test]
3292 fn create_driver_dispatches_azure_openai() {
3293 unsafe { std::env::set_var("TEST_AZURE_KEY_DISPATCH", "azure-key") };
3295 let config = ModelConfig {
3296 provider: Provider::AzureOpenAi,
3297 model: "gpt-4".to_string(),
3298 api_key_env: Some("TEST_AZURE_KEY_DISPATCH".to_string()),
3299 base_url: Some("https://myres.openai.azure.com".to_string()),
3300 max_tokens: None,
3301 temperature: None,
3302 };
3303 let driver = create_driver(&config);
3304 assert!(driver.is_ok());
3305 unsafe { std::env::remove_var("TEST_AZURE_KEY_DISPATCH") };
3306 }
3307
3308 #[test]
3309 fn gemini_tools_in_request() {
3310 let driver = GeminiDriver::new("key".to_string(), None);
3311 let body = driver.build_request_body(&request_with_tools());
3312
3313 let tools = body["tools"].as_array().unwrap();
3314 assert_eq!(tools.len(), 1);
3315 let func_decls = tools[0]["function_declarations"].as_array().unwrap();
3316 assert_eq!(func_decls.len(), 1);
3317 assert_eq!(func_decls[0]["name"], "get_weather");
3318 }
3319
3320 #[test]
3321 fn ollama_tools_in_request() {
3322 let driver = OllamaDriver::new(None);
3323 let body = driver.build_request_body(&request_with_tools());
3324
3325 let tools = body["tools"].as_array().unwrap();
3326 assert_eq!(tools.len(), 1);
3327 assert_eq!(tools[0]["type"], "function");
3328 assert_eq!(tools[0]["function"]["name"], "get_weather");
3329 }
3330
3331 #[test]
3332 fn bedrock_url_construction() {
3333 let driver = BedrockDriver::new(
3334 "key".to_string(),
3335 "secret".to_string(),
3336 "eu-west-1".to_string(),
3337 );
3338 let url = driver.build_url("anthropic.claude-3-sonnet");
3339 assert_eq!(
3340 url,
3341 "https://bedrock-runtime.eu-west-1.amazonaws.com/model/anthropic.claude-3-sonnet/converse"
3342 );
3343 }
3344
3345 #[test]
3350 fn token_usage_default() {
3351 let u = TokenUsage::default();
3352 assert_eq!(u.input_tokens, 0);
3353 assert_eq!(u.output_tokens, 0);
3354 assert_eq!(u.total(), 0);
3355 }
3356
3357 #[test]
3358 fn token_usage_accumulate() {
3359 let mut u = TokenUsage {
3360 input_tokens: 10,
3361 output_tokens: 20,
3362 };
3363 let other = TokenUsage {
3364 input_tokens: 5,
3365 output_tokens: 15,
3366 };
3367 u.accumulate(&other);
3368 assert_eq!(u.input_tokens, 15);
3369 assert_eq!(u.output_tokens, 35);
3370 assert_eq!(u.total(), 50);
3371 }
3372
3373 #[test]
3374 fn token_usage_total() {
3375 let u = TokenUsage {
3376 input_tokens: 100,
3377 output_tokens: 200,
3378 };
3379 assert_eq!(u.total(), 300);
3380 }
3381
3382 #[test]
3387 fn stop_reason_serialization() {
3388 let json = serde_json::to_string(&StopReason::EndTurn).unwrap();
3389 assert_eq!(json, "\"end_turn\"");
3390
3391 let json = serde_json::to_string(&StopReason::ToolUse).unwrap();
3392 assert_eq!(json, "\"tool_use\"");
3393
3394 let json = serde_json::to_string(&StopReason::MaxTokens).unwrap();
3395 assert_eq!(json, "\"max_tokens\"");
3396
3397 let json = serde_json::to_string(&StopReason::Error).unwrap();
3398 assert_eq!(json, "\"error\"");
3399 }
3400
3401 #[test]
3402 fn stop_reason_deserialization() {
3403 let sr: StopReason = serde_json::from_str("\"end_turn\"").unwrap();
3404 assert_eq!(sr, StopReason::EndTurn);
3405
3406 let sr: StopReason = serde_json::from_str("\"tool_use\"").unwrap();
3407 assert_eq!(sr, StopReason::ToolUse);
3408 }
3409
3410 #[test]
3415 fn anthropic_request_body_simple() {
3416 let driver = AnthropicDriver::new("test-key".to_string(), None);
3417 let body = driver.build_request_body(&simple_request());
3418
3419 assert_eq!(body["model"], "test-model");
3420 assert_eq!(body["max_tokens"], 4096);
3421 let system = body["system"].as_array().unwrap();
3423 assert_eq!(system.len(), 1);
3424 assert_eq!(system[0]["type"], "text");
3425 assert_eq!(system[0]["text"], "You are helpful.");
3426 assert_eq!(system[0]["cache_control"]["type"], "ephemeral");
3427 assert!((body["temperature"].as_f64().unwrap() - 0.7).abs() < 0.001);
3428
3429 let messages = body["messages"].as_array().unwrap();
3430 assert_eq!(messages.len(), 1);
3431 assert_eq!(messages[0]["role"], "user");
3432 assert_eq!(messages[0]["content"], "Hello");
3433 }
3434
3435 #[test]
3436 fn anthropic_request_body_with_tools() {
3437 let driver = AnthropicDriver::new("test-key".to_string(), None);
3438 let body = driver.build_request_body(&request_with_tools());
3439
3440 let tools = body["tools"].as_array().unwrap();
3441 assert_eq!(tools.len(), 1);
3442 assert_eq!(tools[0]["name"], "get_weather");
3443 assert!(tools[0]["input_schema"]["properties"].is_object());
3444 }
3445
3446 #[test]
3447 fn anthropic_request_body_no_system_prompt() {
3448 let driver = AnthropicDriver::new("test-key".to_string(), None);
3449 let req = CompletionRequest {
3450 model: "test".into(),
3451 messages: vec![Message::new(Role::User, "Hi")],
3452 tools: Vec::new(),
3453 max_tokens: 100,
3454 temperature: None,
3455 system_prompt: None,
3456 };
3457 let body = driver.build_request_body(&req);
3458 assert!(body.get("system").is_none());
3459 assert!(body.get("temperature").is_none());
3460 }
3461
3462 #[test]
3463 fn anthropic_parse_response_text() {
3464 let driver = AnthropicDriver::new("test-key".to_string(), None);
3465 let response_body = serde_json::json!({
3466 "content": [
3467 {"type": "text", "text": "Hello!"}
3468 ],
3469 "stop_reason": "end_turn",
3470 "usage": {
3471 "input_tokens": 10,
3472 "output_tokens": 5
3473 }
3474 });
3475
3476 let resp = driver.parse_response(&response_body).unwrap();
3477 assert_eq!(resp.message.content, "Hello!");
3478 assert_eq!(resp.stop_reason, StopReason::EndTurn);
3479 assert_eq!(resp.usage.input_tokens, 10);
3480 assert_eq!(resp.usage.output_tokens, 5);
3481 assert!(resp.message.tool_calls.is_empty());
3482 }
3483
3484 #[test]
3485 fn anthropic_parse_response_tool_use() {
3486 let driver = AnthropicDriver::new("test-key".to_string(), None);
3487 let response_body = serde_json::json!({
3488 "content": [
3489 {"type": "text", "text": "Let me check."},
3490 {
3491 "type": "tool_use",
3492 "id": "tool_abc",
3493 "name": "get_weather",
3494 "input": {"city": "NYC"}
3495 }
3496 ],
3497 "stop_reason": "tool_use",
3498 "usage": {"input_tokens": 20, "output_tokens": 15}
3499 });
3500
3501 let resp = driver.parse_response(&response_body).unwrap();
3502 assert_eq!(resp.message.content, "Let me check.");
3503 assert_eq!(resp.stop_reason, StopReason::ToolUse);
3504 assert_eq!(resp.message.tool_calls.len(), 1);
3505 assert_eq!(resp.message.tool_calls[0].id, "tool_abc");
3506 assert_eq!(resp.message.tool_calls[0].name, "get_weather");
3507 assert_eq!(resp.message.tool_calls[0].input["city"], "NYC");
3508 }
3509
3510 #[test]
3511 fn anthropic_parse_response_max_tokens() {
3512 let driver = AnthropicDriver::new("test-key".to_string(), None);
3513 let response_body = serde_json::json!({
3514 "content": [{"type": "text", "text": "truncated"}],
3515 "stop_reason": "max_tokens",
3516 "usage": {"input_tokens": 5, "output_tokens": 100}
3517 });
3518
3519 let resp = driver.parse_response(&response_body).unwrap();
3520 assert_eq!(resp.stop_reason, StopReason::MaxTokens);
3521 }
3522
3523 #[test]
3524 fn anthropic_parse_response_unknown_stop_reason() {
3525 let driver = AnthropicDriver::new("test-key".to_string(), None);
3526 let response_body = serde_json::json!({
3527 "content": [{"type": "text", "text": "err"}],
3528 "stop_reason": "something_unknown",
3529 "usage": {"input_tokens": 0, "output_tokens": 0}
3530 });
3531
3532 let resp = driver.parse_response(&response_body).unwrap();
3533 assert_eq!(resp.stop_reason, StopReason::Error);
3534 }
3535
3536 #[test]
3537 fn anthropic_request_body_with_assistant_and_tool_messages() {
3538 let driver = AnthropicDriver::new("test-key".to_string(), None);
3539 let req = CompletionRequest {
3540 model: "test".into(),
3541 messages: vec![
3542 Message::new(Role::User, "Hi"),
3543 Message {
3544 role: Role::Assistant,
3545 content: "I'll check".into(),
3546 content_parts: Vec::new(),
3547 tool_calls: vec![ToolCall {
3548 id: "call_1".into(),
3549 name: "file_read".into(),
3550 input: serde_json::json!({"path": "/tmp/test"}),
3551 }],
3552 tool_results: Vec::new(),
3553 timestamp: chrono::Utc::now(),
3554 },
3555 Message {
3556 role: Role::Tool,
3557 content: String::new(),
3558 content_parts: Vec::new(),
3559 tool_calls: Vec::new(),
3560 tool_results: vec![punch_types::ToolCallResult {
3561 id: "call_1".into(),
3562 content: "file contents".into(),
3563 is_error: false,
3564 image: None,
3565 }],
3566 timestamp: chrono::Utc::now(),
3567 },
3568 ],
3569 tools: Vec::new(),
3570 max_tokens: 100,
3571 temperature: None,
3572 system_prompt: None,
3573 };
3574
3575 let body = driver.build_request_body(&req);
3576 let messages = body["messages"].as_array().unwrap();
3577 assert_eq!(messages.len(), 3);
3578 assert_eq!(messages[0]["role"], "user");
3579 assert_eq!(messages[1]["role"], "assistant");
3580 assert_eq!(messages[2]["role"], "user"); }
3582
3583 #[test]
3584 fn anthropic_request_body_system_message_skipped() {
3585 let driver = AnthropicDriver::new("test-key".to_string(), None);
3586 let req = CompletionRequest {
3587 model: "test".into(),
3588 messages: vec![
3589 Message::new(Role::System, "System instruction"),
3590 Message::new(Role::User, "Hi"),
3591 ],
3592 tools: Vec::new(),
3593 max_tokens: 100,
3594 temperature: None,
3595 system_prompt: None,
3596 };
3597
3598 let body = driver.build_request_body(&req);
3599 let messages = body["messages"].as_array().unwrap();
3600 assert_eq!(messages.len(), 1);
3602 assert_eq!(messages[0]["role"], "user");
3603 }
3604
3605 #[test]
3610 fn openai_request_body_simple() {
3611 let driver = OpenAiCompatibleDriver::new(
3612 "key".into(),
3613 "https://api.openai.com".into(),
3614 "openai".into(),
3615 );
3616 let body = driver.build_request_body(&simple_request());
3617
3618 assert_eq!(body["model"], "test-model");
3619 let messages = body["messages"].as_array().unwrap();
3620 assert_eq!(messages.len(), 2);
3621 assert_eq!(messages[0]["role"], "system");
3622 assert_eq!(messages[0]["content"], "You are helpful.");
3623 assert_eq!(messages[1]["role"], "user");
3624 }
3625
3626 #[test]
3627 fn openai_request_body_with_tools() {
3628 let driver = OpenAiCompatibleDriver::new(
3629 "key".into(),
3630 "https://api.openai.com".into(),
3631 "openai".into(),
3632 );
3633 let body = driver.build_request_body(&request_with_tools());
3634
3635 let tools = body["tools"].as_array().unwrap();
3636 assert_eq!(tools.len(), 1);
3637 assert_eq!(tools[0]["type"], "function");
3638 assert_eq!(tools[0]["function"]["name"], "get_weather");
3639 }
3640
3641 #[test]
3642 fn openai_parse_response_text() {
3643 let driver = OpenAiCompatibleDriver::new(
3644 "key".into(),
3645 "https://api.openai.com".into(),
3646 "openai".into(),
3647 );
3648 let response_body = serde_json::json!({
3649 "choices": [{
3650 "message": {
3651 "role": "assistant",
3652 "content": "Hello!"
3653 },
3654 "finish_reason": "stop"
3655 }],
3656 "usage": {
3657 "prompt_tokens": 10,
3658 "completion_tokens": 5
3659 }
3660 });
3661
3662 let resp = driver.parse_response(&response_body).unwrap();
3663 assert_eq!(resp.message.content, "Hello!");
3664 assert_eq!(resp.stop_reason, StopReason::EndTurn);
3665 assert_eq!(resp.usage.input_tokens, 10);
3666 assert_eq!(resp.usage.output_tokens, 5);
3667 }
3668
3669 #[test]
3670 fn openai_parse_response_tool_calls() {
3671 let driver = OpenAiCompatibleDriver::new(
3672 "key".into(),
3673 "https://api.openai.com".into(),
3674 "openai".into(),
3675 );
3676 let response_body = serde_json::json!({
3677 "choices": [{
3678 "message": {
3679 "role": "assistant",
3680 "content": null,
3681 "tool_calls": [{
3682 "id": "call_123",
3683 "type": "function",
3684 "function": {
3685 "name": "get_weather",
3686 "arguments": "{\"city\": \"NYC\"}"
3687 }
3688 }]
3689 },
3690 "finish_reason": "tool_calls"
3691 }],
3692 "usage": {"prompt_tokens": 10, "completion_tokens": 5}
3693 });
3694
3695 let resp = driver.parse_response(&response_body).unwrap();
3696 assert_eq!(resp.stop_reason, StopReason::ToolUse);
3697 assert_eq!(resp.message.tool_calls.len(), 1);
3698 assert_eq!(resp.message.tool_calls[0].name, "get_weather");
3699 assert_eq!(resp.message.tool_calls[0].input["city"], "NYC");
3700 }
3701
3702 #[test]
3703 fn openai_parse_response_tool_calls_fix_stop_reason() {
3704 let driver = OpenAiCompatibleDriver::new(
3705 "key".into(),
3706 "https://api.openai.com".into(),
3707 "openai".into(),
3708 );
3709 let response_body = serde_json::json!({
3711 "choices": [{
3712 "message": {
3713 "role": "assistant",
3714 "content": "Using tool",
3715 "tool_calls": [{
3716 "id": "call_1",
3717 "type": "function",
3718 "function": {
3719 "name": "test_tool",
3720 "arguments": "{}"
3721 }
3722 }]
3723 },
3724 "finish_reason": "stop"
3725 }],
3726 "usage": {"prompt_tokens": 0, "completion_tokens": 0}
3727 });
3728
3729 let resp = driver.parse_response(&response_body).unwrap();
3730 assert_eq!(resp.stop_reason, StopReason::ToolUse);
3731 }
3732
3733 #[test]
3734 fn openai_parse_response_length_stop_reason() {
3735 let driver = OpenAiCompatibleDriver::new(
3736 "key".into(),
3737 "https://api.openai.com".into(),
3738 "openai".into(),
3739 );
3740 let response_body = serde_json::json!({
3741 "choices": [{
3742 "message": {"role": "assistant", "content": "cut off"},
3743 "finish_reason": "length"
3744 }],
3745 "usage": {"prompt_tokens": 0, "completion_tokens": 0}
3746 });
3747
3748 let resp = driver.parse_response(&response_body).unwrap();
3749 assert_eq!(resp.stop_reason, StopReason::MaxTokens);
3750 }
3751
3752 #[test]
3753 fn openai_parse_response_no_choices_error() {
3754 let driver = OpenAiCompatibleDriver::new(
3755 "key".into(),
3756 "https://api.openai.com".into(),
3757 "openai".into(),
3758 );
3759 let response_body = serde_json::json!({"choices": []});
3760
3761 let result = driver.parse_response(&response_body);
3762 assert!(result.is_err());
3763 }
3764
3765 #[test]
3770 fn gemini_assistant_message_formatting() {
3771 let driver = GeminiDriver::new("key".to_string(), None);
3772 let req = CompletionRequest {
3773 model: "gemini-pro".into(),
3774 messages: vec![
3775 Message::new(Role::User, "Hi"),
3776 Message {
3777 role: Role::Assistant,
3778 content: "Let me help".into(),
3779 content_parts: Vec::new(),
3780 tool_calls: vec![ToolCall {
3781 id: "tc1".into(),
3782 name: "get_weather".into(),
3783 input: serde_json::json!({"city": "NYC"}),
3784 }],
3785 tool_results: Vec::new(),
3786 timestamp: chrono::Utc::now(),
3787 },
3788 ],
3789 tools: Vec::new(),
3790 max_tokens: 100,
3791 temperature: None,
3792 system_prompt: None,
3793 };
3794
3795 let body = driver.build_request_body(&req);
3796 let contents = body["contents"].as_array().unwrap();
3797 assert_eq!(contents[1]["role"], "model"); let parts = contents[1]["parts"].as_array().unwrap();
3799 assert!(parts.len() >= 2); }
3801
3802 #[test]
3803 fn gemini_max_tokens_stop_reason() {
3804 let driver = GeminiDriver::new("key".to_string(), None);
3805 let response_body = serde_json::json!({
3806 "candidates": [{
3807 "content": {
3808 "parts": [{"text": "truncated"}],
3809 "role": "model"
3810 },
3811 "finishReason": "MAX_TOKENS"
3812 }],
3813 "usageMetadata": {"promptTokenCount": 0, "candidatesTokenCount": 0}
3814 });
3815
3816 let resp = driver.parse_response(&response_body).unwrap();
3817 assert_eq!(resp.stop_reason, StopReason::MaxTokens);
3818 }
3819
3820 #[test]
3821 fn gemini_custom_base_url() {
3822 let driver =
3823 GeminiDriver::new("key".to_string(), Some("https://custom.example.com".into()));
3824 let url = driver.build_url("gemini-pro");
3825 assert!(url.starts_with("https://custom.example.com/"));
3826 }
3827
3828 #[test]
3833 fn ollama_response_with_tool_calls() {
3834 let driver = OllamaDriver::new(None);
3835 let response_body = serde_json::json!({
3836 "message": {
3837 "role": "assistant",
3838 "content": "",
3839 "tool_calls": [{
3840 "function": {
3841 "name": "get_weather",
3842 "arguments": {"city": "London"}
3843 }
3844 }]
3845 },
3846 "done": true,
3847 "prompt_eval_count": 10,
3848 "eval_count": 5
3849 });
3850
3851 let resp = driver.parse_response(&response_body).unwrap();
3852 assert_eq!(resp.stop_reason, StopReason::ToolUse);
3853 assert_eq!(resp.message.tool_calls.len(), 1);
3854 assert_eq!(resp.message.tool_calls[0].name, "get_weather");
3855 }
3856
3857 #[test]
3858 fn ollama_response_not_done() {
3859 let driver = OllamaDriver::new(None);
3860 let response_body = serde_json::json!({
3861 "message": {"role": "assistant", "content": "partial"},
3862 "done": false,
3863 "prompt_eval_count": 10,
3864 "eval_count": 5
3865 });
3866
3867 let resp = driver.parse_response(&response_body).unwrap();
3868 assert_eq!(resp.stop_reason, StopReason::MaxTokens);
3869 }
3870
3871 #[test]
3876 fn bedrock_request_with_tools() {
3877 let driver = BedrockDriver::new("key".into(), "secret".into(), "us-east-1".into());
3878 let body = driver.build_request_body(&request_with_tools());
3879
3880 let tool_config = &body["toolConfig"]["tools"];
3881 assert!(tool_config.is_array());
3882 let tools = tool_config.as_array().unwrap();
3883 assert_eq!(tools.len(), 1);
3884 assert_eq!(tools[0]["toolSpec"]["name"], "get_weather");
3885 }
3886
3887 #[test]
3888 fn bedrock_response_with_tool_use() {
3889 let driver = BedrockDriver::new("key".into(), "secret".into(), "us-east-1".into());
3890 let response_body = serde_json::json!({
3891 "output": {
3892 "message": {
3893 "role": "assistant",
3894 "content": [
3895 {"text": "Using tool"},
3896 {"toolUse": {
3897 "toolUseId": "tu_123",
3898 "name": "get_weather",
3899 "input": {"city": "NYC"}
3900 }}
3901 ]
3902 }
3903 },
3904 "stopReason": "tool_use",
3905 "usage": {"inputTokens": 10, "outputTokens": 20}
3906 });
3907
3908 let resp = driver.parse_response(&response_body).unwrap();
3909 assert_eq!(resp.stop_reason, StopReason::ToolUse);
3910 assert_eq!(resp.message.tool_calls.len(), 1);
3911 assert_eq!(resp.message.tool_calls[0].id, "tu_123");
3912 assert_eq!(resp.message.tool_calls[0].name, "get_weather");
3913 }
3914
3915 #[test]
3916 fn bedrock_request_with_tool_results() {
3917 let driver = BedrockDriver::new("key".into(), "secret".into(), "us-east-1".into());
3918 let req = CompletionRequest {
3919 model: "test".into(),
3920 messages: vec![
3921 Message::new(Role::User, "Hi"),
3922 Message {
3923 role: Role::Tool,
3924 content: String::new(),
3925 content_parts: Vec::new(),
3926 tool_calls: Vec::new(),
3927 tool_results: vec![punch_types::ToolCallResult {
3928 id: "tu_1".into(),
3929 content: "result data".into(),
3930 is_error: false,
3931 image: None,
3932 }],
3933 timestamp: chrono::Utc::now(),
3934 },
3935 ],
3936 tools: Vec::new(),
3937 max_tokens: 100,
3938 temperature: None,
3939 system_prompt: None,
3940 };
3941
3942 let body = driver.build_request_body(&req);
3943 let messages = body["messages"].as_array().unwrap();
3944 assert_eq!(messages[1]["role"], "user"); let content = messages[1]["content"].as_array().unwrap();
3946 assert!(content[0]["toolResult"].is_object());
3947 assert_eq!(content[0]["toolResult"]["status"], "success");
3948 }
3949
3950 #[test]
3951 fn bedrock_url_different_regions() {
3952 let driver = BedrockDriver::new("k".into(), "s".into(), "ap-southeast-1".into());
3953 let url = driver.build_url("model-id");
3954 assert!(url.contains("ap-southeast-1"));
3955 }
3956
3957 #[test]
3962 fn azure_openai_delegates_parse_to_openai() {
3963 let driver = AzureOpenAiDriver::new("key".into(), "res".into(), "dep".into(), None);
3964 let response_body = serde_json::json!({
3965 "choices": [{
3966 "message": {"role": "assistant", "content": "Azure response"},
3967 "finish_reason": "stop"
3968 }],
3969 "usage": {"prompt_tokens": 5, "completion_tokens": 3}
3970 });
3971
3972 let resp = driver.parse_response(&response_body).unwrap();
3973 assert_eq!(resp.message.content, "Azure response");
3974 }
3975
3976 #[test]
3981 fn default_base_url_anthropic() {
3982 assert_eq!(
3983 default_base_url(&Provider::Anthropic),
3984 "https://api.anthropic.com"
3985 );
3986 }
3987
3988 #[test]
3989 fn default_base_url_openai() {
3990 assert_eq!(
3991 default_base_url(&Provider::OpenAI),
3992 "https://api.openai.com"
3993 );
3994 }
3995
3996 #[test]
3997 fn default_base_url_google() {
3998 assert_eq!(
3999 default_base_url(&Provider::Google),
4000 "https://generativelanguage.googleapis.com"
4001 );
4002 }
4003
4004 #[test]
4005 fn default_base_url_ollama() {
4006 assert_eq!(
4007 default_base_url(&Provider::Ollama),
4008 "http://localhost:11434"
4009 );
4010 }
4011
4012 #[test]
4013 fn default_base_url_groq() {
4014 assert_eq!(
4015 default_base_url(&Provider::Groq),
4016 "https://api.groq.com/openai"
4017 );
4018 }
4019
4020 #[test]
4021 fn default_base_url_deepseek() {
4022 assert_eq!(
4023 default_base_url(&Provider::DeepSeek),
4024 "https://api.deepseek.com"
4025 );
4026 }
4027
4028 #[test]
4033 fn test_hex_sha256() {
4034 let hash = hex_sha256(b"");
4035 assert_eq!(
4036 hash,
4037 "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855"
4038 );
4039 }
4040
4041 #[test]
4042 fn test_hex_encode() {
4043 assert_eq!(hex_encode(&[0x00, 0xff, 0x0a, 0xbc]), "00ff0abc");
4044 }
4045
4046 #[test]
4047 fn test_hmac_sha256_basic() {
4048 let result = hmac_sha256(b"key", b"data");
4049 assert!(!result.is_empty());
4050 assert_eq!(result.len(), 32); }
4052
4053 #[test]
4058 fn create_driver_missing_api_key_env() {
4059 let config = ModelConfig {
4060 provider: Provider::Anthropic,
4061 model: "claude-3".into(),
4062 api_key_env: Some("PUNCH_TEST_NONEXISTENT_KEY_XYZ".into()),
4063 base_url: None,
4064 max_tokens: None,
4065 temperature: None,
4066 };
4067 let result = create_driver(&config);
4068 assert!(result.is_err());
4069 }
4070
4071 #[test]
4072 fn create_driver_openai_compatible_fallback() {
4073 unsafe { std::env::set_var("TEST_CUSTOM_KEY_DRIVER", "fake-key") };
4075 let config = ModelConfig {
4076 provider: Provider::Custom("my-custom".into()),
4077 model: "custom-model".into(),
4078 api_key_env: Some("TEST_CUSTOM_KEY_DRIVER".into()),
4079 base_url: Some("https://custom.api.com".into()),
4080 max_tokens: None,
4081 temperature: None,
4082 };
4083 let result = create_driver(&config);
4084 assert!(result.is_ok());
4085 unsafe { std::env::remove_var("TEST_CUSTOM_KEY_DRIVER") };
4086 }
4087
4088 #[test]
4093 fn strip_thinking_tags_removes_think_block() {
4094 let input = "<think>internal reasoning here</think>The answer is 42.";
4095 assert_eq!(strip_thinking_tags(input), "The answer is 42.");
4096 }
4097
4098 #[test]
4099 fn strip_thinking_tags_removes_thinking_block() {
4100 let input = "<thinking>step by step reasoning</thinking>Hello world!";
4101 assert_eq!(strip_thinking_tags(input), "Hello world!");
4102 }
4103
4104 #[test]
4105 fn strip_thinking_tags_removes_reasoning_block() {
4106 let input = "<reasoning>let me figure this out</reasoning>The result is correct.";
4107 assert_eq!(strip_thinking_tags(input), "The result is correct.");
4108 }
4109
4110 #[test]
4111 fn strip_thinking_tags_removes_reflection_block() {
4112 let input = "<reflection>checking my work</reflection>Yes, that's right.";
4113 assert_eq!(strip_thinking_tags(input), "Yes, that's right.");
4114 }
4115
4116 #[test]
4117 fn strip_thinking_tags_removes_multiple_blocks() {
4118 let input = "<think>first thought</think>Hello <thinking>second thought</thinking>world!";
4119 assert_eq!(strip_thinking_tags(input), "Hello world!");
4120 }
4121
4122 #[test]
4123 fn strip_thinking_tags_preserves_content_without_tags() {
4124 let input = "Just a normal response with no thinking tags.";
4125 assert_eq!(strip_thinking_tags(input), input);
4126 }
4127
4128 #[test]
4129 fn strip_thinking_tags_handles_multiline_tags() {
4130 let input = "<think>\nLine 1\nLine 2\nLine 3\n</think>\nThe final answer.";
4131 assert_eq!(strip_thinking_tags(input), "The final answer.");
4132 }
4133
4134 #[test]
4135 fn strip_thinking_tags_returns_original_if_all_thinking() {
4136 let input = "<think>this is all thinking content and nothing else</think>";
4139 assert_eq!(strip_thinking_tags(input), input);
4140 }
4141
4142 #[test]
4143 fn strip_thinking_tags_handles_unclosed_tag() {
4144 let input = "Some text<think>unclosed thinking block";
4145 assert_eq!(strip_thinking_tags(input), "Some text");
4146 }
4147
4148 #[test]
4149 fn strip_thinking_tags_handles_empty_input() {
4150 assert_eq!(strip_thinking_tags(""), "");
4151 }
4152
4153 #[test]
4154 fn strip_thinking_tags_handles_empty_think_block() {
4155 let input = "<think></think>Visible content.";
4156 assert_eq!(strip_thinking_tags(input), "Visible content.");
4157 }
4158
4159 #[test]
4160 fn strip_thinking_tags_trims_whitespace() {
4161 let input = " <think>reasoning</think> Result ";
4162 assert_eq!(strip_thinking_tags(input), "Result");
4163 }
4164
4165 #[test]
4166 fn strip_thinking_tags_mixed_tag_types() {
4167 let input = "<think>t1</think>A<reasoning>r1</reasoning>B<reflection>f1</reflection>C";
4168 assert_eq!(strip_thinking_tags(input), "ABC");
4169 }
4170
4171 #[test]
4172 fn ollama_response_strips_thinking_tags() {
4173 let driver = OllamaDriver::new(None);
4174 let response_body = serde_json::json!({
4175 "message": {
4176 "role": "assistant",
4177 "content": "<think>\nLet me think about this...\nThe user wants hello world.\n</think>\nHello, world!"
4178 },
4179 "done": true,
4180 "prompt_eval_count": 20,
4181 "eval_count": 50
4182 });
4183
4184 let resp = driver.parse_response(&response_body).unwrap();
4185 assert_eq!(resp.message.content, "Hello, world!");
4186 assert!(!resp.message.content.contains("<think>"));
4187 }
4188
4189 #[test]
4190 fn gemini_response_strips_thinking_tags() {
4191 let driver = GeminiDriver::new("test-key".to_string(), None);
4192 let response_body = serde_json::json!({
4193 "candidates": [{
4194 "content": {
4195 "parts": [{"text": "<thinking>reasoning step</thinking>The answer is 7."}],
4196 "role": "model"
4197 },
4198 "finishReason": "STOP"
4199 }],
4200 "usageMetadata": {
4201 "promptTokenCount": 10,
4202 "candidatesTokenCount": 20
4203 }
4204 });
4205
4206 let resp = driver.parse_response(&response_body).unwrap();
4207 assert_eq!(resp.message.content, "The answer is 7.");
4208 assert!(!resp.message.content.contains("<thinking>"));
4209 }
4210
4211 #[test]
4212 fn anthropic_response_strips_thinking_tags() {
4213 let driver = AnthropicDriver::new("test-key".to_string(), None);
4214 let response_body = serde_json::json!({
4215 "content": [
4216 {"type": "text", "text": "<think>internal thought</think>Clean output."}
4217 ],
4218 "stop_reason": "end_turn",
4219 "usage": {"input_tokens": 10, "output_tokens": 5}
4220 });
4221
4222 let resp = driver.parse_response(&response_body).unwrap();
4223 assert_eq!(resp.message.content, "Clean output.");
4224 }
4225
4226 #[test]
4227 fn bedrock_response_strips_thinking_tags() {
4228 let driver = BedrockDriver::new(
4229 "key".to_string(),
4230 "secret".to_string(),
4231 "us-east-1".to_string(),
4232 );
4233 let response_body = serde_json::json!({
4234 "output": {
4235 "message": {
4236 "role": "assistant",
4237 "content": [{"text": "<reasoning>deep thought</reasoning>Result here."}]
4238 }
4239 },
4240 "stopReason": "end_turn",
4241 "usage": {"inputTokens": 50, "outputTokens": 25}
4242 });
4243
4244 let resp = driver.parse_response(&response_body).unwrap();
4245 assert_eq!(resp.message.content, "Result here.");
4246 }
4247
4248 #[test]
4253 fn stream_chunk_serialization_roundtrip() {
4254 let chunk = StreamChunk {
4255 delta: "Hello".to_string(),
4256 is_final: false,
4257 tool_call_delta: None,
4258 };
4259 let json = serde_json::to_string(&chunk).unwrap();
4260 let deserialized: StreamChunk = serde_json::from_str(&json).unwrap();
4261 assert_eq!(deserialized.delta, "Hello");
4262 assert!(!deserialized.is_final);
4263 assert!(deserialized.tool_call_delta.is_none());
4264 }
4265
4266 #[test]
4267 fn stream_chunk_with_tool_call_delta_serialization() {
4268 let chunk = StreamChunk {
4269 delta: String::new(),
4270 is_final: false,
4271 tool_call_delta: Some(ToolCallDelta {
4272 index: 0,
4273 id: Some("call_123".to_string()),
4274 name: Some("get_weather".to_string()),
4275 arguments_delta: "{\"city\":".to_string(),
4276 }),
4277 };
4278 let json = serde_json::to_string(&chunk).unwrap();
4279 let deserialized: StreamChunk = serde_json::from_str(&json).unwrap();
4280 let tcd = deserialized.tool_call_delta.unwrap();
4281 assert_eq!(tcd.index, 0);
4282 assert_eq!(tcd.id.unwrap(), "call_123");
4283 assert_eq!(tcd.name.unwrap(), "get_weather");
4284 assert_eq!(tcd.arguments_delta, "{\"city\":");
4285 }
4286
4287 #[test]
4288 fn stream_chunk_final_serialization() {
4289 let chunk = StreamChunk {
4290 delta: String::new(),
4291 is_final: true,
4292 tool_call_delta: None,
4293 };
4294 let json = serde_json::to_string(&chunk).unwrap();
4295 assert!(json.contains("\"is_final\":true"));
4296 }
4297
4298 #[test]
4299 fn tool_call_delta_serialization_roundtrip() {
4300 let tcd = ToolCallDelta {
4301 index: 2,
4302 id: None,
4303 name: None,
4304 arguments_delta: "\"NYC\"}".to_string(),
4305 };
4306 let json = serde_json::to_string(&tcd).unwrap();
4307 let deserialized: ToolCallDelta = serde_json::from_str(&json).unwrap();
4308 assert_eq!(deserialized.index, 2);
4309 assert!(deserialized.id.is_none());
4310 assert!(deserialized.name.is_none());
4311 assert_eq!(deserialized.arguments_delta, "\"NYC\"}");
4312 }
4313
4314 #[test]
4319 fn parse_sse_events_basic() {
4320 let raw = "event: message_start\ndata: {\"type\":\"message_start\"}\n\nevent: content_block_delta\ndata: {\"delta\":{\"text\":\"Hi\"}}\n\n";
4321 let events = parse_sse_events(raw);
4322 assert_eq!(events.len(), 2);
4323 assert_eq!(events[0].0, "message_start");
4324 assert_eq!(events[1].0, "content_block_delta");
4325 }
4326
4327 #[test]
4328 fn parse_sse_events_with_done() {
4329 let raw = "data: {\"choices\":[{\"delta\":{\"content\":\"Hello\"}}]}\n\ndata: [DONE]\n\n";
4330 let events = parse_sse_events(raw);
4331 assert_eq!(events.len(), 2);
4332 assert_eq!(events[1].1, "[DONE]");
4333 }
4334
4335 #[test]
4336 fn parse_sse_events_empty_input() {
4337 let events = parse_sse_events("");
4338 assert!(events.is_empty());
4339 }
4340
4341 #[test]
4342 fn parse_sse_events_no_trailing_newline() {
4343 let raw = "event: test\ndata: {\"value\":1}";
4344 let events = parse_sse_events(raw);
4345 assert_eq!(events.len(), 1);
4346 assert_eq!(events[0].0, "test");
4347 }
4348
4349 #[test]
4350 fn parse_sse_events_multiline_data() {
4351 let raw = "data: line1\ndata: line2\n\n";
4352 let events = parse_sse_events(raw);
4353 assert_eq!(events.len(), 1);
4354 assert_eq!(events[0].1, "line1\nline2");
4355 }
4356
4357 #[test]
4358 fn parse_sse_events_no_event_field() {
4359 let raw = "data: {\"hello\":\"world\"}\n\n";
4360 let events = parse_sse_events(raw);
4361 assert_eq!(events.len(), 1);
4362 assert_eq!(events[0].0, "message"); }
4364
4365 #[test]
4370 fn anthropic_stream_text_only() {
4371 let raw = concat!(
4372 "event: message_start\n",
4373 "data: {\"type\":\"message_start\",\"message\":{\"usage\":{\"input_tokens\":25}}}\n\n",
4374 "event: content_block_start\n",
4375 "data: {\"type\":\"content_block_start\",\"index\":0,\"content_block\":{\"type\":\"text\",\"text\":\"\"}}\n\n",
4376 "event: content_block_delta\n",
4377 "data: {\"type\":\"content_block_delta\",\"delta\":{\"type\":\"text_delta\",\"text\":\"Hello\"}}\n\n",
4378 "event: content_block_delta\n",
4379 "data: {\"type\":\"content_block_delta\",\"delta\":{\"type\":\"text_delta\",\"text\":\" world\"}}\n\n",
4380 "event: message_delta\n",
4381 "data: {\"type\":\"message_delta\",\"delta\":{\"stop_reason\":\"end_turn\"},\"usage\":{\"output_tokens\":10}}\n\n",
4382 "event: message_stop\n",
4383 "data: {\"type\":\"message_stop\"}\n\n",
4384 );
4385
4386 let events = parse_sse_events(raw);
4387 let chunks: Arc<std::sync::Mutex<Vec<StreamChunk>>> =
4388 Arc::new(std::sync::Mutex::new(Vec::new()));
4389 let chunks_clone = chunks.clone();
4390 let callback: StreamCallback = Arc::new(move |chunk| {
4391 chunks_clone.lock().unwrap().push(chunk);
4392 });
4393
4394 let mut text_content = String::new();
4396 let mut usage = TokenUsage::default();
4397 let mut stop_reason = StopReason::EndTurn;
4398
4399 for (event_type, data) in &events {
4400 let parsed: serde_json::Value = match serde_json::from_str(data) {
4401 Ok(v) => v,
4402 Err(_) => continue,
4403 };
4404
4405 match event_type.as_str() {
4406 "message_start" => {
4407 if let Some(inp) = parsed["message"]["usage"]["input_tokens"].as_u64() {
4408 usage.input_tokens = inp;
4409 }
4410 }
4411 "content_block_delta" => {
4412 if let Some(text) = parsed["delta"]["text"].as_str() {
4413 text_content.push_str(text);
4414 callback(StreamChunk {
4415 delta: text.to_string(),
4416 is_final: false,
4417 tool_call_delta: None,
4418 });
4419 }
4420 }
4421 "message_delta" => {
4422 if let Some(sr) = parsed["delta"]["stop_reason"].as_str() {
4423 stop_reason = match sr {
4424 "end_turn" => StopReason::EndTurn,
4425 "tool_use" => StopReason::ToolUse,
4426 _ => StopReason::Error,
4427 };
4428 }
4429 if let Some(out) = parsed["usage"]["output_tokens"].as_u64() {
4430 usage.output_tokens = out;
4431 }
4432 }
4433 "message_stop" => {
4434 callback(StreamChunk {
4435 delta: String::new(),
4436 is_final: true,
4437 tool_call_delta: None,
4438 });
4439 }
4440 _ => {}
4441 }
4442 }
4443
4444 assert_eq!(text_content, "Hello world");
4445 assert_eq!(usage.input_tokens, 25);
4446 assert_eq!(usage.output_tokens, 10);
4447 assert_eq!(stop_reason, StopReason::EndTurn);
4448
4449 let received = chunks.lock().unwrap();
4450 assert_eq!(received.len(), 3); assert_eq!(received[0].delta, "Hello");
4452 assert_eq!(received[1].delta, " world");
4453 assert!(received[2].is_final);
4454 }
4455
4456 #[test]
4457 fn anthropic_stream_with_tool_use() {
4458 let raw = concat!(
4459 "event: message_start\n",
4460 "data: {\"type\":\"message_start\",\"message\":{\"usage\":{\"input_tokens\":15}}}\n\n",
4461 "event: content_block_start\n",
4462 "data: {\"type\":\"content_block_start\",\"index\":0,\"content_block\":{\"type\":\"text\",\"text\":\"\"}}\n\n",
4463 "event: content_block_delta\n",
4464 "data: {\"type\":\"content_block_delta\",\"delta\":{\"type\":\"text_delta\",\"text\":\"Checking.\"}}\n\n",
4465 "event: content_block_start\n",
4466 "data: {\"type\":\"content_block_start\",\"index\":1,\"content_block\":{\"type\":\"tool_use\",\"id\":\"tool_1\",\"name\":\"get_weather\"}}\n\n",
4467 "event: content_block_delta\n",
4468 "data: {\"type\":\"content_block_delta\",\"delta\":{\"type\":\"input_json_delta\",\"partial_json\":\"{\\\"city\\\"\"}}\n\n",
4469 "event: content_block_delta\n",
4470 "data: {\"type\":\"content_block_delta\",\"delta\":{\"type\":\"input_json_delta\",\"partial_json\":\": \\\"NYC\\\"}\"}}\n\n",
4471 "event: message_delta\n",
4472 "data: {\"type\":\"message_delta\",\"delta\":{\"stop_reason\":\"tool_use\"},\"usage\":{\"output_tokens\":20}}\n\n",
4473 "event: message_stop\n",
4474 "data: {\"type\":\"message_stop\"}\n\n",
4475 );
4476
4477 let events = parse_sse_events(raw);
4478 assert!(events.len() >= 7);
4480
4481 let mut tool_json_bufs: Vec<String> = Vec::new();
4483 let mut tc_idx: Option<usize> = None;
4484
4485 for (event_type, data) in &events {
4486 let parsed: serde_json::Value = match serde_json::from_str(data) {
4487 Ok(v) => v,
4488 Err(_) => continue,
4489 };
4490 match event_type.as_str() {
4491 "content_block_start" => {
4492 if parsed["content_block"]["type"].as_str() == Some("tool_use") {
4493 tool_json_bufs.push(String::new());
4494 tc_idx = Some(tool_json_bufs.len() - 1);
4495 } else {
4496 tc_idx = None;
4497 }
4498 }
4499 "content_block_delta" => {
4500 if parsed["delta"]["type"].as_str() == Some("input_json_delta")
4501 && let Some(idx) = tc_idx
4502 && let Some(buf) = tool_json_bufs.get_mut(idx)
4503 {
4504 buf.push_str(parsed["delta"]["partial_json"].as_str().unwrap_or(""));
4505 }
4506 }
4507 _ => {}
4508 }
4509 }
4510
4511 assert_eq!(tool_json_bufs.len(), 1);
4512 assert_eq!(tool_json_bufs[0], "{\"city\": \"NYC\"}");
4513
4514 let parsed_input: serde_json::Value = serde_json::from_str(&tool_json_bufs[0]).unwrap();
4515 assert_eq!(parsed_input["city"], "NYC");
4516 }
4517
4518 #[test]
4523 fn openai_stream_text_only() {
4524 let driver = OpenAiCompatibleDriver::new(
4525 "key".into(),
4526 "https://api.openai.com".into(),
4527 "openai".into(),
4528 );
4529
4530 let raw = concat!(
4531 "data: {\"choices\":[{\"delta\":{\"role\":\"assistant\"},\"index\":0}]}\n\n",
4532 "data: {\"choices\":[{\"delta\":{\"content\":\"Hello\"},\"index\":0}]}\n\n",
4533 "data: {\"choices\":[{\"delta\":{\"content\":\" world\"},\"index\":0}]}\n\n",
4534 "data: {\"choices\":[{\"delta\":{},\"index\":0,\"finish_reason\":\"stop\"}]}\n\n",
4535 "data: [DONE]\n\n",
4536 );
4537
4538 let chunks: Arc<std::sync::Mutex<Vec<StreamChunk>>> =
4539 Arc::new(std::sync::Mutex::new(Vec::new()));
4540 let chunks_clone = chunks.clone();
4541 let callback: StreamCallback = Arc::new(move |chunk| {
4542 chunks_clone.lock().unwrap().push(chunk);
4543 });
4544
4545 let resp = driver.parse_openai_stream(raw, &callback).unwrap();
4546
4547 assert_eq!(resp.message.content, "Hello world");
4548 assert_eq!(resp.stop_reason, StopReason::EndTurn);
4549 assert!(resp.message.tool_calls.is_empty());
4550
4551 let received = chunks.lock().unwrap();
4552 assert!(received.len() >= 3);
4554 assert_eq!(received[0].delta, "Hello");
4555 assert_eq!(received[1].delta, " world");
4556 assert!(received.last().unwrap().is_final);
4557 }
4558
4559 #[test]
4560 fn openai_stream_with_tool_calls() {
4561 let driver = OpenAiCompatibleDriver::new(
4562 "key".into(),
4563 "https://api.openai.com".into(),
4564 "openai".into(),
4565 );
4566
4567 let raw = concat!(
4568 "data: {\"choices\":[{\"delta\":{\"tool_calls\":[{\"index\":0,\"id\":\"call_abc\",\"type\":\"function\",\"function\":{\"name\":\"get_weather\",\"arguments\":\"\"}}]},\"index\":0}]}\n\n",
4569 "data: {\"choices\":[{\"delta\":{\"tool_calls\":[{\"index\":0,\"function\":{\"arguments\":\"{\\\"ci\"}}]},\"index\":0}]}\n\n",
4570 "data: {\"choices\":[{\"delta\":{\"tool_calls\":[{\"index\":0,\"function\":{\"arguments\":\"ty\\\": \\\"NYC\\\"}\"}}]},\"index\":0}]}\n\n",
4571 "data: {\"choices\":[{\"delta\":{},\"index\":0,\"finish_reason\":\"tool_calls\"}]}\n\n",
4572 "data: [DONE]\n\n",
4573 );
4574
4575 let chunks: Arc<std::sync::Mutex<Vec<StreamChunk>>> =
4576 Arc::new(std::sync::Mutex::new(Vec::new()));
4577 let chunks_clone = chunks.clone();
4578 let callback: StreamCallback = Arc::new(move |chunk| {
4579 chunks_clone.lock().unwrap().push(chunk);
4580 });
4581
4582 let resp = driver.parse_openai_stream(raw, &callback).unwrap();
4583
4584 assert_eq!(resp.stop_reason, StopReason::ToolUse);
4585 assert_eq!(resp.message.tool_calls.len(), 1);
4586 assert_eq!(resp.message.tool_calls[0].id, "call_abc");
4587 assert_eq!(resp.message.tool_calls[0].name, "get_weather");
4588 assert_eq!(resp.message.tool_calls[0].input["city"], "NYC");
4589
4590 let received = chunks.lock().unwrap();
4591 let tool_chunks: Vec<_> = received
4593 .iter()
4594 .filter(|c| c.tool_call_delta.is_some())
4595 .collect();
4596 assert!(tool_chunks.len() >= 3); assert!(received.last().unwrap().is_final);
4598 }
4599
4600 #[test]
4601 fn openai_stream_with_mixed_content_and_tools() {
4602 let driver = OpenAiCompatibleDriver::new(
4603 "key".into(),
4604 "https://api.openai.com".into(),
4605 "openai".into(),
4606 );
4607
4608 let raw = concat!(
4609 "data: {\"choices\":[{\"delta\":{\"content\":\"Sure, \"},\"index\":0}]}\n\n",
4610 "data: {\"choices\":[{\"delta\":{\"content\":\"checking.\"},\"index\":0}]}\n\n",
4611 "data: {\"choices\":[{\"delta\":{\"tool_calls\":[{\"index\":0,\"id\":\"call_1\",\"type\":\"function\",\"function\":{\"name\":\"search\",\"arguments\":\"{\\\"q\\\":\\\"test\\\"}\"}}]},\"index\":0}]}\n\n",
4612 "data: {\"choices\":[{\"delta\":{},\"index\":0,\"finish_reason\":\"tool_calls\"}]}\n\n",
4613 "data: [DONE]\n\n",
4614 );
4615
4616 let callback: StreamCallback = Arc::new(|_| {});
4617 let resp = driver.parse_openai_stream(raw, &callback).unwrap();
4618
4619 assert_eq!(resp.message.content, "Sure, checking.");
4620 assert_eq!(resp.stop_reason, StopReason::ToolUse);
4621 assert_eq!(resp.message.tool_calls.len(), 1);
4622 assert_eq!(resp.message.tool_calls[0].name, "search");
4623 }
4624
4625 #[test]
4626 fn openai_stream_length_stop_reason() {
4627 let driver = OpenAiCompatibleDriver::new(
4628 "key".into(),
4629 "https://api.openai.com".into(),
4630 "openai".into(),
4631 );
4632
4633 let raw = concat!(
4634 "data: {\"choices\":[{\"delta\":{\"content\":\"truncated\"},\"index\":0}]}\n\n",
4635 "data: {\"choices\":[{\"delta\":{},\"index\":0,\"finish_reason\":\"length\"}]}\n\n",
4636 "data: [DONE]\n\n",
4637 );
4638
4639 let callback: StreamCallback = Arc::new(|_| {});
4640 let resp = driver.parse_openai_stream(raw, &callback).unwrap();
4641 assert_eq!(resp.stop_reason, StopReason::MaxTokens);
4642 }
4643
4644 #[test]
4649 fn ollama_stream_text_only() {
4650 let driver = OllamaDriver::new(None);
4651
4652 let raw = concat!(
4653 "{\"message\":{\"role\":\"assistant\",\"content\":\"Hello\"},\"done\":false}\n",
4654 "{\"message\":{\"role\":\"assistant\",\"content\":\" world\"},\"done\":false}\n",
4655 "{\"message\":{\"role\":\"assistant\",\"content\":\"!\"},\"done\":false}\n",
4656 "{\"done\":true,\"prompt_eval_count\":15,\"eval_count\":8}\n",
4657 );
4658
4659 let chunks: Arc<std::sync::Mutex<Vec<StreamChunk>>> =
4660 Arc::new(std::sync::Mutex::new(Vec::new()));
4661 let chunks_clone = chunks.clone();
4662 let callback: StreamCallback = Arc::new(move |chunk| {
4663 chunks_clone.lock().unwrap().push(chunk);
4664 });
4665
4666 let resp = driver.parse_ollama_stream(raw, &callback).unwrap();
4667
4668 assert_eq!(resp.message.content, "Hello world!");
4669 assert_eq!(resp.stop_reason, StopReason::EndTurn);
4670 assert_eq!(resp.usage.input_tokens, 15);
4671 assert_eq!(resp.usage.output_tokens, 8);
4672
4673 let received = chunks.lock().unwrap();
4674 assert_eq!(received.len(), 4); assert_eq!(received[0].delta, "Hello");
4676 assert_eq!(received[1].delta, " world");
4677 assert_eq!(received[2].delta, "!");
4678 assert!(received[3].is_final);
4679 }
4680
4681 #[test]
4682 fn ollama_stream_with_tool_calls() {
4683 let driver = OllamaDriver::new(None);
4684
4685 let raw = concat!(
4686 "{\"message\":{\"role\":\"assistant\",\"content\":\"Let me check.\"},\"done\":false}\n",
4687 "{\"message\":{\"role\":\"assistant\",\"content\":\"\",\"tool_calls\":[{\"function\":{\"name\":\"get_weather\",\"arguments\":{\"city\":\"London\"}}}]},\"done\":true,\"prompt_eval_count\":10,\"eval_count\":5}\n",
4688 );
4689
4690 let callback: StreamCallback = Arc::new(|_| {});
4691 let resp = driver.parse_ollama_stream(raw, &callback).unwrap();
4692
4693 assert_eq!(resp.message.content, "Let me check.");
4694 assert_eq!(resp.stop_reason, StopReason::ToolUse);
4695 assert_eq!(resp.message.tool_calls.len(), 1);
4696 assert_eq!(resp.message.tool_calls[0].name, "get_weather");
4697 assert_eq!(resp.usage.input_tokens, 10);
4698 }
4699
4700 #[test]
4701 fn ollama_stream_strips_thinking_tags() {
4702 let driver = OllamaDriver::new(None);
4703
4704 let raw = concat!(
4705 "{\"message\":{\"role\":\"assistant\",\"content\":\"<think>hmm</think>\"},\"done\":false}\n",
4706 "{\"message\":{\"role\":\"assistant\",\"content\":\"Clean answer.\"},\"done\":false}\n",
4707 "{\"done\":true,\"prompt_eval_count\":5,\"eval_count\":3}\n",
4708 );
4709
4710 let callback: StreamCallback = Arc::new(|_| {});
4711 let resp = driver.parse_ollama_stream(raw, &callback).unwrap();
4712 assert_eq!(resp.message.content, "Clean answer.");
4713 }
4714
4715 #[test]
4720 fn gemini_stream_url_construction() {
4721 let driver = GeminiDriver::new("my-key".to_string(), None);
4722 let url = driver.build_stream_url("gemini-pro");
4723 assert!(url.contains("streamGenerateContent"));
4724 assert!(url.contains("alt=sse"));
4725 assert!(url.contains("key=my-key"));
4726 assert!(url.contains("models/gemini-pro"));
4727 }
4728
4729 #[test]
4730 fn gemini_stream_custom_base_url() {
4731 let driver = GeminiDriver::new(
4732 "key".to_string(),
4733 Some("https://custom.example.com".to_string()),
4734 );
4735 let url = driver.build_stream_url("gemini-pro");
4736 assert!(url.starts_with("https://custom.example.com/"));
4737 assert!(url.contains("streamGenerateContent"));
4738 }
4739
4740 #[test]
4745 fn callback_receives_all_chunks_in_order() {
4746 let driver = OpenAiCompatibleDriver::new(
4747 "key".into(),
4748 "https://api.openai.com".into(),
4749 "openai".into(),
4750 );
4751
4752 let raw = concat!(
4753 "data: {\"choices\":[{\"delta\":{\"content\":\"A\"},\"index\":0}]}\n\n",
4754 "data: {\"choices\":[{\"delta\":{\"content\":\"B\"},\"index\":0}]}\n\n",
4755 "data: {\"choices\":[{\"delta\":{\"content\":\"C\"},\"index\":0}]}\n\n",
4756 "data: {\"choices\":[{\"delta\":{},\"index\":0,\"finish_reason\":\"stop\"}]}\n\n",
4757 "data: [DONE]\n\n",
4758 );
4759
4760 let deltas: Arc<std::sync::Mutex<Vec<String>>> =
4761 Arc::new(std::sync::Mutex::new(Vec::new()));
4762 let deltas_clone = deltas.clone();
4763 let callback: StreamCallback = Arc::new(move |chunk| {
4764 if !chunk.delta.is_empty() || chunk.is_final {
4765 deltas_clone.lock().unwrap().push(chunk.delta.clone());
4766 }
4767 });
4768
4769 let _resp = driver.parse_openai_stream(raw, &callback).unwrap();
4770 let received = deltas.lock().unwrap();
4771 assert_eq!(received.as_slice(), &["A", "B", "C", ""]);
4772 }
4773
4774 #[test]
4775 fn openai_stream_multiple_tool_calls() {
4776 let driver = OpenAiCompatibleDriver::new(
4777 "key".into(),
4778 "https://api.openai.com".into(),
4779 "openai".into(),
4780 );
4781
4782 let raw = concat!(
4783 "data: {\"choices\":[{\"delta\":{\"tool_calls\":[{\"index\":0,\"id\":\"call_1\",\"type\":\"function\",\"function\":{\"name\":\"tool_a\",\"arguments\":\"{\\\"x\\\":1}\"}}]},\"index\":0}]}\n\n",
4784 "data: {\"choices\":[{\"delta\":{\"tool_calls\":[{\"index\":1,\"id\":\"call_2\",\"type\":\"function\",\"function\":{\"name\":\"tool_b\",\"arguments\":\"{\\\"y\\\":2}\"}}]},\"index\":0}]}\n\n",
4785 "data: {\"choices\":[{\"delta\":{},\"index\":0,\"finish_reason\":\"tool_calls\"}]}\n\n",
4786 "data: [DONE]\n\n",
4787 );
4788
4789 let callback: StreamCallback = Arc::new(|_| {});
4790 let resp = driver.parse_openai_stream(raw, &callback).unwrap();
4791
4792 assert_eq!(resp.message.tool_calls.len(), 2);
4793 assert_eq!(resp.message.tool_calls[0].id, "call_1");
4794 assert_eq!(resp.message.tool_calls[0].name, "tool_a");
4795 assert_eq!(resp.message.tool_calls[0].input["x"], 1);
4796 assert_eq!(resp.message.tool_calls[1].id, "call_2");
4797 assert_eq!(resp.message.tool_calls[1].name, "tool_b");
4798 assert_eq!(resp.message.tool_calls[1].input["y"], 2);
4799 }
4800
4801 #[test]
4806 fn stream_chunk_default_values() {
4807 let chunk = StreamChunk {
4808 delta: String::new(),
4809 is_final: false,
4810 tool_call_delta: None,
4811 };
4812 assert!(chunk.delta.is_empty());
4813 assert!(!chunk.is_final);
4814 assert!(chunk.tool_call_delta.is_none());
4815 }
4816
4817 #[test]
4818 fn openai_stream_empty_input() {
4819 let driver = OpenAiCompatibleDriver::new(
4820 "key".into(),
4821 "https://api.openai.com".into(),
4822 "openai".into(),
4823 );
4824
4825 let raw = "data: [DONE]\n\n";
4826
4827 let callback: StreamCallback = Arc::new(|_| {});
4828 let resp = driver.parse_openai_stream(raw, &callback).unwrap();
4829
4830 assert_eq!(resp.message.content, "");
4831 assert!(resp.message.tool_calls.is_empty());
4832 }
4833
4834 #[test]
4835 fn ollama_stream_empty_input() {
4836 let driver = OllamaDriver::new(None);
4837 let raw = "";
4838
4839 let callback: StreamCallback = Arc::new(|_| {});
4840 let resp = driver.parse_ollama_stream(raw, &callback).unwrap();
4841
4842 assert_eq!(resp.message.content, "");
4843 assert_eq!(resp.stop_reason, StopReason::MaxTokens); }
4845
4846 #[test]
4847 fn openai_stream_strips_thinking_tags() {
4848 let driver = OpenAiCompatibleDriver::new(
4849 "key".into(),
4850 "https://api.openai.com".into(),
4851 "openai".into(),
4852 );
4853
4854 let raw = concat!(
4855 "data: {\"choices\":[{\"delta\":{\"content\":\"<think>internal</think>\"},\"index\":0}]}\n\n",
4856 "data: {\"choices\":[{\"delta\":{\"content\":\"Result\"},\"index\":0}]}\n\n",
4857 "data: {\"choices\":[{\"delta\":{},\"index\":0,\"finish_reason\":\"stop\"}]}\n\n",
4858 "data: [DONE]\n\n",
4859 );
4860
4861 let callback: StreamCallback = Arc::new(|_| {});
4862 let resp = driver.parse_openai_stream(raw, &callback).unwrap();
4863 assert_eq!(resp.message.content, "Result");
4864 }
4865}