1use crate::tools::ToolSpec;
2use async_trait::async_trait;
3use futures_util::{StreamExt, stream};
4use serde::{Deserialize, Serialize};
5use std::fmt::Write;
6
7#[derive(Debug, Clone, Serialize, Deserialize)]
9pub struct ChatMessage {
10 pub role: String,
11 pub content: String,
12}
13
14impl ChatMessage {
15 pub fn system(content: impl Into<String>) -> Self {
16 Self {
17 role: "system".into(),
18 content: content.into(),
19 }
20 }
21
22 pub fn user(content: impl Into<String>) -> Self {
23 Self {
24 role: "user".into(),
25 content: content.into(),
26 }
27 }
28
29 pub fn assistant(content: impl Into<String>) -> Self {
30 Self {
31 role: "assistant".into(),
32 content: content.into(),
33 }
34 }
35
36 pub fn tool(content: impl Into<String>) -> Self {
37 Self {
38 role: "tool".into(),
39 content: content.into(),
40 }
41 }
42}
43
44#[derive(Debug, Clone, Serialize, Deserialize)]
46pub struct ToolCall {
47 pub id: String,
48 pub name: String,
49 pub arguments: String,
50}
51
52#[derive(Debug, Clone, Default)]
54pub struct TokenUsage {
55 pub input_tokens: Option<u64>,
56 pub output_tokens: Option<u64>,
57 pub cached_input_tokens: Option<u64>,
60}
61
62#[derive(Debug, Clone)]
64pub struct ChatResponse {
65 pub text: Option<String>,
67 pub tool_calls: Vec<ToolCall>,
69 pub usage: Option<TokenUsage>,
71 pub reasoning_content: Option<String>,
76}
77
78impl ChatResponse {
79 pub fn has_tool_calls(&self) -> bool {
81 !self.tool_calls.is_empty()
82 }
83
84 pub fn text_or_empty(&self) -> &str {
86 self.text.as_deref().unwrap_or("")
87 }
88}
89
90#[derive(Debug, Clone, Copy)]
92pub struct ChatRequest<'a> {
93 pub messages: &'a [ChatMessage],
94 pub tools: Option<&'a [ToolSpec]>,
95}
96
97#[derive(Debug, Clone, Serialize, Deserialize)]
99pub struct ToolResultMessage {
100 pub tool_call_id: String,
101 pub content: String,
102}
103
104#[derive(Debug, Clone, Serialize, Deserialize)]
106#[serde(tag = "type", content = "data")]
107pub enum ConversationMessage {
108 Chat(ChatMessage),
110 AssistantToolCalls {
112 text: Option<String>,
113 tool_calls: Vec<ToolCall>,
114 reasoning_content: Option<String>,
117 },
118 ToolResults(Vec<ToolResultMessage>),
120}
121
122#[derive(Debug, Clone)]
124pub struct StreamChunk {
125 pub delta: String,
127 pub reasoning: Option<String>,
129 pub is_final: bool,
131 pub token_count: usize,
133}
134
135impl StreamChunk {
136 pub fn delta(text: impl Into<String>) -> Self {
138 Self {
139 delta: text.into(),
140 reasoning: None,
141 is_final: false,
142 token_count: 0,
143 }
144 }
145
146 pub fn reasoning(text: impl Into<String>) -> Self {
148 Self {
149 delta: String::new(),
150 reasoning: Some(text.into()),
151 is_final: false,
152 token_count: 0,
153 }
154 }
155
156 pub fn final_chunk() -> Self {
158 Self {
159 delta: String::new(),
160 reasoning: None,
161 is_final: true,
162 token_count: 0,
163 }
164 }
165
166 pub fn error(message: impl Into<String>) -> Self {
168 Self {
169 delta: message.into(),
170 reasoning: None,
171 is_final: true,
172 token_count: 0,
173 }
174 }
175
176 pub fn with_token_estimate(mut self) -> Self {
178 self.token_count = self.delta.len().div_ceil(4);
179 self
180 }
181}
182
183#[derive(Debug, Clone)]
188pub enum StreamEvent {
189 TextDelta(StreamChunk),
191 ToolCall(ToolCall),
193 PreExecutedToolCall { name: String, args: String },
196 PreExecutedToolResult { name: String, output: String },
198 Usage(TokenUsage),
202 Final,
204}
205
206impl StreamEvent {
207 pub(crate) fn from_chunk(chunk: StreamChunk) -> Self {
208 if chunk.is_final {
209 Self::Final
210 } else {
211 Self::TextDelta(chunk)
212 }
213 }
214}
215
216#[derive(Debug, Clone, Copy, Default)]
218pub struct StreamOptions {
219 pub enabled: bool,
221 pub count_tokens: bool,
223}
224
225impl StreamOptions {
226 pub fn new(enabled: bool) -> Self {
228 Self {
229 enabled,
230 count_tokens: false,
231 }
232 }
233
234 pub fn with_token_count(mut self) -> Self {
236 self.count_tokens = true;
237 self
238 }
239}
240
241pub type StreamResult<T> = std::result::Result<T, StreamError>;
243
244#[derive(Debug, thiserror::Error)]
246pub enum StreamError {
247 #[error("HTTP error: {0}")]
248 Http(reqwest::Error),
249
250 #[error("JSON parse error: {0}")]
251 Json(serde_json::Error),
252
253 #[error("Invalid SSE format: {0}")]
254 InvalidSse(String),
255
256 #[error("Provider error: {0}")]
257 Provider(String),
258
259 #[error("IO error: {0}")]
260 Io(#[from] std::io::Error),
261}
262
263#[derive(Debug, Clone, thiserror::Error)]
265#[error("provider_capability_error provider={provider} capability={capability} message={message}")]
266pub struct ProviderCapabilityError {
267 pub provider: String,
268 pub capability: String,
269 pub message: String,
270}
271
272#[derive(Debug, Clone, Default, PartialEq, Eq)]
277pub struct ProviderCapabilities {
278 pub native_tool_calling: bool,
285 pub vision: bool,
287 pub prompt_caching: bool,
290}
291
292#[derive(Debug, Clone)]
298pub enum ToolsPayload {
299 Gemini {
301 function_declarations: Vec<serde_json::Value>,
302 },
303 Anthropic { tools: Vec<serde_json::Value> },
305 OpenAI { tools: Vec<serde_json::Value> },
307 PromptGuided { instructions: String },
309}
310
311#[async_trait]
312pub trait Provider: Send + Sync {
313 fn capabilities(&self) -> ProviderCapabilities {
318 ProviderCapabilities::default()
319 }
320
321 fn convert_tools(&self, tools: &[ToolSpec]) -> ToolsPayload {
328 ToolsPayload::PromptGuided {
329 instructions: build_tool_instructions_text(tools),
330 }
331 }
332
333 async fn simple_chat(
337 &self,
338 message: &str,
339 model: &str,
340 temperature: f64,
341 ) -> anyhow::Result<String> {
342 self.chat_with_system(None, message, model, temperature)
343 .await
344 }
345
346 async fn chat_with_system(
350 &self,
351 system_prompt: Option<&str>,
352 message: &str,
353 model: &str,
354 temperature: f64,
355 ) -> anyhow::Result<String>;
356
357 async fn chat_with_history(
360 &self,
361 messages: &[ChatMessage],
362 model: &str,
363 temperature: f64,
364 ) -> anyhow::Result<String> {
365 let system = messages
366 .iter()
367 .find(|m| m.role == "system")
368 .map(|m| m.content.as_str());
369 let last_user = messages
370 .iter()
371 .rfind(|m| m.role == "user")
372 .map(|m| m.content.as_str())
373 .unwrap_or("");
374 self.chat_with_system(system, last_user, model, temperature)
375 .await
376 }
377
378 async fn chat(
380 &self,
381 request: ChatRequest<'_>,
382 model: &str,
383 temperature: f64,
384 ) -> anyhow::Result<ChatResponse> {
385 if let Some(tools) = request.tools {
388 if !tools.is_empty() && !self.supports_native_tools() {
389 let tool_instructions = match self.convert_tools(tools) {
390 ToolsPayload::PromptGuided { instructions } => instructions,
391 payload => {
392 anyhow::bail!(
393 "Provider returned non-prompt-guided tools payload ({payload:?}) while supports_native_tools() is false"
394 )
395 }
396 };
397 let mut modified_messages = request.messages.to_vec();
398
399 if let Some(system_message) =
402 modified_messages.iter_mut().find(|m| m.role == "system")
403 {
404 if !system_message.content.is_empty() {
405 system_message.content.push_str("\n\n");
406 }
407 system_message.content.push_str(&tool_instructions);
408 } else {
409 modified_messages.insert(0, ChatMessage::system(tool_instructions));
410 }
411
412 let text = self
413 .chat_with_history(&modified_messages, model, temperature)
414 .await?;
415 return Ok(ChatResponse {
416 text: Some(text),
417 tool_calls: Vec::new(),
418 usage: None,
419 reasoning_content: None,
420 });
421 }
422 }
423
424 let text = self
425 .chat_with_history(request.messages, model, temperature)
426 .await?;
427 Ok(ChatResponse {
428 text: Some(text),
429 tool_calls: Vec::new(),
430 usage: None,
431 reasoning_content: None,
432 })
433 }
434
435 fn supports_native_tools(&self) -> bool {
437 self.capabilities().native_tool_calling
438 }
439
440 fn supports_vision(&self) -> bool {
442 self.capabilities().vision
443 }
444
445 async fn warmup(&self) -> anyhow::Result<()> {
448 Ok(())
449 }
450
451 async fn chat_with_tools(
455 &self,
456 messages: &[ChatMessage],
457 _tools: &[serde_json::Value],
458 model: &str,
459 temperature: f64,
460 ) -> anyhow::Result<ChatResponse> {
461 let text = self.chat_with_history(messages, model, temperature).await?;
462 Ok(ChatResponse {
463 text: Some(text),
464 tool_calls: Vec::new(),
465 usage: None,
466 reasoning_content: None,
467 })
468 }
469
470 fn supports_streaming(&self) -> bool {
473 false
474 }
475
476 fn supports_streaming_tool_events(&self) -> bool {
481 false
482 }
483
484 fn stream_chat_with_system(
488 &self,
489 _system_prompt: Option<&str>,
490 _message: &str,
491 _model: &str,
492 _temperature: f64,
493 _options: StreamOptions,
494 ) -> stream::BoxStream<'static, StreamResult<StreamChunk>> {
495 stream::empty().boxed()
497 }
498
499 fn stream_chat_with_history(
503 &self,
504 messages: &[ChatMessage],
505 model: &str,
506 temperature: f64,
507 options: StreamOptions,
508 ) -> stream::BoxStream<'static, StreamResult<StreamChunk>> {
509 let system = messages
510 .iter()
511 .find(|m| m.role == "system")
512 .map(|m| m.content.as_str());
513 let last_user = messages
514 .iter()
515 .rfind(|m| m.role == "user")
516 .map(|m| m.content.as_str())
517 .unwrap_or("");
518 self.stream_chat_with_system(system, last_user, model, temperature, options)
519 }
520
521 fn stream_chat(
526 &self,
527 request: ChatRequest<'_>,
528 model: &str,
529 temperature: f64,
530 options: StreamOptions,
531 ) -> stream::BoxStream<'static, StreamResult<StreamEvent>> {
532 self.stream_chat_with_history(request.messages, model, temperature, options)
533 .map(|chunk_result| chunk_result.map(StreamEvent::from_chunk))
534 .boxed()
535 }
536}
537
538pub fn build_tool_instructions_text(tools: &[ToolSpec]) -> String {
544 let mut instructions = String::new();
545
546 instructions.push_str("## Tool Use Protocol\n\n");
547 instructions.push_str("To use a tool, wrap a JSON object in <tool_call></tool_call> tags:\n\n");
548 instructions.push_str("<tool_call>\n");
549 instructions.push_str(r#"{"name": "tool_name", "arguments": {"param": "value"}}"#);
550 instructions.push_str("\n</tool_call>\n\n");
551 instructions.push_str("You may use multiple tool calls in a single response. ");
552 instructions.push_str("After tool execution, results appear in <tool_result> tags. ");
553 instructions
554 .push_str("Continue reasoning with the results until you can give a final answer.\n\n");
555 instructions.push_str("### Available Tools\n\n");
556
557 for tool in tools {
558 writeln!(&mut instructions, "**{}**: {}", tool.name, tool.description)
559 .expect("writing to String cannot fail");
560
561 let parameters =
562 serde_json::to_string(&tool.parameters).unwrap_or_else(|_| "{}".to_string());
563 writeln!(&mut instructions, "Parameters: `{parameters}`")
564 .expect("writing to String cannot fail");
565 instructions.push('\n');
566 }
567
568 instructions
569}
570
571#[cfg(test)]
572mod tests {
573 use super::*;
574 use futures_util::StreamExt;
575
576 struct CapabilityMockProvider;
577
578 #[async_trait]
579 impl Provider for CapabilityMockProvider {
580 fn capabilities(&self) -> ProviderCapabilities {
581 ProviderCapabilities {
582 native_tool_calling: true,
583 vision: true,
584 prompt_caching: false,
585 }
586 }
587
588 async fn chat_with_system(
589 &self,
590 _system_prompt: Option<&str>,
591 _message: &str,
592 _model: &str,
593 _temperature: f64,
594 ) -> anyhow::Result<String> {
595 Ok("ok".into())
596 }
597 }
598
599 #[test]
600 fn chat_message_constructors() {
601 let sys = ChatMessage::system("Be helpful");
602 assert_eq!(sys.role, "system");
603 assert_eq!(sys.content, "Be helpful");
604
605 let user = ChatMessage::user("Hello");
606 assert_eq!(user.role, "user");
607
608 let asst = ChatMessage::assistant("Hi there");
609 assert_eq!(asst.role, "assistant");
610
611 let tool = ChatMessage::tool("{}");
612 assert_eq!(tool.role, "tool");
613 }
614
615 #[test]
616 fn chat_response_helpers() {
617 let empty = ChatResponse {
618 text: None,
619 tool_calls: vec![],
620 usage: None,
621 reasoning_content: None,
622 };
623 assert!(!empty.has_tool_calls());
624 assert_eq!(empty.text_or_empty(), "");
625
626 let with_tools = ChatResponse {
627 text: Some("Let me check".into()),
628 tool_calls: vec![ToolCall {
629 id: "1".into(),
630 name: "shell".into(),
631 arguments: "{}".into(),
632 }],
633 usage: None,
634 reasoning_content: None,
635 };
636 assert!(with_tools.has_tool_calls());
637 assert_eq!(with_tools.text_or_empty(), "Let me check");
638 }
639
640 #[test]
641 fn token_usage_default_is_none() {
642 let usage = TokenUsage::default();
643 assert!(usage.input_tokens.is_none());
644 assert!(usage.output_tokens.is_none());
645 }
646
647 #[test]
648 fn chat_response_with_usage() {
649 let resp = ChatResponse {
650 text: Some("Hello".into()),
651 tool_calls: vec![],
652 usage: Some(TokenUsage {
653 input_tokens: Some(100),
654 output_tokens: Some(50),
655 cached_input_tokens: None,
656 }),
657 reasoning_content: None,
658 };
659 assert_eq!(resp.usage.as_ref().unwrap().input_tokens, Some(100));
660 assert_eq!(resp.usage.as_ref().unwrap().output_tokens, Some(50));
661 }
662
663 #[test]
664 fn tool_call_serialization() {
665 let tc = ToolCall {
666 id: "call_123".into(),
667 name: "file_read".into(),
668 arguments: r#"{"path":"test.txt"}"#.into(),
669 };
670 let json = serde_json::to_string(&tc).unwrap();
671 assert!(json.contains("call_123"));
672 assert!(json.contains("file_read"));
673 }
674
675 #[test]
676 fn conversation_message_variants() {
677 let chat = ConversationMessage::Chat(ChatMessage::user("hi"));
678 let json = serde_json::to_string(&chat).unwrap();
679 assert!(json.contains("\"type\":\"Chat\""));
680
681 let tool_result = ConversationMessage::ToolResults(vec![ToolResultMessage {
682 tool_call_id: "1".into(),
683 content: "done".into(),
684 }]);
685 let json = serde_json::to_string(&tool_result).unwrap();
686 assert!(json.contains("\"type\":\"ToolResults\""));
687 }
688
689 #[test]
690 fn provider_capabilities_default() {
691 let caps = ProviderCapabilities::default();
692 assert!(!caps.native_tool_calling);
693 assert!(!caps.vision);
694 }
695
696 #[test]
697 fn provider_capabilities_equality() {
698 let caps1 = ProviderCapabilities {
699 native_tool_calling: true,
700 vision: false,
701 prompt_caching: false,
702 };
703 let caps2 = ProviderCapabilities {
704 native_tool_calling: true,
705 vision: false,
706 prompt_caching: false,
707 };
708 let caps3 = ProviderCapabilities {
709 native_tool_calling: false,
710 vision: false,
711 prompt_caching: false,
712 };
713
714 assert_eq!(caps1, caps2);
715 assert_ne!(caps1, caps3);
716 }
717
718 #[test]
719 fn supports_native_tools_reflects_capabilities_default_mapping() {
720 let provider = CapabilityMockProvider;
721 assert!(provider.supports_native_tools());
722 }
723
724 #[test]
725 fn supports_vision_reflects_capabilities_default_mapping() {
726 let provider = CapabilityMockProvider;
727 assert!(provider.supports_vision());
728 }
729
730 #[test]
731 fn tools_payload_variants() {
732 let gemini = ToolsPayload::Gemini {
734 function_declarations: vec![serde_json::json!({"name": "test"})],
735 };
736 assert!(matches!(gemini, ToolsPayload::Gemini { .. }));
737
738 let anthropic = ToolsPayload::Anthropic {
740 tools: vec![serde_json::json!({"name": "test"})],
741 };
742 assert!(matches!(anthropic, ToolsPayload::Anthropic { .. }));
743
744 let openai = ToolsPayload::OpenAI {
746 tools: vec![serde_json::json!({"type": "function"})],
747 };
748 assert!(matches!(openai, ToolsPayload::OpenAI { .. }));
749
750 let prompt_guided = ToolsPayload::PromptGuided {
752 instructions: "Use tools...".to_string(),
753 };
754 assert!(matches!(prompt_guided, ToolsPayload::PromptGuided { .. }));
755 }
756
757 #[test]
758 fn build_tool_instructions_text_format() {
759 let tools = vec![
760 ToolSpec {
761 name: "shell".to_string(),
762 description: "Execute commands".to_string(),
763 parameters: serde_json::json!({
764 "type": "object",
765 "properties": {
766 "command": {"type": "string"}
767 }
768 }),
769 },
770 ToolSpec {
771 name: "file_read".to_string(),
772 description: "Read files".to_string(),
773 parameters: serde_json::json!({
774 "type": "object",
775 "properties": {
776 "path": {"type": "string"}
777 }
778 }),
779 },
780 ];
781
782 let instructions = build_tool_instructions_text(&tools);
783
784 assert!(instructions.contains("Tool Use Protocol"));
786 assert!(instructions.contains("<tool_call>"));
787 assert!(instructions.contains("</tool_call>"));
788
789 assert!(instructions.contains("**shell**"));
791 assert!(instructions.contains("Execute commands"));
792 assert!(instructions.contains("**file_read**"));
793 assert!(instructions.contains("Read files"));
794
795 assert!(instructions.contains("Parameters:"));
797 assert!(instructions.contains(r#""type":"object""#));
798 }
799
800 #[test]
801 fn build_tool_instructions_text_empty() {
802 let instructions = build_tool_instructions_text(&[]);
803
804 assert!(instructions.contains("Tool Use Protocol"));
806
807 assert!(instructions.contains("Available Tools"));
809 }
810
811 struct MockProvider {
813 supports_native: bool,
814 }
815
816 #[async_trait]
817 impl Provider for MockProvider {
818 fn supports_native_tools(&self) -> bool {
819 self.supports_native
820 }
821
822 async fn chat_with_system(
823 &self,
824 _system: Option<&str>,
825 _message: &str,
826 _model: &str,
827 _temperature: f64,
828 ) -> anyhow::Result<String> {
829 Ok("response".to_string())
830 }
831 }
832
833 #[test]
834 fn provider_convert_tools_default() {
835 let provider = MockProvider {
836 supports_native: false,
837 };
838
839 let tools = vec![ToolSpec {
840 name: "test_tool".to_string(),
841 description: "A test tool".to_string(),
842 parameters: serde_json::json!({"type": "object"}),
843 }];
844
845 let payload = provider.convert_tools(&tools);
846
847 assert!(matches!(payload, ToolsPayload::PromptGuided { .. }));
849
850 if let ToolsPayload::PromptGuided { instructions } = payload {
851 assert!(instructions.contains("test_tool"));
852 assert!(instructions.contains("A test tool"));
853 }
854 }
855
856 #[tokio::test]
857 async fn provider_chat_prompt_guided_fallback() {
858 let provider = MockProvider {
859 supports_native: false,
860 };
861
862 let tools = vec![ToolSpec {
863 name: "shell".to_string(),
864 description: "Run commands".to_string(),
865 parameters: serde_json::json!({"type": "object"}),
866 }];
867
868 let request = ChatRequest {
869 messages: &[ChatMessage::user("Hello")],
870 tools: Some(&tools),
871 };
872
873 let response = provider.chat(request, "model", 0.7).await.unwrap();
874
875 assert!(response.text.is_some());
877 }
878
879 #[tokio::test]
880 async fn provider_chat_without_tools() {
881 let provider = MockProvider {
882 supports_native: true,
883 };
884
885 let request = ChatRequest {
886 messages: &[ChatMessage::user("Hello")],
887 tools: None,
888 };
889
890 let response = provider.chat(request, "model", 0.7).await.unwrap();
891
892 assert!(response.text.is_some());
894 }
895
896 struct EchoSystemProvider {
898 supports_native: bool,
899 }
900
901 #[async_trait]
902 impl Provider for EchoSystemProvider {
903 fn supports_native_tools(&self) -> bool {
904 self.supports_native
905 }
906
907 async fn chat_with_system(
908 &self,
909 system: Option<&str>,
910 _message: &str,
911 _model: &str,
912 _temperature: f64,
913 ) -> anyhow::Result<String> {
914 Ok(system.unwrap_or_default().to_string())
915 }
916 }
917
918 struct CustomConvertProvider;
920
921 #[async_trait]
922 impl Provider for CustomConvertProvider {
923 fn supports_native_tools(&self) -> bool {
924 false
925 }
926
927 fn convert_tools(&self, _tools: &[ToolSpec]) -> ToolsPayload {
928 ToolsPayload::PromptGuided {
929 instructions: "CUSTOM_TOOL_INSTRUCTIONS".to_string(),
930 }
931 }
932
933 async fn chat_with_system(
934 &self,
935 system: Option<&str>,
936 _message: &str,
937 _model: &str,
938 _temperature: f64,
939 ) -> anyhow::Result<String> {
940 Ok(system.unwrap_or_default().to_string())
941 }
942 }
943
944 struct InvalidConvertProvider;
946
947 #[async_trait]
948 impl Provider for InvalidConvertProvider {
949 fn supports_native_tools(&self) -> bool {
950 false
951 }
952
953 fn convert_tools(&self, _tools: &[ToolSpec]) -> ToolsPayload {
954 ToolsPayload::OpenAI {
955 tools: vec![serde_json::json!({"type": "function"})],
956 }
957 }
958
959 async fn chat_with_system(
960 &self,
961 _system: Option<&str>,
962 _message: &str,
963 _model: &str,
964 _temperature: f64,
965 ) -> anyhow::Result<String> {
966 Ok("should_not_reach".to_string())
967 }
968 }
969
970 #[tokio::test]
971 async fn provider_chat_prompt_guided_preserves_existing_system_not_first() {
972 let provider = EchoSystemProvider {
973 supports_native: false,
974 };
975
976 let tools = vec![ToolSpec {
977 name: "shell".to_string(),
978 description: "Run commands".to_string(),
979 parameters: serde_json::json!({"type": "object"}),
980 }];
981
982 let request = ChatRequest {
983 messages: &[
984 ChatMessage::user("Hello"),
985 ChatMessage::system("BASE_SYSTEM_PROMPT"),
986 ],
987 tools: Some(&tools),
988 };
989
990 let response = provider.chat(request, "model", 0.7).await.unwrap();
991 let text = response.text.unwrap_or_default();
992
993 assert!(text.contains("BASE_SYSTEM_PROMPT"));
994 assert!(text.contains("Tool Use Protocol"));
995 }
996
997 #[tokio::test]
998 async fn provider_chat_prompt_guided_uses_convert_tools_override() {
999 let provider = CustomConvertProvider;
1000
1001 let tools = vec![ToolSpec {
1002 name: "shell".to_string(),
1003 description: "Run commands".to_string(),
1004 parameters: serde_json::json!({"type": "object"}),
1005 }];
1006
1007 let request = ChatRequest {
1008 messages: &[ChatMessage::system("BASE"), ChatMessage::user("Hello")],
1009 tools: Some(&tools),
1010 };
1011
1012 let response = provider.chat(request, "model", 0.7).await.unwrap();
1013 let text = response.text.unwrap_or_default();
1014
1015 assert!(text.contains("BASE"));
1016 assert!(text.contains("CUSTOM_TOOL_INSTRUCTIONS"));
1017 }
1018
1019 #[tokio::test]
1020 async fn provider_chat_prompt_guided_rejects_non_prompt_payload() {
1021 let provider = InvalidConvertProvider;
1022
1023 let tools = vec![ToolSpec {
1024 name: "shell".to_string(),
1025 description: "Run commands".to_string(),
1026 parameters: serde_json::json!({"type": "object"}),
1027 }];
1028
1029 let request = ChatRequest {
1030 messages: &[ChatMessage::user("Hello")],
1031 tools: Some(&tools),
1032 };
1033
1034 let err = provider.chat(request, "model", 0.7).await.unwrap_err();
1035 let message = err.to_string();
1036
1037 assert!(message.contains("non-prompt-guided"));
1038 }
1039
1040 struct StreamingChunkOnlyProvider;
1041
1042 #[async_trait]
1043 impl Provider for StreamingChunkOnlyProvider {
1044 async fn chat_with_system(
1045 &self,
1046 _system_prompt: Option<&str>,
1047 _message: &str,
1048 _model: &str,
1049 _temperature: f64,
1050 ) -> anyhow::Result<String> {
1051 Ok("ok".to_string())
1052 }
1053
1054 fn supports_streaming(&self) -> bool {
1055 true
1056 }
1057
1058 fn stream_chat_with_history(
1059 &self,
1060 _messages: &[ChatMessage],
1061 _model: &str,
1062 _temperature: f64,
1063 _options: StreamOptions,
1064 ) -> stream::BoxStream<'static, StreamResult<StreamChunk>> {
1065 stream::iter(vec![
1066 Ok(StreamChunk::delta("hello")),
1067 Ok(StreamChunk::final_chunk()),
1068 ])
1069 .boxed()
1070 }
1071 }
1072
1073 #[tokio::test]
1074 async fn provider_stream_chat_default_maps_legacy_chunks_to_events() {
1075 let provider = StreamingChunkOnlyProvider;
1076 let mut stream = provider.stream_chat(
1077 ChatRequest {
1078 messages: &[ChatMessage::user("hi")],
1079 tools: None,
1080 },
1081 "model",
1082 0.0,
1083 StreamOptions::new(true),
1084 );
1085
1086 let first = stream.next().await.unwrap().unwrap();
1087 let second = stream.next().await.unwrap().unwrap();
1088 assert!(stream.next().await.is_none());
1089
1090 match first {
1091 StreamEvent::TextDelta(chunk) => assert_eq!(chunk.delta, "hello"),
1092 other => panic!("expected text delta event, got {other:?}"),
1093 }
1094 assert!(matches!(second, StreamEvent::Final));
1095 }
1096}