1use core::fmt;
7use std::future::Future;
8
9use futures_util::{Stream, StreamExt};
10use reqwest::{Client, StatusCode};
11use serde::{Deserialize, Serialize};
12
13use serde_json::Value;
14use thiserror::Error;
15
16#[derive(Debug, Serialize, Deserialize)]
18pub struct SSEvents {
19 pub data: String,
20 pub event: Option<String>,
21 pub retry: Option<u64>,
22}
23
24#[derive(Error, Debug)]
26pub enum GatewayError {
27 #[error("Unauthorized: {0}")]
28 Unauthorized(String),
29
30 #[error("Forbidden: {0}")]
31 Forbidden(String),
32
33 #[error("Not found: {0}")]
34 NotFound(String),
35
36 #[error("Bad request: {0}")]
37 BadRequest(String),
38
39 #[error("Internal server error: {0}")]
40 InternalError(String),
41
42 #[error("Stream error: {0}")]
43 StreamError(reqwest::Error),
44
45 #[error("Decoding error: {0}")]
46 DecodingError(std::string::FromUtf8Error),
47
48 #[error("Request error: {0}")]
49 RequestError(#[from] reqwest::Error),
50
51 #[error("Deserialization error: {0}")]
52 DeserializationError(serde_json::Error),
53
54 #[error("Serialization error: {0}")]
55 SerializationError(#[from] serde_json::Error),
56
57 #[error("Other error: {0}")]
58 Other(#[from] Box<dyn std::error::Error + Send + Sync>),
59}
60
61#[derive(Debug, Deserialize)]
62struct ErrorResponse {
63 error: String,
64}
65
66#[derive(Debug, Serialize, Deserialize, Clone)]
68pub struct Model {
69 pub id: String,
71 pub object: String,
73 pub created: i64,
75 pub owned_by: String,
77 pub served_by: Provider,
79}
80
81#[derive(Debug, Serialize, Deserialize)]
83pub struct ListModelsResponse {
84 #[serde(skip_serializing_if = "Option::is_none")]
86 pub provider: Option<Provider>,
87 pub object: String,
89 pub data: Vec<Model>,
91}
92
93#[derive(Debug, Serialize, Deserialize, Clone)]
95pub struct MCPTool {
96 pub name: String,
98 pub description: String,
100 pub server: String,
102 #[serde(skip_serializing_if = "Option::is_none")]
104 pub input_schema: Option<Value>,
105}
106
107#[derive(Debug, Serialize, Deserialize)]
109pub struct ListToolsResponse {
110 pub object: String,
112 pub data: Vec<MCPTool>,
114}
115
116#[derive(Debug, Serialize, Deserialize, PartialEq, Clone, Copy)]
118#[serde(rename_all = "lowercase")]
119pub enum Provider {
120 #[serde(alias = "Ollama", alias = "OLLAMA")]
121 Ollama,
122 #[serde(alias = "OllamaCloud", alias = "OLLAMA_CLOUD", rename = "ollama_cloud")]
123 OllamaCloud,
124 #[serde(alias = "Groq", alias = "GROQ")]
125 Groq,
126 #[serde(alias = "OpenAI", alias = "OPENAI")]
127 OpenAI,
128 #[serde(alias = "Cloudflare", alias = "CLOUDFLARE")]
129 Cloudflare,
130 #[serde(alias = "Cohere", alias = "COHERE")]
131 Cohere,
132 #[serde(alias = "Anthropic", alias = "ANTHROPIC")]
133 Anthropic,
134 #[serde(alias = "Deepseek", alias = "DEEPSEEK")]
135 Deepseek,
136 #[serde(alias = "Google", alias = "GOOGLE")]
137 Google,
138 #[serde(alias = "Mistral", alias = "MISTRAL")]
139 Mistral,
140}
141
142impl fmt::Display for Provider {
143 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
144 match self {
145 Provider::Ollama => write!(f, "ollama"),
146 Provider::OllamaCloud => write!(f, "ollama_cloud"),
147 Provider::Groq => write!(f, "groq"),
148 Provider::OpenAI => write!(f, "openai"),
149 Provider::Cloudflare => write!(f, "cloudflare"),
150 Provider::Cohere => write!(f, "cohere"),
151 Provider::Anthropic => write!(f, "anthropic"),
152 Provider::Deepseek => write!(f, "deepseek"),
153 Provider::Google => write!(f, "google"),
154 Provider::Mistral => write!(f, "mistral"),
155 }
156 }
157}
158
159impl TryFrom<&str> for Provider {
160 type Error = GatewayError;
161
162 fn try_from(s: &str) -> Result<Self, Self::Error> {
163 match s.to_lowercase().as_str() {
164 "ollama" => Ok(Self::Ollama),
165 "ollama_cloud" => Ok(Self::OllamaCloud),
166 "groq" => Ok(Self::Groq),
167 "openai" => Ok(Self::OpenAI),
168 "cloudflare" => Ok(Self::Cloudflare),
169 "cohere" => Ok(Self::Cohere),
170 "anthropic" => Ok(Self::Anthropic),
171 "deepseek" => Ok(Self::Deepseek),
172 "google" => Ok(Self::Google),
173 "mistral" => Ok(Self::Mistral),
174 _ => Err(GatewayError::BadRequest(format!("Unknown provider: {s}"))),
175 }
176 }
177}
178
179#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Default)]
180#[serde(rename_all = "lowercase")]
181pub enum MessageRole {
182 System,
183 #[default]
184 User,
185 Assistant,
186 Tool,
187}
188
189impl fmt::Display for MessageRole {
190 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
191 match self {
192 MessageRole::System => write!(f, "system"),
193 MessageRole::User => write!(f, "user"),
194 MessageRole::Assistant => write!(f, "assistant"),
195 MessageRole::Tool => write!(f, "tool"),
196 }
197 }
198}
199
200#[derive(Debug, Serialize, Deserialize, Clone, Default)]
202pub struct Message {
203 pub role: MessageRole,
205 pub content: String,
207 #[serde(skip_serializing_if = "Option::is_none")]
209 pub tool_calls: Option<Vec<ChatCompletionMessageToolCall>>,
210 #[serde(skip_serializing_if = "Option::is_none")]
212 pub tool_call_id: Option<String>,
213 #[serde(skip_serializing_if = "Option::is_none")]
215 pub reasoning_content: Option<String>,
216 #[serde(skip_serializing_if = "Option::is_none")]
218 pub reasoning: Option<String>,
219}
220
221#[derive(Debug, Deserialize, Serialize, Clone)]
223pub struct ChatCompletionMessageToolCall {
224 pub id: String,
226 #[serde(rename = "type")]
228 pub r#type: ChatCompletionToolType,
229 pub function: ChatCompletionMessageToolCallFunction,
231}
232
233#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
235pub enum ChatCompletionToolType {
236 #[serde(rename = "function")]
238 Function,
239}
240
241#[derive(Debug, Deserialize, Serialize, Clone)]
243pub struct ChatCompletionMessageToolCallFunction {
244 pub name: String,
246 pub arguments: String,
248}
249
250impl ChatCompletionMessageToolCallFunction {
252 pub fn parse_arguments(&self) -> Result<serde_json::Value, serde_json::Error> {
253 serde_json::from_str(&self.arguments)
254 }
255}
256
257#[derive(Debug, Serialize, Deserialize, Clone)]
259pub struct FunctionObject {
260 pub name: String,
261 pub description: String,
262 pub parameters: Value,
263}
264
265#[derive(Debug, Serialize, Deserialize, Clone)]
267#[serde(rename_all = "lowercase")]
268pub enum ToolType {
269 Function,
270}
271
272#[derive(Debug, Serialize, Deserialize, Clone)]
274pub struct Tool {
275 pub r#type: ToolType,
276 pub function: FunctionObject,
277}
278
279#[derive(Debug, Serialize)]
281struct CreateChatCompletionRequest {
282 model: String,
284 messages: Vec<Message>,
286 stream: bool,
288 #[serde(skip_serializing_if = "Option::is_none")]
290 tools: Option<Vec<Tool>>,
291 #[serde(skip_serializing_if = "Option::is_none")]
293 max_tokens: Option<i32>,
294 #[serde(skip_serializing_if = "Option::is_none")]
296 reasoning_format: Option<String>,
297}
298
299#[derive(Debug, Serialize, Deserialize, Clone)]
301pub struct ChatCompletionMessageToolCallChunk {
302 pub index: i32,
304 #[serde(skip_serializing_if = "Option::is_none")]
306 pub id: Option<String>,
307 #[serde(rename = "type", skip_serializing_if = "Option::is_none")]
309 pub r#type: Option<String>,
310 #[serde(skip_serializing_if = "Option::is_none")]
312 pub function: Option<ChatCompletionMessageToolCallFunction>,
313}
314
315#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
317#[serde(rename_all = "snake_case")]
318pub enum FinishReason {
319 Stop,
321 Length,
323 ToolCalls,
325 ContentFilter,
327 FunctionCall,
329}
330
331#[derive(Debug, Deserialize, Clone)]
332pub struct ChatCompletionChoice {
333 pub finish_reason: FinishReason,
334 pub message: Message,
335 pub index: i32,
336 pub logprobs: Option<ChoiceLogprobs>,
338}
339
340#[derive(Debug, Deserialize, Clone)]
342pub struct CreateChatCompletionResponse {
343 pub id: String,
344 pub choices: Vec<ChatCompletionChoice>,
345 pub created: i64,
346 pub model: String,
347 pub object: String,
348}
349
350#[derive(Debug, Deserialize, Clone)]
352pub struct CreateChatCompletionStreamResponse {
353 pub id: String,
355 pub choices: Vec<ChatCompletionStreamChoice>,
357 pub created: i64,
359 pub model: String,
361 #[serde(skip_serializing_if = "Option::is_none")]
363 pub system_fingerprint: Option<String>,
364 pub object: String,
366 #[serde(skip_serializing_if = "Option::is_none")]
368 pub usage: Option<CompletionUsage>,
369 #[serde(skip_serializing_if = "Option::is_none")]
371 pub reasoning_format: Option<String>,
372}
373
374#[derive(Debug, Deserialize, Clone)]
376pub struct ChatCompletionTokenLogprob {
377 pub token: String,
379 pub logprob: f64,
381 pub bytes: Option<Vec<i32>>,
383 pub top_logprobs: Vec<TopLogprob>,
385}
386
387#[derive(Debug, Deserialize, Clone)]
389pub struct TopLogprob {
390 pub token: String,
392 pub logprob: f64,
394 pub bytes: Option<Vec<i32>>,
396}
397
398#[derive(Debug, Deserialize, Clone)]
400pub struct ChoiceLogprobs {
401 pub content: Option<Vec<ChatCompletionTokenLogprob>>,
403 pub refusal: Option<Vec<ChatCompletionTokenLogprob>>,
405}
406
407#[derive(Debug, Deserialize, Clone)]
409pub struct ChatCompletionStreamChoice {
410 pub delta: ChatCompletionStreamDelta,
412 pub index: i32,
414 #[serde(skip_serializing_if = "Option::is_none")]
416 pub finish_reason: Option<FinishReason>,
417 #[serde(skip_serializing_if = "Option::is_none")]
419 pub logprobs: Option<ChoiceLogprobs>,
420}
421
422#[derive(Debug, Deserialize, Clone)]
424pub struct ChatCompletionStreamDelta {
425 #[serde(skip_serializing_if = "Option::is_none")]
427 pub role: Option<MessageRole>,
428 #[serde(skip_serializing_if = "Option::is_none")]
430 pub content: Option<String>,
431 #[serde(skip_serializing_if = "Option::is_none")]
433 pub reasoning_content: Option<String>,
434 #[serde(skip_serializing_if = "Option::is_none")]
436 pub reasoning: Option<String>,
437 #[serde(skip_serializing_if = "Option::is_none")]
439 pub tool_calls: Option<Vec<ChatCompletionMessageToolCallChunk>>,
440 #[serde(skip_serializing_if = "Option::is_none")]
442 pub refusal: Option<String>,
443}
444
445#[derive(Debug, Deserialize, Clone)]
447pub struct CompletionUsage {
448 pub completion_tokens: i64,
450 pub prompt_tokens: i64,
452 pub total_tokens: i64,
454}
455
456pub struct InferenceGatewayClient {
458 base_url: String,
459 client: Client,
460 token: Option<String>,
461 tools: Option<Vec<Tool>>,
462 max_tokens: Option<i32>,
463}
464
465impl std::fmt::Debug for InferenceGatewayClient {
467 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
468 f.debug_struct("InferenceGatewayClient")
469 .field("base_url", &self.base_url)
470 .field("token", &self.token.as_ref().map(|_| "*****"))
471 .finish()
472 }
473}
474
475pub trait InferenceGatewayAPI {
477 fn list_models(&self) -> impl Future<Output = Result<ListModelsResponse, GatewayError>> + Send;
488
489 fn list_models_by_provider(
503 &self,
504 provider: Provider,
505 ) -> impl Future<Output = Result<ListModelsResponse, GatewayError>> + Send;
506
507 fn generate_content(
524 &self,
525 provider: Provider,
526 model: &str,
527 messages: Vec<Message>,
528 ) -> impl Future<Output = Result<CreateChatCompletionResponse, GatewayError>> + Send;
529
530 fn generate_content_stream(
540 &self,
541 provider: Provider,
542 model: &str,
543 messages: Vec<Message>,
544 ) -> impl Stream<Item = Result<SSEvents, GatewayError>> + Send;
545
546 fn list_tools(&self) -> impl Future<Output = Result<ListToolsResponse, GatewayError>> + Send;
557
558 fn health_check(&self) -> impl Future<Output = Result<bool, GatewayError>> + Send;
560}
561
562impl InferenceGatewayClient {
563 pub fn new(base_url: &str) -> Self {
568 Self {
569 base_url: base_url.to_string(),
570 client: Client::new(),
571 token: None,
572 tools: None,
573 max_tokens: None,
574 }
575 }
576
577 pub fn new_default() -> Self {
580 let base_url = std::env::var("INFERENCE_GATEWAY_URL")
581 .unwrap_or_else(|_| "http://localhost:8080/v1".to_string());
582
583 Self {
584 base_url,
585 client: Client::new(),
586 token: None,
587 tools: None,
588 max_tokens: None,
589 }
590 }
591
592 pub fn base_url(&self) -> &str {
594 &self.base_url
595 }
596
597 pub fn with_tools(mut self, tools: Option<Vec<Tool>>) -> Self {
605 self.tools = tools;
606 self
607 }
608
609 pub fn with_token(mut self, token: impl Into<String>) -> Self {
617 self.token = Some(token.into());
618 self
619 }
620
621 pub fn with_max_tokens(mut self, max_tokens: Option<i32>) -> Self {
629 self.max_tokens = max_tokens;
630 self
631 }
632}
633
634impl InferenceGatewayAPI for InferenceGatewayClient {
635 async fn list_models(&self) -> Result<ListModelsResponse, GatewayError> {
636 let url = format!("{}/models", self.base_url);
637 let mut request = self.client.get(&url);
638 if let Some(token) = &self.token {
639 request = request.bearer_auth(token);
640 }
641
642 let response = request.send().await?;
643 match response.status() {
644 StatusCode::OK => {
645 let json_response: ListModelsResponse = response.json().await?;
646 Ok(json_response)
647 }
648 StatusCode::UNAUTHORIZED => {
649 let error: ErrorResponse = response.json().await?;
650 Err(GatewayError::Unauthorized(error.error))
651 }
652 StatusCode::BAD_REQUEST => {
653 let error: ErrorResponse = response.json().await?;
654 Err(GatewayError::BadRequest(error.error))
655 }
656 StatusCode::INTERNAL_SERVER_ERROR => {
657 let error: ErrorResponse = response.json().await?;
658 Err(GatewayError::InternalError(error.error))
659 }
660 _ => Err(GatewayError::Other(Box::new(std::io::Error::other(
661 format!("Unexpected status code: {}", response.status()),
662 )))),
663 }
664 }
665
666 async fn list_models_by_provider(
667 &self,
668 provider: Provider,
669 ) -> Result<ListModelsResponse, GatewayError> {
670 let url = format!("{}/models?provider={}", self.base_url, provider);
671 let mut request = self.client.get(&url);
672 if let Some(token) = &self.token {
673 request = self.client.get(&url).bearer_auth(token);
674 }
675
676 let response = request.send().await?;
677 match response.status() {
678 StatusCode::OK => {
679 let json_response: ListModelsResponse = response.json().await?;
680 Ok(json_response)
681 }
682 StatusCode::UNAUTHORIZED => {
683 let error: ErrorResponse = response.json().await?;
684 Err(GatewayError::Unauthorized(error.error))
685 }
686 StatusCode::BAD_REQUEST => {
687 let error: ErrorResponse = response.json().await?;
688 Err(GatewayError::BadRequest(error.error))
689 }
690 StatusCode::INTERNAL_SERVER_ERROR => {
691 let error: ErrorResponse = response.json().await?;
692 Err(GatewayError::InternalError(error.error))
693 }
694 _ => Err(GatewayError::Other(Box::new(std::io::Error::other(
695 format!("Unexpected status code: {}", response.status()),
696 )))),
697 }
698 }
699
700 async fn generate_content(
701 &self,
702 provider: Provider,
703 model: &str,
704 messages: Vec<Message>,
705 ) -> Result<CreateChatCompletionResponse, GatewayError> {
706 let url = format!("{}/chat/completions?provider={}", self.base_url, provider);
707 let mut request = self.client.post(&url);
708 if let Some(token) = &self.token {
709 request = request.bearer_auth(token);
710 }
711
712 let request_payload = CreateChatCompletionRequest {
713 model: model.to_string(),
714 messages,
715 stream: false,
716 tools: self.tools.clone(),
717 max_tokens: self.max_tokens,
718 reasoning_format: None,
719 };
720
721 let response = request.json(&request_payload).send().await?;
722
723 match response.status() {
724 StatusCode::OK => Ok(response.json().await?),
725 StatusCode::BAD_REQUEST => {
726 let error: ErrorResponse = response.json().await?;
727 Err(GatewayError::BadRequest(error.error))
728 }
729 StatusCode::UNAUTHORIZED => {
730 let error: ErrorResponse = response.json().await?;
731 Err(GatewayError::Unauthorized(error.error))
732 }
733 StatusCode::INTERNAL_SERVER_ERROR => {
734 let error: ErrorResponse = response.json().await?;
735 Err(GatewayError::InternalError(error.error))
736 }
737 status => Err(GatewayError::Other(Box::new(std::io::Error::other(
738 format!("Unexpected status code: {status}"),
739 )))),
740 }
741 }
742
743 fn generate_content_stream(
745 &self,
746 provider: Provider,
747 model: &str,
748 messages: Vec<Message>,
749 ) -> impl Stream<Item = Result<SSEvents, GatewayError>> + Send {
750 let client = self.client.clone();
751 let base_url = self.base_url.clone();
752 let url = format!(
753 "{}/chat/completions?provider={}",
754 base_url,
755 provider.to_string().to_lowercase()
756 );
757
758 let request = CreateChatCompletionRequest {
759 model: model.to_string(),
760 messages,
761 stream: true,
762 tools: None,
763 max_tokens: None,
764 reasoning_format: None,
765 };
766
767 async_stream::try_stream! {
768 let response = client.post(&url).json(&request).send().await?;
769 let mut stream = response.bytes_stream();
770 let mut current_event: Option<String> = None;
771 let mut current_data: Option<String> = None;
772
773 while let Some(chunk) = stream.next().await {
774 let chunk = chunk?;
775 let chunk_str = String::from_utf8_lossy(&chunk);
776
777 for line in chunk_str.lines() {
778 if line.is_empty() && current_data.is_some() {
779 yield SSEvents {
780 data: current_data.take().unwrap(),
781 event: current_event.take(),
782 retry: None, };
784 continue;
785 }
786
787 if let Some(event) = line.strip_prefix("event:") {
788 current_event = Some(event.trim().to_string());
789 } else if let Some(data) = line.strip_prefix("data:") {
790 let processed_data = data.strip_suffix('\n').unwrap_or(data);
791 current_data = Some(processed_data.trim().to_string());
792 }
793 }
794 }
795 }
796 }
797
798 async fn list_tools(&self) -> Result<ListToolsResponse, GatewayError> {
799 let url = format!("{}/mcp/tools", self.base_url);
800 let mut request = self.client.get(&url);
801 if let Some(token) = &self.token {
802 request = request.bearer_auth(token);
803 }
804
805 let response = request.send().await?;
806 match response.status() {
807 StatusCode::OK => {
808 let json_response: ListToolsResponse = response.json().await?;
809 Ok(json_response)
810 }
811 StatusCode::UNAUTHORIZED => {
812 let error: ErrorResponse = response.json().await?;
813 Err(GatewayError::Unauthorized(error.error))
814 }
815 StatusCode::BAD_REQUEST => {
816 let error: ErrorResponse = response.json().await?;
817 Err(GatewayError::BadRequest(error.error))
818 }
819 StatusCode::FORBIDDEN => {
820 let error: ErrorResponse = response.json().await?;
821 Err(GatewayError::Forbidden(error.error))
822 }
823 StatusCode::INTERNAL_SERVER_ERROR => {
824 let error: ErrorResponse = response.json().await?;
825 Err(GatewayError::InternalError(error.error))
826 }
827 _ => Err(GatewayError::Other(Box::new(std::io::Error::other(
828 format!("Unexpected status code: {}", response.status()),
829 )))),
830 }
831 }
832
833 async fn health_check(&self) -> Result<bool, GatewayError> {
834 let url = format!("{}/health", self.base_url);
835
836 let response = self.client.get(&url).send().await?;
837 match response.status() {
838 StatusCode::OK => Ok(true),
839 _ => Ok(false),
840 }
841 }
842}
843
844#[cfg(test)]
845mod tests {
846 use crate::{
847 CreateChatCompletionRequest, CreateChatCompletionResponse,
848 CreateChatCompletionStreamResponse, FinishReason, FunctionObject, GatewayError,
849 InferenceGatewayAPI, InferenceGatewayClient, Message, MessageRole, Provider, Tool,
850 ToolType,
851 };
852 use futures_util::{pin_mut, StreamExt};
853 use mockito::{Matcher, Server};
854 use serde_json::json;
855
856 #[test]
857 fn test_provider_serialization() {
858 let providers = vec![
859 (Provider::Ollama, "ollama"),
860 (Provider::OllamaCloud, "ollama_cloud"),
861 (Provider::Groq, "groq"),
862 (Provider::OpenAI, "openai"),
863 (Provider::Cloudflare, "cloudflare"),
864 (Provider::Cohere, "cohere"),
865 (Provider::Anthropic, "anthropic"),
866 (Provider::Deepseek, "deepseek"),
867 (Provider::Google, "google"),
868 (Provider::Mistral, "mistral"),
869 ];
870
871 for (provider, expected) in providers {
872 let json = serde_json::to_string(&provider).unwrap();
873 assert_eq!(json, format!("\"{}\"", expected));
874 }
875 }
876
877 #[test]
878 fn test_provider_deserialization() {
879 let test_cases = vec![
880 ("\"ollama\"", Provider::Ollama),
881 ("\"ollama_cloud\"", Provider::OllamaCloud),
882 ("\"groq\"", Provider::Groq),
883 ("\"openai\"", Provider::OpenAI),
884 ("\"cloudflare\"", Provider::Cloudflare),
885 ("\"cohere\"", Provider::Cohere),
886 ("\"anthropic\"", Provider::Anthropic),
887 ("\"deepseek\"", Provider::Deepseek),
888 ("\"google\"", Provider::Google),
889 ("\"mistral\"", Provider::Mistral),
890 ];
891
892 for (json, expected) in test_cases {
893 let provider: Provider = serde_json::from_str(json).unwrap();
894 assert_eq!(provider, expected);
895 }
896 }
897
898 #[test]
899 fn test_message_serialization_with_tool_call_id() {
900 let message_with_tool = Message {
901 role: MessageRole::Tool,
902 content: "The weather is sunny".to_string(),
903 tool_call_id: Some("call_123".to_string()),
904 ..Default::default()
905 };
906
907 let serialized = serde_json::to_string(&message_with_tool).unwrap();
908 let expected_with_tool =
909 r#"{"role":"tool","content":"The weather is sunny","tool_call_id":"call_123"}"#;
910 assert_eq!(serialized, expected_with_tool);
911
912 let message_without_tool = Message {
913 role: MessageRole::User,
914 content: "What's the weather?".to_string(),
915 ..Default::default()
916 };
917
918 let serialized = serde_json::to_string(&message_without_tool).unwrap();
919 let expected_without_tool = r#"{"role":"user","content":"What's the weather?"}"#;
920 assert_eq!(serialized, expected_without_tool);
921
922 let deserialized: Message = serde_json::from_str(expected_with_tool).unwrap();
923 assert_eq!(deserialized.role, MessageRole::Tool);
924 assert_eq!(deserialized.content, "The weather is sunny");
925 assert_eq!(deserialized.tool_call_id, Some("call_123".to_string()));
926
927 let deserialized: Message = serde_json::from_str(expected_without_tool).unwrap();
928 assert_eq!(deserialized.role, MessageRole::User);
929 assert_eq!(deserialized.content, "What's the weather?");
930 assert_eq!(deserialized.tool_call_id, None);
931 }
932
933 #[test]
934 fn test_provider_display() {
935 let providers = vec![
936 (Provider::Ollama, "ollama"),
937 (Provider::OllamaCloud, "ollama_cloud"),
938 (Provider::Groq, "groq"),
939 (Provider::OpenAI, "openai"),
940 (Provider::Cloudflare, "cloudflare"),
941 (Provider::Cohere, "cohere"),
942 (Provider::Anthropic, "anthropic"),
943 (Provider::Deepseek, "deepseek"),
944 (Provider::Google, "google"),
945 (Provider::Mistral, "mistral"),
946 ];
947
948 for (provider, expected) in providers {
949 assert_eq!(provider.to_string(), expected);
950 }
951 }
952
953 #[test]
954 fn test_google_provider_case_insensitive() {
955 let test_cases = vec!["google", "Google", "GOOGLE", "GoOgLe"];
956
957 for test_case in test_cases {
958 let provider: Result<Provider, _> = test_case.try_into();
959 assert!(provider.is_ok(), "Failed to parse: {}", test_case);
960 assert_eq!(provider.unwrap(), Provider::Google);
961 }
962
963 let json_cases = vec![r#""google""#, r#""Google""#, r#""GOOGLE""#];
964
965 for json_case in json_cases {
966 let provider: Provider = serde_json::from_str(json_case).unwrap();
967 assert_eq!(provider, Provider::Google);
968 }
969
970 assert_eq!(Provider::Google.to_string(), "google");
971 }
972
973 #[test]
974 fn test_generate_request_serialization() {
975 let request_payload = CreateChatCompletionRequest {
976 model: "llama3.2:1b".to_string(),
977 messages: vec![
978 Message {
979 role: MessageRole::System,
980 content: "You are a helpful assistant.".to_string(),
981 ..Default::default()
982 },
983 Message {
984 role: MessageRole::User,
985 content: "What is the current weather in Toronto?".to_string(),
986 ..Default::default()
987 },
988 ],
989 stream: false,
990 tools: Some(vec![Tool {
991 r#type: ToolType::Function,
992 function: FunctionObject {
993 name: "get_current_weather".to_string(),
994 description: "Get the current weather of a city".to_string(),
995 parameters: json!({
996 "type": "object",
997 "properties": {
998 "city": {
999 "type": "string",
1000 "description": "The name of the city"
1001 }
1002 },
1003 "required": ["city"]
1004 }),
1005 },
1006 }]),
1007 max_tokens: None,
1008 reasoning_format: None,
1009 };
1010
1011 let serialized = serde_json::to_string_pretty(&request_payload).unwrap();
1012 let expected = r#"{
1013 "model": "llama3.2:1b",
1014 "messages": [
1015 {
1016 "role": "system",
1017 "content": "You are a helpful assistant."
1018 },
1019 {
1020 "role": "user",
1021 "content": "What is the current weather in Toronto?"
1022 }
1023 ],
1024 "stream": false,
1025 "tools": [
1026 {
1027 "type": "function",
1028 "function": {
1029 "name": "get_current_weather",
1030 "description": "Get the current weather of a city",
1031 "parameters": {
1032 "type": "object",
1033 "properties": {
1034 "city": {
1035 "type": "string",
1036 "description": "The name of the city"
1037 }
1038 },
1039 "required": ["city"]
1040 }
1041 }
1042 }
1043 ]
1044 }"#;
1045
1046 assert_eq!(
1047 serde_json::from_str::<serde_json::Value>(&serialized).unwrap(),
1048 serde_json::from_str::<serde_json::Value>(expected).unwrap()
1049 );
1050 }
1051
1052 #[tokio::test]
1053 async fn test_authentication_header() -> Result<(), GatewayError> {
1054 let mut server = Server::new_async().await;
1055
1056 let mock_response = r#"{
1057 "object": "list",
1058 "data": []
1059 }"#;
1060
1061 let mock_with_auth = server
1062 .mock("GET", "/v1/models")
1063 .match_header("authorization", "Bearer test-token")
1064 .with_status(200)
1065 .with_header("content-type", "application/json")
1066 .with_body(mock_response)
1067 .expect(1)
1068 .create();
1069
1070 let base_url = format!("{}/v1", server.url());
1071 let client = InferenceGatewayClient::new(&base_url).with_token("test-token");
1072 client.list_models().await?;
1073 mock_with_auth.assert();
1074
1075 let mock_without_auth = server
1076 .mock("GET", "/v1/models")
1077 .match_header("authorization", Matcher::Missing)
1078 .with_status(200)
1079 .with_header("content-type", "application/json")
1080 .with_body(mock_response)
1081 .expect(1)
1082 .create();
1083
1084 let base_url = format!("{}/v1", server.url());
1085 let client = InferenceGatewayClient::new(&base_url);
1086 client.list_models().await?;
1087 mock_without_auth.assert();
1088
1089 Ok(())
1090 }
1091
1092 #[tokio::test]
1093 async fn test_unauthorized_error() -> Result<(), GatewayError> {
1094 let mut server = Server::new_async().await;
1095
1096 let raw_json_response = r#"{
1097 "error": "Invalid token"
1098 }"#;
1099
1100 let mock = server
1101 .mock("GET", "/v1/models")
1102 .with_status(401)
1103 .with_header("content-type", "application/json")
1104 .with_body(raw_json_response)
1105 .create();
1106
1107 let base_url = format!("{}/v1", server.url());
1108 let client = InferenceGatewayClient::new(&base_url);
1109 let error = client.list_models().await.unwrap_err();
1110
1111 assert!(matches!(error, GatewayError::Unauthorized(_)));
1112 if let GatewayError::Unauthorized(msg) = error {
1113 assert_eq!(msg, "Invalid token");
1114 }
1115 mock.assert();
1116
1117 Ok(())
1118 }
1119
1120 #[tokio::test]
1121 async fn test_list_models() -> Result<(), GatewayError> {
1122 let mut server = Server::new_async().await;
1123
1124 let raw_response_json = r#"{
1125 "object": "list",
1126 "data": [
1127 {
1128 "id": "llama2",
1129 "object": "model",
1130 "created": 1630000001,
1131 "owned_by": "ollama",
1132 "served_by": "ollama"
1133 }
1134 ]
1135 }"#;
1136
1137 let mock = server
1138 .mock("GET", "/v1/models")
1139 .with_status(200)
1140 .with_header("content-type", "application/json")
1141 .with_body(raw_response_json)
1142 .create();
1143
1144 let base_url = format!("{}/v1", server.url());
1145 let client = InferenceGatewayClient::new(&base_url);
1146 let response = client.list_models().await?;
1147
1148 assert!(response.provider.is_none());
1149 assert_eq!(response.object, "list");
1150 assert_eq!(response.data.len(), 1);
1151 assert_eq!(response.data[0].id, "llama2");
1152 mock.assert();
1153
1154 Ok(())
1155 }
1156
1157 #[tokio::test]
1158 async fn test_list_models_by_provider() -> Result<(), GatewayError> {
1159 let mut server = Server::new_async().await;
1160
1161 let raw_json_response = r#"{
1162 "provider":"ollama",
1163 "object":"list",
1164 "data": [
1165 {
1166 "id": "llama2",
1167 "object": "model",
1168 "created": 1630000001,
1169 "owned_by": "ollama",
1170 "served_by": "ollama"
1171 }
1172 ]
1173 }"#;
1174
1175 let mock = server
1176 .mock("GET", "/v1/models?provider=ollama")
1177 .with_status(200)
1178 .with_header("content-type", "application/json")
1179 .with_body(raw_json_response)
1180 .create();
1181
1182 let base_url = format!("{}/v1", server.url());
1183 let client = InferenceGatewayClient::new(&base_url);
1184 let response = client.list_models_by_provider(Provider::Ollama).await?;
1185
1186 assert!(response.provider.is_some());
1187 assert_eq!(response.provider, Some(Provider::Ollama));
1188 assert_eq!(response.data[0].id, "llama2");
1189 mock.assert();
1190
1191 Ok(())
1192 }
1193
1194 #[tokio::test]
1195 async fn test_generate_content() -> Result<(), GatewayError> {
1196 let mut server = Server::new_async().await;
1197
1198 let raw_json_response = r#"{
1199 "id": "chatcmpl-456",
1200 "object": "chat.completion",
1201 "created": 1630000001,
1202 "model": "mixtral-8x7b",
1203 "choices": [
1204 {
1205 "index": 0,
1206 "finish_reason": "stop",
1207 "logprobs": null,
1208 "message": {
1209 "role": "assistant",
1210 "content": "Hellloooo"
1211 }
1212 }
1213 ]
1214 }"#;
1215
1216 let mock = server
1217 .mock("POST", "/v1/chat/completions?provider=ollama")
1218 .with_status(200)
1219 .with_header("content-type", "application/json")
1220 .with_body(raw_json_response)
1221 .create();
1222
1223 let base_url = format!("{}/v1", server.url());
1224 let client = InferenceGatewayClient::new(&base_url);
1225
1226 let messages = vec![Message {
1227 role: MessageRole::User,
1228 content: "Hello".to_string(),
1229 ..Default::default()
1230 }];
1231 let response = client
1232 .generate_content(Provider::Ollama, "llama2", messages)
1233 .await?;
1234
1235 assert_eq!(response.choices[0].message.role, MessageRole::Assistant);
1236 assert_eq!(response.choices[0].message.content, "Hellloooo");
1237 mock.assert();
1238
1239 Ok(())
1240 }
1241
1242 #[tokio::test]
1243 async fn test_generate_content_serialization() -> Result<(), GatewayError> {
1244 let mut server = Server::new_async().await;
1245
1246 let raw_json = r#"{
1247 "id": "chatcmpl-456",
1248 "object": "chat.completion",
1249 "created": 1630000001,
1250 "model": "mixtral-8x7b",
1251 "choices": [
1252 {
1253 "index": 0,
1254 "finish_reason": "stop",
1255 "logprobs": null,
1256 "message": {
1257 "role": "assistant",
1258 "content": "Hello"
1259 }
1260 }
1261 ]
1262 }"#;
1263
1264 let mock = server
1265 .mock("POST", "/v1/chat/completions?provider=groq")
1266 .with_status(200)
1267 .with_header("content-type", "application/json")
1268 .with_body(raw_json)
1269 .create();
1270
1271 let base_url = format!("{}/v1", server.url());
1272 let client = InferenceGatewayClient::new(&base_url);
1273
1274 let direct_parse: Result<CreateChatCompletionResponse, _> = serde_json::from_str(raw_json);
1275 assert!(
1276 direct_parse.is_ok(),
1277 "Direct JSON parse failed: {:?}",
1278 direct_parse.err()
1279 );
1280
1281 let messages = vec![Message {
1282 role: MessageRole::User,
1283 content: "Hello".to_string(),
1284 ..Default::default()
1285 }];
1286
1287 let response = client
1288 .generate_content(Provider::Groq, "mixtral-8x7b", messages)
1289 .await?;
1290
1291 assert_eq!(response.choices[0].message.role, MessageRole::Assistant);
1292 assert_eq!(response.choices[0].message.content, "Hello");
1293
1294 mock.assert();
1295 Ok(())
1296 }
1297
1298 #[tokio::test]
1299 async fn test_generate_content_error_response() -> Result<(), GatewayError> {
1300 let mut server = Server::new_async().await;
1301
1302 let raw_json_response = r#"{
1303 "error":"Invalid request"
1304 }"#;
1305
1306 let mock = server
1307 .mock("POST", "/v1/chat/completions?provider=groq")
1308 .with_status(400)
1309 .with_header("content-type", "application/json")
1310 .with_body(raw_json_response)
1311 .create();
1312
1313 let base_url = format!("{}/v1", server.url());
1314 let client = InferenceGatewayClient::new(&base_url);
1315 let messages = vec![Message {
1316 role: MessageRole::User,
1317 content: "Hello".to_string(),
1318 ..Default::default()
1319 }];
1320 let error = client
1321 .generate_content(Provider::Groq, "mixtral-8x7b", messages)
1322 .await
1323 .unwrap_err();
1324
1325 assert!(matches!(error, GatewayError::BadRequest(_)));
1326 if let GatewayError::BadRequest(msg) = error {
1327 assert_eq!(msg, "Invalid request");
1328 }
1329 mock.assert();
1330
1331 Ok(())
1332 }
1333
1334 #[tokio::test]
1335 async fn test_gateway_errors() -> Result<(), GatewayError> {
1336 let mut server: mockito::ServerGuard = Server::new_async().await;
1337
1338 let unauthorized_mock = server
1339 .mock("GET", "/v1/models")
1340 .with_status(401)
1341 .with_header("content-type", "application/json")
1342 .with_body(r#"{"error":"Invalid token"}"#)
1343 .create();
1344
1345 let base_url = format!("{}/v1", server.url());
1346 let client = InferenceGatewayClient::new(&base_url);
1347 match client.list_models().await {
1348 Err(GatewayError::Unauthorized(msg)) => assert_eq!(msg, "Invalid token"),
1349 _ => panic!("Expected Unauthorized error"),
1350 }
1351 unauthorized_mock.assert();
1352
1353 let bad_request_mock = server
1354 .mock("GET", "/v1/models")
1355 .with_status(400)
1356 .with_header("content-type", "application/json")
1357 .with_body(r#"{"error":"Invalid provider"}"#)
1358 .create();
1359
1360 match client.list_models().await {
1361 Err(GatewayError::BadRequest(msg)) => assert_eq!(msg, "Invalid provider"),
1362 _ => panic!("Expected BadRequest error"),
1363 }
1364 bad_request_mock.assert();
1365
1366 let internal_error_mock = server
1367 .mock("GET", "/v1/models")
1368 .with_status(500)
1369 .with_header("content-type", "application/json")
1370 .with_body(r#"{"error":"Internal server error occurred"}"#)
1371 .create();
1372
1373 match client.list_models().await {
1374 Err(GatewayError::InternalError(msg)) => {
1375 assert_eq!(msg, "Internal server error occurred")
1376 }
1377 _ => panic!("Expected InternalError error"),
1378 }
1379 internal_error_mock.assert();
1380
1381 Ok(())
1382 }
1383
1384 #[tokio::test]
1385 async fn test_generate_content_case_insensitive() -> Result<(), GatewayError> {
1386 let mut server = Server::new_async().await;
1387
1388 let raw_json = r#"{
1389 "id": "chatcmpl-456",
1390 "object": "chat.completion",
1391 "created": 1630000001,
1392 "model": "mixtral-8x7b",
1393 "choices": [
1394 {
1395 "index": 0,
1396 "finish_reason": "stop",
1397 "logprobs": null,
1398 "message": {
1399 "role": "assistant",
1400 "content": "Hello"
1401 }
1402 }
1403 ]
1404 }"#;
1405
1406 let mock = server
1407 .mock("POST", "/v1/chat/completions?provider=groq")
1408 .with_status(200)
1409 .with_header("content-type", "application/json")
1410 .with_body(raw_json)
1411 .create();
1412
1413 let base_url = format!("{}/v1", server.url());
1414 let client = InferenceGatewayClient::new(&base_url);
1415
1416 let messages = vec![Message {
1417 role: MessageRole::User,
1418 content: "Hello".to_string(),
1419 ..Default::default()
1420 }];
1421
1422 let response = client
1423 .generate_content(Provider::Groq, "mixtral-8x7b", messages)
1424 .await?;
1425
1426 assert_eq!(response.choices[0].message.role, MessageRole::Assistant);
1427 assert_eq!(response.choices[0].message.content, "Hello");
1428 assert_eq!(response.model, "mixtral-8x7b");
1429 assert_eq!(response.object, "chat.completion");
1430 mock.assert();
1431
1432 Ok(())
1433 }
1434
1435 #[tokio::test]
1436 async fn test_generate_content_stream() -> Result<(), GatewayError> {
1437 let mut server = Server::new_async().await;
1438
1439 let mock = server
1440 .mock("POST", "/v1/chat/completions?provider=groq")
1441 .with_status(200)
1442 .with_header("content-type", "text/event-stream")
1443 .with_chunked_body(move |writer| -> std::io::Result<()> {
1444 let events = vec![
1445 format!("data: {}\n\n", r#"{"id":"chatcmpl-123","object":"chat.completion.chunk","created":1694268190,"model":"mixtral-8x7b","system_fingerprint":"fp_","choices":[{"index":0,"delta":{"role":"assistant","content":"Hello"},"finish_reason":null}]}"#),
1446 format!("data: {}\n\n", r#"{"id":"chatcmpl-123","object":"chat.completion.chunk","created":1694268191,"model":"mixtral-8x7b","system_fingerprint":"fp_","choices":[{"index":0,"delta":{"role":"assistant","content":" World"},"finish_reason":null}]}"#),
1447 format!("data: {}\n\n", r#"{"id":"chatcmpl-123","object":"chat.completion.chunk","created":1694268192,"model":"mixtral-8x7b","system_fingerprint":"fp_","choices":[{"index":0,"delta":{},"finish_reason":"stop"}],"usage":{"prompt_tokens":17,"completion_tokens":40,"total_tokens":57}}"#),
1448 format!("data: [DONE]\n\n")
1449 ];
1450 for event in events {
1451 writer.write_all(event.as_bytes())?;
1452 }
1453 Ok(())
1454 })
1455 .create();
1456
1457 let base_url = format!("{}/v1", server.url());
1458 let client = InferenceGatewayClient::new(&base_url);
1459
1460 let messages = vec![Message {
1461 role: MessageRole::User,
1462 content: "Test message".to_string(),
1463 ..Default::default()
1464 }];
1465
1466 let stream = client.generate_content_stream(Provider::Groq, "mixtral-8x7b", messages);
1467 pin_mut!(stream);
1468 while let Some(result) = stream.next().await {
1469 let result = result?;
1470 let generate_response: CreateChatCompletionStreamResponse =
1471 serde_json::from_str(&result.data)
1472 .expect("Failed to parse CreateChatCompletionResponse");
1473
1474 if generate_response.choices[0].finish_reason.is_some() {
1475 assert_eq!(
1476 generate_response.choices[0].finish_reason.as_ref().unwrap(),
1477 &FinishReason::Stop
1478 );
1479 break;
1480 }
1481
1482 if let Some(content) = &generate_response.choices[0].delta.content {
1483 assert!(matches!(content.as_str(), "Hello" | " World"));
1484 }
1485 if let Some(role) = &generate_response.choices[0].delta.role {
1486 assert_eq!(role, &MessageRole::Assistant);
1487 }
1488 }
1489
1490 mock.assert();
1491 Ok(())
1492 }
1493
1494 #[tokio::test]
1495 async fn test_generate_content_stream_error() -> Result<(), GatewayError> {
1496 let mut server = Server::new_async().await;
1497
1498 let mock = server
1499 .mock("POST", "/v1/chat/completions?provider=groq")
1500 .with_status(400)
1501 .with_header("content-type", "application/json")
1502 .with_chunked_body(move |writer| -> std::io::Result<()> {
1503 let events = vec![format!(
1504 "event: {}\ndata: {}\nretry: {}\n\n",
1505 r#"error"#, r#"{"error":"Invalid request"}"#, r#"1000"#,
1506 )];
1507 for event in events {
1508 writer.write_all(event.as_bytes())?;
1509 }
1510 Ok(())
1511 })
1512 .expect_at_least(1)
1513 .create();
1514
1515 let base_url = format!("{}/v1", server.url());
1516 let client = InferenceGatewayClient::new(&base_url);
1517
1518 let messages = vec![Message {
1519 role: MessageRole::User,
1520 content: "Test message".to_string(),
1521 ..Default::default()
1522 }];
1523
1524 let stream = client.generate_content_stream(Provider::Groq, "mixtral-8x7b", messages);
1525
1526 pin_mut!(stream);
1527 while let Some(result) = stream.next().await {
1528 let result = result?;
1529 assert!(result.event.is_some());
1530 assert_eq!(result.event.unwrap(), "error");
1531 assert!(result.data.contains("Invalid request"));
1532 assert!(result.retry.is_none());
1533 }
1534
1535 mock.assert();
1536 Ok(())
1537 }
1538
1539 #[tokio::test]
1540 async fn test_generate_content_with_tools() -> Result<(), GatewayError> {
1541 let mut server = Server::new_async().await;
1542
1543 let raw_json_response = r#"{
1544 "id": "chatcmpl-123",
1545 "object": "chat.completion",
1546 "created": 1630000000,
1547 "model": "deepseek-r1-distill-llama-70b",
1548 "choices": [
1549 {
1550 "index": 0,
1551 "finish_reason": "tool_calls",
1552 "logprobs": null,
1553 "message": {
1554 "role": "assistant",
1555 "content": "Let me check the weather for you.",
1556 "tool_calls": [
1557 {
1558 "id": "1234",
1559 "type": "function",
1560 "function": {
1561 "name": "get_weather",
1562 "arguments": "{\"location\": \"London\"}"
1563 }
1564 }
1565 ]
1566 }
1567 }
1568 ]
1569 }"#;
1570
1571 let mock = server
1572 .mock("POST", "/v1/chat/completions?provider=groq")
1573 .with_status(200)
1574 .with_header("content-type", "application/json")
1575 .with_body(raw_json_response)
1576 .create();
1577
1578 let tools = vec![Tool {
1579 r#type: ToolType::Function,
1580 function: FunctionObject {
1581 name: "get_weather".to_string(),
1582 description: "Get the weather for a location".to_string(),
1583 parameters: json!({
1584 "type": "object",
1585 "properties": {
1586 "location": {
1587 "type": "string",
1588 "description": "The city name"
1589 }
1590 },
1591 "required": ["location"]
1592 }),
1593 },
1594 }];
1595
1596 let base_url = format!("{}/v1", server.url());
1597 let client = InferenceGatewayClient::new(&base_url).with_tools(Some(tools));
1598
1599 let messages = vec![Message {
1600 role: MessageRole::User,
1601 content: "What's the weather in London?".to_string(),
1602 ..Default::default()
1603 }];
1604
1605 let response = client
1606 .generate_content(Provider::Groq, "deepseek-r1-distill-llama-70b", messages)
1607 .await?;
1608
1609 assert_eq!(response.choices[0].message.role, MessageRole::Assistant);
1610 assert_eq!(
1611 response.choices[0].message.content,
1612 "Let me check the weather for you."
1613 );
1614
1615 let tool_calls = response.choices[0].message.tool_calls.as_ref().unwrap();
1616 assert_eq!(tool_calls.len(), 1);
1617 assert_eq!(tool_calls[0].function.name, "get_weather");
1618
1619 let params = tool_calls[0]
1620 .function
1621 .parse_arguments()
1622 .expect("Failed to parse function arguments");
1623 assert_eq!(params["location"].as_str().unwrap(), "London");
1624
1625 mock.assert();
1626 Ok(())
1627 }
1628
1629 #[tokio::test]
1630 async fn test_generate_content_without_tools() -> Result<(), GatewayError> {
1631 let mut server = Server::new_async().await;
1632
1633 let raw_json_response = r#"{
1634 "id": "chatcmpl-123",
1635 "object": "chat.completion",
1636 "created": 1630000000,
1637 "model": "gpt-4",
1638 "choices": [
1639 {
1640 "index": 0,
1641 "finish_reason": "stop",
1642 "logprobs": null,
1643 "message": {
1644 "role": "assistant",
1645 "content": "Hello!"
1646 }
1647 }
1648 ]
1649 }"#;
1650
1651 let mock = server
1652 .mock("POST", "/v1/chat/completions?provider=openai")
1653 .with_status(200)
1654 .with_header("content-type", "application/json")
1655 .with_body(raw_json_response)
1656 .create();
1657
1658 let base_url = format!("{}/v1", server.url());
1659 let client = InferenceGatewayClient::new(&base_url);
1660
1661 let messages = vec![Message {
1662 role: MessageRole::User,
1663 content: "Hi".to_string(),
1664 ..Default::default()
1665 }];
1666
1667 let response = client
1668 .generate_content(Provider::OpenAI, "gpt-4", messages)
1669 .await?;
1670
1671 assert_eq!(response.model, "gpt-4");
1672 assert_eq!(response.choices[0].message.content, "Hello!");
1673 assert_eq!(response.choices[0].message.role, MessageRole::Assistant);
1674 assert!(response.choices[0].message.tool_calls.is_none());
1675
1676 mock.assert();
1677 Ok(())
1678 }
1679
1680 #[tokio::test]
1681 async fn test_generate_content_with_tools_payload() -> Result<(), GatewayError> {
1682 let mut server = Server::new_async().await;
1683
1684 let raw_request_body = r#"{
1685 "model": "deepseek-r1-distill-llama-70b",
1686 "messages": [
1687 {
1688 "role": "system",
1689 "content": "You are a helpful assistant."
1690 },
1691 {
1692 "role": "user",
1693 "content": "What is the current weather in Toronto?"
1694 }
1695 ],
1696 "stream": false,
1697 "tools": [
1698 {
1699 "type": "function",
1700 "function": {
1701 "name": "get_current_weather",
1702 "description": "Get the current weather of a city",
1703 "parameters": {
1704 "type": "object",
1705 "properties": {
1706 "city": {
1707 "type": "string",
1708 "description": "The name of the city"
1709 }
1710 },
1711 "required": ["city"]
1712 }
1713 }
1714 }
1715 ]
1716 }"#;
1717
1718 let raw_json_response = r#"{
1719 "id": "1234",
1720 "object": "chat.completion",
1721 "created": 1630000000,
1722 "model": "deepseek-r1-distill-llama-70b",
1723 "choices": [
1724 {
1725 "index": 0,
1726 "finish_reason": "stop",
1727 "logprobs": null,
1728 "message": {
1729 "role": "assistant",
1730 "content": "Let me check the weather for you",
1731 "tool_calls": [
1732 {
1733 "id": "1234",
1734 "type": "function",
1735 "function": {
1736 "name": "get_current_weather",
1737 "arguments": "{\"city\": \"Toronto\"}"
1738 }
1739 }
1740 ]
1741 }
1742 }
1743 ]
1744 }"#;
1745
1746 let mock = server
1747 .mock("POST", "/v1/chat/completions?provider=groq")
1748 .with_status(200)
1749 .with_header("content-type", "application/json")
1750 .match_body(mockito::Matcher::JsonString(raw_request_body.to_string()))
1751 .with_body(raw_json_response)
1752 .create();
1753
1754 let tools = vec![Tool {
1755 r#type: ToolType::Function,
1756 function: FunctionObject {
1757 name: "get_current_weather".to_string(),
1758 description: "Get the current weather of a city".to_string(),
1759 parameters: json!({
1760 "type": "object",
1761 "properties": {
1762 "city": {
1763 "type": "string",
1764 "description": "The name of the city"
1765 }
1766 },
1767 "required": ["city"]
1768 }),
1769 },
1770 }];
1771
1772 let base_url = format!("{}/v1", server.url());
1773 let client = InferenceGatewayClient::new(&base_url);
1774
1775 let messages = vec![
1776 Message {
1777 role: MessageRole::System,
1778 content: "You are a helpful assistant.".to_string(),
1779 ..Default::default()
1780 },
1781 Message {
1782 role: MessageRole::User,
1783 content: "What is the current weather in Toronto?".to_string(),
1784 ..Default::default()
1785 },
1786 ];
1787
1788 let response = client
1789 .with_tools(Some(tools))
1790 .generate_content(Provider::Groq, "deepseek-r1-distill-llama-70b", messages)
1791 .await?;
1792
1793 assert_eq!(response.choices[0].message.role, MessageRole::Assistant);
1794 assert_eq!(
1795 response.choices[0].message.content,
1796 "Let me check the weather for you"
1797 );
1798 assert_eq!(
1799 response.choices[0]
1800 .message
1801 .tool_calls
1802 .as_ref()
1803 .unwrap()
1804 .len(),
1805 1
1806 );
1807
1808 mock.assert();
1809 Ok(())
1810 }
1811
1812 #[tokio::test]
1813 async fn test_generate_content_with_max_tokens() -> Result<(), GatewayError> {
1814 let mut server = Server::new_async().await;
1815
1816 let raw_json_response = r#"{
1817 "id": "chatcmpl-123",
1818 "object": "chat.completion",
1819 "created": 1630000000,
1820 "model": "mixtral-8x7b",
1821 "choices": [
1822 {
1823 "index": 0,
1824 "finish_reason": "stop",
1825 "logprobs": null,
1826 "message": {
1827 "role": "assistant",
1828 "content": "Here's a poem with 100 tokens..."
1829 }
1830 }
1831 ]
1832 }"#;
1833
1834 let mock = server
1835 .mock("POST", "/v1/chat/completions?provider=groq")
1836 .with_status(200)
1837 .with_header("content-type", "application/json")
1838 .match_body(mockito::Matcher::JsonString(
1839 r#"{
1840 "model": "mixtral-8x7b",
1841 "messages": [{"role":"user","content":"Write a poem"}],
1842 "stream": false,
1843 "max_tokens": 100
1844 }"#
1845 .to_string(),
1846 ))
1847 .with_body(raw_json_response)
1848 .create();
1849
1850 let base_url = format!("{}/v1", server.url());
1851 let client = InferenceGatewayClient::new(&base_url).with_max_tokens(Some(100));
1852
1853 let messages = vec![Message {
1854 role: MessageRole::User,
1855 content: "Write a poem".to_string(),
1856 ..Default::default()
1857 }];
1858
1859 let response = client
1860 .generate_content(Provider::Groq, "mixtral-8x7b", messages)
1861 .await?;
1862
1863 assert_eq!(
1864 response.choices[0].message.content,
1865 "Here's a poem with 100 tokens..."
1866 );
1867 assert_eq!(response.model, "mixtral-8x7b");
1868 assert_eq!(response.created, 1630000000);
1869 assert_eq!(response.object, "chat.completion");
1870
1871 mock.assert();
1872 Ok(())
1873 }
1874
1875 #[tokio::test]
1876 async fn test_health_check() -> Result<(), GatewayError> {
1877 let mut server = Server::new_async().await;
1878 let mock = server.mock("GET", "/health").with_status(200).create();
1879
1880 let client = InferenceGatewayClient::new(&server.url());
1881 let is_healthy = client.health_check().await?;
1882
1883 assert!(is_healthy);
1884 mock.assert();
1885
1886 Ok(())
1887 }
1888
1889 #[tokio::test]
1890 async fn test_client_base_url_configuration() -> Result<(), GatewayError> {
1891 let mut custom_url_server = Server::new_async().await;
1892
1893 let custom_url_mock = custom_url_server
1894 .mock("GET", "/health")
1895 .with_status(200)
1896 .create();
1897
1898 let custom_client = InferenceGatewayClient::new(&custom_url_server.url());
1899 let is_healthy = custom_client.health_check().await?;
1900 assert!(is_healthy);
1901 custom_url_mock.assert();
1902
1903 let default_client = InferenceGatewayClient::new_default();
1904
1905 let default_url = "http://localhost:8080/v1";
1906 assert_eq!(default_client.base_url(), default_url);
1907
1908 Ok(())
1909 }
1910
1911 #[tokio::test]
1912 async fn test_list_tools() -> Result<(), GatewayError> {
1913 let mut server = Server::new_async().await;
1914
1915 let raw_response_json = r#"{
1916 "object": "list",
1917 "data": [
1918 {
1919 "name": "read_file",
1920 "description": "Read content from a file",
1921 "server": "http://mcp-filesystem-server:8083/mcp",
1922 "input_schema": {
1923 "type": "object",
1924 "properties": {
1925 "file_path": {
1926 "type": "string",
1927 "description": "Path to the file to read"
1928 }
1929 },
1930 "required": ["file_path"]
1931 }
1932 },
1933 {
1934 "name": "write_file",
1935 "description": "Write content to a file",
1936 "server": "http://mcp-filesystem-server:8083/mcp"
1937 }
1938 ]
1939 }"#;
1940
1941 let mock = server
1942 .mock("GET", "/v1/mcp/tools")
1943 .with_status(200)
1944 .with_header("content-type", "application/json")
1945 .with_body(raw_response_json)
1946 .create();
1947
1948 let base_url = format!("{}/v1", server.url());
1949 let client = InferenceGatewayClient::new(&base_url);
1950 let response = client.list_tools().await?;
1951
1952 assert_eq!(response.object, "list");
1953 assert_eq!(response.data.len(), 2);
1954
1955 assert_eq!(response.data[0].name, "read_file");
1957 assert_eq!(response.data[0].description, "Read content from a file");
1958 assert_eq!(
1959 response.data[0].server,
1960 "http://mcp-filesystem-server:8083/mcp"
1961 );
1962 assert!(response.data[0].input_schema.is_some());
1963
1964 assert_eq!(response.data[1].name, "write_file");
1966 assert_eq!(response.data[1].description, "Write content to a file");
1967 assert_eq!(
1968 response.data[1].server,
1969 "http://mcp-filesystem-server:8083/mcp"
1970 );
1971 assert!(response.data[1].input_schema.is_none());
1972
1973 mock.assert();
1974 Ok(())
1975 }
1976
1977 #[tokio::test]
1978 async fn test_list_tools_with_authentication() -> Result<(), GatewayError> {
1979 let mut server = Server::new_async().await;
1980
1981 let raw_response_json = r#"{
1982 "object": "list",
1983 "data": []
1984 }"#;
1985
1986 let mock = server
1987 .mock("GET", "/v1/mcp/tools")
1988 .match_header("authorization", "Bearer test-token")
1989 .with_status(200)
1990 .with_header("content-type", "application/json")
1991 .with_body(raw_response_json)
1992 .create();
1993
1994 let base_url = format!("{}/v1", server.url());
1995 let client = InferenceGatewayClient::new(&base_url).with_token("test-token");
1996 let response = client.list_tools().await?;
1997
1998 assert_eq!(response.object, "list");
1999 assert_eq!(response.data.len(), 0);
2000 mock.assert();
2001 Ok(())
2002 }
2003
2004 #[tokio::test]
2005 async fn test_list_tools_mcp_not_exposed() -> Result<(), GatewayError> {
2006 let mut server = Server::new_async().await;
2007
2008 let mock = server
2009 .mock("GET", "/v1/mcp/tools")
2010 .with_status(403)
2011 .with_header("content-type", "application/json")
2012 .with_body(
2013 r#"{"error":"MCP tools endpoint is not exposed. Set EXPOSE_MCP=true to enable."}"#,
2014 )
2015 .create();
2016
2017 let base_url = format!("{}/v1", server.url());
2018 let client = InferenceGatewayClient::new(&base_url);
2019
2020 match client.list_tools().await {
2021 Err(GatewayError::Forbidden(msg)) => {
2022 assert_eq!(
2023 msg,
2024 "MCP tools endpoint is not exposed. Set EXPOSE_MCP=true to enable."
2025 );
2026 }
2027 _ => panic!("Expected Forbidden error for MCP not exposed"),
2028 }
2029
2030 mock.assert();
2031 Ok(())
2032 }
2033}