1use core::fmt;
7use std::future::Future;
8
9use futures_util::{Stream, StreamExt};
10use reqwest::{Client, StatusCode};
11use serde::{Deserialize, Serialize};
12
13use serde_json::Value;
14use thiserror::Error;
15
16#[derive(Debug, Serialize, Deserialize)]
18pub struct SSEvents {
19 pub data: String,
20 pub event: Option<String>,
21 pub retry: Option<u64>,
22}
23
24#[derive(Error, Debug)]
26pub enum GatewayError {
27 #[error("Unauthorized: {0}")]
28 Unauthorized(String),
29
30 #[error("Bad request: {0}")]
31 BadRequest(String),
32
33 #[error("Internal server error: {0}")]
34 InternalError(String),
35
36 #[error("Stream error: {0}")]
37 StreamError(reqwest::Error),
38
39 #[error("Decoding error: {0}")]
40 DecodingError(std::string::FromUtf8Error),
41
42 #[error("Request error: {0}")]
43 RequestError(#[from] reqwest::Error),
44
45 #[error("Deserialization error: {0}")]
46 DeserializationError(serde_json::Error),
47
48 #[error("Serialization error: {0}")]
49 SerializationError(#[from] serde_json::Error),
50
51 #[error("Other error: {0}")]
52 Other(#[from] Box<dyn std::error::Error + Send + Sync>),
53}
54
55#[derive(Debug, Deserialize)]
56struct ErrorResponse {
57 error: String,
58}
59
60#[derive(Debug, Serialize, Deserialize, Clone)]
62pub struct Model {
63 pub id: String,
65 pub object: Option<String>,
67 pub created: Option<i64>,
69 pub owned_by: Option<String>,
71 pub served_by: Option<String>,
73}
74
75#[derive(Debug, Serialize, Deserialize)]
77pub struct ListModelsResponse {
78 #[serde(skip_serializing_if = "Option::is_none")]
80 pub provider: Option<Provider>,
81 pub object: String,
83 pub data: Vec<Model>,
85}
86
87#[derive(Debug, Serialize, Deserialize, PartialEq, Clone, Copy)]
89#[serde(rename_all = "lowercase")]
90pub enum Provider {
91 #[serde(alias = "Ollama", alias = "OLLAMA")]
92 Ollama,
93 #[serde(alias = "Groq", alias = "GROQ")]
94 Groq,
95 #[serde(alias = "OpenAI", alias = "OPENAI")]
96 OpenAI,
97 #[serde(alias = "Cloudflare", alias = "CLOUDFLARE")]
98 Cloudflare,
99 #[serde(alias = "Cohere", alias = "COHERE")]
100 Cohere,
101 #[serde(alias = "Anthropic", alias = "ANTHROPIC")]
102 Anthropic,
103}
104
105impl fmt::Display for Provider {
106 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
107 match self {
108 Provider::Ollama => write!(f, "ollama"),
109 Provider::Groq => write!(f, "groq"),
110 Provider::OpenAI => write!(f, "openai"),
111 Provider::Cloudflare => write!(f, "cloudflare"),
112 Provider::Cohere => write!(f, "cohere"),
113 Provider::Anthropic => write!(f, "anthropic"),
114 }
115 }
116}
117
118impl TryFrom<&str> for Provider {
119 type Error = GatewayError;
120
121 fn try_from(s: &str) -> Result<Self, Self::Error> {
122 match s.to_lowercase().as_str() {
123 "ollama" => Ok(Self::Ollama),
124 "groq" => Ok(Self::Groq),
125 "openai" => Ok(Self::OpenAI),
126 "cloudflare" => Ok(Self::Cloudflare),
127 "cohere" => Ok(Self::Cohere),
128 "anthropic" => Ok(Self::Anthropic),
129 _ => Err(GatewayError::BadRequest(format!("Unknown provider: {}", s))),
130 }
131 }
132}
133
134#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Default)]
135#[serde(rename_all = "lowercase")]
136pub enum MessageRole {
137 System,
138 #[default]
139 User,
140 Assistant,
141 Tool,
142}
143
144impl fmt::Display for MessageRole {
145 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
146 match self {
147 MessageRole::System => write!(f, "system"),
148 MessageRole::User => write!(f, "user"),
149 MessageRole::Assistant => write!(f, "assistant"),
150 MessageRole::Tool => write!(f, "tool"),
151 }
152 }
153}
154
155#[derive(Debug, Serialize, Deserialize, Clone, Default)]
157pub struct Message {
158 pub role: MessageRole,
160 pub content: String,
162 #[serde(skip_serializing_if = "Option::is_none")]
164 pub tool_calls: Option<Vec<ChatCompletionMessageToolCall>>,
165 #[serde(skip_serializing_if = "Option::is_none")]
167 pub tool_call_id: Option<String>,
168 #[serde(skip_serializing_if = "Option::is_none")]
170 pub reasoning: Option<String>,
171}
172
173#[derive(Debug, Deserialize, Serialize, Clone)]
175pub struct ChatCompletionMessageToolCall {
176 pub id: String,
178 #[serde(rename = "type")]
180 pub r#type: ChatCompletionToolType,
181 pub function: ChatCompletionMessageToolCallFunction,
183}
184
185#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
187pub enum ChatCompletionToolType {
188 #[serde(rename = "function")]
190 Function,
191}
192
193#[derive(Debug, Deserialize, Serialize, Clone)]
195pub struct ChatCompletionMessageToolCallFunction {
196 pub name: String,
198 pub arguments: String,
200}
201
202impl ChatCompletionMessageToolCallFunction {
204 pub fn parse_arguments(&self) -> Result<serde_json::Value, serde_json::Error> {
205 serde_json::from_str(&self.arguments)
206 }
207}
208
209#[derive(Debug, Serialize, Deserialize, Clone)]
211pub struct FunctionObject {
212 pub name: String,
213 pub description: String,
214 pub parameters: Value,
215}
216
217#[derive(Debug, Serialize, Deserialize, Clone)]
219#[serde(rename_all = "lowercase")]
220pub enum ToolType {
221 Function,
222}
223
224#[derive(Debug, Serialize, Deserialize, Clone)]
226pub struct Tool {
227 pub r#type: ToolType,
228 pub function: FunctionObject,
229}
230
231#[derive(Debug, Serialize)]
233struct CreateChatCompletionRequest {
234 model: String,
236 messages: Vec<Message>,
238 stream: bool,
240 #[serde(skip_serializing_if = "Option::is_none")]
242 tools: Option<Vec<Tool>>,
243 #[serde(skip_serializing_if = "Option::is_none")]
245 max_tokens: Option<i32>,
246}
247
248#[derive(Debug, Deserialize, Clone)]
250pub struct ToolCallResponse {
251 pub id: String,
253 #[serde(rename = "type")]
255 pub r#type: ToolType,
256 pub function: ChatCompletionMessageToolCallFunction,
258}
259
260#[derive(Debug, Deserialize, Clone)]
261pub struct ChatCompletionChoice {
262 pub finish_reason: String,
263 pub message: Message,
264 pub index: i32,
265}
266
267#[derive(Debug, Deserialize, Clone)]
269pub struct CreateChatCompletionResponse {
270 pub id: String,
271 pub choices: Vec<ChatCompletionChoice>,
272 pub created: i64,
273 pub model: String,
274 pub object: String,
275}
276
277#[derive(Debug, Deserialize, Clone)]
279pub struct CreateChatCompletionStreamResponse {
280 pub id: String,
282 pub choices: Vec<ChatCompletionStreamChoice>,
284 pub created: i64,
286 pub model: String,
288 #[serde(skip_serializing_if = "Option::is_none")]
290 pub system_fingerprint: Option<String>,
291 pub object: String,
293 #[serde(skip_serializing_if = "Option::is_none")]
295 pub usage: Option<CompletionUsage>,
296}
297
298#[derive(Debug, Deserialize, Clone)]
300pub struct ChatCompletionStreamChoice {
301 pub delta: ChatCompletionStreamDelta,
303 pub index: i32,
305 #[serde(skip_serializing_if = "Option::is_none")]
307 pub finish_reason: Option<String>,
308}
309
310#[derive(Debug, Deserialize, Clone)]
312pub struct ChatCompletionStreamDelta {
313 #[serde(skip_serializing_if = "Option::is_none")]
315 pub role: Option<MessageRole>,
316 #[serde(skip_serializing_if = "Option::is_none")]
318 pub content: Option<String>,
319 #[serde(skip_serializing_if = "Option::is_none")]
321 pub tool_calls: Option<Vec<ToolCallResponse>>,
322}
323
324#[derive(Debug, Deserialize, Clone)]
326pub struct CompletionUsage {
327 pub completion_tokens: i64,
329 pub prompt_tokens: i64,
331 pub total_tokens: i64,
333}
334
335pub struct InferenceGatewayClient {
337 base_url: String,
338 client: Client,
339 token: Option<String>,
340 tools: Option<Vec<Tool>>,
341 max_tokens: Option<i32>,
342}
343
344impl std::fmt::Debug for InferenceGatewayClient {
346 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
347 f.debug_struct("InferenceGatewayClient")
348 .field("base_url", &self.base_url)
349 .field("token", &self.token.as_ref().map(|_| "*****"))
350 .finish()
351 }
352}
353
354pub trait InferenceGatewayAPI {
356 fn list_models(&self) -> impl Future<Output = Result<ListModelsResponse, GatewayError>> + Send;
367
368 fn list_models_by_provider(
382 &self,
383 provider: Provider,
384 ) -> impl Future<Output = Result<ListModelsResponse, GatewayError>> + Send;
385
386 fn generate_content(
403 &self,
404 provider: Provider,
405 model: &str,
406 messages: Vec<Message>,
407 ) -> impl Future<Output = Result<CreateChatCompletionResponse, GatewayError>> + Send;
408
409 fn generate_content_stream(
419 &self,
420 provider: Provider,
421 model: &str,
422 messages: Vec<Message>,
423 ) -> impl Stream<Item = Result<SSEvents, GatewayError>> + Send;
424
425 fn health_check(&self) -> impl Future<Output = Result<bool, GatewayError>> + Send;
427}
428
429impl InferenceGatewayClient {
430 pub fn new(base_url: &str) -> Self {
435 Self {
436 base_url: base_url.to_string(),
437 client: Client::new(),
438 token: None,
439 tools: None,
440 max_tokens: None,
441 }
442 }
443
444 pub fn new_default() -> Self {
447 let base_url = std::env::var("INFERENCE_GATEWAY_URL")
448 .unwrap_or_else(|_| "http://localhost:8080/v1".to_string());
449
450 Self {
451 base_url,
452 client: Client::new(),
453 token: None,
454 tools: None,
455 max_tokens: None,
456 }
457 }
458
459 pub fn base_url(&self) -> &str {
461 &self.base_url
462 }
463
464 pub fn with_tools(mut self, tools: Option<Vec<Tool>>) -> Self {
472 self.tools = tools;
473 self
474 }
475
476 pub fn with_token(mut self, token: impl Into<String>) -> Self {
484 self.token = Some(token.into());
485 self
486 }
487
488 pub fn with_max_tokens(mut self, max_tokens: Option<i32>) -> Self {
496 self.max_tokens = max_tokens;
497 self
498 }
499}
500
501impl InferenceGatewayAPI for InferenceGatewayClient {
502 async fn list_models(&self) -> Result<ListModelsResponse, GatewayError> {
503 let url = format!("{}/models", self.base_url);
504 let mut request = self.client.get(&url);
505 if let Some(token) = &self.token {
506 request = request.bearer_auth(token);
507 }
508
509 let response = request.send().await?;
510 match response.status() {
511 StatusCode::OK => {
512 let json_response: ListModelsResponse = response.json().await?;
513 Ok(json_response)
514 }
515 StatusCode::UNAUTHORIZED => {
516 let error: ErrorResponse = response.json().await?;
517 Err(GatewayError::Unauthorized(error.error))
518 }
519 StatusCode::BAD_REQUEST => {
520 let error: ErrorResponse = response.json().await?;
521 Err(GatewayError::BadRequest(error.error))
522 }
523 StatusCode::INTERNAL_SERVER_ERROR => {
524 let error: ErrorResponse = response.json().await?;
525 Err(GatewayError::InternalError(error.error))
526 }
527 _ => Err(GatewayError::Other(Box::new(std::io::Error::new(
528 std::io::ErrorKind::Other,
529 format!("Unexpected status code: {}", response.status()),
530 )))),
531 }
532 }
533
534 async fn list_models_by_provider(
535 &self,
536 provider: Provider,
537 ) -> Result<ListModelsResponse, GatewayError> {
538 let url = format!("{}/models?provider={}", self.base_url, provider);
539 let mut request = self.client.get(&url);
540 if let Some(token) = &self.token {
541 request = self.client.get(&url).bearer_auth(token);
542 }
543
544 let response = request.send().await?;
545 match response.status() {
546 StatusCode::OK => {
547 let json_response: ListModelsResponse = response.json().await?;
548 Ok(json_response)
549 }
550 StatusCode::UNAUTHORIZED => {
551 let error: ErrorResponse = response.json().await?;
552 Err(GatewayError::Unauthorized(error.error))
553 }
554 StatusCode::BAD_REQUEST => {
555 let error: ErrorResponse = response.json().await?;
556 Err(GatewayError::BadRequest(error.error))
557 }
558 StatusCode::INTERNAL_SERVER_ERROR => {
559 let error: ErrorResponse = response.json().await?;
560 Err(GatewayError::InternalError(error.error))
561 }
562 _ => Err(GatewayError::Other(Box::new(std::io::Error::new(
563 std::io::ErrorKind::Other,
564 format!("Unexpected status code: {}", response.status()),
565 )))),
566 }
567 }
568
569 async fn generate_content(
570 &self,
571 provider: Provider,
572 model: &str,
573 messages: Vec<Message>,
574 ) -> Result<CreateChatCompletionResponse, GatewayError> {
575 let url = format!("{}/chat/completions?provider={}", self.base_url, provider);
576 let mut request = self.client.post(&url);
577 if let Some(token) = &self.token {
578 request = request.bearer_auth(token);
579 }
580
581 let request_payload = CreateChatCompletionRequest {
582 model: model.to_string(),
583 messages,
584 stream: false,
585 tools: self.tools.clone(),
586 max_tokens: self.max_tokens,
587 };
588
589 let response = request.json(&request_payload).send().await?;
590
591 match response.status() {
592 StatusCode::OK => Ok(response.json().await?),
593 StatusCode::BAD_REQUEST => {
594 let error: ErrorResponse = response.json().await?;
595 Err(GatewayError::BadRequest(error.error))
596 }
597 StatusCode::UNAUTHORIZED => {
598 let error: ErrorResponse = response.json().await?;
599 Err(GatewayError::Unauthorized(error.error))
600 }
601 StatusCode::INTERNAL_SERVER_ERROR => {
602 let error: ErrorResponse = response.json().await?;
603 Err(GatewayError::InternalError(error.error))
604 }
605 status => Err(GatewayError::Other(Box::new(std::io::Error::new(
606 std::io::ErrorKind::Other,
607 format!("Unexpected status code: {}", status),
608 )))),
609 }
610 }
611
612 fn generate_content_stream(
614 &self,
615 provider: Provider,
616 model: &str,
617 messages: Vec<Message>,
618 ) -> impl Stream<Item = Result<SSEvents, GatewayError>> + Send {
619 let client = self.client.clone();
620 let base_url = self.base_url.clone();
621 let url = format!(
622 "{}/chat/completions?provider={}",
623 base_url,
624 provider.to_string().to_lowercase()
625 );
626
627 let request = CreateChatCompletionRequest {
628 model: model.to_string(),
629 messages,
630 stream: true,
631 tools: None,
632 max_tokens: None,
633 };
634
635 async_stream::try_stream! {
636 let response = client.post(&url).json(&request).send().await?;
637 let mut stream = response.bytes_stream();
638 let mut current_event: Option<String> = None;
639 let mut current_data: Option<String> = None;
640
641 while let Some(chunk) = stream.next().await {
642 let chunk = chunk?;
643 let chunk_str = String::from_utf8_lossy(&chunk);
644
645 for line in chunk_str.lines() {
646 if line.is_empty() && current_data.is_some() {
647 yield SSEvents {
648 data: current_data.take().unwrap(),
649 event: current_event.take(),
650 retry: None, };
652 continue;
653 }
654
655 if let Some(event) = line.strip_prefix("event:") {
656 current_event = Some(event.trim().to_string());
657 } else if let Some(data) = line.strip_prefix("data:") {
658 let processed_data = data.strip_suffix('\n').unwrap_or(data);
659 current_data = Some(processed_data.trim().to_string());
660 }
661 }
662 }
663 }
664 }
665
666 async fn health_check(&self) -> Result<bool, GatewayError> {
667 let url = format!("{}/health", self.base_url);
668
669 let response = self.client.get(&url).send().await?;
670 match response.status() {
671 StatusCode::OK => Ok(true),
672 _ => Ok(false),
673 }
674 }
675}
676
677#[cfg(test)]
678mod tests {
679 use crate::{
680 CreateChatCompletionRequest, CreateChatCompletionResponse,
681 CreateChatCompletionStreamResponse, FunctionObject, GatewayError, InferenceGatewayAPI,
682 InferenceGatewayClient, Message, MessageRole, Provider, Tool, ToolType,
683 };
684 use futures_util::{pin_mut, StreamExt};
685 use mockito::{Matcher, Server};
686 use serde_json::json;
687
688 #[test]
689 fn test_provider_serialization() {
690 let providers = vec![
691 (Provider::Ollama, "ollama"),
692 (Provider::Groq, "groq"),
693 (Provider::OpenAI, "openai"),
694 (Provider::Cloudflare, "cloudflare"),
695 (Provider::Cohere, "cohere"),
696 (Provider::Anthropic, "anthropic"),
697 ];
698
699 for (provider, expected) in providers {
700 let json = serde_json::to_string(&provider).unwrap();
701 assert_eq!(json, format!("\"{}\"", expected));
702 }
703 }
704
705 #[test]
706 fn test_provider_deserialization() {
707 let test_cases = vec![
708 ("\"ollama\"", Provider::Ollama),
709 ("\"groq\"", Provider::Groq),
710 ("\"openai\"", Provider::OpenAI),
711 ("\"cloudflare\"", Provider::Cloudflare),
712 ("\"cohere\"", Provider::Cohere),
713 ("\"anthropic\"", Provider::Anthropic),
714 ];
715
716 for (json, expected) in test_cases {
717 let provider: Provider = serde_json::from_str(json).unwrap();
718 assert_eq!(provider, expected);
719 }
720 }
721
722 #[test]
723 fn test_message_serialization_with_tool_call_id() {
724 let message_with_tool = Message {
725 role: MessageRole::Tool,
726 content: "The weather is sunny".to_string(),
727 tool_call_id: Some("call_123".to_string()),
728 ..Default::default()
729 };
730
731 let serialized = serde_json::to_string(&message_with_tool).unwrap();
732 let expected_with_tool =
733 r#"{"role":"tool","content":"The weather is sunny","tool_call_id":"call_123"}"#;
734 assert_eq!(serialized, expected_with_tool);
735
736 let message_without_tool = Message {
737 role: MessageRole::User,
738 content: "What's the weather?".to_string(),
739 ..Default::default()
740 };
741
742 let serialized = serde_json::to_string(&message_without_tool).unwrap();
743 let expected_without_tool = r#"{"role":"user","content":"What's the weather?"}"#;
744 assert_eq!(serialized, expected_without_tool);
745
746 let deserialized: Message = serde_json::from_str(expected_with_tool).unwrap();
747 assert_eq!(deserialized.role, MessageRole::Tool);
748 assert_eq!(deserialized.content, "The weather is sunny");
749 assert_eq!(deserialized.tool_call_id, Some("call_123".to_string()));
750
751 let deserialized: Message = serde_json::from_str(expected_without_tool).unwrap();
752 assert_eq!(deserialized.role, MessageRole::User);
753 assert_eq!(deserialized.content, "What's the weather?");
754 assert_eq!(deserialized.tool_call_id, None);
755 }
756
757 #[test]
758 fn test_provider_display() {
759 let providers = vec![
760 (Provider::Ollama, "ollama"),
761 (Provider::Groq, "groq"),
762 (Provider::OpenAI, "openai"),
763 (Provider::Cloudflare, "cloudflare"),
764 (Provider::Cohere, "cohere"),
765 (Provider::Anthropic, "anthropic"),
766 ];
767
768 for (provider, expected) in providers {
769 assert_eq!(provider.to_string(), expected);
770 }
771 }
772
773 #[test]
774 fn test_generate_request_serialization() {
775 let request_payload = CreateChatCompletionRequest {
776 model: "llama3.2:1b".to_string(),
777 messages: vec![
778 Message {
779 role: MessageRole::System,
780 content: "You are a helpful assistant.".to_string(),
781 ..Default::default()
782 },
783 Message {
784 role: MessageRole::User,
785 content: "What is the current weather in Toronto?".to_string(),
786 ..Default::default()
787 },
788 ],
789 stream: false,
790 tools: Some(vec![Tool {
791 r#type: ToolType::Function,
792 function: FunctionObject {
793 name: "get_current_weather".to_string(),
794 description: "Get the current weather of a city".to_string(),
795 parameters: json!({
796 "type": "object",
797 "properties": {
798 "city": {
799 "type": "string",
800 "description": "The name of the city"
801 }
802 },
803 "required": ["city"]
804 }),
805 },
806 }]),
807 max_tokens: None,
808 };
809
810 let serialized = serde_json::to_string_pretty(&request_payload).unwrap();
811 let expected = r#"{
812 "model": "llama3.2:1b",
813 "messages": [
814 {
815 "role": "system",
816 "content": "You are a helpful assistant."
817 },
818 {
819 "role": "user",
820 "content": "What is the current weather in Toronto?"
821 }
822 ],
823 "stream": false,
824 "tools": [
825 {
826 "type": "function",
827 "function": {
828 "name": "get_current_weather",
829 "description": "Get the current weather of a city",
830 "parameters": {
831 "type": "object",
832 "properties": {
833 "city": {
834 "type": "string",
835 "description": "The name of the city"
836 }
837 },
838 "required": ["city"]
839 }
840 }
841 }
842 ]
843 }"#;
844
845 assert_eq!(
846 serde_json::from_str::<serde_json::Value>(&serialized).unwrap(),
847 serde_json::from_str::<serde_json::Value>(expected).unwrap()
848 );
849 }
850
851 #[tokio::test]
852 async fn test_authentication_header() -> Result<(), GatewayError> {
853 let mut server = Server::new_async().await;
854
855 let mock_response = r#"{
856 "object": "list",
857 "data": []
858 }"#;
859
860 let mock_with_auth = server
861 .mock("GET", "/v1/models")
862 .match_header("authorization", "Bearer test-token")
863 .with_status(200)
864 .with_header("content-type", "application/json")
865 .with_body(mock_response)
866 .expect(1)
867 .create();
868
869 let base_url = format!("{}/v1", server.url());
870 let client = InferenceGatewayClient::new(&base_url).with_token("test-token");
871 client.list_models().await?;
872 mock_with_auth.assert();
873
874 let mock_without_auth = server
875 .mock("GET", "/v1/models")
876 .match_header("authorization", Matcher::Missing)
877 .with_status(200)
878 .with_header("content-type", "application/json")
879 .with_body(mock_response)
880 .expect(1)
881 .create();
882
883 let base_url = format!("{}/v1", server.url());
884 let client = InferenceGatewayClient::new(&base_url);
885 client.list_models().await?;
886 mock_without_auth.assert();
887
888 Ok(())
889 }
890
891 #[tokio::test]
892 async fn test_unauthorized_error() -> Result<(), GatewayError> {
893 let mut server = Server::new_async().await;
894
895 let raw_json_response = r#"{
896 "error": "Invalid token"
897 }"#;
898
899 let mock = server
900 .mock("GET", "/v1/models")
901 .with_status(401)
902 .with_header("content-type", "application/json")
903 .with_body(raw_json_response)
904 .create();
905
906 let base_url = format!("{}/v1", server.url());
907 let client = InferenceGatewayClient::new(&base_url);
908 let error = client.list_models().await.unwrap_err();
909
910 assert!(matches!(error, GatewayError::Unauthorized(_)));
911 if let GatewayError::Unauthorized(msg) = error {
912 assert_eq!(msg, "Invalid token");
913 }
914 mock.assert();
915
916 Ok(())
917 }
918
919 #[tokio::test]
920 async fn test_list_models() -> Result<(), GatewayError> {
921 let mut server = Server::new_async().await;
922
923 let raw_response_json = r#"{
924 "object": "list",
925 "data": [
926 {
927 "id": "llama2",
928 "object": "model",
929 "created": 1630000001,
930 "owned_by": "ollama",
931 "served_by": "ollama"
932 }
933 ]
934 }"#;
935
936 let mock = server
937 .mock("GET", "/v1/models")
938 .with_status(200)
939 .with_header("content-type", "application/json")
940 .with_body(raw_response_json)
941 .create();
942
943 let base_url = format!("{}/v1", server.url());
944 let client = InferenceGatewayClient::new(&base_url);
945 let response = client.list_models().await?;
946
947 assert!(response.provider.is_none());
948 assert_eq!(response.object, "list");
949 assert_eq!(response.data.len(), 1);
950 assert_eq!(response.data[0].id, "llama2");
951 mock.assert();
952
953 Ok(())
954 }
955
956 #[tokio::test]
957 async fn test_list_models_by_provider() -> Result<(), GatewayError> {
958 let mut server = Server::new_async().await;
959
960 let raw_json_response = r#"{
961 "provider":"ollama",
962 "object":"list",
963 "data": [
964 {
965 "id": "llama2",
966 "object": "model",
967 "created": 1630000001,
968 "owned_by": "ollama",
969 "served_by": "ollama"
970 }
971 ]
972 }"#;
973
974 let mock = server
975 .mock("GET", "/v1/models?provider=ollama")
976 .with_status(200)
977 .with_header("content-type", "application/json")
978 .with_body(raw_json_response)
979 .create();
980
981 let base_url = format!("{}/v1", server.url());
982 let client = InferenceGatewayClient::new(&base_url);
983 let response = client.list_models_by_provider(Provider::Ollama).await?;
984
985 assert!(response.provider.is_some());
986 assert_eq!(response.provider, Some(Provider::Ollama));
987 assert_eq!(response.data[0].id, "llama2");
988 mock.assert();
989
990 Ok(())
991 }
992
993 #[tokio::test]
994 async fn test_generate_content() -> Result<(), GatewayError> {
995 let mut server = Server::new_async().await;
996
997 let raw_json_response = r#"{
998 "id": "chatcmpl-456",
999 "object": "chat.completion",
1000 "created": 1630000001,
1001 "model": "mixtral-8x7b",
1002 "choices": [
1003 {
1004 "index": 0,
1005 "finish_reason": "stop",
1006 "message": {
1007 "role": "assistant",
1008 "content": "Hellloooo"
1009 }
1010 }
1011 ]
1012 }"#;
1013
1014 let mock = server
1015 .mock("POST", "/v1/chat/completions?provider=ollama")
1016 .with_status(200)
1017 .with_header("content-type", "application/json")
1018 .with_body(raw_json_response)
1019 .create();
1020
1021 let base_url = format!("{}/v1", server.url());
1022 let client = InferenceGatewayClient::new(&base_url);
1023
1024 let messages = vec![Message {
1025 role: MessageRole::User,
1026 content: "Hello".to_string(),
1027 ..Default::default()
1028 }];
1029 let response = client
1030 .generate_content(Provider::Ollama, "llama2", messages)
1031 .await?;
1032
1033 assert_eq!(response.choices[0].message.role, MessageRole::Assistant);
1034 assert_eq!(response.choices[0].message.content, "Hellloooo");
1035 mock.assert();
1036
1037 Ok(())
1038 }
1039
1040 #[tokio::test]
1041 async fn test_generate_content_serialization() -> Result<(), GatewayError> {
1042 let mut server = Server::new_async().await;
1043
1044 let raw_json = r#"{
1045 "id": "chatcmpl-456",
1046 "object": "chat.completion",
1047 "created": 1630000001,
1048 "model": "mixtral-8x7b",
1049 "choices": [
1050 {
1051 "index": 0,
1052 "finish_reason": "stop",
1053 "message": {
1054 "role": "assistant",
1055 "content": "Hello"
1056 }
1057 }
1058 ]
1059 }"#;
1060
1061 let mock = server
1062 .mock("POST", "/v1/chat/completions?provider=groq")
1063 .with_status(200)
1064 .with_header("content-type", "application/json")
1065 .with_body(raw_json)
1066 .create();
1067
1068 let base_url = format!("{}/v1", server.url());
1069 let client = InferenceGatewayClient::new(&base_url);
1070
1071 let direct_parse: Result<CreateChatCompletionResponse, _> = serde_json::from_str(raw_json);
1072 assert!(
1073 direct_parse.is_ok(),
1074 "Direct JSON parse failed: {:?}",
1075 direct_parse.err()
1076 );
1077
1078 let messages = vec![Message {
1079 role: MessageRole::User,
1080 content: "Hello".to_string(),
1081 ..Default::default()
1082 }];
1083
1084 let response = client
1085 .generate_content(Provider::Groq, "mixtral-8x7b", messages)
1086 .await?;
1087
1088 assert_eq!(response.choices[0].message.role, MessageRole::Assistant);
1089 assert_eq!(response.choices[0].message.content, "Hello");
1090
1091 mock.assert();
1092 Ok(())
1093 }
1094
1095 #[tokio::test]
1096 async fn test_generate_content_error_response() -> Result<(), GatewayError> {
1097 let mut server = Server::new_async().await;
1098
1099 let raw_json_response = r#"{
1100 "error":"Invalid request"
1101 }"#;
1102
1103 let mock = server
1104 .mock("POST", "/v1/chat/completions?provider=groq")
1105 .with_status(400)
1106 .with_header("content-type", "application/json")
1107 .with_body(raw_json_response)
1108 .create();
1109
1110 let base_url = format!("{}/v1", server.url());
1111 let client = InferenceGatewayClient::new(&base_url);
1112 let messages = vec![Message {
1113 role: MessageRole::User,
1114 content: "Hello".to_string(),
1115 ..Default::default()
1116 }];
1117 let error = client
1118 .generate_content(Provider::Groq, "mixtral-8x7b", messages)
1119 .await
1120 .unwrap_err();
1121
1122 assert!(matches!(error, GatewayError::BadRequest(_)));
1123 if let GatewayError::BadRequest(msg) = error {
1124 assert_eq!(msg, "Invalid request");
1125 }
1126 mock.assert();
1127
1128 Ok(())
1129 }
1130
1131 #[tokio::test]
1132 async fn test_gateway_errors() -> Result<(), GatewayError> {
1133 let mut server: mockito::ServerGuard = Server::new_async().await;
1134
1135 let unauthorized_mock = server
1136 .mock("GET", "/v1/models")
1137 .with_status(401)
1138 .with_header("content-type", "application/json")
1139 .with_body(r#"{"error":"Invalid token"}"#)
1140 .create();
1141
1142 let base_url = format!("{}/v1", server.url());
1143 let client = InferenceGatewayClient::new(&base_url);
1144 match client.list_models().await {
1145 Err(GatewayError::Unauthorized(msg)) => assert_eq!(msg, "Invalid token"),
1146 _ => panic!("Expected Unauthorized error"),
1147 }
1148 unauthorized_mock.assert();
1149
1150 let bad_request_mock = server
1151 .mock("GET", "/v1/models")
1152 .with_status(400)
1153 .with_header("content-type", "application/json")
1154 .with_body(r#"{"error":"Invalid provider"}"#)
1155 .create();
1156
1157 match client.list_models().await {
1158 Err(GatewayError::BadRequest(msg)) => assert_eq!(msg, "Invalid provider"),
1159 _ => panic!("Expected BadRequest error"),
1160 }
1161 bad_request_mock.assert();
1162
1163 let internal_error_mock = server
1164 .mock("GET", "/v1/models")
1165 .with_status(500)
1166 .with_header("content-type", "application/json")
1167 .with_body(r#"{"error":"Internal server error occurred"}"#)
1168 .create();
1169
1170 match client.list_models().await {
1171 Err(GatewayError::InternalError(msg)) => {
1172 assert_eq!(msg, "Internal server error occurred")
1173 }
1174 _ => panic!("Expected InternalError error"),
1175 }
1176 internal_error_mock.assert();
1177
1178 Ok(())
1179 }
1180
1181 #[tokio::test]
1182 async fn test_generate_content_case_insensitive() -> Result<(), GatewayError> {
1183 let mut server = Server::new_async().await;
1184
1185 let raw_json = r#"{
1186 "id": "chatcmpl-456",
1187 "object": "chat.completion",
1188 "created": 1630000001,
1189 "model": "mixtral-8x7b",
1190 "choices": [
1191 {
1192 "index": 0,
1193 "finish_reason": "stop",
1194 "message": {
1195 "role": "assistant",
1196 "content": "Hello"
1197 }
1198 }
1199 ]
1200 }"#;
1201
1202 let mock = server
1203 .mock("POST", "/v1/chat/completions?provider=groq")
1204 .with_status(200)
1205 .with_header("content-type", "application/json")
1206 .with_body(raw_json)
1207 .create();
1208
1209 let base_url = format!("{}/v1", server.url());
1210 let client = InferenceGatewayClient::new(&base_url);
1211
1212 let messages = vec![Message {
1213 role: MessageRole::User,
1214 content: "Hello".to_string(),
1215 ..Default::default()
1216 }];
1217
1218 let response = client
1219 .generate_content(Provider::Groq, "mixtral-8x7b", messages)
1220 .await?;
1221
1222 assert_eq!(response.choices[0].message.role, MessageRole::Assistant);
1223 assert_eq!(response.choices[0].message.content, "Hello");
1224 assert_eq!(response.model, "mixtral-8x7b");
1225 assert_eq!(response.object, "chat.completion");
1226 mock.assert();
1227
1228 Ok(())
1229 }
1230
1231 #[tokio::test]
1232 async fn test_generate_content_stream() -> Result<(), GatewayError> {
1233 let mut server = Server::new_async().await;
1234
1235 let mock = server
1236 .mock("POST", "/v1/chat/completions?provider=groq")
1237 .with_status(200)
1238 .with_header("content-type", "text/event-stream")
1239 .with_chunked_body(move |writer| -> std::io::Result<()> {
1240 let events = vec![
1241 format!("data: {}\n\n", r#"{"id":"chatcmpl-123","object":"chat.completion.chunk","created":1694268190,"model":"mixtral-8x7b","system_fingerprint":"fp_","choices":[{"index":0,"delta":{"role":"assistant","content":"Hello"},"finish_reason":null}]}"#),
1242 format!("data: {}\n\n", r#"{"id":"chatcmpl-123","object":"chat.completion.chunk","created":1694268191,"model":"mixtral-8x7b","system_fingerprint":"fp_","choices":[{"index":0,"delta":{"role":"assistant","content":" World"},"finish_reason":null}]}"#),
1243 format!("data: {}\n\n", r#"{"id":"chatcmpl-123","object":"chat.completion.chunk","created":1694268192,"model":"mixtral-8x7b","system_fingerprint":"fp_","choices":[{"index":0,"delta":{},"finish_reason":"stop"}],"usage":{"prompt_tokens":17,"completion_tokens":40,"total_tokens":57}}"#),
1244 format!("data: [DONE]\n\n")
1245 ];
1246 for event in events {
1247 writer.write_all(event.as_bytes())?;
1248 }
1249 Ok(())
1250 })
1251 .create();
1252
1253 let base_url = format!("{}/v1", server.url());
1254 let client = InferenceGatewayClient::new(&base_url);
1255
1256 let messages = vec![Message {
1257 role: MessageRole::User,
1258 content: "Test message".to_string(),
1259 ..Default::default()
1260 }];
1261
1262 let stream = client.generate_content_stream(Provider::Groq, "mixtral-8x7b", messages);
1263 pin_mut!(stream);
1264 while let Some(result) = stream.next().await {
1265 let result = result?;
1266 let generate_response: CreateChatCompletionStreamResponse =
1267 serde_json::from_str(&result.data)
1268 .expect("Failed to parse CreateChatCompletionResponse");
1269
1270 if generate_response.choices[0].finish_reason.is_some() {
1271 assert_eq!(
1272 generate_response.choices[0].finish_reason.as_ref().unwrap(),
1273 "stop"
1274 );
1275 break;
1276 }
1277
1278 if let Some(content) = &generate_response.choices[0].delta.content {
1279 assert!(matches!(content.as_str(), "Hello" | " World"));
1280 }
1281 if let Some(role) = &generate_response.choices[0].delta.role {
1282 assert_eq!(role, &MessageRole::Assistant);
1283 }
1284 }
1285
1286 mock.assert();
1287 Ok(())
1288 }
1289
1290 #[tokio::test]
1291 async fn test_generate_content_stream_error() -> Result<(), GatewayError> {
1292 let mut server = Server::new_async().await;
1293
1294 let mock = server
1295 .mock("POST", "/v1/chat/completions?provider=groq")
1296 .with_status(400)
1297 .with_header("content-type", "application/json")
1298 .with_chunked_body(move |writer| -> std::io::Result<()> {
1299 let events = vec![format!(
1300 "event: {}\ndata: {}\nretry: {}\n\n",
1301 r#"error"#, r#"{"error":"Invalid request"}"#, r#"1000"#,
1302 )];
1303 for event in events {
1304 writer.write_all(event.as_bytes())?;
1305 }
1306 Ok(())
1307 })
1308 .expect_at_least(1)
1309 .create();
1310
1311 let base_url = format!("{}/v1", server.url());
1312 let client = InferenceGatewayClient::new(&base_url);
1313
1314 let messages = vec![Message {
1315 role: MessageRole::User,
1316 content: "Test message".to_string(),
1317 ..Default::default()
1318 }];
1319
1320 let stream = client.generate_content_stream(Provider::Groq, "mixtral-8x7b", messages);
1321
1322 pin_mut!(stream);
1323 while let Some(result) = stream.next().await {
1324 let result = result?;
1325 assert!(result.event.is_some());
1326 assert_eq!(result.event.unwrap(), "error");
1327 assert!(result.data.contains("Invalid request"));
1328 assert!(result.retry.is_none());
1329 }
1330
1331 mock.assert();
1332 Ok(())
1333 }
1334
1335 #[tokio::test]
1336 async fn test_generate_content_with_tools() -> Result<(), GatewayError> {
1337 let mut server = Server::new_async().await;
1338
1339 let raw_json_response = r#"{
1340 "id": "chatcmpl-123",
1341 "object": "chat.completion",
1342 "created": 1630000000,
1343 "model": "deepseek-r1-distill-llama-70b",
1344 "choices": [
1345 {
1346 "index": 0,
1347 "finish_reason": "tool_calls",
1348 "message": {
1349 "role": "assistant",
1350 "content": "Let me check the weather for you.",
1351 "tool_calls": [
1352 {
1353 "id": "1234",
1354 "type": "function",
1355 "function": {
1356 "name": "get_weather",
1357 "arguments": "{\"location\": \"London\"}"
1358 }
1359 }
1360 ]
1361 }
1362 }
1363 ]
1364 }"#;
1365
1366 let mock = server
1367 .mock("POST", "/v1/chat/completions?provider=groq")
1368 .with_status(200)
1369 .with_header("content-type", "application/json")
1370 .with_body(raw_json_response)
1371 .create();
1372
1373 let tools = vec![Tool {
1374 r#type: ToolType::Function,
1375 function: FunctionObject {
1376 name: "get_weather".to_string(),
1377 description: "Get the weather for a location".to_string(),
1378 parameters: json!({
1379 "type": "object",
1380 "properties": {
1381 "location": {
1382 "type": "string",
1383 "description": "The city name"
1384 }
1385 },
1386 "required": ["location"]
1387 }),
1388 },
1389 }];
1390
1391 let base_url = format!("{}/v1", server.url());
1392 let client = InferenceGatewayClient::new(&base_url).with_tools(Some(tools));
1393
1394 let messages = vec![Message {
1395 role: MessageRole::User,
1396 content: "What's the weather in London?".to_string(),
1397 ..Default::default()
1398 }];
1399
1400 let response = client
1401 .generate_content(Provider::Groq, "deepseek-r1-distill-llama-70b", messages)
1402 .await?;
1403
1404 assert_eq!(response.choices[0].message.role, MessageRole::Assistant);
1405 assert_eq!(
1406 response.choices[0].message.content,
1407 "Let me check the weather for you."
1408 );
1409
1410 let tool_calls = response.choices[0].message.tool_calls.as_ref().unwrap();
1411 assert_eq!(tool_calls.len(), 1);
1412 assert_eq!(tool_calls[0].function.name, "get_weather");
1413
1414 let params = tool_calls[0]
1415 .function
1416 .parse_arguments()
1417 .expect("Failed to parse function arguments");
1418 assert_eq!(params["location"].as_str().unwrap(), "London");
1419
1420 mock.assert();
1421 Ok(())
1422 }
1423
1424 #[tokio::test]
1425 async fn test_generate_content_without_tools() -> Result<(), GatewayError> {
1426 let mut server = Server::new_async().await;
1427
1428 let raw_json_response = r#"{
1429 "id": "chatcmpl-123",
1430 "object": "chat.completion",
1431 "created": 1630000000,
1432 "model": "gpt-4",
1433 "choices": [
1434 {
1435 "index": 0,
1436 "finish_reason": "stop",
1437 "message": {
1438 "role": "assistant",
1439 "content": "Hello!"
1440 }
1441 }
1442 ]
1443 }"#;
1444
1445 let mock = server
1446 .mock("POST", "/v1/chat/completions?provider=openai")
1447 .with_status(200)
1448 .with_header("content-type", "application/json")
1449 .with_body(raw_json_response)
1450 .create();
1451
1452 let base_url = format!("{}/v1", server.url());
1453 let client = InferenceGatewayClient::new(&base_url);
1454
1455 let messages = vec![Message {
1456 role: MessageRole::User,
1457 content: "Hi".to_string(),
1458 ..Default::default()
1459 }];
1460
1461 let response = client
1462 .generate_content(Provider::OpenAI, "gpt-4", messages)
1463 .await?;
1464
1465 assert_eq!(response.model, "gpt-4");
1466 assert_eq!(response.choices[0].message.content, "Hello!");
1467 assert_eq!(response.choices[0].message.role, MessageRole::Assistant);
1468 assert!(response.choices[0].message.tool_calls.is_none());
1469
1470 mock.assert();
1471 Ok(())
1472 }
1473
1474 #[tokio::test]
1475 async fn test_generate_content_with_tools_payload() -> Result<(), GatewayError> {
1476 let mut server = Server::new_async().await;
1477
1478 let raw_request_body = r#"{
1479 "model": "deepseek-r1-distill-llama-70b",
1480 "messages": [
1481 {
1482 "role": "system",
1483 "content": "You are a helpful assistant."
1484 },
1485 {
1486 "role": "user",
1487 "content": "What is the current weather in Toronto?"
1488 }
1489 ],
1490 "stream": false,
1491 "tools": [
1492 {
1493 "type": "function",
1494 "function": {
1495 "name": "get_current_weather",
1496 "description": "Get the current weather of a city",
1497 "parameters": {
1498 "type": "object",
1499 "properties": {
1500 "city": {
1501 "type": "string",
1502 "description": "The name of the city"
1503 }
1504 },
1505 "required": ["city"]
1506 }
1507 }
1508 }
1509 ]
1510 }"#;
1511
1512 let raw_json_response = r#"{
1513 "id": "1234",
1514 "object": "chat.completion",
1515 "created": 1630000000,
1516 "model": "deepseek-r1-distill-llama-70b",
1517 "choices": [
1518 {
1519 "index": 0,
1520 "finish_reason": "stop",
1521 "message": {
1522 "role": "assistant",
1523 "content": "Let me check the weather for you",
1524 "tool_calls": [
1525 {
1526 "id": "1234",
1527 "type": "function",
1528 "function": {
1529 "name": "get_current_weather",
1530 "arguments": "{\"city\": \"Toronto\"}"
1531 }
1532 }
1533 ]
1534 }
1535 }
1536 ]
1537 }"#;
1538
1539 let mock = server
1540 .mock("POST", "/v1/chat/completions?provider=groq")
1541 .with_status(200)
1542 .with_header("content-type", "application/json")
1543 .match_body(mockito::Matcher::JsonString(raw_request_body.to_string()))
1544 .with_body(raw_json_response)
1545 .create();
1546
1547 let tools = vec![Tool {
1548 r#type: ToolType::Function,
1549 function: FunctionObject {
1550 name: "get_current_weather".to_string(),
1551 description: "Get the current weather of a city".to_string(),
1552 parameters: json!({
1553 "type": "object",
1554 "properties": {
1555 "city": {
1556 "type": "string",
1557 "description": "The name of the city"
1558 }
1559 },
1560 "required": ["city"]
1561 }),
1562 },
1563 }];
1564
1565 let base_url = format!("{}/v1", server.url());
1566 let client = InferenceGatewayClient::new(&base_url);
1567
1568 let messages = vec![
1569 Message {
1570 role: MessageRole::System,
1571 content: "You are a helpful assistant.".to_string(),
1572 ..Default::default()
1573 },
1574 Message {
1575 role: MessageRole::User,
1576 content: "What is the current weather in Toronto?".to_string(),
1577 ..Default::default()
1578 },
1579 ];
1580
1581 let response = client
1582 .with_tools(Some(tools))
1583 .generate_content(Provider::Groq, "deepseek-r1-distill-llama-70b", messages)
1584 .await?;
1585
1586 assert_eq!(response.choices[0].message.role, MessageRole::Assistant);
1587 assert_eq!(
1588 response.choices[0].message.content,
1589 "Let me check the weather for you"
1590 );
1591 assert_eq!(
1592 response.choices[0]
1593 .message
1594 .tool_calls
1595 .as_ref()
1596 .unwrap()
1597 .len(),
1598 1
1599 );
1600
1601 mock.assert();
1602 Ok(())
1603 }
1604
1605 #[tokio::test]
1606 async fn test_generate_content_with_max_tokens() -> Result<(), GatewayError> {
1607 let mut server = Server::new_async().await;
1608
1609 let raw_json_response = r#"{
1610 "id": "chatcmpl-123",
1611 "object": "chat.completion",
1612 "created": 1630000000,
1613 "model": "mixtral-8x7b",
1614 "choices": [
1615 {
1616 "index": 0,
1617 "finish_reason": "stop",
1618 "message": {
1619 "role": "assistant",
1620 "content": "Here's a poem with 100 tokens..."
1621 }
1622 }
1623 ]
1624 }"#;
1625
1626 let mock = server
1627 .mock("POST", "/v1/chat/completions?provider=groq")
1628 .with_status(200)
1629 .with_header("content-type", "application/json")
1630 .match_body(mockito::Matcher::JsonString(
1631 r#"{
1632 "model": "mixtral-8x7b",
1633 "messages": [{"role":"user","content":"Write a poem"}],
1634 "stream": false,
1635 "max_tokens": 100
1636 }"#
1637 .to_string(),
1638 ))
1639 .with_body(raw_json_response)
1640 .create();
1641
1642 let base_url = format!("{}/v1", server.url());
1643 let client = InferenceGatewayClient::new(&base_url).with_max_tokens(Some(100));
1644
1645 let messages = vec![Message {
1646 role: MessageRole::User,
1647 content: "Write a poem".to_string(),
1648 ..Default::default()
1649 }];
1650
1651 let response = client
1652 .generate_content(Provider::Groq, "mixtral-8x7b", messages)
1653 .await?;
1654
1655 assert_eq!(
1656 response.choices[0].message.content,
1657 "Here's a poem with 100 tokens..."
1658 );
1659 assert_eq!(response.model, "mixtral-8x7b");
1660 assert_eq!(response.created, 1630000000);
1661 assert_eq!(response.object, "chat.completion");
1662
1663 mock.assert();
1664 Ok(())
1665 }
1666
1667 #[tokio::test]
1668 async fn test_health_check() -> Result<(), GatewayError> {
1669 let mut server = Server::new_async().await;
1670 let mock = server.mock("GET", "/health").with_status(200).create();
1671
1672 let client = InferenceGatewayClient::new(&server.url());
1673 let is_healthy = client.health_check().await?;
1674
1675 assert!(is_healthy);
1676 mock.assert();
1677
1678 Ok(())
1679 }
1680
1681 #[tokio::test]
1682 async fn test_client_base_url_configuration() -> Result<(), GatewayError> {
1683 let mut custom_url_server = Server::new_async().await;
1684
1685 let custom_url_mock = custom_url_server
1686 .mock("GET", "/health")
1687 .with_status(200)
1688 .create();
1689
1690 let custom_client = InferenceGatewayClient::new(&custom_url_server.url());
1691 let is_healthy = custom_client.health_check().await?;
1692 assert!(is_healthy);
1693 custom_url_mock.assert();
1694
1695 let default_client = InferenceGatewayClient::new_default();
1696
1697 let default_url = "http://localhost:8080/v1";
1698 assert_eq!(default_client.base_url(), default_url);
1699
1700 Ok(())
1701 }
1702}