inference_gateway_sdk/
lib.rs

1//! Inference Gateway SDK for Rust
2//!
3//! This crate provides a Rust client for the Inference Gateway API, allowing interaction
4//! with various LLM providers through a unified interface.
5
6use core::fmt;
7use std::future::Future;
8
9use futures_util::{Stream, StreamExt};
10use reqwest::{Client, StatusCode};
11use serde::{Deserialize, Serialize};
12
13use serde_json::Value;
14use thiserror::Error;
15
16/// Stream of Server-Sent Events (SSE) from the Inference Gateway API
17#[derive(Debug, Serialize, Deserialize)]
18pub struct SSEvents {
19    pub data: String,
20    pub event: Option<String>,
21    pub retry: Option<u64>,
22}
23
24/// Custom error types for the Inference Gateway SDK
25#[derive(Error, Debug)]
26pub enum GatewayError {
27    #[error("Unauthorized: {0}")]
28    Unauthorized(String),
29
30    #[error("Bad request: {0}")]
31    BadRequest(String),
32
33    #[error("Internal server error: {0}")]
34    InternalError(String),
35
36    #[error("Stream error: {0}")]
37    StreamError(reqwest::Error),
38
39    #[error("Decoding error: {0}")]
40    DecodingError(std::string::FromUtf8Error),
41
42    #[error("Request error: {0}")]
43    RequestError(#[from] reqwest::Error),
44
45    #[error("Deserialization error: {0}")]
46    DeserializationError(serde_json::Error),
47
48    #[error("Serialization error: {0}")]
49    SerializationError(#[from] serde_json::Error),
50
51    #[error("Other error: {0}")]
52    Other(#[from] Box<dyn std::error::Error + Send + Sync>),
53}
54
55#[derive(Debug, Deserialize)]
56struct ErrorResponse {
57    error: String,
58}
59
60/// Represents a model available through a provider
61#[derive(Debug, Serialize, Deserialize)]
62pub struct Model {
63    /// Unique identifier of the model
64    pub name: String,
65}
66
67/// Collection of models available from a specific provider
68#[derive(Debug, Serialize, Deserialize)]
69pub struct ProviderModels {
70    /// The LLM provider
71    pub provider: Provider,
72    /// List of available models
73    pub models: Vec<Model>,
74}
75
76/// Supported LLM providers
77#[derive(Debug, Serialize, Deserialize, PartialEq, Clone, Copy)]
78#[serde(rename_all = "lowercase")]
79pub enum Provider {
80    #[serde(alias = "Ollama", alias = "OLLAMA")]
81    Ollama,
82    #[serde(alias = "Groq", alias = "GROQ")]
83    Groq,
84    #[serde(alias = "OpenAI", alias = "OPENAI")]
85    OpenAI,
86    #[serde(alias = "Cloudflare", alias = "CLOUDFLARE")]
87    Cloudflare,
88    #[serde(alias = "Cohere", alias = "COHERE")]
89    Cohere,
90    #[serde(alias = "Anthropic", alias = "ANTHROPIC")]
91    Anthropic,
92}
93
94impl fmt::Display for Provider {
95    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
96        match self {
97            Provider::Ollama => write!(f, "ollama"),
98            Provider::Groq => write!(f, "groq"),
99            Provider::OpenAI => write!(f, "openai"),
100            Provider::Cloudflare => write!(f, "cloudflare"),
101            Provider::Cohere => write!(f, "cohere"),
102            Provider::Anthropic => write!(f, "anthropic"),
103        }
104    }
105}
106
107impl TryFrom<&str> for Provider {
108    type Error = GatewayError;
109
110    fn try_from(s: &str) -> Result<Self, Self::Error> {
111        match s.to_lowercase().as_str() {
112            "ollama" => Ok(Self::Ollama),
113            "groq" => Ok(Self::Groq),
114            "openai" => Ok(Self::OpenAI),
115            "cloudflare" => Ok(Self::Cloudflare),
116            "cohere" => Ok(Self::Cohere),
117            "anthropic" => Ok(Self::Anthropic),
118            _ => Err(GatewayError::BadRequest(format!("Unknown provider: {}", s))),
119        }
120    }
121}
122
123#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Default)]
124#[serde(rename_all = "lowercase")]
125pub enum MessageRole {
126    System,
127    #[default]
128    User,
129    Assistant,
130    Tool,
131}
132
133impl fmt::Display for MessageRole {
134    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
135        match self {
136            MessageRole::System => write!(f, "system"),
137            MessageRole::User => write!(f, "user"),
138            MessageRole::Assistant => write!(f, "assistant"),
139            MessageRole::Tool => write!(f, "tool"),
140        }
141    }
142}
143
144/// A message in a conversation with an LLM
145#[derive(Debug, Serialize, Deserialize, Clone, Default)]
146pub struct Message {
147    /// Role of the message sender ("system", "user", "assistant" or "tool")
148    pub role: MessageRole,
149    /// Content of the message
150    pub content: String,
151    /// The tools an LLM wants to invoke
152    #[serde(skip_serializing_if = "Option::is_none")]
153    pub tool_calls: Option<Vec<ChatCompletionMessageToolCall>>,
154    /// Unique identifier of the tool call
155    #[serde(skip_serializing_if = "Option::is_none")]
156    pub tool_call_id: Option<String>,
157    /// Reasoning behind the message
158    #[serde(skip_serializing_if = "Option::is_none")]
159    pub reasoning: Option<String>,
160}
161
162/// A tool call in a message response
163#[derive(Debug, Deserialize, Serialize, Clone)]
164pub struct ChatCompletionMessageToolCall {
165    /// Unique identifier of the tool call
166    pub id: String,
167    /// Type of the tool being called
168    #[serde(rename = "type")]
169    pub r#type: ChatCompletionToolType,
170    /// Function that was called
171    pub function: ChatCompletionMessageToolCallFunction,
172}
173
174/// Type of tool that can be called
175#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
176pub enum ChatCompletionToolType {
177    /// Function tool type
178    #[serde(rename = "function")]
179    Function,
180}
181
182/// Function details in a tool call
183#[derive(Debug, Deserialize, Serialize, Clone)]
184pub struct ChatCompletionMessageToolCallFunction {
185    /// Name of the function to call
186    pub name: String,
187    /// Arguments to the function in JSON string format
188    pub arguments: String,
189}
190
191// Add this helper method to make argument access more convenient
192impl ChatCompletionMessageToolCallFunction {
193    pub fn parse_arguments(&self) -> Result<serde_json::Value, serde_json::Error> {
194        serde_json::from_str(&self.arguments)
195    }
196}
197
198/// Tool function to call
199#[derive(Debug, Serialize, Deserialize, Clone)]
200pub struct FunctionObject {
201    pub name: String,
202    pub description: String,
203    pub parameters: Value,
204}
205
206/// Type of tool that can be used by the model
207#[derive(Debug, Serialize, Deserialize, Clone)]
208#[serde(rename_all = "lowercase")]
209pub enum ToolType {
210    Function,
211}
212
213/// Tool to use for the LLM toolbox
214#[derive(Debug, Serialize, Deserialize, Clone)]
215pub struct Tool {
216    pub r#type: ToolType,
217    pub function: FunctionObject,
218}
219
220/// Request payload for generating content
221#[derive(Debug, Serialize)]
222struct CreateChatCompletionRequest {
223    /// Name of the model
224    model: String,
225    /// Conversation history and prompt
226    messages: Vec<Message>,
227    /// Enable streaming of responses
228    stream: bool,
229    /// Optional tools to use for generation
230    #[serde(skip_serializing_if = "Option::is_none")]
231    tools: Option<Vec<Tool>>,
232    /// Maximum number of tokens to generate
233    #[serde(skip_serializing_if = "Option::is_none")]
234    max_tokens: Option<i32>,
235}
236
237/// Function details in a tool call response
238#[derive(Debug, Deserialize, Clone)]
239pub struct ToolFunctionResponse {
240    /// Name of the function that the LLM wants to call
241    pub name: String,
242    /// Description of the function
243    #[serde(skip_serializing_if = "Option::is_none")]
244    pub description: Option<String>,
245    /// The arguments that the LLM wants to pass to the function
246    pub arguments: Value,
247}
248
249/// A tool call in the response
250#[derive(Debug, Deserialize, Clone)]
251pub struct ToolCallResponse {
252    /// Unique identifier of the tool call
253    pub id: String,
254    /// Type of tool that was called
255    #[serde(rename = "type")]
256    pub r#type: ToolType,
257    /// Function that the LLM wants to call
258    pub function: ToolFunctionResponse,
259}
260
261#[derive(Debug, Deserialize, Clone)]
262pub struct ChatCompletionChoice {
263    pub finish_reason: String,
264    pub message: Message,
265    pub index: i32,
266}
267
268/// The response from generating content
269#[derive(Debug, Deserialize, Clone)]
270pub struct CreateChatCompletionResponse {
271    pub id: String,
272    pub choices: Vec<ChatCompletionChoice>,
273    pub created: i64,
274    pub model: String,
275    pub object: String,
276}
277
278/// The response from streaming content generation
279#[derive(Debug, Deserialize, Clone)]
280pub struct CreateChatCompletionStreamResponse {
281    /// A unique identifier for the chat completion. Each chunk has the same ID.
282    pub id: String,
283    /// A list of chat completion choices. Can contain more than one element if `n` is greater than 1.
284    pub choices: Vec<ChatCompletionStreamChoice>,
285    /// The Unix timestamp (in seconds) of when the chat completion was created.
286    pub created: i64,
287    /// The model used to generate the completion.
288    pub model: String,
289    /// This fingerprint represents the backend configuration that the model runs with.
290    #[serde(skip_serializing_if = "Option::is_none")]
291    pub system_fingerprint: Option<String>,
292    /// The object type, which is always "chat.completion.chunk".
293    pub object: String,
294    /// Usage statistics for the completion request.
295    #[serde(skip_serializing_if = "Option::is_none")]
296    pub usage: Option<CompletionUsage>,
297}
298
299/// Choice in a streaming completion response
300#[derive(Debug, Deserialize, Clone)]
301pub struct ChatCompletionStreamChoice {
302    /// The delta content for this streaming chunk
303    pub delta: ChatCompletionStreamDelta,
304    /// Index of the choice in the choices array
305    pub index: i32,
306    /// The reason the model stopped generating tokens
307    #[serde(skip_serializing_if = "Option::is_none")]
308    pub finish_reason: Option<String>,
309}
310
311/// Delta content for streaming responses
312#[derive(Debug, Deserialize, Clone)]
313pub struct ChatCompletionStreamDelta {
314    /// Role of the message sender
315    #[serde(skip_serializing_if = "Option::is_none")]
316    pub role: Option<MessageRole>,
317    /// Content of the message delta
318    #[serde(skip_serializing_if = "Option::is_none")]
319    pub content: Option<String>,
320    /// Tool calls for this delta
321    #[serde(skip_serializing_if = "Option::is_none")]
322    pub tool_calls: Option<Vec<ToolCallResponse>>,
323}
324
325/// Usage statistics for the completion request
326#[derive(Debug, Deserialize, Clone)]
327pub struct CompletionUsage {
328    /// Number of tokens in the generated completion
329    pub completion_tokens: i64,
330    /// Number of tokens in the prompt
331    pub prompt_tokens: i64,
332    /// Total number of tokens used in the request (prompt + completion)
333    pub total_tokens: i64,
334}
335
336/// Client for interacting with the Inference Gateway API
337pub struct InferenceGatewayClient {
338    base_url: String,
339    client: Client,
340    token: Option<String>,
341    tools: Option<Vec<Tool>>,
342    max_tokens: Option<i32>,
343}
344
345/// Implement Debug for InferenceGatewayClient
346impl std::fmt::Debug for InferenceGatewayClient {
347    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
348        f.debug_struct("InferenceGatewayClient")
349            .field("base_url", &self.base_url)
350            .field("token", &self.token.as_ref().map(|_| "*****"))
351            .finish()
352    }
353}
354
355/// Core API interface for the Inference Gateway
356pub trait InferenceGatewayAPI {
357    /// Lists available models from all providers
358    ///
359    /// # Errors
360    /// - Returns [`GatewayError::Unauthorized`] if authentication fails
361    /// - Returns [`GatewayError::BadRequest`] if the request is malformed
362    /// - Returns [`GatewayError::InternalError`] if the server has an error
363    /// - Returns [`GatewayError::Other`] for other errors
364    ///
365    /// # Returns
366    /// A list of models available from all providers
367    fn list_models(&self)
368        -> impl Future<Output = Result<Vec<ProviderModels>, GatewayError>> + Send;
369
370    /// Lists available models by a specific provider
371    ///
372    /// # Arguments
373    /// * `provider` - The LLM provider to list models for
374    ///
375    /// # Errors
376    /// - Returns [`GatewayError::Unauthorized`] if authentication fails
377    /// - Returns [`GatewayError::BadRequest`] if the request is malformed
378    /// - Returns [`GatewayError::InternalError`] if the server has an error
379    /// - Returns [`GatewayError::Other`] for other errors
380    ///
381    /// # Returns
382    /// A list of models available from the specified provider
383    fn list_models_by_provider(
384        &self,
385        provider: Provider,
386    ) -> impl Future<Output = Result<ProviderModels, GatewayError>> + Send;
387
388    /// Generates content using a specified model
389    ///
390    /// # Arguments
391    /// * `provider` - The LLM provider to use
392    /// * `model` - Name of the model
393    /// * `messages` - Conversation history and prompt
394    /// * `tools` - Optional tools to use for generation
395    ///
396    /// # Errors
397    /// - Returns [`GatewayError::Unauthorized`] if authentication fails
398    /// - Returns [`GatewayError::BadRequest`] if the request is malformed
399    /// - Returns [`GatewayError::InternalError`] if the server has an error
400    /// - Returns [`GatewayError::Other`] for other errors
401    ///
402    /// # Returns
403    /// The generated response
404    fn generate_content(
405        &self,
406        provider: Provider,
407        model: &str,
408        messages: Vec<Message>,
409    ) -> impl Future<Output = Result<CreateChatCompletionResponse, GatewayError>> + Send;
410
411    /// Stream content generation directly using the backend SSE stream.
412    ///
413    /// # Arguments
414    /// * `provider` - The LLM provider to use
415    /// * `model` - Name of the model
416    /// * `messages` - Conversation history and prompt
417    ///
418    /// # Returns
419    /// A stream of Server-Sent Events (SSE) from the Inference Gateway API
420    fn generate_content_stream(
421        &self,
422        provider: Provider,
423        model: &str,
424        messages: Vec<Message>,
425    ) -> impl Stream<Item = Result<SSEvents, GatewayError>> + Send;
426
427    /// Checks if the API is available
428    fn health_check(&self) -> impl Future<Output = Result<bool, GatewayError>> + Send;
429}
430
431impl InferenceGatewayClient {
432    /// Creates a new client instance
433    ///
434    /// # Arguments
435    /// * `base_url` - Base URL of the Inference Gateway API
436    pub fn new(base_url: &str) -> Self {
437        Self {
438            base_url: base_url.to_string(),
439            client: Client::new(),
440            token: None,
441            tools: None,
442            max_tokens: None,
443        }
444    }
445
446    /// Creates a new client instance with default configuration
447    /// pointing to http://localhost:8080/v1
448    pub fn new_default() -> Self {
449        let base_url = std::env::var("INFERENCE_GATEWAY_URL")
450            .unwrap_or_else(|_| "http://localhost:8080/v1".to_string());
451
452        Self {
453            base_url,
454            client: Client::new(),
455            token: None,
456            tools: None,
457            max_tokens: None,
458        }
459    }
460
461    /// Returns the base URL this client is configured to use
462    pub fn base_url(&self) -> &str {
463        &self.base_url
464    }
465
466    /// Provides tools to use for generation
467    ///
468    /// # Arguments
469    /// * `tools` - List of tools to use for generation
470    ///
471    /// # Returns
472    /// Self with the tools set
473    pub fn with_tools(mut self, tools: Option<Vec<Tool>>) -> Self {
474        self.tools = tools;
475        self
476    }
477
478    /// Sets an authentication token for the client
479    ///
480    /// # Arguments
481    /// * `token` - JWT token for authentication
482    ///
483    /// # Returns
484    /// Self with the token set
485    pub fn with_token(mut self, token: impl Into<String>) -> Self {
486        self.token = Some(token.into());
487        self
488    }
489
490    /// Sets the maximum number of tokens to generate
491    ///
492    /// # Arguments
493    /// * `max_tokens` - Maximum number of tokens to generate
494    ///
495    /// # Returns
496    /// Self with the maximum tokens set
497    pub fn with_max_tokens(mut self, max_tokens: Option<i32>) -> Self {
498        self.max_tokens = max_tokens;
499        self
500    }
501}
502
503impl InferenceGatewayAPI for InferenceGatewayClient {
504    async fn list_models(&self) -> Result<Vec<ProviderModels>, GatewayError> {
505        let url = format!("{}/models", self.base_url);
506        let mut request = self.client.get(&url);
507        if let Some(token) = &self.token {
508            request = request.bearer_auth(token);
509        }
510
511        let response = request.send().await?;
512        match response.status() {
513            StatusCode::OK => Ok(response.json().await?),
514            StatusCode::UNAUTHORIZED => {
515                let error: ErrorResponse = response.json().await?;
516                Err(GatewayError::Unauthorized(error.error))
517            }
518            StatusCode::BAD_REQUEST => {
519                let error: ErrorResponse = response.json().await?;
520                Err(GatewayError::BadRequest(error.error))
521            }
522            StatusCode::INTERNAL_SERVER_ERROR => {
523                let error: ErrorResponse = response.json().await?;
524                Err(GatewayError::InternalError(error.error))
525            }
526            _ => Err(GatewayError::Other(Box::new(std::io::Error::new(
527                std::io::ErrorKind::Other,
528                format!("Unexpected status code: {}", response.status()),
529            )))),
530        }
531    }
532
533    async fn list_models_by_provider(
534        &self,
535        provider: Provider,
536    ) -> Result<ProviderModels, GatewayError> {
537        let url = format!("{}/list/models?provider={}", self.base_url, provider);
538        let mut request = self.client.get(&url);
539        if let Some(token) = &self.token {
540            request = self.client.get(&url).bearer_auth(token);
541        }
542
543        let response = request.send().await?;
544        match response.status() {
545            StatusCode::OK => Ok(response.json().await?),
546            StatusCode::UNAUTHORIZED => {
547                let error: ErrorResponse = response.json().await?;
548                Err(GatewayError::Unauthorized(error.error))
549            }
550            StatusCode::BAD_REQUEST => {
551                let error: ErrorResponse = response.json().await?;
552                Err(GatewayError::BadRequest(error.error))
553            }
554            StatusCode::INTERNAL_SERVER_ERROR => {
555                let error: ErrorResponse = response.json().await?;
556                Err(GatewayError::InternalError(error.error))
557            }
558            _ => Err(GatewayError::Other(Box::new(std::io::Error::new(
559                std::io::ErrorKind::Other,
560                format!("Unexpected status code: {}", response.status()),
561            )))),
562        }
563    }
564
565    async fn generate_content(
566        &self,
567        provider: Provider,
568        model: &str,
569        messages: Vec<Message>,
570    ) -> Result<CreateChatCompletionResponse, GatewayError> {
571        let url = format!("{}/chat/completions?provider={}", self.base_url, provider);
572        let mut request = self.client.post(&url);
573        if let Some(token) = &self.token {
574            request = request.bearer_auth(token);
575        }
576
577        let request_payload = CreateChatCompletionRequest {
578            model: model.to_string(),
579            messages,
580            stream: false,
581            tools: self.tools.clone(),
582            max_tokens: self.max_tokens,
583        };
584
585        let response = request.json(&request_payload).send().await?;
586
587        match response.status() {
588            StatusCode::OK => Ok(response.json().await?),
589            StatusCode::BAD_REQUEST => {
590                let error: ErrorResponse = response.json().await?;
591                Err(GatewayError::BadRequest(error.error))
592            }
593            StatusCode::UNAUTHORIZED => {
594                let error: ErrorResponse = response.json().await?;
595                Err(GatewayError::Unauthorized(error.error))
596            }
597            StatusCode::INTERNAL_SERVER_ERROR => {
598                let error: ErrorResponse = response.json().await?;
599                Err(GatewayError::InternalError(error.error))
600            }
601            status => Err(GatewayError::Other(Box::new(std::io::Error::new(
602                std::io::ErrorKind::Other,
603                format!("Unexpected status code: {}", status),
604            )))),
605        }
606    }
607
608    /// Stream content generation directly using the backend SSE stream.
609    fn generate_content_stream(
610        &self,
611        provider: Provider,
612        model: &str,
613        messages: Vec<Message>,
614    ) -> impl Stream<Item = Result<SSEvents, GatewayError>> + Send {
615        let client = self.client.clone();
616        let base_url = self.base_url.clone();
617        let url = format!(
618            "{}/chat/completions?provider={}",
619            base_url,
620            provider.to_string().to_lowercase()
621        );
622
623        let request = CreateChatCompletionRequest {
624            model: model.to_string(),
625            messages,
626            stream: true,
627            tools: None,
628            max_tokens: None,
629        };
630
631        async_stream::try_stream! {
632            let response = client.post(&url).json(&request).send().await?;
633            let mut stream = response.bytes_stream();
634            let mut current_event: Option<String> = None;
635            let mut current_data: Option<String> = None;
636
637            while let Some(chunk) = stream.next().await {
638                let chunk = chunk?;
639                let chunk_str = String::from_utf8_lossy(&chunk);
640
641                for line in chunk_str.lines() {
642                    if line.is_empty() && current_data.is_some() {
643                        yield SSEvents {
644                            data: current_data.take().unwrap(),
645                            event: current_event.take(),
646                            retry: None, // TODO - implement this, for now it's not implemented in the backend
647                        };
648                        continue;
649                    }
650
651                    if let Some(event) = line.strip_prefix("event:") {
652                        current_event = Some(event.trim().to_string());
653                    } else if let Some(data) = line.strip_prefix("data:") {
654                        let processed_data = data.strip_suffix('\n').unwrap_or(data);
655                        current_data = Some(processed_data.trim().to_string());
656                    }
657                }
658            }
659        }
660    }
661
662    async fn health_check(&self) -> Result<bool, GatewayError> {
663        let url = format!("{}/health", self.base_url);
664
665        let response = self.client.get(&url).send().await?;
666        match response.status() {
667            StatusCode::OK => Ok(true),
668            _ => Ok(false),
669        }
670    }
671}
672
673#[cfg(test)]
674mod tests {
675    use crate::{
676        CreateChatCompletionRequest, CreateChatCompletionResponse,
677        CreateChatCompletionStreamResponse, FunctionObject, GatewayError, InferenceGatewayAPI,
678        InferenceGatewayClient, Message, MessageRole, Provider, Tool, ToolType,
679    };
680    use futures_util::{pin_mut, StreamExt};
681    use mockito::{Matcher, Server};
682    use serde_json::json;
683
684    #[test]
685    fn test_provider_serialization() {
686        let providers = vec![
687            (Provider::Ollama, "ollama"),
688            (Provider::Groq, "groq"),
689            (Provider::OpenAI, "openai"),
690            (Provider::Cloudflare, "cloudflare"),
691            (Provider::Cohere, "cohere"),
692            (Provider::Anthropic, "anthropic"),
693        ];
694
695        for (provider, expected) in providers {
696            let json = serde_json::to_string(&provider).unwrap();
697            assert_eq!(json, format!("\"{}\"", expected));
698        }
699    }
700
701    #[test]
702    fn test_provider_deserialization() {
703        let test_cases = vec![
704            ("\"ollama\"", Provider::Ollama),
705            ("\"groq\"", Provider::Groq),
706            ("\"openai\"", Provider::OpenAI),
707            ("\"cloudflare\"", Provider::Cloudflare),
708            ("\"cohere\"", Provider::Cohere),
709            ("\"anthropic\"", Provider::Anthropic),
710        ];
711
712        for (json, expected) in test_cases {
713            let provider: Provider = serde_json::from_str(json).unwrap();
714            assert_eq!(provider, expected);
715        }
716    }
717
718    #[test]
719    fn test_message_serialization_with_tool_call_id() {
720        let message_with_tool = Message {
721            role: MessageRole::Tool,
722            content: "The weather is sunny".to_string(),
723            tool_call_id: Some("call_123".to_string()),
724            ..Default::default()
725        };
726
727        let serialized = serde_json::to_string(&message_with_tool).unwrap();
728        let expected_with_tool =
729            r#"{"role":"tool","content":"The weather is sunny","tool_call_id":"call_123"}"#;
730        assert_eq!(serialized, expected_with_tool);
731
732        let message_without_tool = Message {
733            role: MessageRole::User,
734            content: "What's the weather?".to_string(),
735            ..Default::default()
736        };
737
738        let serialized = serde_json::to_string(&message_without_tool).unwrap();
739        let expected_without_tool = r#"{"role":"user","content":"What's the weather?"}"#;
740        assert_eq!(serialized, expected_without_tool);
741
742        let deserialized: Message = serde_json::from_str(expected_with_tool).unwrap();
743        assert_eq!(deserialized.role, MessageRole::Tool);
744        assert_eq!(deserialized.content, "The weather is sunny");
745        assert_eq!(deserialized.tool_call_id, Some("call_123".to_string()));
746
747        let deserialized: Message = serde_json::from_str(expected_without_tool).unwrap();
748        assert_eq!(deserialized.role, MessageRole::User);
749        assert_eq!(deserialized.content, "What's the weather?");
750        assert_eq!(deserialized.tool_call_id, None);
751    }
752
753    #[test]
754    fn test_provider_display() {
755        let providers = vec![
756            (Provider::Ollama, "ollama"),
757            (Provider::Groq, "groq"),
758            (Provider::OpenAI, "openai"),
759            (Provider::Cloudflare, "cloudflare"),
760            (Provider::Cohere, "cohere"),
761            (Provider::Anthropic, "anthropic"),
762        ];
763
764        for (provider, expected) in providers {
765            assert_eq!(provider.to_string(), expected);
766        }
767    }
768
769    #[test]
770    fn test_generate_request_serialization() {
771        let request_payload = CreateChatCompletionRequest {
772            model: "llama3.2:1b".to_string(),
773            messages: vec![
774                Message {
775                    role: MessageRole::System,
776                    content: "You are a helpful assistant.".to_string(),
777                    ..Default::default()
778                },
779                Message {
780                    role: MessageRole::User,
781                    content: "What is the current weather in Toronto?".to_string(),
782                    ..Default::default()
783                },
784            ],
785            stream: false,
786            tools: Some(vec![Tool {
787                r#type: ToolType::Function,
788                function: FunctionObject {
789                    name: "get_current_weather".to_string(),
790                    description: "Get the current weather of a city".to_string(),
791                    parameters: json!({
792                        "type": "object",
793                        "properties": {
794                            "city": {
795                                "type": "string",
796                                "description": "The name of the city"
797                            }
798                        },
799                        "required": ["city"]
800                    }),
801                },
802            }]),
803            max_tokens: None,
804        };
805
806        let serialized = serde_json::to_string_pretty(&request_payload).unwrap();
807        let expected = r#"{
808      "model": "llama3.2:1b",
809      "messages": [
810        {
811          "role": "system",
812          "content": "You are a helpful assistant."
813        },
814        {
815          "role": "user",
816          "content": "What is the current weather in Toronto?"
817        }
818      ],
819      "stream": false,
820      "tools": [
821        {
822          "type": "function",
823          "function": {
824            "name": "get_current_weather",
825            "description": "Get the current weather of a city",
826            "parameters": {
827              "type": "object",
828              "properties": {
829                "city": {
830                  "type": "string",
831                  "description": "The name of the city"
832                }
833              },
834              "required": ["city"]
835            }
836          }
837        }
838      ]
839    }"#;
840
841        assert_eq!(
842            serde_json::from_str::<serde_json::Value>(&serialized).unwrap(),
843            serde_json::from_str::<serde_json::Value>(expected).unwrap()
844        );
845    }
846
847    #[tokio::test]
848    async fn test_authentication_header() -> Result<(), GatewayError> {
849        let mut server = Server::new_async().await;
850
851        let mock_with_auth = server
852            .mock("GET", "/v1/models")
853            .match_header("authorization", "Bearer test-token")
854            .with_status(200)
855            .with_header("content-type", "application/json")
856            .with_body("[]")
857            .expect(1)
858            .create();
859
860        let base_url = format!("{}/v1", server.url());
861        let client = InferenceGatewayClient::new(&base_url).with_token("test-token");
862        client.list_models().await?;
863        mock_with_auth.assert();
864
865        let mock_without_auth = server
866            .mock("GET", "/v1/models")
867            .match_header("authorization", Matcher::Missing)
868            .with_status(200)
869            .with_header("content-type", "application/json")
870            .with_body("[]")
871            .expect(1)
872            .create();
873
874        let base_url = format!("{}/v1", server.url());
875        let client = InferenceGatewayClient::new(&base_url);
876        client.list_models().await?;
877        mock_without_auth.assert();
878
879        Ok(())
880    }
881
882    #[tokio::test]
883    async fn test_unauthorized_error() -> Result<(), GatewayError> {
884        let mut server = Server::new_async().await;
885
886        let raw_json_response = r#"{
887            "error": "Invalid token"
888        }"#;
889
890        let mock = server
891            .mock("GET", "/v1/models")
892            .with_status(401)
893            .with_header("content-type", "application/json")
894            .with_body(raw_json_response)
895            .create();
896
897        let base_url = format!("{}/v1", server.url());
898        let client = InferenceGatewayClient::new(&base_url);
899        let error = client.list_models().await.unwrap_err();
900
901        assert!(matches!(error, GatewayError::Unauthorized(_)));
902        if let GatewayError::Unauthorized(msg) = error {
903            assert_eq!(msg, "Invalid token");
904        }
905        mock.assert();
906
907        Ok(())
908    }
909
910    #[tokio::test]
911    async fn test_list_models() -> Result<(), GatewayError> {
912        let mut server = Server::new_async().await;
913
914        let raw_response_json = r#"[
915            {
916                "provider": "ollama",
917                "models": [
918                    {"name": "llama2"}
919                ]
920            }
921        ]"#;
922
923        let mock = server
924            .mock("GET", "/v1/models")
925            .with_status(200)
926            .with_header("content-type", "application/json")
927            .with_body(raw_response_json)
928            .create();
929
930        let base_url = format!("{}/v1", server.url());
931        let client = InferenceGatewayClient::new(&base_url);
932        let models = client.list_models().await?;
933
934        assert_eq!(models.len(), 1);
935        assert_eq!(models[0].models[0].name, "llama2");
936        mock.assert();
937
938        Ok(())
939    }
940
941    #[tokio::test]
942    async fn test_list_models_by_provider() -> Result<(), GatewayError> {
943        let mut server = Server::new_async().await;
944
945        let raw_json_response = r#"{
946            "provider":"ollama",
947            "models": [{
948                "name": "llama2"
949            }]
950        }"#;
951
952        let mock = server
953            .mock("GET", "/v1/list/models?provider=ollama")
954            .with_status(200)
955            .with_header("content-type", "application/json")
956            .with_body(raw_json_response)
957            .create();
958
959        let base_url = format!("{}/v1", server.url());
960        let client = InferenceGatewayClient::new(&base_url);
961        let models = client.list_models_by_provider(Provider::Ollama).await?;
962
963        assert_eq!(models.provider, Provider::Ollama);
964        assert_eq!(models.models[0].name, "llama2");
965        mock.assert();
966
967        Ok(())
968    }
969
970    #[tokio::test]
971    async fn test_generate_content() -> Result<(), GatewayError> {
972        let mut server = Server::new_async().await;
973
974        let raw_json_response = r#"{
975            "id": "chatcmpl-456",
976            "object": "chat.completion",
977            "created": 1630000001,
978            "model": "mixtral-8x7b",
979            "choices": [
980                {
981                    "index": 0,
982                    "finish_reason": "stop",
983                    "message": {
984                        "role": "assistant",
985                        "content": "Hellloooo"
986                    }
987                }
988            ]
989        }"#;
990
991        let mock = server
992            .mock("POST", "/v1/chat/completions?provider=ollama")
993            .with_status(200)
994            .with_header("content-type", "application/json")
995            .with_body(raw_json_response)
996            .create();
997
998        let base_url = format!("{}/v1", server.url());
999        let client = InferenceGatewayClient::new(&base_url);
1000
1001        let messages = vec![Message {
1002            role: MessageRole::User,
1003            content: "Hello".to_string(),
1004            ..Default::default()
1005        }];
1006        let response = client
1007            .generate_content(Provider::Ollama, "llama2", messages)
1008            .await?;
1009
1010        assert_eq!(response.choices[0].message.role, MessageRole::Assistant);
1011        assert_eq!(response.choices[0].message.content, "Hellloooo");
1012        mock.assert();
1013
1014        Ok(())
1015    }
1016
1017    #[tokio::test]
1018    async fn test_generate_content_serialization() -> Result<(), GatewayError> {
1019        let mut server = Server::new_async().await;
1020
1021        let raw_json = r#"{
1022            "id": "chatcmpl-456",
1023            "object": "chat.completion",
1024            "created": 1630000001,
1025            "model": "mixtral-8x7b",
1026            "choices": [
1027                {
1028                    "index": 0,
1029                    "finish_reason": "stop",
1030                    "message": {
1031                        "role": "assistant",
1032                        "content": "Hello"
1033                    }
1034                }
1035            ]
1036        }"#;
1037
1038        let mock = server
1039            .mock("POST", "/v1/chat/completions?provider=groq")
1040            .with_status(200)
1041            .with_header("content-type", "application/json")
1042            .with_body(raw_json)
1043            .create();
1044
1045        let base_url = format!("{}/v1", server.url());
1046        let client = InferenceGatewayClient::new(&base_url);
1047
1048        let direct_parse: Result<CreateChatCompletionResponse, _> = serde_json::from_str(raw_json);
1049        assert!(
1050            direct_parse.is_ok(),
1051            "Direct JSON parse failed: {:?}",
1052            direct_parse.err()
1053        );
1054
1055        let messages = vec![Message {
1056            role: MessageRole::User,
1057            content: "Hello".to_string(),
1058            ..Default::default()
1059        }];
1060
1061        let response = client
1062            .generate_content(Provider::Groq, "mixtral-8x7b", messages)
1063            .await?;
1064
1065        assert_eq!(response.choices[0].message.role, MessageRole::Assistant);
1066        assert_eq!(response.choices[0].message.content, "Hello");
1067
1068        mock.assert();
1069        Ok(())
1070    }
1071
1072    #[tokio::test]
1073    async fn test_generate_content_error_response() -> Result<(), GatewayError> {
1074        let mut server = Server::new_async().await;
1075
1076        let raw_json_response = r#"{
1077            "error":"Invalid request"
1078        }"#;
1079
1080        let mock = server
1081            .mock("POST", "/v1/chat/completions?provider=groq")
1082            .with_status(400)
1083            .with_header("content-type", "application/json")
1084            .with_body(raw_json_response)
1085            .create();
1086
1087        let base_url = format!("{}/v1", server.url());
1088        let client = InferenceGatewayClient::new(&base_url);
1089        let messages = vec![Message {
1090            role: MessageRole::User,
1091            content: "Hello".to_string(),
1092            ..Default::default()
1093        }];
1094        let error = client
1095            .generate_content(Provider::Groq, "mixtral-8x7b", messages)
1096            .await
1097            .unwrap_err();
1098
1099        assert!(matches!(error, GatewayError::BadRequest(_)));
1100        if let GatewayError::BadRequest(msg) = error {
1101            assert_eq!(msg, "Invalid request");
1102        }
1103        mock.assert();
1104
1105        Ok(())
1106    }
1107
1108    #[tokio::test]
1109    async fn test_gateway_errors() -> Result<(), GatewayError> {
1110        let mut server: mockito::ServerGuard = Server::new_async().await;
1111
1112        let unauthorized_mock = server
1113            .mock("GET", "/v1/models")
1114            .with_status(401)
1115            .with_header("content-type", "application/json")
1116            .with_body(r#"{"error":"Invalid token"}"#)
1117            .create();
1118
1119        let base_url = format!("{}/v1", server.url());
1120        let client = InferenceGatewayClient::new(&base_url);
1121        match client.list_models().await {
1122            Err(GatewayError::Unauthorized(msg)) => assert_eq!(msg, "Invalid token"),
1123            _ => panic!("Expected Unauthorized error"),
1124        }
1125        unauthorized_mock.assert();
1126
1127        let bad_request_mock = server
1128            .mock("GET", "/v1/models")
1129            .with_status(400)
1130            .with_header("content-type", "application/json")
1131            .with_body(r#"{"error":"Invalid provider"}"#)
1132            .create();
1133
1134        match client.list_models().await {
1135            Err(GatewayError::BadRequest(msg)) => assert_eq!(msg, "Invalid provider"),
1136            _ => panic!("Expected BadRequest error"),
1137        }
1138        bad_request_mock.assert();
1139
1140        let internal_error_mock = server
1141            .mock("GET", "/v1/models")
1142            .with_status(500)
1143            .with_header("content-type", "application/json")
1144            .with_body(r#"{"error":"Internal server error occurred"}"#)
1145            .create();
1146
1147        match client.list_models().await {
1148            Err(GatewayError::InternalError(msg)) => {
1149                assert_eq!(msg, "Internal server error occurred")
1150            }
1151            _ => panic!("Expected InternalError error"),
1152        }
1153        internal_error_mock.assert();
1154
1155        Ok(())
1156    }
1157
1158    #[tokio::test]
1159    async fn test_generate_content_case_insensitive() -> Result<(), GatewayError> {
1160        let mut server = Server::new_async().await;
1161
1162        let raw_json = r#"{
1163            "id": "chatcmpl-456",
1164            "object": "chat.completion",
1165            "created": 1630000001,
1166            "model": "mixtral-8x7b",
1167            "choices": [
1168                {
1169                    "index": 0,
1170                    "finish_reason": "stop",
1171                    "message": {
1172                        "role": "assistant",
1173                        "content": "Hello"
1174                    }
1175                }
1176            ]
1177        }"#;
1178
1179        let mock = server
1180            .mock("POST", "/v1/chat/completions?provider=groq")
1181            .with_status(200)
1182            .with_header("content-type", "application/json")
1183            .with_body(raw_json)
1184            .create();
1185
1186        let base_url = format!("{}/v1", server.url());
1187        let client = InferenceGatewayClient::new(&base_url);
1188
1189        let messages = vec![Message {
1190            role: MessageRole::User,
1191            content: "Hello".to_string(),
1192            ..Default::default()
1193        }];
1194
1195        let response = client
1196            .generate_content(Provider::Groq, "mixtral-8x7b", messages)
1197            .await?;
1198
1199        assert_eq!(response.choices[0].message.role, MessageRole::Assistant);
1200        assert_eq!(response.choices[0].message.content, "Hello");
1201        assert_eq!(response.model, "mixtral-8x7b");
1202        assert_eq!(response.object, "chat.completion");
1203        mock.assert();
1204
1205        Ok(())
1206    }
1207
1208    #[tokio::test]
1209    async fn test_generate_content_stream() -> Result<(), GatewayError> {
1210        let mut server = Server::new_async().await;
1211
1212        let mock = server
1213            .mock("POST", "/v1/chat/completions?provider=groq")
1214            .with_status(200)
1215            .with_header("content-type", "text/event-stream")
1216            .with_chunked_body(move |writer| -> std::io::Result<()> {
1217                let events = vec![
1218                    format!("data: {}\n\n", r#"{"id":"chatcmpl-123","object":"chat.completion.chunk","created":1694268190,"model":"mixtral-8x7b","system_fingerprint":"fp_","choices":[{"index":0,"delta":{"role":"assistant","content":"Hello"},"finish_reason":null}]}"#),
1219                    format!("data: {}\n\n", r#"{"id":"chatcmpl-123","object":"chat.completion.chunk","created":1694268191,"model":"mixtral-8x7b","system_fingerprint":"fp_","choices":[{"index":0,"delta":{"role":"assistant","content":" World"},"finish_reason":null}]}"#),
1220                    format!("data: {}\n\n", r#"{"id":"chatcmpl-123","object":"chat.completion.chunk","created":1694268192,"model":"mixtral-8x7b","system_fingerprint":"fp_","choices":[{"index":0,"delta":{},"finish_reason":"stop"}],"usage":{"prompt_tokens":17,"completion_tokens":40,"total_tokens":57}}"#),
1221                    format!("data: [DONE]\n\n")
1222                ];
1223                for event in events {
1224                    writer.write_all(event.as_bytes())?;
1225                }
1226                Ok(())
1227            })
1228            .create();
1229
1230        let base_url = format!("{}/v1", server.url());
1231        let client = InferenceGatewayClient::new(&base_url);
1232
1233        let messages = vec![Message {
1234            role: MessageRole::User,
1235            content: "Test message".to_string(),
1236            ..Default::default()
1237        }];
1238
1239        let stream = client.generate_content_stream(Provider::Groq, "mixtral-8x7b", messages);
1240        pin_mut!(stream);
1241        while let Some(result) = stream.next().await {
1242            let result = result?;
1243            let generate_response: CreateChatCompletionStreamResponse =
1244                serde_json::from_str(&result.data)
1245                    .expect("Failed to parse CreateChatCompletionResponse");
1246
1247            if generate_response.choices[0].finish_reason.is_some() {
1248                assert_eq!(
1249                    generate_response.choices[0].finish_reason.as_ref().unwrap(),
1250                    "stop"
1251                );
1252                break;
1253            }
1254
1255            if let Some(content) = &generate_response.choices[0].delta.content {
1256                assert!(matches!(content.as_str(), "Hello" | " World"));
1257            }
1258            if let Some(role) = &generate_response.choices[0].delta.role {
1259                assert_eq!(role, &MessageRole::Assistant);
1260            }
1261        }
1262
1263        mock.assert();
1264        Ok(())
1265    }
1266
1267    #[tokio::test]
1268    async fn test_generate_content_stream_error() -> Result<(), GatewayError> {
1269        let mut server = Server::new_async().await;
1270
1271        let mock = server
1272            .mock("POST", "/v1/chat/completions?provider=groq")
1273            .with_status(400)
1274            .with_header("content-type", "application/json")
1275            .with_chunked_body(move |writer| -> std::io::Result<()> {
1276                let events = vec![format!(
1277                    "event: {}\ndata: {}\nretry: {}\n\n",
1278                    r#"error"#, r#"{"error":"Invalid request"}"#, r#"1000"#,
1279                )];
1280                for event in events {
1281                    writer.write_all(event.as_bytes())?;
1282                }
1283                Ok(())
1284            })
1285            .expect_at_least(1)
1286            .create();
1287
1288        let base_url = format!("{}/v1", server.url());
1289        let client = InferenceGatewayClient::new(&base_url);
1290
1291        let messages = vec![Message {
1292            role: MessageRole::User,
1293            content: "Test message".to_string(),
1294            ..Default::default()
1295        }];
1296
1297        let stream = client.generate_content_stream(Provider::Groq, "mixtral-8x7b", messages);
1298
1299        pin_mut!(stream);
1300        while let Some(result) = stream.next().await {
1301            let result = result?;
1302            assert!(result.event.is_some());
1303            assert_eq!(result.event.unwrap(), "error");
1304            assert!(result.data.contains("Invalid request"));
1305            assert!(result.retry.is_none());
1306        }
1307
1308        mock.assert();
1309        Ok(())
1310    }
1311
1312    #[tokio::test]
1313    async fn test_generate_content_with_tools() -> Result<(), GatewayError> {
1314        let mut server = Server::new_async().await;
1315
1316        let raw_json_response = r#"{
1317            "id": "chatcmpl-123",
1318            "object": "chat.completion",
1319            "created": 1630000000,
1320            "model": "deepseek-r1-distill-llama-70b",
1321            "choices": [
1322                {
1323                    "index": 0,
1324                    "finish_reason": "tool_calls",
1325                    "message": {
1326                        "role": "assistant",
1327                        "content": "Let me check the weather for you.",
1328                        "tool_calls": [
1329                            {
1330                                "id": "1234",
1331                                "type": "function",
1332                                "function": {
1333                                    "name": "get_weather",
1334                                    "arguments": "{\"location\": \"London\"}"
1335                                }
1336                            }
1337                        ]
1338                    }
1339                }
1340            ]
1341        }"#;
1342
1343        let mock = server
1344            .mock("POST", "/v1/chat/completions?provider=groq")
1345            .with_status(200)
1346            .with_header("content-type", "application/json")
1347            .with_body(raw_json_response)
1348            .create();
1349
1350        let tools = vec![Tool {
1351            r#type: ToolType::Function,
1352            function: FunctionObject {
1353                name: "get_weather".to_string(),
1354                description: "Get the weather for a location".to_string(),
1355                parameters: json!({
1356                    "type": "object",
1357                    "properties": {
1358                        "location": {
1359                            "type": "string",
1360                            "description": "The city name"
1361                        }
1362                    },
1363                    "required": ["location"]
1364                }),
1365            },
1366        }];
1367
1368        let base_url = format!("{}/v1", server.url());
1369        let client = InferenceGatewayClient::new(&base_url).with_tools(Some(tools));
1370
1371        let messages = vec![Message {
1372            role: MessageRole::User,
1373            content: "What's the weather in London?".to_string(),
1374            ..Default::default()
1375        }];
1376
1377        let response = client
1378            .generate_content(Provider::Groq, "deepseek-r1-distill-llama-70b", messages)
1379            .await?;
1380
1381        assert_eq!(response.choices[0].message.role, MessageRole::Assistant);
1382        assert_eq!(
1383            response.choices[0].message.content,
1384            "Let me check the weather for you."
1385        );
1386
1387        let tool_calls = response.choices[0].message.tool_calls.as_ref().unwrap();
1388        assert_eq!(tool_calls.len(), 1);
1389        assert_eq!(tool_calls[0].function.name, "get_weather");
1390
1391        let params = tool_calls[0]
1392            .function
1393            .parse_arguments()
1394            .expect("Failed to parse function arguments");
1395        assert_eq!(params["location"].as_str().unwrap(), "London");
1396
1397        mock.assert();
1398        Ok(())
1399    }
1400
1401    #[tokio::test]
1402    async fn test_generate_content_without_tools() -> Result<(), GatewayError> {
1403        let mut server = Server::new_async().await;
1404
1405        let raw_json_response = r#"{
1406            "id": "chatcmpl-123",
1407            "object": "chat.completion",
1408            "created": 1630000000,
1409            "model": "gpt-4",
1410            "choices": [
1411                {
1412                    "index": 0,
1413                    "finish_reason": "stop",
1414                    "message": {
1415                        "role": "assistant",
1416                        "content": "Hello!"
1417                    }
1418                }
1419            ]
1420        }"#;
1421
1422        let mock = server
1423            .mock("POST", "/v1/chat/completions?provider=openai")
1424            .with_status(200)
1425            .with_header("content-type", "application/json")
1426            .with_body(raw_json_response)
1427            .create();
1428
1429        let base_url = format!("{}/v1", server.url());
1430        let client = InferenceGatewayClient::new(&base_url);
1431
1432        let messages = vec![Message {
1433            role: MessageRole::User,
1434            content: "Hi".to_string(),
1435            ..Default::default()
1436        }];
1437
1438        let response = client
1439            .generate_content(Provider::OpenAI, "gpt-4", messages)
1440            .await?;
1441
1442        assert_eq!(response.model, "gpt-4");
1443        assert_eq!(response.choices[0].message.content, "Hello!");
1444        assert_eq!(response.choices[0].message.role, MessageRole::Assistant);
1445        assert!(response.choices[0].message.tool_calls.is_none());
1446
1447        mock.assert();
1448        Ok(())
1449    }
1450
1451    #[tokio::test]
1452    async fn test_generate_content_with_tools_payload() -> Result<(), GatewayError> {
1453        let mut server = Server::new_async().await;
1454
1455        let raw_request_body = r#"{
1456            "model": "deepseek-r1-distill-llama-70b",
1457            "messages": [
1458                {
1459                    "role": "system",
1460                    "content": "You are a helpful assistant."
1461                },
1462                {
1463                    "role": "user",
1464                    "content": "What is the current weather in Toronto?"
1465                }
1466            ],
1467            "stream": false,
1468            "tools": [
1469                {
1470                    "type": "function",
1471                    "function": {
1472                        "name": "get_current_weather",
1473                        "description": "Get the current weather of a city",
1474                        "parameters": {
1475                            "type": "object",
1476                            "properties": {
1477                                "city": {
1478                                    "type": "string",
1479                                    "description": "The name of the city"
1480                                }
1481                            },
1482                            "required": ["city"]
1483                        }
1484                    }
1485                }
1486            ]
1487        }"#;
1488
1489        let raw_json_response = r#"{
1490            "id": "1234",
1491            "object": "chat.completion",
1492            "created": 1630000000,
1493            "model": "deepseek-r1-distill-llama-70b",
1494            "choices": [
1495                {
1496                    "index": 0,
1497                    "finish_reason": "stop",
1498                    "message": {
1499                        "role": "assistant",
1500                        "content": "Let me check the weather for you",
1501                        "tool_calls": [
1502                            {
1503                                "id": "1234",
1504                                "type": "function",
1505                                "function": {
1506                                    "name": "get_current_weather",
1507                                    "arguments": "{\"city\": \"Toronto\"}"
1508                                }
1509                            }
1510                        ]
1511                    }
1512                }
1513            ]
1514        }"#;
1515
1516        let mock = server
1517            .mock("POST", "/v1/chat/completions?provider=groq")
1518            .with_status(200)
1519            .with_header("content-type", "application/json")
1520            .match_body(mockito::Matcher::JsonString(raw_request_body.to_string()))
1521            .with_body(raw_json_response)
1522            .create();
1523
1524        let tools = vec![Tool {
1525            r#type: ToolType::Function,
1526            function: FunctionObject {
1527                name: "get_current_weather".to_string(),
1528                description: "Get the current weather of a city".to_string(),
1529                parameters: json!({
1530                    "type": "object",
1531                    "properties": {
1532                        "city": {
1533                            "type": "string",
1534                            "description": "The name of the city"
1535                        }
1536                    },
1537                    "required": ["city"]
1538                }),
1539            },
1540        }];
1541
1542        let base_url = format!("{}/v1", server.url());
1543        let client = InferenceGatewayClient::new(&base_url);
1544
1545        let messages = vec![
1546            Message {
1547                role: MessageRole::System,
1548                content: "You are a helpful assistant.".to_string(),
1549                ..Default::default()
1550            },
1551            Message {
1552                role: MessageRole::User,
1553                content: "What is the current weather in Toronto?".to_string(),
1554                ..Default::default()
1555            },
1556        ];
1557
1558        let response = client
1559            .with_tools(Some(tools))
1560            .generate_content(Provider::Groq, "deepseek-r1-distill-llama-70b", messages)
1561            .await?;
1562
1563        assert_eq!(response.choices[0].message.role, MessageRole::Assistant);
1564        assert_eq!(
1565            response.choices[0].message.content,
1566            "Let me check the weather for you"
1567        );
1568        assert_eq!(
1569            response.choices[0]
1570                .message
1571                .tool_calls
1572                .as_ref()
1573                .unwrap()
1574                .len(),
1575            1
1576        );
1577
1578        mock.assert();
1579        Ok(())
1580    }
1581
1582    #[tokio::test]
1583    async fn test_generate_content_with_max_tokens() -> Result<(), GatewayError> {
1584        let mut server = Server::new_async().await;
1585
1586        let raw_json_response = r#"{
1587            "id": "chatcmpl-123",
1588            "object": "chat.completion",
1589            "created": 1630000000,
1590            "model": "mixtral-8x7b",
1591            "choices": [
1592                {
1593                    "index": 0,
1594                    "finish_reason": "stop",
1595                    "message": {
1596                        "role": "assistant",
1597                        "content": "Here's a poem with 100 tokens..."
1598                    }
1599                }
1600            ]
1601        }"#;
1602
1603        let mock = server
1604            .mock("POST", "/v1/chat/completions?provider=groq")
1605            .with_status(200)
1606            .with_header("content-type", "application/json")
1607            .match_body(mockito::Matcher::JsonString(
1608                r#"{
1609                "model": "mixtral-8x7b",
1610                "messages": [{"role":"user","content":"Write a poem"}],
1611                "stream": false,
1612                "max_tokens": 100
1613            }"#
1614                .to_string(),
1615            ))
1616            .with_body(raw_json_response)
1617            .create();
1618
1619        let base_url = format!("{}/v1", server.url());
1620        let client = InferenceGatewayClient::new(&base_url).with_max_tokens(Some(100));
1621
1622        let messages = vec![Message {
1623            role: MessageRole::User,
1624            content: "Write a poem".to_string(),
1625            ..Default::default()
1626        }];
1627
1628        let response = client
1629            .generate_content(Provider::Groq, "mixtral-8x7b", messages)
1630            .await?;
1631
1632        assert_eq!(
1633            response.choices[0].message.content,
1634            "Here's a poem with 100 tokens..."
1635        );
1636        assert_eq!(response.model, "mixtral-8x7b");
1637        assert_eq!(response.created, 1630000000);
1638        assert_eq!(response.object, "chat.completion");
1639
1640        mock.assert();
1641        Ok(())
1642    }
1643
1644    #[tokio::test]
1645    async fn test_health_check() -> Result<(), GatewayError> {
1646        let mut server = Server::new_async().await;
1647        let mock = server.mock("GET", "/health").with_status(200).create();
1648
1649        let client = InferenceGatewayClient::new(&server.url());
1650        let is_healthy = client.health_check().await?;
1651
1652        assert!(is_healthy);
1653        mock.assert();
1654
1655        Ok(())
1656    }
1657
1658    #[tokio::test]
1659    async fn test_client_base_url_configuration() -> Result<(), GatewayError> {
1660        let mut custom_url_server = Server::new_async().await;
1661
1662        let custom_url_mock = custom_url_server
1663            .mock("GET", "/health")
1664            .with_status(200)
1665            .create();
1666
1667        let custom_client = InferenceGatewayClient::new(&custom_url_server.url());
1668        let is_healthy = custom_client.health_check().await?;
1669        assert!(is_healthy);
1670        custom_url_mock.assert();
1671
1672        let default_client = InferenceGatewayClient::new_default();
1673
1674        let default_url = "http://localhost:8080/v1";
1675        assert_eq!(default_client.base_url(), default_url);
1676
1677        Ok(())
1678    }
1679}