1use core::fmt;
7use std::future::Future;
8
9use futures_util::{Stream, StreamExt};
10use reqwest::{Client, StatusCode};
11use serde::{Deserialize, Serialize};
12
13use serde_json::Value;
14use thiserror::Error;
15
16#[derive(Debug, Serialize, Deserialize)]
18pub struct SSEvents {
19    pub data: String,
20    pub event: Option<String>,
21    pub retry: Option<u64>,
22}
23
24#[derive(Error, Debug)]
26pub enum GatewayError {
27    #[error("Unauthorized: {0}")]
28    Unauthorized(String),
29
30    #[error("Forbidden: {0}")]
31    Forbidden(String),
32
33    #[error("Bad request: {0}")]
34    BadRequest(String),
35
36    #[error("Internal server error: {0}")]
37    InternalError(String),
38
39    #[error("Stream error: {0}")]
40    StreamError(reqwest::Error),
41
42    #[error("Decoding error: {0}")]
43    DecodingError(std::string::FromUtf8Error),
44
45    #[error("Request error: {0}")]
46    RequestError(#[from] reqwest::Error),
47
48    #[error("Deserialization error: {0}")]
49    DeserializationError(serde_json::Error),
50
51    #[error("Serialization error: {0}")]
52    SerializationError(#[from] serde_json::Error),
53
54    #[error("Other error: {0}")]
55    Other(#[from] Box<dyn std::error::Error + Send + Sync>),
56}
57
58#[derive(Debug, Deserialize)]
59struct ErrorResponse {
60    error: String,
61}
62
63#[derive(Debug, Serialize, Deserialize, Clone)]
65pub struct Model {
66    pub id: String,
68    pub object: Option<String>,
70    pub created: Option<i64>,
72    pub owned_by: Option<String>,
74    pub served_by: Option<String>,
76}
77
78#[derive(Debug, Serialize, Deserialize)]
80pub struct ListModelsResponse {
81    #[serde(skip_serializing_if = "Option::is_none")]
83    pub provider: Option<Provider>,
84    pub object: String,
86    pub data: Vec<Model>,
88}
89
90#[derive(Debug, Serialize, Deserialize, Clone)]
92pub struct MCPTool {
93    pub name: String,
95    pub description: String,
97    pub server: String,
99    #[serde(skip_serializing_if = "Option::is_none")]
101    pub input_schema: Option<Value>,
102}
103
104#[derive(Debug, Serialize, Deserialize)]
106pub struct ListToolsResponse {
107    pub object: String,
109    pub data: Vec<MCPTool>,
111}
112
113#[derive(Debug, Serialize, Deserialize, PartialEq, Clone, Copy)]
115#[serde(rename_all = "lowercase")]
116pub enum Provider {
117    #[serde(alias = "Ollama", alias = "OLLAMA")]
118    Ollama,
119    #[serde(alias = "Groq", alias = "GROQ")]
120    Groq,
121    #[serde(alias = "OpenAI", alias = "OPENAI")]
122    OpenAI,
123    #[serde(alias = "Cloudflare", alias = "CLOUDFLARE")]
124    Cloudflare,
125    #[serde(alias = "Cohere", alias = "COHERE")]
126    Cohere,
127    #[serde(alias = "Anthropic", alias = "ANTHROPIC")]
128    Anthropic,
129    #[serde(alias = "Deepseek", alias = "DEEPSEEK")]
130    Deepseek,
131    #[serde(alias = "Google", alias = "GOOGLE")]
132    Google,
133}
134
135impl fmt::Display for Provider {
136    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
137        match self {
138            Provider::Ollama => write!(f, "ollama"),
139            Provider::Groq => write!(f, "groq"),
140            Provider::OpenAI => write!(f, "openai"),
141            Provider::Cloudflare => write!(f, "cloudflare"),
142            Provider::Cohere => write!(f, "cohere"),
143            Provider::Anthropic => write!(f, "anthropic"),
144            Provider::Deepseek => write!(f, "deepseek"),
145            Provider::Google => write!(f, "google"),
146        }
147    }
148}
149
150impl TryFrom<&str> for Provider {
151    type Error = GatewayError;
152
153    fn try_from(s: &str) -> Result<Self, Self::Error> {
154        match s.to_lowercase().as_str() {
155            "ollama" => Ok(Self::Ollama),
156            "groq" => Ok(Self::Groq),
157            "openai" => Ok(Self::OpenAI),
158            "cloudflare" => Ok(Self::Cloudflare),
159            "cohere" => Ok(Self::Cohere),
160            "anthropic" => Ok(Self::Anthropic),
161            "deepseek" => Ok(Self::Deepseek),
162            "google" => Ok(Self::Google),
163            _ => Err(GatewayError::BadRequest(format!("Unknown provider: {s}"))),
164        }
165    }
166}
167
168#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Default)]
169#[serde(rename_all = "lowercase")]
170pub enum MessageRole {
171    System,
172    #[default]
173    User,
174    Assistant,
175    Tool,
176}
177
178impl fmt::Display for MessageRole {
179    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
180        match self {
181            MessageRole::System => write!(f, "system"),
182            MessageRole::User => write!(f, "user"),
183            MessageRole::Assistant => write!(f, "assistant"),
184            MessageRole::Tool => write!(f, "tool"),
185        }
186    }
187}
188
189#[derive(Debug, Serialize, Deserialize, Clone, Default)]
191pub struct Message {
192    pub role: MessageRole,
194    pub content: String,
196    #[serde(skip_serializing_if = "Option::is_none")]
198    pub tool_calls: Option<Vec<ChatCompletionMessageToolCall>>,
199    #[serde(skip_serializing_if = "Option::is_none")]
201    pub tool_call_id: Option<String>,
202    #[serde(skip_serializing_if = "Option::is_none")]
204    pub reasoning: Option<String>,
205}
206
207#[derive(Debug, Deserialize, Serialize, Clone)]
209pub struct ChatCompletionMessageToolCall {
210    pub id: String,
212    #[serde(rename = "type")]
214    pub r#type: ChatCompletionToolType,
215    pub function: ChatCompletionMessageToolCallFunction,
217}
218
219#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
221pub enum ChatCompletionToolType {
222    #[serde(rename = "function")]
224    Function,
225}
226
227#[derive(Debug, Deserialize, Serialize, Clone)]
229pub struct ChatCompletionMessageToolCallFunction {
230    pub name: String,
232    pub arguments: String,
234}
235
236impl ChatCompletionMessageToolCallFunction {
238    pub fn parse_arguments(&self) -> Result<serde_json::Value, serde_json::Error> {
239        serde_json::from_str(&self.arguments)
240    }
241}
242
243#[derive(Debug, Serialize, Deserialize, Clone)]
245pub struct FunctionObject {
246    pub name: String,
247    pub description: String,
248    pub parameters: Value,
249}
250
251#[derive(Debug, Serialize, Deserialize, Clone)]
253#[serde(rename_all = "lowercase")]
254pub enum ToolType {
255    Function,
256}
257
258#[derive(Debug, Serialize, Deserialize, Clone)]
260pub struct Tool {
261    pub r#type: ToolType,
262    pub function: FunctionObject,
263}
264
265#[derive(Debug, Serialize)]
267struct CreateChatCompletionRequest {
268    model: String,
270    messages: Vec<Message>,
272    stream: bool,
274    #[serde(skip_serializing_if = "Option::is_none")]
276    tools: Option<Vec<Tool>>,
277    #[serde(skip_serializing_if = "Option::is_none")]
279    max_tokens: Option<i32>,
280}
281
282#[derive(Debug, Deserialize, Clone)]
284pub struct ToolCallResponse {
285    pub id: String,
287    #[serde(rename = "type")]
289    pub r#type: ToolType,
290    pub function: ChatCompletionMessageToolCallFunction,
292}
293
294#[derive(Debug, Deserialize, Clone)]
295pub struct ChatCompletionChoice {
296    pub finish_reason: String,
297    pub message: Message,
298    pub index: i32,
299}
300
301#[derive(Debug, Deserialize, Clone)]
303pub struct CreateChatCompletionResponse {
304    pub id: String,
305    pub choices: Vec<ChatCompletionChoice>,
306    pub created: i64,
307    pub model: String,
308    pub object: String,
309}
310
311#[derive(Debug, Deserialize, Clone)]
313pub struct CreateChatCompletionStreamResponse {
314    pub id: String,
316    pub choices: Vec<ChatCompletionStreamChoice>,
318    pub created: i64,
320    pub model: String,
322    #[serde(skip_serializing_if = "Option::is_none")]
324    pub system_fingerprint: Option<String>,
325    pub object: String,
327    #[serde(skip_serializing_if = "Option::is_none")]
329    pub usage: Option<CompletionUsage>,
330}
331
332#[derive(Debug, Deserialize, Clone)]
334pub struct ChatCompletionStreamChoice {
335    pub delta: ChatCompletionStreamDelta,
337    pub index: i32,
339    #[serde(skip_serializing_if = "Option::is_none")]
341    pub finish_reason: Option<String>,
342}
343
344#[derive(Debug, Deserialize, Clone)]
346pub struct ChatCompletionStreamDelta {
347    #[serde(skip_serializing_if = "Option::is_none")]
349    pub role: Option<MessageRole>,
350    #[serde(skip_serializing_if = "Option::is_none")]
352    pub content: Option<String>,
353    #[serde(skip_serializing_if = "Option::is_none")]
355    pub tool_calls: Option<Vec<ToolCallResponse>>,
356}
357
358#[derive(Debug, Deserialize, Clone)]
360pub struct CompletionUsage {
361    pub completion_tokens: i64,
363    pub prompt_tokens: i64,
365    pub total_tokens: i64,
367}
368
369pub struct InferenceGatewayClient {
371    base_url: String,
372    client: Client,
373    token: Option<String>,
374    tools: Option<Vec<Tool>>,
375    max_tokens: Option<i32>,
376}
377
378impl std::fmt::Debug for InferenceGatewayClient {
380    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
381        f.debug_struct("InferenceGatewayClient")
382            .field("base_url", &self.base_url)
383            .field("token", &self.token.as_ref().map(|_| "*****"))
384            .finish()
385    }
386}
387
388pub trait InferenceGatewayAPI {
390    fn list_models(&self) -> impl Future<Output = Result<ListModelsResponse, GatewayError>> + Send;
401
402    fn list_models_by_provider(
416        &self,
417        provider: Provider,
418    ) -> impl Future<Output = Result<ListModelsResponse, GatewayError>> + Send;
419
420    fn generate_content(
437        &self,
438        provider: Provider,
439        model: &str,
440        messages: Vec<Message>,
441    ) -> impl Future<Output = Result<CreateChatCompletionResponse, GatewayError>> + Send;
442
443    fn generate_content_stream(
453        &self,
454        provider: Provider,
455        model: &str,
456        messages: Vec<Message>,
457    ) -> impl Stream<Item = Result<SSEvents, GatewayError>> + Send;
458
459    fn list_tools(&self) -> impl Future<Output = Result<ListToolsResponse, GatewayError>> + Send;
470
471    fn health_check(&self) -> impl Future<Output = Result<bool, GatewayError>> + Send;
473}
474
475impl InferenceGatewayClient {
476    pub fn new(base_url: &str) -> Self {
481        Self {
482            base_url: base_url.to_string(),
483            client: Client::new(),
484            token: None,
485            tools: None,
486            max_tokens: None,
487        }
488    }
489
490    pub fn new_default() -> Self {
493        let base_url = std::env::var("INFERENCE_GATEWAY_URL")
494            .unwrap_or_else(|_| "http://localhost:8080/v1".to_string());
495
496        Self {
497            base_url,
498            client: Client::new(),
499            token: None,
500            tools: None,
501            max_tokens: None,
502        }
503    }
504
505    pub fn base_url(&self) -> &str {
507        &self.base_url
508    }
509
510    pub fn with_tools(mut self, tools: Option<Vec<Tool>>) -> Self {
518        self.tools = tools;
519        self
520    }
521
522    pub fn with_token(mut self, token: impl Into<String>) -> Self {
530        self.token = Some(token.into());
531        self
532    }
533
534    pub fn with_max_tokens(mut self, max_tokens: Option<i32>) -> Self {
542        self.max_tokens = max_tokens;
543        self
544    }
545}
546
547impl InferenceGatewayAPI for InferenceGatewayClient {
548    async fn list_models(&self) -> Result<ListModelsResponse, GatewayError> {
549        let url = format!("{}/models", self.base_url);
550        let mut request = self.client.get(&url);
551        if let Some(token) = &self.token {
552            request = request.bearer_auth(token);
553        }
554
555        let response = request.send().await?;
556        match response.status() {
557            StatusCode::OK => {
558                let json_response: ListModelsResponse = response.json().await?;
559                Ok(json_response)
560            }
561            StatusCode::UNAUTHORIZED => {
562                let error: ErrorResponse = response.json().await?;
563                Err(GatewayError::Unauthorized(error.error))
564            }
565            StatusCode::BAD_REQUEST => {
566                let error: ErrorResponse = response.json().await?;
567                Err(GatewayError::BadRequest(error.error))
568            }
569            StatusCode::INTERNAL_SERVER_ERROR => {
570                let error: ErrorResponse = response.json().await?;
571                Err(GatewayError::InternalError(error.error))
572            }
573            _ => Err(GatewayError::Other(Box::new(std::io::Error::other(
574                format!("Unexpected status code: {}", response.status()),
575            )))),
576        }
577    }
578
579    async fn list_models_by_provider(
580        &self,
581        provider: Provider,
582    ) -> Result<ListModelsResponse, GatewayError> {
583        let url = format!("{}/models?provider={}", self.base_url, provider);
584        let mut request = self.client.get(&url);
585        if let Some(token) = &self.token {
586            request = self.client.get(&url).bearer_auth(token);
587        }
588
589        let response = request.send().await?;
590        match response.status() {
591            StatusCode::OK => {
592                let json_response: ListModelsResponse = response.json().await?;
593                Ok(json_response)
594            }
595            StatusCode::UNAUTHORIZED => {
596                let error: ErrorResponse = response.json().await?;
597                Err(GatewayError::Unauthorized(error.error))
598            }
599            StatusCode::BAD_REQUEST => {
600                let error: ErrorResponse = response.json().await?;
601                Err(GatewayError::BadRequest(error.error))
602            }
603            StatusCode::INTERNAL_SERVER_ERROR => {
604                let error: ErrorResponse = response.json().await?;
605                Err(GatewayError::InternalError(error.error))
606            }
607            _ => Err(GatewayError::Other(Box::new(std::io::Error::other(
608                format!("Unexpected status code: {}", response.status()),
609            )))),
610        }
611    }
612
613    async fn generate_content(
614        &self,
615        provider: Provider,
616        model: &str,
617        messages: Vec<Message>,
618    ) -> Result<CreateChatCompletionResponse, GatewayError> {
619        let url = format!("{}/chat/completions?provider={}", self.base_url, provider);
620        let mut request = self.client.post(&url);
621        if let Some(token) = &self.token {
622            request = request.bearer_auth(token);
623        }
624
625        let request_payload = CreateChatCompletionRequest {
626            model: model.to_string(),
627            messages,
628            stream: false,
629            tools: self.tools.clone(),
630            max_tokens: self.max_tokens,
631        };
632
633        let response = request.json(&request_payload).send().await?;
634
635        match response.status() {
636            StatusCode::OK => Ok(response.json().await?),
637            StatusCode::BAD_REQUEST => {
638                let error: ErrorResponse = response.json().await?;
639                Err(GatewayError::BadRequest(error.error))
640            }
641            StatusCode::UNAUTHORIZED => {
642                let error: ErrorResponse = response.json().await?;
643                Err(GatewayError::Unauthorized(error.error))
644            }
645            StatusCode::INTERNAL_SERVER_ERROR => {
646                let error: ErrorResponse = response.json().await?;
647                Err(GatewayError::InternalError(error.error))
648            }
649            status => Err(GatewayError::Other(Box::new(std::io::Error::other(
650                format!("Unexpected status code: {status}"),
651            )))),
652        }
653    }
654
655    fn generate_content_stream(
657        &self,
658        provider: Provider,
659        model: &str,
660        messages: Vec<Message>,
661    ) -> impl Stream<Item = Result<SSEvents, GatewayError>> + Send {
662        let client = self.client.clone();
663        let base_url = self.base_url.clone();
664        let url = format!(
665            "{}/chat/completions?provider={}",
666            base_url,
667            provider.to_string().to_lowercase()
668        );
669
670        let request = CreateChatCompletionRequest {
671            model: model.to_string(),
672            messages,
673            stream: true,
674            tools: None,
675            max_tokens: None,
676        };
677
678        async_stream::try_stream! {
679            let response = client.post(&url).json(&request).send().await?;
680            let mut stream = response.bytes_stream();
681            let mut current_event: Option<String> = None;
682            let mut current_data: Option<String> = None;
683
684            while let Some(chunk) = stream.next().await {
685                let chunk = chunk?;
686                let chunk_str = String::from_utf8_lossy(&chunk);
687
688                for line in chunk_str.lines() {
689                    if line.is_empty() && current_data.is_some() {
690                        yield SSEvents {
691                            data: current_data.take().unwrap(),
692                            event: current_event.take(),
693                            retry: None, };
695                        continue;
696                    }
697
698                    if let Some(event) = line.strip_prefix("event:") {
699                        current_event = Some(event.trim().to_string());
700                    } else if let Some(data) = line.strip_prefix("data:") {
701                        let processed_data = data.strip_suffix('\n').unwrap_or(data);
702                        current_data = Some(processed_data.trim().to_string());
703                    }
704                }
705            }
706        }
707    }
708
709    async fn list_tools(&self) -> Result<ListToolsResponse, GatewayError> {
710        let url = format!("{}/mcp/tools", self.base_url);
711        let mut request = self.client.get(&url);
712        if let Some(token) = &self.token {
713            request = request.bearer_auth(token);
714        }
715
716        let response = request.send().await?;
717        match response.status() {
718            StatusCode::OK => {
719                let json_response: ListToolsResponse = response.json().await?;
720                Ok(json_response)
721            }
722            StatusCode::UNAUTHORIZED => {
723                let error: ErrorResponse = response.json().await?;
724                Err(GatewayError::Unauthorized(error.error))
725            }
726            StatusCode::BAD_REQUEST => {
727                let error: ErrorResponse = response.json().await?;
728                Err(GatewayError::BadRequest(error.error))
729            }
730            StatusCode::FORBIDDEN => {
731                let error: ErrorResponse = response.json().await?;
732                Err(GatewayError::Forbidden(error.error))
733            }
734            StatusCode::INTERNAL_SERVER_ERROR => {
735                let error: ErrorResponse = response.json().await?;
736                Err(GatewayError::InternalError(error.error))
737            }
738            _ => Err(GatewayError::Other(Box::new(std::io::Error::other(
739                format!("Unexpected status code: {}", response.status()),
740            )))),
741        }
742    }
743
744    async fn health_check(&self) -> Result<bool, GatewayError> {
745        let url = format!("{}/health", self.base_url);
746
747        let response = self.client.get(&url).send().await?;
748        match response.status() {
749            StatusCode::OK => Ok(true),
750            _ => Ok(false),
751        }
752    }
753}
754
755#[cfg(test)]
756mod tests {
757    use crate::{
758        CreateChatCompletionRequest, CreateChatCompletionResponse,
759        CreateChatCompletionStreamResponse, FunctionObject, GatewayError, InferenceGatewayAPI,
760        InferenceGatewayClient, Message, MessageRole, Provider, Tool, ToolType,
761    };
762    use futures_util::{pin_mut, StreamExt};
763    use mockito::{Matcher, Server};
764    use serde_json::json;
765
766    #[test]
767    fn test_provider_serialization() {
768        let providers = vec![
769            (Provider::Ollama, "ollama"),
770            (Provider::Groq, "groq"),
771            (Provider::OpenAI, "openai"),
772            (Provider::Cloudflare, "cloudflare"),
773            (Provider::Cohere, "cohere"),
774            (Provider::Anthropic, "anthropic"),
775            (Provider::Deepseek, "deepseek"),
776            (Provider::Google, "google"),
777        ];
778
779        for (provider, expected) in providers {
780            let json = serde_json::to_string(&provider).unwrap();
781            assert_eq!(json, format!("\"{}\"", expected));
782        }
783    }
784
785    #[test]
786    fn test_provider_deserialization() {
787        let test_cases = vec![
788            ("\"ollama\"", Provider::Ollama),
789            ("\"groq\"", Provider::Groq),
790            ("\"openai\"", Provider::OpenAI),
791            ("\"cloudflare\"", Provider::Cloudflare),
792            ("\"cohere\"", Provider::Cohere),
793            ("\"anthropic\"", Provider::Anthropic),
794            ("\"deepseek\"", Provider::Deepseek),
795            ("\"google\"", Provider::Google),
796        ];
797
798        for (json, expected) in test_cases {
799            let provider: Provider = serde_json::from_str(json).unwrap();
800            assert_eq!(provider, expected);
801        }
802    }
803
804    #[test]
805    fn test_message_serialization_with_tool_call_id() {
806        let message_with_tool = Message {
807            role: MessageRole::Tool,
808            content: "The weather is sunny".to_string(),
809            tool_call_id: Some("call_123".to_string()),
810            ..Default::default()
811        };
812
813        let serialized = serde_json::to_string(&message_with_tool).unwrap();
814        let expected_with_tool =
815            r#"{"role":"tool","content":"The weather is sunny","tool_call_id":"call_123"}"#;
816        assert_eq!(serialized, expected_with_tool);
817
818        let message_without_tool = Message {
819            role: MessageRole::User,
820            content: "What's the weather?".to_string(),
821            ..Default::default()
822        };
823
824        let serialized = serde_json::to_string(&message_without_tool).unwrap();
825        let expected_without_tool = r#"{"role":"user","content":"What's the weather?"}"#;
826        assert_eq!(serialized, expected_without_tool);
827
828        let deserialized: Message = serde_json::from_str(expected_with_tool).unwrap();
829        assert_eq!(deserialized.role, MessageRole::Tool);
830        assert_eq!(deserialized.content, "The weather is sunny");
831        assert_eq!(deserialized.tool_call_id, Some("call_123".to_string()));
832
833        let deserialized: Message = serde_json::from_str(expected_without_tool).unwrap();
834        assert_eq!(deserialized.role, MessageRole::User);
835        assert_eq!(deserialized.content, "What's the weather?");
836        assert_eq!(deserialized.tool_call_id, None);
837    }
838
839    #[test]
840    fn test_provider_display() {
841        let providers = vec![
842            (Provider::Ollama, "ollama"),
843            (Provider::Groq, "groq"),
844            (Provider::OpenAI, "openai"),
845            (Provider::Cloudflare, "cloudflare"),
846            (Provider::Cohere, "cohere"),
847            (Provider::Anthropic, "anthropic"),
848            (Provider::Deepseek, "deepseek"),
849            (Provider::Google, "google"),
850        ];
851
852        for (provider, expected) in providers {
853            assert_eq!(provider.to_string(), expected);
854        }
855    }
856
857    #[test]
858    fn test_google_provider_case_insensitive() {
859        let test_cases = vec!["google", "Google", "GOOGLE", "GoOgLe"];
860
861        for test_case in test_cases {
862            let provider: Result<Provider, _> = test_case.try_into();
863            assert!(provider.is_ok(), "Failed to parse: {}", test_case);
864            assert_eq!(provider.unwrap(), Provider::Google);
865        }
866
867        let json_cases = vec![r#""google""#, r#""Google""#, r#""GOOGLE""#];
868
869        for json_case in json_cases {
870            let provider: Provider = serde_json::from_str(json_case).unwrap();
871            assert_eq!(provider, Provider::Google);
872        }
873
874        assert_eq!(Provider::Google.to_string(), "google");
875    }
876
877    #[test]
878    fn test_generate_request_serialization() {
879        let request_payload = CreateChatCompletionRequest {
880            model: "llama3.2:1b".to_string(),
881            messages: vec![
882                Message {
883                    role: MessageRole::System,
884                    content: "You are a helpful assistant.".to_string(),
885                    ..Default::default()
886                },
887                Message {
888                    role: MessageRole::User,
889                    content: "What is the current weather in Toronto?".to_string(),
890                    ..Default::default()
891                },
892            ],
893            stream: false,
894            tools: Some(vec![Tool {
895                r#type: ToolType::Function,
896                function: FunctionObject {
897                    name: "get_current_weather".to_string(),
898                    description: "Get the current weather of a city".to_string(),
899                    parameters: json!({
900                        "type": "object",
901                        "properties": {
902                            "city": {
903                                "type": "string",
904                                "description": "The name of the city"
905                            }
906                        },
907                        "required": ["city"]
908                    }),
909                },
910            }]),
911            max_tokens: None,
912        };
913
914        let serialized = serde_json::to_string_pretty(&request_payload).unwrap();
915        let expected = r#"{
916      "model": "llama3.2:1b",
917      "messages": [
918        {
919          "role": "system",
920          "content": "You are a helpful assistant."
921        },
922        {
923          "role": "user",
924          "content": "What is the current weather in Toronto?"
925        }
926      ],
927      "stream": false,
928      "tools": [
929        {
930          "type": "function",
931          "function": {
932            "name": "get_current_weather",
933            "description": "Get the current weather of a city",
934            "parameters": {
935              "type": "object",
936              "properties": {
937                "city": {
938                  "type": "string",
939                  "description": "The name of the city"
940                }
941              },
942              "required": ["city"]
943            }
944          }
945        }
946      ]
947    }"#;
948
949        assert_eq!(
950            serde_json::from_str::<serde_json::Value>(&serialized).unwrap(),
951            serde_json::from_str::<serde_json::Value>(expected).unwrap()
952        );
953    }
954
955    #[tokio::test]
956    async fn test_authentication_header() -> Result<(), GatewayError> {
957        let mut server = Server::new_async().await;
958
959        let mock_response = r#"{
960            "object": "list",
961            "data": []
962        }"#;
963
964        let mock_with_auth = server
965            .mock("GET", "/v1/models")
966            .match_header("authorization", "Bearer test-token")
967            .with_status(200)
968            .with_header("content-type", "application/json")
969            .with_body(mock_response)
970            .expect(1)
971            .create();
972
973        let base_url = format!("{}/v1", server.url());
974        let client = InferenceGatewayClient::new(&base_url).with_token("test-token");
975        client.list_models().await?;
976        mock_with_auth.assert();
977
978        let mock_without_auth = server
979            .mock("GET", "/v1/models")
980            .match_header("authorization", Matcher::Missing)
981            .with_status(200)
982            .with_header("content-type", "application/json")
983            .with_body(mock_response)
984            .expect(1)
985            .create();
986
987        let base_url = format!("{}/v1", server.url());
988        let client = InferenceGatewayClient::new(&base_url);
989        client.list_models().await?;
990        mock_without_auth.assert();
991
992        Ok(())
993    }
994
995    #[tokio::test]
996    async fn test_unauthorized_error() -> Result<(), GatewayError> {
997        let mut server = Server::new_async().await;
998
999        let raw_json_response = r#"{
1000            "error": "Invalid token"
1001        }"#;
1002
1003        let mock = server
1004            .mock("GET", "/v1/models")
1005            .with_status(401)
1006            .with_header("content-type", "application/json")
1007            .with_body(raw_json_response)
1008            .create();
1009
1010        let base_url = format!("{}/v1", server.url());
1011        let client = InferenceGatewayClient::new(&base_url);
1012        let error = client.list_models().await.unwrap_err();
1013
1014        assert!(matches!(error, GatewayError::Unauthorized(_)));
1015        if let GatewayError::Unauthorized(msg) = error {
1016            assert_eq!(msg, "Invalid token");
1017        }
1018        mock.assert();
1019
1020        Ok(())
1021    }
1022
1023    #[tokio::test]
1024    async fn test_list_models() -> Result<(), GatewayError> {
1025        let mut server = Server::new_async().await;
1026
1027        let raw_response_json = r#"{
1028            "object": "list",
1029            "data": [
1030                {
1031                    "id": "llama2",
1032                    "object": "model",
1033                    "created": 1630000001,
1034                    "owned_by": "ollama",
1035                    "served_by": "ollama"
1036                }
1037            ]
1038        }"#;
1039
1040        let mock = server
1041            .mock("GET", "/v1/models")
1042            .with_status(200)
1043            .with_header("content-type", "application/json")
1044            .with_body(raw_response_json)
1045            .create();
1046
1047        let base_url = format!("{}/v1", server.url());
1048        let client = InferenceGatewayClient::new(&base_url);
1049        let response = client.list_models().await?;
1050
1051        assert!(response.provider.is_none());
1052        assert_eq!(response.object, "list");
1053        assert_eq!(response.data.len(), 1);
1054        assert_eq!(response.data[0].id, "llama2");
1055        mock.assert();
1056
1057        Ok(())
1058    }
1059
1060    #[tokio::test]
1061    async fn test_list_models_by_provider() -> Result<(), GatewayError> {
1062        let mut server = Server::new_async().await;
1063
1064        let raw_json_response = r#"{
1065            "provider":"ollama",
1066            "object":"list",
1067            "data": [
1068                {
1069                    "id": "llama2",
1070                    "object": "model",
1071                    "created": 1630000001,
1072                    "owned_by": "ollama",
1073                    "served_by": "ollama"
1074                }
1075            ]
1076        }"#;
1077
1078        let mock = server
1079            .mock("GET", "/v1/models?provider=ollama")
1080            .with_status(200)
1081            .with_header("content-type", "application/json")
1082            .with_body(raw_json_response)
1083            .create();
1084
1085        let base_url = format!("{}/v1", server.url());
1086        let client = InferenceGatewayClient::new(&base_url);
1087        let response = client.list_models_by_provider(Provider::Ollama).await?;
1088
1089        assert!(response.provider.is_some());
1090        assert_eq!(response.provider, Some(Provider::Ollama));
1091        assert_eq!(response.data[0].id, "llama2");
1092        mock.assert();
1093
1094        Ok(())
1095    }
1096
1097    #[tokio::test]
1098    async fn test_generate_content() -> Result<(), GatewayError> {
1099        let mut server = Server::new_async().await;
1100
1101        let raw_json_response = r#"{
1102            "id": "chatcmpl-456",
1103            "object": "chat.completion",
1104            "created": 1630000001,
1105            "model": "mixtral-8x7b",
1106            "choices": [
1107                {
1108                    "index": 0,
1109                    "finish_reason": "stop",
1110                    "message": {
1111                        "role": "assistant",
1112                        "content": "Hellloooo"
1113                    }
1114                }
1115            ]
1116        }"#;
1117
1118        let mock = server
1119            .mock("POST", "/v1/chat/completions?provider=ollama")
1120            .with_status(200)
1121            .with_header("content-type", "application/json")
1122            .with_body(raw_json_response)
1123            .create();
1124
1125        let base_url = format!("{}/v1", server.url());
1126        let client = InferenceGatewayClient::new(&base_url);
1127
1128        let messages = vec![Message {
1129            role: MessageRole::User,
1130            content: "Hello".to_string(),
1131            ..Default::default()
1132        }];
1133        let response = client
1134            .generate_content(Provider::Ollama, "llama2", messages)
1135            .await?;
1136
1137        assert_eq!(response.choices[0].message.role, MessageRole::Assistant);
1138        assert_eq!(response.choices[0].message.content, "Hellloooo");
1139        mock.assert();
1140
1141        Ok(())
1142    }
1143
1144    #[tokio::test]
1145    async fn test_generate_content_serialization() -> Result<(), GatewayError> {
1146        let mut server = Server::new_async().await;
1147
1148        let raw_json = r#"{
1149            "id": "chatcmpl-456",
1150            "object": "chat.completion",
1151            "created": 1630000001,
1152            "model": "mixtral-8x7b",
1153            "choices": [
1154                {
1155                    "index": 0,
1156                    "finish_reason": "stop",
1157                    "message": {
1158                        "role": "assistant",
1159                        "content": "Hello"
1160                    }
1161                }
1162            ]
1163        }"#;
1164
1165        let mock = server
1166            .mock("POST", "/v1/chat/completions?provider=groq")
1167            .with_status(200)
1168            .with_header("content-type", "application/json")
1169            .with_body(raw_json)
1170            .create();
1171
1172        let base_url = format!("{}/v1", server.url());
1173        let client = InferenceGatewayClient::new(&base_url);
1174
1175        let direct_parse: Result<CreateChatCompletionResponse, _> = serde_json::from_str(raw_json);
1176        assert!(
1177            direct_parse.is_ok(),
1178            "Direct JSON parse failed: {:?}",
1179            direct_parse.err()
1180        );
1181
1182        let messages = vec![Message {
1183            role: MessageRole::User,
1184            content: "Hello".to_string(),
1185            ..Default::default()
1186        }];
1187
1188        let response = client
1189            .generate_content(Provider::Groq, "mixtral-8x7b", messages)
1190            .await?;
1191
1192        assert_eq!(response.choices[0].message.role, MessageRole::Assistant);
1193        assert_eq!(response.choices[0].message.content, "Hello");
1194
1195        mock.assert();
1196        Ok(())
1197    }
1198
1199    #[tokio::test]
1200    async fn test_generate_content_error_response() -> Result<(), GatewayError> {
1201        let mut server = Server::new_async().await;
1202
1203        let raw_json_response = r#"{
1204            "error":"Invalid request"
1205        }"#;
1206
1207        let mock = server
1208            .mock("POST", "/v1/chat/completions?provider=groq")
1209            .with_status(400)
1210            .with_header("content-type", "application/json")
1211            .with_body(raw_json_response)
1212            .create();
1213
1214        let base_url = format!("{}/v1", server.url());
1215        let client = InferenceGatewayClient::new(&base_url);
1216        let messages = vec![Message {
1217            role: MessageRole::User,
1218            content: "Hello".to_string(),
1219            ..Default::default()
1220        }];
1221        let error = client
1222            .generate_content(Provider::Groq, "mixtral-8x7b", messages)
1223            .await
1224            .unwrap_err();
1225
1226        assert!(matches!(error, GatewayError::BadRequest(_)));
1227        if let GatewayError::BadRequest(msg) = error {
1228            assert_eq!(msg, "Invalid request");
1229        }
1230        mock.assert();
1231
1232        Ok(())
1233    }
1234
1235    #[tokio::test]
1236    async fn test_gateway_errors() -> Result<(), GatewayError> {
1237        let mut server: mockito::ServerGuard = Server::new_async().await;
1238
1239        let unauthorized_mock = server
1240            .mock("GET", "/v1/models")
1241            .with_status(401)
1242            .with_header("content-type", "application/json")
1243            .with_body(r#"{"error":"Invalid token"}"#)
1244            .create();
1245
1246        let base_url = format!("{}/v1", server.url());
1247        let client = InferenceGatewayClient::new(&base_url);
1248        match client.list_models().await {
1249            Err(GatewayError::Unauthorized(msg)) => assert_eq!(msg, "Invalid token"),
1250            _ => panic!("Expected Unauthorized error"),
1251        }
1252        unauthorized_mock.assert();
1253
1254        let bad_request_mock = server
1255            .mock("GET", "/v1/models")
1256            .with_status(400)
1257            .with_header("content-type", "application/json")
1258            .with_body(r#"{"error":"Invalid provider"}"#)
1259            .create();
1260
1261        match client.list_models().await {
1262            Err(GatewayError::BadRequest(msg)) => assert_eq!(msg, "Invalid provider"),
1263            _ => panic!("Expected BadRequest error"),
1264        }
1265        bad_request_mock.assert();
1266
1267        let internal_error_mock = server
1268            .mock("GET", "/v1/models")
1269            .with_status(500)
1270            .with_header("content-type", "application/json")
1271            .with_body(r#"{"error":"Internal server error occurred"}"#)
1272            .create();
1273
1274        match client.list_models().await {
1275            Err(GatewayError::InternalError(msg)) => {
1276                assert_eq!(msg, "Internal server error occurred")
1277            }
1278            _ => panic!("Expected InternalError error"),
1279        }
1280        internal_error_mock.assert();
1281
1282        Ok(())
1283    }
1284
1285    #[tokio::test]
1286    async fn test_generate_content_case_insensitive() -> Result<(), GatewayError> {
1287        let mut server = Server::new_async().await;
1288
1289        let raw_json = r#"{
1290            "id": "chatcmpl-456",
1291            "object": "chat.completion",
1292            "created": 1630000001,
1293            "model": "mixtral-8x7b",
1294            "choices": [
1295                {
1296                    "index": 0,
1297                    "finish_reason": "stop",
1298                    "message": {
1299                        "role": "assistant",
1300                        "content": "Hello"
1301                    }
1302                }
1303            ]
1304        }"#;
1305
1306        let mock = server
1307            .mock("POST", "/v1/chat/completions?provider=groq")
1308            .with_status(200)
1309            .with_header("content-type", "application/json")
1310            .with_body(raw_json)
1311            .create();
1312
1313        let base_url = format!("{}/v1", server.url());
1314        let client = InferenceGatewayClient::new(&base_url);
1315
1316        let messages = vec![Message {
1317            role: MessageRole::User,
1318            content: "Hello".to_string(),
1319            ..Default::default()
1320        }];
1321
1322        let response = client
1323            .generate_content(Provider::Groq, "mixtral-8x7b", messages)
1324            .await?;
1325
1326        assert_eq!(response.choices[0].message.role, MessageRole::Assistant);
1327        assert_eq!(response.choices[0].message.content, "Hello");
1328        assert_eq!(response.model, "mixtral-8x7b");
1329        assert_eq!(response.object, "chat.completion");
1330        mock.assert();
1331
1332        Ok(())
1333    }
1334
1335    #[tokio::test]
1336    async fn test_generate_content_stream() -> Result<(), GatewayError> {
1337        let mut server = Server::new_async().await;
1338
1339        let mock = server
1340            .mock("POST", "/v1/chat/completions?provider=groq")
1341            .with_status(200)
1342            .with_header("content-type", "text/event-stream")
1343            .with_chunked_body(move |writer| -> std::io::Result<()> {
1344                let events = vec![
1345                    format!("data: {}\n\n", r#"{"id":"chatcmpl-123","object":"chat.completion.chunk","created":1694268190,"model":"mixtral-8x7b","system_fingerprint":"fp_","choices":[{"index":0,"delta":{"role":"assistant","content":"Hello"},"finish_reason":null}]}"#),
1346                    format!("data: {}\n\n", r#"{"id":"chatcmpl-123","object":"chat.completion.chunk","created":1694268191,"model":"mixtral-8x7b","system_fingerprint":"fp_","choices":[{"index":0,"delta":{"role":"assistant","content":" World"},"finish_reason":null}]}"#),
1347                    format!("data: {}\n\n", r#"{"id":"chatcmpl-123","object":"chat.completion.chunk","created":1694268192,"model":"mixtral-8x7b","system_fingerprint":"fp_","choices":[{"index":0,"delta":{},"finish_reason":"stop"}],"usage":{"prompt_tokens":17,"completion_tokens":40,"total_tokens":57}}"#),
1348                    format!("data: [DONE]\n\n")
1349                ];
1350                for event in events {
1351                    writer.write_all(event.as_bytes())?;
1352                }
1353                Ok(())
1354            })
1355            .create();
1356
1357        let base_url = format!("{}/v1", server.url());
1358        let client = InferenceGatewayClient::new(&base_url);
1359
1360        let messages = vec![Message {
1361            role: MessageRole::User,
1362            content: "Test message".to_string(),
1363            ..Default::default()
1364        }];
1365
1366        let stream = client.generate_content_stream(Provider::Groq, "mixtral-8x7b", messages);
1367        pin_mut!(stream);
1368        while let Some(result) = stream.next().await {
1369            let result = result?;
1370            let generate_response: CreateChatCompletionStreamResponse =
1371                serde_json::from_str(&result.data)
1372                    .expect("Failed to parse CreateChatCompletionResponse");
1373
1374            if generate_response.choices[0].finish_reason.is_some() {
1375                assert_eq!(
1376                    generate_response.choices[0].finish_reason.as_ref().unwrap(),
1377                    "stop"
1378                );
1379                break;
1380            }
1381
1382            if let Some(content) = &generate_response.choices[0].delta.content {
1383                assert!(matches!(content.as_str(), "Hello" | " World"));
1384            }
1385            if let Some(role) = &generate_response.choices[0].delta.role {
1386                assert_eq!(role, &MessageRole::Assistant);
1387            }
1388        }
1389
1390        mock.assert();
1391        Ok(())
1392    }
1393
1394    #[tokio::test]
1395    async fn test_generate_content_stream_error() -> Result<(), GatewayError> {
1396        let mut server = Server::new_async().await;
1397
1398        let mock = server
1399            .mock("POST", "/v1/chat/completions?provider=groq")
1400            .with_status(400)
1401            .with_header("content-type", "application/json")
1402            .with_chunked_body(move |writer| -> std::io::Result<()> {
1403                let events = vec![format!(
1404                    "event: {}\ndata: {}\nretry: {}\n\n",
1405                    r#"error"#, r#"{"error":"Invalid request"}"#, r#"1000"#,
1406                )];
1407                for event in events {
1408                    writer.write_all(event.as_bytes())?;
1409                }
1410                Ok(())
1411            })
1412            .expect_at_least(1)
1413            .create();
1414
1415        let base_url = format!("{}/v1", server.url());
1416        let client = InferenceGatewayClient::new(&base_url);
1417
1418        let messages = vec![Message {
1419            role: MessageRole::User,
1420            content: "Test message".to_string(),
1421            ..Default::default()
1422        }];
1423
1424        let stream = client.generate_content_stream(Provider::Groq, "mixtral-8x7b", messages);
1425
1426        pin_mut!(stream);
1427        while let Some(result) = stream.next().await {
1428            let result = result?;
1429            assert!(result.event.is_some());
1430            assert_eq!(result.event.unwrap(), "error");
1431            assert!(result.data.contains("Invalid request"));
1432            assert!(result.retry.is_none());
1433        }
1434
1435        mock.assert();
1436        Ok(())
1437    }
1438
1439    #[tokio::test]
1440    async fn test_generate_content_with_tools() -> Result<(), GatewayError> {
1441        let mut server = Server::new_async().await;
1442
1443        let raw_json_response = r#"{
1444            "id": "chatcmpl-123",
1445            "object": "chat.completion",
1446            "created": 1630000000,
1447            "model": "deepseek-r1-distill-llama-70b",
1448            "choices": [
1449                {
1450                    "index": 0,
1451                    "finish_reason": "tool_calls",
1452                    "message": {
1453                        "role": "assistant",
1454                        "content": "Let me check the weather for you.",
1455                        "tool_calls": [
1456                            {
1457                                "id": "1234",
1458                                "type": "function",
1459                                "function": {
1460                                    "name": "get_weather",
1461                                    "arguments": "{\"location\": \"London\"}"
1462                                }
1463                            }
1464                        ]
1465                    }
1466                }
1467            ]
1468        }"#;
1469
1470        let mock = server
1471            .mock("POST", "/v1/chat/completions?provider=groq")
1472            .with_status(200)
1473            .with_header("content-type", "application/json")
1474            .with_body(raw_json_response)
1475            .create();
1476
1477        let tools = vec![Tool {
1478            r#type: ToolType::Function,
1479            function: FunctionObject {
1480                name: "get_weather".to_string(),
1481                description: "Get the weather for a location".to_string(),
1482                parameters: json!({
1483                    "type": "object",
1484                    "properties": {
1485                        "location": {
1486                            "type": "string",
1487                            "description": "The city name"
1488                        }
1489                    },
1490                    "required": ["location"]
1491                }),
1492            },
1493        }];
1494
1495        let base_url = format!("{}/v1", server.url());
1496        let client = InferenceGatewayClient::new(&base_url).with_tools(Some(tools));
1497
1498        let messages = vec![Message {
1499            role: MessageRole::User,
1500            content: "What's the weather in London?".to_string(),
1501            ..Default::default()
1502        }];
1503
1504        let response = client
1505            .generate_content(Provider::Groq, "deepseek-r1-distill-llama-70b", messages)
1506            .await?;
1507
1508        assert_eq!(response.choices[0].message.role, MessageRole::Assistant);
1509        assert_eq!(
1510            response.choices[0].message.content,
1511            "Let me check the weather for you."
1512        );
1513
1514        let tool_calls = response.choices[0].message.tool_calls.as_ref().unwrap();
1515        assert_eq!(tool_calls.len(), 1);
1516        assert_eq!(tool_calls[0].function.name, "get_weather");
1517
1518        let params = tool_calls[0]
1519            .function
1520            .parse_arguments()
1521            .expect("Failed to parse function arguments");
1522        assert_eq!(params["location"].as_str().unwrap(), "London");
1523
1524        mock.assert();
1525        Ok(())
1526    }
1527
1528    #[tokio::test]
1529    async fn test_generate_content_without_tools() -> Result<(), GatewayError> {
1530        let mut server = Server::new_async().await;
1531
1532        let raw_json_response = r#"{
1533            "id": "chatcmpl-123",
1534            "object": "chat.completion",
1535            "created": 1630000000,
1536            "model": "gpt-4",
1537            "choices": [
1538                {
1539                    "index": 0,
1540                    "finish_reason": "stop",
1541                    "message": {
1542                        "role": "assistant",
1543                        "content": "Hello!"
1544                    }
1545                }
1546            ]
1547        }"#;
1548
1549        let mock = server
1550            .mock("POST", "/v1/chat/completions?provider=openai")
1551            .with_status(200)
1552            .with_header("content-type", "application/json")
1553            .with_body(raw_json_response)
1554            .create();
1555
1556        let base_url = format!("{}/v1", server.url());
1557        let client = InferenceGatewayClient::new(&base_url);
1558
1559        let messages = vec![Message {
1560            role: MessageRole::User,
1561            content: "Hi".to_string(),
1562            ..Default::default()
1563        }];
1564
1565        let response = client
1566            .generate_content(Provider::OpenAI, "gpt-4", messages)
1567            .await?;
1568
1569        assert_eq!(response.model, "gpt-4");
1570        assert_eq!(response.choices[0].message.content, "Hello!");
1571        assert_eq!(response.choices[0].message.role, MessageRole::Assistant);
1572        assert!(response.choices[0].message.tool_calls.is_none());
1573
1574        mock.assert();
1575        Ok(())
1576    }
1577
1578    #[tokio::test]
1579    async fn test_generate_content_with_tools_payload() -> Result<(), GatewayError> {
1580        let mut server = Server::new_async().await;
1581
1582        let raw_request_body = r#"{
1583            "model": "deepseek-r1-distill-llama-70b",
1584            "messages": [
1585                {
1586                    "role": "system",
1587                    "content": "You are a helpful assistant."
1588                },
1589                {
1590                    "role": "user",
1591                    "content": "What is the current weather in Toronto?"
1592                }
1593            ],
1594            "stream": false,
1595            "tools": [
1596                {
1597                    "type": "function",
1598                    "function": {
1599                        "name": "get_current_weather",
1600                        "description": "Get the current weather of a city",
1601                        "parameters": {
1602                            "type": "object",
1603                            "properties": {
1604                                "city": {
1605                                    "type": "string",
1606                                    "description": "The name of the city"
1607                                }
1608                            },
1609                            "required": ["city"]
1610                        }
1611                    }
1612                }
1613            ]
1614        }"#;
1615
1616        let raw_json_response = r#"{
1617            "id": "1234",
1618            "object": "chat.completion",
1619            "created": 1630000000,
1620            "model": "deepseek-r1-distill-llama-70b",
1621            "choices": [
1622                {
1623                    "index": 0,
1624                    "finish_reason": "stop",
1625                    "message": {
1626                        "role": "assistant",
1627                        "content": "Let me check the weather for you",
1628                        "tool_calls": [
1629                            {
1630                                "id": "1234",
1631                                "type": "function",
1632                                "function": {
1633                                    "name": "get_current_weather",
1634                                    "arguments": "{\"city\": \"Toronto\"}"
1635                                }
1636                            }
1637                        ]
1638                    }
1639                }
1640            ]
1641        }"#;
1642
1643        let mock = server
1644            .mock("POST", "/v1/chat/completions?provider=groq")
1645            .with_status(200)
1646            .with_header("content-type", "application/json")
1647            .match_body(mockito::Matcher::JsonString(raw_request_body.to_string()))
1648            .with_body(raw_json_response)
1649            .create();
1650
1651        let tools = vec![Tool {
1652            r#type: ToolType::Function,
1653            function: FunctionObject {
1654                name: "get_current_weather".to_string(),
1655                description: "Get the current weather of a city".to_string(),
1656                parameters: json!({
1657                    "type": "object",
1658                    "properties": {
1659                        "city": {
1660                            "type": "string",
1661                            "description": "The name of the city"
1662                        }
1663                    },
1664                    "required": ["city"]
1665                }),
1666            },
1667        }];
1668
1669        let base_url = format!("{}/v1", server.url());
1670        let client = InferenceGatewayClient::new(&base_url);
1671
1672        let messages = vec![
1673            Message {
1674                role: MessageRole::System,
1675                content: "You are a helpful assistant.".to_string(),
1676                ..Default::default()
1677            },
1678            Message {
1679                role: MessageRole::User,
1680                content: "What is the current weather in Toronto?".to_string(),
1681                ..Default::default()
1682            },
1683        ];
1684
1685        let response = client
1686            .with_tools(Some(tools))
1687            .generate_content(Provider::Groq, "deepseek-r1-distill-llama-70b", messages)
1688            .await?;
1689
1690        assert_eq!(response.choices[0].message.role, MessageRole::Assistant);
1691        assert_eq!(
1692            response.choices[0].message.content,
1693            "Let me check the weather for you"
1694        );
1695        assert_eq!(
1696            response.choices[0]
1697                .message
1698                .tool_calls
1699                .as_ref()
1700                .unwrap()
1701                .len(),
1702            1
1703        );
1704
1705        mock.assert();
1706        Ok(())
1707    }
1708
1709    #[tokio::test]
1710    async fn test_generate_content_with_max_tokens() -> Result<(), GatewayError> {
1711        let mut server = Server::new_async().await;
1712
1713        let raw_json_response = r#"{
1714            "id": "chatcmpl-123",
1715            "object": "chat.completion",
1716            "created": 1630000000,
1717            "model": "mixtral-8x7b",
1718            "choices": [
1719                {
1720                    "index": 0,
1721                    "finish_reason": "stop",
1722                    "message": {
1723                        "role": "assistant",
1724                        "content": "Here's a poem with 100 tokens..."
1725                    }
1726                }
1727            ]
1728        }"#;
1729
1730        let mock = server
1731            .mock("POST", "/v1/chat/completions?provider=groq")
1732            .with_status(200)
1733            .with_header("content-type", "application/json")
1734            .match_body(mockito::Matcher::JsonString(
1735                r#"{
1736                "model": "mixtral-8x7b",
1737                "messages": [{"role":"user","content":"Write a poem"}],
1738                "stream": false,
1739                "max_tokens": 100
1740            }"#
1741                .to_string(),
1742            ))
1743            .with_body(raw_json_response)
1744            .create();
1745
1746        let base_url = format!("{}/v1", server.url());
1747        let client = InferenceGatewayClient::new(&base_url).with_max_tokens(Some(100));
1748
1749        let messages = vec![Message {
1750            role: MessageRole::User,
1751            content: "Write a poem".to_string(),
1752            ..Default::default()
1753        }];
1754
1755        let response = client
1756            .generate_content(Provider::Groq, "mixtral-8x7b", messages)
1757            .await?;
1758
1759        assert_eq!(
1760            response.choices[0].message.content,
1761            "Here's a poem with 100 tokens..."
1762        );
1763        assert_eq!(response.model, "mixtral-8x7b");
1764        assert_eq!(response.created, 1630000000);
1765        assert_eq!(response.object, "chat.completion");
1766
1767        mock.assert();
1768        Ok(())
1769    }
1770
1771    #[tokio::test]
1772    async fn test_health_check() -> Result<(), GatewayError> {
1773        let mut server = Server::new_async().await;
1774        let mock = server.mock("GET", "/health").with_status(200).create();
1775
1776        let client = InferenceGatewayClient::new(&server.url());
1777        let is_healthy = client.health_check().await?;
1778
1779        assert!(is_healthy);
1780        mock.assert();
1781
1782        Ok(())
1783    }
1784
1785    #[tokio::test]
1786    async fn test_client_base_url_configuration() -> Result<(), GatewayError> {
1787        let mut custom_url_server = Server::new_async().await;
1788
1789        let custom_url_mock = custom_url_server
1790            .mock("GET", "/health")
1791            .with_status(200)
1792            .create();
1793
1794        let custom_client = InferenceGatewayClient::new(&custom_url_server.url());
1795        let is_healthy = custom_client.health_check().await?;
1796        assert!(is_healthy);
1797        custom_url_mock.assert();
1798
1799        let default_client = InferenceGatewayClient::new_default();
1800
1801        let default_url = "http://localhost:8080/v1";
1802        assert_eq!(default_client.base_url(), default_url);
1803
1804        Ok(())
1805    }
1806
1807    #[tokio::test]
1808    async fn test_list_tools() -> Result<(), GatewayError> {
1809        let mut server = Server::new_async().await;
1810
1811        let raw_response_json = r#"{
1812            "object": "list",
1813            "data": [
1814                {
1815                    "name": "read_file",
1816                    "description": "Read content from a file",
1817                    "server": "http://mcp-filesystem-server:8083/mcp",
1818                    "input_schema": {
1819                        "type": "object",
1820                        "properties": {
1821                            "file_path": {
1822                                "type": "string",
1823                                "description": "Path to the file to read"
1824                            }
1825                        },
1826                        "required": ["file_path"]
1827                    }
1828                },
1829                {
1830                    "name": "write_file",
1831                    "description": "Write content to a file",
1832                    "server": "http://mcp-filesystem-server:8083/mcp"
1833                }
1834            ]
1835        }"#;
1836
1837        let mock = server
1838            .mock("GET", "/v1/mcp/tools")
1839            .with_status(200)
1840            .with_header("content-type", "application/json")
1841            .with_body(raw_response_json)
1842            .create();
1843
1844        let base_url = format!("{}/v1", server.url());
1845        let client = InferenceGatewayClient::new(&base_url);
1846        let response = client.list_tools().await?;
1847
1848        assert_eq!(response.object, "list");
1849        assert_eq!(response.data.len(), 2);
1850
1851        assert_eq!(response.data[0].name, "read_file");
1853        assert_eq!(response.data[0].description, "Read content from a file");
1854        assert_eq!(
1855            response.data[0].server,
1856            "http://mcp-filesystem-server:8083/mcp"
1857        );
1858        assert!(response.data[0].input_schema.is_some());
1859
1860        assert_eq!(response.data[1].name, "write_file");
1862        assert_eq!(response.data[1].description, "Write content to a file");
1863        assert_eq!(
1864            response.data[1].server,
1865            "http://mcp-filesystem-server:8083/mcp"
1866        );
1867        assert!(response.data[1].input_schema.is_none());
1868
1869        mock.assert();
1870        Ok(())
1871    }
1872
1873    #[tokio::test]
1874    async fn test_list_tools_with_authentication() -> Result<(), GatewayError> {
1875        let mut server = Server::new_async().await;
1876
1877        let raw_response_json = r#"{
1878            "object": "list",
1879            "data": []
1880        }"#;
1881
1882        let mock = server
1883            .mock("GET", "/v1/mcp/tools")
1884            .match_header("authorization", "Bearer test-token")
1885            .with_status(200)
1886            .with_header("content-type", "application/json")
1887            .with_body(raw_response_json)
1888            .create();
1889
1890        let base_url = format!("{}/v1", server.url());
1891        let client = InferenceGatewayClient::new(&base_url).with_token("test-token");
1892        let response = client.list_tools().await?;
1893
1894        assert_eq!(response.object, "list");
1895        assert_eq!(response.data.len(), 0);
1896        mock.assert();
1897        Ok(())
1898    }
1899
1900    #[tokio::test]
1901    async fn test_list_tools_mcp_not_exposed() -> Result<(), GatewayError> {
1902        let mut server = Server::new_async().await;
1903
1904        let mock = server
1905            .mock("GET", "/v1/mcp/tools")
1906            .with_status(403)
1907            .with_header("content-type", "application/json")
1908            .with_body(
1909                r#"{"error":"MCP tools endpoint is not exposed. Set EXPOSE_MCP=true to enable."}"#,
1910            )
1911            .create();
1912
1913        let base_url = format!("{}/v1", server.url());
1914        let client = InferenceGatewayClient::new(&base_url);
1915
1916        match client.list_tools().await {
1917            Err(GatewayError::Forbidden(msg)) => {
1918                assert_eq!(
1919                    msg,
1920                    "MCP tools endpoint is not exposed. Set EXPOSE_MCP=true to enable."
1921                );
1922            }
1923            _ => panic!("Expected Forbidden error for MCP not exposed"),
1924        }
1925
1926        mock.assert();
1927        Ok(())
1928    }
1929}