1use core::fmt;
7use std::future::Future;
8
9use futures_util::{Stream, StreamExt};
10use reqwest::{Client, StatusCode};
11use serde::{Deserialize, Serialize};
12
13use serde_json::Value;
14use thiserror::Error;
15
16#[derive(Debug, Serialize, Deserialize)]
18pub struct SSEvents {
19 pub data: String,
20 pub event: Option<String>,
21 pub retry: Option<u64>,
22}
23
24#[derive(Error, Debug)]
26pub enum GatewayError {
27 #[error("Unauthorized: {0}")]
28 Unauthorized(String),
29
30 #[error("Forbidden: {0}")]
31 Forbidden(String),
32
33 #[error("Bad request: {0}")]
34 BadRequest(String),
35
36 #[error("Internal server error: {0}")]
37 InternalError(String),
38
39 #[error("Stream error: {0}")]
40 StreamError(reqwest::Error),
41
42 #[error("Decoding error: {0}")]
43 DecodingError(std::string::FromUtf8Error),
44
45 #[error("Request error: {0}")]
46 RequestError(#[from] reqwest::Error),
47
48 #[error("Deserialization error: {0}")]
49 DeserializationError(serde_json::Error),
50
51 #[error("Serialization error: {0}")]
52 SerializationError(#[from] serde_json::Error),
53
54 #[error("Other error: {0}")]
55 Other(#[from] Box<dyn std::error::Error + Send + Sync>),
56}
57
58#[derive(Debug, Deserialize)]
59struct ErrorResponse {
60 error: String,
61}
62
63#[derive(Debug, Serialize, Deserialize, Clone)]
65pub struct Model {
66 pub id: String,
68 pub object: Option<String>,
70 pub created: Option<i64>,
72 pub owned_by: Option<String>,
74 pub served_by: Option<String>,
76}
77
78#[derive(Debug, Serialize, Deserialize)]
80pub struct ListModelsResponse {
81 #[serde(skip_serializing_if = "Option::is_none")]
83 pub provider: Option<Provider>,
84 pub object: String,
86 pub data: Vec<Model>,
88}
89
90#[derive(Debug, Serialize, Deserialize, Clone)]
92pub struct MCPTool {
93 pub name: String,
95 pub description: String,
97 pub server: String,
99 #[serde(skip_serializing_if = "Option::is_none")]
101 pub input_schema: Option<Value>,
102}
103
104#[derive(Debug, Serialize, Deserialize)]
106pub struct ListToolsResponse {
107 pub object: String,
109 pub data: Vec<MCPTool>,
111}
112
113#[derive(Debug, Serialize, Deserialize, PartialEq, Clone, Copy)]
115#[serde(rename_all = "lowercase")]
116pub enum Provider {
117 #[serde(alias = "Ollama", alias = "OLLAMA")]
118 Ollama,
119 #[serde(alias = "Groq", alias = "GROQ")]
120 Groq,
121 #[serde(alias = "OpenAI", alias = "OPENAI")]
122 OpenAI,
123 #[serde(alias = "Cloudflare", alias = "CLOUDFLARE")]
124 Cloudflare,
125 #[serde(alias = "Cohere", alias = "COHERE")]
126 Cohere,
127 #[serde(alias = "Anthropic", alias = "ANTHROPIC")]
128 Anthropic,
129 #[serde(alias = "Deepseek", alias = "DEEPSEEK")]
130 Deepseek,
131 #[serde(alias = "Google", alias = "GOOGLE")]
132 Google,
133}
134
135impl fmt::Display for Provider {
136 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
137 match self {
138 Provider::Ollama => write!(f, "ollama"),
139 Provider::Groq => write!(f, "groq"),
140 Provider::OpenAI => write!(f, "openai"),
141 Provider::Cloudflare => write!(f, "cloudflare"),
142 Provider::Cohere => write!(f, "cohere"),
143 Provider::Anthropic => write!(f, "anthropic"),
144 Provider::Deepseek => write!(f, "deepseek"),
145 Provider::Google => write!(f, "google"),
146 }
147 }
148}
149
150impl TryFrom<&str> for Provider {
151 type Error = GatewayError;
152
153 fn try_from(s: &str) -> Result<Self, Self::Error> {
154 match s.to_lowercase().as_str() {
155 "ollama" => Ok(Self::Ollama),
156 "groq" => Ok(Self::Groq),
157 "openai" => Ok(Self::OpenAI),
158 "cloudflare" => Ok(Self::Cloudflare),
159 "cohere" => Ok(Self::Cohere),
160 "anthropic" => Ok(Self::Anthropic),
161 "deepseek" => Ok(Self::Deepseek),
162 "google" => Ok(Self::Google),
163 _ => Err(GatewayError::BadRequest(format!("Unknown provider: {s}"))),
164 }
165 }
166}
167
168#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Default)]
169#[serde(rename_all = "lowercase")]
170pub enum MessageRole {
171 System,
172 #[default]
173 User,
174 Assistant,
175 Tool,
176}
177
178impl fmt::Display for MessageRole {
179 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
180 match self {
181 MessageRole::System => write!(f, "system"),
182 MessageRole::User => write!(f, "user"),
183 MessageRole::Assistant => write!(f, "assistant"),
184 MessageRole::Tool => write!(f, "tool"),
185 }
186 }
187}
188
189#[derive(Debug, Serialize, Deserialize, Clone, Default)]
191pub struct Message {
192 pub role: MessageRole,
194 pub content: String,
196 #[serde(skip_serializing_if = "Option::is_none")]
198 pub tool_calls: Option<Vec<ChatCompletionMessageToolCall>>,
199 #[serde(skip_serializing_if = "Option::is_none")]
201 pub tool_call_id: Option<String>,
202 #[serde(skip_serializing_if = "Option::is_none")]
204 pub reasoning: Option<String>,
205}
206
207#[derive(Debug, Deserialize, Serialize, Clone)]
209pub struct ChatCompletionMessageToolCall {
210 pub id: String,
212 #[serde(rename = "type")]
214 pub r#type: ChatCompletionToolType,
215 pub function: ChatCompletionMessageToolCallFunction,
217}
218
219#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
221pub enum ChatCompletionToolType {
222 #[serde(rename = "function")]
224 Function,
225}
226
227#[derive(Debug, Deserialize, Serialize, Clone)]
229pub struct ChatCompletionMessageToolCallFunction {
230 pub name: String,
232 pub arguments: String,
234}
235
236impl ChatCompletionMessageToolCallFunction {
238 pub fn parse_arguments(&self) -> Result<serde_json::Value, serde_json::Error> {
239 serde_json::from_str(&self.arguments)
240 }
241}
242
243#[derive(Debug, Serialize, Deserialize, Clone)]
245pub struct FunctionObject {
246 pub name: String,
247 pub description: String,
248 pub parameters: Value,
249}
250
251#[derive(Debug, Serialize, Deserialize, Clone)]
253#[serde(rename_all = "lowercase")]
254pub enum ToolType {
255 Function,
256}
257
258#[derive(Debug, Serialize, Deserialize, Clone)]
260pub struct Tool {
261 pub r#type: ToolType,
262 pub function: FunctionObject,
263}
264
265#[derive(Debug, Serialize)]
267struct CreateChatCompletionRequest {
268 model: String,
270 messages: Vec<Message>,
272 stream: bool,
274 #[serde(skip_serializing_if = "Option::is_none")]
276 tools: Option<Vec<Tool>>,
277 #[serde(skip_serializing_if = "Option::is_none")]
279 max_tokens: Option<i32>,
280}
281
282#[derive(Debug, Deserialize, Clone)]
284pub struct ToolCallResponse {
285 pub id: String,
287 #[serde(rename = "type")]
289 pub r#type: ToolType,
290 pub function: ChatCompletionMessageToolCallFunction,
292}
293
294#[derive(Debug, Deserialize, Clone)]
295pub struct ChatCompletionChoice {
296 pub finish_reason: String,
297 pub message: Message,
298 pub index: i32,
299}
300
301#[derive(Debug, Deserialize, Clone)]
303pub struct CreateChatCompletionResponse {
304 pub id: String,
305 pub choices: Vec<ChatCompletionChoice>,
306 pub created: i64,
307 pub model: String,
308 pub object: String,
309}
310
311#[derive(Debug, Deserialize, Clone)]
313pub struct CreateChatCompletionStreamResponse {
314 pub id: String,
316 pub choices: Vec<ChatCompletionStreamChoice>,
318 pub created: i64,
320 pub model: String,
322 #[serde(skip_serializing_if = "Option::is_none")]
324 pub system_fingerprint: Option<String>,
325 pub object: String,
327 #[serde(skip_serializing_if = "Option::is_none")]
329 pub usage: Option<CompletionUsage>,
330}
331
332#[derive(Debug, Deserialize, Clone)]
334pub struct ChatCompletionStreamChoice {
335 pub delta: ChatCompletionStreamDelta,
337 pub index: i32,
339 #[serde(skip_serializing_if = "Option::is_none")]
341 pub finish_reason: Option<String>,
342}
343
344#[derive(Debug, Deserialize, Clone)]
346pub struct ChatCompletionStreamDelta {
347 #[serde(skip_serializing_if = "Option::is_none")]
349 pub role: Option<MessageRole>,
350 #[serde(skip_serializing_if = "Option::is_none")]
352 pub content: Option<String>,
353 #[serde(skip_serializing_if = "Option::is_none")]
355 pub tool_calls: Option<Vec<ToolCallResponse>>,
356}
357
358#[derive(Debug, Deserialize, Clone)]
360pub struct CompletionUsage {
361 pub completion_tokens: i64,
363 pub prompt_tokens: i64,
365 pub total_tokens: i64,
367}
368
369pub struct InferenceGatewayClient {
371 base_url: String,
372 client: Client,
373 token: Option<String>,
374 tools: Option<Vec<Tool>>,
375 max_tokens: Option<i32>,
376}
377
378impl std::fmt::Debug for InferenceGatewayClient {
380 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
381 f.debug_struct("InferenceGatewayClient")
382 .field("base_url", &self.base_url)
383 .field("token", &self.token.as_ref().map(|_| "*****"))
384 .finish()
385 }
386}
387
388pub trait InferenceGatewayAPI {
390 fn list_models(&self) -> impl Future<Output = Result<ListModelsResponse, GatewayError>> + Send;
401
402 fn list_models_by_provider(
416 &self,
417 provider: Provider,
418 ) -> impl Future<Output = Result<ListModelsResponse, GatewayError>> + Send;
419
420 fn generate_content(
437 &self,
438 provider: Provider,
439 model: &str,
440 messages: Vec<Message>,
441 ) -> impl Future<Output = Result<CreateChatCompletionResponse, GatewayError>> + Send;
442
443 fn generate_content_stream(
453 &self,
454 provider: Provider,
455 model: &str,
456 messages: Vec<Message>,
457 ) -> impl Stream<Item = Result<SSEvents, GatewayError>> + Send;
458
459 fn list_tools(&self) -> impl Future<Output = Result<ListToolsResponse, GatewayError>> + Send;
470
471 fn health_check(&self) -> impl Future<Output = Result<bool, GatewayError>> + Send;
473}
474
475impl InferenceGatewayClient {
476 pub fn new(base_url: &str) -> Self {
481 Self {
482 base_url: base_url.to_string(),
483 client: Client::new(),
484 token: None,
485 tools: None,
486 max_tokens: None,
487 }
488 }
489
490 pub fn new_default() -> Self {
493 let base_url = std::env::var("INFERENCE_GATEWAY_URL")
494 .unwrap_or_else(|_| "http://localhost:8080/v1".to_string());
495
496 Self {
497 base_url,
498 client: Client::new(),
499 token: None,
500 tools: None,
501 max_tokens: None,
502 }
503 }
504
505 pub fn base_url(&self) -> &str {
507 &self.base_url
508 }
509
510 pub fn with_tools(mut self, tools: Option<Vec<Tool>>) -> Self {
518 self.tools = tools;
519 self
520 }
521
522 pub fn with_token(mut self, token: impl Into<String>) -> Self {
530 self.token = Some(token.into());
531 self
532 }
533
534 pub fn with_max_tokens(mut self, max_tokens: Option<i32>) -> Self {
542 self.max_tokens = max_tokens;
543 self
544 }
545}
546
547impl InferenceGatewayAPI for InferenceGatewayClient {
548 async fn list_models(&self) -> Result<ListModelsResponse, GatewayError> {
549 let url = format!("{}/models", self.base_url);
550 let mut request = self.client.get(&url);
551 if let Some(token) = &self.token {
552 request = request.bearer_auth(token);
553 }
554
555 let response = request.send().await?;
556 match response.status() {
557 StatusCode::OK => {
558 let json_response: ListModelsResponse = response.json().await?;
559 Ok(json_response)
560 }
561 StatusCode::UNAUTHORIZED => {
562 let error: ErrorResponse = response.json().await?;
563 Err(GatewayError::Unauthorized(error.error))
564 }
565 StatusCode::BAD_REQUEST => {
566 let error: ErrorResponse = response.json().await?;
567 Err(GatewayError::BadRequest(error.error))
568 }
569 StatusCode::INTERNAL_SERVER_ERROR => {
570 let error: ErrorResponse = response.json().await?;
571 Err(GatewayError::InternalError(error.error))
572 }
573 _ => Err(GatewayError::Other(Box::new(std::io::Error::other(
574 format!("Unexpected status code: {}", response.status()),
575 )))),
576 }
577 }
578
579 async fn list_models_by_provider(
580 &self,
581 provider: Provider,
582 ) -> Result<ListModelsResponse, GatewayError> {
583 let url = format!("{}/models?provider={}", self.base_url, provider);
584 let mut request = self.client.get(&url);
585 if let Some(token) = &self.token {
586 request = self.client.get(&url).bearer_auth(token);
587 }
588
589 let response = request.send().await?;
590 match response.status() {
591 StatusCode::OK => {
592 let json_response: ListModelsResponse = response.json().await?;
593 Ok(json_response)
594 }
595 StatusCode::UNAUTHORIZED => {
596 let error: ErrorResponse = response.json().await?;
597 Err(GatewayError::Unauthorized(error.error))
598 }
599 StatusCode::BAD_REQUEST => {
600 let error: ErrorResponse = response.json().await?;
601 Err(GatewayError::BadRequest(error.error))
602 }
603 StatusCode::INTERNAL_SERVER_ERROR => {
604 let error: ErrorResponse = response.json().await?;
605 Err(GatewayError::InternalError(error.error))
606 }
607 _ => Err(GatewayError::Other(Box::new(std::io::Error::other(
608 format!("Unexpected status code: {}", response.status()),
609 )))),
610 }
611 }
612
613 async fn generate_content(
614 &self,
615 provider: Provider,
616 model: &str,
617 messages: Vec<Message>,
618 ) -> Result<CreateChatCompletionResponse, GatewayError> {
619 let url = format!("{}/chat/completions?provider={}", self.base_url, provider);
620 let mut request = self.client.post(&url);
621 if let Some(token) = &self.token {
622 request = request.bearer_auth(token);
623 }
624
625 let request_payload = CreateChatCompletionRequest {
626 model: model.to_string(),
627 messages,
628 stream: false,
629 tools: self.tools.clone(),
630 max_tokens: self.max_tokens,
631 };
632
633 let response = request.json(&request_payload).send().await?;
634
635 match response.status() {
636 StatusCode::OK => Ok(response.json().await?),
637 StatusCode::BAD_REQUEST => {
638 let error: ErrorResponse = response.json().await?;
639 Err(GatewayError::BadRequest(error.error))
640 }
641 StatusCode::UNAUTHORIZED => {
642 let error: ErrorResponse = response.json().await?;
643 Err(GatewayError::Unauthorized(error.error))
644 }
645 StatusCode::INTERNAL_SERVER_ERROR => {
646 let error: ErrorResponse = response.json().await?;
647 Err(GatewayError::InternalError(error.error))
648 }
649 status => Err(GatewayError::Other(Box::new(std::io::Error::other(
650 format!("Unexpected status code: {status}"),
651 )))),
652 }
653 }
654
655 fn generate_content_stream(
657 &self,
658 provider: Provider,
659 model: &str,
660 messages: Vec<Message>,
661 ) -> impl Stream<Item = Result<SSEvents, GatewayError>> + Send {
662 let client = self.client.clone();
663 let base_url = self.base_url.clone();
664 let url = format!(
665 "{}/chat/completions?provider={}",
666 base_url,
667 provider.to_string().to_lowercase()
668 );
669
670 let request = CreateChatCompletionRequest {
671 model: model.to_string(),
672 messages,
673 stream: true,
674 tools: None,
675 max_tokens: None,
676 };
677
678 async_stream::try_stream! {
679 let response = client.post(&url).json(&request).send().await?;
680 let mut stream = response.bytes_stream();
681 let mut current_event: Option<String> = None;
682 let mut current_data: Option<String> = None;
683
684 while let Some(chunk) = stream.next().await {
685 let chunk = chunk?;
686 let chunk_str = String::from_utf8_lossy(&chunk);
687
688 for line in chunk_str.lines() {
689 if line.is_empty() && current_data.is_some() {
690 yield SSEvents {
691 data: current_data.take().unwrap(),
692 event: current_event.take(),
693 retry: None, };
695 continue;
696 }
697
698 if let Some(event) = line.strip_prefix("event:") {
699 current_event = Some(event.trim().to_string());
700 } else if let Some(data) = line.strip_prefix("data:") {
701 let processed_data = data.strip_suffix('\n').unwrap_or(data);
702 current_data = Some(processed_data.trim().to_string());
703 }
704 }
705 }
706 }
707 }
708
709 async fn list_tools(&self) -> Result<ListToolsResponse, GatewayError> {
710 let url = format!("{}/mcp/tools", self.base_url);
711 let mut request = self.client.get(&url);
712 if let Some(token) = &self.token {
713 request = request.bearer_auth(token);
714 }
715
716 let response = request.send().await?;
717 match response.status() {
718 StatusCode::OK => {
719 let json_response: ListToolsResponse = response.json().await?;
720 Ok(json_response)
721 }
722 StatusCode::UNAUTHORIZED => {
723 let error: ErrorResponse = response.json().await?;
724 Err(GatewayError::Unauthorized(error.error))
725 }
726 StatusCode::BAD_REQUEST => {
727 let error: ErrorResponse = response.json().await?;
728 Err(GatewayError::BadRequest(error.error))
729 }
730 StatusCode::FORBIDDEN => {
731 let error: ErrorResponse = response.json().await?;
732 Err(GatewayError::Forbidden(error.error))
733 }
734 StatusCode::INTERNAL_SERVER_ERROR => {
735 let error: ErrorResponse = response.json().await?;
736 Err(GatewayError::InternalError(error.error))
737 }
738 _ => Err(GatewayError::Other(Box::new(std::io::Error::other(
739 format!("Unexpected status code: {}", response.status()),
740 )))),
741 }
742 }
743
744 async fn health_check(&self) -> Result<bool, GatewayError> {
745 let url = format!("{}/health", self.base_url);
746
747 let response = self.client.get(&url).send().await?;
748 match response.status() {
749 StatusCode::OK => Ok(true),
750 _ => Ok(false),
751 }
752 }
753}
754
755#[cfg(test)]
756mod tests {
757 use crate::{
758 CreateChatCompletionRequest, CreateChatCompletionResponse,
759 CreateChatCompletionStreamResponse, FunctionObject, GatewayError, InferenceGatewayAPI,
760 InferenceGatewayClient, Message, MessageRole, Provider, Tool, ToolType,
761 };
762 use futures_util::{pin_mut, StreamExt};
763 use mockito::{Matcher, Server};
764 use serde_json::json;
765
766 #[test]
767 fn test_provider_serialization() {
768 let providers = vec![
769 (Provider::Ollama, "ollama"),
770 (Provider::Groq, "groq"),
771 (Provider::OpenAI, "openai"),
772 (Provider::Cloudflare, "cloudflare"),
773 (Provider::Cohere, "cohere"),
774 (Provider::Anthropic, "anthropic"),
775 (Provider::Deepseek, "deepseek"),
776 (Provider::Google, "google"),
777 ];
778
779 for (provider, expected) in providers {
780 let json = serde_json::to_string(&provider).unwrap();
781 assert_eq!(json, format!("\"{}\"", expected));
782 }
783 }
784
785 #[test]
786 fn test_provider_deserialization() {
787 let test_cases = vec![
788 ("\"ollama\"", Provider::Ollama),
789 ("\"groq\"", Provider::Groq),
790 ("\"openai\"", Provider::OpenAI),
791 ("\"cloudflare\"", Provider::Cloudflare),
792 ("\"cohere\"", Provider::Cohere),
793 ("\"anthropic\"", Provider::Anthropic),
794 ("\"deepseek\"", Provider::Deepseek),
795 ("\"google\"", Provider::Google),
796 ];
797
798 for (json, expected) in test_cases {
799 let provider: Provider = serde_json::from_str(json).unwrap();
800 assert_eq!(provider, expected);
801 }
802 }
803
804 #[test]
805 fn test_message_serialization_with_tool_call_id() {
806 let message_with_tool = Message {
807 role: MessageRole::Tool,
808 content: "The weather is sunny".to_string(),
809 tool_call_id: Some("call_123".to_string()),
810 ..Default::default()
811 };
812
813 let serialized = serde_json::to_string(&message_with_tool).unwrap();
814 let expected_with_tool =
815 r#"{"role":"tool","content":"The weather is sunny","tool_call_id":"call_123"}"#;
816 assert_eq!(serialized, expected_with_tool);
817
818 let message_without_tool = Message {
819 role: MessageRole::User,
820 content: "What's the weather?".to_string(),
821 ..Default::default()
822 };
823
824 let serialized = serde_json::to_string(&message_without_tool).unwrap();
825 let expected_without_tool = r#"{"role":"user","content":"What's the weather?"}"#;
826 assert_eq!(serialized, expected_without_tool);
827
828 let deserialized: Message = serde_json::from_str(expected_with_tool).unwrap();
829 assert_eq!(deserialized.role, MessageRole::Tool);
830 assert_eq!(deserialized.content, "The weather is sunny");
831 assert_eq!(deserialized.tool_call_id, Some("call_123".to_string()));
832
833 let deserialized: Message = serde_json::from_str(expected_without_tool).unwrap();
834 assert_eq!(deserialized.role, MessageRole::User);
835 assert_eq!(deserialized.content, "What's the weather?");
836 assert_eq!(deserialized.tool_call_id, None);
837 }
838
839 #[test]
840 fn test_provider_display() {
841 let providers = vec![
842 (Provider::Ollama, "ollama"),
843 (Provider::Groq, "groq"),
844 (Provider::OpenAI, "openai"),
845 (Provider::Cloudflare, "cloudflare"),
846 (Provider::Cohere, "cohere"),
847 (Provider::Anthropic, "anthropic"),
848 (Provider::Deepseek, "deepseek"),
849 (Provider::Google, "google"),
850 ];
851
852 for (provider, expected) in providers {
853 assert_eq!(provider.to_string(), expected);
854 }
855 }
856
857 #[test]
858 fn test_google_provider_case_insensitive() {
859 let test_cases = vec!["google", "Google", "GOOGLE", "GoOgLe"];
860
861 for test_case in test_cases {
862 let provider: Result<Provider, _> = test_case.try_into();
863 assert!(provider.is_ok(), "Failed to parse: {}", test_case);
864 assert_eq!(provider.unwrap(), Provider::Google);
865 }
866
867 let json_cases = vec![r#""google""#, r#""Google""#, r#""GOOGLE""#];
868
869 for json_case in json_cases {
870 let provider: Provider = serde_json::from_str(json_case).unwrap();
871 assert_eq!(provider, Provider::Google);
872 }
873
874 assert_eq!(Provider::Google.to_string(), "google");
875 }
876
877 #[test]
878 fn test_generate_request_serialization() {
879 let request_payload = CreateChatCompletionRequest {
880 model: "llama3.2:1b".to_string(),
881 messages: vec![
882 Message {
883 role: MessageRole::System,
884 content: "You are a helpful assistant.".to_string(),
885 ..Default::default()
886 },
887 Message {
888 role: MessageRole::User,
889 content: "What is the current weather in Toronto?".to_string(),
890 ..Default::default()
891 },
892 ],
893 stream: false,
894 tools: Some(vec![Tool {
895 r#type: ToolType::Function,
896 function: FunctionObject {
897 name: "get_current_weather".to_string(),
898 description: "Get the current weather of a city".to_string(),
899 parameters: json!({
900 "type": "object",
901 "properties": {
902 "city": {
903 "type": "string",
904 "description": "The name of the city"
905 }
906 },
907 "required": ["city"]
908 }),
909 },
910 }]),
911 max_tokens: None,
912 };
913
914 let serialized = serde_json::to_string_pretty(&request_payload).unwrap();
915 let expected = r#"{
916 "model": "llama3.2:1b",
917 "messages": [
918 {
919 "role": "system",
920 "content": "You are a helpful assistant."
921 },
922 {
923 "role": "user",
924 "content": "What is the current weather in Toronto?"
925 }
926 ],
927 "stream": false,
928 "tools": [
929 {
930 "type": "function",
931 "function": {
932 "name": "get_current_weather",
933 "description": "Get the current weather of a city",
934 "parameters": {
935 "type": "object",
936 "properties": {
937 "city": {
938 "type": "string",
939 "description": "The name of the city"
940 }
941 },
942 "required": ["city"]
943 }
944 }
945 }
946 ]
947 }"#;
948
949 assert_eq!(
950 serde_json::from_str::<serde_json::Value>(&serialized).unwrap(),
951 serde_json::from_str::<serde_json::Value>(expected).unwrap()
952 );
953 }
954
955 #[tokio::test]
956 async fn test_authentication_header() -> Result<(), GatewayError> {
957 let mut server = Server::new_async().await;
958
959 let mock_response = r#"{
960 "object": "list",
961 "data": []
962 }"#;
963
964 let mock_with_auth = server
965 .mock("GET", "/v1/models")
966 .match_header("authorization", "Bearer test-token")
967 .with_status(200)
968 .with_header("content-type", "application/json")
969 .with_body(mock_response)
970 .expect(1)
971 .create();
972
973 let base_url = format!("{}/v1", server.url());
974 let client = InferenceGatewayClient::new(&base_url).with_token("test-token");
975 client.list_models().await?;
976 mock_with_auth.assert();
977
978 let mock_without_auth = server
979 .mock("GET", "/v1/models")
980 .match_header("authorization", Matcher::Missing)
981 .with_status(200)
982 .with_header("content-type", "application/json")
983 .with_body(mock_response)
984 .expect(1)
985 .create();
986
987 let base_url = format!("{}/v1", server.url());
988 let client = InferenceGatewayClient::new(&base_url);
989 client.list_models().await?;
990 mock_without_auth.assert();
991
992 Ok(())
993 }
994
995 #[tokio::test]
996 async fn test_unauthorized_error() -> Result<(), GatewayError> {
997 let mut server = Server::new_async().await;
998
999 let raw_json_response = r#"{
1000 "error": "Invalid token"
1001 }"#;
1002
1003 let mock = server
1004 .mock("GET", "/v1/models")
1005 .with_status(401)
1006 .with_header("content-type", "application/json")
1007 .with_body(raw_json_response)
1008 .create();
1009
1010 let base_url = format!("{}/v1", server.url());
1011 let client = InferenceGatewayClient::new(&base_url);
1012 let error = client.list_models().await.unwrap_err();
1013
1014 assert!(matches!(error, GatewayError::Unauthorized(_)));
1015 if let GatewayError::Unauthorized(msg) = error {
1016 assert_eq!(msg, "Invalid token");
1017 }
1018 mock.assert();
1019
1020 Ok(())
1021 }
1022
1023 #[tokio::test]
1024 async fn test_list_models() -> Result<(), GatewayError> {
1025 let mut server = Server::new_async().await;
1026
1027 let raw_response_json = r#"{
1028 "object": "list",
1029 "data": [
1030 {
1031 "id": "llama2",
1032 "object": "model",
1033 "created": 1630000001,
1034 "owned_by": "ollama",
1035 "served_by": "ollama"
1036 }
1037 ]
1038 }"#;
1039
1040 let mock = server
1041 .mock("GET", "/v1/models")
1042 .with_status(200)
1043 .with_header("content-type", "application/json")
1044 .with_body(raw_response_json)
1045 .create();
1046
1047 let base_url = format!("{}/v1", server.url());
1048 let client = InferenceGatewayClient::new(&base_url);
1049 let response = client.list_models().await?;
1050
1051 assert!(response.provider.is_none());
1052 assert_eq!(response.object, "list");
1053 assert_eq!(response.data.len(), 1);
1054 assert_eq!(response.data[0].id, "llama2");
1055 mock.assert();
1056
1057 Ok(())
1058 }
1059
1060 #[tokio::test]
1061 async fn test_list_models_by_provider() -> Result<(), GatewayError> {
1062 let mut server = Server::new_async().await;
1063
1064 let raw_json_response = r#"{
1065 "provider":"ollama",
1066 "object":"list",
1067 "data": [
1068 {
1069 "id": "llama2",
1070 "object": "model",
1071 "created": 1630000001,
1072 "owned_by": "ollama",
1073 "served_by": "ollama"
1074 }
1075 ]
1076 }"#;
1077
1078 let mock = server
1079 .mock("GET", "/v1/models?provider=ollama")
1080 .with_status(200)
1081 .with_header("content-type", "application/json")
1082 .with_body(raw_json_response)
1083 .create();
1084
1085 let base_url = format!("{}/v1", server.url());
1086 let client = InferenceGatewayClient::new(&base_url);
1087 let response = client.list_models_by_provider(Provider::Ollama).await?;
1088
1089 assert!(response.provider.is_some());
1090 assert_eq!(response.provider, Some(Provider::Ollama));
1091 assert_eq!(response.data[0].id, "llama2");
1092 mock.assert();
1093
1094 Ok(())
1095 }
1096
1097 #[tokio::test]
1098 async fn test_generate_content() -> Result<(), GatewayError> {
1099 let mut server = Server::new_async().await;
1100
1101 let raw_json_response = r#"{
1102 "id": "chatcmpl-456",
1103 "object": "chat.completion",
1104 "created": 1630000001,
1105 "model": "mixtral-8x7b",
1106 "choices": [
1107 {
1108 "index": 0,
1109 "finish_reason": "stop",
1110 "message": {
1111 "role": "assistant",
1112 "content": "Hellloooo"
1113 }
1114 }
1115 ]
1116 }"#;
1117
1118 let mock = server
1119 .mock("POST", "/v1/chat/completions?provider=ollama")
1120 .with_status(200)
1121 .with_header("content-type", "application/json")
1122 .with_body(raw_json_response)
1123 .create();
1124
1125 let base_url = format!("{}/v1", server.url());
1126 let client = InferenceGatewayClient::new(&base_url);
1127
1128 let messages = vec![Message {
1129 role: MessageRole::User,
1130 content: "Hello".to_string(),
1131 ..Default::default()
1132 }];
1133 let response = client
1134 .generate_content(Provider::Ollama, "llama2", messages)
1135 .await?;
1136
1137 assert_eq!(response.choices[0].message.role, MessageRole::Assistant);
1138 assert_eq!(response.choices[0].message.content, "Hellloooo");
1139 mock.assert();
1140
1141 Ok(())
1142 }
1143
1144 #[tokio::test]
1145 async fn test_generate_content_serialization() -> Result<(), GatewayError> {
1146 let mut server = Server::new_async().await;
1147
1148 let raw_json = r#"{
1149 "id": "chatcmpl-456",
1150 "object": "chat.completion",
1151 "created": 1630000001,
1152 "model": "mixtral-8x7b",
1153 "choices": [
1154 {
1155 "index": 0,
1156 "finish_reason": "stop",
1157 "message": {
1158 "role": "assistant",
1159 "content": "Hello"
1160 }
1161 }
1162 ]
1163 }"#;
1164
1165 let mock = server
1166 .mock("POST", "/v1/chat/completions?provider=groq")
1167 .with_status(200)
1168 .with_header("content-type", "application/json")
1169 .with_body(raw_json)
1170 .create();
1171
1172 let base_url = format!("{}/v1", server.url());
1173 let client = InferenceGatewayClient::new(&base_url);
1174
1175 let direct_parse: Result<CreateChatCompletionResponse, _> = serde_json::from_str(raw_json);
1176 assert!(
1177 direct_parse.is_ok(),
1178 "Direct JSON parse failed: {:?}",
1179 direct_parse.err()
1180 );
1181
1182 let messages = vec![Message {
1183 role: MessageRole::User,
1184 content: "Hello".to_string(),
1185 ..Default::default()
1186 }];
1187
1188 let response = client
1189 .generate_content(Provider::Groq, "mixtral-8x7b", messages)
1190 .await?;
1191
1192 assert_eq!(response.choices[0].message.role, MessageRole::Assistant);
1193 assert_eq!(response.choices[0].message.content, "Hello");
1194
1195 mock.assert();
1196 Ok(())
1197 }
1198
1199 #[tokio::test]
1200 async fn test_generate_content_error_response() -> Result<(), GatewayError> {
1201 let mut server = Server::new_async().await;
1202
1203 let raw_json_response = r#"{
1204 "error":"Invalid request"
1205 }"#;
1206
1207 let mock = server
1208 .mock("POST", "/v1/chat/completions?provider=groq")
1209 .with_status(400)
1210 .with_header("content-type", "application/json")
1211 .with_body(raw_json_response)
1212 .create();
1213
1214 let base_url = format!("{}/v1", server.url());
1215 let client = InferenceGatewayClient::new(&base_url);
1216 let messages = vec![Message {
1217 role: MessageRole::User,
1218 content: "Hello".to_string(),
1219 ..Default::default()
1220 }];
1221 let error = client
1222 .generate_content(Provider::Groq, "mixtral-8x7b", messages)
1223 .await
1224 .unwrap_err();
1225
1226 assert!(matches!(error, GatewayError::BadRequest(_)));
1227 if let GatewayError::BadRequest(msg) = error {
1228 assert_eq!(msg, "Invalid request");
1229 }
1230 mock.assert();
1231
1232 Ok(())
1233 }
1234
1235 #[tokio::test]
1236 async fn test_gateway_errors() -> Result<(), GatewayError> {
1237 let mut server: mockito::ServerGuard = Server::new_async().await;
1238
1239 let unauthorized_mock = server
1240 .mock("GET", "/v1/models")
1241 .with_status(401)
1242 .with_header("content-type", "application/json")
1243 .with_body(r#"{"error":"Invalid token"}"#)
1244 .create();
1245
1246 let base_url = format!("{}/v1", server.url());
1247 let client = InferenceGatewayClient::new(&base_url);
1248 match client.list_models().await {
1249 Err(GatewayError::Unauthorized(msg)) => assert_eq!(msg, "Invalid token"),
1250 _ => panic!("Expected Unauthorized error"),
1251 }
1252 unauthorized_mock.assert();
1253
1254 let bad_request_mock = server
1255 .mock("GET", "/v1/models")
1256 .with_status(400)
1257 .with_header("content-type", "application/json")
1258 .with_body(r#"{"error":"Invalid provider"}"#)
1259 .create();
1260
1261 match client.list_models().await {
1262 Err(GatewayError::BadRequest(msg)) => assert_eq!(msg, "Invalid provider"),
1263 _ => panic!("Expected BadRequest error"),
1264 }
1265 bad_request_mock.assert();
1266
1267 let internal_error_mock = server
1268 .mock("GET", "/v1/models")
1269 .with_status(500)
1270 .with_header("content-type", "application/json")
1271 .with_body(r#"{"error":"Internal server error occurred"}"#)
1272 .create();
1273
1274 match client.list_models().await {
1275 Err(GatewayError::InternalError(msg)) => {
1276 assert_eq!(msg, "Internal server error occurred")
1277 }
1278 _ => panic!("Expected InternalError error"),
1279 }
1280 internal_error_mock.assert();
1281
1282 Ok(())
1283 }
1284
1285 #[tokio::test]
1286 async fn test_generate_content_case_insensitive() -> Result<(), GatewayError> {
1287 let mut server = Server::new_async().await;
1288
1289 let raw_json = r#"{
1290 "id": "chatcmpl-456",
1291 "object": "chat.completion",
1292 "created": 1630000001,
1293 "model": "mixtral-8x7b",
1294 "choices": [
1295 {
1296 "index": 0,
1297 "finish_reason": "stop",
1298 "message": {
1299 "role": "assistant",
1300 "content": "Hello"
1301 }
1302 }
1303 ]
1304 }"#;
1305
1306 let mock = server
1307 .mock("POST", "/v1/chat/completions?provider=groq")
1308 .with_status(200)
1309 .with_header("content-type", "application/json")
1310 .with_body(raw_json)
1311 .create();
1312
1313 let base_url = format!("{}/v1", server.url());
1314 let client = InferenceGatewayClient::new(&base_url);
1315
1316 let messages = vec![Message {
1317 role: MessageRole::User,
1318 content: "Hello".to_string(),
1319 ..Default::default()
1320 }];
1321
1322 let response = client
1323 .generate_content(Provider::Groq, "mixtral-8x7b", messages)
1324 .await?;
1325
1326 assert_eq!(response.choices[0].message.role, MessageRole::Assistant);
1327 assert_eq!(response.choices[0].message.content, "Hello");
1328 assert_eq!(response.model, "mixtral-8x7b");
1329 assert_eq!(response.object, "chat.completion");
1330 mock.assert();
1331
1332 Ok(())
1333 }
1334
1335 #[tokio::test]
1336 async fn test_generate_content_stream() -> Result<(), GatewayError> {
1337 let mut server = Server::new_async().await;
1338
1339 let mock = server
1340 .mock("POST", "/v1/chat/completions?provider=groq")
1341 .with_status(200)
1342 .with_header("content-type", "text/event-stream")
1343 .with_chunked_body(move |writer| -> std::io::Result<()> {
1344 let events = vec![
1345 format!("data: {}\n\n", r#"{"id":"chatcmpl-123","object":"chat.completion.chunk","created":1694268190,"model":"mixtral-8x7b","system_fingerprint":"fp_","choices":[{"index":0,"delta":{"role":"assistant","content":"Hello"},"finish_reason":null}]}"#),
1346 format!("data: {}\n\n", r#"{"id":"chatcmpl-123","object":"chat.completion.chunk","created":1694268191,"model":"mixtral-8x7b","system_fingerprint":"fp_","choices":[{"index":0,"delta":{"role":"assistant","content":" World"},"finish_reason":null}]}"#),
1347 format!("data: {}\n\n", r#"{"id":"chatcmpl-123","object":"chat.completion.chunk","created":1694268192,"model":"mixtral-8x7b","system_fingerprint":"fp_","choices":[{"index":0,"delta":{},"finish_reason":"stop"}],"usage":{"prompt_tokens":17,"completion_tokens":40,"total_tokens":57}}"#),
1348 format!("data: [DONE]\n\n")
1349 ];
1350 for event in events {
1351 writer.write_all(event.as_bytes())?;
1352 }
1353 Ok(())
1354 })
1355 .create();
1356
1357 let base_url = format!("{}/v1", server.url());
1358 let client = InferenceGatewayClient::new(&base_url);
1359
1360 let messages = vec![Message {
1361 role: MessageRole::User,
1362 content: "Test message".to_string(),
1363 ..Default::default()
1364 }];
1365
1366 let stream = client.generate_content_stream(Provider::Groq, "mixtral-8x7b", messages);
1367 pin_mut!(stream);
1368 while let Some(result) = stream.next().await {
1369 let result = result?;
1370 let generate_response: CreateChatCompletionStreamResponse =
1371 serde_json::from_str(&result.data)
1372 .expect("Failed to parse CreateChatCompletionResponse");
1373
1374 if generate_response.choices[0].finish_reason.is_some() {
1375 assert_eq!(
1376 generate_response.choices[0].finish_reason.as_ref().unwrap(),
1377 "stop"
1378 );
1379 break;
1380 }
1381
1382 if let Some(content) = &generate_response.choices[0].delta.content {
1383 assert!(matches!(content.as_str(), "Hello" | " World"));
1384 }
1385 if let Some(role) = &generate_response.choices[0].delta.role {
1386 assert_eq!(role, &MessageRole::Assistant);
1387 }
1388 }
1389
1390 mock.assert();
1391 Ok(())
1392 }
1393
1394 #[tokio::test]
1395 async fn test_generate_content_stream_error() -> Result<(), GatewayError> {
1396 let mut server = Server::new_async().await;
1397
1398 let mock = server
1399 .mock("POST", "/v1/chat/completions?provider=groq")
1400 .with_status(400)
1401 .with_header("content-type", "application/json")
1402 .with_chunked_body(move |writer| -> std::io::Result<()> {
1403 let events = vec![format!(
1404 "event: {}\ndata: {}\nretry: {}\n\n",
1405 r#"error"#, r#"{"error":"Invalid request"}"#, r#"1000"#,
1406 )];
1407 for event in events {
1408 writer.write_all(event.as_bytes())?;
1409 }
1410 Ok(())
1411 })
1412 .expect_at_least(1)
1413 .create();
1414
1415 let base_url = format!("{}/v1", server.url());
1416 let client = InferenceGatewayClient::new(&base_url);
1417
1418 let messages = vec![Message {
1419 role: MessageRole::User,
1420 content: "Test message".to_string(),
1421 ..Default::default()
1422 }];
1423
1424 let stream = client.generate_content_stream(Provider::Groq, "mixtral-8x7b", messages);
1425
1426 pin_mut!(stream);
1427 while let Some(result) = stream.next().await {
1428 let result = result?;
1429 assert!(result.event.is_some());
1430 assert_eq!(result.event.unwrap(), "error");
1431 assert!(result.data.contains("Invalid request"));
1432 assert!(result.retry.is_none());
1433 }
1434
1435 mock.assert();
1436 Ok(())
1437 }
1438
1439 #[tokio::test]
1440 async fn test_generate_content_with_tools() -> Result<(), GatewayError> {
1441 let mut server = Server::new_async().await;
1442
1443 let raw_json_response = r#"{
1444 "id": "chatcmpl-123",
1445 "object": "chat.completion",
1446 "created": 1630000000,
1447 "model": "deepseek-r1-distill-llama-70b",
1448 "choices": [
1449 {
1450 "index": 0,
1451 "finish_reason": "tool_calls",
1452 "message": {
1453 "role": "assistant",
1454 "content": "Let me check the weather for you.",
1455 "tool_calls": [
1456 {
1457 "id": "1234",
1458 "type": "function",
1459 "function": {
1460 "name": "get_weather",
1461 "arguments": "{\"location\": \"London\"}"
1462 }
1463 }
1464 ]
1465 }
1466 }
1467 ]
1468 }"#;
1469
1470 let mock = server
1471 .mock("POST", "/v1/chat/completions?provider=groq")
1472 .with_status(200)
1473 .with_header("content-type", "application/json")
1474 .with_body(raw_json_response)
1475 .create();
1476
1477 let tools = vec![Tool {
1478 r#type: ToolType::Function,
1479 function: FunctionObject {
1480 name: "get_weather".to_string(),
1481 description: "Get the weather for a location".to_string(),
1482 parameters: json!({
1483 "type": "object",
1484 "properties": {
1485 "location": {
1486 "type": "string",
1487 "description": "The city name"
1488 }
1489 },
1490 "required": ["location"]
1491 }),
1492 },
1493 }];
1494
1495 let base_url = format!("{}/v1", server.url());
1496 let client = InferenceGatewayClient::new(&base_url).with_tools(Some(tools));
1497
1498 let messages = vec![Message {
1499 role: MessageRole::User,
1500 content: "What's the weather in London?".to_string(),
1501 ..Default::default()
1502 }];
1503
1504 let response = client
1505 .generate_content(Provider::Groq, "deepseek-r1-distill-llama-70b", messages)
1506 .await?;
1507
1508 assert_eq!(response.choices[0].message.role, MessageRole::Assistant);
1509 assert_eq!(
1510 response.choices[0].message.content,
1511 "Let me check the weather for you."
1512 );
1513
1514 let tool_calls = response.choices[0].message.tool_calls.as_ref().unwrap();
1515 assert_eq!(tool_calls.len(), 1);
1516 assert_eq!(tool_calls[0].function.name, "get_weather");
1517
1518 let params = tool_calls[0]
1519 .function
1520 .parse_arguments()
1521 .expect("Failed to parse function arguments");
1522 assert_eq!(params["location"].as_str().unwrap(), "London");
1523
1524 mock.assert();
1525 Ok(())
1526 }
1527
1528 #[tokio::test]
1529 async fn test_generate_content_without_tools() -> Result<(), GatewayError> {
1530 let mut server = Server::new_async().await;
1531
1532 let raw_json_response = r#"{
1533 "id": "chatcmpl-123",
1534 "object": "chat.completion",
1535 "created": 1630000000,
1536 "model": "gpt-4",
1537 "choices": [
1538 {
1539 "index": 0,
1540 "finish_reason": "stop",
1541 "message": {
1542 "role": "assistant",
1543 "content": "Hello!"
1544 }
1545 }
1546 ]
1547 }"#;
1548
1549 let mock = server
1550 .mock("POST", "/v1/chat/completions?provider=openai")
1551 .with_status(200)
1552 .with_header("content-type", "application/json")
1553 .with_body(raw_json_response)
1554 .create();
1555
1556 let base_url = format!("{}/v1", server.url());
1557 let client = InferenceGatewayClient::new(&base_url);
1558
1559 let messages = vec![Message {
1560 role: MessageRole::User,
1561 content: "Hi".to_string(),
1562 ..Default::default()
1563 }];
1564
1565 let response = client
1566 .generate_content(Provider::OpenAI, "gpt-4", messages)
1567 .await?;
1568
1569 assert_eq!(response.model, "gpt-4");
1570 assert_eq!(response.choices[0].message.content, "Hello!");
1571 assert_eq!(response.choices[0].message.role, MessageRole::Assistant);
1572 assert!(response.choices[0].message.tool_calls.is_none());
1573
1574 mock.assert();
1575 Ok(())
1576 }
1577
1578 #[tokio::test]
1579 async fn test_generate_content_with_tools_payload() -> Result<(), GatewayError> {
1580 let mut server = Server::new_async().await;
1581
1582 let raw_request_body = r#"{
1583 "model": "deepseek-r1-distill-llama-70b",
1584 "messages": [
1585 {
1586 "role": "system",
1587 "content": "You are a helpful assistant."
1588 },
1589 {
1590 "role": "user",
1591 "content": "What is the current weather in Toronto?"
1592 }
1593 ],
1594 "stream": false,
1595 "tools": [
1596 {
1597 "type": "function",
1598 "function": {
1599 "name": "get_current_weather",
1600 "description": "Get the current weather of a city",
1601 "parameters": {
1602 "type": "object",
1603 "properties": {
1604 "city": {
1605 "type": "string",
1606 "description": "The name of the city"
1607 }
1608 },
1609 "required": ["city"]
1610 }
1611 }
1612 }
1613 ]
1614 }"#;
1615
1616 let raw_json_response = r#"{
1617 "id": "1234",
1618 "object": "chat.completion",
1619 "created": 1630000000,
1620 "model": "deepseek-r1-distill-llama-70b",
1621 "choices": [
1622 {
1623 "index": 0,
1624 "finish_reason": "stop",
1625 "message": {
1626 "role": "assistant",
1627 "content": "Let me check the weather for you",
1628 "tool_calls": [
1629 {
1630 "id": "1234",
1631 "type": "function",
1632 "function": {
1633 "name": "get_current_weather",
1634 "arguments": "{\"city\": \"Toronto\"}"
1635 }
1636 }
1637 ]
1638 }
1639 }
1640 ]
1641 }"#;
1642
1643 let mock = server
1644 .mock("POST", "/v1/chat/completions?provider=groq")
1645 .with_status(200)
1646 .with_header("content-type", "application/json")
1647 .match_body(mockito::Matcher::JsonString(raw_request_body.to_string()))
1648 .with_body(raw_json_response)
1649 .create();
1650
1651 let tools = vec![Tool {
1652 r#type: ToolType::Function,
1653 function: FunctionObject {
1654 name: "get_current_weather".to_string(),
1655 description: "Get the current weather of a city".to_string(),
1656 parameters: json!({
1657 "type": "object",
1658 "properties": {
1659 "city": {
1660 "type": "string",
1661 "description": "The name of the city"
1662 }
1663 },
1664 "required": ["city"]
1665 }),
1666 },
1667 }];
1668
1669 let base_url = format!("{}/v1", server.url());
1670 let client = InferenceGatewayClient::new(&base_url);
1671
1672 let messages = vec![
1673 Message {
1674 role: MessageRole::System,
1675 content: "You are a helpful assistant.".to_string(),
1676 ..Default::default()
1677 },
1678 Message {
1679 role: MessageRole::User,
1680 content: "What is the current weather in Toronto?".to_string(),
1681 ..Default::default()
1682 },
1683 ];
1684
1685 let response = client
1686 .with_tools(Some(tools))
1687 .generate_content(Provider::Groq, "deepseek-r1-distill-llama-70b", messages)
1688 .await?;
1689
1690 assert_eq!(response.choices[0].message.role, MessageRole::Assistant);
1691 assert_eq!(
1692 response.choices[0].message.content,
1693 "Let me check the weather for you"
1694 );
1695 assert_eq!(
1696 response.choices[0]
1697 .message
1698 .tool_calls
1699 .as_ref()
1700 .unwrap()
1701 .len(),
1702 1
1703 );
1704
1705 mock.assert();
1706 Ok(())
1707 }
1708
1709 #[tokio::test]
1710 async fn test_generate_content_with_max_tokens() -> Result<(), GatewayError> {
1711 let mut server = Server::new_async().await;
1712
1713 let raw_json_response = r#"{
1714 "id": "chatcmpl-123",
1715 "object": "chat.completion",
1716 "created": 1630000000,
1717 "model": "mixtral-8x7b",
1718 "choices": [
1719 {
1720 "index": 0,
1721 "finish_reason": "stop",
1722 "message": {
1723 "role": "assistant",
1724 "content": "Here's a poem with 100 tokens..."
1725 }
1726 }
1727 ]
1728 }"#;
1729
1730 let mock = server
1731 .mock("POST", "/v1/chat/completions?provider=groq")
1732 .with_status(200)
1733 .with_header("content-type", "application/json")
1734 .match_body(mockito::Matcher::JsonString(
1735 r#"{
1736 "model": "mixtral-8x7b",
1737 "messages": [{"role":"user","content":"Write a poem"}],
1738 "stream": false,
1739 "max_tokens": 100
1740 }"#
1741 .to_string(),
1742 ))
1743 .with_body(raw_json_response)
1744 .create();
1745
1746 let base_url = format!("{}/v1", server.url());
1747 let client = InferenceGatewayClient::new(&base_url).with_max_tokens(Some(100));
1748
1749 let messages = vec![Message {
1750 role: MessageRole::User,
1751 content: "Write a poem".to_string(),
1752 ..Default::default()
1753 }];
1754
1755 let response = client
1756 .generate_content(Provider::Groq, "mixtral-8x7b", messages)
1757 .await?;
1758
1759 assert_eq!(
1760 response.choices[0].message.content,
1761 "Here's a poem with 100 tokens..."
1762 );
1763 assert_eq!(response.model, "mixtral-8x7b");
1764 assert_eq!(response.created, 1630000000);
1765 assert_eq!(response.object, "chat.completion");
1766
1767 mock.assert();
1768 Ok(())
1769 }
1770
1771 #[tokio::test]
1772 async fn test_health_check() -> Result<(), GatewayError> {
1773 let mut server = Server::new_async().await;
1774 let mock = server.mock("GET", "/health").with_status(200).create();
1775
1776 let client = InferenceGatewayClient::new(&server.url());
1777 let is_healthy = client.health_check().await?;
1778
1779 assert!(is_healthy);
1780 mock.assert();
1781
1782 Ok(())
1783 }
1784
1785 #[tokio::test]
1786 async fn test_client_base_url_configuration() -> Result<(), GatewayError> {
1787 let mut custom_url_server = Server::new_async().await;
1788
1789 let custom_url_mock = custom_url_server
1790 .mock("GET", "/health")
1791 .with_status(200)
1792 .create();
1793
1794 let custom_client = InferenceGatewayClient::new(&custom_url_server.url());
1795 let is_healthy = custom_client.health_check().await?;
1796 assert!(is_healthy);
1797 custom_url_mock.assert();
1798
1799 let default_client = InferenceGatewayClient::new_default();
1800
1801 let default_url = "http://localhost:8080/v1";
1802 assert_eq!(default_client.base_url(), default_url);
1803
1804 Ok(())
1805 }
1806
1807 #[tokio::test]
1808 async fn test_list_tools() -> Result<(), GatewayError> {
1809 let mut server = Server::new_async().await;
1810
1811 let raw_response_json = r#"{
1812 "object": "list",
1813 "data": [
1814 {
1815 "name": "read_file",
1816 "description": "Read content from a file",
1817 "server": "http://mcp-filesystem-server:8083/mcp",
1818 "input_schema": {
1819 "type": "object",
1820 "properties": {
1821 "file_path": {
1822 "type": "string",
1823 "description": "Path to the file to read"
1824 }
1825 },
1826 "required": ["file_path"]
1827 }
1828 },
1829 {
1830 "name": "write_file",
1831 "description": "Write content to a file",
1832 "server": "http://mcp-filesystem-server:8083/mcp"
1833 }
1834 ]
1835 }"#;
1836
1837 let mock = server
1838 .mock("GET", "/v1/mcp/tools")
1839 .with_status(200)
1840 .with_header("content-type", "application/json")
1841 .with_body(raw_response_json)
1842 .create();
1843
1844 let base_url = format!("{}/v1", server.url());
1845 let client = InferenceGatewayClient::new(&base_url);
1846 let response = client.list_tools().await?;
1847
1848 assert_eq!(response.object, "list");
1849 assert_eq!(response.data.len(), 2);
1850
1851 assert_eq!(response.data[0].name, "read_file");
1853 assert_eq!(response.data[0].description, "Read content from a file");
1854 assert_eq!(
1855 response.data[0].server,
1856 "http://mcp-filesystem-server:8083/mcp"
1857 );
1858 assert!(response.data[0].input_schema.is_some());
1859
1860 assert_eq!(response.data[1].name, "write_file");
1862 assert_eq!(response.data[1].description, "Write content to a file");
1863 assert_eq!(
1864 response.data[1].server,
1865 "http://mcp-filesystem-server:8083/mcp"
1866 );
1867 assert!(response.data[1].input_schema.is_none());
1868
1869 mock.assert();
1870 Ok(())
1871 }
1872
1873 #[tokio::test]
1874 async fn test_list_tools_with_authentication() -> Result<(), GatewayError> {
1875 let mut server = Server::new_async().await;
1876
1877 let raw_response_json = r#"{
1878 "object": "list",
1879 "data": []
1880 }"#;
1881
1882 let mock = server
1883 .mock("GET", "/v1/mcp/tools")
1884 .match_header("authorization", "Bearer test-token")
1885 .with_status(200)
1886 .with_header("content-type", "application/json")
1887 .with_body(raw_response_json)
1888 .create();
1889
1890 let base_url = format!("{}/v1", server.url());
1891 let client = InferenceGatewayClient::new(&base_url).with_token("test-token");
1892 let response = client.list_tools().await?;
1893
1894 assert_eq!(response.object, "list");
1895 assert_eq!(response.data.len(), 0);
1896 mock.assert();
1897 Ok(())
1898 }
1899
1900 #[tokio::test]
1901 async fn test_list_tools_mcp_not_exposed() -> Result<(), GatewayError> {
1902 let mut server = Server::new_async().await;
1903
1904 let mock = server
1905 .mock("GET", "/v1/mcp/tools")
1906 .with_status(403)
1907 .with_header("content-type", "application/json")
1908 .with_body(
1909 r#"{"error":"MCP tools endpoint is not exposed. Set EXPOSE_MCP=true to enable."}"#,
1910 )
1911 .create();
1912
1913 let base_url = format!("{}/v1", server.url());
1914 let client = InferenceGatewayClient::new(&base_url);
1915
1916 match client.list_tools().await {
1917 Err(GatewayError::Forbidden(msg)) => {
1918 assert_eq!(
1919 msg,
1920 "MCP tools endpoint is not exposed. Set EXPOSE_MCP=true to enable."
1921 );
1922 }
1923 _ => panic!("Expected Forbidden error for MCP not exposed"),
1924 }
1925
1926 mock.assert();
1927 Ok(())
1928 }
1929}