1use crate::stream::StreamEvent;
8use crate::tasks::generate::{ContentBlock, Message, ResponseFormat, ToolCall};
9use crate::InferenceError;
10use serde_json::Value;
11use std::collections::HashMap;
12
13#[derive(Debug, Clone)]
15pub struct ApiRequest {
16 pub model: String,
17 pub messages: Vec<Value>,
18 pub system: Option<String>,
19 pub temperature: f64,
20 pub max_tokens: usize,
21 pub tools: Option<Vec<Value>>,
22 pub tool_choice: Option<String>,
23 pub parallel_tool_calls: Option<bool>,
24 pub stream: bool,
25 pub budget_tokens: usize,
26 pub cache_control: bool,
29 pub response_format: Option<ResponseFormat>,
32}
33
34#[derive(Debug, Clone)]
36pub struct ApiResponse {
37 pub text: String,
38 pub tool_calls: Vec<ToolCall>,
39 pub usage: Option<crate::TokenUsage>,
41}
42
43pub trait ProtocolHandler: Send + Sync {
45 fn endpoint_path(&self) -> &str;
47
48 fn auth_headers(&self, api_key: &str) -> Vec<(String, String)>;
50
51 fn build_request_body(&self, req: &ApiRequest) -> Value;
53
54 fn parse_response(&self, body: &str) -> Result<ApiResponse, InferenceError>;
56
57 fn parse_stream_event(&self, event_type: &str, data: &str) -> Vec<StreamEvent>;
60
61 fn build_messages(
63 &self,
64 messages: &[Message],
65 prompt: &str,
66 context: Option<&str>,
67 images: Option<&[ContentBlock]>,
68 ) -> (Vec<Value>, Option<String>);
69
70 fn build_tools(&self, tools: &[Value]) -> Vec<Value>;
72
73 fn supports_streaming(&self) -> bool {
75 true
76 }
77
78 fn supports_thinking(&self) -> bool {
80 false
81 }
82
83 fn supports_video(&self) -> bool {
90 false
91 }
92
93 fn supports_audio(&self) -> bool {
98 false
99 }
100
101 fn protocol_name(&self) -> &'static str {
105 "remote"
106 }
107}
108
109pub struct OpenAiHandler;
114
115impl ProtocolHandler for OpenAiHandler {
116 fn endpoint_path(&self) -> &str {
117 "/v1/chat/completions"
118 }
119
120 fn auth_headers(&self, api_key: &str) -> Vec<(String, String)> {
121 vec![
122 ("Authorization".into(), format!("Bearer {}", api_key)),
123 ("Content-Type".into(), "application/json".into()),
124 ]
125 }
126
127 fn build_request_body(&self, req: &ApiRequest) -> Value {
128 let quirks = openai_quirks(&req.model);
129 let mut body = serde_json::json!({
130 "model": req.model,
131 "messages": req.messages,
132 });
133 if quirks.uses_max_completion_tokens {
134 body["max_completion_tokens"] = serde_json::json!(req.max_tokens);
135 } else {
136 body["max_tokens"] = serde_json::json!(req.max_tokens);
137 }
138 if req.temperature >= 0.0 && !quirks.rejects_temperature {
139 body["temperature"] = serde_json::json!(req.temperature);
140 }
141 if let Some(ref tools) = req.tools {
142 body["tools"] = serde_json::json!(tools);
143 body["tool_choice"] = serde_json::json!(req.tool_choice.as_deref().unwrap_or("auto"));
144 if let Some(parallel_tool_calls) = req.parallel_tool_calls {
145 body["parallel_tool_calls"] = serde_json::json!(parallel_tool_calls);
146 }
147 }
148 match &req.response_format {
153 Some(ResponseFormat::JsonSchema {
154 schema,
155 strict,
156 name,
157 }) => {
158 body["response_format"] = serde_json::json!({
159 "type": "json_schema",
160 "json_schema": {
161 "name": name.as_deref().unwrap_or("response"),
162 "schema": schema,
163 "strict": strict,
164 },
165 });
166 }
167 Some(ResponseFormat::JsonObject) => {
168 body["response_format"] = serde_json::json!({ "type": "json_object" });
169 }
170 None => {}
171 }
172 if req.stream {
173 body["stream"] = serde_json::json!(true);
174 body["stream_options"] = serde_json::json!({ "include_usage": true });
180 }
181 body
182 }
183
184 fn parse_response(&self, body: &str) -> Result<ApiResponse, InferenceError> {
185 let parsed: Value = serde_json::from_str(body)
186 .map_err(|e| InferenceError::InferenceFailed(format!("parse response: {e}")))?;
187
188 let choice = parsed
189 .get("choices")
190 .and_then(|c| c.as_array())
191 .and_then(|a| a.first())
192 .ok_or_else(|| InferenceError::InferenceFailed("empty response".into()))?;
193
194 let message = choice
195 .get("message")
196 .ok_or_else(|| InferenceError::InferenceFailed("no message in choice".into()))?;
197
198 let text = message
199 .get("content")
200 .and_then(|c| c.as_str())
201 .unwrap_or("")
202 .to_string();
203
204 let mut tool_calls = Vec::new();
205 if let Some(tcs) = message.get("tool_calls").and_then(|t| t.as_array()) {
206 for tc in tcs {
207 if let Some(func) = tc.get("function") {
208 let id = tc.get("id").and_then(|i| i.as_str()).map(|s| s.to_string());
209 let name = func
210 .get("name")
211 .and_then(|n| n.as_str())
212 .unwrap_or("")
213 .to_string();
214 let args_str = func
215 .get("arguments")
216 .and_then(|a| a.as_str())
217 .unwrap_or("{}");
218 let arguments: HashMap<String, Value> =
219 serde_json::from_str(args_str).unwrap_or_default();
220 tool_calls.push(ToolCall {
221 id,
222 name,
223 arguments,
224 });
225 }
226 }
227 }
228
229 let usage = parsed.get("usage").and_then(|u| {
231 Some(crate::TokenUsage {
232 prompt_tokens: u.get("prompt_tokens").and_then(|v| v.as_u64()).unwrap_or(0),
233 completion_tokens: u
234 .get("completion_tokens")
235 .and_then(|v| v.as_u64())
236 .unwrap_or(0),
237 total_tokens: u.get("total_tokens").and_then(|v| v.as_u64()).unwrap_or(0),
238 context_window: 0, })
240 });
241
242 Ok(ApiResponse {
243 text,
244 tool_calls,
245 usage,
246 })
247 }
248
249 fn parse_stream_event(&self, _event_type: &str, data: &str) -> Vec<StreamEvent> {
250 crate::stream::parse_openai_sse_line(&format!("data: {}", data))
251 }
252
253 fn build_messages(
254 &self,
255 messages: &[Message],
256 prompt: &str,
257 context: Option<&str>,
258 images: Option<&[ContentBlock]>,
259 ) -> (Vec<Value>, Option<String>) {
260 if !messages.is_empty() {
261 let mut result = Vec::new();
262 if let Some(ctx) = context {
263 result.push(serde_json::json!({"role": "system", "content": ctx}));
264 }
265 for msg in messages {
266 match msg {
267 Message::System { content } => {
268 result.push(serde_json::json!({"role": "system", "content": content}));
269 }
270 Message::User { content } => {
271 result.push(serde_json::json!({"role": "user", "content": content}));
272 }
273 Message::UserMultimodal { content } => {
274 let blocks: Vec<Value> = content
275 .iter()
276 .map(|block| match block {
277 ContentBlock::Text { text } => {
278 serde_json::json!({"type": "text", "text": text})
279 }
280 ContentBlock::ImageBase64 { data, media_type } => {
281 serde_json::json!({
282 "type": "image_url",
283 "image_url": {
284 "url": format!("data:{};base64,{}", media_type, data),
285 }
286 })
287 }
288 ContentBlock::ImageUrl { url, detail } => {
289 serde_json::json!({
290 "type": "image_url",
291 "image_url": {
292 "url": url,
293 "detail": detail,
294 }
295 })
296 }
297 ContentBlock::VideoPath { .. }
310 | ContentBlock::VideoUrl { .. }
311 | ContentBlock::VideoBase64 { .. }
312 | ContentBlock::AudioPath { .. }
313 | ContentBlock::AudioUrl { .. }
314 | ContentBlock::AudioBase64 { .. } => {
315 unreachable!(
316 "video/audio ContentBlock reached OpenAI \
317 build_messages — should have been rejected \
318 by RemoteBackend::execute_request"
319 )
320 }
321 })
322 .collect();
323 result.push(serde_json::json!({"role": "user", "content": blocks}));
324 }
325 Message::Assistant {
326 content,
327 tool_calls,
328 } => {
329 if tool_calls.is_empty() {
330 result
331 .push(serde_json::json!({"role": "assistant", "content": content}));
332 } else {
333 let tc: Vec<Value> = tool_calls.iter().enumerate().map(|(i, tc)| {
334 let id = tc.id.clone().unwrap_or_else(|| format!("call_{}", i));
335 serde_json::json!({
336 "id": id,
337 "type": "function",
338 "function": {
339 "name": tc.name,
340 "arguments": serde_json::to_string(&tc.arguments).unwrap_or_default(),
341 }
342 })
343 }).collect();
344 let mut msg =
345 serde_json::json!({"role": "assistant", "tool_calls": tc});
346 if !content.is_empty() {
347 msg["content"] = serde_json::json!(content);
348 }
349 result.push(msg);
350 }
351 }
352 Message::ToolResult {
353 tool_use_id,
354 content,
355 } => {
356 result.push(serde_json::json!({
357 "role": "tool",
358 "tool_call_id": tool_use_id,
359 "content": content,
360 }));
361 }
362 Message::ProviderOutputItems { .. } => continue,
369 }
370 }
371 (result, None) } else {
373 let mut msgs = Vec::new();
374 if let Some(ctx) = context {
375 msgs.push(serde_json::json!({"role": "system", "content": ctx}));
376 }
377 if let Some(images) = images.filter(|images| !images.is_empty()) {
378 let mut blocks = vec![serde_json::json!({"type": "text", "text": prompt})];
379 for image in images {
380 let block = match image {
381 ContentBlock::Text { text } => {
382 serde_json::json!({"type": "text", "text": text})
383 }
384 ContentBlock::ImageBase64 { data, media_type } => {
385 serde_json::json!({
386 "type": "image_url",
387 "image_url": {
388 "url": format!("data:{};base64,{}", media_type, data),
389 }
390 })
391 }
392 ContentBlock::ImageUrl { url, detail } => {
393 serde_json::json!({
394 "type": "image_url",
395 "image_url": {
396 "url": url,
397 "detail": detail,
398 }
399 })
400 }
401 ContentBlock::VideoPath { .. }
402 | ContentBlock::VideoUrl { .. }
403 | ContentBlock::VideoBase64 { .. }
404 | ContentBlock::AudioPath { .. }
405 | ContentBlock::AudioUrl { .. }
406 | ContentBlock::AudioBase64 { .. } => {
407 unreachable!(
408 "video/audio ContentBlock reached OpenAI build_messages \
409 — should have been rejected by RemoteBackend::execute_request"
410 )
411 }
412 };
413 blocks.push(block);
414 }
415 msgs.push(serde_json::json!({"role": "user", "content": blocks}));
416 } else {
417 msgs.push(serde_json::json!({"role": "user", "content": prompt}));
418 }
419 (msgs, None)
420 }
421 }
422
423 fn build_tools(&self, tools: &[Value]) -> Vec<Value> {
424 tools
425 .iter()
426 .map(|t| {
427 if t.get("type").is_some() {
428 t.clone()
429 } else {
430 serde_json::json!({"type": "function", "function": t})
431 }
432 })
433 .collect()
434 }
435
436 fn protocol_name(&self) -> &'static str {
437 "openai"
438 }
439}
440
441struct OpenAiQuirks {
446 uses_max_completion_tokens: bool,
448 rejects_temperature: bool,
450}
451
452fn openai_quirks(model: &str) -> OpenAiQuirks {
453 let m = model.to_lowercase();
454 let is_o_series = m.starts_with("o1") || m.starts_with("o3") || m.starts_with("o4");
455 OpenAiQuirks {
456 uses_max_completion_tokens: is_o_series
457 || m.starts_with("gpt-5")
458 || m.starts_with("gpt-4.1"),
459 rejects_temperature: is_o_series,
460 }
461}
462
463pub struct AnthropicHandler;
468
469impl ProtocolHandler for AnthropicHandler {
470 fn endpoint_path(&self) -> &str {
471 "/v1/messages"
472 }
473
474 fn auth_headers(&self, api_key: &str) -> Vec<(String, String)> {
475 vec![
476 ("x-api-key".into(), api_key.to_string()),
477 ("anthropic-version".into(), "2023-06-01".into()),
478 ("anthropic-beta".into(), "prompt-caching-2024-07-31".into()),
479 ("Content-Type".into(), "application/json".into()),
480 ]
481 }
482
483 fn build_request_body(&self, req: &ApiRequest) -> Value {
484 let mut body = serde_json::json!({
485 "model": req.model,
486 "max_tokens": req.max_tokens,
487 "messages": req.messages,
488 });
489
490 if req.budget_tokens > 0 {
491 body["thinking"] = serde_json::json!({
492 "type": "enabled",
493 "budget_tokens": req.budget_tokens,
494 });
495 } else if req.temperature >= 0.0 {
497 body["temperature"] = serde_json::json!(req.temperature);
498 }
499
500 if let Some(ref system) = req.system {
501 if req.cache_control {
502 body["system"] = serde_json::json!([{
504 "type": "text",
505 "text": system,
506 "cache_control": {"type": "ephemeral"}
507 }]);
508 } else {
509 body["system"] = Value::String(system.clone());
510 }
511 }
512
513 if let Some(ref tools) = req.tools {
514 body["tools"] = Value::Array(tools.clone());
515 if req.cache_control && !tools.is_empty() {
516 if let Some(arr) = body["tools"].as_array_mut() {
518 if let Some(last) = arr.last_mut() {
519 if let Some(obj) = last.as_object_mut() {
520 obj.insert(
521 "cache_control".to_string(),
522 serde_json::json!({"type": "ephemeral"}),
523 );
524 }
525 }
526 }
527 }
528 body["tool_choice"] = serde_json::json!({"type": "auto"});
529 }
530
531 if req.response_format.is_some() {
538 tracing::warn!(
539 "response_format set on Anthropic request — Anthropic has no native \
540 JSON-schema field; the request will run unconstrained. Use tool_use \
541 with tool_choice=required to enforce a schema on Claude.",
542 );
543 }
544
545 if req.stream {
546 body["stream"] = serde_json::json!(true);
547 }
548
549 body
550 }
551
552 fn parse_response(&self, body: &str) -> Result<ApiResponse, InferenceError> {
553 let parsed: Value = serde_json::from_str(body)
554 .map_err(|e| InferenceError::InferenceFailed(format!("parse response: {e}")))?;
555
556 let mut text = String::new();
557 let mut tool_calls = Vec::new();
558 let mut thinking_text = String::new();
559
560 if let Some(content) = parsed.get("content").and_then(|c| c.as_array()) {
561 for block in content {
562 match block.get("type").and_then(|t| t.as_str()) {
563 Some("text") => {
564 if let Some(t) = block.get("text").and_then(|t| t.as_str()) {
565 text.push_str(t);
566 }
567 }
568 Some("thinking") => {
569 if let Some(t) = block.get("thinking").and_then(|t| t.as_str()) {
571 tracing::debug!(thinking_len = t.len(), "extended thinking block");
572 thinking_text.push_str(t);
573 }
574 }
575 Some("tool_use") => {
576 if let (Some(name), Some(input)) = (
577 block.get("name").and_then(|n| n.as_str()),
578 block.get("input"),
579 ) {
580 let id = block
581 .get("id")
582 .and_then(|i| i.as_str())
583 .map(|s| s.to_string());
584 let arguments: HashMap<String, Value> =
585 serde_json::from_value(input.clone()).unwrap_or_default();
586 tool_calls.push(ToolCall {
587 id,
588 name: name.to_string(),
589 arguments,
590 });
591 }
592 }
593 _ => {}
594 }
595 }
596 }
597
598 if text.is_empty() && !thinking_text.is_empty() && tool_calls.is_empty() {
602 tracing::warn!(
603 thinking_len = thinking_text.len(),
604 "response had only thinking blocks, using thinking content as text fallback"
605 );
606 text = thinking_text;
607 }
608
609 let usage = parsed.get("usage").and_then(|u| {
611 Some(crate::TokenUsage {
612 prompt_tokens: u.get("input_tokens").and_then(|v| v.as_u64()).unwrap_or(0),
613 completion_tokens: u.get("output_tokens").and_then(|v| v.as_u64()).unwrap_or(0),
614 total_tokens: u.get("input_tokens").and_then(|v| v.as_u64()).unwrap_or(0)
615 + u.get("output_tokens").and_then(|v| v.as_u64()).unwrap_or(0),
616 context_window: 0, })
618 });
619
620 Ok(ApiResponse {
621 text,
622 tool_calls,
623 usage,
624 })
625 }
626
627 fn parse_stream_event(&self, event_type: &str, data: &str) -> Vec<StreamEvent> {
628 crate::stream::parse_anthropic_sse_line(event_type, data)
629 }
630
631 fn build_messages(
632 &self,
633 messages: &[Message],
634 prompt: &str,
635 context: Option<&str>,
636 images: Option<&[ContentBlock]>,
637 ) -> (Vec<Value>, Option<String>) {
638 let mut system = context.map(|c| c.to_string());
646 if !messages.is_empty() {
647 for msg in messages {
648 if let Message::System { content } = msg {
649 system = Some(match system {
650 Some(existing) if !existing.is_empty() => {
651 format!("{existing}\n\n{content}")
652 }
653 _ => content.clone(),
654 });
655 }
656 }
657 }
658
659 if !messages.is_empty() {
660 let mut result = Vec::new();
661 for msg in messages {
662 match msg {
663 Message::System { .. } => continue,
665 Message::User { content } => {
666 result.push(serde_json::json!({"role": "user", "content": content}));
667 }
668 Message::UserMultimodal { content } => {
669 let blocks: Vec<Value> = content
670 .iter()
671 .map(|block| {
672 match block {
673 ContentBlock::Text { text } => {
674 serde_json::json!({"type": "text", "text": text})
675 }
676 ContentBlock::ImageBase64 { data, media_type } => {
677 serde_json::json!({
678 "type": "image",
679 "source": {
680 "type": "base64",
681 "media_type": media_type,
682 "data": data,
683 }
684 })
685 }
686 ContentBlock::ImageUrl { url, .. } => {
687 serde_json::json!({
689 "type": "image",
690 "source": {
691 "type": "url",
692 "url": url,
693 }
694 })
695 }
696 ContentBlock::VideoPath { .. }
703 | ContentBlock::VideoUrl { .. }
704 | ContentBlock::VideoBase64 { .. }
705 | ContentBlock::AudioPath { .. }
706 | ContentBlock::AudioUrl { .. }
707 | ContentBlock::AudioBase64 { .. } => {
708 unreachable!(
709 "video/audio ContentBlock reached Anthropic \
710 build_messages — should have been rejected \
711 by RemoteBackend::execute_request"
712 )
713 }
714 }
715 })
716 .collect();
717 result.push(serde_json::json!({"role": "user", "content": blocks}));
718 }
719 Message::Assistant {
720 content,
721 tool_calls,
722 } => {
723 let mut blocks: Vec<Value> = Vec::new();
724 if !content.is_empty() {
725 blocks.push(serde_json::json!({"type": "text", "text": content}));
726 }
727 for (i, tc) in tool_calls.iter().enumerate() {
728 let id = tc.id.clone().unwrap_or_else(|| format!("toolu_{}", i));
729 blocks.push(serde_json::json!({
730 "type": "tool_use",
731 "id": id,
732 "name": tc.name,
733 "input": tc.arguments,
734 }));
735 }
736 if blocks.is_empty() {
737 blocks.push(serde_json::json!({"type": "text", "text": ""}));
738 }
739 result.push(serde_json::json!({"role": "assistant", "content": blocks}));
740 }
741 Message::ToolResult {
742 tool_use_id,
743 content,
744 } => {
745 result.push(serde_json::json!({
746 "role": "user",
747 "content": [{
748 "type": "tool_result",
749 "tool_use_id": tool_use_id,
750 "content": content,
751 }]
752 }));
753 }
754 Message::ProviderOutputItems { .. } => continue,
759 }
760 }
761 (result, system)
762 } else {
763 let content = if let Some(images) = images.filter(|images| !images.is_empty()) {
764 let mut blocks = vec![serde_json::json!({"type": "text", "text": prompt})];
765 for image in images {
766 let block = match image {
767 ContentBlock::Text { text } => {
768 serde_json::json!({"type": "text", "text": text})
769 }
770 ContentBlock::ImageBase64 { data, media_type } => {
771 serde_json::json!({
772 "type": "image",
773 "source": {
774 "type": "base64",
775 "media_type": media_type,
776 "data": data,
777 }
778 })
779 }
780 ContentBlock::ImageUrl { url, .. } => {
781 serde_json::json!({
782 "type": "image",
783 "source": {
784 "type": "url",
785 "url": url,
786 }
787 })
788 }
789 ContentBlock::VideoPath { .. }
790 | ContentBlock::VideoUrl { .. }
791 | ContentBlock::VideoBase64 { .. }
792 | ContentBlock::AudioPath { .. }
793 | ContentBlock::AudioUrl { .. }
794 | ContentBlock::AudioBase64 { .. } => {
795 unreachable!(
796 "video/audio ContentBlock reached Anthropic build_messages \
797 — should have been rejected by RemoteBackend::execute_request"
798 )
799 }
800 };
801 blocks.push(block);
802 }
803 Value::Array(blocks)
804 } else {
805 Value::String(prompt.to_string())
806 };
807 let msgs = vec![serde_json::json!({"role": "user", "content": content})];
808 (msgs, system)
809 }
810 }
811
812 fn build_tools(&self, tools: &[Value]) -> Vec<Value> {
813 tools.iter().filter_map(|t| {
814 let func = t.get("function").unwrap_or(t);
815 Some(serde_json::json!({
816 "name": func.get("name")?,
817 "description": func.get("description").and_then(|d| d.as_str()).unwrap_or(""),
818 "input_schema": func.get("parameters").cloned().unwrap_or(serde_json::json!({"type": "object"})),
819 }))
820 }).collect()
821 }
822
823 fn supports_thinking(&self) -> bool {
824 true
825 }
826
827 fn protocol_name(&self) -> &'static str {
828 "anthropic"
829 }
830}
831
832pub struct GoogleHandler;
837
838impl ProtocolHandler for GoogleHandler {
839 fn endpoint_path(&self) -> &str {
840 ""
841 } fn auth_headers(&self, _api_key: &str) -> Vec<(String, String)> {
844 vec![("Content-Type".into(), "application/json".into())]
846 }
847
848 fn build_request_body(&self, req: &ApiRequest) -> Value {
849 let mut body = serde_json::json!({
850 "contents": req.messages,
851 });
852
853 if let Some(ref system) = req.system {
854 body["systemInstruction"] = serde_json::json!({
855 "parts": [{"text": system}],
856 });
857 }
858
859 let mut generation_config = serde_json::json!({
860 "maxOutputTokens": req.max_tokens,
861 });
862 if req.temperature >= 0.0 {
863 generation_config["temperature"] = serde_json::json!(req.temperature);
864 }
865 match &req.response_format {
872 Some(ResponseFormat::JsonSchema { schema, .. }) => {
873 generation_config["responseMimeType"] = serde_json::json!("application/json");
874 generation_config["responseSchema"] = schema.clone();
875 }
876 Some(ResponseFormat::JsonObject) => {
877 generation_config["responseMimeType"] = serde_json::json!("application/json");
878 }
879 None => {}
880 }
881 body["generationConfig"] = generation_config;
882
883 if let Some(ref tools) = req.tools {
884 body["tools"] = serde_json::json!([{
885 "functionDeclarations": tools,
886 }]);
887 body["toolConfig"] = serde_json::json!({
888 "functionCallingConfig": {
889 "mode": "AUTO",
890 }
891 });
892 }
893
894 body
895 }
896
897 fn parse_response(&self, body: &str) -> Result<ApiResponse, InferenceError> {
898 let parsed: Value = serde_json::from_str(body)
899 .map_err(|e| InferenceError::InferenceFailed(format!("parse response: {e}")))?;
900
901 let parts = parsed
902 .get("candidates")
903 .and_then(|c| c.as_array())
904 .and_then(|a| a.first())
905 .and_then(|c| c.get("content"))
906 .and_then(|c| c.get("parts"))
907 .and_then(|p| p.as_array())
908 .cloned()
909 .unwrap_or_default();
910
911 let mut text_chunks = Vec::new();
912 let mut tool_calls = Vec::new();
913 for part in parts {
914 if let Some(text) = part.get("text").and_then(|t| t.as_str()) {
915 text_chunks.push(text.to_string());
916 }
917 if let Some(function_call) = part
918 .get("functionCall")
919 .or_else(|| part.get("function_call"))
920 {
921 let name = function_call
922 .get("name")
923 .and_then(|n| n.as_str())
924 .unwrap_or_default()
925 .to_string();
926 let arguments = function_call
927 .get("args")
928 .or_else(|| function_call.get("arguments"))
929 .and_then(|args| args.as_object())
930 .map(|map| {
931 map.iter()
932 .map(|(k, v)| (k.clone(), v.clone()))
933 .collect::<HashMap<_, _>>()
934 })
935 .unwrap_or_default();
936 if !name.is_empty() {
937 tool_calls.push(ToolCall {
938 id: None,
939 name,
940 arguments,
941 });
942 }
943 }
944 }
945
946 let usage = parsed.get("usageMetadata").map(|usage| crate::TokenUsage {
947 prompt_tokens: usage
948 .get("promptTokenCount")
949 .and_then(|v| v.as_u64())
950 .unwrap_or(0),
951 completion_tokens: usage
952 .get("candidatesTokenCount")
953 .and_then(|v| v.as_u64())
954 .unwrap_or(0),
955 total_tokens: usage
956 .get("totalTokenCount")
957 .and_then(|v| v.as_u64())
958 .unwrap_or(0),
959 context_window: 0,
960 });
961
962 Ok(ApiResponse {
963 text: text_chunks.join("\n"),
964 tool_calls,
965 usage,
966 })
967 }
968
969 fn parse_stream_event(&self, _event_type: &str, _data: &str) -> Vec<StreamEvent> {
970 Vec::new() }
972
973 fn build_messages(
974 &self,
975 messages: &[Message],
976 prompt: &str,
977 context: Option<&str>,
978 images: Option<&[ContentBlock]>,
979 ) -> (Vec<Value>, Option<String>) {
980 let mut system_instruction: Option<String> = context.map(|c| c.to_string());
985 if !messages.is_empty() {
986 for msg in messages {
987 if let Message::System { content } = msg {
988 system_instruction = Some(match system_instruction {
989 Some(existing) if !existing.is_empty() => {
990 format!("{existing}\n\n{content}")
991 }
992 _ => content.clone(),
993 });
994 }
995 }
996 }
997
998 if !messages.is_empty() {
999 let contents = messages
1000 .iter()
1001 .filter(|msg| {
1008 !matches!(
1009 msg,
1010 Message::System { .. } | Message::ProviderOutputItems { .. }
1011 )
1012 })
1013 .map(|msg| match msg {
1014 Message::System { .. } => unreachable!("System filtered above"),
1017 Message::ProviderOutputItems { .. } => {
1018 unreachable!("ProviderOutputItems filtered above")
1019 }
1020 Message::User { content } => serde_json::json!({
1021 "role": "user",
1022 "parts": [{"text": content}],
1023 }),
1024 Message::UserMultimodal { content } => serde_json::json!({
1025 "role": "user",
1026 "parts": content.iter().map(google_part_from_block).collect::<Vec<_>>(),
1027 }),
1028 Message::Assistant {
1029 content,
1030 tool_calls,
1031 } => {
1032 let mut parts = Vec::new();
1033 if !content.is_empty() {
1034 parts.push(serde_json::json!({"text": content}));
1035 }
1036 for tool_call in tool_calls {
1037 parts.push(serde_json::json!({
1038 "functionCall": {
1039 "name": tool_call.name,
1040 "args": tool_call.arguments,
1041 }
1042 }));
1043 }
1044 serde_json::json!({
1045 "role": "model",
1046 "parts": parts,
1047 })
1048 }
1049 Message::ToolResult {
1050 tool_use_id,
1051 content,
1052 } => serde_json::json!({
1053 "role": "tool",
1054 "parts": [{
1055 "functionResponse": {
1056 "name": tool_use_id,
1057 "response": {"content": content},
1058 }
1059 }],
1060 }),
1061 })
1062 .collect();
1063 (contents, system_instruction)
1064 } else {
1065 let mut parts = vec![serde_json::json!({"text": prompt})];
1066 if let Some(images) = images.filter(|images| !images.is_empty()) {
1067 parts.extend(images.iter().map(google_part_from_block));
1068 }
1069 (
1070 vec![serde_json::json!({
1071 "role": "user",
1072 "parts": parts,
1073 })],
1074 system_instruction,
1075 )
1076 }
1077 }
1078
1079 fn build_tools(&self, tools: &[Value]) -> Vec<Value> {
1080 tools
1081 .iter()
1082 .filter_map(|t| {
1083 let func = t.get("function").unwrap_or(t);
1084 Some(serde_json::json!({
1085 "name": func.get("name")?,
1086 "description": func.get("description").and_then(|d| d.as_str()).unwrap_or(""),
1087 "parameters": func.get("parameters").cloned().unwrap_or(serde_json::json!({"type": "object"})),
1088 }))
1089 })
1090 .collect()
1091 }
1092
1093 fn supports_streaming(&self) -> bool {
1094 false
1095 }
1096
1097 fn supports_video(&self) -> bool {
1098 true
1102 }
1103
1104 fn supports_audio(&self) -> bool {
1105 true
1108 }
1109
1110 fn protocol_name(&self) -> &'static str {
1111 "google-gemini"
1112 }
1113}
1114
1115fn google_part_from_block(block: &ContentBlock) -> Value {
1116 match block {
1117 ContentBlock::Text { text } => serde_json::json!({"text": text}),
1118 ContentBlock::ImageBase64 { data, media_type } => serde_json::json!({
1119 "inlineData": {
1120 "mimeType": media_type,
1121 "data": data,
1122 }
1123 }),
1124 ContentBlock::ImageUrl { url, .. } => serde_json::json!({
1125 "fileData": {
1126 "mimeType": infer_mime_type_from_url(url),
1127 "fileUri": url,
1128 }
1129 }),
1130 ContentBlock::VideoPath { path, .. } => serde_json::json!({
1135 "fileData": {
1136 "mimeType": "video/mp4",
1137 "fileUri": format!("file://{path}"),
1138 }
1139 }),
1140 ContentBlock::VideoUrl { url, .. } => serde_json::json!({
1141 "fileData": {
1142 "mimeType": "video/mp4",
1143 "fileUri": url,
1144 }
1145 }),
1146 ContentBlock::VideoBase64 {
1147 data, media_type, ..
1148 } => serde_json::json!({
1149 "inlineData": {
1150 "mimeType": media_type,
1151 "data": data,
1152 }
1153 }),
1154 ContentBlock::AudioPath { path, .. } => serde_json::json!({
1158 "fileData": {
1159 "mimeType": "audio/wav",
1160 "fileUri": format!("file://{path}"),
1161 }
1162 }),
1163 ContentBlock::AudioUrl { url, .. } => serde_json::json!({
1164 "fileData": {
1165 "mimeType": "audio/wav",
1166 "fileUri": url,
1167 }
1168 }),
1169 ContentBlock::AudioBase64 {
1170 data, media_type, ..
1171 } => serde_json::json!({
1172 "inlineData": {
1173 "mimeType": media_type,
1174 "data": data,
1175 }
1176 }),
1177 }
1178}
1179
1180fn infer_mime_type_from_url(url: &str) -> &'static str {
1181 let lower = url.to_ascii_lowercase();
1182 if lower.ends_with(".png") {
1183 "image/png"
1184 } else if lower.ends_with(".webp") {
1185 "image/webp"
1186 } else if lower.ends_with(".heic") {
1187 "image/heic"
1188 } else if lower.ends_with(".heif") {
1189 "image/heif"
1190 } else {
1191 "image/jpeg"
1192 }
1193}
1194
1195pub fn handler_for(protocol: crate::schema::ApiProtocol) -> Box<dyn ProtocolHandler> {
1201 match protocol {
1202 crate::schema::ApiProtocol::OpenAiCompat | crate::schema::ApiProtocol::OpenAiResponses => {
1203 Box::new(OpenAiHandler)
1204 }
1205 crate::schema::ApiProtocol::Anthropic => Box::new(AnthropicHandler),
1206 crate::schema::ApiProtocol::Google => Box::new(GoogleHandler),
1207 crate::schema::ApiProtocol::AzureOpenAi => Box::new(AzureOpenAiHandler),
1208 }
1209}
1210
1211pub struct AzureOpenAiHandler;
1216
1217impl ProtocolHandler for AzureOpenAiHandler {
1218 fn endpoint_path(&self) -> &str {
1219 "/openai/deployments"
1221 }
1222
1223 fn auth_headers(&self, api_key: &str) -> Vec<(String, String)> {
1224 vec![
1225 ("api-key".into(), api_key.to_string()),
1226 ("Content-Type".into(), "application/json".into()),
1227 ]
1228 }
1229
1230 fn build_request_body(&self, req: &ApiRequest) -> Value {
1231 OpenAiHandler.build_request_body(req)
1233 }
1234
1235 fn parse_response(&self, body: &str) -> Result<ApiResponse, crate::InferenceError> {
1236 OpenAiHandler.parse_response(body)
1237 }
1238
1239 fn parse_stream_event(&self, event_type: &str, data: &str) -> Vec<crate::stream::StreamEvent> {
1240 OpenAiHandler.parse_stream_event(event_type, data)
1241 }
1242
1243 fn build_messages(
1244 &self,
1245 messages: &[Message],
1246 prompt: &str,
1247 context: Option<&str>,
1248 images: Option<&[ContentBlock]>,
1249 ) -> (Vec<Value>, Option<String>) {
1250 OpenAiHandler.build_messages(messages, prompt, context, images)
1251 }
1252
1253 fn build_tools(&self, tools: &[Value]) -> Vec<Value> {
1254 OpenAiHandler.build_tools(tools)
1255 }
1256
1257 fn protocol_name(&self) -> &'static str {
1258 "azure-openai"
1259 }
1260}
1261
1262pub fn google_url(endpoint: &str, model: &str, api_key: &str) -> String {
1268 let base = endpoint.trim_end_matches('/');
1269 format!(
1270 "{}/v1beta/models/{}:generateContent?key={}",
1271 base, model, api_key
1272 )
1273}
1274
1275#[cfg(test)]
1276mod tests {
1277 use super::*;
1278
1279 #[test]
1280 fn openai_single_turn_messages() {
1281 let handler = OpenAiHandler;
1282 let (msgs, system) = handler.build_messages(&[], "Hello", Some("Be helpful"), None);
1283 assert_eq!(msgs.len(), 2);
1284 assert_eq!(msgs[0]["role"], "system");
1285 assert_eq!(msgs[1]["content"], "Hello");
1286 assert!(system.is_none()); }
1288
1289 #[test]
1290 fn openai_multi_turn_messages() {
1291 let handler = OpenAiHandler;
1292 let messages = vec![
1293 Message::User {
1294 content: "Hi".into(),
1295 },
1296 Message::Assistant {
1297 content: "Hello!".into(),
1298 tool_calls: vec![],
1299 },
1300 Message::User {
1301 content: "Search for X".into(),
1302 },
1303 ];
1304 let (msgs, _) = handler.build_messages(&messages, "", None, None);
1305 assert_eq!(msgs.len(), 3);
1306 assert_eq!(msgs[2]["content"], "Search for X");
1307 }
1308
1309 #[test]
1310 fn openai_tool_call_messages() {
1311 let handler = OpenAiHandler;
1312 let tc = ToolCall {
1313 id: None,
1314 name: "search".into(),
1315 arguments: [("q".into(), Value::String("rust".into()))].into(),
1316 };
1317 let messages = vec![
1318 Message::User {
1319 content: "Search".into(),
1320 },
1321 Message::Assistant {
1322 content: String::new(),
1323 tool_calls: vec![tc],
1324 },
1325 Message::ToolResult {
1326 tool_use_id: "call_0".into(),
1327 content: "found it".into(),
1328 },
1329 ];
1330 let (msgs, _) = handler.build_messages(&messages, "", None, None);
1331 assert_eq!(msgs.len(), 3);
1332 assert!(msgs[1].get("tool_calls").is_some());
1333 assert_eq!(msgs[2]["role"], "tool");
1334 }
1335
1336 #[test]
1337 fn anthropic_system_separate() {
1338 let handler = AnthropicHandler;
1339 let (msgs, system) = handler.build_messages(&[], "Hello", Some("Be helpful"), None);
1340 assert_eq!(msgs.len(), 1); assert_eq!(system, Some("Be helpful".into()));
1342 }
1343
1344 #[test]
1345 fn anthropic_tool_use_format() {
1346 let handler = AnthropicHandler;
1347 let tc = ToolCall {
1348 id: None,
1349 name: "search".into(),
1350 arguments: [("q".into(), Value::String("test".into()))].into(),
1351 };
1352 let messages = vec![
1353 Message::User {
1354 content: "Search".into(),
1355 },
1356 Message::Assistant {
1357 content: String::new(),
1358 tool_calls: vec![tc],
1359 },
1360 Message::ToolResult {
1361 tool_use_id: "toolu_0".into(),
1362 content: "result".into(),
1363 },
1364 ];
1365 let (msgs, _) = handler.build_messages(&messages, "", None, None);
1366 assert_eq!(msgs.len(), 3);
1367 let assistant_content = msgs[1].get("content").unwrap().as_array().unwrap();
1368 assert_eq!(assistant_content[0]["type"], "tool_use");
1369 let user_content = msgs[2].get("content").unwrap().as_array().unwrap();
1370 assert_eq!(user_content[0]["type"], "tool_result");
1371 }
1372
1373 #[test]
1374 fn anthropic_thinking_in_request() {
1375 let handler = AnthropicHandler;
1376 let req = ApiRequest {
1377 model: "claude".into(),
1378 messages: vec![serde_json::json!({"role": "user", "content": "plan"})],
1379 system: None,
1380 temperature: 0.7,
1381 max_tokens: 4096,
1382 tools: None,
1383 tool_choice: None,
1384 parallel_tool_calls: None,
1385 stream: false,
1386 budget_tokens: 8000,
1387 cache_control: false,
1388 response_format: None,
1389 };
1390 let body = handler.build_request_body(&req);
1391 assert!(body.get("thinking").is_some());
1392 assert_eq!(body["thinking"]["budget_tokens"], 8000);
1393 assert!(body.get("temperature").is_none()); }
1395
1396 fn empty_request(model: &str) -> ApiRequest {
1397 ApiRequest {
1398 model: model.into(),
1399 messages: vec![serde_json::json!({"role": "user", "content": "hi"})],
1400 system: None,
1401 temperature: 0.7,
1402 max_tokens: 256,
1403 tools: None,
1404 tool_choice: None,
1405 parallel_tool_calls: None,
1406 stream: false,
1407 budget_tokens: 0,
1408 cache_control: false,
1409 response_format: None,
1410 }
1411 }
1412
1413 #[test]
1414 fn openai_streaming_request_carries_stream_options_include_usage() {
1415 let handler = OpenAiHandler;
1422 let mut req = empty_request("gpt-5");
1423 req.stream = true;
1424 let body = handler.build_request_body(&req);
1425 assert_eq!(body["stream"], serde_json::json!(true));
1426 assert_eq!(
1427 body["stream_options"]["include_usage"],
1428 serde_json::json!(true),
1429 "streaming bodies must include `stream_options.include_usage` so usage flows back"
1430 );
1431 }
1432
1433 #[test]
1434 fn openai_non_streaming_request_omits_stream_options() {
1435 let handler = OpenAiHandler;
1439 let req = empty_request("gpt-5");
1440 let body = handler.build_request_body(&req);
1441 assert!(body.get("stream").is_none());
1442 assert!(body.get("stream_options").is_none());
1443 }
1444
1445 #[test]
1446 fn anthropic_streaming_request_marks_stream_true() {
1447 let handler = AnthropicHandler;
1452 let mut req = empty_request("claude-opus-4-7");
1453 req.stream = true;
1454 let body = handler.build_request_body(&req);
1455 assert_eq!(body["stream"], serde_json::json!(true));
1456 assert!(
1457 body.get("stream_options").is_none(),
1458 "Anthropic has no stream_options field; usage flows via SSE frames"
1459 );
1460 }
1461
1462 #[test]
1463 fn openai_emits_strict_json_schema_response_format() {
1464 let mut req = empty_request("gpt-5");
1465 req.response_format = Some(ResponseFormat::JsonSchema {
1466 schema: serde_json::json!({
1467 "type": "object",
1468 "properties": {"answer": {"type": "string"}},
1469 "required": ["answer"]
1470 }),
1471 strict: true,
1472 name: Some("answer_schema".into()),
1473 });
1474 let body = OpenAiHandler.build_request_body(&req);
1475 let rf = body.get("response_format").expect("response_format set");
1476 assert_eq!(rf["type"], "json_schema");
1477 assert_eq!(rf["json_schema"]["name"], "answer_schema");
1478 assert_eq!(rf["json_schema"]["strict"], true);
1479 assert_eq!(rf["json_schema"]["schema"]["required"][0], "answer");
1480 }
1481
1482 #[test]
1483 fn openai_emits_json_object_when_no_schema() {
1484 let mut req = empty_request("gpt-4o");
1485 req.response_format = Some(ResponseFormat::JsonObject);
1486 let body = OpenAiHandler.build_request_body(&req);
1487 assert_eq!(body["response_format"]["type"], "json_object");
1488 assert!(body["response_format"].get("json_schema").is_none());
1490 }
1491
1492 #[test]
1493 fn openai_omits_response_format_when_none() {
1494 let req = empty_request("gpt-4o");
1495 let body = OpenAiHandler.build_request_body(&req);
1496 assert!(body.get("response_format").is_none());
1497 }
1498
1499 #[test]
1500 fn google_emits_response_mime_and_schema() {
1501 let mut req = empty_request("gemini-2.5-pro");
1502 req.response_format = Some(ResponseFormat::JsonSchema {
1503 schema: serde_json::json!({"type": "object"}),
1504 strict: false,
1505 name: None,
1506 });
1507 let body = GoogleHandler.build_request_body(&req);
1508 let cfg = body.get("generationConfig").expect("generationConfig");
1509 assert_eq!(cfg["responseMimeType"], "application/json");
1510 assert_eq!(cfg["responseSchema"]["type"], "object");
1511 }
1512
1513 #[test]
1514 fn google_json_object_skips_schema() {
1515 let mut req = empty_request("gemini-2.5-pro");
1516 req.response_format = Some(ResponseFormat::JsonObject);
1517 let body = GoogleHandler.build_request_body(&req);
1518 let cfg = body.get("generationConfig").expect("generationConfig");
1519 assert_eq!(cfg["responseMimeType"], "application/json");
1520 assert!(cfg.get("responseSchema").is_none());
1521 }
1522
1523 #[test]
1524 fn anthropic_does_not_emit_response_format_field() {
1525 let mut req = empty_request("claude-opus-4-7");
1529 req.response_format = Some(ResponseFormat::JsonSchema {
1530 schema: serde_json::json!({"type": "object"}),
1531 strict: true,
1532 name: None,
1533 });
1534 let body = AnthropicHandler.build_request_body(&req);
1535 assert!(body.get("response_format").is_none());
1536 assert!(body.get("responseSchema").is_none());
1537 }
1538
1539 #[test]
1540 fn openai_tools_wrapped() {
1541 let handler = OpenAiHandler;
1542 let tools = vec![serde_json::json!({"name": "search", "parameters": {}})];
1543 let built = handler.build_tools(&tools);
1544 assert_eq!(built[0]["type"], "function");
1545 assert!(built[0].get("function").is_some());
1546 }
1547
1548 #[test]
1549 fn openai_request_preserves_required_tool_choice_and_parallel_tool_calls() {
1550 let handler = OpenAiHandler;
1551 let req = ApiRequest {
1552 model: "gpt-5.4-mini".into(),
1553 messages: vec![serde_json::json!({"role": "user", "content": "extract"})],
1554 system: None,
1555 temperature: 0.0,
1556 max_tokens: 1024,
1557 tools: Some(vec![serde_json::json!({
1558 "type": "function",
1559 "function": {
1560 "name": "extract_action_items",
1561 "parameters": {"type": "object", "additionalProperties": false}
1562 }
1563 })]),
1564 tool_choice: Some("required".into()),
1565 parallel_tool_calls: Some(false),
1566 stream: false,
1567 budget_tokens: 0,
1568 cache_control: false,
1569 response_format: None,
1570 };
1571
1572 let body = handler.build_request_body(&req);
1573
1574 assert_eq!(body["tool_choice"], "required");
1575 assert_eq!(body["parallel_tool_calls"], false);
1576 assert_eq!(body["tools"][0]["function"]["name"], "extract_action_items");
1577 }
1578
1579 #[test]
1580 fn anthropic_tools_format() {
1581 let handler = AnthropicHandler;
1582 let tools = vec![
1583 serde_json::json!({"function": {"name": "search", "description": "Search", "parameters": {"type": "object"}}}),
1584 ];
1585 let built = handler.build_tools(&tools);
1586 assert_eq!(built[0]["name"], "search");
1587 assert!(built[0].get("input_schema").is_some());
1588 }
1589
1590 #[test]
1591 fn google_no_streaming() {
1592 let handler = GoogleHandler;
1593 assert!(!handler.supports_streaming());
1594 }
1595
1596 #[test]
1597 fn anthropic_supports_thinking() {
1598 let handler = AnthropicHandler;
1599 assert!(handler.supports_thinking());
1600 }
1601
1602 #[test]
1603 fn handler_factory() {
1604 use crate::schema::ApiProtocol;
1605 let h = handler_for(ApiProtocol::Anthropic);
1606 assert!(h.supports_thinking());
1607
1608 let h = handler_for(ApiProtocol::OpenAiCompat);
1609 assert!(!h.supports_thinking());
1610 }
1611
1612 #[test]
1613 fn openai_parse_text_response() {
1614 let handler = OpenAiHandler;
1615 let body = r#"{"choices":[{"message":{"content":"Hello world"}}]}"#;
1616 let resp = handler.parse_response(body).unwrap();
1617 assert_eq!(resp.text, "Hello world");
1618 assert!(resp.tool_calls.is_empty());
1619 }
1620
1621 #[test]
1622 fn openai_parse_usage() {
1623 let handler = OpenAiHandler;
1624 let body = r#"{"choices":[{"message":{"content":"Hi"}}],"usage":{"prompt_tokens":10,"completion_tokens":5,"total_tokens":15}}"#;
1625 let resp = handler.parse_response(body).unwrap();
1626 let usage = resp.usage.unwrap();
1627 assert_eq!(usage.prompt_tokens, 10);
1628 assert_eq!(usage.completion_tokens, 5);
1629 assert_eq!(usage.total_tokens, 15);
1630 }
1631
1632 #[test]
1633 fn openai_parse_multiple_tool_calls() {
1634 let handler = OpenAiHandler;
1635 let body = r#"{"choices":[{"message":{"content":"","tool_calls":[{"function":{"name":"read_file","arguments":"{\"path\":\"a.rs\"}"}},{"function":{"name":"read_file","arguments":"{\"path\":\"b.rs\"}"}}]}}]}"#;
1636 let resp = handler.parse_response(body).unwrap();
1637 assert_eq!(resp.tool_calls.len(), 2);
1638 assert_eq!(resp.tool_calls[0].name, "read_file");
1639 assert_eq!(resp.tool_calls[1].name, "read_file");
1640 }
1641
1642 #[test]
1643 fn anthropic_parse_tool_response() {
1644 let handler = AnthropicHandler;
1645 let body = r#"{"content":[{"type":"text","text":"Let me search"},{"type":"tool_use","name":"search","id":"t1","input":{"q":"rust"}}]}"#;
1646 let resp = handler.parse_response(body).unwrap();
1647 assert_eq!(resp.text, "Let me search");
1648 assert_eq!(resp.tool_calls.len(), 1);
1649 assert_eq!(resp.tool_calls[0].name, "search");
1650 }
1651
1652 #[test]
1653 fn anthropic_parse_usage() {
1654 let handler = AnthropicHandler;
1655 let body = r#"{"content":[{"type":"text","text":"Hi"}],"usage":{"input_tokens":12,"output_tokens":3}}"#;
1656 let resp = handler.parse_response(body).unwrap();
1657 let usage = resp.usage.unwrap();
1658 assert_eq!(usage.prompt_tokens, 12);
1659 assert_eq!(usage.completion_tokens, 3);
1660 assert_eq!(usage.total_tokens, 15);
1661 }
1662
1663 #[test]
1664 fn anthropic_parse_multiple_tool_calls() {
1665 let handler = AnthropicHandler;
1666 let body = r#"{"content":[{"type":"text","text":"I'll read both files"},{"type":"tool_use","name":"read","id":"t1","input":{"path":"a.rs"}},{"type":"tool_use","name":"read","id":"t2","input":{"path":"b.rs"}}]}"#;
1667 let resp = handler.parse_response(body).unwrap();
1668 assert_eq!(resp.text, "I'll read both files");
1669 assert_eq!(resp.tool_calls.len(), 2);
1670 assert_eq!(resp.tool_calls[0].name, "read");
1671 assert_eq!(resp.tool_calls[1].name, "read");
1672 }
1673
1674 #[test]
1675 fn anthropic_cache_control_system_prompt() {
1676 let handler = AnthropicHandler;
1677 let req = ApiRequest {
1678 model: "claude".into(),
1679 messages: vec![serde_json::json!({"role": "user", "content": "hello"})],
1680 system: Some("You are helpful.".into()),
1681 temperature: 0.7,
1682 max_tokens: 1024,
1683 tools: None,
1684 tool_choice: None,
1685 parallel_tool_calls: None,
1686 stream: false,
1687 budget_tokens: 0,
1688 cache_control: true,
1689 response_format: None,
1690 };
1691 let body = handler.build_request_body(&req);
1692 let system = body.get("system").unwrap();
1694 assert!(system.is_array());
1695 let blocks = system.as_array().unwrap();
1696 assert_eq!(blocks.len(), 1);
1697 assert_eq!(blocks[0]["type"], "text");
1698 assert_eq!(blocks[0]["text"], "You are helpful.");
1699 assert_eq!(blocks[0]["cache_control"]["type"], "ephemeral");
1700 }
1701
1702 #[test]
1703 fn anthropic_cache_control_disabled() {
1704 let handler = AnthropicHandler;
1705 let req = ApiRequest {
1706 model: "claude".into(),
1707 messages: vec![serde_json::json!({"role": "user", "content": "hello"})],
1708 system: Some("You are helpful.".into()),
1709 temperature: 0.7,
1710 max_tokens: 1024,
1711 tools: None,
1712 tool_choice: None,
1713 parallel_tool_calls: None,
1714 stream: false,
1715 budget_tokens: 0,
1716 cache_control: false,
1717 response_format: None,
1718 };
1719 let body = handler.build_request_body(&req);
1720 assert!(body.get("system").unwrap().is_string());
1722 }
1723
1724 #[test]
1725 fn anthropic_cache_control_tools() {
1726 let handler = AnthropicHandler;
1727 let tools = vec![
1728 serde_json::json!({"name": "search", "description": "Search", "input_schema": {"type": "object"}}),
1729 serde_json::json!({"name": "read", "description": "Read file", "input_schema": {"type": "object"}}),
1730 ];
1731 let req = ApiRequest {
1732 model: "claude".into(),
1733 messages: vec![serde_json::json!({"role": "user", "content": "hello"})],
1734 system: None,
1735 temperature: 0.7,
1736 max_tokens: 1024,
1737 tools: Some(tools),
1738 tool_choice: None,
1739 parallel_tool_calls: None,
1740 stream: false,
1741 budget_tokens: 0,
1742 cache_control: true,
1743 response_format: None,
1744 };
1745 let body = handler.build_request_body(&req);
1746 let tools_arr = body["tools"].as_array().unwrap();
1747 assert!(tools_arr[0].get("cache_control").is_none());
1749 assert_eq!(tools_arr[1]["cache_control"]["type"], "ephemeral");
1750 }
1751
1752 #[test]
1753 fn anthropic_beta_header_included() {
1754 let handler = AnthropicHandler;
1755 let headers = handler.auth_headers("test-key");
1756 let beta = headers.iter().find(|(k, _)| k == "anthropic-beta");
1757 assert!(beta.is_some());
1758 assert_eq!(beta.unwrap().1, "prompt-caching-2024-07-31");
1759 }
1760
1761 #[test]
1762 fn google_parse_response() {
1763 let handler = GoogleHandler;
1764 let body = r#"{"candidates":[{"content":{"parts":[{"text":"Hello from Gemini"}]}}]}"#;
1765 let resp = handler.parse_response(body).unwrap();
1766 assert_eq!(resp.text, "Hello from Gemini");
1767 }
1768
1769 #[test]
1770 fn google_tools_format() {
1771 let handler = GoogleHandler;
1772 let tools = vec![serde_json::json!({
1773 "function": {
1774 "name": "search",
1775 "description": "Search docs",
1776 "parameters": {"type": "object"}
1777 }
1778 })];
1779 let built = handler.build_tools(&tools);
1780 assert_eq!(built[0]["name"], "search");
1781 assert!(built[0].get("parameters").is_some());
1782 }
1783
1784 #[test]
1785 fn google_builds_multimodal_messages() {
1786 let handler = GoogleHandler;
1787 let messages = vec![Message::UserMultimodal {
1788 content: vec![
1789 ContentBlock::Text {
1790 text: "Describe this image.".to_string(),
1791 },
1792 ContentBlock::ImageUrl {
1793 url: "https://example.com/cat.jpg".to_string(),
1794 detail: "auto".to_string(),
1795 },
1796 ],
1797 }];
1798 let (msgs, system) = handler.build_messages(&messages, "", Some("Be concise"), None);
1799 assert_eq!(msgs.len(), 1);
1800 assert_eq!(msgs[0]["role"], "user");
1801 let parts = msgs[0]["parts"].as_array().unwrap();
1802 assert_eq!(parts[0]["text"], "Describe this image.");
1803 assert!(parts[1].get("fileData").is_some());
1804 assert_eq!(system, Some("Be concise".to_string()));
1805 }
1806
1807 #[test]
1808 fn google_request_body_includes_tools_and_system() {
1809 let handler = GoogleHandler;
1810 let req = ApiRequest {
1811 model: "gemini-2.5-flash".into(),
1812 messages: vec![serde_json::json!({
1813 "role": "user",
1814 "parts": [{"text": "Find the file and summarize it."}],
1815 })],
1816 system: Some("Use tools when needed.".into()),
1817 temperature: 0.2,
1818 max_tokens: 512,
1819 tools: Some(vec![serde_json::json!({
1820 "name": "search",
1821 "description": "Search files",
1822 "parameters": {"type": "object"}
1823 })]),
1824 tool_choice: None,
1825 parallel_tool_calls: None,
1826 stream: false,
1827 budget_tokens: 0,
1828 cache_control: false,
1829 response_format: None,
1830 };
1831 let body = handler.build_request_body(&req);
1832 assert!(body.get("systemInstruction").is_some());
1833 assert!(body.get("tools").is_some());
1834 assert_eq!(body["toolConfig"]["functionCallingConfig"]["mode"], "AUTO");
1835 assert_eq!(body["generationConfig"]["maxOutputTokens"], 512);
1836 }
1837
1838 #[test]
1839 fn google_parse_multiple_tool_calls_and_usage() {
1840 let handler = GoogleHandler;
1841 let body = r#"{
1842 "candidates":[{"content":{"parts":[
1843 {"text":"Let me do that."},
1844 {"functionCall":{"name":"search","args":{"q":"rust"}}},
1845 {"functionCall":{"name":"read_file","args":{"path":"src/lib.rs"}}}
1846 ]}}],
1847 "usageMetadata":{"promptTokenCount":10,"candidatesTokenCount":4,"totalTokenCount":14}
1848 }"#;
1849 let resp = handler.parse_response(body).unwrap();
1850 assert_eq!(resp.text, "Let me do that.");
1851 assert_eq!(resp.tool_calls.len(), 2);
1852 assert_eq!(resp.tool_calls[0].name, "search");
1853 assert_eq!(resp.tool_calls[1].name, "read_file");
1854 let usage = resp.usage.unwrap();
1855 assert_eq!(usage.prompt_tokens, 10);
1856 assert_eq!(usage.completion_tokens, 4);
1857 assert_eq!(usage.total_tokens, 14);
1858 }
1859
1860 #[test]
1861 fn openai_vision_message() {
1862 let handler = OpenAiHandler;
1863 let messages = vec![Message::UserMultimodal {
1864 content: vec![
1865 ContentBlock::Text {
1866 text: "What is in this image?".to_string(),
1867 },
1868 ContentBlock::ImageUrl {
1869 url: "https://example.com/cat.jpg".to_string(),
1870 detail: "auto".to_string(),
1871 },
1872 ],
1873 }];
1874 let (msgs, _) = handler.build_messages(&messages, "", None, None);
1875 assert_eq!(msgs.len(), 1);
1876 let content = msgs[0]["content"].as_array().unwrap();
1877 assert_eq!(content.len(), 2);
1878 assert_eq!(content[0]["type"], "text");
1879 assert_eq!(content[1]["type"], "image_url");
1880 }
1881
1882 #[test]
1883 fn anthropic_vision_message() {
1884 let handler = AnthropicHandler;
1885 let messages = vec![Message::UserMultimodal {
1886 content: vec![
1887 ContentBlock::Text {
1888 text: "Describe this.".to_string(),
1889 },
1890 ContentBlock::ImageBase64 {
1891 data: "iVBOR...".to_string(),
1892 media_type: "image/png".to_string(),
1893 },
1894 ],
1895 }];
1896 let (msgs, _) = handler.build_messages(&messages, "", None, None);
1897 assert_eq!(msgs.len(), 1);
1898 let content = msgs[0]["content"].as_array().unwrap();
1899 assert_eq!(content[0]["type"], "text");
1900 assert_eq!(content[1]["type"], "image");
1901 assert_eq!(content[1]["source"]["type"], "base64");
1902 }
1903
1904 #[test]
1905 fn openai_single_turn_images() {
1906 let handler = OpenAiHandler;
1907 let images = vec![ContentBlock::ImageUrl {
1908 url: "https://example.com/cat.jpg".to_string(),
1909 detail: "high".to_string(),
1910 }];
1911 let (msgs, _) = handler.build_messages(&[], "Describe this image", None, Some(&images));
1912 let content = msgs[0]["content"].as_array().unwrap();
1913 assert_eq!(content.len(), 2);
1914 assert_eq!(content[0]["type"], "text");
1915 assert_eq!(content[1]["type"], "image_url");
1916 }
1917
1918 #[test]
1919 fn anthropic_single_turn_images() {
1920 let handler = AnthropicHandler;
1921 let images = vec![ContentBlock::ImageBase64 {
1922 data: "iVBOR...".to_string(),
1923 media_type: "image/png".to_string(),
1924 }];
1925 let (msgs, _) = handler.build_messages(&[], "Describe this image", None, Some(&images));
1926 let content = msgs[0]["content"].as_array().unwrap();
1927 assert_eq!(content.len(), 2);
1928 assert_eq!(content[0]["type"], "text");
1929 assert_eq!(content[1]["type"], "image");
1930 assert_eq!(content[1]["source"]["type"], "base64");
1931 }
1932}