1use core::fmt;
7use std::future::Future;
8
9use futures_util::{Stream, StreamExt};
10use reqwest::{Client, StatusCode};
11use serde::{Deserialize, Serialize};
12
13use serde_json::Value;
14use thiserror::Error;
15
16#[derive(Debug, Serialize, Deserialize)]
18pub struct SSEvents {
19 pub data: String,
20 pub event: Option<String>,
21 pub retry: Option<u64>,
22}
23
24#[derive(Error, Debug)]
26pub enum GatewayError {
27 #[error("Unauthorized: {0}")]
28 Unauthorized(String),
29
30 #[error("Bad request: {0}")]
31 BadRequest(String),
32
33 #[error("Internal server error: {0}")]
34 InternalError(String),
35
36 #[error("Stream error: {0}")]
37 StreamError(reqwest::Error),
38
39 #[error("Decoding error: {0}")]
40 DecodingError(std::string::FromUtf8Error),
41
42 #[error("Request error: {0}")]
43 RequestError(#[from] reqwest::Error),
44
45 #[error("Deserialization error: {0}")]
46 DeserializationError(serde_json::Error),
47
48 #[error("Serialization error: {0}")]
49 SerializationError(#[from] serde_json::Error),
50
51 #[error("Other error: {0}")]
52 Other(#[from] Box<dyn std::error::Error + Send + Sync>),
53}
54
55#[derive(Debug, Deserialize)]
56struct ErrorResponse {
57 error: String,
58}
59
60#[derive(Debug, Serialize, Deserialize)]
62pub struct Model {
63 pub name: String,
65}
66
67#[derive(Debug, Serialize, Deserialize)]
69pub struct ProviderModels {
70 pub provider: Provider,
72 pub models: Vec<Model>,
74}
75
76#[derive(Debug, Serialize, Deserialize, PartialEq, Clone, Copy)]
78#[serde(rename_all = "lowercase")]
79pub enum Provider {
80 #[serde(alias = "Ollama", alias = "OLLAMA")]
81 Ollama,
82 #[serde(alias = "Groq", alias = "GROQ")]
83 Groq,
84 #[serde(alias = "OpenAI", alias = "OPENAI")]
85 OpenAI,
86 #[serde(alias = "Cloudflare", alias = "CLOUDFLARE")]
87 Cloudflare,
88 #[serde(alias = "Cohere", alias = "COHERE")]
89 Cohere,
90 #[serde(alias = "Anthropic", alias = "ANTHROPIC")]
91 Anthropic,
92}
93
94impl fmt::Display for Provider {
95 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
96 match self {
97 Provider::Ollama => write!(f, "ollama"),
98 Provider::Groq => write!(f, "groq"),
99 Provider::OpenAI => write!(f, "openai"),
100 Provider::Cloudflare => write!(f, "cloudflare"),
101 Provider::Cohere => write!(f, "cohere"),
102 Provider::Anthropic => write!(f, "anthropic"),
103 }
104 }
105}
106
107impl TryFrom<&str> for Provider {
108 type Error = GatewayError;
109
110 fn try_from(s: &str) -> Result<Self, Self::Error> {
111 match s.to_lowercase().as_str() {
112 "ollama" => Ok(Self::Ollama),
113 "groq" => Ok(Self::Groq),
114 "openai" => Ok(Self::OpenAI),
115 "cloudflare" => Ok(Self::Cloudflare),
116 "cohere" => Ok(Self::Cohere),
117 "anthropic" => Ok(Self::Anthropic),
118 _ => Err(GatewayError::BadRequest(format!("Unknown provider: {}", s))),
119 }
120 }
121}
122
123#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Default)]
124#[serde(rename_all = "lowercase")]
125pub enum MessageRole {
126 System,
127 #[default]
128 User,
129 Assistant,
130 Tool,
131}
132
133impl fmt::Display for MessageRole {
134 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
135 match self {
136 MessageRole::System => write!(f, "system"),
137 MessageRole::User => write!(f, "user"),
138 MessageRole::Assistant => write!(f, "assistant"),
139 MessageRole::Tool => write!(f, "tool"),
140 }
141 }
142}
143
144#[derive(Debug, Serialize, Deserialize, Clone, Default)]
146pub struct Message {
147 pub role: MessageRole,
149 pub content: String,
151 #[serde(skip_serializing_if = "Option::is_none")]
153 pub tool_calls: Option<Vec<ChatCompletionMessageToolCall>>,
154 #[serde(skip_serializing_if = "Option::is_none")]
156 pub tool_call_id: Option<String>,
157 #[serde(skip_serializing_if = "Option::is_none")]
159 pub reasoning: Option<String>,
160}
161
162#[derive(Debug, Deserialize, Serialize, Clone)]
164pub struct ChatCompletionMessageToolCall {
165 pub id: String,
167 #[serde(rename = "type")]
169 pub r#type: ChatCompletionToolType,
170 pub function: ChatCompletionMessageToolCallFunction,
172}
173
174#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
176pub enum ChatCompletionToolType {
177 #[serde(rename = "function")]
179 Function,
180}
181
182#[derive(Debug, Deserialize, Serialize, Clone)]
184pub struct ChatCompletionMessageToolCallFunction {
185 pub name: String,
187 pub arguments: String,
189}
190
191impl ChatCompletionMessageToolCallFunction {
193 pub fn parse_arguments(&self) -> Result<serde_json::Value, serde_json::Error> {
194 serde_json::from_str(&self.arguments)
195 }
196}
197
198#[derive(Debug, Serialize, Deserialize, Clone)]
200pub struct FunctionObject {
201 pub name: String,
202 pub description: String,
203 pub parameters: Value,
204}
205
206#[derive(Debug, Serialize, Deserialize, Clone)]
208#[serde(rename_all = "lowercase")]
209pub enum ToolType {
210 Function,
211}
212
213#[derive(Debug, Serialize, Deserialize, Clone)]
215pub struct Tool {
216 pub r#type: ToolType,
217 pub function: FunctionObject,
218}
219
220#[derive(Debug, Serialize)]
222struct CreateChatCompletionRequest {
223 model: String,
225 messages: Vec<Message>,
227 stream: bool,
229 #[serde(skip_serializing_if = "Option::is_none")]
231 tools: Option<Vec<Tool>>,
232 #[serde(skip_serializing_if = "Option::is_none")]
234 max_tokens: Option<i32>,
235}
236
237#[derive(Debug, Deserialize, Clone)]
239pub struct ToolFunctionResponse {
240 pub name: String,
242 #[serde(skip_serializing_if = "Option::is_none")]
244 pub description: Option<String>,
245 pub arguments: Value,
247}
248
249#[derive(Debug, Deserialize, Clone)]
251pub struct ToolCallResponse {
252 pub id: String,
254 #[serde(rename = "type")]
256 pub r#type: ToolType,
257 pub function: ToolFunctionResponse,
259}
260
261#[derive(Debug, Deserialize, Clone)]
262pub struct ChatCompletionChoice {
263 pub finish_reason: String,
264 pub message: Message,
265 pub index: i32,
266}
267
268#[derive(Debug, Deserialize, Clone)]
270pub struct CreateChatCompletionResponse {
271 pub id: String,
272 pub choices: Vec<ChatCompletionChoice>,
273 pub created: i64,
274 pub model: String,
275 pub object: String,
276}
277
278#[derive(Debug, Deserialize, Clone)]
280pub struct CreateChatCompletionStreamResponse {
281 pub id: String,
283 pub choices: Vec<ChatCompletionStreamChoice>,
285 pub created: i64,
287 pub model: String,
289 #[serde(skip_serializing_if = "Option::is_none")]
291 pub system_fingerprint: Option<String>,
292 pub object: String,
294 #[serde(skip_serializing_if = "Option::is_none")]
296 pub usage: Option<CompletionUsage>,
297}
298
299#[derive(Debug, Deserialize, Clone)]
301pub struct ChatCompletionStreamChoice {
302 pub delta: ChatCompletionStreamDelta,
304 pub index: i32,
306 #[serde(skip_serializing_if = "Option::is_none")]
308 pub finish_reason: Option<String>,
309}
310
311#[derive(Debug, Deserialize, Clone)]
313pub struct ChatCompletionStreamDelta {
314 #[serde(skip_serializing_if = "Option::is_none")]
316 pub role: Option<MessageRole>,
317 #[serde(skip_serializing_if = "Option::is_none")]
319 pub content: Option<String>,
320 #[serde(skip_serializing_if = "Option::is_none")]
322 pub tool_calls: Option<Vec<ToolCallResponse>>,
323}
324
325#[derive(Debug, Deserialize, Clone)]
327pub struct CompletionUsage {
328 pub completion_tokens: i64,
330 pub prompt_tokens: i64,
332 pub total_tokens: i64,
334}
335
336pub struct InferenceGatewayClient {
338 base_url: String,
339 client: Client,
340 token: Option<String>,
341 tools: Option<Vec<Tool>>,
342 max_tokens: Option<i32>,
343}
344
345impl std::fmt::Debug for InferenceGatewayClient {
347 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
348 f.debug_struct("InferenceGatewayClient")
349 .field("base_url", &self.base_url)
350 .field("token", &self.token.as_ref().map(|_| "*****"))
351 .finish()
352 }
353}
354
355pub trait InferenceGatewayAPI {
357 fn list_models(&self)
368 -> impl Future<Output = Result<Vec<ProviderModels>, GatewayError>> + Send;
369
370 fn list_models_by_provider(
384 &self,
385 provider: Provider,
386 ) -> impl Future<Output = Result<ProviderModels, GatewayError>> + Send;
387
388 fn generate_content(
405 &self,
406 provider: Provider,
407 model: &str,
408 messages: Vec<Message>,
409 ) -> impl Future<Output = Result<CreateChatCompletionResponse, GatewayError>> + Send;
410
411 fn generate_content_stream(
421 &self,
422 provider: Provider,
423 model: &str,
424 messages: Vec<Message>,
425 ) -> impl Stream<Item = Result<SSEvents, GatewayError>> + Send;
426
427 fn health_check(&self) -> impl Future<Output = Result<bool, GatewayError>> + Send;
429}
430
431impl InferenceGatewayClient {
432 pub fn new(base_url: &str) -> Self {
437 Self {
438 base_url: base_url.to_string(),
439 client: Client::new(),
440 token: None,
441 tools: None,
442 max_tokens: None,
443 }
444 }
445
446 pub fn new_default() -> Self {
449 let base_url = std::env::var("INFERENCE_GATEWAY_URL")
450 .unwrap_or_else(|_| "http://localhost:8080/v1".to_string());
451
452 Self {
453 base_url,
454 client: Client::new(),
455 token: None,
456 tools: None,
457 max_tokens: None,
458 }
459 }
460
461 pub fn base_url(&self) -> &str {
463 &self.base_url
464 }
465
466 pub fn with_tools(mut self, tools: Option<Vec<Tool>>) -> Self {
474 self.tools = tools;
475 self
476 }
477
478 pub fn with_token(mut self, token: impl Into<String>) -> Self {
486 self.token = Some(token.into());
487 self
488 }
489
490 pub fn with_max_tokens(mut self, max_tokens: Option<i32>) -> Self {
498 self.max_tokens = max_tokens;
499 self
500 }
501}
502
503impl InferenceGatewayAPI for InferenceGatewayClient {
504 async fn list_models(&self) -> Result<Vec<ProviderModels>, GatewayError> {
505 let url = format!("{}/models", self.base_url);
506 let mut request = self.client.get(&url);
507 if let Some(token) = &self.token {
508 request = request.bearer_auth(token);
509 }
510
511 let response = request.send().await?;
512 match response.status() {
513 StatusCode::OK => Ok(response.json().await?),
514 StatusCode::UNAUTHORIZED => {
515 let error: ErrorResponse = response.json().await?;
516 Err(GatewayError::Unauthorized(error.error))
517 }
518 StatusCode::BAD_REQUEST => {
519 let error: ErrorResponse = response.json().await?;
520 Err(GatewayError::BadRequest(error.error))
521 }
522 StatusCode::INTERNAL_SERVER_ERROR => {
523 let error: ErrorResponse = response.json().await?;
524 Err(GatewayError::InternalError(error.error))
525 }
526 _ => Err(GatewayError::Other(Box::new(std::io::Error::new(
527 std::io::ErrorKind::Other,
528 format!("Unexpected status code: {}", response.status()),
529 )))),
530 }
531 }
532
533 async fn list_models_by_provider(
534 &self,
535 provider: Provider,
536 ) -> Result<ProviderModels, GatewayError> {
537 let url = format!("{}/list/models?provider={}", self.base_url, provider);
538 let mut request = self.client.get(&url);
539 if let Some(token) = &self.token {
540 request = self.client.get(&url).bearer_auth(token);
541 }
542
543 let response = request.send().await?;
544 match response.status() {
545 StatusCode::OK => Ok(response.json().await?),
546 StatusCode::UNAUTHORIZED => {
547 let error: ErrorResponse = response.json().await?;
548 Err(GatewayError::Unauthorized(error.error))
549 }
550 StatusCode::BAD_REQUEST => {
551 let error: ErrorResponse = response.json().await?;
552 Err(GatewayError::BadRequest(error.error))
553 }
554 StatusCode::INTERNAL_SERVER_ERROR => {
555 let error: ErrorResponse = response.json().await?;
556 Err(GatewayError::InternalError(error.error))
557 }
558 _ => Err(GatewayError::Other(Box::new(std::io::Error::new(
559 std::io::ErrorKind::Other,
560 format!("Unexpected status code: {}", response.status()),
561 )))),
562 }
563 }
564
565 async fn generate_content(
566 &self,
567 provider: Provider,
568 model: &str,
569 messages: Vec<Message>,
570 ) -> Result<CreateChatCompletionResponse, GatewayError> {
571 let url = format!("{}/chat/completions?provider={}", self.base_url, provider);
572 let mut request = self.client.post(&url);
573 if let Some(token) = &self.token {
574 request = request.bearer_auth(token);
575 }
576
577 let request_payload = CreateChatCompletionRequest {
578 model: model.to_string(),
579 messages,
580 stream: false,
581 tools: self.tools.clone(),
582 max_tokens: self.max_tokens,
583 };
584
585 let response = request.json(&request_payload).send().await?;
586
587 match response.status() {
588 StatusCode::OK => Ok(response.json().await?),
589 StatusCode::BAD_REQUEST => {
590 let error: ErrorResponse = response.json().await?;
591 Err(GatewayError::BadRequest(error.error))
592 }
593 StatusCode::UNAUTHORIZED => {
594 let error: ErrorResponse = response.json().await?;
595 Err(GatewayError::Unauthorized(error.error))
596 }
597 StatusCode::INTERNAL_SERVER_ERROR => {
598 let error: ErrorResponse = response.json().await?;
599 Err(GatewayError::InternalError(error.error))
600 }
601 status => Err(GatewayError::Other(Box::new(std::io::Error::new(
602 std::io::ErrorKind::Other,
603 format!("Unexpected status code: {}", status),
604 )))),
605 }
606 }
607
608 fn generate_content_stream(
610 &self,
611 provider: Provider,
612 model: &str,
613 messages: Vec<Message>,
614 ) -> impl Stream<Item = Result<SSEvents, GatewayError>> + Send {
615 let client = self.client.clone();
616 let base_url = self.base_url.clone();
617 let url = format!(
618 "{}/chat/completions?provider={}",
619 base_url,
620 provider.to_string().to_lowercase()
621 );
622
623 let request = CreateChatCompletionRequest {
624 model: model.to_string(),
625 messages,
626 stream: true,
627 tools: None,
628 max_tokens: None,
629 };
630
631 async_stream::try_stream! {
632 let response = client.post(&url).json(&request).send().await?;
633 let mut stream = response.bytes_stream();
634 let mut current_event: Option<String> = None;
635 let mut current_data: Option<String> = None;
636
637 while let Some(chunk) = stream.next().await {
638 let chunk = chunk?;
639 let chunk_str = String::from_utf8_lossy(&chunk);
640
641 for line in chunk_str.lines() {
642 if line.is_empty() && current_data.is_some() {
643 yield SSEvents {
644 data: current_data.take().unwrap(),
645 event: current_event.take(),
646 retry: None, };
648 continue;
649 }
650
651 if let Some(event) = line.strip_prefix("event:") {
652 current_event = Some(event.trim().to_string());
653 } else if let Some(data) = line.strip_prefix("data:") {
654 let processed_data = data.strip_suffix('\n').unwrap_or(data);
655 current_data = Some(processed_data.trim().to_string());
656 }
657 }
658 }
659 }
660 }
661
662 async fn health_check(&self) -> Result<bool, GatewayError> {
663 let url = format!("{}/health", self.base_url);
664
665 let response = self.client.get(&url).send().await?;
666 match response.status() {
667 StatusCode::OK => Ok(true),
668 _ => Ok(false),
669 }
670 }
671}
672
673#[cfg(test)]
674mod tests {
675 use crate::{
676 CreateChatCompletionRequest, CreateChatCompletionResponse,
677 CreateChatCompletionStreamResponse, FunctionObject, GatewayError, InferenceGatewayAPI,
678 InferenceGatewayClient, Message, MessageRole, Provider, Tool, ToolType,
679 };
680 use futures_util::{pin_mut, StreamExt};
681 use mockito::{Matcher, Server};
682 use serde_json::json;
683
684 #[test]
685 fn test_provider_serialization() {
686 let providers = vec![
687 (Provider::Ollama, "ollama"),
688 (Provider::Groq, "groq"),
689 (Provider::OpenAI, "openai"),
690 (Provider::Cloudflare, "cloudflare"),
691 (Provider::Cohere, "cohere"),
692 (Provider::Anthropic, "anthropic"),
693 ];
694
695 for (provider, expected) in providers {
696 let json = serde_json::to_string(&provider).unwrap();
697 assert_eq!(json, format!("\"{}\"", expected));
698 }
699 }
700
701 #[test]
702 fn test_provider_deserialization() {
703 let test_cases = vec![
704 ("\"ollama\"", Provider::Ollama),
705 ("\"groq\"", Provider::Groq),
706 ("\"openai\"", Provider::OpenAI),
707 ("\"cloudflare\"", Provider::Cloudflare),
708 ("\"cohere\"", Provider::Cohere),
709 ("\"anthropic\"", Provider::Anthropic),
710 ];
711
712 for (json, expected) in test_cases {
713 let provider: Provider = serde_json::from_str(json).unwrap();
714 assert_eq!(provider, expected);
715 }
716 }
717
718 #[test]
719 fn test_message_serialization_with_tool_call_id() {
720 let message_with_tool = Message {
721 role: MessageRole::Tool,
722 content: "The weather is sunny".to_string(),
723 tool_call_id: Some("call_123".to_string()),
724 ..Default::default()
725 };
726
727 let serialized = serde_json::to_string(&message_with_tool).unwrap();
728 let expected_with_tool =
729 r#"{"role":"tool","content":"The weather is sunny","tool_call_id":"call_123"}"#;
730 assert_eq!(serialized, expected_with_tool);
731
732 let message_without_tool = Message {
733 role: MessageRole::User,
734 content: "What's the weather?".to_string(),
735 ..Default::default()
736 };
737
738 let serialized = serde_json::to_string(&message_without_tool).unwrap();
739 let expected_without_tool = r#"{"role":"user","content":"What's the weather?"}"#;
740 assert_eq!(serialized, expected_without_tool);
741
742 let deserialized: Message = serde_json::from_str(expected_with_tool).unwrap();
743 assert_eq!(deserialized.role, MessageRole::Tool);
744 assert_eq!(deserialized.content, "The weather is sunny");
745 assert_eq!(deserialized.tool_call_id, Some("call_123".to_string()));
746
747 let deserialized: Message = serde_json::from_str(expected_without_tool).unwrap();
748 assert_eq!(deserialized.role, MessageRole::User);
749 assert_eq!(deserialized.content, "What's the weather?");
750 assert_eq!(deserialized.tool_call_id, None);
751 }
752
753 #[test]
754 fn test_provider_display() {
755 let providers = vec![
756 (Provider::Ollama, "ollama"),
757 (Provider::Groq, "groq"),
758 (Provider::OpenAI, "openai"),
759 (Provider::Cloudflare, "cloudflare"),
760 (Provider::Cohere, "cohere"),
761 (Provider::Anthropic, "anthropic"),
762 ];
763
764 for (provider, expected) in providers {
765 assert_eq!(provider.to_string(), expected);
766 }
767 }
768
769 #[test]
770 fn test_generate_request_serialization() {
771 let request_payload = CreateChatCompletionRequest {
772 model: "llama3.2:1b".to_string(),
773 messages: vec![
774 Message {
775 role: MessageRole::System,
776 content: "You are a helpful assistant.".to_string(),
777 ..Default::default()
778 },
779 Message {
780 role: MessageRole::User,
781 content: "What is the current weather in Toronto?".to_string(),
782 ..Default::default()
783 },
784 ],
785 stream: false,
786 tools: Some(vec![Tool {
787 r#type: ToolType::Function,
788 function: FunctionObject {
789 name: "get_current_weather".to_string(),
790 description: "Get the current weather of a city".to_string(),
791 parameters: json!({
792 "type": "object",
793 "properties": {
794 "city": {
795 "type": "string",
796 "description": "The name of the city"
797 }
798 },
799 "required": ["city"]
800 }),
801 },
802 }]),
803 max_tokens: None,
804 };
805
806 let serialized = serde_json::to_string_pretty(&request_payload).unwrap();
807 let expected = r#"{
808 "model": "llama3.2:1b",
809 "messages": [
810 {
811 "role": "system",
812 "content": "You are a helpful assistant."
813 },
814 {
815 "role": "user",
816 "content": "What is the current weather in Toronto?"
817 }
818 ],
819 "stream": false,
820 "tools": [
821 {
822 "type": "function",
823 "function": {
824 "name": "get_current_weather",
825 "description": "Get the current weather of a city",
826 "parameters": {
827 "type": "object",
828 "properties": {
829 "city": {
830 "type": "string",
831 "description": "The name of the city"
832 }
833 },
834 "required": ["city"]
835 }
836 }
837 }
838 ]
839 }"#;
840
841 assert_eq!(
842 serde_json::from_str::<serde_json::Value>(&serialized).unwrap(),
843 serde_json::from_str::<serde_json::Value>(expected).unwrap()
844 );
845 }
846
847 #[tokio::test]
848 async fn test_authentication_header() -> Result<(), GatewayError> {
849 let mut server = Server::new_async().await;
850
851 let mock_with_auth = server
852 .mock("GET", "/v1/models")
853 .match_header("authorization", "Bearer test-token")
854 .with_status(200)
855 .with_header("content-type", "application/json")
856 .with_body("[]")
857 .expect(1)
858 .create();
859
860 let base_url = format!("{}/v1", server.url());
861 let client = InferenceGatewayClient::new(&base_url).with_token("test-token");
862 client.list_models().await?;
863 mock_with_auth.assert();
864
865 let mock_without_auth = server
866 .mock("GET", "/v1/models")
867 .match_header("authorization", Matcher::Missing)
868 .with_status(200)
869 .with_header("content-type", "application/json")
870 .with_body("[]")
871 .expect(1)
872 .create();
873
874 let base_url = format!("{}/v1", server.url());
875 let client = InferenceGatewayClient::new(&base_url);
876 client.list_models().await?;
877 mock_without_auth.assert();
878
879 Ok(())
880 }
881
882 #[tokio::test]
883 async fn test_unauthorized_error() -> Result<(), GatewayError> {
884 let mut server = Server::new_async().await;
885
886 let raw_json_response = r#"{
887 "error": "Invalid token"
888 }"#;
889
890 let mock = server
891 .mock("GET", "/v1/models")
892 .with_status(401)
893 .with_header("content-type", "application/json")
894 .with_body(raw_json_response)
895 .create();
896
897 let base_url = format!("{}/v1", server.url());
898 let client = InferenceGatewayClient::new(&base_url);
899 let error = client.list_models().await.unwrap_err();
900
901 assert!(matches!(error, GatewayError::Unauthorized(_)));
902 if let GatewayError::Unauthorized(msg) = error {
903 assert_eq!(msg, "Invalid token");
904 }
905 mock.assert();
906
907 Ok(())
908 }
909
910 #[tokio::test]
911 async fn test_list_models() -> Result<(), GatewayError> {
912 let mut server = Server::new_async().await;
913
914 let raw_response_json = r#"[
915 {
916 "provider": "ollama",
917 "models": [
918 {"name": "llama2"}
919 ]
920 }
921 ]"#;
922
923 let mock = server
924 .mock("GET", "/v1/models")
925 .with_status(200)
926 .with_header("content-type", "application/json")
927 .with_body(raw_response_json)
928 .create();
929
930 let base_url = format!("{}/v1", server.url());
931 let client = InferenceGatewayClient::new(&base_url);
932 let models = client.list_models().await?;
933
934 assert_eq!(models.len(), 1);
935 assert_eq!(models[0].models[0].name, "llama2");
936 mock.assert();
937
938 Ok(())
939 }
940
941 #[tokio::test]
942 async fn test_list_models_by_provider() -> Result<(), GatewayError> {
943 let mut server = Server::new_async().await;
944
945 let raw_json_response = r#"{
946 "provider":"ollama",
947 "models": [{
948 "name": "llama2"
949 }]
950 }"#;
951
952 let mock = server
953 .mock("GET", "/v1/list/models?provider=ollama")
954 .with_status(200)
955 .with_header("content-type", "application/json")
956 .with_body(raw_json_response)
957 .create();
958
959 let base_url = format!("{}/v1", server.url());
960 let client = InferenceGatewayClient::new(&base_url);
961 let models = client.list_models_by_provider(Provider::Ollama).await?;
962
963 assert_eq!(models.provider, Provider::Ollama);
964 assert_eq!(models.models[0].name, "llama2");
965 mock.assert();
966
967 Ok(())
968 }
969
970 #[tokio::test]
971 async fn test_generate_content() -> Result<(), GatewayError> {
972 let mut server = Server::new_async().await;
973
974 let raw_json_response = r#"{
975 "id": "chatcmpl-456",
976 "object": "chat.completion",
977 "created": 1630000001,
978 "model": "mixtral-8x7b",
979 "choices": [
980 {
981 "index": 0,
982 "finish_reason": "stop",
983 "message": {
984 "role": "assistant",
985 "content": "Hellloooo"
986 }
987 }
988 ]
989 }"#;
990
991 let mock = server
992 .mock("POST", "/v1/chat/completions?provider=ollama")
993 .with_status(200)
994 .with_header("content-type", "application/json")
995 .with_body(raw_json_response)
996 .create();
997
998 let base_url = format!("{}/v1", server.url());
999 let client = InferenceGatewayClient::new(&base_url);
1000
1001 let messages = vec![Message {
1002 role: MessageRole::User,
1003 content: "Hello".to_string(),
1004 ..Default::default()
1005 }];
1006 let response = client
1007 .generate_content(Provider::Ollama, "llama2", messages)
1008 .await?;
1009
1010 assert_eq!(response.choices[0].message.role, MessageRole::Assistant);
1011 assert_eq!(response.choices[0].message.content, "Hellloooo");
1012 mock.assert();
1013
1014 Ok(())
1015 }
1016
1017 #[tokio::test]
1018 async fn test_generate_content_serialization() -> Result<(), GatewayError> {
1019 let mut server = Server::new_async().await;
1020
1021 let raw_json = r#"{
1022 "id": "chatcmpl-456",
1023 "object": "chat.completion",
1024 "created": 1630000001,
1025 "model": "mixtral-8x7b",
1026 "choices": [
1027 {
1028 "index": 0,
1029 "finish_reason": "stop",
1030 "message": {
1031 "role": "assistant",
1032 "content": "Hello"
1033 }
1034 }
1035 ]
1036 }"#;
1037
1038 let mock = server
1039 .mock("POST", "/v1/chat/completions?provider=groq")
1040 .with_status(200)
1041 .with_header("content-type", "application/json")
1042 .with_body(raw_json)
1043 .create();
1044
1045 let base_url = format!("{}/v1", server.url());
1046 let client = InferenceGatewayClient::new(&base_url);
1047
1048 let direct_parse: Result<CreateChatCompletionResponse, _> = serde_json::from_str(raw_json);
1049 assert!(
1050 direct_parse.is_ok(),
1051 "Direct JSON parse failed: {:?}",
1052 direct_parse.err()
1053 );
1054
1055 let messages = vec![Message {
1056 role: MessageRole::User,
1057 content: "Hello".to_string(),
1058 ..Default::default()
1059 }];
1060
1061 let response = client
1062 .generate_content(Provider::Groq, "mixtral-8x7b", messages)
1063 .await?;
1064
1065 assert_eq!(response.choices[0].message.role, MessageRole::Assistant);
1066 assert_eq!(response.choices[0].message.content, "Hello");
1067
1068 mock.assert();
1069 Ok(())
1070 }
1071
1072 #[tokio::test]
1073 async fn test_generate_content_error_response() -> Result<(), GatewayError> {
1074 let mut server = Server::new_async().await;
1075
1076 let raw_json_response = r#"{
1077 "error":"Invalid request"
1078 }"#;
1079
1080 let mock = server
1081 .mock("POST", "/v1/chat/completions?provider=groq")
1082 .with_status(400)
1083 .with_header("content-type", "application/json")
1084 .with_body(raw_json_response)
1085 .create();
1086
1087 let base_url = format!("{}/v1", server.url());
1088 let client = InferenceGatewayClient::new(&base_url);
1089 let messages = vec![Message {
1090 role: MessageRole::User,
1091 content: "Hello".to_string(),
1092 ..Default::default()
1093 }];
1094 let error = client
1095 .generate_content(Provider::Groq, "mixtral-8x7b", messages)
1096 .await
1097 .unwrap_err();
1098
1099 assert!(matches!(error, GatewayError::BadRequest(_)));
1100 if let GatewayError::BadRequest(msg) = error {
1101 assert_eq!(msg, "Invalid request");
1102 }
1103 mock.assert();
1104
1105 Ok(())
1106 }
1107
1108 #[tokio::test]
1109 async fn test_gateway_errors() -> Result<(), GatewayError> {
1110 let mut server: mockito::ServerGuard = Server::new_async().await;
1111
1112 let unauthorized_mock = server
1113 .mock("GET", "/v1/models")
1114 .with_status(401)
1115 .with_header("content-type", "application/json")
1116 .with_body(r#"{"error":"Invalid token"}"#)
1117 .create();
1118
1119 let base_url = format!("{}/v1", server.url());
1120 let client = InferenceGatewayClient::new(&base_url);
1121 match client.list_models().await {
1122 Err(GatewayError::Unauthorized(msg)) => assert_eq!(msg, "Invalid token"),
1123 _ => panic!("Expected Unauthorized error"),
1124 }
1125 unauthorized_mock.assert();
1126
1127 let bad_request_mock = server
1128 .mock("GET", "/v1/models")
1129 .with_status(400)
1130 .with_header("content-type", "application/json")
1131 .with_body(r#"{"error":"Invalid provider"}"#)
1132 .create();
1133
1134 match client.list_models().await {
1135 Err(GatewayError::BadRequest(msg)) => assert_eq!(msg, "Invalid provider"),
1136 _ => panic!("Expected BadRequest error"),
1137 }
1138 bad_request_mock.assert();
1139
1140 let internal_error_mock = server
1141 .mock("GET", "/v1/models")
1142 .with_status(500)
1143 .with_header("content-type", "application/json")
1144 .with_body(r#"{"error":"Internal server error occurred"}"#)
1145 .create();
1146
1147 match client.list_models().await {
1148 Err(GatewayError::InternalError(msg)) => {
1149 assert_eq!(msg, "Internal server error occurred")
1150 }
1151 _ => panic!("Expected InternalError error"),
1152 }
1153 internal_error_mock.assert();
1154
1155 Ok(())
1156 }
1157
1158 #[tokio::test]
1159 async fn test_generate_content_case_insensitive() -> Result<(), GatewayError> {
1160 let mut server = Server::new_async().await;
1161
1162 let raw_json = r#"{
1163 "id": "chatcmpl-456",
1164 "object": "chat.completion",
1165 "created": 1630000001,
1166 "model": "mixtral-8x7b",
1167 "choices": [
1168 {
1169 "index": 0,
1170 "finish_reason": "stop",
1171 "message": {
1172 "role": "assistant",
1173 "content": "Hello"
1174 }
1175 }
1176 ]
1177 }"#;
1178
1179 let mock = server
1180 .mock("POST", "/v1/chat/completions?provider=groq")
1181 .with_status(200)
1182 .with_header("content-type", "application/json")
1183 .with_body(raw_json)
1184 .create();
1185
1186 let base_url = format!("{}/v1", server.url());
1187 let client = InferenceGatewayClient::new(&base_url);
1188
1189 let messages = vec![Message {
1190 role: MessageRole::User,
1191 content: "Hello".to_string(),
1192 ..Default::default()
1193 }];
1194
1195 let response = client
1196 .generate_content(Provider::Groq, "mixtral-8x7b", messages)
1197 .await?;
1198
1199 assert_eq!(response.choices[0].message.role, MessageRole::Assistant);
1200 assert_eq!(response.choices[0].message.content, "Hello");
1201 assert_eq!(response.model, "mixtral-8x7b");
1202 assert_eq!(response.object, "chat.completion");
1203 mock.assert();
1204
1205 Ok(())
1206 }
1207
1208 #[tokio::test]
1209 async fn test_generate_content_stream() -> Result<(), GatewayError> {
1210 let mut server = Server::new_async().await;
1211
1212 let mock = server
1213 .mock("POST", "/v1/chat/completions?provider=groq")
1214 .with_status(200)
1215 .with_header("content-type", "text/event-stream")
1216 .with_chunked_body(move |writer| -> std::io::Result<()> {
1217 let events = vec![
1218 format!("data: {}\n\n", r#"{"id":"chatcmpl-123","object":"chat.completion.chunk","created":1694268190,"model":"mixtral-8x7b","system_fingerprint":"fp_","choices":[{"index":0,"delta":{"role":"assistant","content":"Hello"},"finish_reason":null}]}"#),
1219 format!("data: {}\n\n", r#"{"id":"chatcmpl-123","object":"chat.completion.chunk","created":1694268191,"model":"mixtral-8x7b","system_fingerprint":"fp_","choices":[{"index":0,"delta":{"role":"assistant","content":" World"},"finish_reason":null}]}"#),
1220 format!("data: {}\n\n", r#"{"id":"chatcmpl-123","object":"chat.completion.chunk","created":1694268192,"model":"mixtral-8x7b","system_fingerprint":"fp_","choices":[{"index":0,"delta":{},"finish_reason":"stop"}],"usage":{"prompt_tokens":17,"completion_tokens":40,"total_tokens":57}}"#),
1221 format!("data: [DONE]\n\n")
1222 ];
1223 for event in events {
1224 writer.write_all(event.as_bytes())?;
1225 }
1226 Ok(())
1227 })
1228 .create();
1229
1230 let base_url = format!("{}/v1", server.url());
1231 let client = InferenceGatewayClient::new(&base_url);
1232
1233 let messages = vec![Message {
1234 role: MessageRole::User,
1235 content: "Test message".to_string(),
1236 ..Default::default()
1237 }];
1238
1239 let stream = client.generate_content_stream(Provider::Groq, "mixtral-8x7b", messages);
1240 pin_mut!(stream);
1241 while let Some(result) = stream.next().await {
1242 let result = result?;
1243 let generate_response: CreateChatCompletionStreamResponse =
1244 serde_json::from_str(&result.data)
1245 .expect("Failed to parse CreateChatCompletionResponse");
1246
1247 if generate_response.choices[0].finish_reason.is_some() {
1248 assert_eq!(
1249 generate_response.choices[0].finish_reason.as_ref().unwrap(),
1250 "stop"
1251 );
1252 break;
1253 }
1254
1255 if let Some(content) = &generate_response.choices[0].delta.content {
1256 assert!(matches!(content.as_str(), "Hello" | " World"));
1257 }
1258 if let Some(role) = &generate_response.choices[0].delta.role {
1259 assert_eq!(role, &MessageRole::Assistant);
1260 }
1261 }
1262
1263 mock.assert();
1264 Ok(())
1265 }
1266
1267 #[tokio::test]
1268 async fn test_generate_content_stream_error() -> Result<(), GatewayError> {
1269 let mut server = Server::new_async().await;
1270
1271 let mock = server
1272 .mock("POST", "/v1/chat/completions?provider=groq")
1273 .with_status(400)
1274 .with_header("content-type", "application/json")
1275 .with_chunked_body(move |writer| -> std::io::Result<()> {
1276 let events = vec![format!(
1277 "event: {}\ndata: {}\nretry: {}\n\n",
1278 r#"error"#, r#"{"error":"Invalid request"}"#, r#"1000"#,
1279 )];
1280 for event in events {
1281 writer.write_all(event.as_bytes())?;
1282 }
1283 Ok(())
1284 })
1285 .expect_at_least(1)
1286 .create();
1287
1288 let base_url = format!("{}/v1", server.url());
1289 let client = InferenceGatewayClient::new(&base_url);
1290
1291 let messages = vec![Message {
1292 role: MessageRole::User,
1293 content: "Test message".to_string(),
1294 ..Default::default()
1295 }];
1296
1297 let stream = client.generate_content_stream(Provider::Groq, "mixtral-8x7b", messages);
1298
1299 pin_mut!(stream);
1300 while let Some(result) = stream.next().await {
1301 let result = result?;
1302 assert!(result.event.is_some());
1303 assert_eq!(result.event.unwrap(), "error");
1304 assert!(result.data.contains("Invalid request"));
1305 assert!(result.retry.is_none());
1306 }
1307
1308 mock.assert();
1309 Ok(())
1310 }
1311
1312 #[tokio::test]
1313 async fn test_generate_content_with_tools() -> Result<(), GatewayError> {
1314 let mut server = Server::new_async().await;
1315
1316 let raw_json_response = r#"{
1317 "id": "chatcmpl-123",
1318 "object": "chat.completion",
1319 "created": 1630000000,
1320 "model": "deepseek-r1-distill-llama-70b",
1321 "choices": [
1322 {
1323 "index": 0,
1324 "finish_reason": "tool_calls",
1325 "message": {
1326 "role": "assistant",
1327 "content": "Let me check the weather for you.",
1328 "tool_calls": [
1329 {
1330 "id": "1234",
1331 "type": "function",
1332 "function": {
1333 "name": "get_weather",
1334 "arguments": "{\"location\": \"London\"}"
1335 }
1336 }
1337 ]
1338 }
1339 }
1340 ]
1341 }"#;
1342
1343 let mock = server
1344 .mock("POST", "/v1/chat/completions?provider=groq")
1345 .with_status(200)
1346 .with_header("content-type", "application/json")
1347 .with_body(raw_json_response)
1348 .create();
1349
1350 let tools = vec![Tool {
1351 r#type: ToolType::Function,
1352 function: FunctionObject {
1353 name: "get_weather".to_string(),
1354 description: "Get the weather for a location".to_string(),
1355 parameters: json!({
1356 "type": "object",
1357 "properties": {
1358 "location": {
1359 "type": "string",
1360 "description": "The city name"
1361 }
1362 },
1363 "required": ["location"]
1364 }),
1365 },
1366 }];
1367
1368 let base_url = format!("{}/v1", server.url());
1369 let client = InferenceGatewayClient::new(&base_url).with_tools(Some(tools));
1370
1371 let messages = vec![Message {
1372 role: MessageRole::User,
1373 content: "What's the weather in London?".to_string(),
1374 ..Default::default()
1375 }];
1376
1377 let response = client
1378 .generate_content(Provider::Groq, "deepseek-r1-distill-llama-70b", messages)
1379 .await?;
1380
1381 assert_eq!(response.choices[0].message.role, MessageRole::Assistant);
1382 assert_eq!(
1383 response.choices[0].message.content,
1384 "Let me check the weather for you."
1385 );
1386
1387 let tool_calls = response.choices[0].message.tool_calls.as_ref().unwrap();
1388 assert_eq!(tool_calls.len(), 1);
1389 assert_eq!(tool_calls[0].function.name, "get_weather");
1390
1391 let params = tool_calls[0]
1392 .function
1393 .parse_arguments()
1394 .expect("Failed to parse function arguments");
1395 assert_eq!(params["location"].as_str().unwrap(), "London");
1396
1397 mock.assert();
1398 Ok(())
1399 }
1400
1401 #[tokio::test]
1402 async fn test_generate_content_without_tools() -> Result<(), GatewayError> {
1403 let mut server = Server::new_async().await;
1404
1405 let raw_json_response = r#"{
1406 "id": "chatcmpl-123",
1407 "object": "chat.completion",
1408 "created": 1630000000,
1409 "model": "gpt-4",
1410 "choices": [
1411 {
1412 "index": 0,
1413 "finish_reason": "stop",
1414 "message": {
1415 "role": "assistant",
1416 "content": "Hello!"
1417 }
1418 }
1419 ]
1420 }"#;
1421
1422 let mock = server
1423 .mock("POST", "/v1/chat/completions?provider=openai")
1424 .with_status(200)
1425 .with_header("content-type", "application/json")
1426 .with_body(raw_json_response)
1427 .create();
1428
1429 let base_url = format!("{}/v1", server.url());
1430 let client = InferenceGatewayClient::new(&base_url);
1431
1432 let messages = vec![Message {
1433 role: MessageRole::User,
1434 content: "Hi".to_string(),
1435 ..Default::default()
1436 }];
1437
1438 let response = client
1439 .generate_content(Provider::OpenAI, "gpt-4", messages)
1440 .await?;
1441
1442 assert_eq!(response.model, "gpt-4");
1443 assert_eq!(response.choices[0].message.content, "Hello!");
1444 assert_eq!(response.choices[0].message.role, MessageRole::Assistant);
1445 assert!(response.choices[0].message.tool_calls.is_none());
1446
1447 mock.assert();
1448 Ok(())
1449 }
1450
1451 #[tokio::test]
1452 async fn test_generate_content_with_tools_payload() -> Result<(), GatewayError> {
1453 let mut server = Server::new_async().await;
1454
1455 let raw_request_body = r#"{
1456 "model": "deepseek-r1-distill-llama-70b",
1457 "messages": [
1458 {
1459 "role": "system",
1460 "content": "You are a helpful assistant."
1461 },
1462 {
1463 "role": "user",
1464 "content": "What is the current weather in Toronto?"
1465 }
1466 ],
1467 "stream": false,
1468 "tools": [
1469 {
1470 "type": "function",
1471 "function": {
1472 "name": "get_current_weather",
1473 "description": "Get the current weather of a city",
1474 "parameters": {
1475 "type": "object",
1476 "properties": {
1477 "city": {
1478 "type": "string",
1479 "description": "The name of the city"
1480 }
1481 },
1482 "required": ["city"]
1483 }
1484 }
1485 }
1486 ]
1487 }"#;
1488
1489 let raw_json_response = r#"{
1490 "id": "1234",
1491 "object": "chat.completion",
1492 "created": 1630000000,
1493 "model": "deepseek-r1-distill-llama-70b",
1494 "choices": [
1495 {
1496 "index": 0,
1497 "finish_reason": "stop",
1498 "message": {
1499 "role": "assistant",
1500 "content": "Let me check the weather for you",
1501 "tool_calls": [
1502 {
1503 "id": "1234",
1504 "type": "function",
1505 "function": {
1506 "name": "get_current_weather",
1507 "arguments": "{\"city\": \"Toronto\"}"
1508 }
1509 }
1510 ]
1511 }
1512 }
1513 ]
1514 }"#;
1515
1516 let mock = server
1517 .mock("POST", "/v1/chat/completions?provider=groq")
1518 .with_status(200)
1519 .with_header("content-type", "application/json")
1520 .match_body(mockito::Matcher::JsonString(raw_request_body.to_string()))
1521 .with_body(raw_json_response)
1522 .create();
1523
1524 let tools = vec![Tool {
1525 r#type: ToolType::Function,
1526 function: FunctionObject {
1527 name: "get_current_weather".to_string(),
1528 description: "Get the current weather of a city".to_string(),
1529 parameters: json!({
1530 "type": "object",
1531 "properties": {
1532 "city": {
1533 "type": "string",
1534 "description": "The name of the city"
1535 }
1536 },
1537 "required": ["city"]
1538 }),
1539 },
1540 }];
1541
1542 let base_url = format!("{}/v1", server.url());
1543 let client = InferenceGatewayClient::new(&base_url);
1544
1545 let messages = vec![
1546 Message {
1547 role: MessageRole::System,
1548 content: "You are a helpful assistant.".to_string(),
1549 ..Default::default()
1550 },
1551 Message {
1552 role: MessageRole::User,
1553 content: "What is the current weather in Toronto?".to_string(),
1554 ..Default::default()
1555 },
1556 ];
1557
1558 let response = client
1559 .with_tools(Some(tools))
1560 .generate_content(Provider::Groq, "deepseek-r1-distill-llama-70b", messages)
1561 .await?;
1562
1563 assert_eq!(response.choices[0].message.role, MessageRole::Assistant);
1564 assert_eq!(
1565 response.choices[0].message.content,
1566 "Let me check the weather for you"
1567 );
1568 assert_eq!(
1569 response.choices[0]
1570 .message
1571 .tool_calls
1572 .as_ref()
1573 .unwrap()
1574 .len(),
1575 1
1576 );
1577
1578 mock.assert();
1579 Ok(())
1580 }
1581
1582 #[tokio::test]
1583 async fn test_generate_content_with_max_tokens() -> Result<(), GatewayError> {
1584 let mut server = Server::new_async().await;
1585
1586 let raw_json_response = r#"{
1587 "id": "chatcmpl-123",
1588 "object": "chat.completion",
1589 "created": 1630000000,
1590 "model": "mixtral-8x7b",
1591 "choices": [
1592 {
1593 "index": 0,
1594 "finish_reason": "stop",
1595 "message": {
1596 "role": "assistant",
1597 "content": "Here's a poem with 100 tokens..."
1598 }
1599 }
1600 ]
1601 }"#;
1602
1603 let mock = server
1604 .mock("POST", "/v1/chat/completions?provider=groq")
1605 .with_status(200)
1606 .with_header("content-type", "application/json")
1607 .match_body(mockito::Matcher::JsonString(
1608 r#"{
1609 "model": "mixtral-8x7b",
1610 "messages": [{"role":"user","content":"Write a poem"}],
1611 "stream": false,
1612 "max_tokens": 100
1613 }"#
1614 .to_string(),
1615 ))
1616 .with_body(raw_json_response)
1617 .create();
1618
1619 let base_url = format!("{}/v1", server.url());
1620 let client = InferenceGatewayClient::new(&base_url).with_max_tokens(Some(100));
1621
1622 let messages = vec![Message {
1623 role: MessageRole::User,
1624 content: "Write a poem".to_string(),
1625 ..Default::default()
1626 }];
1627
1628 let response = client
1629 .generate_content(Provider::Groq, "mixtral-8x7b", messages)
1630 .await?;
1631
1632 assert_eq!(
1633 response.choices[0].message.content,
1634 "Here's a poem with 100 tokens..."
1635 );
1636 assert_eq!(response.model, "mixtral-8x7b");
1637 assert_eq!(response.created, 1630000000);
1638 assert_eq!(response.object, "chat.completion");
1639
1640 mock.assert();
1641 Ok(())
1642 }
1643
1644 #[tokio::test]
1645 async fn test_health_check() -> Result<(), GatewayError> {
1646 let mut server = Server::new_async().await;
1647 let mock = server.mock("GET", "/health").with_status(200).create();
1648
1649 let client = InferenceGatewayClient::new(&server.url());
1650 let is_healthy = client.health_check().await?;
1651
1652 assert!(is_healthy);
1653 mock.assert();
1654
1655 Ok(())
1656 }
1657
1658 #[tokio::test]
1659 async fn test_client_base_url_configuration() -> Result<(), GatewayError> {
1660 let mut custom_url_server = Server::new_async().await;
1661
1662 let custom_url_mock = custom_url_server
1663 .mock("GET", "/health")
1664 .with_status(200)
1665 .create();
1666
1667 let custom_client = InferenceGatewayClient::new(&custom_url_server.url());
1668 let is_healthy = custom_client.health_check().await?;
1669 assert!(is_healthy);
1670 custom_url_mock.assert();
1671
1672 let default_client = InferenceGatewayClient::new_default();
1673
1674 let default_url = "http://localhost:8080/v1";
1675 assert_eq!(default_client.base_url(), default_url);
1676
1677 Ok(())
1678 }
1679}