1use std::sync::Arc;
8
9use async_trait::async_trait;
10use futures::StreamExt;
11use hmac::{Hmac, Mac};
12use reqwest::Client;
13use serde::{Deserialize, Serialize};
14use sha2::{Digest, Sha256};
15
16use punch_types::{
17 Message, ModelConfig, Provider, PunchError, PunchResult, Role, ToolCall, ToolDefinition,
18};
19
20#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
26#[serde(rename_all = "snake_case")]
27pub enum StopReason {
28 EndTurn,
30 ToolUse,
32 MaxTokens,
34 Error,
36}
37
38#[derive(Debug, Clone, Copy, Default, Serialize, Deserialize)]
40pub struct TokenUsage {
41 pub input_tokens: u64,
42 pub output_tokens: u64,
43}
44
45impl TokenUsage {
46 pub fn accumulate(&mut self, other: &TokenUsage) {
48 self.input_tokens += other.input_tokens;
49 self.output_tokens += other.output_tokens;
50 }
51
52 pub fn total(&self) -> u64 {
54 self.input_tokens + self.output_tokens
55 }
56}
57
58#[derive(Debug, Clone, Serialize, Deserialize)]
60pub struct CompletionRequest {
61 pub model: String,
63 pub messages: Vec<Message>,
65 #[serde(default)]
67 pub tools: Vec<ToolDefinition>,
68 pub max_tokens: u32,
70 pub temperature: Option<f32>,
72 pub system_prompt: Option<String>,
74}
75
76#[derive(Debug, Clone, Serialize, Deserialize)]
78pub struct CompletionResponse {
79 pub message: Message,
81 pub usage: TokenUsage,
83 pub stop_reason: StopReason,
85}
86
87#[derive(Debug, Clone, Serialize, Deserialize)]
93pub struct StreamChunk {
94 pub delta: String,
96 pub is_final: bool,
98 pub tool_call_delta: Option<ToolCallDelta>,
100}
101
102#[derive(Debug, Clone, Serialize, Deserialize)]
104pub struct ToolCallDelta {
105 pub index: usize,
106 pub id: Option<String>,
107 pub name: Option<String>,
108 pub arguments_delta: String,
109}
110
111pub type StreamCallback = Arc<dyn Fn(StreamChunk) + Send + Sync>;
113
114fn parse_sse_events(raw: &str) -> Vec<(String, String)> {
122 let mut events = Vec::new();
123 let mut current_event = String::new();
124 let mut current_data = String::new();
125
126 for line in raw.lines() {
127 if line.is_empty() {
128 if !current_data.is_empty() || !current_event.is_empty() {
130 events.push((
131 if current_event.is_empty() {
132 "message".to_string()
133 } else {
134 current_event.clone()
135 },
136 current_data.clone(),
137 ));
138 current_event.clear();
139 current_data.clear();
140 }
141 } else if let Some(val) = line.strip_prefix("event: ") {
142 current_event = val.trim().to_string();
143 } else if let Some(val) = line.strip_prefix("event:") {
144 current_event = val.trim().to_string();
145 } else if let Some(val) = line.strip_prefix("data: ") {
146 if !current_data.is_empty() {
147 current_data.push('\n');
148 }
149 current_data.push_str(val);
150 } else if let Some(val) = line.strip_prefix("data:") {
151 if !current_data.is_empty() {
152 current_data.push('\n');
153 }
154 current_data.push_str(val.trim());
155 }
156 }
157
158 if !current_data.is_empty() || !current_event.is_empty() {
160 events.push((
161 if current_event.is_empty() {
162 "message".to_string()
163 } else {
164 current_event
165 },
166 current_data,
167 ));
168 }
169
170 events
171}
172
173async fn read_stream_body(response: reqwest::Response) -> PunchResult<String> {
175 let mut stream = response.bytes_stream();
176 let mut body = Vec::new();
177 while let Some(chunk) = stream.next().await {
178 let chunk = chunk.map_err(|e| PunchError::Provider {
179 provider: "stream".to_string(),
180 message: format!("stream read error: {e}"),
181 })?;
182 body.extend_from_slice(&chunk);
183 }
184 String::from_utf8(body).map_err(|e| PunchError::Provider {
185 provider: "stream".to_string(),
186 message: format!("invalid UTF-8 in stream: {e}"),
187 })
188}
189
190pub fn strip_thinking_tags(content: &str) -> String {
203 let mut result = content.to_string();
204
205 for tag in &["think", "thinking", "reasoning", "reflection"] {
207 let open = format!("<{}>", tag);
208 let close = format!("</{}>", tag);
209
210 while let Some(start) = result.find(&open) {
212 if let Some(end) = result[start..].find(&close) {
213 let block_end = start + end + close.len();
214 result = format!("{}{}", &result[..start], &result[block_end..]);
215 } else {
216 result = result[..start].to_string();
218 break;
219 }
220 }
221 }
222
223 let trimmed = result.trim().to_string();
224
225 if trimmed.is_empty() {
228 content.to_string()
229 } else {
230 trimmed
231 }
232}
233
234#[async_trait]
240pub trait LlmDriver: Send + Sync + 'static {
241 async fn complete(&self, request: CompletionRequest) -> PunchResult<CompletionResponse>;
243
244 async fn stream_complete(&self, request: CompletionRequest) -> PunchResult<CompletionResponse> {
246 let noop: StreamCallback = Arc::new(|_| {});
247 self.stream_complete_with_callback(request, noop).await
248 }
249
250 async fn stream_complete_with_callback(
253 &self,
254 request: CompletionRequest,
255 callback: StreamCallback,
256 ) -> PunchResult<CompletionResponse> {
257 let response = self.complete(request).await?;
259 callback(StreamChunk {
260 delta: response.message.content.clone(),
261 is_final: true,
262 tool_call_delta: None,
263 });
264 Ok(response)
265 }
266}
267
268pub struct AnthropicDriver {
274 client: Client,
275 api_key: String,
276 base_url: String,
277}
278
279impl AnthropicDriver {
280 pub fn new(api_key: String, base_url: Option<String>) -> Self {
284 Self {
285 client: Client::new(),
286 api_key,
287 base_url: base_url.unwrap_or_else(|| "https://api.anthropic.com".to_string()),
288 }
289 }
290
291 pub fn with_client(client: Client, api_key: String, base_url: Option<String>) -> Self {
295 Self {
296 client,
297 api_key,
298 base_url: base_url.unwrap_or_else(|| "https://api.anthropic.com".to_string()),
299 }
300 }
301
302 fn build_request_body(&self, request: &CompletionRequest) -> serde_json::Value {
304 let mut messages = Vec::new();
305
306 for msg in &request.messages {
307 match msg.role {
308 Role::User => {
309 messages.push(serde_json::json!({
310 "role": "user",
311 "content": msg.content,
312 }));
313 }
314 Role::Assistant => {
315 let mut content_blocks: Vec<serde_json::Value> = Vec::new();
316
317 if !msg.content.is_empty() {
318 content_blocks.push(serde_json::json!({
319 "type": "text",
320 "text": msg.content,
321 }));
322 }
323
324 for tc in &msg.tool_calls {
325 content_blocks.push(serde_json::json!({
326 "type": "tool_use",
327 "id": tc.id,
328 "name": tc.name,
329 "input": tc.input,
330 }));
331 }
332
333 if content_blocks.is_empty() {
334 content_blocks.push(serde_json::json!({
335 "type": "text",
336 "text": "",
337 }));
338 }
339
340 messages.push(serde_json::json!({
341 "role": "assistant",
342 "content": content_blocks,
343 }));
344 }
345 Role::Tool => {
346 let mut result_blocks: Vec<serde_json::Value> = Vec::new();
347 for tr in &msg.tool_results {
348 result_blocks.push(serde_json::json!({
349 "type": "tool_result",
350 "tool_use_id": tr.id,
351 "content": tr.content,
352 "is_error": tr.is_error,
353 }));
354 }
355 messages.push(serde_json::json!({
356 "role": "user",
357 "content": result_blocks,
358 }));
359 }
360 Role::System => {
361 }
364 }
365 }
366
367 let tools: Vec<serde_json::Value> = request
368 .tools
369 .iter()
370 .map(|t| {
371 serde_json::json!({
372 "name": t.name,
373 "description": t.description,
374 "input_schema": t.input_schema,
375 })
376 })
377 .collect();
378
379 let mut body = serde_json::json!({
380 "model": request.model,
381 "messages": messages,
382 "max_tokens": request.max_tokens,
383 });
384
385 if let Some(temp) = request.temperature {
386 body["temperature"] = serde_json::json!(temp);
387 }
388
389 if let Some(ref system) = request.system_prompt {
390 body["system"] = serde_json::json!(system);
391 }
392
393 if !tools.is_empty() {
394 body["tools"] = serde_json::json!(tools);
395 }
396
397 body
398 }
399
400 fn parse_response(&self, body: &serde_json::Value) -> PunchResult<CompletionResponse> {
402 let stop_reason = match body["stop_reason"].as_str() {
403 Some("end_turn") => StopReason::EndTurn,
404 Some("tool_use") => StopReason::ToolUse,
405 Some("max_tokens") => StopReason::MaxTokens,
406 _ => StopReason::Error,
407 };
408
409 let usage = TokenUsage {
410 input_tokens: body["usage"]["input_tokens"].as_u64().unwrap_or(0),
411 output_tokens: body["usage"]["output_tokens"].as_u64().unwrap_or(0),
412 };
413
414 let mut text_content = String::new();
415 let mut tool_calls = Vec::new();
416
417 if let Some(content_array) = body["content"].as_array() {
418 for block in content_array {
419 match block["type"].as_str() {
420 Some("text") => {
421 if let Some(text) = block["text"].as_str() {
422 if !text_content.is_empty() {
423 text_content.push('\n');
424 }
425 text_content.push_str(text);
426 }
427 }
428 Some("tool_use") => {
429 tool_calls.push(ToolCall {
430 id: block["id"].as_str().unwrap_or_default().to_string(),
431 name: block["name"].as_str().unwrap_or_default().to_string(),
432 input: block["input"].clone(),
433 });
434 }
435 _ => {}
436 }
437 }
438 }
439
440 let text_content = strip_thinking_tags(&text_content);
442
443 let message = Message {
444 role: Role::Assistant,
445 content: text_content,
446 tool_calls,
447 tool_results: Vec::new(),
448 timestamp: chrono::Utc::now(),
449 };
450
451 Ok(CompletionResponse {
452 message,
453 usage,
454 stop_reason,
455 })
456 }
457}
458
459#[async_trait]
460impl LlmDriver for AnthropicDriver {
461 async fn complete(&self, request: CompletionRequest) -> PunchResult<CompletionResponse> {
462 let url = format!("{}/v1/messages", self.base_url);
463 let body = self.build_request_body(&request);
464
465 let response = self
466 .client
467 .post(&url)
468 .header("x-api-key", &self.api_key)
469 .header("anthropic-version", "2023-06-01")
470 .header("content-type", "application/json")
471 .json(&body)
472 .send()
473 .await
474 .map_err(|e| PunchError::Provider {
475 provider: "anthropic".to_string(),
476 message: format!("request failed: {e}"),
477 })?;
478
479 let status = response.status();
480
481 if status == reqwest::StatusCode::TOO_MANY_REQUESTS {
482 let retry_after = response
483 .headers()
484 .get("retry-after")
485 .and_then(|v| v.to_str().ok())
486 .and_then(|s| s.parse::<u64>().ok())
487 .unwrap_or(60)
488 * 1000;
489
490 return Err(PunchError::RateLimited {
491 provider: "anthropic".to_string(),
492 retry_after_ms: retry_after,
493 });
494 }
495
496 if status == reqwest::StatusCode::UNAUTHORIZED || status == reqwest::StatusCode::FORBIDDEN {
497 return Err(PunchError::Auth(
498 "anthropic API key is invalid or lacks permissions".to_string(),
499 ));
500 }
501
502 let response_body: serde_json::Value =
503 response.json().await.map_err(|e| PunchError::Provider {
504 provider: "anthropic".to_string(),
505 message: format!("failed to parse response: {e}"),
506 })?;
507
508 if !status.is_success() {
509 let error_msg = response_body["error"]["message"]
510 .as_str()
511 .unwrap_or("unknown error");
512 return Err(PunchError::Provider {
513 provider: "anthropic".to_string(),
514 message: format!("API error ({}): {}", status, error_msg),
515 });
516 }
517
518 self.parse_response(&response_body)
519 }
520
521 async fn stream_complete_with_callback(
522 &self,
523 request: CompletionRequest,
524 callback: StreamCallback,
525 ) -> PunchResult<CompletionResponse> {
526 let url = format!("{}/v1/messages", self.base_url);
527 let mut body = self.build_request_body(&request);
528 body["stream"] = serde_json::json!(true);
529
530 let response = self
531 .client
532 .post(&url)
533 .header("x-api-key", &self.api_key)
534 .header("anthropic-version", "2023-06-01")
535 .header("content-type", "application/json")
536 .json(&body)
537 .send()
538 .await
539 .map_err(|e| PunchError::Provider {
540 provider: "anthropic".to_string(),
541 message: format!("stream request failed: {e}"),
542 })?;
543
544 let status = response.status();
545 if status == reqwest::StatusCode::TOO_MANY_REQUESTS {
546 return Err(PunchError::RateLimited {
547 provider: "anthropic".to_string(),
548 retry_after_ms: 60_000,
549 });
550 }
551 if !status.is_success() {
552 let err_body: serde_json::Value =
553 response.json().await.unwrap_or(serde_json::json!({}));
554 let msg = err_body["error"]["message"]
555 .as_str()
556 .unwrap_or("unknown error");
557 return Err(PunchError::Provider {
558 provider: "anthropic".to_string(),
559 message: format!("API error ({}): {}", status, msg),
560 });
561 }
562
563 let raw = read_stream_body(response).await?;
564 let events = parse_sse_events(&raw);
565
566 let mut text_content = String::new();
567 let mut tool_calls: Vec<ToolCall> = Vec::new();
568 let mut usage = TokenUsage::default();
569 let mut stop_reason = StopReason::EndTurn;
570 let mut current_tool_index: Option<usize> = None;
572
573 for (event_type, data) in &events {
574 let parsed: serde_json::Value = match serde_json::from_str(data) {
575 Ok(v) => v,
576 Err(_) => continue,
577 };
578
579 match event_type.as_str() {
580 "message_start" => {
581 if let Some(inp) = parsed["message"]["usage"]["input_tokens"].as_u64() {
582 usage.input_tokens = inp;
583 }
584 }
585 "content_block_start" => {
586 let block = &parsed["content_block"];
587 match block["type"].as_str() {
588 Some("tool_use") => {
589 let id = block["id"].as_str().unwrap_or_default().to_string();
590 let name = block["name"].as_str().unwrap_or_default().to_string();
591 tool_calls.push(ToolCall {
592 id: id.clone(),
593 name: name.clone(),
594 input: serde_json::json!({}),
595 });
596 current_tool_index = Some(tool_calls.len() - 1);
597 callback(StreamChunk {
598 delta: String::new(),
599 is_final: false,
600 tool_call_delta: Some(ToolCallDelta {
601 index: tool_calls.len() - 1,
602 id: Some(id),
603 name: Some(name),
604 arguments_delta: String::new(),
605 }),
606 });
607 }
608 Some("text") => {
609 current_tool_index = None;
610 }
611 _ => {}
612 }
613 }
614 "content_block_delta" => {
615 let delta = &parsed["delta"];
616 match delta["type"].as_str() {
617 Some("text_delta") => {
618 let text = delta["text"].as_str().unwrap_or("");
619 text_content.push_str(text);
620 callback(StreamChunk {
621 delta: text.to_string(),
622 is_final: false,
623 tool_call_delta: None,
624 });
625 }
626 Some("input_json_delta") => {
627 let partial = delta["partial_json"].as_str().unwrap_or("");
628 if let Some(idx) = current_tool_index {
629 callback(StreamChunk {
630 delta: String::new(),
631 is_final: false,
632 tool_call_delta: Some(ToolCallDelta {
633 index: idx,
634 id: None,
635 name: None,
636 arguments_delta: partial.to_string(),
637 }),
638 });
639 }
640 }
641 _ => {}
642 }
643 }
644 "message_delta" => {
645 if let Some(sr) = parsed["delta"]["stop_reason"].as_str() {
646 stop_reason = match sr {
647 "end_turn" => StopReason::EndTurn,
648 "tool_use" => StopReason::ToolUse,
649 "max_tokens" => StopReason::MaxTokens,
650 _ => StopReason::Error,
651 };
652 }
653 if let Some(out) = parsed["usage"]["output_tokens"].as_u64() {
654 usage.output_tokens = out;
655 }
656 }
657 "message_stop" => {
658 callback(StreamChunk {
659 delta: String::new(),
660 is_final: true,
661 tool_call_delta: None,
662 });
663 }
664 _ => {}
665 }
666 }
667
668 let mut tool_json_bufs: Vec<String> = vec![String::new(); tool_calls.len()];
673 let mut tc_idx: Option<usize> = None;
674 for (event_type, data) in &events {
675 let parsed: serde_json::Value = match serde_json::from_str(data) {
676 Ok(v) => v,
677 Err(_) => continue,
678 };
679 match event_type.as_str() {
680 "content_block_start" => {
681 if parsed["content_block"]["type"].as_str() == Some("tool_use") {
682 tc_idx = Some(tc_idx.map_or(0, |i| i + 1));
683 } else {
684 tc_idx = None;
685 }
686 }
687 "content_block_delta" => {
688 if parsed["delta"]["type"].as_str() == Some("input_json_delta")
689 && let Some(idx) = tc_idx
690 && let Some(buf) = tool_json_bufs.get_mut(idx)
691 {
692 buf.push_str(parsed["delta"]["partial_json"].as_str().unwrap_or(""));
693 }
694 }
695 _ => {}
696 }
697 }
698 for (i, buf) in tool_json_bufs.into_iter().enumerate() {
699 if !buf.is_empty()
700 && let Some(tc) = tool_calls.get_mut(i)
701 {
702 tc.input = serde_json::from_str(&buf).unwrap_or(serde_json::json!({}));
703 }
704 }
705
706 let text_content = strip_thinking_tags(&text_content);
707
708 if !tool_calls.is_empty() && stop_reason != StopReason::ToolUse {
709 stop_reason = StopReason::ToolUse;
710 }
711
712 let message = Message {
713 role: Role::Assistant,
714 content: text_content,
715 tool_calls,
716 tool_results: Vec::new(),
717 timestamp: chrono::Utc::now(),
718 };
719
720 Ok(CompletionResponse {
721 message,
722 usage,
723 stop_reason,
724 })
725 }
726}
727
728pub struct OpenAiCompatibleDriver {
738 client: Client,
739 api_key: String,
740 base_url: String,
741 provider_name: String,
742}
743
744impl OpenAiCompatibleDriver {
745 pub fn new(api_key: String, base_url: String, provider_name: String) -> Self {
747 Self {
748 client: Client::new(),
749 api_key,
750 base_url,
751 provider_name,
752 }
753 }
754
755 pub fn with_client(
757 client: Client,
758 api_key: String,
759 base_url: String,
760 provider_name: String,
761 ) -> Self {
762 Self {
763 client,
764 api_key,
765 base_url,
766 provider_name,
767 }
768 }
769
770 pub fn build_request_body(&self, request: &CompletionRequest) -> serde_json::Value {
772 let mut messages = Vec::new();
773
774 if let Some(ref system) = request.system_prompt {
776 messages.push(serde_json::json!({
777 "role": "system",
778 "content": system,
779 }));
780 }
781
782 for msg in &request.messages {
783 match msg.role {
784 Role::System => {
785 messages.push(serde_json::json!({
786 "role": "system",
787 "content": msg.content,
788 }));
789 }
790 Role::User => {
791 messages.push(serde_json::json!({
792 "role": "user",
793 "content": msg.content,
794 }));
795 }
796 Role::Assistant => {
797 let mut m = serde_json::json!({
798 "role": "assistant",
799 });
800
801 if !msg.content.is_empty() {
802 m["content"] = serde_json::json!(msg.content);
803 }
804
805 if !msg.tool_calls.is_empty() {
806 let tc: Vec<serde_json::Value> = msg
807 .tool_calls
808 .iter()
809 .map(|tc| {
810 serde_json::json!({
811 "id": tc.id,
812 "type": "function",
813 "function": {
814 "name": tc.name,
815 "arguments": tc.input.to_string(),
816 },
817 })
818 })
819 .collect();
820 m["tool_calls"] = serde_json::json!(tc);
821 }
822
823 messages.push(m);
824 }
825 Role::Tool => {
826 for tr in &msg.tool_results {
827 messages.push(serde_json::json!({
828 "role": "tool",
829 "tool_call_id": tr.id,
830 "content": tr.content,
831 }));
832 }
833 }
834 }
835 }
836
837 let tools: Vec<serde_json::Value> = request
838 .tools
839 .iter()
840 .map(|t| {
841 serde_json::json!({
842 "type": "function",
843 "function": {
844 "name": t.name,
845 "description": t.description,
846 "parameters": t.input_schema,
847 },
848 })
849 })
850 .collect();
851
852 let mut body = serde_json::json!({
853 "model": request.model,
854 "messages": messages,
855 "max_tokens": request.max_tokens,
856 });
857
858 if let Some(temp) = request.temperature {
859 body["temperature"] = serde_json::json!(temp);
860 }
861
862 if !tools.is_empty() {
863 body["tools"] = serde_json::json!(tools);
864 }
865
866 body
867 }
868
869 pub fn parse_response(&self, body: &serde_json::Value) -> PunchResult<CompletionResponse> {
871 let choice = body["choices"].get(0).ok_or_else(|| PunchError::Provider {
872 provider: self.provider_name.clone(),
873 message: "no choices in response".to_string(),
874 })?;
875
876 let finish_reason = choice["finish_reason"].as_str().unwrap_or("stop");
877 let stop_reason = match finish_reason {
878 "stop" => StopReason::EndTurn,
879 "tool_calls" => StopReason::ToolUse,
880 "length" => StopReason::MaxTokens,
881 _ => StopReason::EndTurn,
882 };
883
884 let msg = &choice["message"];
885 let raw_content = msg["content"].as_str().unwrap_or("");
886 let content = strip_thinking_tags(raw_content);
888
889 let mut tool_calls = Vec::new();
890 if let Some(tc_array) = msg["tool_calls"].as_array() {
891 for tc in tc_array {
892 let id = tc["id"].as_str().unwrap_or_default().to_string();
893 let name = tc["function"]["name"]
894 .as_str()
895 .unwrap_or_default()
896 .to_string();
897 let args_str = tc["function"]["arguments"].as_str().unwrap_or("{}");
898 let input: serde_json::Value =
899 serde_json::from_str(args_str).unwrap_or(serde_json::json!({}));
900
901 tool_calls.push(ToolCall { id, name, input });
902 }
903 }
904
905 let usage = TokenUsage {
906 input_tokens: body["usage"]["prompt_tokens"].as_u64().unwrap_or(0),
907 output_tokens: body["usage"]["completion_tokens"].as_u64().unwrap_or(0),
908 };
909
910 let stop_reason = if !tool_calls.is_empty() && stop_reason != StopReason::ToolUse {
912 StopReason::ToolUse
913 } else {
914 stop_reason
915 };
916
917 let message = Message {
918 role: Role::Assistant,
919 content,
920 tool_calls,
921 tool_results: Vec::new(),
922 timestamp: chrono::Utc::now(),
923 };
924
925 Ok(CompletionResponse {
926 message,
927 usage,
928 stop_reason,
929 })
930 }
931}
932
933#[async_trait]
934impl LlmDriver for OpenAiCompatibleDriver {
935 async fn complete(&self, request: CompletionRequest) -> PunchResult<CompletionResponse> {
936 let url = format!(
937 "{}/v1/chat/completions",
938 self.base_url.trim_end_matches('/')
939 );
940 let body = self.build_request_body(&request);
941
942 let response = self
943 .client
944 .post(&url)
945 .header("authorization", format!("Bearer {}", self.api_key))
946 .header("content-type", "application/json")
947 .json(&body)
948 .send()
949 .await
950 .map_err(|e| PunchError::Provider {
951 provider: self.provider_name.clone(),
952 message: format!("request failed: {e}"),
953 })?;
954
955 let status = response.status();
956
957 if status == reqwest::StatusCode::TOO_MANY_REQUESTS {
958 let retry_after = response
959 .headers()
960 .get("retry-after")
961 .and_then(|v| v.to_str().ok())
962 .and_then(|s| s.parse::<u64>().ok())
963 .unwrap_or(60)
964 * 1000;
965
966 return Err(PunchError::RateLimited {
967 provider: self.provider_name.clone(),
968 retry_after_ms: retry_after,
969 });
970 }
971
972 if status == reqwest::StatusCode::UNAUTHORIZED || status == reqwest::StatusCode::FORBIDDEN {
973 return Err(PunchError::Auth(format!(
974 "{} API key is invalid or lacks permissions",
975 self.provider_name
976 )));
977 }
978
979 let response_body: serde_json::Value =
980 response.json().await.map_err(|e| PunchError::Provider {
981 provider: self.provider_name.clone(),
982 message: format!("failed to parse response: {e}"),
983 })?;
984
985 if !status.is_success() {
986 let error_msg = response_body["error"]["message"]
987 .as_str()
988 .unwrap_or("unknown error");
989 return Err(PunchError::Provider {
990 provider: self.provider_name.clone(),
991 message: format!("API error ({}): {}", status, error_msg),
992 });
993 }
994
995 self.parse_response(&response_body)
996 }
997
998 async fn stream_complete_with_callback(
999 &self,
1000 request: CompletionRequest,
1001 callback: StreamCallback,
1002 ) -> PunchResult<CompletionResponse> {
1003 let url = format!(
1004 "{}/v1/chat/completions",
1005 self.base_url.trim_end_matches('/')
1006 );
1007 let mut body = self.build_request_body(&request);
1008 body["stream"] = serde_json::json!(true);
1009
1010 let response = self
1011 .client
1012 .post(&url)
1013 .header("authorization", format!("Bearer {}", self.api_key))
1014 .header("content-type", "application/json")
1015 .json(&body)
1016 .send()
1017 .await
1018 .map_err(|e| PunchError::Provider {
1019 provider: self.provider_name.clone(),
1020 message: format!("stream request failed: {e}"),
1021 })?;
1022
1023 let status = response.status();
1024 if status == reqwest::StatusCode::TOO_MANY_REQUESTS {
1025 return Err(PunchError::RateLimited {
1026 provider: self.provider_name.clone(),
1027 retry_after_ms: 60_000,
1028 });
1029 }
1030 if !status.is_success() {
1031 let err_body: serde_json::Value =
1032 response.json().await.unwrap_or(serde_json::json!({}));
1033 let msg = err_body["error"]["message"]
1034 .as_str()
1035 .unwrap_or("unknown error");
1036 return Err(PunchError::Provider {
1037 provider: self.provider_name.clone(),
1038 message: format!("API error ({}): {}", status, msg),
1039 });
1040 }
1041
1042 let raw = read_stream_body(response).await?;
1043 let assembled = self.parse_openai_stream(&raw, &callback)?;
1044 Ok(assembled)
1045 }
1046}
1047
1048impl OpenAiCompatibleDriver {
1049 pub fn parse_openai_stream(
1052 &self,
1053 raw: &str,
1054 callback: &StreamCallback,
1055 ) -> PunchResult<CompletionResponse> {
1056 let events = parse_sse_events(raw);
1057
1058 let mut text_content = String::new();
1059 let mut tool_map: std::collections::BTreeMap<usize, (String, String, String)> =
1061 std::collections::BTreeMap::new();
1062 let mut finish_reason = String::new();
1063
1064 for (_event_type, data) in &events {
1065 if data.trim() == "[DONE]" {
1066 callback(StreamChunk {
1067 delta: String::new(),
1068 is_final: true,
1069 tool_call_delta: None,
1070 });
1071 break;
1072 }
1073
1074 let parsed: serde_json::Value = match serde_json::from_str(data) {
1075 Ok(v) => v,
1076 Err(_) => continue,
1077 };
1078
1079 let choice = match parsed["choices"].get(0) {
1080 Some(c) => c,
1081 None => continue,
1082 };
1083
1084 if let Some(fr) = choice["finish_reason"].as_str() {
1085 finish_reason = fr.to_string();
1086 }
1087
1088 let delta = &choice["delta"];
1089
1090 if let Some(content) = delta["content"].as_str() {
1092 text_content.push_str(content);
1093 callback(StreamChunk {
1094 delta: content.to_string(),
1095 is_final: false,
1096 tool_call_delta: None,
1097 });
1098 }
1099
1100 if let Some(tc_array) = delta["tool_calls"].as_array() {
1102 for tc in tc_array {
1103 let idx = tc["index"].as_u64().unwrap_or(0) as usize;
1104 let entry = tool_map
1105 .entry(idx)
1106 .or_insert_with(|| (String::new(), String::new(), String::new()));
1107
1108 let id_delta = tc["id"].as_str().unwrap_or("");
1109 let name_delta = tc["function"]["name"].as_str().unwrap_or("");
1110 let args_delta = tc["function"]["arguments"].as_str().unwrap_or("");
1111
1112 if !id_delta.is_empty() {
1113 entry.0.push_str(id_delta);
1114 }
1115 if !name_delta.is_empty() {
1116 entry.1.push_str(name_delta);
1117 }
1118 entry.2.push_str(args_delta);
1119
1120 callback(StreamChunk {
1121 delta: String::new(),
1122 is_final: false,
1123 tool_call_delta: Some(ToolCallDelta {
1124 index: idx,
1125 id: if id_delta.is_empty() {
1126 None
1127 } else {
1128 Some(id_delta.to_string())
1129 },
1130 name: if name_delta.is_empty() {
1131 None
1132 } else {
1133 Some(name_delta.to_string())
1134 },
1135 arguments_delta: args_delta.to_string(),
1136 }),
1137 });
1138 }
1139 }
1140 }
1141
1142 let tool_calls: Vec<ToolCall> = tool_map
1143 .into_values()
1144 .map(|(id, name, args)| {
1145 let input: serde_json::Value =
1146 serde_json::from_str(&args).unwrap_or(serde_json::json!({}));
1147 ToolCall { id, name, input }
1148 })
1149 .collect();
1150
1151 let stop_reason = if !tool_calls.is_empty() {
1152 StopReason::ToolUse
1153 } else {
1154 match finish_reason.as_str() {
1155 "stop" => StopReason::EndTurn,
1156 "tool_calls" => StopReason::ToolUse,
1157 "length" => StopReason::MaxTokens,
1158 _ => StopReason::EndTurn,
1159 }
1160 };
1161
1162 let text_content = strip_thinking_tags(&text_content);
1163
1164 let message = Message {
1165 role: Role::Assistant,
1166 content: text_content,
1167 tool_calls,
1168 tool_results: Vec::new(),
1169 timestamp: chrono::Utc::now(),
1170 };
1171
1172 Ok(CompletionResponse {
1174 message,
1175 usage: TokenUsage::default(),
1176 stop_reason,
1177 })
1178 }
1179}
1180
1181pub struct GeminiDriver {
1187 client: Client,
1188 api_key: String,
1189 base_url: String,
1190}
1191
1192impl GeminiDriver {
1193 pub fn new(api_key: String, base_url: Option<String>) -> Self {
1195 Self {
1196 client: Client::new(),
1197 api_key,
1198 base_url: base_url
1199 .unwrap_or_else(|| "https://generativelanguage.googleapis.com".to_string()),
1200 }
1201 }
1202
1203 pub fn with_client(client: Client, api_key: String, base_url: Option<String>) -> Self {
1205 Self {
1206 client,
1207 api_key,
1208 base_url: base_url
1209 .unwrap_or_else(|| "https://generativelanguage.googleapis.com".to_string()),
1210 }
1211 }
1212
1213 pub fn build_request_body(&self, request: &CompletionRequest) -> serde_json::Value {
1215 let mut contents = Vec::new();
1216 let mut system_text: Option<String> = request.system_prompt.clone();
1217
1218 for msg in &request.messages {
1219 match msg.role {
1220 Role::System => {
1221 let existing = system_text.take().unwrap_or_default();
1223 let combined = if existing.is_empty() {
1224 msg.content.clone()
1225 } else {
1226 format!("{}\n{}", existing, msg.content)
1227 };
1228 system_text = Some(combined);
1229 }
1230 Role::User => {
1231 let mut text = String::new();
1232 if let Some(sys) = system_text.take()
1233 && !sys.is_empty()
1234 {
1235 text.push_str(&sys);
1236 text.push_str("\n\n");
1237 }
1238 text.push_str(&msg.content);
1239 contents.push(serde_json::json!({
1240 "role": "user",
1241 "parts": [{"text": text}],
1242 }));
1243 }
1244 Role::Assistant => {
1245 let mut parts: Vec<serde_json::Value> = Vec::new();
1246 if !msg.content.is_empty() {
1247 parts.push(serde_json::json!({"text": msg.content}));
1248 }
1249 for tc in &msg.tool_calls {
1250 parts.push(serde_json::json!({
1251 "functionCall": {
1252 "name": tc.name,
1253 "args": tc.input,
1254 }
1255 }));
1256 }
1257 if parts.is_empty() {
1258 parts.push(serde_json::json!({"text": ""}));
1259 }
1260 contents.push(serde_json::json!({
1261 "role": "model",
1262 "parts": parts,
1263 }));
1264 }
1265 Role::Tool => {
1266 let mut parts: Vec<serde_json::Value> = Vec::new();
1267 for tr in &msg.tool_results {
1268 parts.push(serde_json::json!({
1269 "functionResponse": {
1270 "name": tr.id.clone(),
1271 "response": {"content": tr.content},
1272 }
1273 }));
1274 }
1275 contents.push(serde_json::json!({
1276 "role": "user",
1277 "parts": parts,
1278 }));
1279 }
1280 }
1281 }
1282
1283 if let Some(sys) = system_text
1285 && !sys.is_empty()
1286 {
1287 contents.insert(
1288 0,
1289 serde_json::json!({
1290 "role": "user",
1291 "parts": [{"text": sys}],
1292 }),
1293 );
1294 }
1295
1296 let mut body = serde_json::json!({
1297 "contents": contents,
1298 });
1299
1300 let mut gen_config = serde_json::json!({
1301 "maxOutputTokens": request.max_tokens,
1302 });
1303 if let Some(temp) = request.temperature {
1304 gen_config["temperature"] = serde_json::json!(temp);
1305 }
1306 body["generationConfig"] = gen_config;
1307
1308 if !request.tools.is_empty() {
1309 let func_decls: Vec<serde_json::Value> = request
1310 .tools
1311 .iter()
1312 .map(|t| {
1313 serde_json::json!({
1314 "name": t.name,
1315 "description": t.description,
1316 "parameters": t.input_schema,
1317 })
1318 })
1319 .collect();
1320 body["tools"] = serde_json::json!([{"function_declarations": func_decls}]);
1321 }
1322
1323 body
1324 }
1325
1326 pub fn build_url(&self, model: &str) -> String {
1328 format!(
1329 "{}/v1beta/models/{}:generateContent?key={}",
1330 self.base_url.trim_end_matches('/'),
1331 model,
1332 self.api_key,
1333 )
1334 }
1335
1336 pub fn build_stream_url(&self, model: &str) -> String {
1338 format!(
1339 "{}/v1beta/models/{}:streamGenerateContent?alt=sse&key={}",
1340 self.base_url.trim_end_matches('/'),
1341 model,
1342 self.api_key,
1343 )
1344 }
1345
1346 pub fn parse_response(&self, body: &serde_json::Value) -> PunchResult<CompletionResponse> {
1348 let candidate = body["candidates"]
1349 .get(0)
1350 .ok_or_else(|| PunchError::Provider {
1351 provider: "gemini".to_string(),
1352 message: "no candidates in response".to_string(),
1353 })?;
1354
1355 let parts = candidate["content"]["parts"]
1356 .as_array()
1357 .cloned()
1358 .unwrap_or_default();
1359
1360 let mut text_content = String::new();
1361 let mut tool_calls = Vec::new();
1362
1363 for part in &parts {
1364 if let Some(text) = part["text"].as_str() {
1365 if !text_content.is_empty() {
1366 text_content.push('\n');
1367 }
1368 text_content.push_str(text);
1369 }
1370 if let Some(fc) = part.get("functionCall") {
1371 let name = fc["name"].as_str().unwrap_or_default().to_string();
1372 let args = fc["args"].clone();
1373 tool_calls.push(ToolCall {
1374 id: format!("gemini-{}", uuid::Uuid::new_v4()),
1375 name,
1376 input: args,
1377 });
1378 }
1379 }
1380
1381 let finish_reason = candidate["finishReason"].as_str().unwrap_or("STOP");
1382 let stop_reason = if !tool_calls.is_empty() {
1383 StopReason::ToolUse
1384 } else {
1385 match finish_reason {
1386 "STOP" => StopReason::EndTurn,
1387 "MAX_TOKENS" => StopReason::MaxTokens,
1388 _ => StopReason::EndTurn,
1389 }
1390 };
1391
1392 let usage = TokenUsage {
1393 input_tokens: body["usageMetadata"]["promptTokenCount"]
1394 .as_u64()
1395 .unwrap_or(0),
1396 output_tokens: body["usageMetadata"]["candidatesTokenCount"]
1397 .as_u64()
1398 .unwrap_or(0),
1399 };
1400
1401 let text_content = strip_thinking_tags(&text_content);
1403
1404 let message = Message {
1405 role: Role::Assistant,
1406 content: text_content,
1407 tool_calls,
1408 tool_results: Vec::new(),
1409 timestamp: chrono::Utc::now(),
1410 };
1411
1412 Ok(CompletionResponse {
1413 message,
1414 usage,
1415 stop_reason,
1416 })
1417 }
1418}
1419
1420#[async_trait]
1421impl LlmDriver for GeminiDriver {
1422 async fn complete(&self, request: CompletionRequest) -> PunchResult<CompletionResponse> {
1423 let url = self.build_url(&request.model);
1424 let body = self.build_request_body(&request);
1425
1426 let response = self
1427 .client
1428 .post(&url)
1429 .header("content-type", "application/json")
1430 .json(&body)
1431 .send()
1432 .await
1433 .map_err(|e| PunchError::Provider {
1434 provider: "gemini".to_string(),
1435 message: format!("request failed: {e}"),
1436 })?;
1437
1438 let status = response.status();
1439
1440 if status == reqwest::StatusCode::TOO_MANY_REQUESTS {
1441 return Err(PunchError::RateLimited {
1442 provider: "gemini".to_string(),
1443 retry_after_ms: 60_000,
1444 });
1445 }
1446
1447 if status == reqwest::StatusCode::UNAUTHORIZED || status == reqwest::StatusCode::FORBIDDEN {
1448 return Err(PunchError::Auth(
1449 "Gemini API key is invalid or lacks permissions".to_string(),
1450 ));
1451 }
1452
1453 let response_body: serde_json::Value =
1454 response.json().await.map_err(|e| PunchError::Provider {
1455 provider: "gemini".to_string(),
1456 message: format!("failed to parse response: {e}"),
1457 })?;
1458
1459 if !status.is_success() {
1460 let error_msg = response_body["error"]["message"]
1461 .as_str()
1462 .unwrap_or("unknown error");
1463 return Err(PunchError::Provider {
1464 provider: "gemini".to_string(),
1465 message: format!("API error ({}): {}", status, error_msg),
1466 });
1467 }
1468
1469 self.parse_response(&response_body)
1470 }
1471
1472 async fn stream_complete_with_callback(
1473 &self,
1474 request: CompletionRequest,
1475 callback: StreamCallback,
1476 ) -> PunchResult<CompletionResponse> {
1477 let url = self.build_stream_url(&request.model);
1478 let body = self.build_request_body(&request);
1479
1480 let response = self
1481 .client
1482 .post(&url)
1483 .header("content-type", "application/json")
1484 .json(&body)
1485 .send()
1486 .await
1487 .map_err(|e| PunchError::Provider {
1488 provider: "gemini".to_string(),
1489 message: format!("stream request failed: {e}"),
1490 })?;
1491
1492 let status = response.status();
1493 if !status.is_success() {
1494 let err_body: serde_json::Value =
1495 response.json().await.unwrap_or(serde_json::json!({}));
1496 let msg = err_body["error"]["message"]
1497 .as_str()
1498 .unwrap_or("unknown error");
1499 return Err(PunchError::Provider {
1500 provider: "gemini".to_string(),
1501 message: format!("API error ({}): {}", status, msg),
1502 });
1503 }
1504
1505 let raw = read_stream_body(response).await?;
1506 let events = parse_sse_events(&raw);
1507
1508 let mut text_content = String::new();
1509 let mut tool_calls: Vec<ToolCall> = Vec::new();
1510 let mut usage = TokenUsage::default();
1511 let mut finish_reason = String::new();
1512
1513 for (_event_type, data) in &events {
1514 let parsed: serde_json::Value = match serde_json::from_str(data) {
1515 Ok(v) => v,
1516 Err(_) => continue,
1517 };
1518
1519 if let Some(parts) = parsed["candidates"][0]["content"]["parts"].as_array() {
1521 for part in parts {
1522 if let Some(text) = part["text"].as_str() {
1523 text_content.push_str(text);
1524 callback(StreamChunk {
1525 delta: text.to_string(),
1526 is_final: false,
1527 tool_call_delta: None,
1528 });
1529 }
1530 if let Some(fc) = part.get("functionCall") {
1531 let name = fc["name"].as_str().unwrap_or_default().to_string();
1532 let args = fc["args"].clone();
1533 let idx = tool_calls.len();
1534 tool_calls.push(ToolCall {
1535 id: format!("gemini-{}", uuid::Uuid::new_v4()),
1536 name: name.clone(),
1537 input: args,
1538 });
1539 callback(StreamChunk {
1540 delta: String::new(),
1541 is_final: false,
1542 tool_call_delta: Some(ToolCallDelta {
1543 index: idx,
1544 id: None,
1545 name: Some(name),
1546 arguments_delta: String::new(),
1547 }),
1548 });
1549 }
1550 }
1551 }
1552
1553 if let Some(fr) = parsed["candidates"][0]["finishReason"].as_str() {
1554 finish_reason = fr.to_string();
1555 }
1556
1557 if let Some(inp) = parsed["usageMetadata"]["promptTokenCount"].as_u64() {
1559 usage.input_tokens = inp;
1560 }
1561 if let Some(out) = parsed["usageMetadata"]["candidatesTokenCount"].as_u64() {
1562 usage.output_tokens = out;
1563 }
1564 }
1565
1566 callback(StreamChunk {
1567 delta: String::new(),
1568 is_final: true,
1569 tool_call_delta: None,
1570 });
1571
1572 let stop_reason = if !tool_calls.is_empty() {
1573 StopReason::ToolUse
1574 } else {
1575 match finish_reason.as_str() {
1576 "STOP" => StopReason::EndTurn,
1577 "MAX_TOKENS" => StopReason::MaxTokens,
1578 _ => StopReason::EndTurn,
1579 }
1580 };
1581
1582 let text_content = strip_thinking_tags(&text_content);
1583
1584 let message = Message {
1585 role: Role::Assistant,
1586 content: text_content,
1587 tool_calls,
1588 tool_results: Vec::new(),
1589 timestamp: chrono::Utc::now(),
1590 };
1591
1592 Ok(CompletionResponse {
1593 message,
1594 usage,
1595 stop_reason,
1596 })
1597 }
1598}
1599
1600pub struct OllamaDriver {
1606 client: Client,
1607 base_url: String,
1608}
1609
1610impl OllamaDriver {
1611 pub fn new(base_url: Option<String>) -> Self {
1613 Self {
1614 client: Client::new(),
1615 base_url: base_url.unwrap_or_else(|| "http://localhost:11434".to_string()),
1616 }
1617 }
1618
1619 pub fn with_client(client: Client, base_url: Option<String>) -> Self {
1621 Self {
1622 client,
1623 base_url: base_url.unwrap_or_else(|| "http://localhost:11434".to_string()),
1624 }
1625 }
1626
1627 pub fn base_url(&self) -> &str {
1629 &self.base_url
1630 }
1631
1632 pub fn build_request_body(&self, request: &CompletionRequest) -> serde_json::Value {
1634 let mut messages = Vec::new();
1635
1636 if let Some(ref system) = request.system_prompt {
1637 messages.push(serde_json::json!({
1638 "role": "system",
1639 "content": system,
1640 }));
1641 }
1642
1643 for msg in &request.messages {
1644 match msg.role {
1645 Role::System => {
1646 messages.push(serde_json::json!({
1647 "role": "system",
1648 "content": msg.content,
1649 }));
1650 }
1651 Role::User => {
1652 messages.push(serde_json::json!({
1653 "role": "user",
1654 "content": msg.content,
1655 }));
1656 }
1657 Role::Assistant => {
1658 let mut m = serde_json::json!({
1659 "role": "assistant",
1660 "content": msg.content,
1661 });
1662 if !msg.tool_calls.is_empty() {
1663 let tc: Vec<serde_json::Value> = msg
1664 .tool_calls
1665 .iter()
1666 .map(|tc| {
1667 serde_json::json!({
1668 "function": {
1669 "name": tc.name,
1670 "arguments": tc.input,
1671 }
1672 })
1673 })
1674 .collect();
1675 m["tool_calls"] = serde_json::json!(tc);
1676 }
1677 messages.push(m);
1678 }
1679 Role::Tool => {
1680 for tr in &msg.tool_results {
1681 messages.push(serde_json::json!({
1682 "role": "tool",
1683 "content": tr.content,
1684 }));
1685 }
1686 }
1687 }
1688 }
1689
1690 let mut body = serde_json::json!({
1691 "model": request.model,
1692 "messages": messages,
1693 "stream": false,
1694 });
1695
1696 let mut options = serde_json::json!({});
1697 if let Some(temp) = request.temperature {
1698 options["temperature"] = serde_json::json!(temp);
1699 }
1700 if request.max_tokens > 0 {
1701 options["num_predict"] = serde_json::json!(request.max_tokens);
1702 }
1703 body["options"] = options;
1704
1705 body["think"] = serde_json::json!(false);
1709
1710 if !request.tools.is_empty() {
1711 let tools: Vec<serde_json::Value> = request
1712 .tools
1713 .iter()
1714 .map(|t| {
1715 serde_json::json!({
1716 "type": "function",
1717 "function": {
1718 "name": t.name,
1719 "description": t.description,
1720 "parameters": t.input_schema,
1721 }
1722 })
1723 })
1724 .collect();
1725 body["tools"] = serde_json::json!(tools);
1726 }
1727
1728 body
1729 }
1730
1731 pub fn parse_response(&self, body: &serde_json::Value) -> PunchResult<CompletionResponse> {
1733 let msg = &body["message"];
1734 let raw_content = msg["content"].as_str().unwrap_or("");
1735 let content = strip_thinking_tags(raw_content);
1737
1738 let mut tool_calls = Vec::new();
1739 if let Some(tc_array) = msg["tool_calls"].as_array() {
1740 for tc in tc_array {
1741 let name = tc["function"]["name"]
1742 .as_str()
1743 .unwrap_or_default()
1744 .to_string();
1745 let input = tc["function"]["arguments"].clone();
1746 tool_calls.push(ToolCall {
1747 id: format!("ollama-{}", uuid::Uuid::new_v4()),
1748 name,
1749 input,
1750 });
1751 }
1752 }
1753
1754 let stop_reason = if !tool_calls.is_empty() {
1755 StopReason::ToolUse
1756 } else if body["done"].as_bool().unwrap_or(true) {
1757 StopReason::EndTurn
1758 } else {
1759 StopReason::MaxTokens
1760 };
1761
1762 let usage = TokenUsage {
1763 input_tokens: body["prompt_eval_count"].as_u64().unwrap_or(0),
1764 output_tokens: body["eval_count"].as_u64().unwrap_or(0),
1765 };
1766
1767 let message = Message {
1768 role: Role::Assistant,
1769 content,
1770 tool_calls,
1771 tool_results: Vec::new(),
1772 timestamp: chrono::Utc::now(),
1773 };
1774
1775 Ok(CompletionResponse {
1776 message,
1777 usage,
1778 stop_reason,
1779 })
1780 }
1781}
1782
1783#[async_trait]
1784impl LlmDriver for OllamaDriver {
1785 async fn complete(&self, request: CompletionRequest) -> PunchResult<CompletionResponse> {
1786 let url = format!("{}/api/chat", self.base_url.trim_end_matches('/'));
1787 let body = self.build_request_body(&request);
1788
1789 let response = self
1790 .client
1791 .post(&url)
1792 .header("content-type", "application/json")
1793 .json(&body)
1794 .send()
1795 .await
1796 .map_err(|e| PunchError::Provider {
1797 provider: "ollama".to_string(),
1798 message: format!("request failed: {e}"),
1799 })?;
1800
1801 let status = response.status();
1802 let response_body: serde_json::Value =
1803 response.json().await.map_err(|e| PunchError::Provider {
1804 provider: "ollama".to_string(),
1805 message: format!("failed to parse response: {e}"),
1806 })?;
1807
1808 if !status.is_success() {
1809 let error_msg = response_body["error"].as_str().unwrap_or("unknown error");
1810 return Err(PunchError::Provider {
1811 provider: "ollama".to_string(),
1812 message: format!("API error ({}): {}", status, error_msg),
1813 });
1814 }
1815
1816 self.parse_response(&response_body)
1817 }
1818
1819 async fn stream_complete_with_callback(
1820 &self,
1821 request: CompletionRequest,
1822 callback: StreamCallback,
1823 ) -> PunchResult<CompletionResponse> {
1824 let url = format!("{}/api/chat", self.base_url.trim_end_matches('/'));
1825 let mut body = self.build_request_body(&request);
1826 body["stream"] = serde_json::json!(true);
1827
1828 let response = self
1829 .client
1830 .post(&url)
1831 .header("content-type", "application/json")
1832 .json(&body)
1833 .send()
1834 .await
1835 .map_err(|e| PunchError::Provider {
1836 provider: "ollama".to_string(),
1837 message: format!("stream request failed: {e}"),
1838 })?;
1839
1840 let status = response.status();
1841 if !status.is_success() {
1842 let err_body: serde_json::Value =
1843 response.json().await.unwrap_or(serde_json::json!({}));
1844 let msg = err_body["error"].as_str().unwrap_or("unknown error");
1845 return Err(PunchError::Provider {
1846 provider: "ollama".to_string(),
1847 message: format!("API error ({}): {}", status, msg),
1848 });
1849 }
1850
1851 let raw = read_stream_body(response).await?;
1852 let assembled = self.parse_ollama_stream(&raw, &callback)?;
1853 Ok(assembled)
1854 }
1855}
1856
1857impl OllamaDriver {
1858 pub fn parse_ollama_stream(
1861 &self,
1862 raw: &str,
1863 callback: &StreamCallback,
1864 ) -> PunchResult<CompletionResponse> {
1865 let mut text_content = String::new();
1866 let mut tool_calls: Vec<ToolCall> = Vec::new();
1867 let mut usage = TokenUsage::default();
1868 let mut done = false;
1869
1870 for line in raw.lines() {
1871 let line = line.trim();
1872 if line.is_empty() {
1873 continue;
1874 }
1875
1876 let parsed: serde_json::Value = match serde_json::from_str(line) {
1877 Ok(v) => v,
1878 Err(_) => continue,
1879 };
1880
1881 if parsed["done"].as_bool() == Some(true) {
1882 done = true;
1883 if let Some(inp) = parsed["prompt_eval_count"].as_u64() {
1885 usage.input_tokens = inp;
1886 }
1887 if let Some(out) = parsed["eval_count"].as_u64() {
1888 usage.output_tokens = out;
1889 }
1890 if let Some(tc_array) = parsed["message"]["tool_calls"].as_array() {
1892 for tc in tc_array {
1893 let name = tc["function"]["name"]
1894 .as_str()
1895 .unwrap_or_default()
1896 .to_string();
1897 let input = tc["function"]["arguments"].clone();
1898 tool_calls.push(ToolCall {
1899 id: format!("ollama-{}", uuid::Uuid::new_v4()),
1900 name,
1901 input,
1902 });
1903 }
1904 }
1905 callback(StreamChunk {
1906 delta: String::new(),
1907 is_final: true,
1908 tool_call_delta: None,
1909 });
1910 break;
1911 }
1912
1913 let content = parsed["message"]["content"].as_str().unwrap_or("");
1915 if !content.is_empty() {
1916 text_content.push_str(content);
1917 callback(StreamChunk {
1918 delta: content.to_string(),
1919 is_final: false,
1920 tool_call_delta: None,
1921 });
1922 }
1923 }
1924
1925 let text_content = strip_thinking_tags(&text_content);
1926
1927 let stop_reason = if !tool_calls.is_empty() {
1928 StopReason::ToolUse
1929 } else if done {
1930 StopReason::EndTurn
1931 } else {
1932 StopReason::MaxTokens
1933 };
1934
1935 let message = Message {
1936 role: Role::Assistant,
1937 content: text_content,
1938 tool_calls,
1939 tool_results: Vec::new(),
1940 timestamp: chrono::Utc::now(),
1941 };
1942
1943 Ok(CompletionResponse {
1944 message,
1945 usage,
1946 stop_reason,
1947 })
1948 }
1949}
1950
1951pub struct BedrockDriver {
1957 client: Client,
1958 access_key: String,
1959 secret_key: String,
1960 region: String,
1961}
1962
1963impl BedrockDriver {
1964 pub fn new(access_key: String, secret_key: String, region: String) -> Self {
1966 Self {
1967 client: Client::new(),
1968 access_key,
1969 secret_key,
1970 region,
1971 }
1972 }
1973
1974 pub fn with_client(
1976 client: Client,
1977 access_key: String,
1978 secret_key: String,
1979 region: String,
1980 ) -> Self {
1981 Self {
1982 client,
1983 access_key,
1984 secret_key,
1985 region,
1986 }
1987 }
1988
1989 pub fn build_request_body(&self, request: &CompletionRequest) -> serde_json::Value {
1991 let mut messages = Vec::new();
1992
1993 for msg in &request.messages {
1994 match msg.role {
1995 Role::User => {
1996 messages.push(serde_json::json!({
1997 "role": "user",
1998 "content": [{"text": msg.content}],
1999 }));
2000 }
2001 Role::Assistant => {
2002 let mut content: Vec<serde_json::Value> = Vec::new();
2003 if !msg.content.is_empty() {
2004 content.push(serde_json::json!({"text": msg.content}));
2005 }
2006 for tc in &msg.tool_calls {
2007 content.push(serde_json::json!({
2008 "toolUse": {
2009 "toolUseId": tc.id,
2010 "name": tc.name,
2011 "input": tc.input,
2012 }
2013 }));
2014 }
2015 if content.is_empty() {
2016 content.push(serde_json::json!({"text": ""}));
2017 }
2018 messages.push(serde_json::json!({
2019 "role": "assistant",
2020 "content": content,
2021 }));
2022 }
2023 Role::Tool => {
2024 let mut content: Vec<serde_json::Value> = Vec::new();
2025 for tr in &msg.tool_results {
2026 content.push(serde_json::json!({
2027 "toolResult": {
2028 "toolUseId": tr.id,
2029 "content": [{"text": tr.content}],
2030 "status": if tr.is_error { "error" } else { "success" },
2031 }
2032 }));
2033 }
2034 messages.push(serde_json::json!({
2035 "role": "user",
2036 "content": content,
2037 }));
2038 }
2039 Role::System => {
2040 }
2042 }
2043 }
2044
2045 let mut body = serde_json::json!({
2046 "messages": messages,
2047 });
2048
2049 let mut inference_config = serde_json::json!({
2050 "maxTokens": request.max_tokens,
2051 });
2052 if let Some(temp) = request.temperature {
2053 inference_config["temperature"] = serde_json::json!(temp);
2054 }
2055 body["inferenceConfig"] = inference_config;
2056
2057 if let Some(ref system) = request.system_prompt {
2058 body["system"] = serde_json::json!([{"text": system}]);
2059 }
2060
2061 if !request.tools.is_empty() {
2062 let tool_config: Vec<serde_json::Value> = request
2063 .tools
2064 .iter()
2065 .map(|t| {
2066 serde_json::json!({
2067 "toolSpec": {
2068 "name": t.name,
2069 "description": t.description,
2070 "inputSchema": {"json": t.input_schema},
2071 }
2072 })
2073 })
2074 .collect();
2075 body["toolConfig"] = serde_json::json!({"tools": tool_config});
2076 }
2077
2078 body
2079 }
2080
2081 pub fn build_url(&self, model_id: &str) -> String {
2083 format!(
2084 "https://bedrock-runtime.{}.amazonaws.com/model/{}/converse",
2085 self.region, model_id,
2086 )
2087 }
2088
2089 pub fn parse_response(&self, body: &serde_json::Value) -> PunchResult<CompletionResponse> {
2091 let content = body["output"]["message"]["content"]
2092 .as_array()
2093 .cloned()
2094 .unwrap_or_default();
2095
2096 let mut text_content = String::new();
2097 let mut tool_calls = Vec::new();
2098
2099 for block in &content {
2100 if let Some(text) = block["text"].as_str() {
2101 if !text_content.is_empty() {
2102 text_content.push('\n');
2103 }
2104 text_content.push_str(text);
2105 }
2106 if let Some(tu) = block.get("toolUse") {
2107 tool_calls.push(ToolCall {
2108 id: tu["toolUseId"].as_str().unwrap_or_default().to_string(),
2109 name: tu["name"].as_str().unwrap_or_default().to_string(),
2110 input: tu["input"].clone(),
2111 });
2112 }
2113 }
2114
2115 let stop_reason_str = body["stopReason"].as_str().unwrap_or("end_turn");
2116 let stop_reason = if !tool_calls.is_empty() {
2117 StopReason::ToolUse
2118 } else {
2119 match stop_reason_str {
2120 "end_turn" => StopReason::EndTurn,
2121 "tool_use" => StopReason::ToolUse,
2122 "max_tokens" => StopReason::MaxTokens,
2123 _ => StopReason::EndTurn,
2124 }
2125 };
2126
2127 let usage = TokenUsage {
2128 input_tokens: body["usage"]["inputTokens"].as_u64().unwrap_or(0),
2129 output_tokens: body["usage"]["outputTokens"].as_u64().unwrap_or(0),
2130 };
2131
2132 let text_content = strip_thinking_tags(&text_content);
2134
2135 let message = Message {
2136 role: Role::Assistant,
2137 content: text_content,
2138 tool_calls,
2139 tool_results: Vec::new(),
2140 timestamp: chrono::Utc::now(),
2141 };
2142
2143 Ok(CompletionResponse {
2144 message,
2145 usage,
2146 stop_reason,
2147 })
2148 }
2149
2150 pub fn sign_request(
2154 &self,
2155 method: &str,
2156 url: &str,
2157 headers: &[(String, String)],
2158 payload: &[u8],
2159 timestamp: &str, ) -> PunchResult<String> {
2161 let date = ×tamp[..8]; let service = "bedrock";
2163
2164 let parsed = url::Url::parse(url).map_err(|e| PunchError::Provider {
2166 provider: "bedrock".to_string(),
2167 message: format!("invalid URL: {e}"),
2168 })?;
2169 let host = parsed.host_str().unwrap_or("");
2170 let path = parsed.path();
2171
2172 let payload_hash = hex_sha256(payload);
2174
2175 let mut signed_header_names: Vec<String> =
2176 headers.iter().map(|(k, _)| k.to_lowercase()).collect();
2177 signed_header_names.push("host".to_string());
2178 signed_header_names.push("x-amz-date".to_string());
2179 signed_header_names.sort();
2180 signed_header_names.dedup();
2181
2182 let mut header_map: Vec<(String, String)> = headers
2183 .iter()
2184 .map(|(k, v)| (k.to_lowercase(), v.trim().to_string()))
2185 .collect();
2186 header_map.push(("host".to_string(), host.to_string()));
2187 header_map.push(("x-amz-date".to_string(), timestamp.to_string()));
2188 header_map.sort_by(|a, b| a.0.cmp(&b.0));
2189 header_map.dedup_by(|a, b| a.0 == b.0);
2190
2191 let canonical_headers: String = header_map
2192 .iter()
2193 .map(|(k, v)| format!("{}:{}\n", k, v))
2194 .collect();
2195
2196 let signed_headers = signed_header_names.join(";");
2197
2198 let canonical_request = format!(
2199 "{}\n{}\n\n{}\n{}\n{}",
2200 method, path, canonical_headers, signed_headers, payload_hash,
2201 );
2202
2203 let credential_scope = format!("{}/{}/{}/aws4_request", date, self.region, service);
2205 let string_to_sign = format!(
2206 "AWS4-HMAC-SHA256\n{}\n{}\n{}",
2207 timestamp,
2208 credential_scope,
2209 hex_sha256(canonical_request.as_bytes()),
2210 );
2211
2212 let k_date = hmac_sha256(
2214 format!("AWS4{}", self.secret_key).as_bytes(),
2215 date.as_bytes(),
2216 );
2217 let k_region = hmac_sha256(&k_date, self.region.as_bytes());
2218 let k_service = hmac_sha256(&k_region, service.as_bytes());
2219 let k_signing = hmac_sha256(&k_service, b"aws4_request");
2220
2221 let signature = hex_encode(&hmac_sha256(&k_signing, string_to_sign.as_bytes()));
2223
2224 Ok(format!(
2226 "AWS4-HMAC-SHA256 Credential={}/{}, SignedHeaders={}, Signature={}",
2227 self.access_key, credential_scope, signed_headers, signature,
2228 ))
2229 }
2230}
2231
2232fn hex_sha256(data: &[u8]) -> String {
2234 let mut hasher = Sha256::new();
2235 hasher.update(data);
2236 hex_encode(hasher.finalize().as_slice())
2237}
2238
2239fn hmac_sha256(key: &[u8], data: &[u8]) -> Vec<u8> {
2241 type HmacSha256 = Hmac<Sha256>;
2242 let mut mac = HmacSha256::new_from_slice(key).expect("HMAC can take key of any size");
2243 mac.update(data);
2244 mac.finalize().into_bytes().to_vec()
2245}
2246
2247fn hex_encode(bytes: &[u8]) -> String {
2249 bytes.iter().map(|b| format!("{:02x}", b)).collect()
2250}
2251
2252#[async_trait]
2253impl LlmDriver for BedrockDriver {
2254 async fn complete(&self, request: CompletionRequest) -> PunchResult<CompletionResponse> {
2255 let url = self.build_url(&request.model);
2256 let body = self.build_request_body(&request);
2257 let payload = serde_json::to_vec(&body).map_err(|e| PunchError::Provider {
2258 provider: "bedrock".to_string(),
2259 message: format!("failed to serialize request: {e}"),
2260 })?;
2261
2262 let timestamp = chrono::Utc::now().format("%Y%m%dT%H%M%SZ").to_string();
2263
2264 let auth_header = self.sign_request(
2265 "POST",
2266 &url,
2267 &[("content-type".to_string(), "application/json".to_string())],
2268 &payload,
2269 ×tamp,
2270 )?;
2271
2272 let parsed_url = url::Url::parse(&url).map_err(|e| PunchError::Provider {
2273 provider: "bedrock".to_string(),
2274 message: format!("invalid URL: {e}"),
2275 })?;
2276 let host = parsed_url.host_str().unwrap_or_default().to_string();
2277
2278 let response = self
2279 .client
2280 .post(&url)
2281 .header("content-type", "application/json")
2282 .header("host", &host)
2283 .header("x-amz-date", ×tamp)
2284 .header("authorization", &auth_header)
2285 .body(payload)
2286 .send()
2287 .await
2288 .map_err(|e| PunchError::Provider {
2289 provider: "bedrock".to_string(),
2290 message: format!("request failed: {e}"),
2291 })?;
2292
2293 let status = response.status();
2294
2295 if status == reqwest::StatusCode::TOO_MANY_REQUESTS {
2296 return Err(PunchError::RateLimited {
2297 provider: "bedrock".to_string(),
2298 retry_after_ms: 60_000,
2299 });
2300 }
2301
2302 if status == reqwest::StatusCode::UNAUTHORIZED || status == reqwest::StatusCode::FORBIDDEN {
2303 return Err(PunchError::Auth(
2304 "AWS Bedrock credentials are invalid or lack permissions".to_string(),
2305 ));
2306 }
2307
2308 let response_body: serde_json::Value =
2309 response.json().await.map_err(|e| PunchError::Provider {
2310 provider: "bedrock".to_string(),
2311 message: format!("failed to parse response: {e}"),
2312 })?;
2313
2314 if !status.is_success() {
2315 let error_msg = response_body["message"].as_str().unwrap_or("unknown error");
2316 return Err(PunchError::Provider {
2317 provider: "bedrock".to_string(),
2318 message: format!("API error ({}): {}", status, error_msg),
2319 });
2320 }
2321
2322 self.parse_response(&response_body)
2323 }
2324
2325 async fn stream_complete_with_callback(
2326 &self,
2327 request: CompletionRequest,
2328 callback: StreamCallback,
2329 ) -> PunchResult<CompletionResponse> {
2330 let response = self.complete(request).await?;
2333 callback(StreamChunk {
2334 delta: response.message.content.clone(),
2335 is_final: true,
2336 tool_call_delta: None,
2337 });
2338 Ok(response)
2339 }
2340}
2341
2342pub struct AzureOpenAiDriver {
2351 inner: OpenAiCompatibleDriver,
2352 resource: String,
2353 deployment: String,
2354 api_version: String,
2355}
2356
2357impl AzureOpenAiDriver {
2358 pub fn new(
2365 api_key: String,
2366 resource: String,
2367 deployment: String,
2368 api_version: Option<String>,
2369 ) -> Self {
2370 let base_url = format!("https://{}.openai.azure.com", resource);
2371 Self {
2372 inner: OpenAiCompatibleDriver::new(api_key, base_url, "azure_openai".to_string()),
2373 resource,
2374 deployment,
2375 api_version: api_version.unwrap_or_else(|| "2024-02-01".to_string()),
2376 }
2377 }
2378
2379 pub fn with_client(
2381 client: Client,
2382 api_key: String,
2383 resource: String,
2384 deployment: String,
2385 api_version: Option<String>,
2386 ) -> Self {
2387 let base_url = format!("https://{}.openai.azure.com", resource);
2388 Self {
2389 inner: OpenAiCompatibleDriver::with_client(
2390 client,
2391 api_key,
2392 base_url,
2393 "azure_openai".to_string(),
2394 ),
2395 resource,
2396 deployment,
2397 api_version: api_version.unwrap_or_else(|| "2024-02-01".to_string()),
2398 }
2399 }
2400
2401 pub fn build_url(&self) -> String {
2403 format!(
2404 "https://{}.openai.azure.com/openai/deployments/{}/chat/completions?api-version={}",
2405 self.resource, self.deployment, self.api_version,
2406 )
2407 }
2408
2409 pub fn resource(&self) -> &str {
2411 &self.resource
2412 }
2413
2414 pub fn deployment(&self) -> &str {
2416 &self.deployment
2417 }
2418
2419 pub fn build_request_body(&self, request: &CompletionRequest) -> serde_json::Value {
2421 self.inner.build_request_body(request)
2422 }
2423
2424 pub fn parse_response(&self, body: &serde_json::Value) -> PunchResult<CompletionResponse> {
2426 self.inner.parse_response(body)
2427 }
2428}
2429
2430#[async_trait]
2431impl LlmDriver for AzureOpenAiDriver {
2432 async fn complete(&self, request: CompletionRequest) -> PunchResult<CompletionResponse> {
2433 let url = self.build_url();
2434 let body = self.inner.build_request_body(&request);
2435
2436 let response = self
2437 .inner
2438 .client
2439 .post(&url)
2440 .header("api-key", &self.inner.api_key)
2441 .header("content-type", "application/json")
2442 .json(&body)
2443 .send()
2444 .await
2445 .map_err(|e| PunchError::Provider {
2446 provider: "azure_openai".to_string(),
2447 message: format!("request failed: {e}"),
2448 })?;
2449
2450 let status = response.status();
2451
2452 if status == reqwest::StatusCode::TOO_MANY_REQUESTS {
2453 let retry_after = response
2454 .headers()
2455 .get("retry-after")
2456 .and_then(|v| v.to_str().ok())
2457 .and_then(|s| s.parse::<u64>().ok())
2458 .unwrap_or(60)
2459 * 1000;
2460
2461 return Err(PunchError::RateLimited {
2462 provider: "azure_openai".to_string(),
2463 retry_after_ms: retry_after,
2464 });
2465 }
2466
2467 if status == reqwest::StatusCode::UNAUTHORIZED || status == reqwest::StatusCode::FORBIDDEN {
2468 return Err(PunchError::Auth(
2469 "Azure OpenAI API key is invalid or lacks permissions".to_string(),
2470 ));
2471 }
2472
2473 let response_body: serde_json::Value =
2474 response.json().await.map_err(|e| PunchError::Provider {
2475 provider: "azure_openai".to_string(),
2476 message: format!("failed to parse response: {e}"),
2477 })?;
2478
2479 if !status.is_success() {
2480 let error_msg = response_body["error"]["message"]
2481 .as_str()
2482 .unwrap_or("unknown error");
2483 return Err(PunchError::Provider {
2484 provider: "azure_openai".to_string(),
2485 message: format!("API error ({}): {}", status, error_msg),
2486 });
2487 }
2488
2489 self.inner.parse_response(&response_body)
2490 }
2491
2492 async fn stream_complete_with_callback(
2493 &self,
2494 request: CompletionRequest,
2495 callback: StreamCallback,
2496 ) -> PunchResult<CompletionResponse> {
2497 let url = self.build_url();
2498 let mut body = self.inner.build_request_body(&request);
2499 body["stream"] = serde_json::json!(true);
2500
2501 let response = self
2502 .inner
2503 .client
2504 .post(&url)
2505 .header("api-key", &self.inner.api_key)
2506 .header("content-type", "application/json")
2507 .json(&body)
2508 .send()
2509 .await
2510 .map_err(|e| PunchError::Provider {
2511 provider: "azure_openai".to_string(),
2512 message: format!("stream request failed: {e}"),
2513 })?;
2514
2515 let status = response.status();
2516 if status == reqwest::StatusCode::TOO_MANY_REQUESTS {
2517 return Err(PunchError::RateLimited {
2518 provider: "azure_openai".to_string(),
2519 retry_after_ms: 60_000,
2520 });
2521 }
2522 if !status.is_success() {
2523 let err_body: serde_json::Value =
2524 response.json().await.unwrap_or(serde_json::json!({}));
2525 let msg = err_body["error"]["message"]
2526 .as_str()
2527 .unwrap_or("unknown error");
2528 return Err(PunchError::Provider {
2529 provider: "azure_openai".to_string(),
2530 message: format!("API error ({}): {}", status, msg),
2531 });
2532 }
2533
2534 let raw = read_stream_body(response).await?;
2535 let assembled = self.inner.parse_openai_stream(&raw, &callback)?;
2537 Ok(assembled)
2538 }
2539}
2540
2541fn default_base_url(provider: &Provider) -> &'static str {
2547 match provider {
2548 Provider::Anthropic => "https://api.anthropic.com",
2549 Provider::OpenAI => "https://api.openai.com",
2550 Provider::Google => "https://generativelanguage.googleapis.com",
2551 Provider::Groq => "https://api.groq.com/openai",
2552 Provider::DeepSeek => "https://api.deepseek.com",
2553 Provider::Ollama => "http://localhost:11434",
2554 Provider::Mistral => "https://api.mistral.ai",
2555 Provider::Together => "https://api.together.xyz",
2556 Provider::Fireworks => "https://api.fireworks.ai/inference",
2557 Provider::Cerebras => "https://api.cerebras.ai",
2558 Provider::XAI => "https://api.x.ai",
2559 Provider::Cohere => "https://api.cohere.ai",
2560 Provider::Bedrock => "https://bedrock-runtime.us-east-1.amazonaws.com",
2561 Provider::AzureOpenAi => "",
2562 Provider::Custom(_) => "",
2563 }
2564}
2565
2566pub fn create_driver(config: &ModelConfig) -> PunchResult<Arc<dyn LlmDriver>> {
2576 create_driver_with_client(config, None)
2577}
2578
2579pub fn create_driver_with_client(
2581 config: &ModelConfig,
2582 shared_client: Option<&Client>,
2583) -> PunchResult<Arc<dyn LlmDriver>> {
2584 let api_key = match &config.api_key_env {
2585 Some(env_var) => std::env::var(env_var).map_err(|_| {
2586 PunchError::Auth(format!(
2587 "environment variable '{}' not set for {} driver",
2588 env_var, config.provider
2589 ))
2590 })?,
2591 None => {
2592 String::new()
2594 }
2595 };
2596
2597 let base_url = config
2598 .base_url
2599 .clone()
2600 .unwrap_or_else(|| default_base_url(&config.provider).to_string());
2601
2602 match &config.provider {
2603 Provider::Anthropic => {
2604 if let Some(client) = shared_client {
2605 Ok(Arc::new(AnthropicDriver::with_client(
2606 client.clone(),
2607 api_key,
2608 Some(base_url),
2609 )))
2610 } else {
2611 Ok(Arc::new(AnthropicDriver::new(api_key, Some(base_url))))
2612 }
2613 }
2614 Provider::Google => {
2615 if let Some(client) = shared_client {
2616 Ok(Arc::new(GeminiDriver::with_client(
2617 client.clone(),
2618 api_key,
2619 Some(base_url),
2620 )))
2621 } else {
2622 Ok(Arc::new(GeminiDriver::new(api_key, Some(base_url))))
2623 }
2624 }
2625 Provider::Ollama => {
2626 if let Some(client) = shared_client {
2627 Ok(Arc::new(OllamaDriver::with_client(
2628 client.clone(),
2629 Some(base_url),
2630 )))
2631 } else {
2632 Ok(Arc::new(OllamaDriver::new(Some(base_url))))
2633 }
2634 }
2635 Provider::Bedrock => {
2636 let (access_key, secret_key) = if api_key.contains(':') {
2639 let parts: Vec<&str> = api_key.splitn(2, ':').collect();
2640 (parts[0].to_string(), parts[1].to_string())
2641 } else {
2642 let ak = std::env::var("AWS_ACCESS_KEY_ID").unwrap_or(api_key);
2643 let sk = std::env::var("AWS_SECRET_ACCESS_KEY").unwrap_or_default();
2644 (ak, sk)
2645 };
2646 let region = if base_url.contains("bedrock-runtime.") {
2648 base_url
2649 .trim_start_matches("https://bedrock-runtime.")
2650 .split('.')
2651 .next()
2652 .unwrap_or("us-east-1")
2653 .to_string()
2654 } else {
2655 "us-east-1".to_string()
2656 };
2657 if let Some(client) = shared_client {
2658 Ok(Arc::new(BedrockDriver::with_client(
2659 client.clone(),
2660 access_key,
2661 secret_key,
2662 region,
2663 )))
2664 } else {
2665 Ok(Arc::new(BedrockDriver::new(access_key, secret_key, region)))
2666 }
2667 }
2668 Provider::AzureOpenAi => {
2669 let resource = if base_url.contains(".openai.azure.com") {
2672 base_url
2673 .trim_start_matches("https://")
2674 .split('.')
2675 .next()
2676 .unwrap_or("default")
2677 .to_string()
2678 } else {
2679 base_url.clone()
2680 };
2681 let deployment = config.model.clone();
2682 if let Some(client) = shared_client {
2683 Ok(Arc::new(AzureOpenAiDriver::with_client(
2684 client.clone(),
2685 api_key,
2686 resource,
2687 deployment,
2688 None,
2689 )))
2690 } else {
2691 Ok(Arc::new(AzureOpenAiDriver::new(
2692 api_key, resource, deployment, None,
2693 )))
2694 }
2695 }
2696 provider => {
2697 let name = provider.to_string();
2698 if let Some(client) = shared_client {
2699 Ok(Arc::new(OpenAiCompatibleDriver::with_client(
2700 client.clone(),
2701 api_key,
2702 base_url,
2703 name,
2704 )))
2705 } else {
2706 Ok(Arc::new(OpenAiCompatibleDriver::new(
2707 api_key, base_url, name,
2708 )))
2709 }
2710 }
2711 }
2712}
2713
2714#[cfg(test)]
2719mod tests {
2720 use super::*;
2721 use punch_types::ToolCategory;
2722
2723 fn simple_request() -> CompletionRequest {
2725 CompletionRequest {
2726 model: "test-model".to_string(),
2727 messages: vec![Message::new(Role::User, "Hello")],
2728 tools: Vec::new(),
2729 max_tokens: 4096,
2730 temperature: Some(0.7),
2731 system_prompt: Some("You are helpful.".to_string()),
2732 }
2733 }
2734
2735 fn request_with_tools() -> CompletionRequest {
2737 CompletionRequest {
2738 model: "test-model".to_string(),
2739 messages: vec![Message::new(Role::User, "Use the tool")],
2740 tools: vec![ToolDefinition {
2741 name: "get_weather".to_string(),
2742 description: "Get weather for a city".to_string(),
2743 input_schema: serde_json::json!({
2744 "type": "object",
2745 "properties": {
2746 "city": {"type": "string"}
2747 }
2748 }),
2749 category: ToolCategory::Web,
2750 }],
2751 max_tokens: 4096,
2752 temperature: Some(0.7),
2753 system_prompt: None,
2754 }
2755 }
2756
2757 #[test]
2762 fn gemini_request_formatting() {
2763 let driver = GeminiDriver::new("test-key".to_string(), None);
2764 let body = driver.build_request_body(&simple_request());
2765
2766 let contents = body["contents"].as_array().unwrap();
2767 assert_eq!(contents.len(), 1);
2768 let first_text = contents[0]["parts"][0]["text"].as_str().unwrap();
2770 assert!(first_text.contains("You are helpful."));
2771 assert!(first_text.contains("Hello"));
2772 assert_eq!(contents[0]["role"].as_str().unwrap(), "user");
2774
2775 assert_eq!(body["generationConfig"]["maxOutputTokens"], 4096);
2776 assert!((body["generationConfig"]["temperature"].as_f64().unwrap() - 0.7).abs() < 0.001);
2777 }
2778
2779 #[test]
2780 fn gemini_response_parsing() {
2781 let driver = GeminiDriver::new("test-key".to_string(), None);
2782 let response_body = serde_json::json!({
2783 "candidates": [{
2784 "content": {
2785 "parts": [{"text": "Hello there!"}],
2786 "role": "model"
2787 },
2788 "finishReason": "STOP"
2789 }],
2790 "usageMetadata": {
2791 "promptTokenCount": 10,
2792 "candidatesTokenCount": 5
2793 }
2794 });
2795
2796 let resp = driver.parse_response(&response_body).unwrap();
2797 assert_eq!(resp.message.content, "Hello there!");
2798 assert_eq!(resp.stop_reason, StopReason::EndTurn);
2799 assert_eq!(resp.usage.input_tokens, 10);
2800 assert_eq!(resp.usage.output_tokens, 5);
2801 }
2802
2803 #[test]
2804 fn gemini_role_mapping_system_prepended() {
2805 let driver = GeminiDriver::new("test-key".to_string(), None);
2806 let req = CompletionRequest {
2807 model: "gemini-pro".to_string(),
2808 messages: vec![
2809 Message::new(Role::System, "Be concise."),
2810 Message::new(Role::User, "Hi"),
2811 ],
2812 tools: Vec::new(),
2813 max_tokens: 1024,
2814 temperature: None,
2815 system_prompt: None,
2816 };
2817 let body = driver.build_request_body(&req);
2818 let contents = body["contents"].as_array().unwrap();
2819 assert_eq!(contents.len(), 1);
2821 let text = contents[0]["parts"][0]["text"].as_str().unwrap();
2822 assert!(text.contains("Be concise."));
2823 assert!(text.contains("Hi"));
2824 }
2825
2826 #[test]
2827 fn gemini_function_call_parsing() {
2828 let driver = GeminiDriver::new("test-key".to_string(), None);
2829 let response_body = serde_json::json!({
2830 "candidates": [{
2831 "content": {
2832 "parts": [
2833 {"text": "Let me check the weather."},
2834 {
2835 "functionCall": {
2836 "name": "get_weather",
2837 "args": {"city": "London"}
2838 }
2839 }
2840 ],
2841 "role": "model"
2842 },
2843 "finishReason": "STOP"
2844 }],
2845 "usageMetadata": {
2846 "promptTokenCount": 15,
2847 "candidatesTokenCount": 8
2848 }
2849 });
2850
2851 let resp = driver.parse_response(&response_body).unwrap();
2852 assert_eq!(resp.message.content, "Let me check the weather.");
2853 assert_eq!(resp.stop_reason, StopReason::ToolUse);
2854 assert_eq!(resp.message.tool_calls.len(), 1);
2855 assert_eq!(resp.message.tool_calls[0].name, "get_weather");
2856 assert_eq!(resp.message.tool_calls[0].input["city"], "London");
2857 }
2858
2859 #[test]
2860 fn gemini_api_key_in_url() {
2861 let driver = GeminiDriver::new("my-secret-key".to_string(), None);
2862 let url = driver.build_url("gemini-pro");
2863 assert!(url.contains("key=my-secret-key"));
2864 assert!(url.contains("models/gemini-pro:generateContent"));
2865 }
2866
2867 #[test]
2872 fn ollama_request_formatting() {
2873 let driver = OllamaDriver::new(None);
2874 let body = driver.build_request_body(&simple_request());
2875
2876 assert_eq!(body["model"], "test-model");
2877 assert_eq!(body["stream"], false);
2878 let messages = body["messages"].as_array().unwrap();
2879 assert_eq!(messages.len(), 2);
2881 assert_eq!(messages[0]["role"], "system");
2882 assert_eq!(messages[0]["content"], "You are helpful.");
2883 assert_eq!(messages[1]["role"], "user");
2884 assert_eq!(messages[1]["content"], "Hello");
2885 assert!((body["options"]["temperature"].as_f64().unwrap() - 0.7).abs() < 0.001);
2886 }
2887
2888 #[test]
2889 fn ollama_response_parsing() {
2890 let driver = OllamaDriver::new(None);
2891 let response_body = serde_json::json!({
2892 "message": {
2893 "role": "assistant",
2894 "content": "Hi there!"
2895 },
2896 "done": true,
2897 "prompt_eval_count": 20,
2898 "eval_count": 10
2899 });
2900
2901 let resp = driver.parse_response(&response_body).unwrap();
2902 assert_eq!(resp.message.content, "Hi there!");
2903 assert_eq!(resp.stop_reason, StopReason::EndTurn);
2904 assert_eq!(resp.usage.input_tokens, 20);
2905 assert_eq!(resp.usage.output_tokens, 10);
2906 }
2907
2908 #[test]
2909 fn ollama_default_endpoint() {
2910 let driver = OllamaDriver::new(None);
2911 assert_eq!(driver.base_url(), "http://localhost:11434");
2912 }
2913
2914 #[test]
2915 fn ollama_custom_endpoint() {
2916 let driver = OllamaDriver::new(Some("http://myhost:9999".to_string()));
2917 assert_eq!(driver.base_url(), "http://myhost:9999");
2918 }
2919
2920 #[test]
2925 fn bedrock_request_formatting() {
2926 let driver = BedrockDriver::new(
2927 "TESTKEY".to_string(),
2928 "testsecret".to_string(),
2929 "us-west-2".to_string(),
2930 );
2931 let body = driver.build_request_body(&simple_request());
2932
2933 let messages = body["messages"].as_array().unwrap();
2934 assert_eq!(messages.len(), 1);
2935 assert_eq!(messages[0]["role"], "user");
2936 assert_eq!(messages[0]["content"][0]["text"], "Hello");
2937
2938 assert_eq!(body["inferenceConfig"]["maxTokens"], 4096);
2939 assert!((body["inferenceConfig"]["temperature"].as_f64().unwrap() - 0.7).abs() < 0.001);
2940 assert_eq!(body["system"][0]["text"], "You are helpful.");
2941 }
2942
2943 #[test]
2944 fn bedrock_sigv4_canonical_request() {
2945 let driver = BedrockDriver::new(
2946 "TESTACCESS1234567890".to_string(),
2947 "TestSecretKeyValue1234567890abcdefghijk".to_string(),
2948 "us-east-1".to_string(),
2949 );
2950
2951 let payload = b"{}";
2952 let timestamp = "20260313T120000Z";
2953
2954 let auth = driver
2955 .sign_request(
2956 "POST",
2957 "https://bedrock-runtime.us-east-1.amazonaws.com/model/test/converse",
2958 &[("content-type".to_string(), "application/json".to_string())],
2959 payload,
2960 timestamp,
2961 )
2962 .unwrap();
2963
2964 assert!(auth.starts_with(
2965 "AWS4-HMAC-SHA256 Credential=TESTACCESS1234567890/20260313/us-east-1/bedrock/aws4_request"
2966 ));
2967 assert!(auth.contains("SignedHeaders=content-type;host;x-amz-date"));
2968 assert!(auth.contains("Signature="));
2969 }
2970
2971 #[test]
2972 fn bedrock_response_parsing() {
2973 let driver = BedrockDriver::new(
2974 "key".to_string(),
2975 "secret".to_string(),
2976 "us-east-1".to_string(),
2977 );
2978 let response_body = serde_json::json!({
2979 "output": {
2980 "message": {
2981 "role": "assistant",
2982 "content": [{"text": "The answer is 42."}]
2983 }
2984 },
2985 "stopReason": "end_turn",
2986 "usage": {
2987 "inputTokens": 100,
2988 "outputTokens": 50
2989 }
2990 });
2991
2992 let resp = driver.parse_response(&response_body).unwrap();
2993 assert_eq!(resp.message.content, "The answer is 42.");
2994 assert_eq!(resp.stop_reason, StopReason::EndTurn);
2995 assert_eq!(resp.usage.input_tokens, 100);
2996 assert_eq!(resp.usage.output_tokens, 50);
2997 }
2998
2999 #[test]
3004 fn azure_openai_url_construction() {
3005 let driver = AzureOpenAiDriver::new(
3006 "my-azure-key".to_string(),
3007 "myresource".to_string(),
3008 "gpt-4-deployment".to_string(),
3009 None,
3010 );
3011 let url = driver.build_url();
3012 assert_eq!(
3013 url,
3014 "https://myresource.openai.azure.com/openai/deployments/gpt-4-deployment/chat/completions?api-version=2024-02-01"
3015 );
3016 }
3017
3018 #[test]
3019 fn azure_openai_custom_api_version() {
3020 let driver = AzureOpenAiDriver::new(
3021 "key".to_string(),
3022 "res".to_string(),
3023 "dep".to_string(),
3024 Some("2024-06-01".to_string()),
3025 );
3026 let url = driver.build_url();
3027 assert!(url.contains("api-version=2024-06-01"));
3028 }
3029
3030 #[test]
3031 fn azure_openai_request_formatting() {
3032 let driver = AzureOpenAiDriver::new(
3033 "key".to_string(),
3034 "res".to_string(),
3035 "dep".to_string(),
3036 None,
3037 );
3038 let body = driver.build_request_body(&simple_request());
3039 let messages = body["messages"].as_array().unwrap();
3041 assert_eq!(messages.len(), 2);
3043 assert_eq!(messages[0]["role"], "system");
3044 assert_eq!(messages[1]["role"], "user");
3045 assert_eq!(body["model"], "test-model");
3046 }
3047
3048 #[test]
3049 fn azure_openai_resource_and_deployment() {
3050 let driver = AzureOpenAiDriver::new(
3051 "key".to_string(),
3052 "my-resource".to_string(),
3053 "my-deploy".to_string(),
3054 None,
3055 );
3056 assert_eq!(driver.resource(), "my-resource");
3057 assert_eq!(driver.deployment(), "my-deploy");
3058 }
3059
3060 #[test]
3065 fn create_driver_dispatches_ollama() {
3066 let config = ModelConfig {
3067 provider: Provider::Ollama,
3068 model: "llama3".to_string(),
3069 api_key_env: None,
3070 base_url: None,
3071 max_tokens: None,
3072 temperature: None,
3073 };
3074 let driver = create_driver(&config);
3076 assert!(driver.is_ok());
3077 }
3078
3079 #[test]
3080 fn create_driver_dispatches_gemini() {
3081 unsafe { std::env::set_var("TEST_GEMINI_KEY_DISPATCH", "fake-key") };
3084 let config = ModelConfig {
3085 provider: Provider::Google,
3086 model: "gemini-pro".to_string(),
3087 api_key_env: Some("TEST_GEMINI_KEY_DISPATCH".to_string()),
3088 base_url: None,
3089 max_tokens: None,
3090 temperature: None,
3091 };
3092 let driver = create_driver(&config);
3093 assert!(driver.is_ok());
3094 unsafe { std::env::remove_var("TEST_GEMINI_KEY_DISPATCH") };
3095 }
3096
3097 #[test]
3098 fn create_driver_dispatches_bedrock() {
3099 unsafe { std::env::set_var("TEST_BEDROCK_KEY_DISPATCH", "TESTKEY:TESTSECRET") };
3101 let config = ModelConfig {
3102 provider: Provider::Bedrock,
3103 model: "anthropic.claude-v2".to_string(),
3104 api_key_env: Some("TEST_BEDROCK_KEY_DISPATCH".to_string()),
3105 base_url: None,
3106 max_tokens: None,
3107 temperature: None,
3108 };
3109 let driver = create_driver(&config);
3110 assert!(driver.is_ok());
3111 unsafe { std::env::remove_var("TEST_BEDROCK_KEY_DISPATCH") };
3112 }
3113
3114 #[test]
3115 fn create_driver_dispatches_azure_openai() {
3116 unsafe { std::env::set_var("TEST_AZURE_KEY_DISPATCH", "azure-key") };
3118 let config = ModelConfig {
3119 provider: Provider::AzureOpenAi,
3120 model: "gpt-4".to_string(),
3121 api_key_env: Some("TEST_AZURE_KEY_DISPATCH".to_string()),
3122 base_url: Some("https://myres.openai.azure.com".to_string()),
3123 max_tokens: None,
3124 temperature: None,
3125 };
3126 let driver = create_driver(&config);
3127 assert!(driver.is_ok());
3128 unsafe { std::env::remove_var("TEST_AZURE_KEY_DISPATCH") };
3129 }
3130
3131 #[test]
3132 fn gemini_tools_in_request() {
3133 let driver = GeminiDriver::new("key".to_string(), None);
3134 let body = driver.build_request_body(&request_with_tools());
3135
3136 let tools = body["tools"].as_array().unwrap();
3137 assert_eq!(tools.len(), 1);
3138 let func_decls = tools[0]["function_declarations"].as_array().unwrap();
3139 assert_eq!(func_decls.len(), 1);
3140 assert_eq!(func_decls[0]["name"], "get_weather");
3141 }
3142
3143 #[test]
3144 fn ollama_tools_in_request() {
3145 let driver = OllamaDriver::new(None);
3146 let body = driver.build_request_body(&request_with_tools());
3147
3148 let tools = body["tools"].as_array().unwrap();
3149 assert_eq!(tools.len(), 1);
3150 assert_eq!(tools[0]["type"], "function");
3151 assert_eq!(tools[0]["function"]["name"], "get_weather");
3152 }
3153
3154 #[test]
3155 fn bedrock_url_construction() {
3156 let driver = BedrockDriver::new(
3157 "key".to_string(),
3158 "secret".to_string(),
3159 "eu-west-1".to_string(),
3160 );
3161 let url = driver.build_url("anthropic.claude-3-sonnet");
3162 assert_eq!(
3163 url,
3164 "https://bedrock-runtime.eu-west-1.amazonaws.com/model/anthropic.claude-3-sonnet/converse"
3165 );
3166 }
3167
3168 #[test]
3173 fn token_usage_default() {
3174 let u = TokenUsage::default();
3175 assert_eq!(u.input_tokens, 0);
3176 assert_eq!(u.output_tokens, 0);
3177 assert_eq!(u.total(), 0);
3178 }
3179
3180 #[test]
3181 fn token_usage_accumulate() {
3182 let mut u = TokenUsage {
3183 input_tokens: 10,
3184 output_tokens: 20,
3185 };
3186 let other = TokenUsage {
3187 input_tokens: 5,
3188 output_tokens: 15,
3189 };
3190 u.accumulate(&other);
3191 assert_eq!(u.input_tokens, 15);
3192 assert_eq!(u.output_tokens, 35);
3193 assert_eq!(u.total(), 50);
3194 }
3195
3196 #[test]
3197 fn token_usage_total() {
3198 let u = TokenUsage {
3199 input_tokens: 100,
3200 output_tokens: 200,
3201 };
3202 assert_eq!(u.total(), 300);
3203 }
3204
3205 #[test]
3210 fn stop_reason_serialization() {
3211 let json = serde_json::to_string(&StopReason::EndTurn).unwrap();
3212 assert_eq!(json, "\"end_turn\"");
3213
3214 let json = serde_json::to_string(&StopReason::ToolUse).unwrap();
3215 assert_eq!(json, "\"tool_use\"");
3216
3217 let json = serde_json::to_string(&StopReason::MaxTokens).unwrap();
3218 assert_eq!(json, "\"max_tokens\"");
3219
3220 let json = serde_json::to_string(&StopReason::Error).unwrap();
3221 assert_eq!(json, "\"error\"");
3222 }
3223
3224 #[test]
3225 fn stop_reason_deserialization() {
3226 let sr: StopReason = serde_json::from_str("\"end_turn\"").unwrap();
3227 assert_eq!(sr, StopReason::EndTurn);
3228
3229 let sr: StopReason = serde_json::from_str("\"tool_use\"").unwrap();
3230 assert_eq!(sr, StopReason::ToolUse);
3231 }
3232
3233 #[test]
3238 fn anthropic_request_body_simple() {
3239 let driver = AnthropicDriver::new("test-key".to_string(), None);
3240 let body = driver.build_request_body(&simple_request());
3241
3242 assert_eq!(body["model"], "test-model");
3243 assert_eq!(body["max_tokens"], 4096);
3244 assert_eq!(body["system"], "You are helpful.");
3245 assert!((body["temperature"].as_f64().unwrap() - 0.7).abs() < 0.001);
3246
3247 let messages = body["messages"].as_array().unwrap();
3248 assert_eq!(messages.len(), 1);
3249 assert_eq!(messages[0]["role"], "user");
3250 assert_eq!(messages[0]["content"], "Hello");
3251 }
3252
3253 #[test]
3254 fn anthropic_request_body_with_tools() {
3255 let driver = AnthropicDriver::new("test-key".to_string(), None);
3256 let body = driver.build_request_body(&request_with_tools());
3257
3258 let tools = body["tools"].as_array().unwrap();
3259 assert_eq!(tools.len(), 1);
3260 assert_eq!(tools[0]["name"], "get_weather");
3261 assert!(tools[0]["input_schema"]["properties"].is_object());
3262 }
3263
3264 #[test]
3265 fn anthropic_request_body_no_system_prompt() {
3266 let driver = AnthropicDriver::new("test-key".to_string(), None);
3267 let req = CompletionRequest {
3268 model: "test".into(),
3269 messages: vec![Message::new(Role::User, "Hi")],
3270 tools: Vec::new(),
3271 max_tokens: 100,
3272 temperature: None,
3273 system_prompt: None,
3274 };
3275 let body = driver.build_request_body(&req);
3276 assert!(body.get("system").is_none());
3277 assert!(body.get("temperature").is_none());
3278 }
3279
3280 #[test]
3281 fn anthropic_parse_response_text() {
3282 let driver = AnthropicDriver::new("test-key".to_string(), None);
3283 let response_body = serde_json::json!({
3284 "content": [
3285 {"type": "text", "text": "Hello!"}
3286 ],
3287 "stop_reason": "end_turn",
3288 "usage": {
3289 "input_tokens": 10,
3290 "output_tokens": 5
3291 }
3292 });
3293
3294 let resp = driver.parse_response(&response_body).unwrap();
3295 assert_eq!(resp.message.content, "Hello!");
3296 assert_eq!(resp.stop_reason, StopReason::EndTurn);
3297 assert_eq!(resp.usage.input_tokens, 10);
3298 assert_eq!(resp.usage.output_tokens, 5);
3299 assert!(resp.message.tool_calls.is_empty());
3300 }
3301
3302 #[test]
3303 fn anthropic_parse_response_tool_use() {
3304 let driver = AnthropicDriver::new("test-key".to_string(), None);
3305 let response_body = serde_json::json!({
3306 "content": [
3307 {"type": "text", "text": "Let me check."},
3308 {
3309 "type": "tool_use",
3310 "id": "tool_abc",
3311 "name": "get_weather",
3312 "input": {"city": "NYC"}
3313 }
3314 ],
3315 "stop_reason": "tool_use",
3316 "usage": {"input_tokens": 20, "output_tokens": 15}
3317 });
3318
3319 let resp = driver.parse_response(&response_body).unwrap();
3320 assert_eq!(resp.message.content, "Let me check.");
3321 assert_eq!(resp.stop_reason, StopReason::ToolUse);
3322 assert_eq!(resp.message.tool_calls.len(), 1);
3323 assert_eq!(resp.message.tool_calls[0].id, "tool_abc");
3324 assert_eq!(resp.message.tool_calls[0].name, "get_weather");
3325 assert_eq!(resp.message.tool_calls[0].input["city"], "NYC");
3326 }
3327
3328 #[test]
3329 fn anthropic_parse_response_max_tokens() {
3330 let driver = AnthropicDriver::new("test-key".to_string(), None);
3331 let response_body = serde_json::json!({
3332 "content": [{"type": "text", "text": "truncated"}],
3333 "stop_reason": "max_tokens",
3334 "usage": {"input_tokens": 5, "output_tokens": 100}
3335 });
3336
3337 let resp = driver.parse_response(&response_body).unwrap();
3338 assert_eq!(resp.stop_reason, StopReason::MaxTokens);
3339 }
3340
3341 #[test]
3342 fn anthropic_parse_response_unknown_stop_reason() {
3343 let driver = AnthropicDriver::new("test-key".to_string(), None);
3344 let response_body = serde_json::json!({
3345 "content": [{"type": "text", "text": "err"}],
3346 "stop_reason": "something_unknown",
3347 "usage": {"input_tokens": 0, "output_tokens": 0}
3348 });
3349
3350 let resp = driver.parse_response(&response_body).unwrap();
3351 assert_eq!(resp.stop_reason, StopReason::Error);
3352 }
3353
3354 #[test]
3355 fn anthropic_request_body_with_assistant_and_tool_messages() {
3356 let driver = AnthropicDriver::new("test-key".to_string(), None);
3357 let req = CompletionRequest {
3358 model: "test".into(),
3359 messages: vec![
3360 Message::new(Role::User, "Hi"),
3361 Message {
3362 role: Role::Assistant,
3363 content: "I'll check".into(),
3364 tool_calls: vec![ToolCall {
3365 id: "call_1".into(),
3366 name: "file_read".into(),
3367 input: serde_json::json!({"path": "/tmp/test"}),
3368 }],
3369 tool_results: Vec::new(),
3370 timestamp: chrono::Utc::now(),
3371 },
3372 Message {
3373 role: Role::Tool,
3374 content: String::new(),
3375 tool_calls: Vec::new(),
3376 tool_results: vec![punch_types::ToolCallResult {
3377 id: "call_1".into(),
3378 content: "file contents".into(),
3379 is_error: false,
3380 }],
3381 timestamp: chrono::Utc::now(),
3382 },
3383 ],
3384 tools: Vec::new(),
3385 max_tokens: 100,
3386 temperature: None,
3387 system_prompt: None,
3388 };
3389
3390 let body = driver.build_request_body(&req);
3391 let messages = body["messages"].as_array().unwrap();
3392 assert_eq!(messages.len(), 3);
3393 assert_eq!(messages[0]["role"], "user");
3394 assert_eq!(messages[1]["role"], "assistant");
3395 assert_eq!(messages[2]["role"], "user"); }
3397
3398 #[test]
3399 fn anthropic_request_body_system_message_skipped() {
3400 let driver = AnthropicDriver::new("test-key".to_string(), None);
3401 let req = CompletionRequest {
3402 model: "test".into(),
3403 messages: vec![
3404 Message::new(Role::System, "System instruction"),
3405 Message::new(Role::User, "Hi"),
3406 ],
3407 tools: Vec::new(),
3408 max_tokens: 100,
3409 temperature: None,
3410 system_prompt: None,
3411 };
3412
3413 let body = driver.build_request_body(&req);
3414 let messages = body["messages"].as_array().unwrap();
3415 assert_eq!(messages.len(), 1);
3417 assert_eq!(messages[0]["role"], "user");
3418 }
3419
3420 #[test]
3425 fn openai_request_body_simple() {
3426 let driver = OpenAiCompatibleDriver::new(
3427 "key".into(),
3428 "https://api.openai.com".into(),
3429 "openai".into(),
3430 );
3431 let body = driver.build_request_body(&simple_request());
3432
3433 assert_eq!(body["model"], "test-model");
3434 let messages = body["messages"].as_array().unwrap();
3435 assert_eq!(messages.len(), 2);
3436 assert_eq!(messages[0]["role"], "system");
3437 assert_eq!(messages[0]["content"], "You are helpful.");
3438 assert_eq!(messages[1]["role"], "user");
3439 }
3440
3441 #[test]
3442 fn openai_request_body_with_tools() {
3443 let driver = OpenAiCompatibleDriver::new(
3444 "key".into(),
3445 "https://api.openai.com".into(),
3446 "openai".into(),
3447 );
3448 let body = driver.build_request_body(&request_with_tools());
3449
3450 let tools = body["tools"].as_array().unwrap();
3451 assert_eq!(tools.len(), 1);
3452 assert_eq!(tools[0]["type"], "function");
3453 assert_eq!(tools[0]["function"]["name"], "get_weather");
3454 }
3455
3456 #[test]
3457 fn openai_parse_response_text() {
3458 let driver = OpenAiCompatibleDriver::new(
3459 "key".into(),
3460 "https://api.openai.com".into(),
3461 "openai".into(),
3462 );
3463 let response_body = serde_json::json!({
3464 "choices": [{
3465 "message": {
3466 "role": "assistant",
3467 "content": "Hello!"
3468 },
3469 "finish_reason": "stop"
3470 }],
3471 "usage": {
3472 "prompt_tokens": 10,
3473 "completion_tokens": 5
3474 }
3475 });
3476
3477 let resp = driver.parse_response(&response_body).unwrap();
3478 assert_eq!(resp.message.content, "Hello!");
3479 assert_eq!(resp.stop_reason, StopReason::EndTurn);
3480 assert_eq!(resp.usage.input_tokens, 10);
3481 assert_eq!(resp.usage.output_tokens, 5);
3482 }
3483
3484 #[test]
3485 fn openai_parse_response_tool_calls() {
3486 let driver = OpenAiCompatibleDriver::new(
3487 "key".into(),
3488 "https://api.openai.com".into(),
3489 "openai".into(),
3490 );
3491 let response_body = serde_json::json!({
3492 "choices": [{
3493 "message": {
3494 "role": "assistant",
3495 "content": null,
3496 "tool_calls": [{
3497 "id": "call_123",
3498 "type": "function",
3499 "function": {
3500 "name": "get_weather",
3501 "arguments": "{\"city\": \"NYC\"}"
3502 }
3503 }]
3504 },
3505 "finish_reason": "tool_calls"
3506 }],
3507 "usage": {"prompt_tokens": 10, "completion_tokens": 5}
3508 });
3509
3510 let resp = driver.parse_response(&response_body).unwrap();
3511 assert_eq!(resp.stop_reason, StopReason::ToolUse);
3512 assert_eq!(resp.message.tool_calls.len(), 1);
3513 assert_eq!(resp.message.tool_calls[0].name, "get_weather");
3514 assert_eq!(resp.message.tool_calls[0].input["city"], "NYC");
3515 }
3516
3517 #[test]
3518 fn openai_parse_response_tool_calls_fix_stop_reason() {
3519 let driver = OpenAiCompatibleDriver::new(
3520 "key".into(),
3521 "https://api.openai.com".into(),
3522 "openai".into(),
3523 );
3524 let response_body = serde_json::json!({
3526 "choices": [{
3527 "message": {
3528 "role": "assistant",
3529 "content": "Using tool",
3530 "tool_calls": [{
3531 "id": "call_1",
3532 "type": "function",
3533 "function": {
3534 "name": "test_tool",
3535 "arguments": "{}"
3536 }
3537 }]
3538 },
3539 "finish_reason": "stop"
3540 }],
3541 "usage": {"prompt_tokens": 0, "completion_tokens": 0}
3542 });
3543
3544 let resp = driver.parse_response(&response_body).unwrap();
3545 assert_eq!(resp.stop_reason, StopReason::ToolUse);
3546 }
3547
3548 #[test]
3549 fn openai_parse_response_length_stop_reason() {
3550 let driver = OpenAiCompatibleDriver::new(
3551 "key".into(),
3552 "https://api.openai.com".into(),
3553 "openai".into(),
3554 );
3555 let response_body = serde_json::json!({
3556 "choices": [{
3557 "message": {"role": "assistant", "content": "cut off"},
3558 "finish_reason": "length"
3559 }],
3560 "usage": {"prompt_tokens": 0, "completion_tokens": 0}
3561 });
3562
3563 let resp = driver.parse_response(&response_body).unwrap();
3564 assert_eq!(resp.stop_reason, StopReason::MaxTokens);
3565 }
3566
3567 #[test]
3568 fn openai_parse_response_no_choices_error() {
3569 let driver = OpenAiCompatibleDriver::new(
3570 "key".into(),
3571 "https://api.openai.com".into(),
3572 "openai".into(),
3573 );
3574 let response_body = serde_json::json!({"choices": []});
3575
3576 let result = driver.parse_response(&response_body);
3577 assert!(result.is_err());
3578 }
3579
3580 #[test]
3585 fn gemini_assistant_message_formatting() {
3586 let driver = GeminiDriver::new("key".to_string(), None);
3587 let req = CompletionRequest {
3588 model: "gemini-pro".into(),
3589 messages: vec![
3590 Message::new(Role::User, "Hi"),
3591 Message {
3592 role: Role::Assistant,
3593 content: "Let me help".into(),
3594 tool_calls: vec![ToolCall {
3595 id: "tc1".into(),
3596 name: "get_weather".into(),
3597 input: serde_json::json!({"city": "NYC"}),
3598 }],
3599 tool_results: Vec::new(),
3600 timestamp: chrono::Utc::now(),
3601 },
3602 ],
3603 tools: Vec::new(),
3604 max_tokens: 100,
3605 temperature: None,
3606 system_prompt: None,
3607 };
3608
3609 let body = driver.build_request_body(&req);
3610 let contents = body["contents"].as_array().unwrap();
3611 assert_eq!(contents[1]["role"], "model"); let parts = contents[1]["parts"].as_array().unwrap();
3613 assert!(parts.len() >= 2); }
3615
3616 #[test]
3617 fn gemini_max_tokens_stop_reason() {
3618 let driver = GeminiDriver::new("key".to_string(), None);
3619 let response_body = serde_json::json!({
3620 "candidates": [{
3621 "content": {
3622 "parts": [{"text": "truncated"}],
3623 "role": "model"
3624 },
3625 "finishReason": "MAX_TOKENS"
3626 }],
3627 "usageMetadata": {"promptTokenCount": 0, "candidatesTokenCount": 0}
3628 });
3629
3630 let resp = driver.parse_response(&response_body).unwrap();
3631 assert_eq!(resp.stop_reason, StopReason::MaxTokens);
3632 }
3633
3634 #[test]
3635 fn gemini_custom_base_url() {
3636 let driver =
3637 GeminiDriver::new("key".to_string(), Some("https://custom.example.com".into()));
3638 let url = driver.build_url("gemini-pro");
3639 assert!(url.starts_with("https://custom.example.com/"));
3640 }
3641
3642 #[test]
3647 fn ollama_response_with_tool_calls() {
3648 let driver = OllamaDriver::new(None);
3649 let response_body = serde_json::json!({
3650 "message": {
3651 "role": "assistant",
3652 "content": "",
3653 "tool_calls": [{
3654 "function": {
3655 "name": "get_weather",
3656 "arguments": {"city": "London"}
3657 }
3658 }]
3659 },
3660 "done": true,
3661 "prompt_eval_count": 10,
3662 "eval_count": 5
3663 });
3664
3665 let resp = driver.parse_response(&response_body).unwrap();
3666 assert_eq!(resp.stop_reason, StopReason::ToolUse);
3667 assert_eq!(resp.message.tool_calls.len(), 1);
3668 assert_eq!(resp.message.tool_calls[0].name, "get_weather");
3669 }
3670
3671 #[test]
3672 fn ollama_response_not_done() {
3673 let driver = OllamaDriver::new(None);
3674 let response_body = serde_json::json!({
3675 "message": {"role": "assistant", "content": "partial"},
3676 "done": false,
3677 "prompt_eval_count": 10,
3678 "eval_count": 5
3679 });
3680
3681 let resp = driver.parse_response(&response_body).unwrap();
3682 assert_eq!(resp.stop_reason, StopReason::MaxTokens);
3683 }
3684
3685 #[test]
3690 fn bedrock_request_with_tools() {
3691 let driver = BedrockDriver::new("key".into(), "secret".into(), "us-east-1".into());
3692 let body = driver.build_request_body(&request_with_tools());
3693
3694 let tool_config = &body["toolConfig"]["tools"];
3695 assert!(tool_config.is_array());
3696 let tools = tool_config.as_array().unwrap();
3697 assert_eq!(tools.len(), 1);
3698 assert_eq!(tools[0]["toolSpec"]["name"], "get_weather");
3699 }
3700
3701 #[test]
3702 fn bedrock_response_with_tool_use() {
3703 let driver = BedrockDriver::new("key".into(), "secret".into(), "us-east-1".into());
3704 let response_body = serde_json::json!({
3705 "output": {
3706 "message": {
3707 "role": "assistant",
3708 "content": [
3709 {"text": "Using tool"},
3710 {"toolUse": {
3711 "toolUseId": "tu_123",
3712 "name": "get_weather",
3713 "input": {"city": "NYC"}
3714 }}
3715 ]
3716 }
3717 },
3718 "stopReason": "tool_use",
3719 "usage": {"inputTokens": 10, "outputTokens": 20}
3720 });
3721
3722 let resp = driver.parse_response(&response_body).unwrap();
3723 assert_eq!(resp.stop_reason, StopReason::ToolUse);
3724 assert_eq!(resp.message.tool_calls.len(), 1);
3725 assert_eq!(resp.message.tool_calls[0].id, "tu_123");
3726 assert_eq!(resp.message.tool_calls[0].name, "get_weather");
3727 }
3728
3729 #[test]
3730 fn bedrock_request_with_tool_results() {
3731 let driver = BedrockDriver::new("key".into(), "secret".into(), "us-east-1".into());
3732 let req = CompletionRequest {
3733 model: "test".into(),
3734 messages: vec![
3735 Message::new(Role::User, "Hi"),
3736 Message {
3737 role: Role::Tool,
3738 content: String::new(),
3739 tool_calls: Vec::new(),
3740 tool_results: vec![punch_types::ToolCallResult {
3741 id: "tu_1".into(),
3742 content: "result data".into(),
3743 is_error: false,
3744 }],
3745 timestamp: chrono::Utc::now(),
3746 },
3747 ],
3748 tools: Vec::new(),
3749 max_tokens: 100,
3750 temperature: None,
3751 system_prompt: None,
3752 };
3753
3754 let body = driver.build_request_body(&req);
3755 let messages = body["messages"].as_array().unwrap();
3756 assert_eq!(messages[1]["role"], "user"); let content = messages[1]["content"].as_array().unwrap();
3758 assert!(content[0]["toolResult"].is_object());
3759 assert_eq!(content[0]["toolResult"]["status"], "success");
3760 }
3761
3762 #[test]
3763 fn bedrock_url_different_regions() {
3764 let driver = BedrockDriver::new("k".into(), "s".into(), "ap-southeast-1".into());
3765 let url = driver.build_url("model-id");
3766 assert!(url.contains("ap-southeast-1"));
3767 }
3768
3769 #[test]
3774 fn azure_openai_delegates_parse_to_openai() {
3775 let driver = AzureOpenAiDriver::new("key".into(), "res".into(), "dep".into(), None);
3776 let response_body = serde_json::json!({
3777 "choices": [{
3778 "message": {"role": "assistant", "content": "Azure response"},
3779 "finish_reason": "stop"
3780 }],
3781 "usage": {"prompt_tokens": 5, "completion_tokens": 3}
3782 });
3783
3784 let resp = driver.parse_response(&response_body).unwrap();
3785 assert_eq!(resp.message.content, "Azure response");
3786 }
3787
3788 #[test]
3793 fn default_base_url_anthropic() {
3794 assert_eq!(
3795 default_base_url(&Provider::Anthropic),
3796 "https://api.anthropic.com"
3797 );
3798 }
3799
3800 #[test]
3801 fn default_base_url_openai() {
3802 assert_eq!(
3803 default_base_url(&Provider::OpenAI),
3804 "https://api.openai.com"
3805 );
3806 }
3807
3808 #[test]
3809 fn default_base_url_google() {
3810 assert_eq!(
3811 default_base_url(&Provider::Google),
3812 "https://generativelanguage.googleapis.com"
3813 );
3814 }
3815
3816 #[test]
3817 fn default_base_url_ollama() {
3818 assert_eq!(
3819 default_base_url(&Provider::Ollama),
3820 "http://localhost:11434"
3821 );
3822 }
3823
3824 #[test]
3825 fn default_base_url_groq() {
3826 assert_eq!(
3827 default_base_url(&Provider::Groq),
3828 "https://api.groq.com/openai"
3829 );
3830 }
3831
3832 #[test]
3833 fn default_base_url_deepseek() {
3834 assert_eq!(
3835 default_base_url(&Provider::DeepSeek),
3836 "https://api.deepseek.com"
3837 );
3838 }
3839
3840 #[test]
3845 fn test_hex_sha256() {
3846 let hash = hex_sha256(b"");
3847 assert_eq!(
3848 hash,
3849 "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855"
3850 );
3851 }
3852
3853 #[test]
3854 fn test_hex_encode() {
3855 assert_eq!(hex_encode(&[0x00, 0xff, 0x0a, 0xbc]), "00ff0abc");
3856 }
3857
3858 #[test]
3859 fn test_hmac_sha256_basic() {
3860 let result = hmac_sha256(b"key", b"data");
3861 assert!(!result.is_empty());
3862 assert_eq!(result.len(), 32); }
3864
3865 #[test]
3870 fn create_driver_missing_api_key_env() {
3871 let config = ModelConfig {
3872 provider: Provider::Anthropic,
3873 model: "claude-3".into(),
3874 api_key_env: Some("PUNCH_TEST_NONEXISTENT_KEY_XYZ".into()),
3875 base_url: None,
3876 max_tokens: None,
3877 temperature: None,
3878 };
3879 let result = create_driver(&config);
3880 assert!(result.is_err());
3881 }
3882
3883 #[test]
3884 fn create_driver_openai_compatible_fallback() {
3885 unsafe { std::env::set_var("TEST_CUSTOM_KEY_DRIVER", "fake-key") };
3887 let config = ModelConfig {
3888 provider: Provider::Custom("my-custom".into()),
3889 model: "custom-model".into(),
3890 api_key_env: Some("TEST_CUSTOM_KEY_DRIVER".into()),
3891 base_url: Some("https://custom.api.com".into()),
3892 max_tokens: None,
3893 temperature: None,
3894 };
3895 let result = create_driver(&config);
3896 assert!(result.is_ok());
3897 unsafe { std::env::remove_var("TEST_CUSTOM_KEY_DRIVER") };
3898 }
3899
3900 #[test]
3905 fn strip_thinking_tags_removes_think_block() {
3906 let input = "<think>internal reasoning here</think>The answer is 42.";
3907 assert_eq!(strip_thinking_tags(input), "The answer is 42.");
3908 }
3909
3910 #[test]
3911 fn strip_thinking_tags_removes_thinking_block() {
3912 let input = "<thinking>step by step reasoning</thinking>Hello world!";
3913 assert_eq!(strip_thinking_tags(input), "Hello world!");
3914 }
3915
3916 #[test]
3917 fn strip_thinking_tags_removes_reasoning_block() {
3918 let input = "<reasoning>let me figure this out</reasoning>The result is correct.";
3919 assert_eq!(strip_thinking_tags(input), "The result is correct.");
3920 }
3921
3922 #[test]
3923 fn strip_thinking_tags_removes_reflection_block() {
3924 let input = "<reflection>checking my work</reflection>Yes, that's right.";
3925 assert_eq!(strip_thinking_tags(input), "Yes, that's right.");
3926 }
3927
3928 #[test]
3929 fn strip_thinking_tags_removes_multiple_blocks() {
3930 let input = "<think>first thought</think>Hello <thinking>second thought</thinking>world!";
3931 assert_eq!(strip_thinking_tags(input), "Hello world!");
3932 }
3933
3934 #[test]
3935 fn strip_thinking_tags_preserves_content_without_tags() {
3936 let input = "Just a normal response with no thinking tags.";
3937 assert_eq!(strip_thinking_tags(input), input);
3938 }
3939
3940 #[test]
3941 fn strip_thinking_tags_handles_multiline_tags() {
3942 let input = "<think>\nLine 1\nLine 2\nLine 3\n</think>\nThe final answer.";
3943 assert_eq!(strip_thinking_tags(input), "The final answer.");
3944 }
3945
3946 #[test]
3947 fn strip_thinking_tags_returns_original_if_all_thinking() {
3948 let input = "<think>this is all thinking content and nothing else</think>";
3951 assert_eq!(strip_thinking_tags(input), input);
3952 }
3953
3954 #[test]
3955 fn strip_thinking_tags_handles_unclosed_tag() {
3956 let input = "Some text<think>unclosed thinking block";
3957 assert_eq!(strip_thinking_tags(input), "Some text");
3958 }
3959
3960 #[test]
3961 fn strip_thinking_tags_handles_empty_input() {
3962 assert_eq!(strip_thinking_tags(""), "");
3963 }
3964
3965 #[test]
3966 fn strip_thinking_tags_handles_empty_think_block() {
3967 let input = "<think></think>Visible content.";
3968 assert_eq!(strip_thinking_tags(input), "Visible content.");
3969 }
3970
3971 #[test]
3972 fn strip_thinking_tags_trims_whitespace() {
3973 let input = " <think>reasoning</think> Result ";
3974 assert_eq!(strip_thinking_tags(input), "Result");
3975 }
3976
3977 #[test]
3978 fn strip_thinking_tags_mixed_tag_types() {
3979 let input = "<think>t1</think>A<reasoning>r1</reasoning>B<reflection>f1</reflection>C";
3980 assert_eq!(strip_thinking_tags(input), "ABC");
3981 }
3982
3983 #[test]
3984 fn ollama_response_strips_thinking_tags() {
3985 let driver = OllamaDriver::new(None);
3986 let response_body = serde_json::json!({
3987 "message": {
3988 "role": "assistant",
3989 "content": "<think>\nLet me think about this...\nThe user wants hello world.\n</think>\nHello, world!"
3990 },
3991 "done": true,
3992 "prompt_eval_count": 20,
3993 "eval_count": 50
3994 });
3995
3996 let resp = driver.parse_response(&response_body).unwrap();
3997 assert_eq!(resp.message.content, "Hello, world!");
3998 assert!(!resp.message.content.contains("<think>"));
3999 }
4000
4001 #[test]
4002 fn gemini_response_strips_thinking_tags() {
4003 let driver = GeminiDriver::new("test-key".to_string(), None);
4004 let response_body = serde_json::json!({
4005 "candidates": [{
4006 "content": {
4007 "parts": [{"text": "<thinking>reasoning step</thinking>The answer is 7."}],
4008 "role": "model"
4009 },
4010 "finishReason": "STOP"
4011 }],
4012 "usageMetadata": {
4013 "promptTokenCount": 10,
4014 "candidatesTokenCount": 20
4015 }
4016 });
4017
4018 let resp = driver.parse_response(&response_body).unwrap();
4019 assert_eq!(resp.message.content, "The answer is 7.");
4020 assert!(!resp.message.content.contains("<thinking>"));
4021 }
4022
4023 #[test]
4024 fn anthropic_response_strips_thinking_tags() {
4025 let driver = AnthropicDriver::new("test-key".to_string(), None);
4026 let response_body = serde_json::json!({
4027 "content": [
4028 {"type": "text", "text": "<think>internal thought</think>Clean output."}
4029 ],
4030 "stop_reason": "end_turn",
4031 "usage": {"input_tokens": 10, "output_tokens": 5}
4032 });
4033
4034 let resp = driver.parse_response(&response_body).unwrap();
4035 assert_eq!(resp.message.content, "Clean output.");
4036 }
4037
4038 #[test]
4039 fn bedrock_response_strips_thinking_tags() {
4040 let driver = BedrockDriver::new(
4041 "key".to_string(),
4042 "secret".to_string(),
4043 "us-east-1".to_string(),
4044 );
4045 let response_body = serde_json::json!({
4046 "output": {
4047 "message": {
4048 "role": "assistant",
4049 "content": [{"text": "<reasoning>deep thought</reasoning>Result here."}]
4050 }
4051 },
4052 "stopReason": "end_turn",
4053 "usage": {"inputTokens": 50, "outputTokens": 25}
4054 });
4055
4056 let resp = driver.parse_response(&response_body).unwrap();
4057 assert_eq!(resp.message.content, "Result here.");
4058 }
4059
4060 #[test]
4065 fn stream_chunk_serialization_roundtrip() {
4066 let chunk = StreamChunk {
4067 delta: "Hello".to_string(),
4068 is_final: false,
4069 tool_call_delta: None,
4070 };
4071 let json = serde_json::to_string(&chunk).unwrap();
4072 let deserialized: StreamChunk = serde_json::from_str(&json).unwrap();
4073 assert_eq!(deserialized.delta, "Hello");
4074 assert!(!deserialized.is_final);
4075 assert!(deserialized.tool_call_delta.is_none());
4076 }
4077
4078 #[test]
4079 fn stream_chunk_with_tool_call_delta_serialization() {
4080 let chunk = StreamChunk {
4081 delta: String::new(),
4082 is_final: false,
4083 tool_call_delta: Some(ToolCallDelta {
4084 index: 0,
4085 id: Some("call_123".to_string()),
4086 name: Some("get_weather".to_string()),
4087 arguments_delta: "{\"city\":".to_string(),
4088 }),
4089 };
4090 let json = serde_json::to_string(&chunk).unwrap();
4091 let deserialized: StreamChunk = serde_json::from_str(&json).unwrap();
4092 let tcd = deserialized.tool_call_delta.unwrap();
4093 assert_eq!(tcd.index, 0);
4094 assert_eq!(tcd.id.unwrap(), "call_123");
4095 assert_eq!(tcd.name.unwrap(), "get_weather");
4096 assert_eq!(tcd.arguments_delta, "{\"city\":");
4097 }
4098
4099 #[test]
4100 fn stream_chunk_final_serialization() {
4101 let chunk = StreamChunk {
4102 delta: String::new(),
4103 is_final: true,
4104 tool_call_delta: None,
4105 };
4106 let json = serde_json::to_string(&chunk).unwrap();
4107 assert!(json.contains("\"is_final\":true"));
4108 }
4109
4110 #[test]
4111 fn tool_call_delta_serialization_roundtrip() {
4112 let tcd = ToolCallDelta {
4113 index: 2,
4114 id: None,
4115 name: None,
4116 arguments_delta: "\"NYC\"}".to_string(),
4117 };
4118 let json = serde_json::to_string(&tcd).unwrap();
4119 let deserialized: ToolCallDelta = serde_json::from_str(&json).unwrap();
4120 assert_eq!(deserialized.index, 2);
4121 assert!(deserialized.id.is_none());
4122 assert!(deserialized.name.is_none());
4123 assert_eq!(deserialized.arguments_delta, "\"NYC\"}");
4124 }
4125
4126 #[test]
4131 fn parse_sse_events_basic() {
4132 let raw = "event: message_start\ndata: {\"type\":\"message_start\"}\n\nevent: content_block_delta\ndata: {\"delta\":{\"text\":\"Hi\"}}\n\n";
4133 let events = parse_sse_events(raw);
4134 assert_eq!(events.len(), 2);
4135 assert_eq!(events[0].0, "message_start");
4136 assert_eq!(events[1].0, "content_block_delta");
4137 }
4138
4139 #[test]
4140 fn parse_sse_events_with_done() {
4141 let raw = "data: {\"choices\":[{\"delta\":{\"content\":\"Hello\"}}]}\n\ndata: [DONE]\n\n";
4142 let events = parse_sse_events(raw);
4143 assert_eq!(events.len(), 2);
4144 assert_eq!(events[1].1, "[DONE]");
4145 }
4146
4147 #[test]
4148 fn parse_sse_events_empty_input() {
4149 let events = parse_sse_events("");
4150 assert!(events.is_empty());
4151 }
4152
4153 #[test]
4154 fn parse_sse_events_no_trailing_newline() {
4155 let raw = "event: test\ndata: {\"value\":1}";
4156 let events = parse_sse_events(raw);
4157 assert_eq!(events.len(), 1);
4158 assert_eq!(events[0].0, "test");
4159 }
4160
4161 #[test]
4162 fn parse_sse_events_multiline_data() {
4163 let raw = "data: line1\ndata: line2\n\n";
4164 let events = parse_sse_events(raw);
4165 assert_eq!(events.len(), 1);
4166 assert_eq!(events[0].1, "line1\nline2");
4167 }
4168
4169 #[test]
4170 fn parse_sse_events_no_event_field() {
4171 let raw = "data: {\"hello\":\"world\"}\n\n";
4172 let events = parse_sse_events(raw);
4173 assert_eq!(events.len(), 1);
4174 assert_eq!(events[0].0, "message"); }
4176
4177 #[test]
4182 fn anthropic_stream_text_only() {
4183 let raw = concat!(
4184 "event: message_start\n",
4185 "data: {\"type\":\"message_start\",\"message\":{\"usage\":{\"input_tokens\":25}}}\n\n",
4186 "event: content_block_start\n",
4187 "data: {\"type\":\"content_block_start\",\"index\":0,\"content_block\":{\"type\":\"text\",\"text\":\"\"}}\n\n",
4188 "event: content_block_delta\n",
4189 "data: {\"type\":\"content_block_delta\",\"delta\":{\"type\":\"text_delta\",\"text\":\"Hello\"}}\n\n",
4190 "event: content_block_delta\n",
4191 "data: {\"type\":\"content_block_delta\",\"delta\":{\"type\":\"text_delta\",\"text\":\" world\"}}\n\n",
4192 "event: message_delta\n",
4193 "data: {\"type\":\"message_delta\",\"delta\":{\"stop_reason\":\"end_turn\"},\"usage\":{\"output_tokens\":10}}\n\n",
4194 "event: message_stop\n",
4195 "data: {\"type\":\"message_stop\"}\n\n",
4196 );
4197
4198 let events = parse_sse_events(raw);
4199 let chunks: Arc<std::sync::Mutex<Vec<StreamChunk>>> =
4200 Arc::new(std::sync::Mutex::new(Vec::new()));
4201 let chunks_clone = chunks.clone();
4202 let callback: StreamCallback = Arc::new(move |chunk| {
4203 chunks_clone.lock().unwrap().push(chunk);
4204 });
4205
4206 let mut text_content = String::new();
4208 let mut usage = TokenUsage::default();
4209 let mut stop_reason = StopReason::EndTurn;
4210
4211 for (event_type, data) in &events {
4212 let parsed: serde_json::Value = match serde_json::from_str(data) {
4213 Ok(v) => v,
4214 Err(_) => continue,
4215 };
4216
4217 match event_type.as_str() {
4218 "message_start" => {
4219 if let Some(inp) = parsed["message"]["usage"]["input_tokens"].as_u64() {
4220 usage.input_tokens = inp;
4221 }
4222 }
4223 "content_block_delta" => {
4224 if let Some(text) = parsed["delta"]["text"].as_str() {
4225 text_content.push_str(text);
4226 callback(StreamChunk {
4227 delta: text.to_string(),
4228 is_final: false,
4229 tool_call_delta: None,
4230 });
4231 }
4232 }
4233 "message_delta" => {
4234 if let Some(sr) = parsed["delta"]["stop_reason"].as_str() {
4235 stop_reason = match sr {
4236 "end_turn" => StopReason::EndTurn,
4237 "tool_use" => StopReason::ToolUse,
4238 _ => StopReason::Error,
4239 };
4240 }
4241 if let Some(out) = parsed["usage"]["output_tokens"].as_u64() {
4242 usage.output_tokens = out;
4243 }
4244 }
4245 "message_stop" => {
4246 callback(StreamChunk {
4247 delta: String::new(),
4248 is_final: true,
4249 tool_call_delta: None,
4250 });
4251 }
4252 _ => {}
4253 }
4254 }
4255
4256 assert_eq!(text_content, "Hello world");
4257 assert_eq!(usage.input_tokens, 25);
4258 assert_eq!(usage.output_tokens, 10);
4259 assert_eq!(stop_reason, StopReason::EndTurn);
4260
4261 let received = chunks.lock().unwrap();
4262 assert_eq!(received.len(), 3); assert_eq!(received[0].delta, "Hello");
4264 assert_eq!(received[1].delta, " world");
4265 assert!(received[2].is_final);
4266 }
4267
4268 #[test]
4269 fn anthropic_stream_with_tool_use() {
4270 let raw = concat!(
4271 "event: message_start\n",
4272 "data: {\"type\":\"message_start\",\"message\":{\"usage\":{\"input_tokens\":15}}}\n\n",
4273 "event: content_block_start\n",
4274 "data: {\"type\":\"content_block_start\",\"index\":0,\"content_block\":{\"type\":\"text\",\"text\":\"\"}}\n\n",
4275 "event: content_block_delta\n",
4276 "data: {\"type\":\"content_block_delta\",\"delta\":{\"type\":\"text_delta\",\"text\":\"Checking.\"}}\n\n",
4277 "event: content_block_start\n",
4278 "data: {\"type\":\"content_block_start\",\"index\":1,\"content_block\":{\"type\":\"tool_use\",\"id\":\"tool_1\",\"name\":\"get_weather\"}}\n\n",
4279 "event: content_block_delta\n",
4280 "data: {\"type\":\"content_block_delta\",\"delta\":{\"type\":\"input_json_delta\",\"partial_json\":\"{\\\"city\\\"\"}}\n\n",
4281 "event: content_block_delta\n",
4282 "data: {\"type\":\"content_block_delta\",\"delta\":{\"type\":\"input_json_delta\",\"partial_json\":\": \\\"NYC\\\"}\"}}\n\n",
4283 "event: message_delta\n",
4284 "data: {\"type\":\"message_delta\",\"delta\":{\"stop_reason\":\"tool_use\"},\"usage\":{\"output_tokens\":20}}\n\n",
4285 "event: message_stop\n",
4286 "data: {\"type\":\"message_stop\"}\n\n",
4287 );
4288
4289 let events = parse_sse_events(raw);
4290 assert!(events.len() >= 7);
4292
4293 let mut tool_json_bufs: Vec<String> = Vec::new();
4295 let mut tc_idx: Option<usize> = None;
4296
4297 for (event_type, data) in &events {
4298 let parsed: serde_json::Value = match serde_json::from_str(data) {
4299 Ok(v) => v,
4300 Err(_) => continue,
4301 };
4302 match event_type.as_str() {
4303 "content_block_start" => {
4304 if parsed["content_block"]["type"].as_str() == Some("tool_use") {
4305 tool_json_bufs.push(String::new());
4306 tc_idx = Some(tool_json_bufs.len() - 1);
4307 } else {
4308 tc_idx = None;
4309 }
4310 }
4311 "content_block_delta" => {
4312 if parsed["delta"]["type"].as_str() == Some("input_json_delta")
4313 && let Some(idx) = tc_idx
4314 && let Some(buf) = tool_json_bufs.get_mut(idx)
4315 {
4316 buf.push_str(parsed["delta"]["partial_json"].as_str().unwrap_or(""));
4317 }
4318 }
4319 _ => {}
4320 }
4321 }
4322
4323 assert_eq!(tool_json_bufs.len(), 1);
4324 assert_eq!(tool_json_bufs[0], "{\"city\": \"NYC\"}");
4325
4326 let parsed_input: serde_json::Value = serde_json::from_str(&tool_json_bufs[0]).unwrap();
4327 assert_eq!(parsed_input["city"], "NYC");
4328 }
4329
4330 #[test]
4335 fn openai_stream_text_only() {
4336 let driver = OpenAiCompatibleDriver::new(
4337 "key".into(),
4338 "https://api.openai.com".into(),
4339 "openai".into(),
4340 );
4341
4342 let raw = concat!(
4343 "data: {\"choices\":[{\"delta\":{\"role\":\"assistant\"},\"index\":0}]}\n\n",
4344 "data: {\"choices\":[{\"delta\":{\"content\":\"Hello\"},\"index\":0}]}\n\n",
4345 "data: {\"choices\":[{\"delta\":{\"content\":\" world\"},\"index\":0}]}\n\n",
4346 "data: {\"choices\":[{\"delta\":{},\"index\":0,\"finish_reason\":\"stop\"}]}\n\n",
4347 "data: [DONE]\n\n",
4348 );
4349
4350 let chunks: Arc<std::sync::Mutex<Vec<StreamChunk>>> =
4351 Arc::new(std::sync::Mutex::new(Vec::new()));
4352 let chunks_clone = chunks.clone();
4353 let callback: StreamCallback = Arc::new(move |chunk| {
4354 chunks_clone.lock().unwrap().push(chunk);
4355 });
4356
4357 let resp = driver.parse_openai_stream(raw, &callback).unwrap();
4358
4359 assert_eq!(resp.message.content, "Hello world");
4360 assert_eq!(resp.stop_reason, StopReason::EndTurn);
4361 assert!(resp.message.tool_calls.is_empty());
4362
4363 let received = chunks.lock().unwrap();
4364 assert!(received.len() >= 3);
4366 assert_eq!(received[0].delta, "Hello");
4367 assert_eq!(received[1].delta, " world");
4368 assert!(received.last().unwrap().is_final);
4369 }
4370
4371 #[test]
4372 fn openai_stream_with_tool_calls() {
4373 let driver = OpenAiCompatibleDriver::new(
4374 "key".into(),
4375 "https://api.openai.com".into(),
4376 "openai".into(),
4377 );
4378
4379 let raw = concat!(
4380 "data: {\"choices\":[{\"delta\":{\"tool_calls\":[{\"index\":0,\"id\":\"call_abc\",\"type\":\"function\",\"function\":{\"name\":\"get_weather\",\"arguments\":\"\"}}]},\"index\":0}]}\n\n",
4381 "data: {\"choices\":[{\"delta\":{\"tool_calls\":[{\"index\":0,\"function\":{\"arguments\":\"{\\\"ci\"}}]},\"index\":0}]}\n\n",
4382 "data: {\"choices\":[{\"delta\":{\"tool_calls\":[{\"index\":0,\"function\":{\"arguments\":\"ty\\\": \\\"NYC\\\"}\"}}]},\"index\":0}]}\n\n",
4383 "data: {\"choices\":[{\"delta\":{},\"index\":0,\"finish_reason\":\"tool_calls\"}]}\n\n",
4384 "data: [DONE]\n\n",
4385 );
4386
4387 let chunks: Arc<std::sync::Mutex<Vec<StreamChunk>>> =
4388 Arc::new(std::sync::Mutex::new(Vec::new()));
4389 let chunks_clone = chunks.clone();
4390 let callback: StreamCallback = Arc::new(move |chunk| {
4391 chunks_clone.lock().unwrap().push(chunk);
4392 });
4393
4394 let resp = driver.parse_openai_stream(raw, &callback).unwrap();
4395
4396 assert_eq!(resp.stop_reason, StopReason::ToolUse);
4397 assert_eq!(resp.message.tool_calls.len(), 1);
4398 assert_eq!(resp.message.tool_calls[0].id, "call_abc");
4399 assert_eq!(resp.message.tool_calls[0].name, "get_weather");
4400 assert_eq!(resp.message.tool_calls[0].input["city"], "NYC");
4401
4402 let received = chunks.lock().unwrap();
4403 let tool_chunks: Vec<_> = received
4405 .iter()
4406 .filter(|c| c.tool_call_delta.is_some())
4407 .collect();
4408 assert!(tool_chunks.len() >= 3); assert!(received.last().unwrap().is_final);
4410 }
4411
4412 #[test]
4413 fn openai_stream_with_mixed_content_and_tools() {
4414 let driver = OpenAiCompatibleDriver::new(
4415 "key".into(),
4416 "https://api.openai.com".into(),
4417 "openai".into(),
4418 );
4419
4420 let raw = concat!(
4421 "data: {\"choices\":[{\"delta\":{\"content\":\"Sure, \"},\"index\":0}]}\n\n",
4422 "data: {\"choices\":[{\"delta\":{\"content\":\"checking.\"},\"index\":0}]}\n\n",
4423 "data: {\"choices\":[{\"delta\":{\"tool_calls\":[{\"index\":0,\"id\":\"call_1\",\"type\":\"function\",\"function\":{\"name\":\"search\",\"arguments\":\"{\\\"q\\\":\\\"test\\\"}\"}}]},\"index\":0}]}\n\n",
4424 "data: {\"choices\":[{\"delta\":{},\"index\":0,\"finish_reason\":\"tool_calls\"}]}\n\n",
4425 "data: [DONE]\n\n",
4426 );
4427
4428 let callback: StreamCallback = Arc::new(|_| {});
4429 let resp = driver.parse_openai_stream(raw, &callback).unwrap();
4430
4431 assert_eq!(resp.message.content, "Sure, checking.");
4432 assert_eq!(resp.stop_reason, StopReason::ToolUse);
4433 assert_eq!(resp.message.tool_calls.len(), 1);
4434 assert_eq!(resp.message.tool_calls[0].name, "search");
4435 }
4436
4437 #[test]
4438 fn openai_stream_length_stop_reason() {
4439 let driver = OpenAiCompatibleDriver::new(
4440 "key".into(),
4441 "https://api.openai.com".into(),
4442 "openai".into(),
4443 );
4444
4445 let raw = concat!(
4446 "data: {\"choices\":[{\"delta\":{\"content\":\"truncated\"},\"index\":0}]}\n\n",
4447 "data: {\"choices\":[{\"delta\":{},\"index\":0,\"finish_reason\":\"length\"}]}\n\n",
4448 "data: [DONE]\n\n",
4449 );
4450
4451 let callback: StreamCallback = Arc::new(|_| {});
4452 let resp = driver.parse_openai_stream(raw, &callback).unwrap();
4453 assert_eq!(resp.stop_reason, StopReason::MaxTokens);
4454 }
4455
4456 #[test]
4461 fn ollama_stream_text_only() {
4462 let driver = OllamaDriver::new(None);
4463
4464 let raw = concat!(
4465 "{\"message\":{\"role\":\"assistant\",\"content\":\"Hello\"},\"done\":false}\n",
4466 "{\"message\":{\"role\":\"assistant\",\"content\":\" world\"},\"done\":false}\n",
4467 "{\"message\":{\"role\":\"assistant\",\"content\":\"!\"},\"done\":false}\n",
4468 "{\"done\":true,\"prompt_eval_count\":15,\"eval_count\":8}\n",
4469 );
4470
4471 let chunks: Arc<std::sync::Mutex<Vec<StreamChunk>>> =
4472 Arc::new(std::sync::Mutex::new(Vec::new()));
4473 let chunks_clone = chunks.clone();
4474 let callback: StreamCallback = Arc::new(move |chunk| {
4475 chunks_clone.lock().unwrap().push(chunk);
4476 });
4477
4478 let resp = driver.parse_ollama_stream(raw, &callback).unwrap();
4479
4480 assert_eq!(resp.message.content, "Hello world!");
4481 assert_eq!(resp.stop_reason, StopReason::EndTurn);
4482 assert_eq!(resp.usage.input_tokens, 15);
4483 assert_eq!(resp.usage.output_tokens, 8);
4484
4485 let received = chunks.lock().unwrap();
4486 assert_eq!(received.len(), 4); assert_eq!(received[0].delta, "Hello");
4488 assert_eq!(received[1].delta, " world");
4489 assert_eq!(received[2].delta, "!");
4490 assert!(received[3].is_final);
4491 }
4492
4493 #[test]
4494 fn ollama_stream_with_tool_calls() {
4495 let driver = OllamaDriver::new(None);
4496
4497 let raw = concat!(
4498 "{\"message\":{\"role\":\"assistant\",\"content\":\"Let me check.\"},\"done\":false}\n",
4499 "{\"message\":{\"role\":\"assistant\",\"content\":\"\",\"tool_calls\":[{\"function\":{\"name\":\"get_weather\",\"arguments\":{\"city\":\"London\"}}}]},\"done\":true,\"prompt_eval_count\":10,\"eval_count\":5}\n",
4500 );
4501
4502 let callback: StreamCallback = Arc::new(|_| {});
4503 let resp = driver.parse_ollama_stream(raw, &callback).unwrap();
4504
4505 assert_eq!(resp.message.content, "Let me check.");
4506 assert_eq!(resp.stop_reason, StopReason::ToolUse);
4507 assert_eq!(resp.message.tool_calls.len(), 1);
4508 assert_eq!(resp.message.tool_calls[0].name, "get_weather");
4509 assert_eq!(resp.usage.input_tokens, 10);
4510 }
4511
4512 #[test]
4513 fn ollama_stream_strips_thinking_tags() {
4514 let driver = OllamaDriver::new(None);
4515
4516 let raw = concat!(
4517 "{\"message\":{\"role\":\"assistant\",\"content\":\"<think>hmm</think>\"},\"done\":false}\n",
4518 "{\"message\":{\"role\":\"assistant\",\"content\":\"Clean answer.\"},\"done\":false}\n",
4519 "{\"done\":true,\"prompt_eval_count\":5,\"eval_count\":3}\n",
4520 );
4521
4522 let callback: StreamCallback = Arc::new(|_| {});
4523 let resp = driver.parse_ollama_stream(raw, &callback).unwrap();
4524 assert_eq!(resp.message.content, "Clean answer.");
4525 }
4526
4527 #[test]
4532 fn gemini_stream_url_construction() {
4533 let driver = GeminiDriver::new("my-key".to_string(), None);
4534 let url = driver.build_stream_url("gemini-pro");
4535 assert!(url.contains("streamGenerateContent"));
4536 assert!(url.contains("alt=sse"));
4537 assert!(url.contains("key=my-key"));
4538 assert!(url.contains("models/gemini-pro"));
4539 }
4540
4541 #[test]
4542 fn gemini_stream_custom_base_url() {
4543 let driver = GeminiDriver::new(
4544 "key".to_string(),
4545 Some("https://custom.example.com".to_string()),
4546 );
4547 let url = driver.build_stream_url("gemini-pro");
4548 assert!(url.starts_with("https://custom.example.com/"));
4549 assert!(url.contains("streamGenerateContent"));
4550 }
4551
4552 #[test]
4557 fn callback_receives_all_chunks_in_order() {
4558 let driver = OpenAiCompatibleDriver::new(
4559 "key".into(),
4560 "https://api.openai.com".into(),
4561 "openai".into(),
4562 );
4563
4564 let raw = concat!(
4565 "data: {\"choices\":[{\"delta\":{\"content\":\"A\"},\"index\":0}]}\n\n",
4566 "data: {\"choices\":[{\"delta\":{\"content\":\"B\"},\"index\":0}]}\n\n",
4567 "data: {\"choices\":[{\"delta\":{\"content\":\"C\"},\"index\":0}]}\n\n",
4568 "data: {\"choices\":[{\"delta\":{},\"index\":0,\"finish_reason\":\"stop\"}]}\n\n",
4569 "data: [DONE]\n\n",
4570 );
4571
4572 let deltas: Arc<std::sync::Mutex<Vec<String>>> =
4573 Arc::new(std::sync::Mutex::new(Vec::new()));
4574 let deltas_clone = deltas.clone();
4575 let callback: StreamCallback = Arc::new(move |chunk| {
4576 if !chunk.delta.is_empty() || chunk.is_final {
4577 deltas_clone.lock().unwrap().push(chunk.delta.clone());
4578 }
4579 });
4580
4581 let _resp = driver.parse_openai_stream(raw, &callback).unwrap();
4582 let received = deltas.lock().unwrap();
4583 assert_eq!(received.as_slice(), &["A", "B", "C", ""]);
4584 }
4585
4586 #[test]
4587 fn openai_stream_multiple_tool_calls() {
4588 let driver = OpenAiCompatibleDriver::new(
4589 "key".into(),
4590 "https://api.openai.com".into(),
4591 "openai".into(),
4592 );
4593
4594 let raw = concat!(
4595 "data: {\"choices\":[{\"delta\":{\"tool_calls\":[{\"index\":0,\"id\":\"call_1\",\"type\":\"function\",\"function\":{\"name\":\"tool_a\",\"arguments\":\"{\\\"x\\\":1}\"}}]},\"index\":0}]}\n\n",
4596 "data: {\"choices\":[{\"delta\":{\"tool_calls\":[{\"index\":1,\"id\":\"call_2\",\"type\":\"function\",\"function\":{\"name\":\"tool_b\",\"arguments\":\"{\\\"y\\\":2}\"}}]},\"index\":0}]}\n\n",
4597 "data: {\"choices\":[{\"delta\":{},\"index\":0,\"finish_reason\":\"tool_calls\"}]}\n\n",
4598 "data: [DONE]\n\n",
4599 );
4600
4601 let callback: StreamCallback = Arc::new(|_| {});
4602 let resp = driver.parse_openai_stream(raw, &callback).unwrap();
4603
4604 assert_eq!(resp.message.tool_calls.len(), 2);
4605 assert_eq!(resp.message.tool_calls[0].id, "call_1");
4606 assert_eq!(resp.message.tool_calls[0].name, "tool_a");
4607 assert_eq!(resp.message.tool_calls[0].input["x"], 1);
4608 assert_eq!(resp.message.tool_calls[1].id, "call_2");
4609 assert_eq!(resp.message.tool_calls[1].name, "tool_b");
4610 assert_eq!(resp.message.tool_calls[1].input["y"], 2);
4611 }
4612
4613 #[test]
4618 fn stream_chunk_default_values() {
4619 let chunk = StreamChunk {
4620 delta: String::new(),
4621 is_final: false,
4622 tool_call_delta: None,
4623 };
4624 assert!(chunk.delta.is_empty());
4625 assert!(!chunk.is_final);
4626 assert!(chunk.tool_call_delta.is_none());
4627 }
4628
4629 #[test]
4630 fn openai_stream_empty_input() {
4631 let driver = OpenAiCompatibleDriver::new(
4632 "key".into(),
4633 "https://api.openai.com".into(),
4634 "openai".into(),
4635 );
4636
4637 let raw = "data: [DONE]\n\n";
4638
4639 let callback: StreamCallback = Arc::new(|_| {});
4640 let resp = driver.parse_openai_stream(raw, &callback).unwrap();
4641
4642 assert_eq!(resp.message.content, "");
4643 assert!(resp.message.tool_calls.is_empty());
4644 }
4645
4646 #[test]
4647 fn ollama_stream_empty_input() {
4648 let driver = OllamaDriver::new(None);
4649 let raw = "";
4650
4651 let callback: StreamCallback = Arc::new(|_| {});
4652 let resp = driver.parse_ollama_stream(raw, &callback).unwrap();
4653
4654 assert_eq!(resp.message.content, "");
4655 assert_eq!(resp.stop_reason, StopReason::MaxTokens); }
4657
4658 #[test]
4659 fn openai_stream_strips_thinking_tags() {
4660 let driver = OpenAiCompatibleDriver::new(
4661 "key".into(),
4662 "https://api.openai.com".into(),
4663 "openai".into(),
4664 );
4665
4666 let raw = concat!(
4667 "data: {\"choices\":[{\"delta\":{\"content\":\"<think>internal</think>\"},\"index\":0}]}\n\n",
4668 "data: {\"choices\":[{\"delta\":{\"content\":\"Result\"},\"index\":0}]}\n\n",
4669 "data: {\"choices\":[{\"delta\":{},\"index\":0,\"finish_reason\":\"stop\"}]}\n\n",
4670 "data: [DONE]\n\n",
4671 );
4672
4673 let callback: StreamCallback = Arc::new(|_| {});
4674 let resp = driver.parse_openai_stream(raw, &callback).unwrap();
4675 assert_eq!(resp.message.content, "Result");
4676 }
4677}