inference_gateway_sdk/
lib.rs

1//! Inference Gateway SDK for Rust
2//!
3//! This crate provides a Rust client for the Inference Gateway API, allowing interaction
4//! with various LLM providers through a unified interface.
5
6use core::fmt;
7use std::future::Future;
8
9use futures_util::{Stream, StreamExt};
10use reqwest::{Client, StatusCode};
11use serde::{Deserialize, Serialize};
12
13use serde_json::Value;
14use thiserror::Error;
15
16/// Stream of Server-Sent Events (SSE) from the Inference Gateway API
17#[derive(Debug, Serialize, Deserialize)]
18pub struct SSEvents {
19    pub data: String,
20    pub event: Option<String>,
21    pub retry: Option<u64>,
22}
23
24/// Custom error types for the Inference Gateway SDK
25#[derive(Error, Debug)]
26pub enum GatewayError {
27    #[error("Unauthorized: {0}")]
28    Unauthorized(String),
29
30    #[error("Bad request: {0}")]
31    BadRequest(String),
32
33    #[error("Internal server error: {0}")]
34    InternalError(String),
35
36    #[error("Stream error: {0}")]
37    StreamError(reqwest::Error),
38
39    #[error("Decoding error: {0}")]
40    DecodingError(std::string::FromUtf8Error),
41
42    #[error("Request error: {0}")]
43    RequestError(#[from] reqwest::Error),
44
45    #[error("Deserialization error: {0}")]
46    DeserializationError(serde_json::Error),
47
48    #[error("Serialization error: {0}")]
49    SerializationError(#[from] serde_json::Error),
50
51    #[error("Other error: {0}")]
52    Other(#[from] Box<dyn std::error::Error + Send + Sync>),
53}
54
55#[derive(Debug, Deserialize)]
56struct ErrorResponse {
57    error: String,
58}
59
60/// Common model information
61#[derive(Debug, Serialize, Deserialize, Clone)]
62pub struct Model {
63    /// The model identifier
64    pub id: String,
65    /// The object type, usually "model"
66    pub object: Option<String>,
67    /// The Unix timestamp (in seconds) of when the model was created
68    pub created: Option<i64>,
69    /// The organization that owns the model
70    pub owned_by: Option<String>,
71    /// The provider that serves the model
72    pub served_by: Option<String>,
73}
74
75/// Response structure for listing models
76#[derive(Debug, Serialize, Deserialize)]
77pub struct ListModelsResponse {
78    /// The provider identifier
79    #[serde(skip_serializing_if = "Option::is_none")]
80    pub provider: Option<Provider>,
81    /// Response object type, usually "list"
82    pub object: String,
83    /// List of available models
84    pub data: Vec<Model>,
85}
86
87/// Supported LLM providers
88#[derive(Debug, Serialize, Deserialize, PartialEq, Clone, Copy)]
89#[serde(rename_all = "lowercase")]
90pub enum Provider {
91    #[serde(alias = "Ollama", alias = "OLLAMA")]
92    Ollama,
93    #[serde(alias = "Groq", alias = "GROQ")]
94    Groq,
95    #[serde(alias = "OpenAI", alias = "OPENAI")]
96    OpenAI,
97    #[serde(alias = "Cloudflare", alias = "CLOUDFLARE")]
98    Cloudflare,
99    #[serde(alias = "Cohere", alias = "COHERE")]
100    Cohere,
101    #[serde(alias = "Anthropic", alias = "ANTHROPIC")]
102    Anthropic,
103}
104
105impl fmt::Display for Provider {
106    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
107        match self {
108            Provider::Ollama => write!(f, "ollama"),
109            Provider::Groq => write!(f, "groq"),
110            Provider::OpenAI => write!(f, "openai"),
111            Provider::Cloudflare => write!(f, "cloudflare"),
112            Provider::Cohere => write!(f, "cohere"),
113            Provider::Anthropic => write!(f, "anthropic"),
114        }
115    }
116}
117
118impl TryFrom<&str> for Provider {
119    type Error = GatewayError;
120
121    fn try_from(s: &str) -> Result<Self, Self::Error> {
122        match s.to_lowercase().as_str() {
123            "ollama" => Ok(Self::Ollama),
124            "groq" => Ok(Self::Groq),
125            "openai" => Ok(Self::OpenAI),
126            "cloudflare" => Ok(Self::Cloudflare),
127            "cohere" => Ok(Self::Cohere),
128            "anthropic" => Ok(Self::Anthropic),
129            _ => Err(GatewayError::BadRequest(format!("Unknown provider: {}", s))),
130        }
131    }
132}
133
134#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Default)]
135#[serde(rename_all = "lowercase")]
136pub enum MessageRole {
137    System,
138    #[default]
139    User,
140    Assistant,
141    Tool,
142}
143
144impl fmt::Display for MessageRole {
145    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
146        match self {
147            MessageRole::System => write!(f, "system"),
148            MessageRole::User => write!(f, "user"),
149            MessageRole::Assistant => write!(f, "assistant"),
150            MessageRole::Tool => write!(f, "tool"),
151        }
152    }
153}
154
155/// A message in a conversation with an LLM
156#[derive(Debug, Serialize, Deserialize, Clone, Default)]
157pub struct Message {
158    /// Role of the message sender ("system", "user", "assistant" or "tool")
159    pub role: MessageRole,
160    /// Content of the message
161    pub content: String,
162    /// The tools an LLM wants to invoke
163    #[serde(skip_serializing_if = "Option::is_none")]
164    pub tool_calls: Option<Vec<ChatCompletionMessageToolCall>>,
165    /// Unique identifier of the tool call
166    #[serde(skip_serializing_if = "Option::is_none")]
167    pub tool_call_id: Option<String>,
168    /// Reasoning behind the message
169    #[serde(skip_serializing_if = "Option::is_none")]
170    pub reasoning: Option<String>,
171}
172
173/// A tool call in a message response
174#[derive(Debug, Deserialize, Serialize, Clone)]
175pub struct ChatCompletionMessageToolCall {
176    /// Unique identifier of the tool call
177    pub id: String,
178    /// Type of the tool being called
179    #[serde(rename = "type")]
180    pub r#type: ChatCompletionToolType,
181    /// Function that was called
182    pub function: ChatCompletionMessageToolCallFunction,
183}
184
185/// Type of tool that can be called
186#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
187pub enum ChatCompletionToolType {
188    /// Function tool type
189    #[serde(rename = "function")]
190    Function,
191}
192
193/// Function details in a tool call
194#[derive(Debug, Deserialize, Serialize, Clone)]
195pub struct ChatCompletionMessageToolCallFunction {
196    /// Name of the function to call
197    pub name: String,
198    /// Arguments to the function in JSON string format
199    pub arguments: String,
200}
201
202// Add this helper method to make argument access more convenient
203impl ChatCompletionMessageToolCallFunction {
204    pub fn parse_arguments(&self) -> Result<serde_json::Value, serde_json::Error> {
205        serde_json::from_str(&self.arguments)
206    }
207}
208
209/// Tool function to call
210#[derive(Debug, Serialize, Deserialize, Clone)]
211pub struct FunctionObject {
212    pub name: String,
213    pub description: String,
214    pub parameters: Value,
215}
216
217/// Type of tool that can be used by the model
218#[derive(Debug, Serialize, Deserialize, Clone)]
219#[serde(rename_all = "lowercase")]
220pub enum ToolType {
221    Function,
222}
223
224/// Tool to use for the LLM toolbox
225#[derive(Debug, Serialize, Deserialize, Clone)]
226pub struct Tool {
227    pub r#type: ToolType,
228    pub function: FunctionObject,
229}
230
231/// Request payload for generating content
232#[derive(Debug, Serialize)]
233struct CreateChatCompletionRequest {
234    /// Name of the model
235    model: String,
236    /// Conversation history and prompt
237    messages: Vec<Message>,
238    /// Enable streaming of responses
239    stream: bool,
240    /// Optional tools to use for generation
241    #[serde(skip_serializing_if = "Option::is_none")]
242    tools: Option<Vec<Tool>>,
243    /// Maximum number of tokens to generate
244    #[serde(skip_serializing_if = "Option::is_none")]
245    max_tokens: Option<i32>,
246}
247
248/// Function details in a tool call response
249#[derive(Debug, Deserialize, Clone)]
250pub struct ToolFunctionResponse {
251    /// Name of the function that the LLM wants to call
252    pub name: String,
253    /// Description of the function
254    #[serde(skip_serializing_if = "Option::is_none")]
255    pub description: Option<String>,
256    /// The arguments that the LLM wants to pass to the function
257    pub arguments: Value,
258}
259
260/// A tool call in the response
261#[derive(Debug, Deserialize, Clone)]
262pub struct ToolCallResponse {
263    /// Unique identifier of the tool call
264    pub id: String,
265    /// Type of tool that was called
266    #[serde(rename = "type")]
267    pub r#type: ToolType,
268    /// Function that the LLM wants to call
269    pub function: ToolFunctionResponse,
270}
271
272#[derive(Debug, Deserialize, Clone)]
273pub struct ChatCompletionChoice {
274    pub finish_reason: String,
275    pub message: Message,
276    pub index: i32,
277}
278
279/// The response from generating content
280#[derive(Debug, Deserialize, Clone)]
281pub struct CreateChatCompletionResponse {
282    pub id: String,
283    pub choices: Vec<ChatCompletionChoice>,
284    pub created: i64,
285    pub model: String,
286    pub object: String,
287}
288
289/// The response from streaming content generation
290#[derive(Debug, Deserialize, Clone)]
291pub struct CreateChatCompletionStreamResponse {
292    /// A unique identifier for the chat completion. Each chunk has the same ID.
293    pub id: String,
294    /// A list of chat completion choices. Can contain more than one element if `n` is greater than 1.
295    pub choices: Vec<ChatCompletionStreamChoice>,
296    /// The Unix timestamp (in seconds) of when the chat completion was created.
297    pub created: i64,
298    /// The model used to generate the completion.
299    pub model: String,
300    /// This fingerprint represents the backend configuration that the model runs with.
301    #[serde(skip_serializing_if = "Option::is_none")]
302    pub system_fingerprint: Option<String>,
303    /// The object type, which is always "chat.completion.chunk".
304    pub object: String,
305    /// Usage statistics for the completion request.
306    #[serde(skip_serializing_if = "Option::is_none")]
307    pub usage: Option<CompletionUsage>,
308}
309
310/// Choice in a streaming completion response
311#[derive(Debug, Deserialize, Clone)]
312pub struct ChatCompletionStreamChoice {
313    /// The delta content for this streaming chunk
314    pub delta: ChatCompletionStreamDelta,
315    /// Index of the choice in the choices array
316    pub index: i32,
317    /// The reason the model stopped generating tokens
318    #[serde(skip_serializing_if = "Option::is_none")]
319    pub finish_reason: Option<String>,
320}
321
322/// Delta content for streaming responses
323#[derive(Debug, Deserialize, Clone)]
324pub struct ChatCompletionStreamDelta {
325    /// Role of the message sender
326    #[serde(skip_serializing_if = "Option::is_none")]
327    pub role: Option<MessageRole>,
328    /// Content of the message delta
329    #[serde(skip_serializing_if = "Option::is_none")]
330    pub content: Option<String>,
331    /// Tool calls for this delta
332    #[serde(skip_serializing_if = "Option::is_none")]
333    pub tool_calls: Option<Vec<ToolCallResponse>>,
334}
335
336/// Usage statistics for the completion request
337#[derive(Debug, Deserialize, Clone)]
338pub struct CompletionUsage {
339    /// Number of tokens in the generated completion
340    pub completion_tokens: i64,
341    /// Number of tokens in the prompt
342    pub prompt_tokens: i64,
343    /// Total number of tokens used in the request (prompt + completion)
344    pub total_tokens: i64,
345}
346
347/// Client for interacting with the Inference Gateway API
348pub struct InferenceGatewayClient {
349    base_url: String,
350    client: Client,
351    token: Option<String>,
352    tools: Option<Vec<Tool>>,
353    max_tokens: Option<i32>,
354}
355
356/// Implement Debug for InferenceGatewayClient
357impl std::fmt::Debug for InferenceGatewayClient {
358    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
359        f.debug_struct("InferenceGatewayClient")
360            .field("base_url", &self.base_url)
361            .field("token", &self.token.as_ref().map(|_| "*****"))
362            .finish()
363    }
364}
365
366/// Core API interface for the Inference Gateway
367pub trait InferenceGatewayAPI {
368    /// Lists available models from all providers
369    ///
370    /// # Errors
371    /// - Returns [`GatewayError::Unauthorized`] if authentication fails
372    /// - Returns [`GatewayError::BadRequest`] if the request is malformed
373    /// - Returns [`GatewayError::InternalError`] if the server has an error
374    /// - Returns [`GatewayError::Other`] for other errors
375    ///
376    /// # Returns
377    /// A list of models available from all providers
378    fn list_models(&self) -> impl Future<Output = Result<ListModelsResponse, GatewayError>> + Send;
379
380    /// Lists available models by a specific provider
381    ///
382    /// # Arguments
383    /// * `provider` - The LLM provider to list models for
384    ///
385    /// # Errors
386    /// - Returns [`GatewayError::Unauthorized`] if authentication fails
387    /// - Returns [`GatewayError::BadRequest`] if the request is malformed
388    /// - Returns [`GatewayError::InternalError`] if the server has an error
389    /// - Returns [`GatewayError::Other`] for other errors
390    ///
391    /// # Returns
392    /// A list of models available from the specified provider
393    fn list_models_by_provider(
394        &self,
395        provider: Provider,
396    ) -> impl Future<Output = Result<ListModelsResponse, GatewayError>> + Send;
397
398    /// Generates content using a specified model
399    ///
400    /// # Arguments
401    /// * `provider` - The LLM provider to use
402    /// * `model` - Name of the model
403    /// * `messages` - Conversation history and prompt
404    /// * `tools` - Optional tools to use for generation
405    ///
406    /// # Errors
407    /// - Returns [`GatewayError::Unauthorized`] if authentication fails
408    /// - Returns [`GatewayError::BadRequest`] if the request is malformed
409    /// - Returns [`GatewayError::InternalError`] if the server has an error
410    /// - Returns [`GatewayError::Other`] for other errors
411    ///
412    /// # Returns
413    /// The generated response
414    fn generate_content(
415        &self,
416        provider: Provider,
417        model: &str,
418        messages: Vec<Message>,
419    ) -> impl Future<Output = Result<CreateChatCompletionResponse, GatewayError>> + Send;
420
421    /// Stream content generation directly using the backend SSE stream.
422    ///
423    /// # Arguments
424    /// * `provider` - The LLM provider to use
425    /// * `model` - Name of the model
426    /// * `messages` - Conversation history and prompt
427    ///
428    /// # Returns
429    /// A stream of Server-Sent Events (SSE) from the Inference Gateway API
430    fn generate_content_stream(
431        &self,
432        provider: Provider,
433        model: &str,
434        messages: Vec<Message>,
435    ) -> impl Stream<Item = Result<SSEvents, GatewayError>> + Send;
436
437    /// Checks if the API is available
438    fn health_check(&self) -> impl Future<Output = Result<bool, GatewayError>> + Send;
439}
440
441impl InferenceGatewayClient {
442    /// Creates a new client instance
443    ///
444    /// # Arguments
445    /// * `base_url` - Base URL of the Inference Gateway API
446    pub fn new(base_url: &str) -> Self {
447        Self {
448            base_url: base_url.to_string(),
449            client: Client::new(),
450            token: None,
451            tools: None,
452            max_tokens: None,
453        }
454    }
455
456    /// Creates a new client instance with default configuration
457    /// pointing to http://localhost:8080/v1
458    pub fn new_default() -> Self {
459        let base_url = std::env::var("INFERENCE_GATEWAY_URL")
460            .unwrap_or_else(|_| "http://localhost:8080/v1".to_string());
461
462        Self {
463            base_url,
464            client: Client::new(),
465            token: None,
466            tools: None,
467            max_tokens: None,
468        }
469    }
470
471    /// Returns the base URL this client is configured to use
472    pub fn base_url(&self) -> &str {
473        &self.base_url
474    }
475
476    /// Provides tools to use for generation
477    ///
478    /// # Arguments
479    /// * `tools` - List of tools to use for generation
480    ///
481    /// # Returns
482    /// Self with the tools set
483    pub fn with_tools(mut self, tools: Option<Vec<Tool>>) -> Self {
484        self.tools = tools;
485        self
486    }
487
488    /// Sets an authentication token for the client
489    ///
490    /// # Arguments
491    /// * `token` - JWT token for authentication
492    ///
493    /// # Returns
494    /// Self with the token set
495    pub fn with_token(mut self, token: impl Into<String>) -> Self {
496        self.token = Some(token.into());
497        self
498    }
499
500    /// Sets the maximum number of tokens to generate
501    ///
502    /// # Arguments
503    /// * `max_tokens` - Maximum number of tokens to generate
504    ///
505    /// # Returns
506    /// Self with the maximum tokens set
507    pub fn with_max_tokens(mut self, max_tokens: Option<i32>) -> Self {
508        self.max_tokens = max_tokens;
509        self
510    }
511}
512
513impl InferenceGatewayAPI for InferenceGatewayClient {
514    async fn list_models(&self) -> Result<ListModelsResponse, GatewayError> {
515        let url = format!("{}/models", self.base_url);
516        let mut request = self.client.get(&url);
517        if let Some(token) = &self.token {
518            request = request.bearer_auth(token);
519        }
520
521        let response = request.send().await?;
522        match response.status() {
523            StatusCode::OK => {
524                let json_response: ListModelsResponse = response.json().await?;
525                Ok(json_response)
526            }
527            StatusCode::UNAUTHORIZED => {
528                let error: ErrorResponse = response.json().await?;
529                Err(GatewayError::Unauthorized(error.error))
530            }
531            StatusCode::BAD_REQUEST => {
532                let error: ErrorResponse = response.json().await?;
533                Err(GatewayError::BadRequest(error.error))
534            }
535            StatusCode::INTERNAL_SERVER_ERROR => {
536                let error: ErrorResponse = response.json().await?;
537                Err(GatewayError::InternalError(error.error))
538            }
539            _ => Err(GatewayError::Other(Box::new(std::io::Error::new(
540                std::io::ErrorKind::Other,
541                format!("Unexpected status code: {}", response.status()),
542            )))),
543        }
544    }
545
546    async fn list_models_by_provider(
547        &self,
548        provider: Provider,
549    ) -> Result<ListModelsResponse, GatewayError> {
550        let url = format!("{}/list/models?provider={}", self.base_url, provider);
551        let mut request = self.client.get(&url);
552        if let Some(token) = &self.token {
553            request = self.client.get(&url).bearer_auth(token);
554        }
555
556        let response = request.send().await?;
557        match response.status() {
558            StatusCode::OK => {
559                let json_response: ListModelsResponse = response.json().await?;
560                Ok(json_response)
561            }
562            StatusCode::UNAUTHORIZED => {
563                let error: ErrorResponse = response.json().await?;
564                Err(GatewayError::Unauthorized(error.error))
565            }
566            StatusCode::BAD_REQUEST => {
567                let error: ErrorResponse = response.json().await?;
568                Err(GatewayError::BadRequest(error.error))
569            }
570            StatusCode::INTERNAL_SERVER_ERROR => {
571                let error: ErrorResponse = response.json().await?;
572                Err(GatewayError::InternalError(error.error))
573            }
574            _ => Err(GatewayError::Other(Box::new(std::io::Error::new(
575                std::io::ErrorKind::Other,
576                format!("Unexpected status code: {}", response.status()),
577            )))),
578        }
579    }
580
581    async fn generate_content(
582        &self,
583        provider: Provider,
584        model: &str,
585        messages: Vec<Message>,
586    ) -> Result<CreateChatCompletionResponse, GatewayError> {
587        let url = format!("{}/chat/completions?provider={}", self.base_url, provider);
588        let mut request = self.client.post(&url);
589        if let Some(token) = &self.token {
590            request = request.bearer_auth(token);
591        }
592
593        let request_payload = CreateChatCompletionRequest {
594            model: model.to_string(),
595            messages,
596            stream: false,
597            tools: self.tools.clone(),
598            max_tokens: self.max_tokens,
599        };
600
601        let response = request.json(&request_payload).send().await?;
602
603        match response.status() {
604            StatusCode::OK => Ok(response.json().await?),
605            StatusCode::BAD_REQUEST => {
606                let error: ErrorResponse = response.json().await?;
607                Err(GatewayError::BadRequest(error.error))
608            }
609            StatusCode::UNAUTHORIZED => {
610                let error: ErrorResponse = response.json().await?;
611                Err(GatewayError::Unauthorized(error.error))
612            }
613            StatusCode::INTERNAL_SERVER_ERROR => {
614                let error: ErrorResponse = response.json().await?;
615                Err(GatewayError::InternalError(error.error))
616            }
617            status => Err(GatewayError::Other(Box::new(std::io::Error::new(
618                std::io::ErrorKind::Other,
619                format!("Unexpected status code: {}", status),
620            )))),
621        }
622    }
623
624    /// Stream content generation directly using the backend SSE stream.
625    fn generate_content_stream(
626        &self,
627        provider: Provider,
628        model: &str,
629        messages: Vec<Message>,
630    ) -> impl Stream<Item = Result<SSEvents, GatewayError>> + Send {
631        let client = self.client.clone();
632        let base_url = self.base_url.clone();
633        let url = format!(
634            "{}/chat/completions?provider={}",
635            base_url,
636            provider.to_string().to_lowercase()
637        );
638
639        let request = CreateChatCompletionRequest {
640            model: model.to_string(),
641            messages,
642            stream: true,
643            tools: None,
644            max_tokens: None,
645        };
646
647        async_stream::try_stream! {
648            let response = client.post(&url).json(&request).send().await?;
649            let mut stream = response.bytes_stream();
650            let mut current_event: Option<String> = None;
651            let mut current_data: Option<String> = None;
652
653            while let Some(chunk) = stream.next().await {
654                let chunk = chunk?;
655                let chunk_str = String::from_utf8_lossy(&chunk);
656
657                for line in chunk_str.lines() {
658                    if line.is_empty() && current_data.is_some() {
659                        yield SSEvents {
660                            data: current_data.take().unwrap(),
661                            event: current_event.take(),
662                            retry: None, // TODO - implement this, for now it's not implemented in the backend
663                        };
664                        continue;
665                    }
666
667                    if let Some(event) = line.strip_prefix("event:") {
668                        current_event = Some(event.trim().to_string());
669                    } else if let Some(data) = line.strip_prefix("data:") {
670                        let processed_data = data.strip_suffix('\n').unwrap_or(data);
671                        current_data = Some(processed_data.trim().to_string());
672                    }
673                }
674            }
675        }
676    }
677
678    async fn health_check(&self) -> Result<bool, GatewayError> {
679        let url = format!("{}/health", self.base_url);
680
681        let response = self.client.get(&url).send().await?;
682        match response.status() {
683            StatusCode::OK => Ok(true),
684            _ => Ok(false),
685        }
686    }
687}
688
689#[cfg(test)]
690mod tests {
691    use crate::{
692        CreateChatCompletionRequest, CreateChatCompletionResponse,
693        CreateChatCompletionStreamResponse, FunctionObject, GatewayError, InferenceGatewayAPI,
694        InferenceGatewayClient, Message, MessageRole, Provider, Tool, ToolType,
695    };
696    use futures_util::{pin_mut, StreamExt};
697    use mockito::{Matcher, Server};
698    use serde_json::json;
699
700    #[test]
701    fn test_provider_serialization() {
702        let providers = vec![
703            (Provider::Ollama, "ollama"),
704            (Provider::Groq, "groq"),
705            (Provider::OpenAI, "openai"),
706            (Provider::Cloudflare, "cloudflare"),
707            (Provider::Cohere, "cohere"),
708            (Provider::Anthropic, "anthropic"),
709        ];
710
711        for (provider, expected) in providers {
712            let json = serde_json::to_string(&provider).unwrap();
713            assert_eq!(json, format!("\"{}\"", expected));
714        }
715    }
716
717    #[test]
718    fn test_provider_deserialization() {
719        let test_cases = vec![
720            ("\"ollama\"", Provider::Ollama),
721            ("\"groq\"", Provider::Groq),
722            ("\"openai\"", Provider::OpenAI),
723            ("\"cloudflare\"", Provider::Cloudflare),
724            ("\"cohere\"", Provider::Cohere),
725            ("\"anthropic\"", Provider::Anthropic),
726        ];
727
728        for (json, expected) in test_cases {
729            let provider: Provider = serde_json::from_str(json).unwrap();
730            assert_eq!(provider, expected);
731        }
732    }
733
734    #[test]
735    fn test_message_serialization_with_tool_call_id() {
736        let message_with_tool = Message {
737            role: MessageRole::Tool,
738            content: "The weather is sunny".to_string(),
739            tool_call_id: Some("call_123".to_string()),
740            ..Default::default()
741        };
742
743        let serialized = serde_json::to_string(&message_with_tool).unwrap();
744        let expected_with_tool =
745            r#"{"role":"tool","content":"The weather is sunny","tool_call_id":"call_123"}"#;
746        assert_eq!(serialized, expected_with_tool);
747
748        let message_without_tool = Message {
749            role: MessageRole::User,
750            content: "What's the weather?".to_string(),
751            ..Default::default()
752        };
753
754        let serialized = serde_json::to_string(&message_without_tool).unwrap();
755        let expected_without_tool = r#"{"role":"user","content":"What's the weather?"}"#;
756        assert_eq!(serialized, expected_without_tool);
757
758        let deserialized: Message = serde_json::from_str(expected_with_tool).unwrap();
759        assert_eq!(deserialized.role, MessageRole::Tool);
760        assert_eq!(deserialized.content, "The weather is sunny");
761        assert_eq!(deserialized.tool_call_id, Some("call_123".to_string()));
762
763        let deserialized: Message = serde_json::from_str(expected_without_tool).unwrap();
764        assert_eq!(deserialized.role, MessageRole::User);
765        assert_eq!(deserialized.content, "What's the weather?");
766        assert_eq!(deserialized.tool_call_id, None);
767    }
768
769    #[test]
770    fn test_provider_display() {
771        let providers = vec![
772            (Provider::Ollama, "ollama"),
773            (Provider::Groq, "groq"),
774            (Provider::OpenAI, "openai"),
775            (Provider::Cloudflare, "cloudflare"),
776            (Provider::Cohere, "cohere"),
777            (Provider::Anthropic, "anthropic"),
778        ];
779
780        for (provider, expected) in providers {
781            assert_eq!(provider.to_string(), expected);
782        }
783    }
784
785    #[test]
786    fn test_generate_request_serialization() {
787        let request_payload = CreateChatCompletionRequest {
788            model: "llama3.2:1b".to_string(),
789            messages: vec![
790                Message {
791                    role: MessageRole::System,
792                    content: "You are a helpful assistant.".to_string(),
793                    ..Default::default()
794                },
795                Message {
796                    role: MessageRole::User,
797                    content: "What is the current weather in Toronto?".to_string(),
798                    ..Default::default()
799                },
800            ],
801            stream: false,
802            tools: Some(vec![Tool {
803                r#type: ToolType::Function,
804                function: FunctionObject {
805                    name: "get_current_weather".to_string(),
806                    description: "Get the current weather of a city".to_string(),
807                    parameters: json!({
808                        "type": "object",
809                        "properties": {
810                            "city": {
811                                "type": "string",
812                                "description": "The name of the city"
813                            }
814                        },
815                        "required": ["city"]
816                    }),
817                },
818            }]),
819            max_tokens: None,
820        };
821
822        let serialized = serde_json::to_string_pretty(&request_payload).unwrap();
823        let expected = r#"{
824      "model": "llama3.2:1b",
825      "messages": [
826        {
827          "role": "system",
828          "content": "You are a helpful assistant."
829        },
830        {
831          "role": "user",
832          "content": "What is the current weather in Toronto?"
833        }
834      ],
835      "stream": false,
836      "tools": [
837        {
838          "type": "function",
839          "function": {
840            "name": "get_current_weather",
841            "description": "Get the current weather of a city",
842            "parameters": {
843              "type": "object",
844              "properties": {
845                "city": {
846                  "type": "string",
847                  "description": "The name of the city"
848                }
849              },
850              "required": ["city"]
851            }
852          }
853        }
854      ]
855    }"#;
856
857        assert_eq!(
858            serde_json::from_str::<serde_json::Value>(&serialized).unwrap(),
859            serde_json::from_str::<serde_json::Value>(expected).unwrap()
860        );
861    }
862
863    #[tokio::test]
864    async fn test_authentication_header() -> Result<(), GatewayError> {
865        let mut server = Server::new_async().await;
866
867        let mock_response = r#"{
868            "object": "list",
869            "data": []
870        }"#;
871
872        let mock_with_auth = server
873            .mock("GET", "/v1/models")
874            .match_header("authorization", "Bearer test-token")
875            .with_status(200)
876            .with_header("content-type", "application/json")
877            .with_body(mock_response)
878            .expect(1)
879            .create();
880
881        let base_url = format!("{}/v1", server.url());
882        let client = InferenceGatewayClient::new(&base_url).with_token("test-token");
883        client.list_models().await?;
884        mock_with_auth.assert();
885
886        let mock_without_auth = server
887            .mock("GET", "/v1/models")
888            .match_header("authorization", Matcher::Missing)
889            .with_status(200)
890            .with_header("content-type", "application/json")
891            .with_body(mock_response)
892            .expect(1)
893            .create();
894
895        let base_url = format!("{}/v1", server.url());
896        let client = InferenceGatewayClient::new(&base_url);
897        client.list_models().await?;
898        mock_without_auth.assert();
899
900        Ok(())
901    }
902
903    #[tokio::test]
904    async fn test_unauthorized_error() -> Result<(), GatewayError> {
905        let mut server = Server::new_async().await;
906
907        let raw_json_response = r#"{
908            "error": "Invalid token"
909        }"#;
910
911        let mock = server
912            .mock("GET", "/v1/models")
913            .with_status(401)
914            .with_header("content-type", "application/json")
915            .with_body(raw_json_response)
916            .create();
917
918        let base_url = format!("{}/v1", server.url());
919        let client = InferenceGatewayClient::new(&base_url);
920        let error = client.list_models().await.unwrap_err();
921
922        assert!(matches!(error, GatewayError::Unauthorized(_)));
923        if let GatewayError::Unauthorized(msg) = error {
924            assert_eq!(msg, "Invalid token");
925        }
926        mock.assert();
927
928        Ok(())
929    }
930
931    #[tokio::test]
932    async fn test_list_models() -> Result<(), GatewayError> {
933        let mut server = Server::new_async().await;
934
935        let raw_response_json = r#"{
936            "object": "list",
937            "data": [
938                {
939                    "id": "llama2",
940                    "object": "model",
941                    "created": 1630000001,
942                    "owned_by": "ollama",
943                    "served_by": "ollama"
944                }
945            ]
946        }"#;
947
948        let mock = server
949            .mock("GET", "/v1/models")
950            .with_status(200)
951            .with_header("content-type", "application/json")
952            .with_body(raw_response_json)
953            .create();
954
955        let base_url = format!("{}/v1", server.url());
956        let client = InferenceGatewayClient::new(&base_url);
957        let response = client.list_models().await?;
958
959        assert!(response.provider.is_none());
960        assert_eq!(response.object, "list");
961        assert_eq!(response.data.len(), 1);
962        assert_eq!(response.data[0].id, "llama2");
963        mock.assert();
964
965        Ok(())
966    }
967
968    #[tokio::test]
969    async fn test_list_models_by_provider() -> Result<(), GatewayError> {
970        let mut server = Server::new_async().await;
971
972        let raw_json_response = r#"{
973            "provider":"ollama",
974            "object":"list",
975            "data": [
976                {
977                    "id": "llama2",
978                    "object": "model",
979                    "created": 1630000001,
980                    "owned_by": "ollama",
981                    "served_by": "ollama"
982                }
983            ]
984        }"#;
985
986        let mock = server
987            .mock("GET", "/v1/list/models?provider=ollama")
988            .with_status(200)
989            .with_header("content-type", "application/json")
990            .with_body(raw_json_response)
991            .create();
992
993        let base_url = format!("{}/v1", server.url());
994        let client = InferenceGatewayClient::new(&base_url);
995        let response = client.list_models_by_provider(Provider::Ollama).await?;
996
997        assert!(response.provider.is_some());
998        assert_eq!(response.provider, Some(Provider::Ollama));
999        assert_eq!(response.data[0].id, "llama2");
1000        mock.assert();
1001
1002        Ok(())
1003    }
1004
1005    #[tokio::test]
1006    async fn test_generate_content() -> Result<(), GatewayError> {
1007        let mut server = Server::new_async().await;
1008
1009        let raw_json_response = r#"{
1010            "id": "chatcmpl-456",
1011            "object": "chat.completion",
1012            "created": 1630000001,
1013            "model": "mixtral-8x7b",
1014            "choices": [
1015                {
1016                    "index": 0,
1017                    "finish_reason": "stop",
1018                    "message": {
1019                        "role": "assistant",
1020                        "content": "Hellloooo"
1021                    }
1022                }
1023            ]
1024        }"#;
1025
1026        let mock = server
1027            .mock("POST", "/v1/chat/completions?provider=ollama")
1028            .with_status(200)
1029            .with_header("content-type", "application/json")
1030            .with_body(raw_json_response)
1031            .create();
1032
1033        let base_url = format!("{}/v1", server.url());
1034        let client = InferenceGatewayClient::new(&base_url);
1035
1036        let messages = vec![Message {
1037            role: MessageRole::User,
1038            content: "Hello".to_string(),
1039            ..Default::default()
1040        }];
1041        let response = client
1042            .generate_content(Provider::Ollama, "llama2", messages)
1043            .await?;
1044
1045        assert_eq!(response.choices[0].message.role, MessageRole::Assistant);
1046        assert_eq!(response.choices[0].message.content, "Hellloooo");
1047        mock.assert();
1048
1049        Ok(())
1050    }
1051
1052    #[tokio::test]
1053    async fn test_generate_content_serialization() -> Result<(), GatewayError> {
1054        let mut server = Server::new_async().await;
1055
1056        let raw_json = r#"{
1057            "id": "chatcmpl-456",
1058            "object": "chat.completion",
1059            "created": 1630000001,
1060            "model": "mixtral-8x7b",
1061            "choices": [
1062                {
1063                    "index": 0,
1064                    "finish_reason": "stop",
1065                    "message": {
1066                        "role": "assistant",
1067                        "content": "Hello"
1068                    }
1069                }
1070            ]
1071        }"#;
1072
1073        let mock = server
1074            .mock("POST", "/v1/chat/completions?provider=groq")
1075            .with_status(200)
1076            .with_header("content-type", "application/json")
1077            .with_body(raw_json)
1078            .create();
1079
1080        let base_url = format!("{}/v1", server.url());
1081        let client = InferenceGatewayClient::new(&base_url);
1082
1083        let direct_parse: Result<CreateChatCompletionResponse, _> = serde_json::from_str(raw_json);
1084        assert!(
1085            direct_parse.is_ok(),
1086            "Direct JSON parse failed: {:?}",
1087            direct_parse.err()
1088        );
1089
1090        let messages = vec![Message {
1091            role: MessageRole::User,
1092            content: "Hello".to_string(),
1093            ..Default::default()
1094        }];
1095
1096        let response = client
1097            .generate_content(Provider::Groq, "mixtral-8x7b", messages)
1098            .await?;
1099
1100        assert_eq!(response.choices[0].message.role, MessageRole::Assistant);
1101        assert_eq!(response.choices[0].message.content, "Hello");
1102
1103        mock.assert();
1104        Ok(())
1105    }
1106
1107    #[tokio::test]
1108    async fn test_generate_content_error_response() -> Result<(), GatewayError> {
1109        let mut server = Server::new_async().await;
1110
1111        let raw_json_response = r#"{
1112            "error":"Invalid request"
1113        }"#;
1114
1115        let mock = server
1116            .mock("POST", "/v1/chat/completions?provider=groq")
1117            .with_status(400)
1118            .with_header("content-type", "application/json")
1119            .with_body(raw_json_response)
1120            .create();
1121
1122        let base_url = format!("{}/v1", server.url());
1123        let client = InferenceGatewayClient::new(&base_url);
1124        let messages = vec![Message {
1125            role: MessageRole::User,
1126            content: "Hello".to_string(),
1127            ..Default::default()
1128        }];
1129        let error = client
1130            .generate_content(Provider::Groq, "mixtral-8x7b", messages)
1131            .await
1132            .unwrap_err();
1133
1134        assert!(matches!(error, GatewayError::BadRequest(_)));
1135        if let GatewayError::BadRequest(msg) = error {
1136            assert_eq!(msg, "Invalid request");
1137        }
1138        mock.assert();
1139
1140        Ok(())
1141    }
1142
1143    #[tokio::test]
1144    async fn test_gateway_errors() -> Result<(), GatewayError> {
1145        let mut server: mockito::ServerGuard = Server::new_async().await;
1146
1147        let unauthorized_mock = server
1148            .mock("GET", "/v1/models")
1149            .with_status(401)
1150            .with_header("content-type", "application/json")
1151            .with_body(r#"{"error":"Invalid token"}"#)
1152            .create();
1153
1154        let base_url = format!("{}/v1", server.url());
1155        let client = InferenceGatewayClient::new(&base_url);
1156        match client.list_models().await {
1157            Err(GatewayError::Unauthorized(msg)) => assert_eq!(msg, "Invalid token"),
1158            _ => panic!("Expected Unauthorized error"),
1159        }
1160        unauthorized_mock.assert();
1161
1162        let bad_request_mock = server
1163            .mock("GET", "/v1/models")
1164            .with_status(400)
1165            .with_header("content-type", "application/json")
1166            .with_body(r#"{"error":"Invalid provider"}"#)
1167            .create();
1168
1169        match client.list_models().await {
1170            Err(GatewayError::BadRequest(msg)) => assert_eq!(msg, "Invalid provider"),
1171            _ => panic!("Expected BadRequest error"),
1172        }
1173        bad_request_mock.assert();
1174
1175        let internal_error_mock = server
1176            .mock("GET", "/v1/models")
1177            .with_status(500)
1178            .with_header("content-type", "application/json")
1179            .with_body(r#"{"error":"Internal server error occurred"}"#)
1180            .create();
1181
1182        match client.list_models().await {
1183            Err(GatewayError::InternalError(msg)) => {
1184                assert_eq!(msg, "Internal server error occurred")
1185            }
1186            _ => panic!("Expected InternalError error"),
1187        }
1188        internal_error_mock.assert();
1189
1190        Ok(())
1191    }
1192
1193    #[tokio::test]
1194    async fn test_generate_content_case_insensitive() -> Result<(), GatewayError> {
1195        let mut server = Server::new_async().await;
1196
1197        let raw_json = r#"{
1198            "id": "chatcmpl-456",
1199            "object": "chat.completion",
1200            "created": 1630000001,
1201            "model": "mixtral-8x7b",
1202            "choices": [
1203                {
1204                    "index": 0,
1205                    "finish_reason": "stop",
1206                    "message": {
1207                        "role": "assistant",
1208                        "content": "Hello"
1209                    }
1210                }
1211            ]
1212        }"#;
1213
1214        let mock = server
1215            .mock("POST", "/v1/chat/completions?provider=groq")
1216            .with_status(200)
1217            .with_header("content-type", "application/json")
1218            .with_body(raw_json)
1219            .create();
1220
1221        let base_url = format!("{}/v1", server.url());
1222        let client = InferenceGatewayClient::new(&base_url);
1223
1224        let messages = vec![Message {
1225            role: MessageRole::User,
1226            content: "Hello".to_string(),
1227            ..Default::default()
1228        }];
1229
1230        let response = client
1231            .generate_content(Provider::Groq, "mixtral-8x7b", messages)
1232            .await?;
1233
1234        assert_eq!(response.choices[0].message.role, MessageRole::Assistant);
1235        assert_eq!(response.choices[0].message.content, "Hello");
1236        assert_eq!(response.model, "mixtral-8x7b");
1237        assert_eq!(response.object, "chat.completion");
1238        mock.assert();
1239
1240        Ok(())
1241    }
1242
1243    #[tokio::test]
1244    async fn test_generate_content_stream() -> Result<(), GatewayError> {
1245        let mut server = Server::new_async().await;
1246
1247        let mock = server
1248            .mock("POST", "/v1/chat/completions?provider=groq")
1249            .with_status(200)
1250            .with_header("content-type", "text/event-stream")
1251            .with_chunked_body(move |writer| -> std::io::Result<()> {
1252                let events = vec![
1253                    format!("data: {}\n\n", r#"{"id":"chatcmpl-123","object":"chat.completion.chunk","created":1694268190,"model":"mixtral-8x7b","system_fingerprint":"fp_","choices":[{"index":0,"delta":{"role":"assistant","content":"Hello"},"finish_reason":null}]}"#),
1254                    format!("data: {}\n\n", r#"{"id":"chatcmpl-123","object":"chat.completion.chunk","created":1694268191,"model":"mixtral-8x7b","system_fingerprint":"fp_","choices":[{"index":0,"delta":{"role":"assistant","content":" World"},"finish_reason":null}]}"#),
1255                    format!("data: {}\n\n", r#"{"id":"chatcmpl-123","object":"chat.completion.chunk","created":1694268192,"model":"mixtral-8x7b","system_fingerprint":"fp_","choices":[{"index":0,"delta":{},"finish_reason":"stop"}],"usage":{"prompt_tokens":17,"completion_tokens":40,"total_tokens":57}}"#),
1256                    format!("data: [DONE]\n\n")
1257                ];
1258                for event in events {
1259                    writer.write_all(event.as_bytes())?;
1260                }
1261                Ok(())
1262            })
1263            .create();
1264
1265        let base_url = format!("{}/v1", server.url());
1266        let client = InferenceGatewayClient::new(&base_url);
1267
1268        let messages = vec![Message {
1269            role: MessageRole::User,
1270            content: "Test message".to_string(),
1271            ..Default::default()
1272        }];
1273
1274        let stream = client.generate_content_stream(Provider::Groq, "mixtral-8x7b", messages);
1275        pin_mut!(stream);
1276        while let Some(result) = stream.next().await {
1277            let result = result?;
1278            let generate_response: CreateChatCompletionStreamResponse =
1279                serde_json::from_str(&result.data)
1280                    .expect("Failed to parse CreateChatCompletionResponse");
1281
1282            if generate_response.choices[0].finish_reason.is_some() {
1283                assert_eq!(
1284                    generate_response.choices[0].finish_reason.as_ref().unwrap(),
1285                    "stop"
1286                );
1287                break;
1288            }
1289
1290            if let Some(content) = &generate_response.choices[0].delta.content {
1291                assert!(matches!(content.as_str(), "Hello" | " World"));
1292            }
1293            if let Some(role) = &generate_response.choices[0].delta.role {
1294                assert_eq!(role, &MessageRole::Assistant);
1295            }
1296        }
1297
1298        mock.assert();
1299        Ok(())
1300    }
1301
1302    #[tokio::test]
1303    async fn test_generate_content_stream_error() -> Result<(), GatewayError> {
1304        let mut server = Server::new_async().await;
1305
1306        let mock = server
1307            .mock("POST", "/v1/chat/completions?provider=groq")
1308            .with_status(400)
1309            .with_header("content-type", "application/json")
1310            .with_chunked_body(move |writer| -> std::io::Result<()> {
1311                let events = vec![format!(
1312                    "event: {}\ndata: {}\nretry: {}\n\n",
1313                    r#"error"#, r#"{"error":"Invalid request"}"#, r#"1000"#,
1314                )];
1315                for event in events {
1316                    writer.write_all(event.as_bytes())?;
1317                }
1318                Ok(())
1319            })
1320            .expect_at_least(1)
1321            .create();
1322
1323        let base_url = format!("{}/v1", server.url());
1324        let client = InferenceGatewayClient::new(&base_url);
1325
1326        let messages = vec![Message {
1327            role: MessageRole::User,
1328            content: "Test message".to_string(),
1329            ..Default::default()
1330        }];
1331
1332        let stream = client.generate_content_stream(Provider::Groq, "mixtral-8x7b", messages);
1333
1334        pin_mut!(stream);
1335        while let Some(result) = stream.next().await {
1336            let result = result?;
1337            assert!(result.event.is_some());
1338            assert_eq!(result.event.unwrap(), "error");
1339            assert!(result.data.contains("Invalid request"));
1340            assert!(result.retry.is_none());
1341        }
1342
1343        mock.assert();
1344        Ok(())
1345    }
1346
1347    #[tokio::test]
1348    async fn test_generate_content_with_tools() -> Result<(), GatewayError> {
1349        let mut server = Server::new_async().await;
1350
1351        let raw_json_response = r#"{
1352            "id": "chatcmpl-123",
1353            "object": "chat.completion",
1354            "created": 1630000000,
1355            "model": "deepseek-r1-distill-llama-70b",
1356            "choices": [
1357                {
1358                    "index": 0,
1359                    "finish_reason": "tool_calls",
1360                    "message": {
1361                        "role": "assistant",
1362                        "content": "Let me check the weather for you.",
1363                        "tool_calls": [
1364                            {
1365                                "id": "1234",
1366                                "type": "function",
1367                                "function": {
1368                                    "name": "get_weather",
1369                                    "arguments": "{\"location\": \"London\"}"
1370                                }
1371                            }
1372                        ]
1373                    }
1374                }
1375            ]
1376        }"#;
1377
1378        let mock = server
1379            .mock("POST", "/v1/chat/completions?provider=groq")
1380            .with_status(200)
1381            .with_header("content-type", "application/json")
1382            .with_body(raw_json_response)
1383            .create();
1384
1385        let tools = vec![Tool {
1386            r#type: ToolType::Function,
1387            function: FunctionObject {
1388                name: "get_weather".to_string(),
1389                description: "Get the weather for a location".to_string(),
1390                parameters: json!({
1391                    "type": "object",
1392                    "properties": {
1393                        "location": {
1394                            "type": "string",
1395                            "description": "The city name"
1396                        }
1397                    },
1398                    "required": ["location"]
1399                }),
1400            },
1401        }];
1402
1403        let base_url = format!("{}/v1", server.url());
1404        let client = InferenceGatewayClient::new(&base_url).with_tools(Some(tools));
1405
1406        let messages = vec![Message {
1407            role: MessageRole::User,
1408            content: "What's the weather in London?".to_string(),
1409            ..Default::default()
1410        }];
1411
1412        let response = client
1413            .generate_content(Provider::Groq, "deepseek-r1-distill-llama-70b", messages)
1414            .await?;
1415
1416        assert_eq!(response.choices[0].message.role, MessageRole::Assistant);
1417        assert_eq!(
1418            response.choices[0].message.content,
1419            "Let me check the weather for you."
1420        );
1421
1422        let tool_calls = response.choices[0].message.tool_calls.as_ref().unwrap();
1423        assert_eq!(tool_calls.len(), 1);
1424        assert_eq!(tool_calls[0].function.name, "get_weather");
1425
1426        let params = tool_calls[0]
1427            .function
1428            .parse_arguments()
1429            .expect("Failed to parse function arguments");
1430        assert_eq!(params["location"].as_str().unwrap(), "London");
1431
1432        mock.assert();
1433        Ok(())
1434    }
1435
1436    #[tokio::test]
1437    async fn test_generate_content_without_tools() -> Result<(), GatewayError> {
1438        let mut server = Server::new_async().await;
1439
1440        let raw_json_response = r#"{
1441            "id": "chatcmpl-123",
1442            "object": "chat.completion",
1443            "created": 1630000000,
1444            "model": "gpt-4",
1445            "choices": [
1446                {
1447                    "index": 0,
1448                    "finish_reason": "stop",
1449                    "message": {
1450                        "role": "assistant",
1451                        "content": "Hello!"
1452                    }
1453                }
1454            ]
1455        }"#;
1456
1457        let mock = server
1458            .mock("POST", "/v1/chat/completions?provider=openai")
1459            .with_status(200)
1460            .with_header("content-type", "application/json")
1461            .with_body(raw_json_response)
1462            .create();
1463
1464        let base_url = format!("{}/v1", server.url());
1465        let client = InferenceGatewayClient::new(&base_url);
1466
1467        let messages = vec![Message {
1468            role: MessageRole::User,
1469            content: "Hi".to_string(),
1470            ..Default::default()
1471        }];
1472
1473        let response = client
1474            .generate_content(Provider::OpenAI, "gpt-4", messages)
1475            .await?;
1476
1477        assert_eq!(response.model, "gpt-4");
1478        assert_eq!(response.choices[0].message.content, "Hello!");
1479        assert_eq!(response.choices[0].message.role, MessageRole::Assistant);
1480        assert!(response.choices[0].message.tool_calls.is_none());
1481
1482        mock.assert();
1483        Ok(())
1484    }
1485
1486    #[tokio::test]
1487    async fn test_generate_content_with_tools_payload() -> Result<(), GatewayError> {
1488        let mut server = Server::new_async().await;
1489
1490        let raw_request_body = r#"{
1491            "model": "deepseek-r1-distill-llama-70b",
1492            "messages": [
1493                {
1494                    "role": "system",
1495                    "content": "You are a helpful assistant."
1496                },
1497                {
1498                    "role": "user",
1499                    "content": "What is the current weather in Toronto?"
1500                }
1501            ],
1502            "stream": false,
1503            "tools": [
1504                {
1505                    "type": "function",
1506                    "function": {
1507                        "name": "get_current_weather",
1508                        "description": "Get the current weather of a city",
1509                        "parameters": {
1510                            "type": "object",
1511                            "properties": {
1512                                "city": {
1513                                    "type": "string",
1514                                    "description": "The name of the city"
1515                                }
1516                            },
1517                            "required": ["city"]
1518                        }
1519                    }
1520                }
1521            ]
1522        }"#;
1523
1524        let raw_json_response = r#"{
1525            "id": "1234",
1526            "object": "chat.completion",
1527            "created": 1630000000,
1528            "model": "deepseek-r1-distill-llama-70b",
1529            "choices": [
1530                {
1531                    "index": 0,
1532                    "finish_reason": "stop",
1533                    "message": {
1534                        "role": "assistant",
1535                        "content": "Let me check the weather for you",
1536                        "tool_calls": [
1537                            {
1538                                "id": "1234",
1539                                "type": "function",
1540                                "function": {
1541                                    "name": "get_current_weather",
1542                                    "arguments": "{\"city\": \"Toronto\"}"
1543                                }
1544                            }
1545                        ]
1546                    }
1547                }
1548            ]
1549        }"#;
1550
1551        let mock = server
1552            .mock("POST", "/v1/chat/completions?provider=groq")
1553            .with_status(200)
1554            .with_header("content-type", "application/json")
1555            .match_body(mockito::Matcher::JsonString(raw_request_body.to_string()))
1556            .with_body(raw_json_response)
1557            .create();
1558
1559        let tools = vec![Tool {
1560            r#type: ToolType::Function,
1561            function: FunctionObject {
1562                name: "get_current_weather".to_string(),
1563                description: "Get the current weather of a city".to_string(),
1564                parameters: json!({
1565                    "type": "object",
1566                    "properties": {
1567                        "city": {
1568                            "type": "string",
1569                            "description": "The name of the city"
1570                        }
1571                    },
1572                    "required": ["city"]
1573                }),
1574            },
1575        }];
1576
1577        let base_url = format!("{}/v1", server.url());
1578        let client = InferenceGatewayClient::new(&base_url);
1579
1580        let messages = vec![
1581            Message {
1582                role: MessageRole::System,
1583                content: "You are a helpful assistant.".to_string(),
1584                ..Default::default()
1585            },
1586            Message {
1587                role: MessageRole::User,
1588                content: "What is the current weather in Toronto?".to_string(),
1589                ..Default::default()
1590            },
1591        ];
1592
1593        let response = client
1594            .with_tools(Some(tools))
1595            .generate_content(Provider::Groq, "deepseek-r1-distill-llama-70b", messages)
1596            .await?;
1597
1598        assert_eq!(response.choices[0].message.role, MessageRole::Assistant);
1599        assert_eq!(
1600            response.choices[0].message.content,
1601            "Let me check the weather for you"
1602        );
1603        assert_eq!(
1604            response.choices[0]
1605                .message
1606                .tool_calls
1607                .as_ref()
1608                .unwrap()
1609                .len(),
1610            1
1611        );
1612
1613        mock.assert();
1614        Ok(())
1615    }
1616
1617    #[tokio::test]
1618    async fn test_generate_content_with_max_tokens() -> Result<(), GatewayError> {
1619        let mut server = Server::new_async().await;
1620
1621        let raw_json_response = r#"{
1622            "id": "chatcmpl-123",
1623            "object": "chat.completion",
1624            "created": 1630000000,
1625            "model": "mixtral-8x7b",
1626            "choices": [
1627                {
1628                    "index": 0,
1629                    "finish_reason": "stop",
1630                    "message": {
1631                        "role": "assistant",
1632                        "content": "Here's a poem with 100 tokens..."
1633                    }
1634                }
1635            ]
1636        }"#;
1637
1638        let mock = server
1639            .mock("POST", "/v1/chat/completions?provider=groq")
1640            .with_status(200)
1641            .with_header("content-type", "application/json")
1642            .match_body(mockito::Matcher::JsonString(
1643                r#"{
1644                "model": "mixtral-8x7b",
1645                "messages": [{"role":"user","content":"Write a poem"}],
1646                "stream": false,
1647                "max_tokens": 100
1648            }"#
1649                .to_string(),
1650            ))
1651            .with_body(raw_json_response)
1652            .create();
1653
1654        let base_url = format!("{}/v1", server.url());
1655        let client = InferenceGatewayClient::new(&base_url).with_max_tokens(Some(100));
1656
1657        let messages = vec![Message {
1658            role: MessageRole::User,
1659            content: "Write a poem".to_string(),
1660            ..Default::default()
1661        }];
1662
1663        let response = client
1664            .generate_content(Provider::Groq, "mixtral-8x7b", messages)
1665            .await?;
1666
1667        assert_eq!(
1668            response.choices[0].message.content,
1669            "Here's a poem with 100 tokens..."
1670        );
1671        assert_eq!(response.model, "mixtral-8x7b");
1672        assert_eq!(response.created, 1630000000);
1673        assert_eq!(response.object, "chat.completion");
1674
1675        mock.assert();
1676        Ok(())
1677    }
1678
1679    #[tokio::test]
1680    async fn test_health_check() -> Result<(), GatewayError> {
1681        let mut server = Server::new_async().await;
1682        let mock = server.mock("GET", "/health").with_status(200).create();
1683
1684        let client = InferenceGatewayClient::new(&server.url());
1685        let is_healthy = client.health_check().await?;
1686
1687        assert!(is_healthy);
1688        mock.assert();
1689
1690        Ok(())
1691    }
1692
1693    #[tokio::test]
1694    async fn test_client_base_url_configuration() -> Result<(), GatewayError> {
1695        let mut custom_url_server = Server::new_async().await;
1696
1697        let custom_url_mock = custom_url_server
1698            .mock("GET", "/health")
1699            .with_status(200)
1700            .create();
1701
1702        let custom_client = InferenceGatewayClient::new(&custom_url_server.url());
1703        let is_healthy = custom_client.health_check().await?;
1704        assert!(is_healthy);
1705        custom_url_mock.assert();
1706
1707        let default_client = InferenceGatewayClient::new_default();
1708
1709        let default_url = "http://localhost:8080/v1";
1710        assert_eq!(default_client.base_url(), default_url);
1711
1712        Ok(())
1713    }
1714}